Unverified Commit be41bb0d authored by Pascal Roth's avatar Pascal Roth Committed by GitHub

Adds option to define the concatenation dimension in the `ObservationManager`...

Adds option to define the concatenation dimension in the `ObservationManager` and change counter update in `CommandManager` (#2393)

# Description


Added support for concatenation of observations along different
dimensions in `ObservationManager`.

Updates the position where the command counter is increased to allow
checking for reset environments in the resample call of the
`CommandManager`

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarPascal Roth <57946385+pascal-roth@users.noreply.github.com>
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Signed-off-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 963b53b9
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.39.4" version = "0.39.5"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.39.5 (2025-05-16)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for concatenation of observations along different dimensions in :class:`~isaaclab.managers.observation_manager.ObservationManager`.
Changed
^^^^^^^
* Updated the :class:`~isaaclab.managers.command_manager.CommandManager` to update the command counter after the
resampling call.
0.39.4 (2025-05-16) 0.39.4 (2025-05-16)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -181,10 +181,10 @@ class CommandTerm(ManagerTermBase): ...@@ -181,10 +181,10 @@ class CommandTerm(ManagerTermBase):
if len(env_ids) != 0: if len(env_ids) != 0:
# resample the time left before resampling # resample the time left before resampling
self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range) self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
# increment the command counter
self.command_counter[env_ids] += 1
# resample the command # resample the command
self._resample_command(env_ids) self._resample_command(env_ids)
# increment the command counter
self.command_counter[env_ids] += 1
""" """
Implementation specific functions. Implementation specific functions.
......
...@@ -201,12 +201,22 @@ class ObservationGroupCfg: ...@@ -201,12 +201,22 @@ class ObservationGroupCfg:
concatenate_terms: bool = True concatenate_terms: bool = True
"""Whether to concatenate the observation terms in the group. Defaults to True. """Whether to concatenate the observation terms in the group. Defaults to True.
If true, the observation terms in the group are concatenated along the last dimension. If true, the observation terms in the group are concatenated along the dimension specified through :attr:`concatenate_dim`.
Otherwise, they are kept separate and returned as a dictionary. Otherwise, they are kept separate and returned as a dictionary.
If the observation group contains terms of different dimensions, it must be set to False. If the observation group contains terms of different dimensions, it must be set to False.
""" """
concatenate_dim: int = -1
"""Dimension along to concatenate the different observation terms. Defaults to -1, which
means the last dimension of the observation terms.
If :attr:`concatenate_terms` is True, this parameter specifies the dimension along which the observation terms are concatenated.
The indicated dimension depends on the shape of the observations. For instance, for a 2D RGB image of shape (H, W, C), the dimension
0 means concatenating along the height, 1 along the width, and 2 along the channels. The offset due
to the batched environment is handled automatically.
"""
enable_corruption: bool = False enable_corruption: bool = False
"""Whether to enable corruption for the observation group. Defaults to False. """Whether to enable corruption for the observation group. Defaults to False.
......
...@@ -88,8 +88,18 @@ class ObservationManager(ManagerBase): ...@@ -88,8 +88,18 @@ class ObservationManager(ManagerBase):
# otherwise, keep the list of shapes as is # otherwise, keep the list of shapes as is
if self._group_obs_concatenate[group_name]: if self._group_obs_concatenate[group_name]:
try: try:
term_dims = [torch.tensor(dims, device="cpu") for dims in group_term_dims] term_dims = torch.stack([torch.tensor(dims, device="cpu") for dims in group_term_dims], dim=0)
self._group_obs_dim[group_name] = tuple(torch.sum(torch.stack(term_dims, dim=0), dim=0).tolist()) if len(term_dims.shape) > 1:
if self._group_obs_concatenate_dim[group_name] >= 0:
dim = self._group_obs_concatenate_dim[group_name] - 1 # account for the batch offset
else:
dim = self._group_obs_concatenate_dim[group_name]
dim_sum = torch.sum(term_dims[:, dim], dim=0)
term_dims[0, dim] = dim_sum
term_dims = term_dims[0]
else:
term_dims = torch.sum(term_dims, dim=0)
self._group_obs_dim[group_name] = tuple(term_dims.tolist())
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Unable to concatenate observation terms in group '{group_name}'." f"Unable to concatenate observation terms in group '{group_name}'."
...@@ -330,7 +340,8 @@ class ObservationManager(ManagerBase): ...@@ -330,7 +340,8 @@ class ObservationManager(ManagerBase):
# concatenate all observations in the group together # concatenate all observations in the group together
if self._group_obs_concatenate[group_name]: if self._group_obs_concatenate[group_name]:
return torch.cat(list(group_obs.values()), dim=-1) # set the concatenate dimension, account for the batch dimension if positive dimension is given
return torch.cat(list(group_obs.values()), dim=self._group_obs_concatenate_dim[group_name])
else: else:
return group_obs return group_obs
...@@ -370,6 +381,8 @@ class ObservationManager(ManagerBase): ...@@ -370,6 +381,8 @@ class ObservationManager(ManagerBase):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict() self._group_obs_concatenate: dict[str, bool] = dict()
self._group_obs_concatenate_dim: dict[str, int] = dict()
self._group_obs_term_history_buffer: dict[str, dict] = dict() self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes # create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls # we store it as a separate list to only call reset on them and prevent unnecessary calls
...@@ -407,6 +420,9 @@ class ObservationManager(ManagerBase): ...@@ -407,6 +420,9 @@ class ObservationManager(ManagerBase):
group_entry_history_buffer: dict[str, CircularBuffer] = dict() group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group # read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
self._group_obs_concatenate_dim[group_name] = (
group_cfg.concatenate_dim + 1 if group_cfg.concatenate_dim >= 0 else group_cfg.concatenate_dim
)
# check if config is dict already # check if config is dict already
if isinstance(group_cfg, dict): if isinstance(group_cfg, dict):
group_cfg_items = group_cfg.items() group_cfg_items = group_cfg.items()
...@@ -415,7 +431,13 @@ class ObservationManager(ManagerBase): ...@@ -415,7 +431,13 @@ class ObservationManager(ManagerBase):
# iterate over all the terms in each group # iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items: for term_name, term_cfg in group_cfg_items:
# skip non-obs settings # skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]: if term_name in [
"enable_corruption",
"concatenate_terms",
"history_length",
"flatten_history_dim",
"concatenate_dim",
]:
continue continue
# check for non config # check for non config
if term_cfg is None: if term_cfg is None:
......
...@@ -667,41 +667,43 @@ def test_modifier_compute(setup_env): ...@@ -667,41 +667,43 @@ def test_modifier_compute(setup_env):
assert torch.min(obs_critic["term_4"]) >= -0.5 assert torch.min(obs_critic["term_4"]) >= -0.5
assert torch.max(obs_critic["term_4"]) <= 0.5 assert torch.max(obs_critic["term_4"]) <= 0.5
def test_serialize(self):
"""Test serialize call for ManagerTermBase terms."""
serialize_data = {"test": 0} def test_serialize(setup_env):
"""Test serialize call for ManagerTermBase terms."""
env = setup_env
class test_serialize_term(ManagerTermBase): serialize_data = {"test": 0}
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv): class test_serialize_term(ManagerTermBase):
super().__init__(cfg, env)
def __call__(self, env: ManagerBasedEnv) -> torch.Tensor: def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv):
return grilled_chicken(env) super().__init__(cfg, env)
def serialize(self) -> dict: def __call__(self, env: ManagerBasedEnv) -> torch.Tensor:
return serialize_data return grilled_chicken(env)
@configclass def serialize(self) -> dict:
class MyObservationManagerCfg: return serialize_data
"""Test config class for observation manager."""
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass @configclass
class PolicyCfg(ObservationGroupCfg): class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group.""" """Test config class for policy observation group."""
concatenate_terms = False concatenate_terms = False
term_1 = ObservationTermCfg(func=test_serialize_term) term_1 = ObservationTermCfg(func=test_serialize_term)
policy: ObservationGroupCfg = PolicyCfg() policy: ObservationGroupCfg = PolicyCfg()
# create observation manager # create observation manager
cfg = MyObservationManagerCfg() cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env) obs_man = ObservationManager(cfg, env)
# check expected output # check expected output
self.assertEqual(self.obs_man.serialize(), {"policy": {"term_1": serialize_data}}) assert obs_man.serialize() == {"policy": {"term_1": serialize_data}}
def test_modifier_invalid_config(setup_env): def test_modifier_invalid_config(setup_env):
...@@ -728,3 +730,71 @@ def test_modifier_invalid_config(setup_env): ...@@ -728,3 +730,71 @@ def test_modifier_invalid_config(setup_env):
with pytest.raises(ValueError): with pytest.raises(ValueError):
ObservationManager(cfg, env) ObservationManager(cfg, env)
def test_concatenate_dim(setup_env):
"""Test concatenation of observations along different dimensions."""
env = setup_env
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
concatenate_terms = True
concatenate_dim = 1 # Concatenate along dimension 1
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
@configclass
class CriticCfg(ObservationGroupCfg):
"""Test config class for critic observation group."""
concatenate_terms = True
concatenate_dim = 2 # Concatenate along dimension 2
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
@configclass
class CriticCfg_neg_dim(ObservationGroupCfg):
"""Test config class for critic observation group."""
concatenate_terms = True
concatenate_dim = -1 # Concatenate along last dimension
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
policy: ObservationGroupCfg = PolicyCfg()
critic: ObservationGroupCfg = CriticCfg()
critic_neg_dim: ObservationGroupCfg = CriticCfg_neg_dim()
# create observation manager
cfg = MyObservationManagerCfg()
obs_man = ObservationManager(cfg, env)
# compute observation using manager
observations = obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
obs_critic: torch.Tensor = observations["critic"]
obs_critic_neg_dim: torch.Tensor = observations["critic_neg_dim"]
# check the observation shapes
# For policy: concatenated along dim 1, so width should be doubled
assert obs_policy.shape == (env.num_envs, 128, 512, 1)
# For critic: concatenated along last dim, so channels should be doubled
assert obs_critic.shape == (env.num_envs, 128, 256, 2)
# For critic_neg_dim: concatenated along last dim, so channels should be doubled
assert obs_critic_neg_dim.shape == (env.num_envs, 128, 256, 2)
# verify the data is concatenated correctly
# For policy: check that the second half matches the first half
torch.testing.assert_close(obs_policy[:, :, :256, :], obs_policy[:, :, 256:, :])
# For critic: check that the second channel matches the first channel
torch.testing.assert_close(obs_critic[:, :, :, 0], obs_critic[:, :, :, 1])
# For critic_neg_dim: check that it is the same as critic
torch.testing.assert_close(obs_critic_neg_dim, obs_critic)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment