Unverified Commit bf068c83 authored by David Hoeller's avatar David Hoeller Committed by GitHub

Adds the Hydra configuration system for RL training (#700)

# Description

This MR adds utilities to enable the hydra configuration system.
Using the `train.py` scripts, the user can now change any parameter in
the environment or agent configs from command line inputs, for example:
``` 
python source/standalone/workflows/rsl_rl/train.py --task Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0
```

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [ ] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarDavid Hoeller <dhoeller@nvidia.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 173b5f3e
......@@ -70,6 +70,7 @@ Table of Contents
:caption: Features
source/features/task_workflows
source/features/hydra
source/features/multi_gpu
source/features/tiled_rendering
source/features/environments
......
Hydra Configuration System
==========================
.. currentmodule:: omni.isaac.lab
Isaac Lab supports the `Hydra <https://hydra.cc/docs/intro/>`_ configuration system to modify the task's
configuration using command line arguments, which can be useful to automate experiments and perform hyperparameter tuning.
Any parameter of the environment can be modified by adding one or multiple elements of the form ``env.a.b.param1=value``
to the command line input, where ``a.b.param1`` reflects the parameter's hierarchy, for example ``env.actions.joint_effort.scale=10.0``.
Similarly, the agent's parameters can be modified by using the ``agent`` prefix, for example ``agent.seed=2024``.
.. note::
The way these command line arguments are set follow the exact structure of the configuration files. Since the different
RL frameworks use different conventions, there might be differences in the way the parameters are set. For example,
with `rl_games` the seed will be set with ``agent.params.seed``, while with `rsl_rl` and `skrl` it will be set with
``agent.seed``.
As a result, training with hydra arguments can be run with the following syntax:
.. tab-set::
:sync-group: rl-train
.. tab-item:: rsl_rl
:sync: rsl_rl
.. code-block:: shell
python source/standalone/workflows/rsl_rl/train.py --task=Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0 agent.seed=2024
.. tab-item:: rl_games
:sync: rl_games
.. code-block:: shell
python source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0 agent.params.seed=2024
.. tab-item:: skrl
:sync: skrl
.. code-block:: shell
python source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0 agent.seed=2024
The above command will run the training script with the task ``Isaac-Cartpole-v0`` in headless mode, and set the
``env.actions.joint_effort.scale`` parameter to 10.0 and the ``agent.seed`` parameter to 2024.
.. note::
To keep backwards compatibility, and to provide a more user-friendly experience, we have kept the old cli arguments
of the form ``--param``, for example ``--num_envs``, ``--seed``, ``--max_iterations``. These arguments have precedence
over the hydra arguments, and will overwrite the values set by the hydra arguments.
.. attention::
Particular care should be taken when modifying the parameters using command line arguments. Some of the configurations
perform intermediate computations based on other parameters. These computations will not be updated when the parameters
are modified.
For example, for the configuration of the Cartpole camera depth environment:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/cartpole/cartpole_camera_env.py
:language: python
:start-at: class CartpoleDepthCameraEnvCfg
:end-at: tiled_camera.width
:emphasize-lines: 16
If the user were to modify the width of the camera, i.e. ``env.tiled_camera.width=128``, then the parameter
``env.num_observations=10240`` (1*80*128) must be updated and given as input as well.
Similarly, the ``__post_init__`` method is not updated with the command line inputs. In the ``LocomotionVelocityRoughEnvCfg``, for example,
the post init update is as follows:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/velocity_env_cfg.py
:language: python
:start-at: class LocomotionVelocityRoughEnvCfg
:emphasize-lines: 23, 29, 31
Here, when modifying ``env.decimation`` or ``env.sim.dt``, the user would have to give the updated ``env.sim.render_interval``,
``env.scene.height_scanner.update_period``, and ``env.scene.contact_forces.update_period`` as input as well.
Modifying advanced parameters
-----------------------------
Callables
^^^^^^^^^
It is possible to modify functions and classes in the configuration files by using the syntax ``module:attribute_name``.
For example, in the Cartpole environment:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py
:language: python
:start-at: class ObservationsCfg
:end-at: policy: PolicyCfg = PolicyCfg()
:emphasize-lines: 9
we could modify ``joint_pos_rel`` to compute absolute positions instead of relative positions with
``env.observations.policy.joint_pos_rel.func=omni.isaac.lab.envs.mdp:joint_pos``.
Setting parameters to None
^^^^^^^^^^^^^^^^^^^^^^^^^^
To set parameters to None, use the ``null`` keyword, which is a special keyword in Hydra that is automatically converted to None.
In the above example, we could also disable the ``joint_pos_rel`` observation by setting it to None with
``env.observations.policy.joint_pos_rel.func=null``.
Dictionaries
^^^^^^^^^^^^
Elements in dictionaries are handled as a parameters in the hierarchy. For example, in the Cartpole environment:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py
:language: python
:lines: 99-111
:emphasize-lines: 10
the ``position_range`` parameter can be modified with ``env.events.reset_cart_position.params.position_range="[-2.0, 2.0]"``.
This example shows two noteworthy points:
- The parameter we set has a space, so it must be enclosed in quotes.
- The parameter is a list while it is a tuple in the config. This is due to the fact that Hydra does not support tuples.
......@@ -482,15 +482,14 @@ class AppLauncher:
if launcher_args.get("cpu", False):
raise ValueError("The `--cpu` flag is deprecated. Please use `--device cpu` instead.")
if "distributed" in launcher_args:
distributed_train = launcher_args["distributed"]
if "distributed" in launcher_args and launcher_args["distributed"]:
# local rank (GPU id) in a current multi-gpu mode
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
# global rank (GPU id) in multi-gpu multi-node mode
self.global_rank = int(os.getenv("RANK", "0"))
if distributed_train:
self.device_id = self.local_rank
launcher_args["multi_gpu"] = False
self.device_id = self.local_rank
launcher_args["multi_gpu"] = False
# limit CPU threads to minimize thread context switching
# this ensures processes do not take up all available threads and fight for resources
num_cpu_cores = os.cpu_count()
......@@ -570,6 +569,9 @@ class AppLauncher:
# add Isaac Lab modules back to sys.modules
for key, value in hacked_modules.items():
sys.modules[key] = value
# remove the threadCount argument from sys.argv if it was added for distributed training
pattern = r"--/plugins/carb\.tasking\.plugin/threadCount=\d+"
sys.argv = [arg for arg in sys.argv if not re.match(pattern, arg)]
def _rendering_enabled(self) -> bool:
"""Check if rendering is required by the app."""
......
......@@ -107,7 +107,7 @@ class RewardManager(ManagerBase):
# store information
# r_1 + r_2 + ... + r_n
episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids])
extras["Episode Reward/" + key] = episodic_sum_avg / self._env.max_episode_length_s
extras["Episode_Reward/" + key] = episodic_sum_avg / self._env.max_episode_length_s
# reset episodic sum
self._episode_sums[key][env_ids] = 0.0
# reset all the reward terms
......
......@@ -135,7 +135,7 @@ class TerminationManager(ManagerBase):
extras = {}
for key in self._term_dones.keys():
# store information
extras["Episode Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
extras["Episode_Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
# reset all the reward terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
......
......@@ -113,6 +113,9 @@ These are redefined here to add new docstrings.
def _class_to_dict(obj: object) -> dict[str, Any]:
"""Convert an object into dictionary recursively.
Args:
obj: The object to convert.
Returns:
Converted dictionary mapping.
"""
......@@ -125,6 +128,7 @@ def _update_class_from_dict(obj, data: dict[str, Any]) -> None:
This function performs in-place update of the class member attributes.
Args:
obj: The object to update.
data: Input (nested) dictionary to update from.
Raises:
......@@ -132,7 +136,7 @@ def _update_class_from_dict(obj, data: dict[str, Any]) -> None:
ValueError: When dictionary has a value that does not match default config type.
KeyError: When dictionary has a key that does not exist in the default config type.
"""
return update_class_from_dict(obj, data, _ns="")
update_class_from_dict(obj, data, _ns="")
def _replace_class_with_kwargs(obj: object, **kwargs) -> object:
......
......@@ -12,7 +12,7 @@ from collections.abc import Iterable, Mapping
from typing import Any
from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES
from .string import callable_to_string, string_to_callable
from .string import callable_to_string, string_to_callable, string_to_slice
"""
Dictionary <-> Class operations.
......@@ -42,6 +42,7 @@ def class_to_dict(obj: object) -> dict[str, Any]:
obj_dict = obj
else:
obj_dict = obj.__dict__
# convert to dictionary
data = dict()
for key, value in obj_dict.items():
......@@ -79,39 +80,36 @@ def update_class_from_dict(obj, data: dict[str, Any], _ns: str = "") -> None:
# key_ns is the full namespace of the key
key_ns = _ns + "/" + key
# check if key is present in the object
if hasattr(obj, key):
obj_mem = getattr(obj, key)
if isinstance(obj_mem, Mapping):
# Note: We don't handle two-level nested dictionaries. Just use configclass if this is needed.
# iterate over the dictionary to look for callable values
for k, v in obj_mem.items():
if callable(v):
value[k] = string_to_callable(value[k])
setattr(obj, key, value)
elif isinstance(value, Mapping):
if hasattr(obj, key) or isinstance(obj, dict):
obj_mem = obj[key] if isinstance(obj, dict) else getattr(obj, key)
if isinstance(value, Mapping):
# recursively call if it is a dictionary
update_class_from_dict(obj_mem, value, _ns=key_ns)
elif isinstance(value, Iterable) and not isinstance(value, str):
continue
if isinstance(value, Iterable) and not isinstance(value, str):
# check length of value to be safe
if len(obj_mem) != len(value) and obj_mem is not None:
raise ValueError(
f"[Config]: Incorrect length under namespace: {key_ns}."
f" Expected: {len(obj_mem)}, Received: {len(value)}."
)
# set value
setattr(obj, key, value)
if isinstance(obj_mem, tuple):
value = tuple(value)
elif callable(obj_mem):
# update function name
value = string_to_callable(value)
setattr(obj, key, value)
elif isinstance(value, type(obj_mem)):
# check that they are type-safe
setattr(obj, key, value)
elif isinstance(value, type(obj_mem)) or value is None:
pass
else:
raise ValueError(
f"[Config]: Incorrect type under namespace: {key_ns}."
f" Expected: {type(obj_mem)}, Received: {type(value)}."
)
# set value
if isinstance(obj, dict):
obj[key] = value
else:
setattr(obj, key, value)
else:
raise KeyError(f"[Config]: Key not found under namespace: {key_ns}.")
......@@ -237,6 +235,40 @@ def update_dict(orig_dict: dict, new_dict: collections.abc.Mapping) -> dict:
return orig_dict
def replace_slices_with_strings(data: dict) -> dict:
"""Replace slice objects with their string representations in a dictionary.
Args:
data: The dictionary to process.
Returns:
The dictionary with slice objects replaced by their string representations.
"""
if isinstance(data, dict):
return {k: replace_slices_with_strings(v) for k, v in data.items()}
elif isinstance(data, slice):
return f"slice({data.start},{data.stop},{data.step})"
else:
return data
def replace_strings_with_slices(data: dict) -> dict:
"""Replace string representations of slices with slice objects in a dictionary.
Args:
data: The dictionary to process.
Returns:
The dictionary with string representations of slices replaced by slice objects.
"""
if isinstance(data, dict):
return {k: replace_strings_with_slices(v) for k, v in data.items()}
elif isinstance(data, str) and data.startswith("slice("):
return string_to_slice(data)
else:
return data
def print_dict(val, nesting: int = -4, start: bool = True):
"""Outputs a nested dictionary."""
if isinstance(val, dict):
......
......@@ -58,6 +58,32 @@ def to_snake_case(camel_str: str) -> str:
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", camel_str).lower()
def string_to_slice(s: str):
"""Convert a string representation of a slice to a slice object.
Args:
s: The string representation of the slice.
Returns:
The slice object.
"""
# extract the content inside the slice()
match = re.match(r"slice\((.*),(.*),(.*)\)", s)
if not match:
raise ValueError(f"Invalid slice string format: {s}")
# extract start, stop, and step values
start_str, stop_str, step_str = match.groups()
# convert 'None' to None and other strings to integers
start = None if start_str == "None" else int(start_str)
stop = None if stop_str == "None" else int(stop_str)
step = None if step_str == "None" else int(step_str)
# create and return the slice object
return slice(start, stop, step)
"""
String <-> Callable operations.
"""
......
......@@ -292,6 +292,18 @@ class FunctionImplementedDemoCfg:
self.a = a
"""
Dummy configuration: Nested dictionaries
"""
@configclass
class NestedDictCfg:
"""Dummy configuration class with nested dictionaries."""
dict_1: dict = {"dict_2": {"func": dummy_function1}}
"""
Test solutions: Basic
"""
......@@ -318,6 +330,23 @@ basic_demo_cfg_change_correct = {
"device_id": 0,
}
basic_demo_cfg_change_with_none_correct = {
"env": {"num_envs": 22, "episode_length": 2000, "viewer": None},
"robot_default_state": {
"pos": (0.0, 0.0, 0.0),
"rot": (1.0, 0.0, 0.0, 0.0),
"dof_pos": (0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
"dof_vel": [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
},
"device_id": 0,
}
basic_demo_cfg_nested_dict = {
"dict_1": {
"dict_2": {"func": dummy_function2},
},
}
basic_demo_post_init_cfg_correct = {
"env": {"num_envs": 56, "episode_length": 2000, "viewer": {"eye": [7.5, 7.5, 7.5], "lookat": [0.0, 0.0, 0.0]}},
"robot_default_state": {
......@@ -427,6 +456,20 @@ class TestConfigClass(unittest.TestCase):
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_correct)
def test_config_update_dict_with_none(self):
"""Test updating configclass using a dictionary that contains None."""
cfg = BasicDemoCfg()
cfg_dict = {"env": {"num_envs": 22, "viewer": None}}
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_with_none_correct)
def test_config_update_nested_dict(self):
"""Test updating configclass with sub-dictionnaries."""
cfg = NestedDictCfg()
cfg_dict = {"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}}}
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict)
def test_config_update_dict_using_internal(self):
"""Test updating configclass from a dictionary using configclass method."""
cfg = BasicDemoCfg()
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.0"
version = "0.10.0"
# Description
title = "Isaac Lab Environments"
......
Changelog
---------
0.10.0 (2024-08-14)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for the Hydra configuration system to all the train scripts. As a result, parameters of the environment
and the agent can be modified using command line arguments, for example ``env.actions.joint_effort.scale=10``.
0.9.0 (2024-08-05)
~~~~~~~~~~~~~~~~~~~
......
......@@ -299,11 +299,11 @@ class AnymalCEnv(DirectRLEnv):
extras = dict()
for key in self._episode_sums.keys():
episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids])
extras["Episode Reward/" + key] = episodic_sum_avg / self.max_episode_length_s
extras["Episode_Reward/" + key] = episodic_sum_avg / self.max_episode_length_s
self._episode_sums[key][env_ids] = 0.0
self.extras["log"] = dict()
self.extras["log"].update(extras)
extras = dict()
extras["Episode Termination/base_contact"] = torch.count_nonzero(self.reset_terminated[env_ids]).item()
extras["Episode Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item()
extras["Episode_Termination/base_contact"] = torch.count_nonzero(self.reset_terminated[env_ids]).item()
extras["Episode_Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item()
self.extras["log"].update(extras)
......@@ -199,13 +199,13 @@ class QuadcopterEnv(DirectRLEnv):
extras = dict()
for key in self._episode_sums.keys():
episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids])
extras["Episode Reward/" + key] = episodic_sum_avg / self.max_episode_length_s
extras["Episode_Reward/" + key] = episodic_sum_avg / self.max_episode_length_s
self._episode_sums[key][env_ids] = 0.0
self.extras["log"] = dict()
self.extras["log"].update(extras)
extras = dict()
extras["Episode Termination/died"] = torch.count_nonzero(self.reset_terminated[env_ids]).item()
extras["Episode Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item()
extras["Episode_Termination/died"] = torch.count_nonzero(self.reset_terminated[env_ids]).item()
extras["Episode_Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item()
extras["Metrics/final_distance_to_goal"] = final_distance_to_goal.item()
self.extras["log"].update(extras)
......
......@@ -92,7 +92,8 @@ class H1RoughEnvCfg(LocomotionVelocityRoughEnvCfg):
super().__post_init__()
# Scene
self.scene.robot = H1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
self.scene.height_scanner.prim_path = "{ENV_REGEX_NS}/Robot/torso_link"
if self.scene.height_scanner:
self.scene.height_scanner.prim_path = "{ENV_REGEX_NS}/Robot/torso_link"
# Randomization
self.events.push_robot = None
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Sub-module with utilities for the hydra configuration system."""
import functools
from collections.abc import Callable
try:
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf
except ImportError:
raise ImportError("Hydra is not installed. Please install it by running 'pip install hydra-core'.")
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils import replace_slices_with_strings, replace_strings_with_slices
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry
def register_task_to_hydra(
task_name: str, agent_cfg_entry_point: str
) -> tuple[ManagerBasedRLEnvCfg | DirectRLEnvCfg, dict]:
"""Register the task configuration to the Hydra configuration store.
This function resolves the configuration file for the environment and agent based on the task's name.
It then registers the configurations to the Hydra configuration store.
Args:
task_name: The name of the task.
agent_cfg_entry_point: The entry point key to resolve the agent's configuration file.
Returns:
A tuple containing the parsed environment and agent configuration objects.
"""
# load the configurations
env_cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point")
agent_cfg = load_cfg_from_registry(task_name, agent_cfg_entry_point)
# convert the configs to dictionary
env_cfg_dict = env_cfg.to_dict()
if isinstance(agent_cfg, dict):
agent_cfg_dict = agent_cfg
else:
agent_cfg_dict = agent_cfg.to_dict()
cfg_dict = {"env": env_cfg_dict, "agent": agent_cfg_dict}
# replace slices with strings because OmegaConf does not support slices
cfg_dict = replace_slices_with_strings(cfg_dict)
# store the configuration to Hydra
ConfigStore.instance().store(name=task_name, node=cfg_dict)
return env_cfg, agent_cfg
def hydra_task_config(task_name: str, agent_cfg_entry_point: str) -> Callable:
"""Decorator to handle the Hydra configuration for a task.
This decorator registers the task to Hydra and updates the environment and agent configurations from Hydra parsed
command line arguments.
Args:
task_name: The name of the task.
agent_cfg_entry_point: The entry point key to resolve the agent's configuration file.
Returns:
The decorated function with the envrionment's and agent's configurations updated from command line arguments.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# register the task to Hydra
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)
# define thr new Hydra main function
@hydra.main(config_path=None, config_name=task_name, version_base="1.3")
def hydra_main(hydra_env_cfg: DictConfig, env_cfg=env_cfg, agent_cfg=agent_cfg):
# convert to a native dictionary
hydra_env_cfg = OmegaConf.to_container(hydra_env_cfg, resolve=True)
# replace string with slices because OmegaConf does not support slices
hydra_env_cfg = replace_strings_with_slices(hydra_env_cfg)
# update the configs with the Hydra command line arguments
env_cfg.from_dict(hydra_env_cfg["env"])
if isinstance(agent_cfg, dict):
agent_cfg = hydra_env_cfg["agent"]
else:
agent_cfg.from_dict(hydra_env_cfg["agent"])
# call the original function
func(env_cfg, agent_cfg, *args, **kwargs)
# call the new Hydra main function
hydra_main()
return wrapper
return decorator
......@@ -26,6 +26,8 @@ INSTALL_REQUIRES = [
# 5.26.0 introduced a breaking change, so we restricted it for now.
# See issue https://github.com/tensorflow/tensorboard/issues/6808 for details.
"protobuf>=3.20.2, < 5.0.0",
# configuration management
"hydra-core",
# data collection
"h5py",
# basic logger
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Launch Isaac Sim Simulator first."""
import sys
from omni.isaac.lab.app import AppLauncher, run_tests
# launch the simulator
app_launcher = AppLauncher(headless=True)
simulation_app = app_launcher.app
"""Rest everything follows."""
import unittest
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
class TestHydra(unittest.TestCase):
"""Test the Hydra configuration system."""
def test_hydra(self):
"""Test the hydra configuration system."""
# set hardcoded command line arguments
sys.argv = [
sys.argv[0],
"env.decimation=42", # test simple env modification
"env.events.physics_material.params.asset_cfg.joint_ids='slice(0 ,1, 2)'", # test slice setting
"env.scene.robot.init_state.joint_vel={.*: 4.0}", # test regex setting
"env.rewards.feet_air_time=null", # test setting to none
"agent.max_iterations=3", # test simple agent modification
]
@hydra_task_config("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
def main(env_cfg, agent_cfg, self):
# env
self.assertEqual(env_cfg.decimation, 42)
self.assertEqual(env_cfg.events.physics_material.params["asset_cfg"].joint_ids, slice(0, 1, 2))
self.assertEqual(env_cfg.scene.robot.init_state.joint_vel, {".*": 4.0})
self.assertIsNone(env_cfg.rewards.feet_air_time)
# agent
self.assertEqual(agent_cfg.max_iterations, 3)
main(self)
if __name__ == "__main__":
run_tests()
......@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first."""
import argparse
import sys
from omni.isaac.lab.app import AppLauncher
......@@ -16,17 +17,12 @@ parser = argparse.ArgumentParser(description="Train an RL agent with RL-Games.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument(
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
)
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument("--sigma", type=str, default=None, help="The policy's initial standard deviation.")
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
......@@ -34,11 +30,14 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
......@@ -54,28 +53,40 @@ from rl_games.common import env_configurations, vecenv
from rl_games.common.algo_observer import IsaacAlgoObserver
from rl_games.torch_runner import Runner
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesGpuEnv, RlGamesVecEnvWrapper
def main():
@hydra_task_config(args_cli.task, "rl_games_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
"""Train with RL-Games agent."""
# parse seed from command line
args_cli_seed = args_cli.seed
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"]
agent_cfg["params"]["config"]["max_epochs"] = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"]
)
agent_cfg = load_cfg_from_registry(args_cli.task, "rl_games_cfg_entry_point")
# override from command line
if args_cli_seed is not None:
agent_cfg["params"]["seed"] = args_cli_seed
if args_cli.checkpoint is not None:
resume_path = retrieve_file_path(args_cli.checkpoint)
agent_cfg["params"]["load_checkpoint"] = True
agent_cfg["params"]["load_path"] = resume_path
print(f"[INFO]: Loading model checkpoint from: {agent_cfg['params']['load_path']}")
train_sigma = float(args_cli.sigma) if args_cli.sigma is not None else None
# multi-gpu training config
if args_cli.distributed:
agent_cfg["params"]["seed"] += app_launcher.global_rank
agent_cfg["params"]["config"]["device"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["device_name"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["multi_gpu"] = True
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# specify directory for logging experiments
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"])
......@@ -88,19 +99,6 @@ def main():
agent_cfg["params"]["config"]["train_dir"] = log_root_path
agent_cfg["params"]["config"]["full_experiment_name"] = log_dir
# multi-gpu training config
if args_cli.distributed:
agent_cfg["params"]["seed"] += app_launcher.global_rank
agent_cfg["params"]["config"]["device"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["device_name"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["multi_gpu"] = True
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# max iterations
if args_cli.max_iterations:
agent_cfg["params"]["config"]["max_epochs"] = args_cli.max_iterations
# dump the configuration into log-directory
dump_yaml(os.path.join(log_root_path, log_dir, "params", "env.yaml"), env_cfg)
dump_yaml(os.path.join(log_root_path, log_dir, "params", "agent.yaml"), agent_cfg)
......@@ -135,17 +133,6 @@ def main():
)
env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
if args_cli.checkpoint is not None:
resume_path = retrieve_file_path(args_cli.checkpoint)
agent_cfg["params"]["load_checkpoint"] = True
agent_cfg["params"]["load_path"] = resume_path
print(f"[INFO]: Loading model checkpoint from: {agent_cfg['params']['load_path']}")
if args_cli.sigma is not None:
train_sigma = float(args_cli.sigma)
else:
train_sigma = None
# set number of actors into agent config
agent_cfg["params"]["config"]["num_actors"] = env.unwrapped.num_envs
# create runner from rl-games
......
......@@ -52,23 +52,36 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol
# load the default configuration
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
return rslrl_cfg
def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
"""Update configuration for RSL-RL agent based on inputs.
Args:
agent_cfg: The configuration for RSL-RL agent.
args_cli: The command line arguments.
Returns:
The updated configuration for RSL-RL agent based on inputs.
"""
# override the default configuration with CLI arguments
if args_cli.seed is not None:
rslrl_cfg.seed = args_cli.seed
agent_cfg.seed = args_cli.seed
if args_cli.resume is not None:
rslrl_cfg.resume = args_cli.resume
agent_cfg.resume = args_cli.resume
if args_cli.load_run is not None:
rslrl_cfg.load_run = args_cli.load_run
agent_cfg.load_run = args_cli.load_run
if args_cli.checkpoint is not None:
rslrl_cfg.load_checkpoint = args_cli.checkpoint
agent_cfg.load_checkpoint = args_cli.checkpoint
if args_cli.run_name is not None:
rslrl_cfg.run_name = args_cli.run_name
agent_cfg.run_name = args_cli.run_name
if args_cli.logger is not None:
rslrl_cfg.logger = args_cli.logger
agent_cfg.logger = args_cli.logger
# set the project name for wandb and neptune
if rslrl_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
rslrl_cfg.wandb_project = args_cli.log_project_name
rslrl_cfg.neptune_project = args_cli.log_project_name
if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
agent_cfg.wandb_project = args_cli.log_project_name
agent_cfg.neptune_project = args_cli.log_project_name
return rslrl_cfg
return agent_cfg
......@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first."""
import argparse
import sys
from omni.isaac.lab.app import AppLauncher
......@@ -20,9 +21,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
......@@ -31,11 +29,15 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy
cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
......@@ -49,12 +51,13 @@ from datetime import datetime
from rsl_rl.runners import OnPolicyRunner
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg
from omni.isaac.lab_tasks.utils import get_checkpoint_path
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper
torch.backends.cuda.matmul.allow_tf32 = True
......@@ -63,13 +66,15 @@ torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
def main():
@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
"""Train with RSL-RL agent."""
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
# override configurations with non-hydra CLI arguments
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg.max_iterations = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg.max_iterations
)
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)
# specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
......@@ -81,10 +86,6 @@ def main():
log_dir += f"_{agent_cfg.run_name}"
log_dir = os.path.join(log_root_path, log_dir)
# max iterations for training
if args_cli.max_iterations:
agent_cfg.max_iterations = args_cli.max_iterations
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
......
......@@ -13,6 +13,7 @@ there will be significant overhead in GPU->CPU transfer.
"""Launch Isaac Sim Simulator first."""
import argparse
import sys
from omni.isaac.lab.app import AppLauncher
......@@ -21,9 +22,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with Stable-Base
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
......@@ -31,11 +29,14 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
......@@ -52,28 +53,23 @@ from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import VecNormalize
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
def main():
@hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
"""Train with stable-baselines agent."""
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point")
# override configuration with command line arguments
if args_cli.seed is not None:
agent_cfg["seed"] = args_cli.seed
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
# max iterations for training
if args_cli.max_iterations:
if args_cli.max_iterations is not None:
agent_cfg["n_timesteps"] = args_cli.max_iterations * agent_cfg["n_steps"] * env_cfg.scene.num_envs
# directory for logging into
......
......@@ -14,6 +14,7 @@ a more user-friendly way.
import argparse
import sys
from omni.isaac.lab.app import AppLauncher
......@@ -22,9 +23,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with skrl.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
......@@ -43,11 +41,14 @@ parser.add_argument(
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# always enable cameras to record video
args_cli, hydra_args = parser.parse_known_args()
if args_cli.video:
args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
......@@ -72,59 +73,53 @@ elif args_cli.ml_framework.startswith("jax"):
from skrl.trainers.jax import SequentialTrainer
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg
def main():
@hydra_task_config(args_cli.task, "skrl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
"""Train with skrl agent."""
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
set_seed(args_cli.seed if args_cli.seed is not None else agent_cfg["seed"])
# multi-gpu training config
if args_cli.distributed:
if args_cli.ml_framework.startswith("jax"):
raise ValueError("Multi-GPU distributed training not yet supported in JAX")
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# max iterations for training
if args_cli.max_iterations:
agent_cfg["trainer"]["timesteps"] = args_cli.max_iterations * agent_cfg["agent"]["rollouts"]
# configure the ML framework into the global skrl variable
if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# read the seed from command line
args_cli_seed = args_cli.seed
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
experiment_cfg = load_cfg_from_registry(args_cli.task, "skrl_cfg_entry_point")
# specify directory for logging experiments
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.join("logs", "skrl", agent_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Logging experiment in directory: {log_root_path}")
# specify directory for logging runs: {time-stamp}_{run_name}
log_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if experiment_cfg["agent"]["experiment"]["experiment_name"]:
log_dir += f'_{experiment_cfg["agent"]["experiment"]["experiment_name"]}'
if agent_cfg["agent"]["experiment"]["experiment_name"]:
log_dir += f'_{agent_cfg["agent"]["experiment"]["experiment_name"]}'
# set directory into agent config
experiment_cfg["agent"]["experiment"]["directory"] = log_root_path
experiment_cfg["agent"]["experiment"]["experiment_name"] = log_dir
agent_cfg["agent"]["experiment"]["directory"] = log_root_path
agent_cfg["agent"]["experiment"]["experiment_name"] = log_dir
# update log_dir
log_dir = os.path.join(log_root_path, log_dir)
# multi-gpu training config
if args_cli.distributed:
if args_cli.ml_framework.startswith("jax"):
raise ValueError("Multi-GPU distributed training not yet supported in JAX")
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# max iterations for training
if args_cli.max_iterations:
experiment_cfg["trainer"]["timesteps"] = args_cli.max_iterations * experiment_cfg["agent"]["rollouts"]
# dump the configuration into log-directory
dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), experiment_cfg)
dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), experiment_cfg)
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
......@@ -142,27 +137,24 @@ def main():
# wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
# set seed for the experiment (override from command line)
set_seed(args_cli_seed if args_cli_seed is not None else experiment_cfg["seed"])
# instantiate models using skrl model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models = {}
if args_cli.ml_framework.startswith("jax"):
experiment_cfg["models"]["separate"] = True # shared model is not supported in JAX
agent_cfg["models"]["separate"] = True # shared model is not supported in JAX
# non-shared models
if experiment_cfg["models"]["separate"]:
if agent_cfg["models"]["separate"]:
models["policy"] = gaussian_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
**process_skrl_cfg(agent_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
)
models["value"] = deterministic_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
**process_skrl_cfg(agent_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
)
# shared models
else:
......@@ -173,8 +165,8 @@ def main():
structure=None,
roles=["policy", "value"],
parameters=[
process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(agent_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(agent_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
],
)
models["value"] = models["policy"]
......@@ -185,22 +177,22 @@ def main():
# instantiate a RandomMemory as rollout buffer (any memory can be used for this)
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory_size = experiment_cfg["agent"]["rollouts"] # memory_size is the agent's number of rollouts
memory_size = agent_cfg["agent"]["rollouts"] # memory_size is the agent's number of rollouts
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=env.device)
# configure and instantiate PPO agent
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent_cfg = PPO_DEFAULT_CONFIG.copy()
experiment_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"], ml_framework=args_cli.ml_framework))
default_agent_cfg = PPO_DEFAULT_CONFIG.copy()
agent_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
default_agent_cfg.update(process_skrl_cfg(agent_cfg["agent"], ml_framework=args_cli.ml_framework))
agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})
default_agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
default_agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})
agent = PPO(
models=models,
memory=memory,
cfg=agent_cfg,
cfg=default_agent_cfg,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
......@@ -208,7 +200,7 @@ def main():
# configure and instantiate a custom RL trainer for logging episode events
# https://skrl.readthedocs.io/en/latest/api/trainers.html
trainer_cfg = experiment_cfg["trainer"]
trainer_cfg = agent_cfg["trainer"]
trainer_cfg["close_environment_at_exit"] = False
trainer = SequentialTrainer(cfg=trainer_cfg, env=env, agents=agent)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment