Unverified Commit bf068c83 authored by David Hoeller's avatar David Hoeller Committed by GitHub

Adds the Hydra configuration system for RL training (#700)

# Description

This MR adds utilities to enable the hydra configuration system.
Using the `train.py` scripts, the user can now change any parameter in
the environment or agent configs from command line inputs, for example:
``` 
python source/standalone/workflows/rsl_rl/train.py --task Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0
```

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [ ] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarDavid Hoeller <dhoeller@nvidia.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 173b5f3e
...@@ -70,6 +70,7 @@ Table of Contents ...@@ -70,6 +70,7 @@ Table of Contents
:caption: Features :caption: Features
source/features/task_workflows source/features/task_workflows
source/features/hydra
source/features/multi_gpu source/features/multi_gpu
source/features/tiled_rendering source/features/tiled_rendering
source/features/environments source/features/environments
......
Hydra Configuration System
==========================
.. currentmodule:: omni.isaac.lab
Isaac Lab supports the `Hydra <https://hydra.cc/docs/intro/>`_ configuration system to modify the task's
configuration using command line arguments, which can be useful to automate experiments and perform hyperparameter tuning.
Any parameter of the environment can be modified by adding one or multiple elements of the form ``env.a.b.param1=value``
to the command line input, where ``a.b.param1`` reflects the parameter's hierarchy, for example ``env.actions.joint_effort.scale=10.0``.
Similarly, the agent's parameters can be modified by using the ``agent`` prefix, for example ``agent.seed=2024``.
.. note::
The way these command line arguments are set follow the exact structure of the configuration files. Since the different
RL frameworks use different conventions, there might be differences in the way the parameters are set. For example,
with `rl_games` the seed will be set with ``agent.params.seed``, while with `rsl_rl` and `skrl` it will be set with
``agent.seed``.
As a result, training with hydra arguments can be run with the following syntax:
.. tab-set::
:sync-group: rl-train
.. tab-item:: rsl_rl
:sync: rsl_rl
.. code-block:: shell
python source/standalone/workflows/rsl_rl/train.py --task=Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0 agent.seed=2024
.. tab-item:: rl_games
:sync: rl_games
.. code-block:: shell
python source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0 agent.params.seed=2024
.. tab-item:: skrl
:sync: skrl
.. code-block:: shell
python source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless env.actions.joint_effort.scale=10.0 agent.seed=2024
The above command will run the training script with the task ``Isaac-Cartpole-v0`` in headless mode, and set the
``env.actions.joint_effort.scale`` parameter to 10.0 and the ``agent.seed`` parameter to 2024.
.. note::
To keep backwards compatibility, and to provide a more user-friendly experience, we have kept the old cli arguments
of the form ``--param``, for example ``--num_envs``, ``--seed``, ``--max_iterations``. These arguments have precedence
over the hydra arguments, and will overwrite the values set by the hydra arguments.
.. attention::
Particular care should be taken when modifying the parameters using command line arguments. Some of the configurations
perform intermediate computations based on other parameters. These computations will not be updated when the parameters
are modified.
For example, for the configuration of the Cartpole camera depth environment:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/cartpole/cartpole_camera_env.py
:language: python
:start-at: class CartpoleDepthCameraEnvCfg
:end-at: tiled_camera.width
:emphasize-lines: 16
If the user were to modify the width of the camera, i.e. ``env.tiled_camera.width=128``, then the parameter
``env.num_observations=10240`` (1*80*128) must be updated and given as input as well.
Similarly, the ``__post_init__`` method is not updated with the command line inputs. In the ``LocomotionVelocityRoughEnvCfg``, for example,
the post init update is as follows:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/locomotion/velocity/velocity_env_cfg.py
:language: python
:start-at: class LocomotionVelocityRoughEnvCfg
:emphasize-lines: 23, 29, 31
Here, when modifying ``env.decimation`` or ``env.sim.dt``, the user would have to give the updated ``env.sim.render_interval``,
``env.scene.height_scanner.update_period``, and ``env.scene.contact_forces.update_period`` as input as well.
Modifying advanced parameters
-----------------------------
Callables
^^^^^^^^^
It is possible to modify functions and classes in the configuration files by using the syntax ``module:attribute_name``.
For example, in the Cartpole environment:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py
:language: python
:start-at: class ObservationsCfg
:end-at: policy: PolicyCfg = PolicyCfg()
:emphasize-lines: 9
we could modify ``joint_pos_rel`` to compute absolute positions instead of relative positions with
``env.observations.policy.joint_pos_rel.func=omni.isaac.lab.envs.mdp:joint_pos``.
Setting parameters to None
^^^^^^^^^^^^^^^^^^^^^^^^^^
To set parameters to None, use the ``null`` keyword, which is a special keyword in Hydra that is automatically converted to None.
In the above example, we could also disable the ``joint_pos_rel`` observation by setting it to None with
``env.observations.policy.joint_pos_rel.func=null``.
Dictionaries
^^^^^^^^^^^^
Elements in dictionaries are handled as a parameters in the hierarchy. For example, in the Cartpole environment:
.. literalinclude:: ../../../source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py
:language: python
:lines: 99-111
:emphasize-lines: 10
the ``position_range`` parameter can be modified with ``env.events.reset_cart_position.params.position_range="[-2.0, 2.0]"``.
This example shows two noteworthy points:
- The parameter we set has a space, so it must be enclosed in quotes.
- The parameter is a list while it is a tuple in the config. This is due to the fact that Hydra does not support tuples.
...@@ -482,13 +482,12 @@ class AppLauncher: ...@@ -482,13 +482,12 @@ class AppLauncher:
if launcher_args.get("cpu", False): if launcher_args.get("cpu", False):
raise ValueError("The `--cpu` flag is deprecated. Please use `--device cpu` instead.") raise ValueError("The `--cpu` flag is deprecated. Please use `--device cpu` instead.")
if "distributed" in launcher_args: if "distributed" in launcher_args and launcher_args["distributed"]:
distributed_train = launcher_args["distributed"]
# local rank (GPU id) in a current multi-gpu mode # local rank (GPU id) in a current multi-gpu mode
self.local_rank = int(os.getenv("LOCAL_RANK", "0")) self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
# global rank (GPU id) in multi-gpu multi-node mode # global rank (GPU id) in multi-gpu multi-node mode
self.global_rank = int(os.getenv("RANK", "0")) self.global_rank = int(os.getenv("RANK", "0"))
if distributed_train:
self.device_id = self.local_rank self.device_id = self.local_rank
launcher_args["multi_gpu"] = False launcher_args["multi_gpu"] = False
# limit CPU threads to minimize thread context switching # limit CPU threads to minimize thread context switching
...@@ -570,6 +569,9 @@ class AppLauncher: ...@@ -570,6 +569,9 @@ class AppLauncher:
# add Isaac Lab modules back to sys.modules # add Isaac Lab modules back to sys.modules
for key, value in hacked_modules.items(): for key, value in hacked_modules.items():
sys.modules[key] = value sys.modules[key] = value
# remove the threadCount argument from sys.argv if it was added for distributed training
pattern = r"--/plugins/carb\.tasking\.plugin/threadCount=\d+"
sys.argv = [arg for arg in sys.argv if not re.match(pattern, arg)]
def _rendering_enabled(self) -> bool: def _rendering_enabled(self) -> bool:
"""Check if rendering is required by the app.""" """Check if rendering is required by the app."""
......
...@@ -107,7 +107,7 @@ class RewardManager(ManagerBase): ...@@ -107,7 +107,7 @@ class RewardManager(ManagerBase):
# store information # store information
# r_1 + r_2 + ... + r_n # r_1 + r_2 + ... + r_n
episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids]) episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids])
extras["Episode Reward/" + key] = episodic_sum_avg / self._env.max_episode_length_s extras["Episode_Reward/" + key] = episodic_sum_avg / self._env.max_episode_length_s
# reset episodic sum # reset episodic sum
self._episode_sums[key][env_ids] = 0.0 self._episode_sums[key][env_ids] = 0.0
# reset all the reward terms # reset all the reward terms
......
...@@ -135,7 +135,7 @@ class TerminationManager(ManagerBase): ...@@ -135,7 +135,7 @@ class TerminationManager(ManagerBase):
extras = {} extras = {}
for key in self._term_dones.keys(): for key in self._term_dones.keys():
# store information # store information
extras["Episode Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item() extras["Episode_Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
# reset all the reward terms # reset all the reward terms
for term_cfg in self._class_term_cfgs: for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids) term_cfg.func.reset(env_ids=env_ids)
......
...@@ -113,6 +113,9 @@ These are redefined here to add new docstrings. ...@@ -113,6 +113,9 @@ These are redefined here to add new docstrings.
def _class_to_dict(obj: object) -> dict[str, Any]: def _class_to_dict(obj: object) -> dict[str, Any]:
"""Convert an object into dictionary recursively. """Convert an object into dictionary recursively.
Args:
obj: The object to convert.
Returns: Returns:
Converted dictionary mapping. Converted dictionary mapping.
""" """
...@@ -125,6 +128,7 @@ def _update_class_from_dict(obj, data: dict[str, Any]) -> None: ...@@ -125,6 +128,7 @@ def _update_class_from_dict(obj, data: dict[str, Any]) -> None:
This function performs in-place update of the class member attributes. This function performs in-place update of the class member attributes.
Args: Args:
obj: The object to update.
data: Input (nested) dictionary to update from. data: Input (nested) dictionary to update from.
Raises: Raises:
...@@ -132,7 +136,7 @@ def _update_class_from_dict(obj, data: dict[str, Any]) -> None: ...@@ -132,7 +136,7 @@ def _update_class_from_dict(obj, data: dict[str, Any]) -> None:
ValueError: When dictionary has a value that does not match default config type. ValueError: When dictionary has a value that does not match default config type.
KeyError: When dictionary has a key that does not exist in the default config type. KeyError: When dictionary has a key that does not exist in the default config type.
""" """
return update_class_from_dict(obj, data, _ns="") update_class_from_dict(obj, data, _ns="")
def _replace_class_with_kwargs(obj: object, **kwargs) -> object: def _replace_class_with_kwargs(obj: object, **kwargs) -> object:
......
...@@ -12,7 +12,7 @@ from collections.abc import Iterable, Mapping ...@@ -12,7 +12,7 @@ from collections.abc import Iterable, Mapping
from typing import Any from typing import Any
from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES
from .string import callable_to_string, string_to_callable from .string import callable_to_string, string_to_callable, string_to_slice
""" """
Dictionary <-> Class operations. Dictionary <-> Class operations.
...@@ -42,6 +42,7 @@ def class_to_dict(obj: object) -> dict[str, Any]: ...@@ -42,6 +42,7 @@ def class_to_dict(obj: object) -> dict[str, Any]:
obj_dict = obj obj_dict = obj
else: else:
obj_dict = obj.__dict__ obj_dict = obj.__dict__
# convert to dictionary # convert to dictionary
data = dict() data = dict()
for key, value in obj_dict.items(): for key, value in obj_dict.items():
...@@ -79,39 +80,36 @@ def update_class_from_dict(obj, data: dict[str, Any], _ns: str = "") -> None: ...@@ -79,39 +80,36 @@ def update_class_from_dict(obj, data: dict[str, Any], _ns: str = "") -> None:
# key_ns is the full namespace of the key # key_ns is the full namespace of the key
key_ns = _ns + "/" + key key_ns = _ns + "/" + key
# check if key is present in the object # check if key is present in the object
if hasattr(obj, key): if hasattr(obj, key) or isinstance(obj, dict):
obj_mem = getattr(obj, key) obj_mem = obj[key] if isinstance(obj, dict) else getattr(obj, key)
if isinstance(obj_mem, Mapping): if isinstance(value, Mapping):
# Note: We don't handle two-level nested dictionaries. Just use configclass if this is needed.
# iterate over the dictionary to look for callable values
for k, v in obj_mem.items():
if callable(v):
value[k] = string_to_callable(value[k])
setattr(obj, key, value)
elif isinstance(value, Mapping):
# recursively call if it is a dictionary # recursively call if it is a dictionary
update_class_from_dict(obj_mem, value, _ns=key_ns) update_class_from_dict(obj_mem, value, _ns=key_ns)
elif isinstance(value, Iterable) and not isinstance(value, str): continue
if isinstance(value, Iterable) and not isinstance(value, str):
# check length of value to be safe # check length of value to be safe
if len(obj_mem) != len(value) and obj_mem is not None: if len(obj_mem) != len(value) and obj_mem is not None:
raise ValueError( raise ValueError(
f"[Config]: Incorrect length under namespace: {key_ns}." f"[Config]: Incorrect length under namespace: {key_ns}."
f" Expected: {len(obj_mem)}, Received: {len(value)}." f" Expected: {len(obj_mem)}, Received: {len(value)}."
) )
# set value if isinstance(obj_mem, tuple):
setattr(obj, key, value) value = tuple(value)
elif callable(obj_mem): elif callable(obj_mem):
# update function name # update function name
value = string_to_callable(value) value = string_to_callable(value)
setattr(obj, key, value) elif isinstance(value, type(obj_mem)) or value is None:
elif isinstance(value, type(obj_mem)): pass
# check that they are type-safe
setattr(obj, key, value)
else: else:
raise ValueError( raise ValueError(
f"[Config]: Incorrect type under namespace: {key_ns}." f"[Config]: Incorrect type under namespace: {key_ns}."
f" Expected: {type(obj_mem)}, Received: {type(value)}." f" Expected: {type(obj_mem)}, Received: {type(value)}."
) )
# set value
if isinstance(obj, dict):
obj[key] = value
else:
setattr(obj, key, value)
else: else:
raise KeyError(f"[Config]: Key not found under namespace: {key_ns}.") raise KeyError(f"[Config]: Key not found under namespace: {key_ns}.")
...@@ -237,6 +235,40 @@ def update_dict(orig_dict: dict, new_dict: collections.abc.Mapping) -> dict: ...@@ -237,6 +235,40 @@ def update_dict(orig_dict: dict, new_dict: collections.abc.Mapping) -> dict:
return orig_dict return orig_dict
def replace_slices_with_strings(data: dict) -> dict:
"""Replace slice objects with their string representations in a dictionary.
Args:
data: The dictionary to process.
Returns:
The dictionary with slice objects replaced by their string representations.
"""
if isinstance(data, dict):
return {k: replace_slices_with_strings(v) for k, v in data.items()}
elif isinstance(data, slice):
return f"slice({data.start},{data.stop},{data.step})"
else:
return data
def replace_strings_with_slices(data: dict) -> dict:
"""Replace string representations of slices with slice objects in a dictionary.
Args:
data: The dictionary to process.
Returns:
The dictionary with string representations of slices replaced by slice objects.
"""
if isinstance(data, dict):
return {k: replace_strings_with_slices(v) for k, v in data.items()}
elif isinstance(data, str) and data.startswith("slice("):
return string_to_slice(data)
else:
return data
def print_dict(val, nesting: int = -4, start: bool = True): def print_dict(val, nesting: int = -4, start: bool = True):
"""Outputs a nested dictionary.""" """Outputs a nested dictionary."""
if isinstance(val, dict): if isinstance(val, dict):
......
...@@ -58,6 +58,32 @@ def to_snake_case(camel_str: str) -> str: ...@@ -58,6 +58,32 @@ def to_snake_case(camel_str: str) -> str:
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", camel_str).lower() return re.sub("([a-z0-9])([A-Z])", r"\1_\2", camel_str).lower()
def string_to_slice(s: str):
"""Convert a string representation of a slice to a slice object.
Args:
s: The string representation of the slice.
Returns:
The slice object.
"""
# extract the content inside the slice()
match = re.match(r"slice\((.*),(.*),(.*)\)", s)
if not match:
raise ValueError(f"Invalid slice string format: {s}")
# extract start, stop, and step values
start_str, stop_str, step_str = match.groups()
# convert 'None' to None and other strings to integers
start = None if start_str == "None" else int(start_str)
stop = None if stop_str == "None" else int(stop_str)
step = None if step_str == "None" else int(step_str)
# create and return the slice object
return slice(start, stop, step)
""" """
String <-> Callable operations. String <-> Callable operations.
""" """
......
...@@ -292,6 +292,18 @@ class FunctionImplementedDemoCfg: ...@@ -292,6 +292,18 @@ class FunctionImplementedDemoCfg:
self.a = a self.a = a
"""
Dummy configuration: Nested dictionaries
"""
@configclass
class NestedDictCfg:
"""Dummy configuration class with nested dictionaries."""
dict_1: dict = {"dict_2": {"func": dummy_function1}}
""" """
Test solutions: Basic Test solutions: Basic
""" """
...@@ -318,6 +330,23 @@ basic_demo_cfg_change_correct = { ...@@ -318,6 +330,23 @@ basic_demo_cfg_change_correct = {
"device_id": 0, "device_id": 0,
} }
basic_demo_cfg_change_with_none_correct = {
"env": {"num_envs": 22, "episode_length": 2000, "viewer": None},
"robot_default_state": {
"pos": (0.0, 0.0, 0.0),
"rot": (1.0, 0.0, 0.0, 0.0),
"dof_pos": (0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
"dof_vel": [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
},
"device_id": 0,
}
basic_demo_cfg_nested_dict = {
"dict_1": {
"dict_2": {"func": dummy_function2},
},
}
basic_demo_post_init_cfg_correct = { basic_demo_post_init_cfg_correct = {
"env": {"num_envs": 56, "episode_length": 2000, "viewer": {"eye": [7.5, 7.5, 7.5], "lookat": [0.0, 0.0, 0.0]}}, "env": {"num_envs": 56, "episode_length": 2000, "viewer": {"eye": [7.5, 7.5, 7.5], "lookat": [0.0, 0.0, 0.0]}},
"robot_default_state": { "robot_default_state": {
...@@ -427,6 +456,20 @@ class TestConfigClass(unittest.TestCase): ...@@ -427,6 +456,20 @@ class TestConfigClass(unittest.TestCase):
update_class_from_dict(cfg, cfg_dict) update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_correct) self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_correct)
def test_config_update_dict_with_none(self):
"""Test updating configclass using a dictionary that contains None."""
cfg = BasicDemoCfg()
cfg_dict = {"env": {"num_envs": 22, "viewer": None}}
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_with_none_correct)
def test_config_update_nested_dict(self):
"""Test updating configclass with sub-dictionnaries."""
cfg = NestedDictCfg()
cfg_dict = {"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}}}
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict)
def test_config_update_dict_using_internal(self): def test_config_update_dict_using_internal(self):
"""Test updating configclass from a dictionary using configclass method.""" """Test updating configclass from a dictionary using configclass method."""
cfg = BasicDemoCfg() cfg = BasicDemoCfg()
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.9.0" version = "0.10.0"
# Description # Description
title = "Isaac Lab Environments" title = "Isaac Lab Environments"
......
Changelog Changelog
--------- ---------
0.10.0 (2024-08-14)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for the Hydra configuration system to all the train scripts. As a result, parameters of the environment
and the agent can be modified using command line arguments, for example ``env.actions.joint_effort.scale=10``.
0.9.0 (2024-08-05) 0.9.0 (2024-08-05)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -299,11 +299,11 @@ class AnymalCEnv(DirectRLEnv): ...@@ -299,11 +299,11 @@ class AnymalCEnv(DirectRLEnv):
extras = dict() extras = dict()
for key in self._episode_sums.keys(): for key in self._episode_sums.keys():
episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids]) episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids])
extras["Episode Reward/" + key] = episodic_sum_avg / self.max_episode_length_s extras["Episode_Reward/" + key] = episodic_sum_avg / self.max_episode_length_s
self._episode_sums[key][env_ids] = 0.0 self._episode_sums[key][env_ids] = 0.0
self.extras["log"] = dict() self.extras["log"] = dict()
self.extras["log"].update(extras) self.extras["log"].update(extras)
extras = dict() extras = dict()
extras["Episode Termination/base_contact"] = torch.count_nonzero(self.reset_terminated[env_ids]).item() extras["Episode_Termination/base_contact"] = torch.count_nonzero(self.reset_terminated[env_ids]).item()
extras["Episode Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item() extras["Episode_Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item()
self.extras["log"].update(extras) self.extras["log"].update(extras)
...@@ -199,13 +199,13 @@ class QuadcopterEnv(DirectRLEnv): ...@@ -199,13 +199,13 @@ class QuadcopterEnv(DirectRLEnv):
extras = dict() extras = dict()
for key in self._episode_sums.keys(): for key in self._episode_sums.keys():
episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids]) episodic_sum_avg = torch.mean(self._episode_sums[key][env_ids])
extras["Episode Reward/" + key] = episodic_sum_avg / self.max_episode_length_s extras["Episode_Reward/" + key] = episodic_sum_avg / self.max_episode_length_s
self._episode_sums[key][env_ids] = 0.0 self._episode_sums[key][env_ids] = 0.0
self.extras["log"] = dict() self.extras["log"] = dict()
self.extras["log"].update(extras) self.extras["log"].update(extras)
extras = dict() extras = dict()
extras["Episode Termination/died"] = torch.count_nonzero(self.reset_terminated[env_ids]).item() extras["Episode_Termination/died"] = torch.count_nonzero(self.reset_terminated[env_ids]).item()
extras["Episode Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item() extras["Episode_Termination/time_out"] = torch.count_nonzero(self.reset_time_outs[env_ids]).item()
extras["Metrics/final_distance_to_goal"] = final_distance_to_goal.item() extras["Metrics/final_distance_to_goal"] = final_distance_to_goal.item()
self.extras["log"].update(extras) self.extras["log"].update(extras)
......
...@@ -92,6 +92,7 @@ class H1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): ...@@ -92,6 +92,7 @@ class H1RoughEnvCfg(LocomotionVelocityRoughEnvCfg):
super().__post_init__() super().__post_init__()
# Scene # Scene
self.scene.robot = H1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") self.scene.robot = H1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
if self.scene.height_scanner:
self.scene.height_scanner.prim_path = "{ENV_REGEX_NS}/Robot/torso_link" self.scene.height_scanner.prim_path = "{ENV_REGEX_NS}/Robot/torso_link"
# Randomization # Randomization
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Sub-module with utilities for the hydra configuration system."""
import functools
from collections.abc import Callable
try:
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf
except ImportError:
raise ImportError("Hydra is not installed. Please install it by running 'pip install hydra-core'.")
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils import replace_slices_with_strings, replace_strings_with_slices
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry
def register_task_to_hydra(
task_name: str, agent_cfg_entry_point: str
) -> tuple[ManagerBasedRLEnvCfg | DirectRLEnvCfg, dict]:
"""Register the task configuration to the Hydra configuration store.
This function resolves the configuration file for the environment and agent based on the task's name.
It then registers the configurations to the Hydra configuration store.
Args:
task_name: The name of the task.
agent_cfg_entry_point: The entry point key to resolve the agent's configuration file.
Returns:
A tuple containing the parsed environment and agent configuration objects.
"""
# load the configurations
env_cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point")
agent_cfg = load_cfg_from_registry(task_name, agent_cfg_entry_point)
# convert the configs to dictionary
env_cfg_dict = env_cfg.to_dict()
if isinstance(agent_cfg, dict):
agent_cfg_dict = agent_cfg
else:
agent_cfg_dict = agent_cfg.to_dict()
cfg_dict = {"env": env_cfg_dict, "agent": agent_cfg_dict}
# replace slices with strings because OmegaConf does not support slices
cfg_dict = replace_slices_with_strings(cfg_dict)
# store the configuration to Hydra
ConfigStore.instance().store(name=task_name, node=cfg_dict)
return env_cfg, agent_cfg
def hydra_task_config(task_name: str, agent_cfg_entry_point: str) -> Callable:
"""Decorator to handle the Hydra configuration for a task.
This decorator registers the task to Hydra and updates the environment and agent configurations from Hydra parsed
command line arguments.
Args:
task_name: The name of the task.
agent_cfg_entry_point: The entry point key to resolve the agent's configuration file.
Returns:
The decorated function with the envrionment's and agent's configurations updated from command line arguments.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# register the task to Hydra
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)
# define thr new Hydra main function
@hydra.main(config_path=None, config_name=task_name, version_base="1.3")
def hydra_main(hydra_env_cfg: DictConfig, env_cfg=env_cfg, agent_cfg=agent_cfg):
# convert to a native dictionary
hydra_env_cfg = OmegaConf.to_container(hydra_env_cfg, resolve=True)
# replace string with slices because OmegaConf does not support slices
hydra_env_cfg = replace_strings_with_slices(hydra_env_cfg)
# update the configs with the Hydra command line arguments
env_cfg.from_dict(hydra_env_cfg["env"])
if isinstance(agent_cfg, dict):
agent_cfg = hydra_env_cfg["agent"]
else:
agent_cfg.from_dict(hydra_env_cfg["agent"])
# call the original function
func(env_cfg, agent_cfg, *args, **kwargs)
# call the new Hydra main function
hydra_main()
return wrapper
return decorator
...@@ -26,6 +26,8 @@ INSTALL_REQUIRES = [ ...@@ -26,6 +26,8 @@ INSTALL_REQUIRES = [
# 5.26.0 introduced a breaking change, so we restricted it for now. # 5.26.0 introduced a breaking change, so we restricted it for now.
# See issue https://github.com/tensorflow/tensorboard/issues/6808 for details. # See issue https://github.com/tensorflow/tensorboard/issues/6808 for details.
"protobuf>=3.20.2, < 5.0.0", "protobuf>=3.20.2, < 5.0.0",
# configuration management
"hydra-core",
# data collection # data collection
"h5py", "h5py",
# basic logger # basic logger
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Launch Isaac Sim Simulator first."""
import sys
from omni.isaac.lab.app import AppLauncher, run_tests
# launch the simulator
app_launcher = AppLauncher(headless=True)
simulation_app = app_launcher.app
"""Rest everything follows."""
import unittest
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
class TestHydra(unittest.TestCase):
"""Test the Hydra configuration system."""
def test_hydra(self):
"""Test the hydra configuration system."""
# set hardcoded command line arguments
sys.argv = [
sys.argv[0],
"env.decimation=42", # test simple env modification
"env.events.physics_material.params.asset_cfg.joint_ids='slice(0 ,1, 2)'", # test slice setting
"env.scene.robot.init_state.joint_vel={.*: 4.0}", # test regex setting
"env.rewards.feet_air_time=null", # test setting to none
"agent.max_iterations=3", # test simple agent modification
]
@hydra_task_config("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
def main(env_cfg, agent_cfg, self):
# env
self.assertEqual(env_cfg.decimation, 42)
self.assertEqual(env_cfg.events.physics_material.params["asset_cfg"].joint_ids, slice(0, 1, 2))
self.assertEqual(env_cfg.scene.robot.init_state.joint_vel, {".*": 4.0})
self.assertIsNone(env_cfg.rewards.feet_air_time)
# agent
self.assertEqual(agent_cfg.max_iterations, 3)
main(self)
if __name__ == "__main__":
run_tests()
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys
from omni.isaac.lab.app import AppLauncher from omni.isaac.lab.app import AppLauncher
...@@ -16,17 +17,12 @@ parser = argparse.ArgumentParser(description="Train an RL agent with RL-Games.") ...@@ -16,17 +17,12 @@ parser = argparse.ArgumentParser(description="Train an RL agent with RL-Games.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).") parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument( parser.add_argument(
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes." "--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
) )
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.") parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument("--sigma", type=str, default=None, help="The policy's initial standard deviation.") parser.add_argument("--sigma", type=str, default=None, help="The policy's initial standard deviation.")
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.") parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
...@@ -34,11 +30,14 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy ...@@ -34,11 +30,14 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -54,28 +53,40 @@ from rl_games.common import env_configurations, vecenv ...@@ -54,28 +53,40 @@ from rl_games.common import env_configurations, vecenv
from rl_games.common.algo_observer import IsaacAlgoObserver from rl_games.common.algo_observer import IsaacAlgoObserver
from rl_games.torch_runner import Runner from rl_games.torch_runner import Runner
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.assets import retrieve_file_path from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesGpuEnv, RlGamesVecEnvWrapper from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesGpuEnv, RlGamesVecEnvWrapper
def main(): @hydra_task_config(args_cli.task, "rl_games_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
"""Train with RL-Games agent.""" """Train with RL-Games agent."""
# parse seed from command line # override configurations with non-hydra CLI arguments
args_cli_seed = args_cli.seed env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"]
# parse configuration agent_cfg["params"]["config"]["max_epochs"] = (
env_cfg = parse_env_cfg( args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"]
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
) )
agent_cfg = load_cfg_from_registry(args_cli.task, "rl_games_cfg_entry_point") if args_cli.checkpoint is not None:
# override from command line resume_path = retrieve_file_path(args_cli.checkpoint)
if args_cli_seed is not None: agent_cfg["params"]["load_checkpoint"] = True
agent_cfg["params"]["seed"] = args_cli_seed agent_cfg["params"]["load_path"] = resume_path
print(f"[INFO]: Loading model checkpoint from: {agent_cfg['params']['load_path']}")
train_sigma = float(args_cli.sigma) if args_cli.sigma is not None else None
# multi-gpu training config
if args_cli.distributed:
agent_cfg["params"]["seed"] += app_launcher.global_rank
agent_cfg["params"]["config"]["device"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["device_name"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["multi_gpu"] = True
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"]) log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"])
...@@ -88,19 +99,6 @@ def main(): ...@@ -88,19 +99,6 @@ def main():
agent_cfg["params"]["config"]["train_dir"] = log_root_path agent_cfg["params"]["config"]["train_dir"] = log_root_path
agent_cfg["params"]["config"]["full_experiment_name"] = log_dir agent_cfg["params"]["config"]["full_experiment_name"] = log_dir
# multi-gpu training config
if args_cli.distributed:
agent_cfg["params"]["seed"] += app_launcher.global_rank
agent_cfg["params"]["config"]["device"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["device_name"] = f"cuda:{app_launcher.local_rank}"
agent_cfg["params"]["config"]["multi_gpu"] = True
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# max iterations
if args_cli.max_iterations:
agent_cfg["params"]["config"]["max_epochs"] = args_cli.max_iterations
# dump the configuration into log-directory # dump the configuration into log-directory
dump_yaml(os.path.join(log_root_path, log_dir, "params", "env.yaml"), env_cfg) dump_yaml(os.path.join(log_root_path, log_dir, "params", "env.yaml"), env_cfg)
dump_yaml(os.path.join(log_root_path, log_dir, "params", "agent.yaml"), agent_cfg) dump_yaml(os.path.join(log_root_path, log_dir, "params", "agent.yaml"), agent_cfg)
...@@ -135,17 +133,6 @@ def main(): ...@@ -135,17 +133,6 @@ def main():
) )
env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env}) env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
if args_cli.checkpoint is not None:
resume_path = retrieve_file_path(args_cli.checkpoint)
agent_cfg["params"]["load_checkpoint"] = True
agent_cfg["params"]["load_path"] = resume_path
print(f"[INFO]: Loading model checkpoint from: {agent_cfg['params']['load_path']}")
if args_cli.sigma is not None:
train_sigma = float(args_cli.sigma)
else:
train_sigma = None
# set number of actors into agent config # set number of actors into agent config
agent_cfg["params"]["config"]["num_actors"] = env.unwrapped.num_envs agent_cfg["params"]["config"]["num_actors"] = env.unwrapped.num_envs
# create runner from rl-games # create runner from rl-games
......
...@@ -52,23 +52,36 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol ...@@ -52,23 +52,36 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol
# load the default configuration # load the default configuration
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
return rslrl_cfg
def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
"""Update configuration for RSL-RL agent based on inputs.
Args:
agent_cfg: The configuration for RSL-RL agent.
args_cli: The command line arguments.
Returns:
The updated configuration for RSL-RL agent based on inputs.
"""
# override the default configuration with CLI arguments # override the default configuration with CLI arguments
if args_cli.seed is not None: if args_cli.seed is not None:
rslrl_cfg.seed = args_cli.seed agent_cfg.seed = args_cli.seed
if args_cli.resume is not None: if args_cli.resume is not None:
rslrl_cfg.resume = args_cli.resume agent_cfg.resume = args_cli.resume
if args_cli.load_run is not None: if args_cli.load_run is not None:
rslrl_cfg.load_run = args_cli.load_run agent_cfg.load_run = args_cli.load_run
if args_cli.checkpoint is not None: if args_cli.checkpoint is not None:
rslrl_cfg.load_checkpoint = args_cli.checkpoint agent_cfg.load_checkpoint = args_cli.checkpoint
if args_cli.run_name is not None: if args_cli.run_name is not None:
rslrl_cfg.run_name = args_cli.run_name agent_cfg.run_name = args_cli.run_name
if args_cli.logger is not None: if args_cli.logger is not None:
rslrl_cfg.logger = args_cli.logger agent_cfg.logger = args_cli.logger
# set the project name for wandb and neptune # set the project name for wandb and neptune
if rslrl_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name: if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
rslrl_cfg.wandb_project = args_cli.log_project_name agent_cfg.wandb_project = args_cli.log_project_name
rslrl_cfg.neptune_project = args_cli.log_project_name agent_cfg.neptune_project = args_cli.log_project_name
return rslrl_cfg return agent_cfg
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys
from omni.isaac.lab.app import AppLauncher from omni.isaac.lab.app import AppLauncher
...@@ -20,9 +21,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.") ...@@ -20,9 +21,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).") parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
...@@ -31,11 +29,15 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy ...@@ -31,11 +29,15 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy
cli_args.add_rsl_rl_args(parser) cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args() args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -49,12 +51,13 @@ from datetime import datetime ...@@ -49,12 +51,13 @@ from datetime import datetime
from rsl_rl.runners import OnPolicyRunner from rsl_rl.runners import OnPolicyRunner
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg from omni.isaac.lab_tasks.utils import get_checkpoint_path
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
...@@ -63,13 +66,15 @@ torch.backends.cudnn.deterministic = False ...@@ -63,13 +66,15 @@ torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
def main(): @hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
"""Train with RSL-RL agent.""" """Train with RSL-RL agent."""
# parse configuration # override configurations with non-hydra CLI arguments
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg( agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg.max_iterations = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg.max_iterations
) )
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
...@@ -81,10 +86,6 @@ def main(): ...@@ -81,10 +86,6 @@ def main():
log_dir += f"_{agent_cfg.run_name}" log_dir += f"_{agent_cfg.run_name}"
log_dir = os.path.join(log_root_path, log_dir) log_dir = os.path.join(log_root_path, log_dir)
# max iterations for training
if args_cli.max_iterations:
agent_cfg.max_iterations = args_cli.max_iterations
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording # wrap for video recording
......
...@@ -13,6 +13,7 @@ there will be significant overhead in GPU->CPU transfer. ...@@ -13,6 +13,7 @@ there will be significant overhead in GPU->CPU transfer.
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys
from omni.isaac.lab.app import AppLauncher from omni.isaac.lab.app import AppLauncher
...@@ -21,9 +22,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with Stable-Base ...@@ -21,9 +22,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with Stable-Base
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).") parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
...@@ -31,11 +29,14 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy ...@@ -31,11 +29,14 @@ parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -52,28 +53,23 @@ from stable_baselines3.common.callbacks import CheckpointCallback ...@@ -52,28 +53,23 @@ from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.logger import configure from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
def main(): @hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
"""Train with stable-baselines agent.""" """Train with stable-baselines agent."""
# parse configuration # override configurations with non-hydra CLI arguments
env_cfg = parse_env_cfg( env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
)
agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point")
# override configuration with command line arguments
if args_cli.seed is not None:
agent_cfg["seed"] = args_cli.seed
# max iterations for training # max iterations for training
if args_cli.max_iterations: if args_cli.max_iterations is not None:
agent_cfg["n_timesteps"] = args_cli.max_iterations * agent_cfg["n_steps"] * env_cfg.scene.num_envs agent_cfg["n_timesteps"] = args_cli.max_iterations * agent_cfg["n_steps"] * env_cfg.scene.num_envs
# directory for logging into # directory for logging into
......
...@@ -14,6 +14,7 @@ a more user-friendly way. ...@@ -14,6 +14,7 @@ a more user-friendly way.
import argparse import argparse
import sys
from omni.isaac.lab.app import AppLauncher from omni.isaac.lab.app import AppLauncher
...@@ -22,9 +23,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with skrl.") ...@@ -22,9 +23,6 @@ parser = argparse.ArgumentParser(description="Train an RL agent with skrl.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).") parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
...@@ -43,11 +41,14 @@ parser.add_argument( ...@@ -43,11 +41,14 @@ parser.add_argument(
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -72,59 +73,53 @@ elif args_cli.ml_framework.startswith("jax"): ...@@ -72,59 +73,53 @@ elif args_cli.ml_framework.startswith("jax"):
from skrl.trainers.jax import SequentialTrainer from skrl.trainers.jax import SequentialTrainer
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg
def main(): @hydra_task_config(args_cli.task, "skrl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
"""Train with skrl agent.""" """Train with skrl agent."""
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
set_seed(args_cli.seed if args_cli.seed is not None else agent_cfg["seed"])
# multi-gpu training config
if args_cli.distributed:
if args_cli.ml_framework.startswith("jax"):
raise ValueError("Multi-GPU distributed training not yet supported in JAX")
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# max iterations for training
if args_cli.max_iterations:
agent_cfg["trainer"]["timesteps"] = args_cli.max_iterations * agent_cfg["agent"]["rollouts"]
# configure the ML framework into the global skrl variable # configure the ML framework into the global skrl variable
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# read the seed from command line
args_cli_seed = args_cli.seed
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
experiment_cfg = load_cfg_from_registry(args_cli.task, "skrl_cfg_entry_point")
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"]) log_root_path = os.path.join("logs", "skrl", agent_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Logging experiment in directory: {log_root_path}") print(f"[INFO] Logging experiment in directory: {log_root_path}")
# specify directory for logging runs: {time-stamp}_{run_name} # specify directory for logging runs: {time-stamp}_{run_name}
log_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") log_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if experiment_cfg["agent"]["experiment"]["experiment_name"]: if agent_cfg["agent"]["experiment"]["experiment_name"]:
log_dir += f'_{experiment_cfg["agent"]["experiment"]["experiment_name"]}' log_dir += f'_{agent_cfg["agent"]["experiment"]["experiment_name"]}'
# set directory into agent config # set directory into agent config
experiment_cfg["agent"]["experiment"]["directory"] = log_root_path agent_cfg["agent"]["experiment"]["directory"] = log_root_path
experiment_cfg["agent"]["experiment"]["experiment_name"] = log_dir agent_cfg["agent"]["experiment"]["experiment_name"] = log_dir
# update log_dir # update log_dir
log_dir = os.path.join(log_root_path, log_dir) log_dir = os.path.join(log_root_path, log_dir)
# multi-gpu training config
if args_cli.distributed:
if args_cli.ml_framework.startswith("jax"):
raise ValueError("Multi-GPU distributed training not yet supported in JAX")
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# max iterations for training
if args_cli.max_iterations:
experiment_cfg["trainer"]["timesteps"] = args_cli.max_iterations * experiment_cfg["agent"]["rollouts"]
# dump the configuration into log-directory # dump the configuration into log-directory
dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg) dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), experiment_cfg) dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg) dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), experiment_cfg) dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
...@@ -142,27 +137,24 @@ def main(): ...@@ -142,27 +137,24 @@ def main():
# wrap around environment for skrl # wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")` env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
# set seed for the experiment (override from command line)
set_seed(args_cli_seed if args_cli_seed is not None else experiment_cfg["seed"])
# instantiate models using skrl model instantiator utility # instantiate models using skrl model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html # https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models = {} models = {}
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
experiment_cfg["models"]["separate"] = True # shared model is not supported in JAX agent_cfg["models"]["separate"] = True # shared model is not supported in JAX
# non-shared models # non-shared models
if experiment_cfg["models"]["separate"]: if agent_cfg["models"]["separate"]:
models["policy"] = gaussian_model( models["policy"] = gaussian_model(
observation_space=env.observation_space, observation_space=env.observation_space,
action_space=env.action_space, action_space=env.action_space,
device=env.device, device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework), **process_skrl_cfg(agent_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
) )
models["value"] = deterministic_model( models["value"] = deterministic_model(
observation_space=env.observation_space, observation_space=env.observation_space,
action_space=env.action_space, action_space=env.action_space,
device=env.device, device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework), **process_skrl_cfg(agent_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
) )
# shared models # shared models
else: else:
...@@ -173,8 +165,8 @@ def main(): ...@@ -173,8 +165,8 @@ def main():
structure=None, structure=None,
roles=["policy", "value"], roles=["policy", "value"],
parameters=[ parameters=[
process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework), process_skrl_cfg(agent_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework), process_skrl_cfg(agent_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
], ],
) )
models["value"] = models["policy"] models["value"] = models["policy"]
...@@ -185,22 +177,22 @@ def main(): ...@@ -185,22 +177,22 @@ def main():
# instantiate a RandomMemory as rollout buffer (any memory can be used for this) # instantiate a RandomMemory as rollout buffer (any memory can be used for this)
# https://skrl.readthedocs.io/en/latest/api/memories/random.html # https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory_size = experiment_cfg["agent"]["rollouts"] # memory_size is the agent's number of rollouts memory_size = agent_cfg["agent"]["rollouts"] # memory_size is the agent's number of rollouts
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=env.device) memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=env.device)
# configure and instantiate PPO agent # configure and instantiate PPO agent
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html # https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent_cfg = PPO_DEFAULT_CONFIG.copy() default_agent_cfg = PPO_DEFAULT_CONFIG.copy()
experiment_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration' agent_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"], ml_framework=args_cli.ml_framework)) default_agent_cfg.update(process_skrl_cfg(agent_cfg["agent"], ml_framework=args_cli.ml_framework))
agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device}) default_agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device}) default_agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})
agent = PPO( agent = PPO(
models=models, models=models,
memory=memory, memory=memory,
cfg=agent_cfg, cfg=default_agent_cfg,
observation_space=env.observation_space, observation_space=env.observation_space,
action_space=env.action_space, action_space=env.action_space,
device=env.device, device=env.device,
...@@ -208,7 +200,7 @@ def main(): ...@@ -208,7 +200,7 @@ def main():
# configure and instantiate a custom RL trainer for logging episode events # configure and instantiate a custom RL trainer for logging episode events
# https://skrl.readthedocs.io/en/latest/api/trainers.html # https://skrl.readthedocs.io/en/latest/api/trainers.html
trainer_cfg = experiment_cfg["trainer"] trainer_cfg = agent_cfg["trainer"]
trainer_cfg["close_environment_at_exit"] = False trainer_cfg["close_environment_at_exit"] = False
trainer = SequentialTrainer(cfg=trainer_cfg, env=env, agents=agent) trainer = SequentialTrainer(cfg=trainer_cfg, env=env, agents=agent)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment