Unverified Commit 4cb968ac authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Allows having hybrid dimensional terms inside an observation group (#772)

# Description

Previously, the observation manager rigidly tried concatenating all the
terms along the last dimension at construction. However, when one has
hybrid-dimensional terms inside the observation group, this operation
would fail, leading to an error. In principle, the concatenation should
only happen when users set the attribute `concatenate_terms` to True
inside the group settings.

This MR lifts the above operation and adds a proper error when users try
to concatenate observation terms of different shapes. They must set the
concatenation flag to False if they want a "hybrid" observation group.

Fixes #768

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 9050d2f6
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.20.7"
version = "0.20.8"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.20.8 (2024-08-02)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed the handling of observation terms with different shapes in the
:class:`~omni.isaac.lab.managers.ObservationManager` class. Earlier, the constructor would throw an error if the
shapes of the observation terms were different. Now, this operation only happens when the terms in an observation
group are being concatenated. Otherwise, the terms are stored as a dictionary of tensors.
* Improved the error message when the observation terms are not of the same shape in the
:class:`~omni.isaac.lab.managers.ObservationManager` class and the terms are being concatenated.
0.20.7 (2024-08-02)
~~~~~~~~~~~~~~~~~~~
......
......@@ -289,7 +289,6 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
# extract quantities about the group
has_concatenated_obs = self.observation_manager.group_obs_concatenate[group_name]
group_dim = self.observation_manager.group_obs_dim[group_name]
group_term_dim = self.observation_manager.group_obs_term_dim[group_name]
# check if group is concatenated or not
# if not concatenated, then we need to add each term separately as a dictionary
if has_concatenated_obs:
......@@ -297,7 +296,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
else:
self.single_observation_space[group_name] = gym.spaces.Dict({
term_name: gym.spaces.Box(low=-np.inf, high=np.inf, shape=term_dim)
for term_name, term_dim in zip(group_term_names, group_term_dim)
for term_name, term_dim in zip(group_term_names, group_dim)
})
# action space (unbounded since we don't impose any limits)
action_dim = sum(self.action_manager.action_term_dim)
......
......@@ -154,6 +154,8 @@ class ObservationGroupCfg:
If true, the observation terms in the group are concatenated along the last dimension.
Otherwise, they are kept separate and returned as a dictionary.
If the observation group contains terms of different dimensions, it must be set to False.
"""
enable_corruption: bool = False
......
......@@ -28,7 +28,26 @@ class ObservationManager(ManagerBase):
corruption model to use, and the sensor to retrieve data from.
Each observation group should inherit from the :class:`ObservationGroupCfg` class. Within each group, each
observation term should instantiate the :class:`ObservationTermCfg` class.
observation term should instantiate the :class:`ObservationTermCfg` class. Based on the configuration, the
observations in a group can be concatenated into a single tensor or returned as a dictionary with keys
corresponding to the term's name.
If the observations in a group are concatenated, the shape of the concatenated tensor is computed based on the
shapes of the individual observation terms. This information is stored in the :attr:`group_obs_dim` dictionary
with keys as the group names and values as the shape of the observation tensor. When the terms in a group are not
concatenated, the attribute stores a list of shapes for each term in the group.
.. note::
When the observation terms in a group do not have the same shape, the observation terms cannot be
concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the
group configuration to False.
The observation manager can be used to compute observations for all the groups or for a specific group. The
observations are computed by calling the registered functions for each term in the group. The functions are
called in the order of the terms in the group. The functions are expected to return a tensor with shape
(num_envs, ...). If a corruption/noise model is registered for a term, the function is called to corrupt
the observation. The corruption function is expected to return a tensor with the same shape as the observation.
The observations are clipped and scaled as per the configuration settings.
"""
def __init__(self, cfg: object, env: ManagerBasedEnv):
......@@ -37,13 +56,31 @@ class ObservationManager(ManagerBase):
Args:
cfg: The configuration object or dictionary (``dict[str, ObservationGroupCfg]``).
env: The environment instance.
Raises:
RuntimeError: If the shapes of the observation terms in a group are not compatible for concatenation
and the :attr:`~ObservationGroupCfg.concatenate_terms` attribute is set to True.
"""
super().__init__(cfg, env)
# compute combined vector for obs group
self._group_obs_dim: dict[str, tuple[int, ...]] = dict()
self._group_obs_dim: dict[str, tuple[int, ...] | list[tuple[int, ...]]] = dict()
for group_name, group_term_dims in self._group_obs_term_dim.items():
# if terms are concatenated, compute the combined shape into a single tuple
# otherwise, keep the list of shapes as is
if self._group_obs_concatenate[group_name]:
try:
term_dims = [torch.tensor(dims, device="cpu") for dims in group_term_dims]
self._group_obs_dim[group_name] = tuple(torch.sum(torch.stack(term_dims, dim=0), dim=0).tolist())
except RuntimeError:
raise RuntimeError(
f"Unable to concatenate observation terms in group '{group_name}'."
f" The shapes of the terms are: {group_term_dims}."
" Please ensure that the shapes are compatible for concatenation."
" Otherwise, set 'concatenate_terms' to False in the group configuration."
)
else:
self._group_obs_dim[group_name] = group_term_dims
def __str__(self) -> str:
"""Returns: A string representation for the observation manager."""
......@@ -53,7 +90,9 @@ class ObservationManager(ManagerBase):
for group_name, group_dim in self._group_obs_dim.items():
# create table for term information
table = PrettyTable()
table.title = f"Active Observation Terms in Group: '{group_name}' (shape: {group_dim})"
table.title = f"Active Observation Terms in Group: '{group_name}'"
if self._group_obs_concatenate[group_name]:
table.title += f" (shape: {group_dim})"
table.field_names = ["Index", "Name", "Shape"]
# set alignment of table columns
table.align["Name"] = "l"
......@@ -79,22 +118,43 @@ class ObservationManager(ManagerBase):
@property
def active_terms(self) -> dict[str, list[str]]:
"""Name of active observation terms in each group."""
"""Name of active observation terms in each group.
The keys are the group names and the values are the list of observation term names in the group.
"""
return self._group_obs_term_names
@property
def group_obs_dim(self) -> dict[str, tuple[int, ...]]:
"""Shape of observation tensor in each group."""
def group_obs_dim(self) -> dict[str, tuple[int, ...] | list[tuple[int, ...]]]:
"""Shape of computed observations in each group.
The key is the group name and the value is the shape of the observation tensor.
If the terms in the group are concatenated, the value is a single tuple representing the
shape of the concatenated observation tensor. Otherwise, the value is a list of tuples,
where each tuple represents the shape of the observation tensor for a term in the group.
"""
return self._group_obs_dim
@property
def group_obs_term_dim(self) -> dict[str, list[tuple[int, ...]]]:
"""Shape of observation tensor for each term in each group."""
"""Shape of individual observation terms in each group.
The key is the group name and the value is a list of tuples representing the shape of the observation terms
in the group. The order of the tuples corresponds to the order of the terms in the group.
This matches the order of the terms in the :attr:`active_terms`.
"""
return self._group_obs_term_dim
@property
def group_obs_concatenate(self) -> dict[str, bool]:
"""Whether the observation terms are concatenated in each group."""
"""Whether the observation terms are concatenated in each group or not.
The key is the group name and the value is a boolean specifying whether the observation terms in the group
are concatenated into a single tensor. If True, the observations are concatenated along the last dimension.
The values are set based on the :attr:`~ObservationGroupCfg.concatenate_terms` attribute in the group
configuration.
"""
return self._group_obs_concatenate
"""
......@@ -117,6 +177,8 @@ class ObservationManager(ManagerBase):
Returns:
A dictionary with keys as the group names and values as the computed observations.
The observations are either concatenated into a single tensor or returned as a dictionary
with keys corresponding to the term's name.
"""
# create a buffer for storing obs from all the groups
obs_buffer = dict()
......@@ -195,7 +257,7 @@ class ObservationManager(ManagerBase):
# create buffers to store information for each observation group
# TODO: Make this more convenient by using data structures.
self._group_obs_term_names: dict[str, list[str]] = dict()
self._group_obs_term_dim: dict[str, list[int]] = dict()
self._group_obs_term_dim: dict[str, list[tuple[int, ...]]] = dict()
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()
......
......@@ -43,6 +43,10 @@ def grilled_chicken_with_yoghurt_and_bbq(env, hot: bool, bland: float, bbq: bool
return hot * bland * bbq * torch.ones(env.num_envs, 3, device=env.device)
def grilled_chicken_image(env, bland: float, channel: int = 1):
return bland * torch.ones(env.num_envs, 128, 256, channel, device=env.device)
class complex_function_class(ManagerTermBase):
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
......@@ -201,8 +205,24 @@ class TestObservationManager(unittest.TestCase):
term_1 = ObservationTermCfg(func=grilled_chicken, scale=10)
term_2 = ObservationTermCfg(func=grilled_chicken_with_curry, scale=0.0, params={"hot": False})
@configclass
class SampleMixedGroupCfg(ObservationGroupCfg):
"""Test config class for policy observation group with a mix of vector and matrix terms."""
concatenate_terms = False
term_1 = ObservationTermCfg(func=grilled_chicken, scale=2.0)
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=1.5, params={"bland": 0.5})
@configclass
class SampleImageGroupCfg(ObservationGroupCfg):
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.5, params={"bland": 0.5, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=0.5, params={"bland": 0.1, "channel": 3})
policy: ObservationGroupCfg = SampleGroupCfg()
critic: ObservationGroupCfg = SampleGroupCfg(term_2=None)
mixed: ObservationGroupCfg = SampleMixedGroupCfg()
image: ObservationGroupCfg = SampleImageGroupCfg()
# create observation manager
cfg = MyObservationManagerCfg()
......@@ -210,6 +230,15 @@ class TestObservationManager(unittest.TestCase):
self.assertEqual(len(self.obs_man.active_terms["policy"]), 2)
self.assertEqual(len(self.obs_man.active_terms["critic"]), 1)
self.assertEqual(len(self.obs_man.active_terms["mixed"]), 2)
self.assertEqual(len(self.obs_man.active_terms["image"]), 2)
# create a new obs manager but where mixed group has invalid config
cfg = MyObservationManagerCfg()
cfg.mixed.concatenate_terms = True
with self.assertRaises(RuntimeError):
ObservationManager(cfg, self.env)
def test_compute(self):
"""Test the observation computation."""
......@@ -234,8 +263,15 @@ class TestObservationManager(unittest.TestCase):
term_3 = ObservationTermCfg(func=pos_w_data, scale=2.0)
term_4 = ObservationTermCfg(func=lin_vel_w_data, scale=1.5)
@configclass
class ImageCfg(ObservationGroupCfg):
term_1 = ObservationTermCfg(func=grilled_chicken_image, scale=1.5, params={"bland": 0.5, "channel": 1})
term_2 = ObservationTermCfg(func=grilled_chicken_image, scale=0.5, params={"bland": 0.1, "channel": 3})
policy: ObservationGroupCfg = PolicyCfg()
critic: ObservationGroupCfg = CriticCfg()
image: ObservationGroupCfg = ImageCfg()
# create observation manager
cfg = MyObservationManagerCfg()
......@@ -246,10 +282,12 @@ class TestObservationManager(unittest.TestCase):
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
obs_critic: torch.Tensor = observations["critic"]
obs_image: torch.Tensor = observations["image"]
# check the observation shape
self.assertEqual((self.env.num_envs, 11), obs_policy.shape)
self.assertEqual((self.env.num_envs, 12), obs_critic.shape)
self.assertEqual((self.env.num_envs, 128, 256, 4), obs_image.shape)
# make sure that the data are the same for same terms
# -- within group
torch.testing.assert_close(obs_critic[:, 0:3], obs_critic[:, 6:9])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment