Unverified Commit 82cbdc22 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds the termination, curriculum and randomization managers (#88)

# Description

This PR adds the following managers similar to how we currently handle
observation and reward terms.

* **Termination Manager**: Iterates over all the configured terms and
computes the done signals as an OR operator over each term's output.
Additionally, `time_outs` are handled separately as they are optional
(i.e. only used in fixed-length episodic learning).
* **Randomization Manager**: Handles various randomization (such as
resetting the state of the environments, and modifying various physics
attributes).
* **Curriculum Manager**: Iterates over all the configured terms and
sets the curriculum setting into the environment accordingly.

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file

---------
Co-authored-by: 's avatarDavid Hoeller <dhoeller@ethz.ch>
Co-authored-by: 's avatarNikita Rudin <nrudin@nvidia.com>
parent b25d6673
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.7.0" version = "0.7.1"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.7.1 (2023-07-10)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added the :class:`TerminationManager`, :class:`CurriculumManager`, and :class:`RandomizationManager` classes
to the :mod:`omni.isaac.orbit.managers` module to handle termination, curriculum, and randomization respectively.
0.7.0 (2023-07-10) 0.7.0 (2023-07-10)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -33,14 +33,35 @@ Example pseudo-code for a manager: ...@@ -33,14 +33,35 @@ Example pseudo-code for a manager:
""" """
from .manager_cfg import ObservationGroupCfg, ObservationTermCfg, RewardTermCfg from .curriculum_manager import CurriculumManager
from .manager_cfg import (
CurriculumTermCfg,
ObservationGroupCfg,
ObservationTermCfg,
RandomizationTermCfg,
RewardTermCfg,
TerminationTermCfg,
)
from .observation_manager import ObservationManager from .observation_manager import ObservationManager
from .randomization_manager import RandomizationManager
from .reward_manager import RewardManager from .reward_manager import RewardManager
from .termination_manager import TerminationManager
__all__ = [ __all__ = [
# curriculum
"CurriculumTermCfg",
"CurriculumManager",
# observation
"ObservationGroupCfg", "ObservationGroupCfg",
"ObservationTermCfg", "ObservationTermCfg",
"ObservationManager", "ObservationManager",
# reward
"RewardTermCfg", "RewardTermCfg",
"RewardManager", "RewardManager",
# randomization
"RandomizationTermCfg",
"RandomizationManager",
# termination
"TerminationTermCfg",
"TerminationManager",
] ]
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES, ETH Zurich, and University of Toronto
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Curriculum manager for updating environment quantities subject to a training curriculum."""
import torch
from prettytable import PrettyTable
from typing import Dict, List, Optional, Sequence
from .manager_base import ManagerBase
from .manager_cfg import CurriculumTermCfg
class CurriculumManager(ManagerBase):
"""Manager to implement and execute specific curricula.
The curriculum manager updates various quantities of the environment subject to a training curriculum by
calling a list of terms. These help stabilize learning by progressively making the learning tasks harder
as the agent improves.
The curriculum terms are parsed from a config class containing the manager's settings and each term's
parameters. Each curriculum term should instantiate the :class:`CurriculumTermCfg` class.
"""
def __init__(self, cfg: object, env: object):
"""Initialize the manager.
Args:
cfg (object): The configuration object or dictionary (``dict[str, CurriculumTermCfg]``)
env (object): An environment object.
Raises:
TypeError: If curriculum term is not of type :class:`CurriculumTermCfg`.
ValueError: If curriculum term configuration does not satisfy its function signature.
"""
super().__init__(cfg, env)
# prepare logging
self._curriculum_state = dict()
for term_name in self._term_names:
self._curriculum_state[term_name] = None
def __str__(self) -> str:
"""Returns: A string representation for curriculum manager."""
msg = f"<CurriculumManager> contains {len(self._term_names)} active terms.\n"
# create table for term information
table = PrettyTable()
table.title = "Active Curriculum Terms"
table.field_names = ["Index", "Name"]
# set alignment of table columns
table.align["Name"] = "l"
# add info on each term
for index, name in enumerate(self._term_names):
table.add_row([index, name])
# convert table to string
msg += table.get_string()
return msg
"""
Properties.
"""
@property
def active_terms(self) -> List[str]:
"""Name of active curriculum terms."""
return self._term_names
"""
Operations.
"""
def log_info(self, env_ids: Optional[Sequence[int]] = None) -> Dict[str, float]:
"""Returns the current state of individual curriculum terms.
Note:
This function does not use the environment indices :attr:`env_ids`
and logs the state of all the terms. The argument is only present
to maintain consistency with other classes.
Returns:
Dict[str, float]: Dictionary of curriculum terms and their states.
"""
extras = {}
for term_name, term_state in self._curriculum_state.items():
if term_state is not None:
# deal with dict
if isinstance(term_state, dict):
# each key is a separate state to log
for key, value in term_state.items():
if isinstance(value, torch.Tensor):
value = value.item()
extras[f"Curriculum/{term_name}/{key}"] = value
else:
# log directly if not a dict
if isinstance(term_state, torch.Tensor):
term_state = term_state.item()
extras[f"Curriculum/{term_name}"] = term_state
return extras
def compute(self, env_ids: Optional[Sequence[int]] = None):
"""Update the curriculum terms.
This function calls each curriculum term managed by the class.
Args:
env_ids (Optional[Sequence[int]]): The list of environment IDs to update.
If None, all the environments are updated. Defaults to None.
"""
# resolve environment indices
if env_ids is None:
env_ids = ...
# iterate over all the curriculum terms
for name, term_cfg in zip(self._term_names, self._term_cfgs):
state = term_cfg.func(self._env, env_ids, **term_cfg.params)
self._curriculum_state[name] = state
"""
Helper functions.
"""
def _prepare_terms(self):
# parse remaining curriculum terms and decimate their information
self._term_names: List[str] = list()
self._term_cfgs: List[CurriculumTermCfg] = list()
# check if config is dict already
if isinstance(self.cfg, dict):
cfg_items = self.cfg.items()
else:
cfg_items = self.cfg.__dict__.items()
# iterate over all the terms
for term_name, term_cfg in cfg_items:
# check for non config
if term_cfg is None:
continue
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# add name and config to list
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
...@@ -53,6 +53,22 @@ class ManagerBaseTermCfg: ...@@ -53,6 +53,22 @@ class ManagerBaseTermCfg:
"""The parameters to be passed to the function as keyword arguments. Defaults to an empty dict.""" """The parameters to be passed to the function as keyword arguments. Defaults to an empty dict."""
"""Curriculum manager."""
@configclass
class CurriculumTermCfg(ManagerBaseTermCfg):
"""Configuration for a curriculum term."""
func: Callable[..., float | dict[str, float]] = MISSING
"""The name of the function to be called.
This function should take the environment object, environment indices
and any other parameters as input and return the curriculum state for
logging purposes.
"""
"""Observation manager.""" """Observation manager."""
...@@ -99,6 +115,40 @@ class ObservationGroupCfg: ...@@ -99,6 +115,40 @@ class ObservationGroupCfg:
""" """
"""Randomization manager."""
@configclass
class RandomizationTermCfg(ManagerBaseTermCfg):
"""Configuration for a randomization term."""
func: Callable[..., None] = MISSING
"""The name of the function to be called.
This function should take the environment object, environment indices
and any other parameters as input.
"""
mode: str = MISSING
"""The mode in which the randomization term is applied.
Note:
The mode name ``"interval"`` is a special mode that is handled by the
manager Hence, its name is reserved and cannot be used for other modes.
"""
interval_range_s: tuple[float, float] | None = None
"""The range of time in seconds at which the term is applied.
Based on this, the interval is sampled uniformly between the specified
interval range for each environment instance and the term is applied for
the environment instances if the current time hits the interval time.
Note:
This is only used if the mode is ``"interval"``.
"""
"""Reward manager.""" """Reward manager."""
...@@ -123,3 +173,26 @@ class RewardTermCfg(ManagerBaseTermCfg): ...@@ -123,3 +173,26 @@ class RewardTermCfg(ManagerBaseTermCfg):
Note: Note:
If the weight is zero, the reward term is ignored. If the weight is zero, the reward term is ignored.
""" """
"""Termination manager."""
@configclass
class TerminationTermCfg(ManagerBaseTermCfg):
"""Configuration for a termination term."""
func: Callable[..., torch.Tensor] = MISSING
"""The name of the function to be called.
This function should take the environment object and any other parameters
as input and return the termination signals as torch boolean tensors of
shape ``(num_envs,)``.
"""
time_out: bool = False
"""Whether the termination term contributes towards episodic timeouts. Defaults to False.
Note:
These usually correspond to tasks that have a fixed time limit.
"""
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES, ETH Zurich, and University of Toronto
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Randomization manager for randomizing different elements in the scene."""
import logging
import torch
from prettytable import PrettyTable
from typing import Dict, List, Optional, Sequence
from .manager_base import ManagerBase
from .manager_cfg import RandomizationTermCfg
class RandomizationManager(ManagerBase):
"""Manager for randomizing different elements in the scene.
The randomization manager applies randomization to any instance in the scene. For example, changing the
masses of objects or their friction coefficients, or applying random pushes to the robot. The user can
specify several modes of randomization to specialize the behavior based on when to apply the randomization.
The randomization terms are parsed from a config class containing the manager's settings and each term's
parameters. Each randomization term should instantiate the :class:`RandomizationTermCfg` class.
Randomization terms can be grouped by their mode. The mode is a user-defined string that specifies when
the randomization term should be applied. This provides the user complete control over when randomization
terms should be applied.
For a typical training process, you may want to randomize in the following modes:
- "startup": Randomization term is applied once at the beginning of the training.
- "reset": Randomization is applied at every reset.
- "interval": Randomization is applied at pre-specified intervals of time.
However, you can also define your own modes and use them in the training process as you see fit.
.. note::
The mode ``"interval"`` is the only mode that is handled by the manager itself which is based on
the environment's time step.
"""
def __init__(self, cfg: object, env: object):
"""Initialize the randomization manager.
Args:
cfg (object): A configuration object or dictionary (``dict[str, RandomizationTermCfg]``).
env (object): An environment object.
"""
super().__init__(cfg, env)
def __str__(self) -> str:
"""Returns: A string representation for randomization manager."""
msg = f"<RandomizationManager> contains {len(self._mode_term_names)} active terms.\n"
# add info on each mode
for mode in self._mode_term_names:
# create table for term information
table = PrettyTable()
table.title = f"Active Randomization Terms in Mode: '{mode}'"
# add table headers based on mode
if mode == "interval":
table.field_names = ["Index", "Name", "Interval time range (s)"]
table.align["Name"] = "l"
for index, (name, cfg) in enumerate(zip(self._mode_term_names[mode], self._mode_term_cfgs[mode])):
table.add_row([index, name, cfg.interval_range_s])
else:
table.field_names = ["Index", "Name"]
table.align["Name"] = "l"
for index, name in enumerate(self._mode_term_names[mode]):
table.add_row([index, name])
# convert table to string
msg += table.get_string()
msg += "\n"
return msg
"""
Properties.
"""
@property
def dt(self) -> float:
"""The environment time-step (in seconds)."""
return self._env.dt
@property
def active_terms(self) -> Dict[str, List[str]]:
"""Name of active randomization terms."""
return self._mode_term_names
"""
Operations.
"""
def randomize(self, mode: str, env_ids: Optional[Sequence[int]] = None, dt: Optional[float] = None):
"""Calls each randomization term in the specified mode.
Note:
For interval mode, the time step of the environment is used to determine if the randomization should be
applied. If the time step is not constant, the user should pass the time step to this function.
Args:
mode (str): The mode of randomization.
env_ids (Optional[Sequence[int]]): The indices of the environments to apply randomization to.
Defaults to None, in which case the randomization is applied to all environments.
dt (Optional[float]): The time step of the environment. Defaults to None, in which case the time
step of the environment is used.
"""
# check if mode is valid
if mode not in self._mode_term_names:
logging.warning(f"Randomization mode '{mode}' is not defined. Skipping randomization.")
return
# iterate over all the randomization terms
for index, term_cfg in enumerate(self._mode_term_cfgs[mode]):
# resample interval if needed
if mode == "interval":
if dt is None:
dt = self.dt
# extract time left for this term
time_left = self._interval_mode_time_left[index]
# update the time left for each environment
time_left -= dt
# check if the interval has passed
env_ids = (time_left <= 0.0).nonzero().flatten()
if len(env_ids) > 0:
lower, upper = term_cfg.interval_range_s
time_left[env_ids] = torch.rand(len(env_ids), device=self.device) * (upper - lower) + lower
# call the randomization term
term_cfg.func(self._env, env_ids, **term_cfg.params)
"""
Helper functions.
"""
def _prepare_terms(self):
"""Prepares a list of randomization functions."""
# parse remaining randomization terms and decimate their information
self._mode_term_names: Dict[str, List[str]] = dict()
self._mode_term_cfgs: Dict[str, List[RandomizationTermCfg]] = dict()
# buffer to store the time left for each environment for "interval" mode
self._interval_mode_time_left: List[torch.Tensor] = list()
# check if config is dict already
if isinstance(self.cfg, dict):
cfg_items = self.cfg.items()
else:
cfg_items = self.cfg.__dict__.items()
# iterate over all the terms
for term_name, term_cfg in cfg_items:
# check for non config
if term_cfg is None:
continue
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# check if mode is a new mode
if term_cfg.mode not in self._mode_term_names:
# add new mode
self._mode_term_names[term_cfg.mode] = list()
self._mode_term_cfgs[term_cfg.mode] = list()
# add term name and parameters
self._mode_term_names[term_cfg.mode].append(term_name)
self._mode_term_cfgs[term_cfg.mode].append(term_cfg)
# resolve the mode of randomization
if term_cfg.mode == "interval":
if term_cfg.interval_range_s is None:
raise ValueError(
f"Randomization term '{term_name}' has mode 'interval' but 'interval_range_s' is not specified."
)
# sample the time left for each environment
lower, upper = term_cfg.interval_range_s
time_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower
self._interval_mode_time_left.append(time_left)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES, ETH Zurich, and University of Toronto
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Termination manager for computing done signals for a given world."""
import torch
from prettytable import PrettyTable
from typing import Dict, List, Optional, Sequence
from .manager_base import ManagerBase
from .manager_cfg import TerminationTermCfg
class TerminationManager(ManagerBase):
"""Manager for computing done signals for a given world.
The termination manager computes the termination signal (also called dones) as a combination
of termination terms. Each termination term is a function which takes the environment as an
argument and returns a boolean tensor of shape ``(num_envs,)``. The termination manager
computes the termination signal as the union (logical or) of all the termination terms.
The termination terms are parsed from a config class containing the manager's settings and each term's
parameters. Each termination term should instantiate the :class:`TerminationTermCfg` class.
"""
def __init__(self, cfg: object, env: object):
"""Initializes the termination manager.
Args:
cfg (object): The configuration object or dictionary (``dict[str, TerminationTermCfg]``).
env (object): An environment object.
"""
super().__init__(cfg, env)
# prepare extra info to store individual termination term information
self._episode_dones = dict()
for term_name in self._term_names:
self._episode_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
# create buffer for managing termination per environment
self._done_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self._time_out_buf = torch.zeros_like(self._done_buf)
def __str__(self) -> str:
"""Returns: A string representation for termination manager."""
msg = f"<TerminationManager> contains {len(self._term_names)} active terms.\n"
# create table for term information
table = PrettyTable()
table.title = "Active Termination Terms"
table.field_names = ["Index", "Name", "Time Out"]
# set alignment of table columns
table.align["Name"] = "l"
# add info on each term
for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)):
table.add_row([index, name, term_cfg.time_out])
# convert table to string
msg += table.get_string()
return msg
"""
Properties.
"""
@property
def active_terms(self) -> List[str]:
"""Name of active termination terms."""
return self._term_names
@property
def dones(self) -> torch.Tensor:
"""The net termination signal. Shape is ``(num_envs,)``."""
return self._done_buf
@property
def time_outs(self) -> torch.Tensor:
"""The timeout signal. Shape is ``(num_envs,)``."""
return self._time_out_buf
"""
Operations.
"""
def log_info(self, env_ids: Optional[Sequence[int]] = None) -> Dict[str, torch.Tensor]:
"""Returns the episodic counts of individual termination terms.
Args:
env_ids (Optional[Sequence[int]], optional): The environment ids. Defaults to None, in which case
all environments are considered.
Returns:
Dict[str, torch.Tensor]: Dictionary of episodic sum of individual reward terms.
"""
# resolve environment ids
if env_ids is None:
env_ids = ...
# add to episode dict
extras = {}
for key in self._episode_dones.keys():
extras["Episode Termination/" + key] = torch.count_nonzero(self._episode_dones[key][env_ids])
self._episode_dones[key][env_ids] = False
return extras
def compute(self) -> torch.Tensor:
"""Computes the termination signal as union of individual terms.
This function calls each termination term managed by the class and performs a logical OR operation
to compute the net termination signal.
Returns:
torch.Tensor: The combined termination signal of shape ``(num_envs,)``.
"""
# reset computation
self._done_buf[:] = 0.0
self._time_out_buf[:] = 0.0
# iterate over all the termination terms
for name, term_cfg in zip(self._term_names, self._term_cfgs):
value = term_cfg.func(self._env, **term_cfg.params)
# update total termination
self._done_buf |= value
# store timeout signal separately
if term_cfg.time_out:
self._time_out_buf |= value
# add to episode dones
self._episode_dones[name] |= value
# return termination signal
return self._done_buf
"""
Helper functions.
"""
def _prepare_terms(self):
"""Prepares a list of termination functions."""
# parse remaining termination terms and decimate their information
self._term_names: List[str] = list()
self._term_cfgs: List[TerminationTermCfg] = list()
# check if config is dict already
if isinstance(self.cfg, dict):
cfg_items = self.cfg.items()
else:
cfg_items = self.cfg.__dict__.items()
# iterate over all the terms
for term_name, term_cfg in cfg_items:
# check for non config
if term_cfg is None:
continue
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
# add function to list
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment