Unverified Commit d2ad4cf6 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds new operations into different managers (#199)

# Description

This MR adds new operations to the various managers. These include:

* Adds a method `compute_group` to compute observations based on
group-names
* Adds methods to retrieve and set term configurations into some of the
managers
* Adds a method to modify reward term's weights as a curriculum

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 00702aa3
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.16"
version = "0.9.17"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.9.17 (2023-10-22)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added setters and getters for term configurations in the :class:`RandomizationManager`, :class:`RewardManager`
and :class:`TerminationManager` classes. This allows the user to modify the term configurations after the
manager has been created.
* Added the method :meth:`compute_group` to the :class:`omni.isaac.orbit.managers.ObservationManager` class to
compute the observations for only a given group.
* Added the curriculum term for modifying reward weights after certain environment steps.
0.9.16 (2023-10-22)
~~~~~~~~~~~~~~~~~~~
......
......@@ -51,3 +51,21 @@ def terrain_levels_vel(
terrain.update_env_origins(env_ids, move_up, move_down)
# return the mean terrain level
return torch.mean(terrain.terrain_levels.float())
def modify_reward_weight(env: RLEnv, env_ids: Sequence[int], term_name: str, weight: float, num_steps: int):
"""Curriculum that modifies a reward weight a given number of steps.
Args:
env: The learning environment.
env_ids: Not used since all environments are affected.
term_name: The name of the reward term.
weight: The weight of the reward term.
num_steps: The number of steps after which the change should be applied.
"""
if env.common_step_counter > num_steps:
# obtain term settings
term_cfg = env.reward_manager.get_term_cfg(term_name)
# update term settings
term_cfg.weight = weight
env.reward_manager.set_term_cfg(term_name, term_cfg)
......@@ -96,53 +96,79 @@ class ObservationManager(ManagerBase):
Operations.
"""
def compute(self) -> dict[str, torch.Tensor]:
"""Compute the observations per group.
def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Compute the observations per group for all groups.
The method computes the observations for each group and returns a dictionary with keys as
the group names and values as the computed observations. The observations are computed
by calling the registered functions for each term in the group. The functions are called
in the order of the terms in the group. The functions are expected to return a tensor
with shape ``(num_envs, ...)``. The tensors are then concatenated along the last dimension to
form the observations for the group.
The method computes the observations for all the groups handled by the observation manager.
Please check the :meth:`compute_group` on the processing of observations per group.
Returns:
A dictionary with keys as the group names and values as the computed observations.
"""
# create a buffer for storing obs from all the groups
obs_buffer = dict()
# iterate over all the terms in each group
for group_name in self._group_obs_term_names:
obs_buffer[group_name] = self.compute_group(group_name)
# otherwise return a dict with observations of all groups
return obs_buffer
def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tensor]:
"""Computes the observations for a given group.
The observations for a given group are computed by calling the registered functions for each
term in the group. The functions are called in the order of the terms in the group. The functions
are expected to return a tensor with shape ``(num_envs, ...)``.
If a corruption/noise model is registered for a term, the function is called to corrupt
the observation. The corruption function is expected to return a tensor with the same
shape as the observation. The observations are clipped and scaled as per the configuration
settings. By default, no scaling or clipping is applied.
Args:
group_name: The name of the group for which to compute the observations. Defaults to :obj:`None`,
in which case observations for all the groups are computed and returned.
Returns:
A dictionary with keys as the group names and values as the computed observations.
Depending on the group's configuration, the tensors for individual observation terms are
concatenated along the last dimension into a single tensor. Otherwise, they are returned as
a dictionary with keys corresponding to the term's name.
Raises:
ValueError: If input ``group_name`` is not a valid group handled by the manager.
"""
self._obs_buffer = dict()
# check ig group name is valid
if group_name not in self._group_obs_term_names:
raise ValueError(
f"Unable to find the group '{group_name}' in the observation manager."
f" Available groups are: {list(self._group_obs_term_names.keys())}"
)
# iterate over all the terms in each group
for group_name, group_term_names in self._group_obs_term_names.items():
# buffer to store obs per group
group_obs = dict.fromkeys(group_term_names, None)
# read attributes for each term
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])
# evaluate terms: compute, add noise, clip, scale.
for name, term_cfg in obs_terms:
# compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params)
# apply post-processing
if term_cfg.noise:
obs = term_cfg.noise.func(obs, term_cfg.noise)
if term_cfg.clip:
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale:
obs = obs.mul_(term_cfg.scale)
# TODO: Introduce delay and filtering models.
# Ref: https://robosuite.ai/docs/modules/sensors.html#observables
# add value to list
group_obs[name] = obs
# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
self._obs_buffer[group_name] = torch.cat(list(group_obs.values()), dim=-1)
else:
self._obs_buffer[group_name] = group_obs
# return all group observations
return self._obs_buffer
group_term_names = self._group_obs_term_names[group_name]
# buffer to store obs per group
group_obs = dict.fromkeys(group_term_names, None)
# read attributes for each term
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])
# evaluate terms: compute, add noise, clip, scale.
for name, term_cfg in obs_terms:
# compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params)
# apply post-processing
if term_cfg.noise:
obs = term_cfg.noise.func(obs, term_cfg.noise)
if term_cfg.clip:
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale:
obs = obs.mul_(term_cfg.scale)
# TODO: Introduce delay and filtering models.
# Ref: https://robosuite.ai/docs/modules/sensors.html#observables
# add value to list
group_obs[name] = obs
# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
return torch.cat(list(group_obs.values()), dim=-1)
else:
return group_obs
"""
Helper functions.
......
......@@ -143,6 +143,52 @@ class RandomizationManager(ManagerBase):
# call the randomization term
term_cfg.func(self._env, env_ids, **term_cfg.params)
"""
Operations - Term settings.
"""
def set_term_cfg(self, term_name: str, cfg: RandomizationTermCfg):
"""Sets the configuration of the specified term into the manager.
The method finds the term by name by searching through all the modes.
It then updates the configuration of the term with the first matching name.
Args:
term_name: The name of the randomization term.
cfg: The configuration for the randomization term.
Raises:
ValueError: If the term name is not found.
"""
term_found = False
for mode, terms in self._mode_term_names.items():
if term_name in terms:
self._mode_term_cfgs[mode][terms.index(term_name)] = cfg
term_found = True
break
if not term_found:
raise ValueError(f"Randomization term '{term_name}' not found.")
def get_term_cfg(self, term_name: str) -> RandomizationTermCfg:
"""Gets the configuration for the specified term.
The method finds the term by name by searching through all the modes.
It then returns the configuration of the term with the first matching name.
Args:
term_name: The name of the randomization term.
Returns:
The configuration of the randomization term.
Raises:
ValueError: If the term name is not found.
"""
for mode, terms in self._mode_term_names.items():
if term_name in terms:
return self._mode_term_cfgs[mode][terms.index(term_name)]
raise ValueError(f"Randomization term '{term_name}' not found.")
"""
Helper functions.
"""
......
......@@ -127,6 +127,9 @@ class RewardManager(ManagerBase):
self._reward_buf[:] = 0.0
# iterate over all the reward terms
for name, term_cfg in zip(self._term_names, self._term_cfgs):
# skip if weight is zero (kind of a micro-optimization)
if term_cfg.weight == 0.0:
continue
# compute term's value
value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt
# update total reward
......@@ -136,6 +139,42 @@ class RewardManager(ManagerBase):
return self._reward_buf
"""
Operations - Term settings.
"""
def set_term_cfg(self, term_name: str, cfg: RewardTermCfg):
"""Sets the configuration of the specified term into the manager.
Args:
term_name: The name of the reward term.
cfg: The configuration for the reward term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Reward term '{term_name}' not found.")
# set the configuration
self._term_cfgs[self._term_names.index(term_name)] = cfg
def get_term_cfg(self, term_name: str) -> RewardTermCfg:
"""Gets the configuration for the specified term.
Args:
term_name: The name of the reward term.
Returns:
The configuration of the reward term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Reward term '{term_name}' not found.")
# return the configuration
return self._term_cfgs[self._term_names.index(term_name)]
"""
Helper functions.
"""
......@@ -170,11 +209,6 @@ class RewardManager(ManagerBase):
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
# remove zero scales and multiply non-zero ones by dt
# note: we multiply weights by dt to make them agnostic to control decimation
term_cfg.weight = float(term_cfg.weight)
if term_cfg.weight == 0.0:
continue
# add function to list
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
......@@ -138,6 +138,42 @@ class TerminationManager(ManagerBase):
# return termination signal
return self._done_buf
"""
Operations - Term settings.
"""
def set_term_cfg(self, term_name: str, cfg: TerminationTermCfg):
"""Sets the configuration of the specified term into the manager.
Args:
term_name: The name of the termination term.
cfg: The configuration for the termination term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Termination term '{term_name}' not found.")
# set the configuration
self._term_cfgs[self._term_names.index(term_name)] = cfg
def get_term_cfg(self, term_name: str) -> TerminationTermCfg:
"""Gets the configuration for the specified term.
Args:
term_name: The name of the termination term.
Returns:
The configuration of the termination term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Termination term '{term_name}' not found.")
# return the configuration
return self._term_cfgs[self._term_names.index(term_name)]
"""
Helper functions.
"""
......
......@@ -111,16 +111,6 @@ class TestRewardManager(unittest.TestCase):
self.assertEqual(rew_man_from_cfg._term_cfgs, rew_man_from_annotated_cfg._term_cfgs)
self.assertEqual(rew_man_from_dict._term_cfgs, rew_man_from_cfg._term_cfgs)
def test_config_terms(self):
"""Test the ignoring of terms with zero weight."""
cfg = {
"term_1": RewardTermCfg(func=grilled_chicken, weight=10),
"term_2": RewardTermCfg(func=grilled_chicken_with_curry, weight=0.0, params={"hot": False}),
}
self.rew_man = RewardManager(cfg, self.env)
self.assertEqual(self.rew_man.active_terms, ["term_1"])
def test_compute(self):
"""Test the computation of reward."""
cfg = {
......@@ -137,7 +127,7 @@ class TestRewardManager(unittest.TestCase):
self.assertEqual(tuple(rewards.shape), (self.env.num_envs,))
def test_active_terms(self):
"""Test the ignoring of terms with zero weight."""
"""Test the correct reading of active terms."""
cfg = {
"term_1": RewardTermCfg(func=grilled_chicken, weight=10),
"term_2": RewardTermCfg(func=grilled_chicken_with_bbq, weight=5, params={"bbq": True}),
......@@ -145,7 +135,7 @@ class TestRewardManager(unittest.TestCase):
}
self.rew_man = RewardManager(cfg, self.env)
self.assertEqual(len(self.rew_man.active_terms), 2)
self.assertEqual(len(self.rew_man.active_terms), 3)
def test_missing_weight(self):
"""Test the missing of weight in the config."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment