Unverified Commit 16c740f5 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds support for callable classes in manager terms (#104)

# Description

This PR adds support for callable classes to the `term_cfg.func`
attribute. This is needed for complex behaviors where users may want to
define certain persistent behaviors as part of the terms.

The callable class should take in the term configuration object and the
environment instance as inputs to its constructor. Additionally, they
should implement the `__call__` function with the signature expected by
the manager.

For example, in the case of observation terms, this looks like:

```python
class complex_function_class:
    def __init__(self, cfg: ObservationTermCfg, env: object):
        self.cfg = cfg
        self.env = env
        # define some variables
        self.history_length = 2
        self._obs_history = torch.zeros(self.env.num_envs, self.history_length, 2, device=self.env.device)

    def __call__(self, env: object) -> torch.Tensor:
        new_obs = torch.rand(env.num_envs, 2, device=env.device)
        # update history
        self._obs_history[:, 1:] = self._obs_history[:, :1].clone()
        self._obs_history[:, 0] = new_obs
        # return obs
        return new_obs
```

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 619337ed
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.8.5"
version = "0.8.6"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.8.6 (2023-08-03)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for callable classes in the :class:`omni.isaac.orbit.managers.ManagerBase`.
0.8.5 (2023-08-03)
~~~~~~~~~~~~~~~~~~
......
......@@ -137,6 +137,11 @@ class CurriculumManager(ManagerBase):
# check for non config
if term_cfg is None:
continue
# check if the term is a valid term config
if not isinstance(term_cfg, CurriculumTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type CurriculumTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# add name and config to list
......
......@@ -130,9 +130,14 @@ class ManagerBase(ABC):
# acquire the body indices
body_ids, _ = getattr(self._env, term_cfg.asset_name).find_bodies(term_cfg.body_names)
term_cfg.params["body_ids"] = body_ids
# get the corresponding function
# get the corresponding function or functional class
if isinstance(term_cfg.func, str):
term_cfg.func = string_to_callable(term_cfg.func)
# initialize the term if it is a class
if inspect.isclass(term_cfg.func):
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
# add the "self" argument to the count
min_argc += 1
# check if function is callable
if not callable(term_cfg.func):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
......
......@@ -27,6 +27,13 @@ class ManagerBaseTermCfg:
"""The function to be called for the term.
The function must take the environment object as the first argument.
Note:
It also supports `callable classes`_, i.e. classes that implement the :meth:`__call__`
method.
..`callable objects`: https://docs.python.org/3/reference/datamodel.html#object.__call__
"""
sensor_name: str | None = None
"""The name of the sensor required by the term. Defaults to None.
......
......@@ -188,6 +188,10 @@ class ObservationManager(ManagerBase):
# check for non config
if term_cfg is None:
continue
if not isinstance(term_cfg, ObservationTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type ObservationTermCfg. Received '{type(term_cfg)}'."
)
# resolve common terms in the config
self._resolve_common_term_cfg(f"{group_name}/{term_name}", term_cfg, min_argc=1)
# check noise settings
......
......@@ -155,6 +155,11 @@ class RandomizationManager(ManagerBase):
# check for non config
if term_cfg is None:
continue
# check for valid config type
if not isinstance(term_cfg, RandomizationTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type RandomizationTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# check if mode is a new mode
......
......@@ -148,6 +148,11 @@ class RewardManager(ManagerBase):
# check for non config
if term_cfg is None:
continue
# check for valid config type
if not isinstance(term_cfg, RewardTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type RewardTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
# remove zero scales and multiply non-zero ones by dt
......
......@@ -148,6 +148,11 @@ class TerminationManager(ManagerBase):
# check for non config
if term_cfg is None:
continue
# check for valid config type
if not isinstance(term_cfg, TerminationTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type TerminationTermCfg. Received '{type(term_cfg)}'."
)
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
# add function to list
......
......@@ -27,6 +27,28 @@ def grilled_chicken_with_yoghurt(env, hot: bool, bland: float):
return hot * bland * torch.ones(env.num_envs, 5, device=env.device)
class complex_function_class:
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
self.env = env
# define some variables
self._cost = 2 * self.env.num_envs
def __call__(self, env: object) -> torch.Tensor:
return torch.ones(env.num_envs, 2, device=env.device) * self._cost
class non_callable_complex_function_class:
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
self.env = env
# define some variables
self._cost = 2 * self.env.num_envs
def call_me(self, env: object) -> torch.Tensor:
return torch.ones(env.num_envs, 2, device=env.device) * self._cost
class TestObservationManager(unittest.TestCase):
"""Test cases for various situations with observation manager."""
......@@ -189,6 +211,53 @@ class TestObservationManager(unittest.TestCase):
with self.assertRaises(ValueError):
self.obs_man = ObservationManager(cfg, self.env)
def test_callable_class_term(self):
"""Test the observation computation with callable class term."""
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
term_1 = ObservationTermCfg(func=grilled_chicken, scale=10)
term_2 = ObservationTermCfg(func=complex_function_class, scale=0.2)
policy: ObservationGroupCfg = PolicyCfg()
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# check the observation shape
self.assertEqual((self.env.num_envs, 6), observations["policy"].shape)
self.assertEqual(observations["policy"][0, -1].item(), 2 * self.env.num_envs * 0.2)
def test_non_callable_class_term(self):
"""Test the observation computation with non-callable class term."""
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
term_1 = ObservationTermCfg(func=grilled_chicken, scale=10)
term_2 = ObservationTermCfg(func=non_callable_complex_function_class, scale=0.2)
policy: ObservationGroupCfg = PolicyCfg()
# create observation manager config
cfg = MyObservationManagerCfg()
# create observation manager
with self.assertRaises(AttributeError):
self.obs_man = ObservationManager(cfg, self.env)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment