Unverified Commit be41bb0d authored by Pascal Roth's avatar Pascal Roth Committed by GitHub

Adds option to define the concatenation dimension in the `ObservationManager`...

Adds option to define the concatenation dimension in the `ObservationManager` and change counter update in `CommandManager` (#2393)

# Description


Added support for concatenation of observations along different
dimensions in `ObservationManager`.

Updates the position where the command counter is increased to allow
checking for reset environments in the resample call of the
`CommandManager`

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarPascal Roth <57946385+pascal-roth@users.noreply.github.com>
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Signed-off-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 963b53b9
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.39.4"
version = "0.39.5"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.39.5 (2025-05-16)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for concatenation of observations along different dimensions in :class:`~isaaclab.managers.observation_manager.ObservationManager`.
Changed
^^^^^^^
* Updated the :class:`~isaaclab.managers.command_manager.CommandManager` to update the command counter after the
resampling call.
0.39.4 (2025-05-16)
~~~~~~~~~~~~~~~~~~~
......
......@@ -181,10 +181,10 @@ class CommandTerm(ManagerTermBase):
if len(env_ids) != 0:
# resample the time left before resampling
self.time_left[env_ids] = self.time_left[env_ids].uniform_(*self.cfg.resampling_time_range)
# increment the command counter
self.command_counter[env_ids] += 1
# resample the command
self._resample_command(env_ids)
# increment the command counter
self.command_counter[env_ids] += 1
"""
Implementation specific functions.
......
......@@ -201,12 +201,22 @@ class ObservationGroupCfg:
concatenate_terms: bool = True
"""Whether to concatenate the observation terms in the group. Defaults to True.
If true, the observation terms in the group are concatenated along the last dimension.
If true, the observation terms in the group are concatenated along the dimension specified through :attr:`concatenate_dim`.
Otherwise, they are kept separate and returned as a dictionary.
If the observation group contains terms of different dimensions, it must be set to False.
"""
concatenate_dim: int = -1
"""Dimension along to concatenate the different observation terms. Defaults to -1, which
means the last dimension of the observation terms.
If :attr:`concatenate_terms` is True, this parameter specifies the dimension along which the observation terms are concatenated.
The indicated dimension depends on the shape of the observations. For instance, for a 2D RGB image of shape (H, W, C), the dimension
0 means concatenating along the height, 1 along the width, and 2 along the channels. The offset due
to the batched environment is handled automatically.
"""
enable_corruption: bool = False
"""Whether to enable corruption for the observation group. Defaults to False.
......
......@@ -88,8 +88,18 @@ class ObservationManager(ManagerBase):
# otherwise, keep the list of shapes as is
if self._group_obs_concatenate[group_name]:
try:
term_dims = [torch.tensor(dims, device="cpu") for dims in group_term_dims]
self._group_obs_dim[group_name] = tuple(torch.sum(torch.stack(term_dims, dim=0), dim=0).tolist())
term_dims = torch.stack([torch.tensor(dims, device="cpu") for dims in group_term_dims], dim=0)
if len(term_dims.shape) > 1:
if self._group_obs_concatenate_dim[group_name] >= 0:
dim = self._group_obs_concatenate_dim[group_name] - 1 # account for the batch offset
else:
dim = self._group_obs_concatenate_dim[group_name]
dim_sum = torch.sum(term_dims[:, dim], dim=0)
term_dims[0, dim] = dim_sum
term_dims = term_dims[0]
else:
term_dims = torch.sum(term_dims, dim=0)
self._group_obs_dim[group_name] = tuple(term_dims.tolist())
except RuntimeError:
raise RuntimeError(
f"Unable to concatenate observation terms in group '{group_name}'."
......@@ -330,7 +340,8 @@ class ObservationManager(ManagerBase):
# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
return torch.cat(list(group_obs.values()), dim=-1)
# set the concatenate dimension, account for the batch dimension if positive dimension is given
return torch.cat(list(group_obs.values()), dim=self._group_obs_concatenate_dim[group_name])
else:
return group_obs
......@@ -370,6 +381,8 @@ class ObservationManager(ManagerBase):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()
self._group_obs_concatenate_dim: dict[str, int] = dict()
self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls
......@@ -407,6 +420,9 @@ class ObservationManager(ManagerBase):
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
self._group_obs_concatenate_dim[group_name] = (
group_cfg.concatenate_dim + 1 if group_cfg.concatenate_dim >= 0 else group_cfg.concatenate_dim
)
# check if config is dict already
if isinstance(group_cfg, dict):
group_cfg_items = group_cfg.items()
......@@ -415,7 +431,13 @@ class ObservationManager(ManagerBase):
# iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items:
# skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
if term_name in [
"enable_corruption",
"concatenate_terms",
"history_length",
"flatten_history_dim",
"concatenate_dim",
]:
continue
# check for non config
if term_cfg is None:
......
......@@ -667,8 +667,10 @@ def test_modifier_compute(setup_env):
assert torch.min(obs_critic["term_4"]) >= -0.5
assert torch.max(obs_critic["term_4"]) <= 0.5
def test_serialize(self):
def test_serialize(setup_env):
"""Test serialize call for ManagerTermBase terms."""
env = setup_env
serialize_data = {"test": 0}
......@@ -698,10 +700,10 @@ def test_modifier_compute(setup_env):
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
obs_man = ObservationManager(cfg, env)
# check expected output
self.assertEqual(self.obs_man.serialize(), {"policy": {"term_1": serialize_data}})
assert obs_man.serialize() == {"policy": {"term_1": serialize_data}}
def test_modifier_invalid_config(setup_env):
......@@ -728,3 +730,71 @@ def test_modifier_invalid_config(setup_env):
with pytest.raises(ValueError):
ObservationManager(cfg, env)
def test_concatenate_dim(setup_env):
"""Test concatenation of observations along different dimensions."""
env = setup_env
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
concatenate_terms = True
concatenate_dim = 1 # Concatenate along dimension 1
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
@configclass
class CriticCfg(ObservationGroupCfg):
"""Test config class for critic observation group."""
concatenate_terms = True
concatenate_dim = 2 # Concatenate along dimension 2
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
@configclass
class CriticCfg_neg_dim(ObservationGroupCfg):
"""Test config class for critic observation group."""
concatenate_terms = True
concatenate_dim = -1 # Concatenate along last dimension
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.0, params={"bland": 1.0, "channel": 1})
policy: ObservationGroupCfg = PolicyCfg()
critic: ObservationGroupCfg = CriticCfg()
critic_neg_dim: ObservationGroupCfg = CriticCfg_neg_dim()
# create observation manager
cfg = MyObservationManagerCfg()
obs_man = ObservationManager(cfg, env)
# compute observation using manager
observations = obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
obs_critic: torch.Tensor = observations["critic"]
obs_critic_neg_dim: torch.Tensor = observations["critic_neg_dim"]
# check the observation shapes
# For policy: concatenated along dim 1, so width should be doubled
assert obs_policy.shape == (env.num_envs, 128, 512, 1)
# For critic: concatenated along last dim, so channels should be doubled
assert obs_critic.shape == (env.num_envs, 128, 256, 2)
# For critic_neg_dim: concatenated along last dim, so channels should be doubled
assert obs_critic_neg_dim.shape == (env.num_envs, 128, 256, 2)
# verify the data is concatenated correctly
# For policy: check that the second half matches the first half
torch.testing.assert_close(obs_policy[:, :, :256, :], obs_policy[:, :, 256:, :])
# For critic: check that the second channel matches the first channel
torch.testing.assert_close(obs_critic[:, :, :, 0], obs_critic[:, :, :, 1])
# For critic_neg_dim: check that it is the same as critic
torch.testing.assert_close(obs_critic_neg_dim, obs_critic)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment