Unverified Commit d7613ce8 authored by ooctipus's avatar ooctipus Committed by GitHub

Supports rl games wrapper with dictionary observation (#3340)

# Description

This PR opens the possibility to use dictionary observation for rl-games
application.
This benefits:
1. combination of high + low dim observations percolate into actor and
critic in rl-games
2. avoid double computation if actor and critic share the same
observation


## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 90e5f31a
......@@ -134,6 +134,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
obs_groups = agent_cfg["params"]["env"].get("obs_groups")
concate_obs_groups = agent_cfg["params"]["env"].get("concate_obs_groups", True)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
......@@ -155,7 +157,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions, obs_groups, concate_obs_groups)
# register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu"
......
......@@ -148,6 +148,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
obs_groups = agent_cfg["params"]["env"].get("obs_groups")
concate_obs_groups = agent_cfg["params"]["env"].get("concate_obs_groups", True)
# set the IO descriptors output directory if requested
if isinstance(env_cfg, ManagerBasedRLEnvCfg):
......@@ -178,7 +180,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions, obs_groups, concate_obs_groups)
# register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu"
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.2.4"
version = "0.3.0"
# Description
title = "Isaac Lab RL"
......
Changelog
---------
0.3.0 (2025-09-03)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Enhanced rl-games wrapper to allow dict observation.
0.2.4 (2025-08-07)
~~~~~~~~~~~~~~~~~~
......
......@@ -37,6 +37,7 @@ from __future__ import annotations
import gym.spaces # needed for rl-games incompatibility: https://github.com/Denys88/rl_games/issues/261
import gymnasium
import torch
from collections.abc import Callable
from rl_games.common import env_configurations
from rl_games.common.vecenv import IVecEnv
......@@ -60,12 +61,14 @@ class RlGamesVecEnvWrapper(IVecEnv):
observations. This dictionary contains "obs" and "states" which typically correspond
to the actor and critic observations respectively.
To use asymmetric actor-critic, the environment observations from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`
must have the key or group name "critic". The observation group is used to set the
:attr:`num_states` (int) and :attr:`state_space` (:obj:`gym.spaces.Box`). These are
used by the learning agent in RL-Games to allocate buffers in the trajectory memory.
Since this is optional for some environments, the wrapper checks if these attributes exist.
If they don't then the wrapper defaults to zero as number of privileged observations.
To use asymmetric actor-critic, map privileged observation groups under ``"states"`` (e.g. ``["critic"]``).
The wrapper supports **either** concatenated tensors (default) **or** Dict inputs:
when wrapper is concate mode, rl-games sees {"obs": Tensor, (optional)"states": Tensor}
when wrapper is not concate mode, rl-games sees {"obs": dict[str, Tensor], (optional)"states": dict[str, Tensor]}
- Concatenated mode (``concate_obs_group=True``): ``observation_space``/``state_space`` are ``gym.spaces.Box``.
- Dict mode (``concate_obs_group=False``): ``observation_space``/``state_space`` are ``gym.spaces.Dict`` keyed by
the requested groups. When no ``"states"`` groups are provided, the states Dict is omitted at runtime.
.. caution::
......@@ -79,7 +82,15 @@ class RlGamesVecEnvWrapper(IVecEnv):
https://github.com/NVIDIA-Omniverse/IsaacGymEnvs
"""
def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, rl_device: str, clip_obs: float, clip_actions: float):
def __init__(
self,
env: ManagerBasedRLEnv | DirectRLEnv,
rl_device: str,
clip_obs: float,
clip_actions: float,
obs_groups: dict[str, list[str]] | None = None,
concate_obs_group: bool = True,
):
"""Initializes the wrapper instance.
Args:
......@@ -87,6 +98,9 @@ class RlGamesVecEnvWrapper(IVecEnv):
rl_device: The device on which agent computations are performed.
clip_obs: The clipping value for observations.
clip_actions: The clipping value for actions.
obs_groups: The remapping from isaaclab observation to rl-games, default to None for backward compatible.
concate_obs_group: The boolean value indicates if input to rl-games network is dict or tensor. Default to
True for backward compatible.
Raises:
ValueError: The environment is not inherited from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
......@@ -105,11 +119,36 @@ class RlGamesVecEnvWrapper(IVecEnv):
self._clip_obs = clip_obs
self._clip_actions = clip_actions
self._sim_device = env.unwrapped.device
# information for privileged observations
if self.state_space is None:
self.rlg_num_states = 0
else:
# resolve the observation group
self._concate_obs_groups = concate_obs_group
self._obs_groups = obs_groups
if obs_groups is None:
self._obs_groups = {"obs": ["policy"], "states": []}
if not self.unwrapped.single_observation_space.get("policy"):
raise KeyError("Policy observation group is expected if no explicit groups is defined")
if self.unwrapped.single_observation_space.get("critic"):
self._obs_groups["states"] = ["critic"]
if (
self._concate_obs_groups
and isinstance(self.state_space, gym.spaces.Box)
and isinstance(self.observation_space, gym.spaces.Box)
):
self.rlg_num_states = self.state_space.shape[0]
elif (
not self._concate_obs_groups
and isinstance(self.state_space, gym.spaces.Dict)
and isinstance(self.observation_space, gym.spaces.Dict)
):
space = [space.shape[0] for space in self.state_space.values()]
self.rlg_num_states = sum(space)
else:
raise TypeError(
"only valid combination for state space is gym.space.Box when concate_obs_groups is True, "
" and gym.space.Dict when concate_obs_groups is False. You have concate_obs_groups: "
f" {self._concate_obs_groups}, and state_space: {self.state_space.__class__}"
)
def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
......@@ -135,19 +174,18 @@ class RlGamesVecEnvWrapper(IVecEnv):
return self.env.render_mode
@property
def observation_space(self) -> gym.spaces.Box:
"""Returns the :attr:`Env` :attr:`observation_space`."""
def observation_space(self) -> gym.spaces.Box | gym.spaces.Dict:
"""Returns the :attr:`Env` :attr:`observation_space` (``Box`` if concatenated, otherwise ``Dict``)."""
# note: rl-games only wants single observation space
policy_obs_space = self.unwrapped.single_observation_space["policy"]
if not isinstance(policy_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support observation space: '{type(policy_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in ManagerBasedRLEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)
space = self.unwrapped.single_observation_space
clip = self._clip_obs
if not self._concate_obs_groups:
policy_space = {grp: gym.spaces.Box(-clip, clip, space.get(grp).shape) for grp in self._obs_groups["obs"]}
return gym.spaces.Dict(policy_space)
else:
shapes = [space.get(group).shape for group in self._obs_groups["obs"]]
cat_shape, self._obs_concat_fn = make_concat_plan(shapes)
return gym.spaces.Box(-clip, clip, cat_shape)
@property
def action_space(self) -> gym.Space:
......@@ -193,23 +231,18 @@ class RlGamesVecEnvWrapper(IVecEnv):
return self.unwrapped.device
@property
def state_space(self) -> gym.spaces.Box | None:
"""Returns the :attr:`Env` :attr:`observation_space`."""
# note: rl-games only wants single observation space
critic_obs_space = self.unwrapped.single_observation_space.get("critic")
# check if we even have a critic obs
if critic_obs_space is None:
return None
elif not isinstance(critic_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support state space: '{type(critic_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in ManagerBasedRLEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, critic_obs_space.shape)
def state_space(self) -> gym.spaces.Box | gym.spaces.Dict | None:
"""Returns the privileged observation space for the critic (``Box`` if concatenated, otherwise ``Dict``)."""
# # note: rl-games only wants single observation space
space = self.unwrapped.single_observation_space
clip = self._clip_obs
if not self._concate_obs_groups:
state_space = {grp: gym.spaces.Box(-clip, clip, space.get(grp).shape) for grp in self._obs_groups["states"]}
return gym.spaces.Dict(state_space)
else:
shapes = [space.get(group).shape for group in self._obs_groups["states"]]
cat_shape, self._states_concat_fn = make_concat_plan(shapes)
return gym.spaces.Box(-self._clip_obs, self._clip_obs, cat_shape)
def get_number_of_agents(self) -> int:
"""Returns number of actors in the environment."""
......@@ -270,7 +303,7 @@ class RlGamesVecEnvWrapper(IVecEnv):
Helper functions
"""
def _process_obs(self, obs_dict: VecEnvObs) -> torch.Tensor | dict[str, torch.Tensor]:
def _process_obs(self, obs_dict: VecEnvObs) -> dict[str, torch.Tensor] | dict[str, dict[str, torch.Tensor]]:
"""Processing of the observations and states from the environment.
Note:
......@@ -280,32 +313,61 @@ class RlGamesVecEnvWrapper(IVecEnv):
Args:
obs_dict: The current observations from environment.
Returns:
If environment provides states, then a dictionary containing the observations and states is returned.
Otherwise just the observations tensor is returned.
Returns:
A dictionary for RL-Games with keys:
- ``"obs"``: either a concatenated tensor (``concate_obs_group=True``) or a Dict of group tensors.
- ``"states"`` (optional): same structure as above when state groups are configured; omitted otherwise.
"""
# process policy obs
obs = obs_dict["policy"]
# clip the observations
obs = torch.clamp(obs, -self._clip_obs, self._clip_obs)
# move the buffer to rl-device
obs = obs.to(device=self._rl_device).clone()
# check if asymmetric actor-critic or not
if self.rlg_num_states > 0:
# acquire states from the environment if it exists
try:
states = obs_dict["critic"]
except AttributeError:
raise NotImplementedError("Environment does not define key 'critic' for privileged observations.")
# clip the states
states = torch.clamp(states, -self._clip_obs, self._clip_obs)
# move buffers to rl-device
states = states.to(self._rl_device).clone()
# convert to dictionary
return {"obs": obs, "states": states}
for key, obs in obs_dict.items():
obs_dict[key] = torch.clamp(obs, -self._clip_obs, self._clip_obs)
# process input obs dict
rl_games_obs = {"obs": {group: obs_dict[group] for group in self._obs_groups["obs"]}}
if len(self._obs_groups["states"]) > 0:
rl_games_obs["states"] = {group: obs_dict[group] for group in self._obs_groups["states"]}
if self._concate_obs_groups:
rl_games_obs["obs"] = self._obs_concat_fn(list(rl_games_obs["obs"].values()))
if "states" in rl_games_obs:
rl_games_obs["states"] = self._states_concat_fn(list(rl_games_obs["states"].values()))
return rl_games_obs
def make_concat_plan(shapes: list[tuple[int, ...]]) -> tuple[tuple[int, ...], Callable]:
"""
Given per-sample shapes (no batch dim), return:
- the concatenated per-sample shape
- a function that concatenates a list of batch tensors accordingly.
Rules:
0) Empty -> (0,), No-op
1) All 1D -> concat features (dim=1).
2) Same rank > 1:
2a) If all s[:-1] equal -> concat along last dim (channels-last, dim=-1).
2b) If all s[1:] equal -> concat along first dim (channels-first, dim=1).
"""
if len(shapes) == 0:
return (0,), lambda x: x
# case 1: all vectors
if all(len(s) == 1 for s in shapes):
return (sum(s[0] for s in shapes),), lambda x: torch.cat(x, dim=1)
# case 2: same rank > 1
rank = len(shapes[0])
if all(len(s) == rank for s in shapes) and rank > 1:
# 2a: concat along last axis (…C)
if all(s[:-1] == shapes[0][:-1] for s in shapes):
out_shape = shapes[0][:-1] + (sum(s[-1] for s in shapes),)
return out_shape, lambda x: torch.cat(x, dim=-1)
# 2b: concat along first axis (C…)
if all(s[1:] == shapes[0][1:] for s in shapes):
out_shape = (sum(s[0] for s in shapes),) + shapes[0][1:]
return out_shape, lambda x: torch.cat(x, dim=1)
else:
return obs
raise ValueError(f"Could not find a valid concatenation plan for rank {[(len(s),) for s in shapes]}")
else:
raise ValueError("Could not find a valid concatenation plan, please make sure all value share the same size")
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment