Unverified Commit d7613ce8 authored by ooctipus's avatar ooctipus Committed by GitHub

Supports rl games wrapper with dictionary observation (#3340)

# Description

This PR opens the possibility to use dictionary observation for rl-games
application.
This benefits:
1. combination of high + low dim observations percolate into actor and
critic in rl-games
2. avoid double computation if actor and critic share the same
observation


## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 90e5f31a
...@@ -134,6 +134,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -134,6 +134,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
rl_device = agent_cfg["params"]["config"]["device"] rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf) clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf) clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
obs_groups = agent_cfg["params"]["env"].get("obs_groups")
concate_obs_groups = agent_cfg["params"]["env"].get("concate_obs_groups", True)
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
...@@ -155,7 +157,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -155,7 +157,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rl-games # wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions) env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions, obs_groups, concate_obs_groups)
# register the environment to rl-games registry # register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu" # note: in agents configuration: environment name must be "rlgpu"
......
...@@ -148,6 +148,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -148,6 +148,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
rl_device = agent_cfg["params"]["config"]["device"] rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf) clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf) clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
obs_groups = agent_cfg["params"]["env"].get("obs_groups")
concate_obs_groups = agent_cfg["params"]["env"].get("concate_obs_groups", True)
# set the IO descriptors output directory if requested # set the IO descriptors output directory if requested
if isinstance(env_cfg, ManagerBasedRLEnvCfg): if isinstance(env_cfg, ManagerBasedRLEnvCfg):
...@@ -178,7 +180,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -178,7 +180,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rl-games # wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions) env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions, obs_groups, concate_obs_groups)
# register the environment to rl-games registry # register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu" # note: in agents configuration: environment name must be "rlgpu"
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.2.4" version = "0.3.0"
# Description # Description
title = "Isaac Lab RL" title = "Isaac Lab RL"
......
Changelog Changelog
--------- ---------
0.3.0 (2025-09-03)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Enhanced rl-games wrapper to allow dict observation.
0.2.4 (2025-08-07) 0.2.4 (2025-08-07)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -37,6 +37,7 @@ from __future__ import annotations ...@@ -37,6 +37,7 @@ from __future__ import annotations
import gym.spaces # needed for rl-games incompatibility: https://github.com/Denys88/rl_games/issues/261 import gym.spaces # needed for rl-games incompatibility: https://github.com/Denys88/rl_games/issues/261
import gymnasium import gymnasium
import torch import torch
from collections.abc import Callable
from rl_games.common import env_configurations from rl_games.common import env_configurations
from rl_games.common.vecenv import IVecEnv from rl_games.common.vecenv import IVecEnv
...@@ -60,12 +61,14 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -60,12 +61,14 @@ class RlGamesVecEnvWrapper(IVecEnv):
observations. This dictionary contains "obs" and "states" which typically correspond observations. This dictionary contains "obs" and "states" which typically correspond
to the actor and critic observations respectively. to the actor and critic observations respectively.
To use asymmetric actor-critic, the environment observations from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv` To use asymmetric actor-critic, map privileged observation groups under ``"states"`` (e.g. ``["critic"]``).
must have the key or group name "critic". The observation group is used to set the
:attr:`num_states` (int) and :attr:`state_space` (:obj:`gym.spaces.Box`). These are The wrapper supports **either** concatenated tensors (default) **or** Dict inputs:
used by the learning agent in RL-Games to allocate buffers in the trajectory memory. when wrapper is concate mode, rl-games sees {"obs": Tensor, (optional)"states": Tensor}
Since this is optional for some environments, the wrapper checks if these attributes exist. when wrapper is not concate mode, rl-games sees {"obs": dict[str, Tensor], (optional)"states": dict[str, Tensor]}
If they don't then the wrapper defaults to zero as number of privileged observations. - Concatenated mode (``concate_obs_group=True``): ``observation_space``/``state_space`` are ``gym.spaces.Box``.
- Dict mode (``concate_obs_group=False``): ``observation_space``/``state_space`` are ``gym.spaces.Dict`` keyed by
the requested groups. When no ``"states"`` groups are provided, the states Dict is omitted at runtime.
.. caution:: .. caution::
...@@ -79,7 +82,15 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -79,7 +82,15 @@ class RlGamesVecEnvWrapper(IVecEnv):
https://github.com/NVIDIA-Omniverse/IsaacGymEnvs https://github.com/NVIDIA-Omniverse/IsaacGymEnvs
""" """
def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, rl_device: str, clip_obs: float, clip_actions: float): def __init__(
self,
env: ManagerBasedRLEnv | DirectRLEnv,
rl_device: str,
clip_obs: float,
clip_actions: float,
obs_groups: dict[str, list[str]] | None = None,
concate_obs_group: bool = True,
):
"""Initializes the wrapper instance. """Initializes the wrapper instance.
Args: Args:
...@@ -87,6 +98,9 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -87,6 +98,9 @@ class RlGamesVecEnvWrapper(IVecEnv):
rl_device: The device on which agent computations are performed. rl_device: The device on which agent computations are performed.
clip_obs: The clipping value for observations. clip_obs: The clipping value for observations.
clip_actions: The clipping value for actions. clip_actions: The clipping value for actions.
obs_groups: The remapping from isaaclab observation to rl-games, default to None for backward compatible.
concate_obs_group: The boolean value indicates if input to rl-games network is dict or tensor. Default to
True for backward compatible.
Raises: Raises:
ValueError: The environment is not inherited from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`. ValueError: The environment is not inherited from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
...@@ -105,11 +119,36 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -105,11 +119,36 @@ class RlGamesVecEnvWrapper(IVecEnv):
self._clip_obs = clip_obs self._clip_obs = clip_obs
self._clip_actions = clip_actions self._clip_actions = clip_actions
self._sim_device = env.unwrapped.device self._sim_device = env.unwrapped.device
# information for privileged observations
if self.state_space is None: # resolve the observation group
self.rlg_num_states = 0 self._concate_obs_groups = concate_obs_group
else: self._obs_groups = obs_groups
if obs_groups is None:
self._obs_groups = {"obs": ["policy"], "states": []}
if not self.unwrapped.single_observation_space.get("policy"):
raise KeyError("Policy observation group is expected if no explicit groups is defined")
if self.unwrapped.single_observation_space.get("critic"):
self._obs_groups["states"] = ["critic"]
if (
self._concate_obs_groups
and isinstance(self.state_space, gym.spaces.Box)
and isinstance(self.observation_space, gym.spaces.Box)
):
self.rlg_num_states = self.state_space.shape[0] self.rlg_num_states = self.state_space.shape[0]
elif (
not self._concate_obs_groups
and isinstance(self.state_space, gym.spaces.Dict)
and isinstance(self.observation_space, gym.spaces.Dict)
):
space = [space.shape[0] for space in self.state_space.values()]
self.rlg_num_states = sum(space)
else:
raise TypeError(
"only valid combination for state space is gym.space.Box when concate_obs_groups is True, "
" and gym.space.Dict when concate_obs_groups is False. You have concate_obs_groups: "
f" {self._concate_obs_groups}, and state_space: {self.state_space.__class__}"
)
def __str__(self): def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string.""" """Returns the wrapper name and the :attr:`env` representation string."""
...@@ -135,19 +174,18 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -135,19 +174,18 @@ class RlGamesVecEnvWrapper(IVecEnv):
return self.env.render_mode return self.env.render_mode
@property @property
def observation_space(self) -> gym.spaces.Box: def observation_space(self) -> gym.spaces.Box | gym.spaces.Dict:
"""Returns the :attr:`Env` :attr:`observation_space`.""" """Returns the :attr:`Env` :attr:`observation_space` (``Box`` if concatenated, otherwise ``Dict``)."""
# note: rl-games only wants single observation space # note: rl-games only wants single observation space
policy_obs_space = self.unwrapped.single_observation_space["policy"] space = self.unwrapped.single_observation_space
if not isinstance(policy_obs_space, gymnasium.spaces.Box): clip = self._clip_obs
raise NotImplementedError( if not self._concate_obs_groups:
f"The RL-Games wrapper does not currently support observation space: '{type(policy_obs_space)}'." policy_space = {grp: gym.spaces.Box(-clip, clip, space.get(grp).shape) for grp in self._obs_groups["obs"]}
f" If you need to support this, please modify the wrapper: {self.__class__.__name__}," return gym.spaces.Dict(policy_space)
" and if you are nice, please send a merge-request." else:
) shapes = [space.get(group).shape for group in self._obs_groups["obs"]]
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since cat_shape, self._obs_concat_fn = make_concat_plan(shapes)
# in ManagerBasedRLEnv we are setting action space as (-inf, inf). return gym.spaces.Box(-clip, clip, cat_shape)
return gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)
@property @property
def action_space(self) -> gym.Space: def action_space(self) -> gym.Space:
...@@ -193,23 +231,18 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -193,23 +231,18 @@ class RlGamesVecEnvWrapper(IVecEnv):
return self.unwrapped.device return self.unwrapped.device
@property @property
def state_space(self) -> gym.spaces.Box | None: def state_space(self) -> gym.spaces.Box | gym.spaces.Dict | None:
"""Returns the :attr:`Env` :attr:`observation_space`.""" """Returns the privileged observation space for the critic (``Box`` if concatenated, otherwise ``Dict``)."""
# note: rl-games only wants single observation space # # note: rl-games only wants single observation space
critic_obs_space = self.unwrapped.single_observation_space.get("critic") space = self.unwrapped.single_observation_space
# check if we even have a critic obs clip = self._clip_obs
if critic_obs_space is None: if not self._concate_obs_groups:
return None state_space = {grp: gym.spaces.Box(-clip, clip, space.get(grp).shape) for grp in self._obs_groups["states"]}
elif not isinstance(critic_obs_space, gymnasium.spaces.Box): return gym.spaces.Dict(state_space)
raise NotImplementedError( else:
f"The RL-Games wrapper does not currently support state space: '{type(critic_obs_space)}'." shapes = [space.get(group).shape for group in self._obs_groups["states"]]
f" If you need to support this, please modify the wrapper: {self.__class__.__name__}," cat_shape, self._states_concat_fn = make_concat_plan(shapes)
" and if you are nice, please send a merge-request." return gym.spaces.Box(-self._clip_obs, self._clip_obs, cat_shape)
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in ManagerBasedRLEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, critic_obs_space.shape)
def get_number_of_agents(self) -> int: def get_number_of_agents(self) -> int:
"""Returns number of actors in the environment.""" """Returns number of actors in the environment."""
...@@ -270,7 +303,7 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -270,7 +303,7 @@ class RlGamesVecEnvWrapper(IVecEnv):
Helper functions Helper functions
""" """
def _process_obs(self, obs_dict: VecEnvObs) -> torch.Tensor | dict[str, torch.Tensor]: def _process_obs(self, obs_dict: VecEnvObs) -> dict[str, torch.Tensor] | dict[str, dict[str, torch.Tensor]]:
"""Processing of the observations and states from the environment. """Processing of the observations and states from the environment.
Note: Note:
...@@ -281,31 +314,60 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -281,31 +314,60 @@ class RlGamesVecEnvWrapper(IVecEnv):
obs_dict: The current observations from environment. obs_dict: The current observations from environment.
Returns: Returns:
If environment provides states, then a dictionary containing the observations and states is returned. A dictionary for RL-Games with keys:
Otherwise just the observations tensor is returned. - ``"obs"``: either a concatenated tensor (``concate_obs_group=True``) or a Dict of group tensors.
- ``"states"`` (optional): same structure as above when state groups are configured; omitted otherwise.
""" """
# process policy obs
obs = obs_dict["policy"]
# clip the observations # clip the observations
obs = torch.clamp(obs, -self._clip_obs, self._clip_obs) for key, obs in obs_dict.items():
# move the buffer to rl-device obs_dict[key] = torch.clamp(obs, -self._clip_obs, self._clip_obs)
obs = obs.to(device=self._rl_device).clone()
# process input obs dict
# check if asymmetric actor-critic or not rl_games_obs = {"obs": {group: obs_dict[group] for group in self._obs_groups["obs"]}}
if self.rlg_num_states > 0: if len(self._obs_groups["states"]) > 0:
# acquire states from the environment if it exists rl_games_obs["states"] = {group: obs_dict[group] for group in self._obs_groups["states"]}
try:
states = obs_dict["critic"] if self._concate_obs_groups:
except AttributeError: rl_games_obs["obs"] = self._obs_concat_fn(list(rl_games_obs["obs"].values()))
raise NotImplementedError("Environment does not define key 'critic' for privileged observations.") if "states" in rl_games_obs:
# clip the states rl_games_obs["states"] = self._states_concat_fn(list(rl_games_obs["states"].values()))
states = torch.clamp(states, -self._clip_obs, self._clip_obs)
# move buffers to rl-device return rl_games_obs
states = states.to(self._rl_device).clone()
# convert to dictionary
return {"obs": obs, "states": states} def make_concat_plan(shapes: list[tuple[int, ...]]) -> tuple[tuple[int, ...], Callable]:
"""
Given per-sample shapes (no batch dim), return:
- the concatenated per-sample shape
- a function that concatenates a list of batch tensors accordingly.
Rules:
0) Empty -> (0,), No-op
1) All 1D -> concat features (dim=1).
2) Same rank > 1:
2a) If all s[:-1] equal -> concat along last dim (channels-last, dim=-1).
2b) If all s[1:] equal -> concat along first dim (channels-first, dim=1).
"""
if len(shapes) == 0:
return (0,), lambda x: x
# case 1: all vectors
if all(len(s) == 1 for s in shapes):
return (sum(s[0] for s in shapes),), lambda x: torch.cat(x, dim=1)
# case 2: same rank > 1
rank = len(shapes[0])
if all(len(s) == rank for s in shapes) and rank > 1:
# 2a: concat along last axis (…C)
if all(s[:-1] == shapes[0][:-1] for s in shapes):
out_shape = shapes[0][:-1] + (sum(s[-1] for s in shapes),)
return out_shape, lambda x: torch.cat(x, dim=-1)
# 2b: concat along first axis (C…)
if all(s[1:] == shapes[0][1:] for s in shapes):
out_shape = (sum(s[0] for s in shapes),) + shapes[0][1:]
return out_shape, lambda x: torch.cat(x, dim=1)
else:
raise ValueError(f"Could not find a valid concatenation plan for rank {[(len(s),) for s in shapes]}")
else: else:
return obs raise ValueError("Could not find a valid concatenation plan, please make sure all value share the same size")
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment