Unverified Commit 51f62922 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds reset for manager terms (#255)

# Description

Previously, while the manager classes supported "functional" class
terms, they only implemented `__call__` attribute. However, more
frequently, the class terms will keep some history in them which we
would want to clear out.

This MR introduces a `ManagerTerm` class for class terms which serves as
a protocol for how manager terms that are classes should be implemented.
The `ActionTerm` inherits from this `ManagerTerm`.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 84d41834
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.48"
version = "0.9.49"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.9.49 (2023-11-27)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added an interface class, :class:`omni.isaac.orbit.managers.ManagerTermBase`, to serve as the parent class
for term implementations that are functional classes.
* Adapted all managers to support terms that are classes and not just functions clearer. This allows the user to
create more complex terms that require additional state information.
0.9.48 (2023-11-24)
~~~~~~~~~~~~~~~~~~~
......
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""This sub-module introduces the base managers for defining MDPs."""
from __future__ import annotations
from .observation_manager import ObservationManager
from .reward_manager import RewardManager
__all__ = ["RewardManager", "ObservationManager"]
......@@ -9,55 +9,34 @@ This sub-module introduces the managers for handling various aspects of the envi
The managers are used to handle various aspects of the environment such as randomization, curriculum, and
observations. Each manager implements a specific functionality for the environment. The managers are
designed to be modular and can be easily extended to support new functionality.
Each manager is implemented as a class that inherits from the :class:`ManagerBase` class. Each manager
class should also have a corresponding configuration class that defines the configuration terms for the
manager. Each term should the :class:`ManagerBaseTermCfg` class or its subclass.
Example pseudo-code for a manager:
.. code-block:: python
from omni.isaac.orbit.utils import configclass
from omni.isaac.orbit.utils.mdp import ManagerBase, ManagerBaseTermCfg
@configclass
class MyManagerCfg:
my_term_1: ManagerBaseTermCfg = ManagerBaseTermCfg(...)
my_term_2: ManagerBaseTermCfg = ManagerBaseTermCfg(...)
my_term_3: ManagerBaseTermCfg = ManagerBaseTermCfg(...)
# define manager instance
my_manager = ManagerBase(cfg=ManagerCfg(), env=env)
"""
from __future__ import annotations
from .action_manager import ActionManager, ActionTerm
from .curriculum_manager import CurriculumManager
from .manager_base import ManagerBase
from .manager_cfg import (
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import (
ActionTermCfg,
CurriculumTermCfg,
ManagerBaseTermCfg,
ManagerTermBaseCfg,
ObservationGroupCfg,
ObservationTermCfg,
RandomizationTermCfg,
RewardTermCfg,
SceneEntityCfg,
TerminationTermCfg,
)
from .observation_manager import ObservationManager
from .randomization_manager import RandomizationManager
from .reward_manager import RewardManager
from .scene_entity_cfg import SceneEntityCfg
from .termination_manager import TerminationManager
__all__ = [
# base
"SceneEntityCfg",
"ManagerBaseTermCfg",
# base
"ManagerTermBaseCfg",
"ManagerTermBase",
"ManagerBase",
# action
"ActionTermCfg",
......
......@@ -8,20 +8,20 @@
from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from abc import abstractmethod
from prettytable import PrettyTable
from typing import TYPE_CHECKING, Sequence
from omni.isaac.orbit.assets import AssetBase
from .manager_base import ManagerBase
from .manager_cfg import ActionTermCfg
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ActionTermCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv
class ActionTerm(ABC):
class ActionTerm(ManagerTermBase):
"""Base class for action terms.
The action term is responsible for processing the raw actions sent to the environment
......@@ -41,9 +41,8 @@ class ActionTerm(ABC):
cfg: The configuration object.
env: The environment instance.
"""
# store the inputs
self.cfg = cfg
self._env = env
# call the base class constructor
super().__init__(cfg, env)
# parse config to obtain asset to which the term is applied
self._asset: AssetBase = self._env.scene[self.cfg.asset_name]
......@@ -51,16 +50,6 @@ class ActionTerm(ABC):
Properties.
"""
@property
def num_envs(self) -> int:
"""Number of environments."""
return self._env.num_envs
@property
def device(self) -> str:
"""Device on which to perform computations."""
return self._asset.device
@property
@abstractmethod
def action_dim(self) -> int:
......@@ -200,7 +189,10 @@ class ActionManager(ManagerBase):
# reset the action history
self._prev_action[env_ids] = 0.0
self._action[env_ids] = 0.0
# reset the terms
# reset all action terms
for term in self._terms:
term.reset(env_ids=env_ids)
# nothing to log here
return {}
def process_action(self, action: torch.Tensor):
......
......@@ -11,8 +11,8 @@ import torch
from prettytable import PrettyTable
from typing import TYPE_CHECKING, Sequence
from .manager_base import ManagerBase
from .manager_cfg import CurriculumTermCfg
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import CurriculumTermCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
......@@ -106,6 +106,10 @@ class CurriculumManager(ManagerBase):
if isinstance(term_state, torch.Tensor):
term_state = term_state.item()
extras[f"Curriculum/{term_name}"] = term_state
# reset all the curriculum terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
# return logged information
return extras
def compute(self, env_ids: Sequence[int] | None = None):
......@@ -133,6 +137,7 @@ class CurriculumManager(ManagerBase):
# parse remaining curriculum terms and decimate their information
self._term_names: list[str] = list()
self._term_cfgs: list[CurriculumTermCfg] = list()
self._class_term_cfgs: list[CurriculumTermCfg] = list()
# check if config is dict already
if isinstance(self.cfg, dict):
......@@ -155,3 +160,6 @@ class CurriculumManager(ManagerBase):
# add name and config to list
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
# check if the term is a class
if isinstance(term_cfg.func, ManagerTermBase):
self._class_term_cfgs.append(term_cfg)
......@@ -8,18 +8,109 @@ from __future__ import annotations
import copy
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Sequence
import carb
from omni.isaac.orbit.utils import string_to_callable
from .manager_cfg import ManagerBaseTermCfg, SceneEntityCfg
from .manager_term_cfg import ManagerTermBaseCfg
from .scene_entity_cfg import SceneEntityCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv
class ManagerTermBase(ABC):
"""Base class for manager terms.
Manager term implementations can be functions or classes. If the term is a class, it should
inherit from this base class and implement the required methods.
Each manager is implemented as a class that inherits from the :class:`ManagerBase` class. Each manager
class should also have a corresponding configuration class that defines the configuration terms for the
manager. Each term should the :class:`ManagerTermBaseCfg` class or its subclass.
Example pseudo-code for creating a manager:
.. code-block:: python
from omni.isaac.orbit.utils import configclass
from omni.isaac.orbit.utils.mdp import ManagerBase, ManagerTermBaseCfg
@configclass
class MyManagerCfg:
my_term_1: ManagerTermBaseCfg = ManagerTermBaseCfg(...)
my_term_2: ManagerTermBaseCfg = ManagerTermBaseCfg(...)
my_term_3: ManagerTermBaseCfg = ManagerTermBaseCfg(...)
# define manager instance
my_manager = ManagerBase(cfg=ManagerCfg(), env=env)
"""
def __init__(self, cfg: ManagerTermBaseCfg, env: BaseEnv):
"""Initialize the manager term.
Args:
cfg: The configuration object.
env: The environment instance.
"""
# store the inputs
self.cfg = cfg
self._env = env
"""
Properties.
"""
@property
def num_envs(self) -> int:
"""Number of environments."""
return self._env.num_envs
@property
def device(self) -> str:
"""Device on which to perform computations."""
return self._env.device
"""
Operations.
"""
def reset(self, env_ids: Sequence[int] | None = None) -> None:
"""Resets the manager term.
Args:
env_ids: The environment ids. Defaults to None, in which case
all environments are considered.
"""
pass
def __call__(self, *args) -> Any:
"""Returns the value of the term required by the manager.
In case of a class implementation, this function is called by the manager
to get the value of the term. The arguments passed to this function are
the ones specified in the term configuration (see :attr:`ManagerTermBaseCfg.params`).
.. attention::
To be consistent with memory-less implementation of terms with functions, it is
recommended to ensure that the returned mutable quantities are cloned before
returning them. For instance, if the term returns a tensor, it is recommended
to ensure that the returned tensor is a clone of the original tensor. This prevents
the manager from storing references to the tensors and altering the original tensors.
Args:
*args: Variable length argument list.
Returns:
The value of the term.
"""
raise NotImplementedError
class ManagerBase(ABC):
"""Base class for all managers."""
......@@ -85,7 +176,7 @@ class ManagerBase(ABC):
Helper functions.
"""
def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerBaseTermCfg, min_argc: int = 1):
def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, min_argc: int = 1):
"""Resolve common term configuration.
Usually, called by the :meth:`_prepare_terms` method to resolve common term configuration.
......@@ -104,44 +195,26 @@ class ManagerBase(ABC):
by the manager.
Raises:
TypeError: If the term configuration is not of type :class:`ManagerBaseTermCfg`.
TypeError: If the term configuration is not of type :class:`ManagerTermBaseCfg`.
ValueError: If the scene entity defined in the term configuration does not exist.
AttributeError: If the term function is not callable.
ValueError: If the term function's arguments are not matched by the parameters.
"""
# check if the term is a valid term config
if not isinstance(term_cfg, ManagerBaseTermCfg):
if not isinstance(term_cfg, ManagerTermBaseCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type ManagerBaseTermCfg."
f"Configuration for the term '{term_name}' is not of type ManagerTermBaseCfg."
f" Received: '{type(term_cfg)}'."
)
# iterate over all the entities and parse the joint and body names
for key, value in term_cfg.params.items():
# deal with string
if isinstance(value, SceneEntityCfg):
# check if the entity is valid
if value.name not in self._env.scene.keys():
raise ValueError(f"For the term '{term_name}', the scene entity '{value.name}' does not exist.")
# convert joint names to indices based on regex
if value.joint_names is not None and value.joint_ids is not None:
raise ValueError(
f"For the term '{term_name}', both 'joint_names' and 'joint_ids' are specified in '{key}'."
)
if value.joint_names is not None:
if isinstance(value.joint_names, str):
value.joint_names = [value.joint_names]
joint_ids, _ = self._env.scene[value.name].find_joints(value.joint_names)
value.joint_ids = joint_ids
# convert body names to indices based on regex
if value.body_names is not None and value.body_ids is not None:
raise ValueError(
f"For the term '{term_name}', both 'body_names' and 'body_ids' are specified in '{key}'."
)
if value.body_names is not None:
if isinstance(value.body_names, str):
value.body_names = [value.body_names]
body_ids, _ = self._env.scene[value.name].find_bodies(value.body_names)
value.body_ids = body_ids
# load the entity
try:
value.resolve(self._env.scene)
except ValueError as e:
raise ValueError(f"Error while parsing '{term_name}:{key}'. {e}")
# log the entity for checking later
msg = f"[{term_cfg.__class__.__name__}:{term_name}] Found entity '{value.name}'."
if value.joint_ids is not None:
......@@ -158,9 +231,12 @@ class ManagerBase(ABC):
term_cfg.func = string_to_callable(term_cfg.func)
# initialize the term if it is a class
if inspect.isclass(term_cfg.func):
if not issubclass(term_cfg.func, ManagerTermBase):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type ManagerTermBase."
f" Received: '{type(term_cfg.func)}'."
)
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
# add the "self" argument to the count
min_argc += 1
# check if function is callable
if not callable(term_cfg.func):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
......
......@@ -14,73 +14,28 @@ from typing import TYPE_CHECKING, Any, Callable
from omni.isaac.orbit.utils import configclass
from omni.isaac.orbit.utils.noise import NoiseCfg
from .scene_entity_cfg import SceneEntityCfg
if TYPE_CHECKING:
from .action_manager import ActionTerm
from .manager_base import ManagerTermBase
@configclass
class SceneEntityCfg:
"""Configuration for a scene entity that is used by the manager's term.
This class is used to specify the name of the scene entity that is queried from the
:class:`InteractiveScene` and passed to the manager's term function.
"""
name: str = MISSING
"""The name of the scene entity.
This is the name defined in the scene configuration file. See the :class:`InteractiveSceneCfg`
class for more details.
"""
joint_names: str | list[str] | None = None
"""The names of the joints from the scene entity. Defaults to None.
The names can be either joint names or a regular expression matching the joint names.
These are converted to joint indices on initialization of the manager and passed to the term
function as a list of joint indices under :attr:`dof_ids`.
"""
joint_ids: list[int] | None = None
"""The indices of the joints from the asset required by the term. Defaults to None.
If ``joint_names`` is specified, this is filled in automatically on initialization of the
manager.
"""
body_names: str | list[str] | None = None
"""The names of the bodies from the asset required by the term. Defaults to None.
The names can be either body names or a regular expression matching the body names.
These are converted to body indices on initialization of the manager and passed to the term
function as a list of body indices under :attr:`body_ids`.
"""
body_ids: list[int] | None = None
"""The indices of the bodies from the asset required by the term. Defaults to None.
If ``body_names`` is specified, this is filled in automatically on initialization of the
manager.
"""
@configclass
class ManagerBaseTermCfg:
class ManagerTermBaseCfg:
"""Configuration for a manager term."""
func: Callable = MISSING
"""The function to be called for the term.
func: Callable | ManagerTermBase = MISSING
"""The function or class to be called for the term.
The function must take the environment object as the first argument.
The remaining arguments are specified in the :attr:`params` attribute.
Note:
It also supports `callable classes`_, i.e. classes that implement the :meth:`__call__`
method.
method. In this case, the class should inherit from the :class:`ManagerTermBase` class
and implement the required methods.
.. _`callable classes`: https://docs.python.org/3/reference/datamodel.html#object.__call__
"""
params: dict[str, Any | SceneEntityCfg] = dict()
......@@ -122,7 +77,7 @@ class ActionTermCfg:
@configclass
class CurriculumTermCfg(ManagerBaseTermCfg):
class CurriculumTermCfg(ManagerTermBaseCfg):
"""Configuration for a curriculum term."""
func: Callable[..., float | dict[str, float]] = MISSING
......@@ -140,7 +95,7 @@ class CurriculumTermCfg(ManagerBaseTermCfg):
@configclass
class ObservationTermCfg(ManagerBaseTermCfg):
class ObservationTermCfg(ManagerTermBaseCfg):
"""Configuration for an observation term."""
func: Callable[..., torch.Tensor] = MISSING
......@@ -188,7 +143,7 @@ class ObservationGroupCfg:
@configclass
class RandomizationTermCfg(ManagerBaseTermCfg):
class RandomizationTermCfg(ManagerTermBaseCfg):
"""Configuration for a randomization term."""
func: Callable[..., None] = MISSING
......@@ -224,7 +179,7 @@ class RandomizationTermCfg(ManagerBaseTermCfg):
@configclass
class RewardTermCfg(ManagerBaseTermCfg):
class RewardTermCfg(ManagerTermBaseCfg):
"""Configuration for a reward term."""
func: Callable[..., torch.Tensor] = MISSING
......@@ -252,7 +207,7 @@ class RewardTermCfg(ManagerBaseTermCfg):
@configclass
class TerminationTermCfg(ManagerBaseTermCfg):
class TerminationTermCfg(ManagerTermBaseCfg):
"""Configuration for a termination term."""
func: Callable[..., torch.Tensor] = MISSING
......
......@@ -9,10 +9,10 @@ from __future__ import annotations
import torch
from prettytable import PrettyTable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence
from .manager_base import ManagerBase
from .manager_cfg import ObservationGroupCfg, ObservationTermCfg
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv
......@@ -100,6 +100,14 @@ class ObservationManager(ManagerBase):
Operations.
"""
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
# call all terms that are classes
for group_cfg in self._group_obs_class_term_cfgs.values():
for term_cfg in group_cfg:
term_cfg.func.reset(env_ids=env_ids)
# nothing to log here
return {}
def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Compute the observations per group for all groups.
......@@ -188,6 +196,7 @@ class ObservationManager(ManagerBase):
self._group_obs_term_names: dict[str, list[str]] = dict()
self._group_obs_term_dim: dict[str, list[int]] = dict()
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()
# check if config is dict already
......@@ -210,6 +219,7 @@ class ObservationManager(ManagerBase):
self._group_obs_term_names[group_name] = list()
self._group_obs_term_dim[group_name] = list()
self._group_obs_term_cfgs[group_name] = list()
self._group_obs_class_term_cfgs[group_name] = list()
# read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
......@@ -242,3 +252,8 @@ class ObservationManager(ManagerBase):
# call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape[1:])
self._group_obs_term_dim[group_name].append(obs_dims)
# add term in a separate list if term is a class
if isinstance(term_cfg.func, ManagerTermBase):
self._group_obs_class_term_cfgs[group_name].append(term_cfg)
# call reset (in-case above call to get obs dims changed the state)
term_cfg.func.reset()
......@@ -13,8 +13,8 @@ from typing import TYPE_CHECKING, Sequence
import carb
from .manager_base import ManagerBase
from .manager_cfg import RandomizationTermCfg
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import RandomizationTermCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
......@@ -108,6 +108,14 @@ class RandomizationManager(ManagerBase):
Operations.
"""
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
# call all terms that are classes
for mode_cfg in self._mode_class_term_cfgs.values():
for term_cfg in mode_cfg:
term_cfg.func.reset(env_ids=env_ids)
# nothing to log here
return {}
def randomize(self, mode: str, env_ids: Sequence[int] | None = None, dt: float | None = None):
"""Calls each randomization term in the specified mode.
......@@ -205,6 +213,7 @@ class RandomizationManager(ManagerBase):
# parse remaining randomization terms and decimate their information
self._mode_term_names: dict[str, list[str]] = dict()
self._mode_term_cfgs: dict[str, list[RandomizationTermCfg]] = dict()
self._mode_class_term_cfgs: dict[str, list[RandomizationTermCfg]] = dict()
# buffer to store the time left for each environment for "interval" mode
self._interval_mode_time_left: list[torch.Tensor] = list()
......@@ -231,9 +240,13 @@ class RandomizationManager(ManagerBase):
# add new mode
self._mode_term_names[term_cfg.mode] = list()
self._mode_term_cfgs[term_cfg.mode] = list()
self._mode_class_term_cfgs[term_cfg.mode] = list()
# add term name and parameters
self._mode_term_names[term_cfg.mode].append(term_name)
self._mode_term_cfgs[term_cfg.mode].append(term_cfg)
# check if the term is a class
if isinstance(term_cfg.func, ManagerTermBase):
self._mode_class_term_cfgs[term_cfg.mode].append(term_cfg)
# resolve the mode of randomization
if term_cfg.mode == "interval":
......
......@@ -11,8 +11,8 @@ import torch
from prettytable import PrettyTable
from typing import TYPE_CHECKING, Sequence
from .manager_base import ManagerBase
from .manager_cfg import RewardTermCfg
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import RewardTermCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
......@@ -108,6 +108,10 @@ class RewardManager(ManagerBase):
extras["Episode Reward/" + key] = episodic_sum_avg / self._env.max_episode_length_s
# reset episodic sum
self._episode_sums[key][env_ids] = 0.0
# reset all the reward terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
# return logged information
return extras
def compute(self, dt: float) -> torch.Tensor:
......@@ -183,6 +187,7 @@ class RewardManager(ManagerBase):
# parse remaining reward terms and decimate their information
self._term_names: list[str] = list()
self._term_cfgs: list[RewardTermCfg] = list()
self._class_term_cfgs: list[RewardTermCfg] = list()
# check if config is dict already
if isinstance(self.cfg, dict):
......@@ -211,3 +216,6 @@ class RewardManager(ManagerBase):
# add function to list
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
# check if the term is a class
if isinstance(term_cfg.func, ManagerTermBase):
self._class_term_cfgs.append(term_cfg)
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Configuration terms for different managers."""
from __future__ import annotations
from dataclasses import MISSING
from omni.isaac.orbit.assets import Articulation, RigidObject
from omni.isaac.orbit.scene import InteractiveScene
from omni.isaac.orbit.utils import configclass
@configclass
class SceneEntityCfg:
"""Configuration for a scene entity that is used by the manager's term.
This class is used to specify the name of the scene entity that is queried from the
:class:`InteractiveScene` and passed to the manager's term function.
"""
name: str = MISSING
"""The name of the scene entity.
This is the name defined in the scene configuration file. See the :class:`InteractiveSceneCfg`
class for more details.
"""
joint_names: str | list[str] | None = None
"""The names of the joints from the scene entity. Defaults to None.
The names can be either joint names or a regular expression matching the joint names.
These are converted to joint indices on initialization of the manager and passed to the term
function as a list of joint indices under :attr:`dof_ids`.
"""
joint_ids: list[int] | None = None
"""The indices of the joints from the asset required by the term. Defaults to None.
If ``joint_names`` is specified, this is filled in automatically on initialization of the
manager.
"""
body_names: str | list[str] | None = None
"""The names of the bodies from the asset required by the term. Defaults to None.
The names can be either body names or a regular expression matching the body names.
These are converted to body indices on initialization of the manager and passed to the term
function as a list of body indices under :attr:`body_ids`.
"""
body_ids: list[int] | None = None
"""The indices of the bodies from the asset required by the term. Defaults to None.
If ``body_names`` is specified, this is filled in automatically on initialization of the
manager.
"""
def resolve(self, scene: InteractiveScene):
"""Resolves the scene entity and converts the joint and body names to indices.
This function examines the scene entity from the :class:`InteractiveScene` and resolves the indices
and names of the joints and bodies. It is an expensive operation as it resolves regular expressions
and should be called only once.
Args:
scene: The interactive scene instance.
Raises:
ValueError: If the scene entity is not found.
ValueError: If both ``joint_names`` and ``joint_ids`` are specified and are not consistent.
ValueError: If both ``body_names`` and ``body_ids`` are specified and are not consistent.
"""
# check if the entity is valid
if self.name not in scene.keys():
raise ValueError(f"The scene entity '{self.name}' does not exist. Available entities: {scene.keys()}.")
# convert joint names to indices based on regex
if self.joint_names is not None or self.joint_ids is not None:
entity: Articulation = scene[self.name]
# -- if both are not None, check if they are valid
if self.joint_names is not None and self.joint_ids is not None:
if isinstance(self.joint_names, str):
self.joint_names = [self.joint_names]
if isinstance(self.joint_ids, int):
self.joint_ids = [self.joint_ids]
joint_ids, _ = entity.find_joints(self.joint_names)
joint_names = [entity.joint_names[i] for i in self.joint_ids]
if joint_ids != self.joint_ids or joint_names != self.joint_names:
raise ValueError(
"Both 'joint_names' and 'joint_ids' are specified, and are not consistent."
f"\n\tfrom joint names: {self.joint_names} [{joint_ids}]"
f"\n\tfrom joint ids: {joint_names} [{self.joint_ids}]"
"\nHint: Use either 'joint_names' or 'joint_ids' to avoid confusion."
)
# -- from joint names to joint indices
elif self.joint_names is not None:
if isinstance(self.joint_names, str):
self.joint_names = [self.joint_names]
self.joint_ids, _ = entity.find_joints(self.joint_names)
# -- from joint indices to joint names
elif self.joint_ids is not None:
if isinstance(self.joint_ids, int):
self.joint_ids = [self.joint_ids]
self.joint_names = [entity.joint_names[i] for i in self.joint_ids]
# convert body names to indices based on regex
if self.body_names is not None or self.body_ids is not None:
entity: RigidObject = scene[self.name]
# -- if both are not None, check if they are valid
if self.body_names is not None and self.body_ids is not None:
if isinstance(self.body_names, str):
self.body_names = [self.body_names]
if isinstance(self.body_ids, int):
self.body_ids = [self.body_ids]
body_ids, _ = entity.find_bodies(self.body_names)
body_names = [entity.body_names[i] for i in self.body_ids]
if body_ids != self.body_ids or body_names != self.body_names:
raise ValueError(
"Both 'body_names' and 'body_ids' are specified, and are not consistent."
f"\n\tfrom body names: {self.body_names} [{body_ids}]"
f"\n\tfrom body ids: {body_names} [{self.body_ids}]"
"\nHint: Use either 'body_names' or 'body_ids' to avoid confusion."
)
# -- from body names to body indices
elif self.body_names is not None:
if isinstance(self.body_names, str):
self.body_names = [self.body_names]
self.body_ids, _ = entity.find_bodies(self.body_names)
# -- from body indices to body names
elif self.body_ids is not None:
if isinstance(self.body_ids, int):
self.body_ids = [self.body_ids]
self.body_names = [entity.body_names[i] for i in self.body_ids]
......@@ -11,8 +11,8 @@ import torch
from prettytable import PrettyTable
from typing import TYPE_CHECKING, Sequence
from .manager_base import ManagerBase
from .manager_cfg import TerminationTermCfg
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import TerminationTermCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
......@@ -136,6 +136,10 @@ class TerminationManager(ManagerBase):
extras["Episode Termination/" + key] = torch.count_nonzero(self._episode_dones[key][env_ids]).item()
# reset episode dones
self._episode_dones[key][env_ids] = False
# reset all the reward terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
# return logged information
return extras
def compute(self) -> torch.Tensor:
......@@ -208,6 +212,7 @@ class TerminationManager(ManagerBase):
# parse remaining termination terms and decimate their information
self._term_names: list[str] = list()
self._term_cfgs: list[TerminationTermCfg] = list()
self._class_term_cfgs: list[TerminationTermCfg] = list()
# check if config is dict already
if isinstance(self.cfg, dict):
......@@ -230,3 +235,6 @@ class TerminationManager(ManagerBase):
# add function to list
self._term_names.append(term_name)
self._term_cfgs.append(term_cfg)
# check if the term is a class
if isinstance(term_cfg.func, ManagerTermBase):
self._class_term_cfgs.append(term_cfg)
......@@ -7,6 +7,7 @@ from __future__ import annotations
"""Wrapper around the Python 3.7 onwards `dataclasses` module."""
import inspect
import sys
from copy import deepcopy
from dataclasses import MISSING, Field, dataclass, field, replace
......@@ -192,18 +193,13 @@ def _add_annotation_types(cls):
# Note: Do not change this to dir(base) since it orders the members alphabetically.
# This is not desirable since the order of the members is important in some cases.
for key in base.__dict__:
# skip dunder members
if key.startswith("__"):
continue
# skip class functions
if key in _CONFIGCLASS_METHODS:
continue
# check if key is already present
if key in hints:
# get class member
value = getattr(base, key)
# skip members
if _skippable_class_member(key, value, hints):
continue
# add type annotations for members that don't have explicit type annotations
# for these, we deduce the type from the default value
value = getattr(base, key)
if not isinstance(value, type):
if key not in hints:
# check if var type is not MISSING
......@@ -263,14 +259,11 @@ def _process_mutable_types(cls):
continue
# iterate over base class members
for key in base.__dict__:
# skip dunder members
if key.startswith("__"):
continue
# skip class functions
if key in _CONFIGCLASS_METHODS:
continue
# get class member
f = getattr(base, key)
# skip members
if _skippable_class_member(key, f):
continue
# store class member if it is not a type or if it is already present in annotations
if not isinstance(f, type) or key in ann:
class_members[key] = f
......@@ -356,6 +349,43 @@ Helper functions
"""
def _skippable_class_member(key: str, value: Any, hints: dict | None = None) -> bool:
"""Check if the class member should be skipped in configclass processing.
The following members are skipped:
* Dunder members: ``__name__``, ``__module__``, ``__qualname__``, ``__annotations__``, ``__dict__``.
* Manually-added special class functions: From :obj:`_CONFIGCLASS_METHODS`.
* Members that are already present in the type annotations.
* Functions bounded to class object or class.
Args:
key: The class member name.
value: The class member value.
hints: The type hints for the class. Defaults to None, in which case, the
members existence in type hints are not checked.
Returns:
True if the class member should be skipped, False otherwise.
"""
# skip dunder members
if key.startswith("__"):
return True
# skip manually-added special class functions
if key in _CONFIGCLASS_METHODS:
return True
# check if key is already present
if hints is not None and key in hints:
return True
# skip functions bounded to class
if callable(value):
signature = inspect.signature(value)
if "self" in signature.parameters or "cls" in signature.parameters:
return True
# Otherwise, don't skip
return False
def _return_f(f: Any) -> Callable[[], Any]:
"""Returns default factory function for creating mutable/immutable variables.
......
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
import unittest
from collections import namedtuple
from omni.isaac.orbit.compat.utils.mdp.observation_manager import ObservationManager
class DefaultObservationManager(ObservationManager):
def grilled_chicken(self, env):
return torch.ones(env.num_envs, 4, device=self.device)
def grilled_chicken_with_bbq(self, env, bbq: bool):
return bbq * torch.ones(env.num_envs, 1, device=self.device)
def grilled_chicken_with_curry(self, env, hot: bool):
return hot * 2 * torch.ones(env.num_envs, 1, device=self.device)
def grilled_chicken_with_yoghurt(self, env, hot: bool, bland: float):
return hot * bland * torch.ones(env.num_envs, 5, device=self.device)
class TestObservationManager(unittest.TestCase):
"""Test cases for various situations with observation manager."""
def setUp(self) -> None:
self.env = namedtuple("IsaacEnv", ["num_envs"])(20)
self.device = "cpu"
def test_str(self):
cfg = {
"policy": {
"grilled_chicken": {"scale": 10},
"grilled_chicken_with_bbq": {"scale": 5, "bbq": True},
"grilled_chicken_with_yoghurt": {"scale": 1.0, "hot": False, "bland": 2.0},
}
}
self.obs_man = DefaultObservationManager(cfg, self.env, self.device)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 3)
# print the expected string
print()
print(self.obs_man)
def test_config_terms(self):
cfg = {"policy": {"grilled_chicken": {"scale": 10}, "grilled_chicken_with_curry": {"scale": 0.0, "hot": False}}}
self.obs_man = DefaultObservationManager(cfg, self.env, self.device)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 2)
def test_compute(self):
cfg = {"policy": {"grilled_chicken": {"scale": 10}, "grilled_chicken_with_curry": {"scale": 0.0, "hot": False}}}
self.obs_man = DefaultObservationManager(cfg, self.env, self.device)
# compute observation using manager
observations = self.obs_man.compute()
# check the observation shape
self.assertEqual((self.env.num_envs, 5), observations["policy"].shape)
def test_active_terms(self):
cfg = {
"policy": {
"grilled_chicken": {"scale": 10},
"grilled_chicken_with_bbq": {"scale": 5, "bbq": True},
"grilled_chicken_with_curry": {"scale": 0.0, "hot": False},
}
}
self.obs_man = DefaultObservationManager(cfg, self.env, self.device)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 3)
def test_invalid_observation_name(self):
cfg = {
"policy": {
"grilled_chicken": {"scale": 10},
"grilled_chicken_with_bbq": {"scale": 5, "bbq": True},
"grilled_chicken_with_no_bbq": {"scale": 0.1, "hot": False},
}
}
with self.assertRaises(AttributeError):
self.obs_man = DefaultObservationManager(cfg, self.env, self.device)
def test_invalid_observation_config(self):
cfg = {
"policy": {
"grilled_chicken_with_bbq": {"scale": 0.1, "hot": False},
"grilled_chicken_with_yoghurt": {"scale": 2.0, "hot": False},
}
}
with self.assertRaises(ValueError):
self.obs_man = DefaultObservationManager(cfg, self.env, self.device)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import unittest
from collections import namedtuple
from omni.isaac.orbit.compat.utils.mdp.reward_manager import RewardManager
class DefaultRewardManager(RewardManager):
def grilled_chicken(self, env):
return 1
def grilled_chicken_with_bbq(self, env, bbq: bool):
return 0
def grilled_chicken_with_curry(self, env, hot: bool):
return 0
def grilled_chicken_with_yoghurt(self, env, hot: bool, bland: float):
return 0
class TestRewardManager(unittest.TestCase):
"""Test cases for various situations with reward manager."""
def setUp(self) -> None:
self.env = namedtuple("IsaacEnv", [])()
self.device = "cpu"
self.num_envs = 20
self.dt = 0.1
def test_str(self):
cfg = {
"grilled_chicken": {"weight": 10},
"grilled_chicken_with_bbq": {"weight": 5, "bbq": True},
"grilled_chicken_with_yoghurt": {"weight": 1.0, "hot": False, "bland": 2.0},
}
self.rew_man = DefaultRewardManager(cfg, self.env, self.num_envs, self.dt, self.device)
self.assertEqual(len(self.rew_man.active_terms), 3)
# print the expected string
print()
print(self.rew_man)
def test_config_terms(self):
cfg = {"grilled_chicken": {"weight": 10}, "grilled_chicken_with_curry": {"weight": 0.0, "hot": False}}
self.rew_man = DefaultRewardManager(cfg, self.env, self.num_envs, self.dt, self.device)
self.assertEqual(len(self.rew_man.active_terms), 1)
def test_compute(self):
cfg = {"grilled_chicken": {"weight": 10}, "grilled_chicken_with_curry": {"weight": 0.0, "hot": False}}
self.rew_man = DefaultRewardManager(cfg, self.env, self.num_envs, self.dt, self.device)
# compute expected reward
expected_reward = cfg["grilled_chicken"]["weight"] * self.dt
# compute reward using manager
rewards = self.rew_man.compute()
# check the reward for environment index 0
self.assertEqual(float(rewards[0]), expected_reward)
def test_active_terms(self):
cfg = {
"grilled_chicken": {"weight": 10},
"grilled_chicken_with_bbq": {"weight": 5, "bbq": True},
"grilled_chicken_with_curry": {"weight": 0.0, "hot": False},
}
self.rew_man = DefaultRewardManager(cfg, self.env, self.num_envs, self.dt, self.device)
self.assertEqual(len(self.rew_man.active_terms), 2)
def test_invalid_reward_name(self):
cfg = {
"grilled_chicken": {"weight": 10},
"grilled_chicken_with_bbq": {"weight": 5, "bbq": True},
"grilled_chicken_with_no_bbq": {"weight": 0.1, "hot": False},
}
with self.assertRaises(AttributeError):
self.rew_man = DefaultRewardManager(cfg, self.env, self.num_envs, self.dt, self.device)
def test_invalid_reward_weight_config(self):
cfg = {"grilled_chicken": {}}
with self.assertRaises(KeyError):
self.rew_man = DefaultRewardManager(cfg, self.env, self.num_envs, self.dt, self.device)
def test_invalid_reward_config(self):
cfg = {
"grilled_chicken_with_bbq": {"weight": 0.1, "hot": False},
"grilled_chicken_with_yoghurt": {"weight": 2.0, "hot": False},
}
with self.assertRaises(ValueError):
self.rew_man = DefaultRewardManager(cfg, self.env, self.num_envs, self.dt, self.device)
if __name__ == "__main__":
unittest.main()
......@@ -19,7 +19,7 @@ import torch
import unittest
from collections import namedtuple
from omni.isaac.orbit.managers import ObservationGroupCfg, ObservationManager, ObservationTermCfg
from omni.isaac.orbit.managers import ManagerTermBase, ObservationGroupCfg, ObservationManager, ObservationTermCfg
from omni.isaac.orbit.utils import configclass
......@@ -43,18 +43,24 @@ def grilled_chicken_with_yoghurt_and_bbq(env, hot: bool, bland: float, bbq: bool
return hot * bland * bbq * torch.ones(env.num_envs, 3, device=env.device)
class complex_function_class:
class complex_function_class(ManagerTermBase):
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
self.env = env
# define some variables
self._cost = 2 * self.env.num_envs
self._time_passed = torch.zeros(env.num_envs, device=env.device)
def __call__(self, env: object) -> torch.Tensor:
return torch.ones(env.num_envs, 2, device=env.device) * self._cost
def reset(self, env_ids: torch.Tensor | None = None):
if env_ids is None:
env_ids = slice(None)
self._time_passed[env_ids] = 0.0
def __call__(self, env: object, interval: float) -> torch.Tensor:
self._time_passed += interval
return self._time_passed.clone().unsqueeze(-1)
class non_callable_complex_function_class:
class non_callable_complex_function_class(ManagerTermBase):
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
self.env = env
......@@ -242,7 +248,7 @@ class TestObservationManager(unittest.TestCase):
"""Test config class for policy observation group."""
term_1 = ObservationTermCfg(func=grilled_chicken, scale=10)
term_2 = ObservationTermCfg(func=complex_function_class, scale=0.2)
term_2 = ObservationTermCfg(func=complex_function_class, scale=0.2, params={"interval": 0.5})
policy: ObservationGroupCfg = PolicyCfg()
......@@ -251,9 +257,21 @@ class TestObservationManager(unittest.TestCase):
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# check the observation shape
self.assertEqual((self.env.num_envs, 6), observations["policy"].shape)
self.assertEqual(observations["policy"][0, -1].item(), 2 * self.env.num_envs * 0.2)
# check the observation
self.assertEqual((self.env.num_envs, 5), observations["policy"].shape)
self.assertAlmostEqual(observations["policy"][0, -1].item(), 0.2 * 0.5)
# check memory in term
num_exec_count = 10
for _ in range(num_exec_count):
observations = self.obs_man.compute()
self.assertAlmostEqual(observations["policy"][0, -1].item(), 0.2 * 0.5 * (num_exec_count + 1))
# check reset works
self.obs_man.reset(env_ids=[0, 4, 9, 14, 19])
observations = self.obs_man.compute()
self.assertAlmostEqual(observations["policy"][0, -1].item(), 0.2 * 0.5)
self.assertAlmostEqual(observations["policy"][1, -1].item(), 0.2 * 0.5 * (num_exec_count + 2))
def test_non_callable_class_term(self):
"""Test the observation computation with non-callable class term."""
......@@ -274,9 +292,10 @@ class TestObservationManager(unittest.TestCase):
# create observation manager config
cfg = MyObservationManagerCfg()
# create observation manager
with self.assertRaises(AttributeError):
with self.assertRaises(NotImplementedError):
self.obs_man = ObservationManager(cfg, self.env)
if __name__ == "__main__":
unittest.main()
simulation_app.close()
......@@ -169,3 +169,4 @@ class TestRewardManager(unittest.TestCase):
if __name__ == "__main__":
unittest.main()
simulation_app.close()
......@@ -256,6 +256,18 @@ class FunctionsDemoCfg:
func_in_dict = {"func": dummy_function1}
@configclass
class FunctionImplementedDemoCfg:
"""Dummy configuration class with functions as attributes."""
func = dummy_function1
a: int = 5
k = 100.0
def set_a(self, a: int):
self.a = a
"""
Test solutions: Basic
"""
......@@ -517,6 +529,13 @@ class TestConfigClass(unittest.TestCase):
self.assertEqual(cfg.wrapped_func(), 4)
self.assertEqual(cfg.func_in_dict["func"](), 1)
def test_function_impl_config(self):
cfg = FunctionImplementedDemoCfg()
# change value
self.assertEqual(cfg.a, 5)
cfg.set_a(10)
self.assertEqual(cfg.a, 10)
def test_dict_conversion_functions_config(self):
"""Tests conversion of config with functions into dictionary."""
cfg = FunctionsDemoCfg()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment