Unverified Commit 505679ff authored by James Smith's avatar James Smith Committed by GitHub

Adds serialization to observation and action managers (#2234)

# Description
This PR adds a serialize method to observation and action manager terms,
so that they can be stored alongside a trained policy for later
introspection.

Fixes # (issue)

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent b1cd175f
......@@ -358,6 +358,14 @@ class ActionManager(ManagerBase):
"""
return self._terms[name]
def serialize(self) -> dict:
"""Serialize the action manager configuration.
Returns:
A dictionary of serialized action term configurations.
"""
return {term_name: term.serialize() for term_name, term in self._terms.items()}
"""
Helper functions.
"""
......
......@@ -16,7 +16,7 @@ import omni.log
import omni.timeline
import isaaclab.utils.string as string_utils
from isaaclab.utils import string_to_callable
from isaaclab.utils import class_to_dict, string_to_callable
from .manager_term_cfg import ManagerTermBaseCfg
from .scene_entity_cfg import SceneEntityCfg
......@@ -79,6 +79,11 @@ class ManagerTermBase(ABC):
"""Device on which to perform computations."""
return self._env.device
@property
def __name__(self) -> str:
"""Return the name of the class or subclass."""
return self.__class__.__name__
"""
Operations.
"""
......@@ -92,6 +97,10 @@ class ManagerTermBase(ABC):
"""
pass
def serialize(self) -> dict:
"""General serialization call. Includes the configuration dict."""
return {"cfg": class_to_dict(self.cfg)}
def __call__(self, *args) -> Any:
"""Returns the value of the term required by the manager.
......
......@@ -14,7 +14,7 @@ from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING
from isaaclab.utils import modifiers
from isaaclab.utils import class_to_dict, modifiers
from isaaclab.utils.buffers import CircularBuffer
from .manager_base import ManagerBase, ManagerTermBase
......@@ -334,6 +334,29 @@ class ObservationManager(ManagerBase):
else:
return group_obs
def serialize(self) -> dict:
"""Serialize the observation term configurations for all active groups.
Returns:
A dictionary where each group name maps to its serialized observation term configurations.
"""
output = {
group_name: {
term_name: (
term_cfg.func.serialize()
if isinstance(term_cfg.func, ManagerTermBase)
else {"cfg": class_to_dict(term_cfg)}
)
for term_name, term_cfg in zip(
self._group_obs_term_names[group_name],
self._group_obs_term_cfgs[group_name],
)
}
for group_name in self.active_terms.keys()
}
return output
"""
Helper functions.
"""
......
......@@ -18,11 +18,21 @@ simulation_app = AppLauncher(headless=True).app
import torch
import unittest
from collections import namedtuple
from typing import TYPE_CHECKING
import isaaclab.sim as sim_utils
from isaaclab.managers import ManagerTermBase, ObservationGroupCfg, ObservationManager, ObservationTermCfg
from isaaclab.managers import (
ManagerTermBase,
ObservationGroupCfg,
ObservationManager,
ObservationTermCfg,
RewardTermCfg,
)
from isaaclab.utils import configclass, modifiers
if TYPE_CHECKING:
from isaaclab.envs import ManagerBasedEnv
def grilled_chicken(env):
return torch.ones(env.num_envs, 4, device=env.device)
......@@ -662,6 +672,42 @@ class TestObservationManager(unittest.TestCase):
with self.assertRaises(ValueError):
self.obs_man = ObservationManager(cfg, self.env)
def test_serialize(self):
"""Test serialize call for ManagerTermBase terms."""
serialize_data = {"test": 0}
class test_serialize_term(ManagerTermBase):
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedEnv):
super().__init__(cfg, env)
def __call__(self, env: ManagerBasedEnv) -> torch.Tensor:
return grilled_chicken(env)
def serialize(self) -> dict:
return serialize_data
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
concatenate_terms = False
term_1 = ObservationTermCfg(func=test_serialize_term)
policy: ObservationGroupCfg = PolicyCfg()
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# check expected output
self.assertEqual(self.obs_man.serialize(), {"policy": {"term_1": serialize_data}})
if __name__ == "__main__":
run_tests()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment