Unverified Commit d2ad4cf6 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds new operations into different managers (#199)

# Description

This MR adds new operations to the various managers. These include:

* Adds a method `compute_group` to compute observations based on
group-names
* Adds methods to retrieve and set term configurations into some of the
managers
* Adds a method to modify reward term's weights as a curriculum

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 00702aa3
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.9.16" version = "0.9.17"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.9.17 (2023-10-22)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added setters and getters for term configurations in the :class:`RandomizationManager`, :class:`RewardManager`
and :class:`TerminationManager` classes. This allows the user to modify the term configurations after the
manager has been created.
* Added the method :meth:`compute_group` to the :class:`omni.isaac.orbit.managers.ObservationManager` class to
compute the observations for only a given group.
* Added the curriculum term for modifying reward weights after certain environment steps.
0.9.16 (2023-10-22) 0.9.16 (2023-10-22)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -51,3 +51,21 @@ def terrain_levels_vel( ...@@ -51,3 +51,21 @@ def terrain_levels_vel(
terrain.update_env_origins(env_ids, move_up, move_down) terrain.update_env_origins(env_ids, move_up, move_down)
# return the mean terrain level # return the mean terrain level
return torch.mean(terrain.terrain_levels.float()) return torch.mean(terrain.terrain_levels.float())
def modify_reward_weight(env: RLEnv, env_ids: Sequence[int], term_name: str, weight: float, num_steps: int):
"""Curriculum that modifies a reward weight a given number of steps.
Args:
env: The learning environment.
env_ids: Not used since all environments are affected.
term_name: The name of the reward term.
weight: The weight of the reward term.
num_steps: The number of steps after which the change should be applied.
"""
if env.common_step_counter > num_steps:
# obtain term settings
term_cfg = env.reward_manager.get_term_cfg(term_name)
# update term settings
term_cfg.weight = weight
env.reward_manager.set_term_cfg(term_name, term_cfg)
...@@ -96,53 +96,79 @@ class ObservationManager(ManagerBase): ...@@ -96,53 +96,79 @@ class ObservationManager(ManagerBase):
Operations. Operations.
""" """
def compute(self) -> dict[str, torch.Tensor]: def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Compute the observations per group. """Compute the observations per group for all groups.
The method computes the observations for each group and returns a dictionary with keys as The method computes the observations for all the groups handled by the observation manager.
the group names and values as the computed observations. The observations are computed Please check the :meth:`compute_group` on the processing of observations per group.
by calling the registered functions for each term in the group. The functions are called
in the order of the terms in the group. The functions are expected to return a tensor Returns:
with shape ``(num_envs, ...)``. The tensors are then concatenated along the last dimension to A dictionary with keys as the group names and values as the computed observations.
form the observations for the group. """
# create a buffer for storing obs from all the groups
obs_buffer = dict()
# iterate over all the terms in each group
for group_name in self._group_obs_term_names:
obs_buffer[group_name] = self.compute_group(group_name)
# otherwise return a dict with observations of all groups
return obs_buffer
def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tensor]:
"""Computes the observations for a given group.
The observations for a given group are computed by calling the registered functions for each
term in the group. The functions are called in the order of the terms in the group. The functions
are expected to return a tensor with shape ``(num_envs, ...)``.
If a corruption/noise model is registered for a term, the function is called to corrupt If a corruption/noise model is registered for a term, the function is called to corrupt
the observation. The corruption function is expected to return a tensor with the same the observation. The corruption function is expected to return a tensor with the same
shape as the observation. The observations are clipped and scaled as per the configuration shape as the observation. The observations are clipped and scaled as per the configuration
settings. By default, no scaling or clipping is applied. settings. By default, no scaling or clipping is applied.
Args:
group_name: The name of the group for which to compute the observations. Defaults to :obj:`None`,
in which case observations for all the groups are computed and returned.
Returns: Returns:
A dictionary with keys as the group names and values as the computed observations. Depending on the group's configuration, the tensors for individual observation terms are
concatenated along the last dimension into a single tensor. Otherwise, they are returned as
a dictionary with keys corresponding to the term's name.
Raises:
ValueError: If input ``group_name`` is not a valid group handled by the manager.
""" """
self._obs_buffer = dict() # check ig group name is valid
if group_name not in self._group_obs_term_names:
raise ValueError(
f"Unable to find the group '{group_name}' in the observation manager."
f" Available groups are: {list(self._group_obs_term_names.keys())}"
)
# iterate over all the terms in each group # iterate over all the terms in each group
for group_name, group_term_names in self._group_obs_term_names.items(): group_term_names = self._group_obs_term_names[group_name]
# buffer to store obs per group # buffer to store obs per group
group_obs = dict.fromkeys(group_term_names, None) group_obs = dict.fromkeys(group_term_names, None)
# read attributes for each term # read attributes for each term
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name]) obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])
# evaluate terms: compute, add noise, clip, scale. # evaluate terms: compute, add noise, clip, scale.
for name, term_cfg in obs_terms: for name, term_cfg in obs_terms:
# compute term's value # compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params) obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params)
# apply post-processing # apply post-processing
if term_cfg.noise: if term_cfg.noise:
obs = term_cfg.noise.func(obs, term_cfg.noise) obs = term_cfg.noise.func(obs, term_cfg.noise)
if term_cfg.clip: if term_cfg.clip:
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1]) obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale: if term_cfg.scale:
obs = obs.mul_(term_cfg.scale) obs = obs.mul_(term_cfg.scale)
# TODO: Introduce delay and filtering models. # TODO: Introduce delay and filtering models.
# Ref: https://robosuite.ai/docs/modules/sensors.html#observables # Ref: https://robosuite.ai/docs/modules/sensors.html#observables
# add value to list # add value to list
group_obs[name] = obs group_obs[name] = obs
# concatenate all observations in the group together # concatenate all observations in the group together
if self._group_obs_concatenate[group_name]: if self._group_obs_concatenate[group_name]:
self._obs_buffer[group_name] = torch.cat(list(group_obs.values()), dim=-1) return torch.cat(list(group_obs.values()), dim=-1)
else: else:
self._obs_buffer[group_name] = group_obs return group_obs
# return all group observations
return self._obs_buffer
""" """
Helper functions. Helper functions.
......
...@@ -143,6 +143,52 @@ class RandomizationManager(ManagerBase): ...@@ -143,6 +143,52 @@ class RandomizationManager(ManagerBase):
# call the randomization term # call the randomization term
term_cfg.func(self._env, env_ids, **term_cfg.params) term_cfg.func(self._env, env_ids, **term_cfg.params)
"""
Operations - Term settings.
"""
def set_term_cfg(self, term_name: str, cfg: RandomizationTermCfg):
"""Sets the configuration of the specified term into the manager.
The method finds the term by name by searching through all the modes.
It then updates the configuration of the term with the first matching name.
Args:
term_name: The name of the randomization term.
cfg: The configuration for the randomization term.
Raises:
ValueError: If the term name is not found.
"""
term_found = False
for mode, terms in self._mode_term_names.items():
if term_name in terms:
self._mode_term_cfgs[mode][terms.index(term_name)] = cfg
term_found = True
break
if not term_found:
raise ValueError(f"Randomization term '{term_name}' not found.")
def get_term_cfg(self, term_name: str) -> RandomizationTermCfg:
"""Gets the configuration for the specified term.
The method finds the term by name by searching through all the modes.
It then returns the configuration of the term with the first matching name.
Args:
term_name: The name of the randomization term.
Returns:
The configuration of the randomization term.
Raises:
ValueError: If the term name is not found.
"""
for mode, terms in self._mode_term_names.items():
if term_name in terms:
return self._mode_term_cfgs[mode][terms.index(term_name)]
raise ValueError(f"Randomization term '{term_name}' not found.")
""" """
Helper functions. Helper functions.
""" """
......
...@@ -127,6 +127,9 @@ class RewardManager(ManagerBase): ...@@ -127,6 +127,9 @@ class RewardManager(ManagerBase):
self._reward_buf[:] = 0.0 self._reward_buf[:] = 0.0
# iterate over all the reward terms # iterate over all the reward terms
for name, term_cfg in zip(self._term_names, self._term_cfgs): for name, term_cfg in zip(self._term_names, self._term_cfgs):
# skip if weight is zero (kind of a micro-optimization)
if term_cfg.weight == 0.0:
continue
# compute term's value # compute term's value
value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt value = term_cfg.func(self._env, **term_cfg.params) * term_cfg.weight * dt
# update total reward # update total reward
...@@ -136,6 +139,42 @@ class RewardManager(ManagerBase): ...@@ -136,6 +139,42 @@ class RewardManager(ManagerBase):
return self._reward_buf return self._reward_buf
"""
Operations - Term settings.
"""
def set_term_cfg(self, term_name: str, cfg: RewardTermCfg):
"""Sets the configuration of the specified term into the manager.
Args:
term_name: The name of the reward term.
cfg: The configuration for the reward term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Reward term '{term_name}' not found.")
# set the configuration
self._term_cfgs[self._term_names.index(term_name)] = cfg
def get_term_cfg(self, term_name: str) -> RewardTermCfg:
"""Gets the configuration for the specified term.
Args:
term_name: The name of the reward term.
Returns:
The configuration of the reward term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Reward term '{term_name}' not found.")
# return the configuration
return self._term_cfgs[self._term_names.index(term_name)]
""" """
Helper functions. Helper functions.
""" """
...@@ -170,11 +209,6 @@ class RewardManager(ManagerBase): ...@@ -170,11 +209,6 @@ class RewardManager(ManagerBase):
) )
# resolve common parameters # resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1) self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
# remove zero scales and multiply non-zero ones by dt
# note: we multiply weights by dt to make them agnostic to control decimation
term_cfg.weight = float(term_cfg.weight)
if term_cfg.weight == 0.0:
continue
# add function to list # add function to list
self._term_names.append(term_name) self._term_names.append(term_name)
self._term_cfgs.append(term_cfg) self._term_cfgs.append(term_cfg)
...@@ -138,6 +138,42 @@ class TerminationManager(ManagerBase): ...@@ -138,6 +138,42 @@ class TerminationManager(ManagerBase):
# return termination signal # return termination signal
return self._done_buf return self._done_buf
"""
Operations - Term settings.
"""
def set_term_cfg(self, term_name: str, cfg: TerminationTermCfg):
"""Sets the configuration of the specified term into the manager.
Args:
term_name: The name of the termination term.
cfg: The configuration for the termination term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Termination term '{term_name}' not found.")
# set the configuration
self._term_cfgs[self._term_names.index(term_name)] = cfg
def get_term_cfg(self, term_name: str) -> TerminationTermCfg:
"""Gets the configuration for the specified term.
Args:
term_name: The name of the termination term.
Returns:
The configuration of the termination term.
Raises:
ValueError: If the term name is not found.
"""
if term_name not in self._term_names:
raise ValueError(f"Termination term '{term_name}' not found.")
# return the configuration
return self._term_cfgs[self._term_names.index(term_name)]
""" """
Helper functions. Helper functions.
""" """
......
...@@ -111,16 +111,6 @@ class TestRewardManager(unittest.TestCase): ...@@ -111,16 +111,6 @@ class TestRewardManager(unittest.TestCase):
self.assertEqual(rew_man_from_cfg._term_cfgs, rew_man_from_annotated_cfg._term_cfgs) self.assertEqual(rew_man_from_cfg._term_cfgs, rew_man_from_annotated_cfg._term_cfgs)
self.assertEqual(rew_man_from_dict._term_cfgs, rew_man_from_cfg._term_cfgs) self.assertEqual(rew_man_from_dict._term_cfgs, rew_man_from_cfg._term_cfgs)
def test_config_terms(self):
"""Test the ignoring of terms with zero weight."""
cfg = {
"term_1": RewardTermCfg(func=grilled_chicken, weight=10),
"term_2": RewardTermCfg(func=grilled_chicken_with_curry, weight=0.0, params={"hot": False}),
}
self.rew_man = RewardManager(cfg, self.env)
self.assertEqual(self.rew_man.active_terms, ["term_1"])
def test_compute(self): def test_compute(self):
"""Test the computation of reward.""" """Test the computation of reward."""
cfg = { cfg = {
...@@ -137,7 +127,7 @@ class TestRewardManager(unittest.TestCase): ...@@ -137,7 +127,7 @@ class TestRewardManager(unittest.TestCase):
self.assertEqual(tuple(rewards.shape), (self.env.num_envs,)) self.assertEqual(tuple(rewards.shape), (self.env.num_envs,))
def test_active_terms(self): def test_active_terms(self):
"""Test the ignoring of terms with zero weight.""" """Test the correct reading of active terms."""
cfg = { cfg = {
"term_1": RewardTermCfg(func=grilled_chicken, weight=10), "term_1": RewardTermCfg(func=grilled_chicken, weight=10),
"term_2": RewardTermCfg(func=grilled_chicken_with_bbq, weight=5, params={"bbq": True}), "term_2": RewardTermCfg(func=grilled_chicken_with_bbq, weight=5, params={"bbq": True}),
...@@ -145,7 +135,7 @@ class TestRewardManager(unittest.TestCase): ...@@ -145,7 +135,7 @@ class TestRewardManager(unittest.TestCase):
} }
self.rew_man = RewardManager(cfg, self.env) self.rew_man = RewardManager(cfg, self.env)
self.assertEqual(len(self.rew_man.active_terms), 2) self.assertEqual(len(self.rew_man.active_terms), 3)
def test_missing_weight(self): def test_missing_weight(self):
"""Test the missing of weight in the config.""" """Test the missing of weight in the config."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment