Unverified Commit ba314082 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Initializes manager term classes only when sim starts (#2117)

# Description

To support creation of managers before the simulation starts playing (as
needed by the event manager for USD randomizations), the MR in #2040
added a callback to resolve scene entities at runtime. However, certain
class-based manager terms can also not be initialized if the simulation
is not playing. Those terms may often rely on parameters that are only
available once simulation plays (for instance, joint position limits).

This MR moves the initializations of class-based manager terms to the
callback too.

Fixes #2136

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 875dc627
......@@ -135,7 +135,6 @@ class DirectMARLEnv(gym.Env):
# that must happen before the simulation starts. Example: randomizing mesh scale
if self.cfg.events:
self.event_manager = EventManager(self.cfg.events, self)
print("[INFO] Event Manager: ", self.event_manager)
# apply USD-related randomization events
if "prestartup" in self.event_manager.available_modes:
......@@ -198,6 +197,9 @@ class DirectMARLEnv(gym.Env):
# perform events at the start of the simulation
if self.cfg.events:
# we print it here to make the logging consistent
print("[INFO] Event Manager: ", self.event_manager)
if "startup" in self.event_manager.available_modes:
self.event_manager.apply(mode="startup")
......
......@@ -141,7 +141,6 @@ class DirectRLEnv(gym.Env):
# that must happen before the simulation starts. Example: randomizing mesh scale
if self.cfg.events:
self.event_manager = EventManager(self.cfg.events, self)
print("[INFO] Event Manager: ", self.event_manager)
# apply USD-related randomization events
if "prestartup" in self.event_manager.available_modes:
......@@ -202,6 +201,9 @@ class DirectRLEnv(gym.Env):
# perform events at the start of the simulation
if self.cfg.events:
# we print it here to make the logging consistent
print("[INFO] Event Manager: ", self.event_manager)
if "startup" in self.event_manager.available_modes:
self.event_manager.apply(mode="startup")
......
......@@ -140,7 +140,6 @@ class ManagerBasedEnv:
# note: this is needed here (rather than after simulation play) to allow USD-related randomization events
# that must happen before the simulation starts. Example: randomizing mesh scale
self.event_manager = EventManager(self.cfg.events, self)
print("[INFO] Event Manager: ", self.event_manager)
# apply USD-related randomization events
if "prestartup" in self.event_manager.available_modes:
......@@ -232,6 +231,8 @@ class ManagerBasedEnv:
"""
# prepare the managers
# -- event manager (we print it here to make the logging consistent)
print("[INFO] Event Manager: ", self.event_manager)
# -- recorder manager
self.recorder_manager = RecorderManager(self.cfg.recorders, self)
print("[INFO] Recorder Manager: ", self.recorder_manager)
......
......@@ -1162,7 +1162,9 @@ class randomize_visual_texture_material(ManagerTermBase):
event_name = cfg.params.get("event_name")
texture_rotation = cfg.params.get("texture_rotation", (0.0, 0.0))
# check to make sure replicate_physics is set to False, else raise warning
# check to make sure replicate_physics is set to False, else raise error
# note: We add an explicit check here since texture randomization can happen outside of 'prestartup' mode
# and the event manager doesn't check in that case.
if env.cfg.scene.replicate_physics:
raise RuntimeError(
"Unable to randomize visual texture material with scene replication enabled."
......@@ -1260,7 +1262,9 @@ class randomize_visual_color(ManagerTermBase):
event_name = cfg.params.get("event_name")
mesh_name: str = cfg.params.get("mesh_name", "") # type: ignore
# check to make sure replicate_physics is set to False, else raise warning
# check to make sure replicate_physics is set to False, else raise error
# note: We add an explicit check here since texture randomization can happen outside of 'prestartup' mode
# and the event manager doesn't check in that case.
if env.cfg.scene.replicate_physics:
raise RuntimeError(
"Unable to randomize visual color with scene replication enabled."
......
......@@ -7,6 +7,7 @@
from __future__ import annotations
import inspect
import torch
from collections.abc import Sequence
from prettytable import PrettyTable
......@@ -186,15 +187,6 @@ class EventManager(ManagerBase):
if mode not in self._mode_term_names:
omni.log.warn(f"Event mode '{mode}' is not defined. Skipping event.")
return
# check if mode is pre-startup and scene replication is enabled
if mode == "prestartup" and self._env.scene.cfg.replicate_physics:
omni.log.warn(
"Scene replication is enabled, which may affect USD-level randomization."
" When assets are replicated, their properties are shared across instances,"
" potentially leading to unintended behavior."
" For stable USD-level randomization, consider disabling scene replication"
" by setting 'replicate_physics' to False in 'InteractiveSceneCfg'."
)
# check if mode is interval and dt is not provided
if mode == "interval" and dt is None:
......@@ -363,6 +355,24 @@ class EventManager(ManagerBase):
# resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# check if mode is pre-startup and scene replication is enabled
if term_cfg.mode == "prestartup" and self._env.scene.cfg.replicate_physics:
raise RuntimeError(
"Scene replication is enabled, which may affect USD-level randomization."
" When assets are replicated, their properties are shared across instances,"
" potentially leading to unintended behavior."
" For stable USD-level randomization, please disable scene replication"
" by setting 'replicate_physics' to False in 'InteractiveSceneCfg'."
)
# for event terms with mode "prestartup", we assume a callable class term
# can be initialized before the simulation starts.
# this is done to ensure that the USD-level randomization is possible before the simulation starts.
if inspect.isclass(term_cfg.func) and term_cfg.mode == "prestartup":
omni.log.info(f"Initializing term '{term_name}' with class '{term_cfg.func.__name__}'.")
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
# check if mode is a new mode
if term_cfg.mode not in self._mode_term_names:
# add new mode
......
......@@ -112,7 +112,7 @@ class ManagerTermBase(ABC):
Returns:
The value of the term.
"""
raise NotImplementedError
raise NotImplementedError("The method '__call__' should be implemented by the subclass.")
class ManagerBase(ABC):
......@@ -136,32 +136,34 @@ class ManagerBase(ABC):
self.cfg = copy.deepcopy(cfg)
self._env = env
# parse config to create terms information
if self.cfg:
self._prepare_terms()
# if the simulation is not playing, we use callbacks to trigger the resolution of the scene
# entities configuration. this is needed for cases where the manager is created after the
# simulation, but before the simulation is playing.
# FIXME: Once Isaac Sim supports storing this information as USD schema, we can remove this
# callback and resolve the scene entities directly inside `_prepare_terms`.
if not self._env.sim.is_playing():
# note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor
# is called
# The order is set to 20 to allow asset/sensor initialization to complete before the scene entities
# are resolved. Those have the order 10.
timeline_event_stream = omni.timeline.get_timeline_interface().get_timeline_event_stream()
self._resolve_scene_entities_handle = timeline_event_stream.create_subscription_to_pop_by_type(
self._resolve_terms_handle = timeline_event_stream.create_subscription_to_pop_by_type(
int(omni.timeline.TimelineEventType.PLAY),
lambda event, obj=weakref.proxy(self): obj._resolve_scene_entities_callback(event),
lambda event, obj=weakref.proxy(self): obj._resolve_terms_callback(event),
order=20,
)
else:
self._resolve_scene_entities_handle = None
self._resolve_terms_handle = None
# parse config to create terms information
if self.cfg:
self._prepare_terms()
def __del__(self):
"""Delete the manager."""
if self._resolve_scene_entities_handle:
self._resolve_scene_entities_handle.unsubscribe()
self._resolve_scene_entities_handle = None
if self._resolve_terms_handle:
self._resolve_terms_handle.unsubscribe()
self._resolve_terms_handle = None
"""
Properties.
......@@ -206,7 +208,7 @@ class ManagerBase(ABC):
specified as regular expressions or a list of regular expressions. The search is
performed on the active terms in the manager.
Please check the :meth:`isaaclab.utils.string_utils.resolve_matching_names` function for more
Please check the :meth:`~isaaclab.utils.string_utils.resolve_matching_names` function for more
information on the name matching.
Args:
......@@ -249,11 +251,10 @@ class ManagerBase(ABC):
Internal callbacks.
"""
def _resolve_scene_entities_callback(self, event):
"""Resolve the scene entities configuration.
def _resolve_terms_callback(self, event):
"""Resolve configurations of terms once the simulation starts.
This callback is called when the simulation starts. It is used to resolve the
scene entities configuration for the terms.
Please check the :meth:`_process_term_cfg_at_play` method for more information.
"""
# check if config is dict already
if isinstance(self.cfg, dict):
......@@ -266,17 +267,26 @@ class ManagerBase(ABC):
# check for non config
if term_cfg is None:
continue
# resolve the scene entity configuration
self._resolve_scene_entity_cfg(term_name, term_cfg)
# process attributes at runtime
# these properties are only resolvable once the simulation starts playing
self._process_term_cfg_at_play(term_name, term_cfg)
"""
Helper functions.
Internal functions.
"""
def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, min_argc: int = 1):
"""Resolve common term configuration.
"""Resolve common attributes of the term configuration.
Usually, called by the :meth:`_prepare_terms` method to resolve common attributes of the term
configuration. These include:
Usually, called by the :meth:`_prepare_terms` method to resolve common term configuration.
* Resolving the term function and checking if it is callable.
* Checking if the term function's arguments are matched by the parameters.
* Resolving special attributes of the term configuration like ``asset_cfg``, ``sensor_cfg``, etc.
* Initializing the term if it is a class.
The last two steps are only possible once the simulation starts playing.
By default, all term functions are expected to have at least one argument, which is the
environment object. Some other managers may expect functions to take more arguments, for
......@@ -303,29 +313,31 @@ class ManagerBase(ABC):
f" Received: '{type(term_cfg)}'."
)
# iterate over all the entities and parse the joint and body names
if self._env.sim.is_playing():
self._resolve_scene_entity_cfg(term_name, term_cfg)
# get the corresponding function or functional class
if isinstance(term_cfg.func, str):
term_cfg.func = string_to_callable(term_cfg.func)
# check if function is callable
if not callable(term_cfg.func):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
# initialize the term if it is a class
# check if the term is a class of valid type
if inspect.isclass(term_cfg.func):
if not issubclass(term_cfg.func, ManagerTermBase):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type ManagerTermBase."
f" Received: '{type(term_cfg.func)}'."
)
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
func_static = term_cfg.func.__call__
min_argc += 1 # forward by 1 to account for 'self' argument
else:
func_static = term_cfg.func
# check if function is callable
if not callable(term_cfg.func):
if not callable(func_static):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
# check if term's arguments are matched by params
# check statically if the term's arguments are matched by params
term_params = list(term_cfg.params.keys())
args = inspect.signature(term_cfg.func).parameters
args = inspect.signature(func_static).parameters
args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty]
args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty]
args = args_without_defaults + args_with_defaults
......@@ -338,8 +350,22 @@ class ManagerBase(ABC):
f" and optional parameters: {args_with_defaults}, but received: {term_params}."
)
def _resolve_scene_entity_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg):
"""Resolve the scene entity configuration for the term.
# process attributes at runtime
# these properties are only resolvable once the simulation starts playing
if self._env.sim.is_playing():
self._process_term_cfg_at_play(term_name, term_cfg)
def _process_term_cfg_at_play(self, term_name: str, term_cfg: ManagerTermBaseCfg):
"""Process the term configuration at runtime.
This function is called when the simulation starts playing. It is used to process the term
configuration at runtime. This includes:
* Resolving the scene entity configuration for the term.
* Initializing the term if it is a class.
Since the above steps rely on PhysX to parse over the simulation scene, they are deferred
until the simulation starts playing.
Args:
term_name: The name of the term.
......@@ -362,3 +388,8 @@ class ManagerBase(ABC):
omni.log.info(msg)
# store the entity
term_cfg.params[key] = value
# initialize the term if it is a class
if inspect.isclass(term_cfg.func):
omni.log.info(f"Initializing term '{term_name}' with class '{term_cfg.func.__name__}'.")
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
......@@ -352,6 +352,14 @@ class ObservationManager(ManagerBase):
# we store it as a separate list to only call reset on them and prevent unnecessary calls
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list()
# make sure the simulation is playing since we compute obs dims which needs asset quantities
if not self._env.sim.is_playing():
raise RuntimeError(
"Simulation is not playing. Observation manager requires the simulation to be playing"
" to compute observation dimensions. Please start the simulation before using the"
" observation manager."
)
# check if config is dict already
if isinstance(self.cfg, dict):
group_cfg_items = self.cfg.items()
......@@ -407,8 +415,10 @@ class ObservationManager(ManagerBase):
# add term config to list to list
self._group_obs_term_names[group_name].append(term_name)
self._group_obs_term_cfgs[group_name].append(term_cfg)
# call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# create history buffers and calculate history term dimensions
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
......
......@@ -275,70 +275,86 @@ class CubeEnvCfg(ManagerBasedEnvCfg):
# simulation settings
self.sim.dt = 0.01
self.sim.physics_material = self.scene.terrain.physics_material
self.sim.render_interval = self.decimation
class TestScaleRandomization(unittest.TestCase):
"""Test for texture randomization"""
"""Test for scale randomization."""
"""
Tests
"""
def test_scale_randomization(self):
"""Main function."""
# setup base environment
env = ManagerBasedEnv(cfg=CubeEnvCfg())
# setup target position commands
target_position = torch.rand(env.num_envs, 3, device=env.device) * 2
target_position[:, 2] += 2.0
# offset all targets so that they move to the world origin
target_position -= env.scene.env_origins
stage = omni.usd.get_context().get_stage()
# test to make sure all assets in the scene are created
all_prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube.*/.*")
self.assertEqual(len(all_prim_paths), (env.num_envs * 2))
# test to make sure randomized values are truly random
applied_scaling_randomization = set()
prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube1")
for i in range(3):
prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i])
scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale")
if scale_spec.default in applied_scaling_randomization:
raise ValueError(
"Detected repeat in applied scale values - indication scaling randomization is not working."
)
applied_scaling_randomization.add(scale_spec.default)
# test to make sure that fixed values are assigned correctly
prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube2")
for i in range(3):
prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i])
scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale")
self.assertEqual(tuple(scale_spec.default), (1.0, 1.0, 1.0))
# simulate physics
with torch.inference_mode():
for count in range(200):
# reset every few steps to check nothing breaks
if count % 100 == 0:
env.reset()
# step the environment
env.step(target_position)
env.close()
"""Test scale randomization for cube environment."""
for device in ["cpu", "cuda"]:
with self.subTest(device=device):
# create a new stage
omni.usd.get_context().new_stage()
# set the device
env_cfg = CubeEnvCfg()
env_cfg.sim.device = device
# setup base environment
env = ManagerBasedEnv(cfg=env_cfg)
# setup target position commands
target_position = torch.rand(env.num_envs, 3, device=env.device) * 2
target_position[:, 2] += 2.0
# offset all targets so that they move to the world origin
target_position -= env.scene.env_origins
# test to make sure all assets in the scene are created
all_prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube.*/.*")
self.assertEqual(len(all_prim_paths), (env.num_envs * 2))
# test to make sure randomized values are truly random
applied_scaling_randomization = set()
prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube1")
# get the stage
stage = omni.usd.get_context().get_stage()
# check if the scale values are truly random
for i in range(3):
prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i])
scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale")
if scale_spec.default in applied_scaling_randomization:
raise ValueError(
"Detected repeat in applied scale values - indication scaling randomization is not working."
)
applied_scaling_randomization.add(scale_spec.default)
# test to make sure that fixed values are assigned correctly
prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube2")
for i in range(3):
prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i])
scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale")
self.assertEqual(tuple(scale_spec.default), (1.0, 1.0, 1.0))
# simulate physics
with torch.inference_mode():
for count in range(200):
# reset every few steps to check nothing breaks
if count % 100 == 0:
env.reset()
# step the environment
env.step(target_position)
env.close()
def test_scale_randomization_failure_replicate_physics(self):
with self.assertRaises(ValueError):
cfg_failure = CubeEnvCfg()
cfg_failure.scene.replicate_physics = True
"""Test scale randomization failure when replicate physics is set to True."""
# create a new stage
omni.usd.get_context().new_stage()
# set the arguments
cfg_failure = CubeEnvCfg()
cfg_failure.scene.replicate_physics = True
# run the test
with self.assertRaises(RuntimeError):
env = ManagerBasedEnv(cfg_failure)
env.close()
env.close()
if __name__ == "__main__":
......
......@@ -21,6 +21,8 @@ import math
import torch
import unittest
import omni.usd
import isaaclab.envs.mdp as mdp
from isaaclab.envs import ManagerBasedEnv, ManagerBasedEnvCfg
from isaaclab.managers import EventTermCfg as EventTerm
......@@ -64,10 +66,12 @@ class ObservationsCfg:
class EventCfg:
"""Configuration for events."""
# on reset apply a new set of textures
# on prestartup apply a new set of textures
# note from @mayank: Changed from 'reset' to 'prestartup' to make test pass.
# The error happens otherwise on Kit thread which is not the main thread.
cart_texture_randomizer = EventTerm(
func=mdp.randomize_visual_texture_material,
mode="reset",
mode="prestartup",
params={
"asset_cfg": SceneEntityCfg("robot", body_names=["cart"]),
"texture_paths": [
......@@ -83,6 +87,7 @@ class EventCfg:
},
)
# on reset apply a new set of textures
pole_texture_randomizer = EventTerm(
func=mdp.randomize_visual_texture_material,
mode="reset",
......@@ -153,35 +158,47 @@ class TestTextureRandomization(unittest.TestCase):
"""
def test_texture_randomization(self):
# set the arguments
env_cfg = CartpoleEnvCfg()
env_cfg.scene.num_envs = 16
env_cfg.scene.replicate_physics = False
# setup base environment
env = ManagerBasedEnv(cfg=env_cfg)
# simulate physics
with torch.inference_mode():
for count in range(50):
# reset every few steps to check nothing breaks
if count % 10 == 0:
env.reset()
# sample random actions
joint_efforts = torch.randn_like(env.action_manager.action)
# step the environment
env.step(joint_efforts)
env.close()
"""Test texture randomization for cartpole environment."""
for device in ["cpu", "cuda"]:
with self.subTest(device=device):
# create a new stage
omni.usd.get_context().new_stage()
# set the arguments
env_cfg = CartpoleEnvCfg()
env_cfg.scene.num_envs = 16
env_cfg.scene.replicate_physics = False
env_cfg.sim.device = device
# setup base environment
env = ManagerBasedEnv(cfg=env_cfg)
# simulate physics
with torch.inference_mode():
for count in range(50):
# reset every few steps to check nothing breaks
if count % 10 == 0:
env.reset()
# sample random actions
joint_efforts = torch.randn_like(env.action_manager.action)
# step the environment
env.step(joint_efforts)
env.close()
def test_texture_randomization_failure_replicate_physics(self):
with self.assertRaises(ValueError):
cfg_failure = CartpoleEnvCfg()
cfg_failure.scene.num_envs = 16
cfg_failure.scene.replicate_physics = True
env = ManagerBasedEnv(cfg_failure)
"""Test texture randomization failure when replicate physics is set to True."""
# create a new stage
omni.usd.get_context().new_stage()
# set the arguments
cfg_failure = CartpoleEnvCfg()
cfg_failure.scene.num_envs = 16
cfg_failure.scene.replicate_physics = True
env.close()
with self.assertRaises(RuntimeError):
env = ManagerBasedEnv(cfg_failure)
env.close()
if __name__ == "__main__":
......
......@@ -19,8 +19,8 @@ import torch
import unittest
from collections import namedtuple
import isaaclab.sim as sim_utils
from isaaclab.managers import ManagerTermBase, ObservationGroupCfg, ObservationManager, ObservationTermCfg
from isaaclab.sim import SimulationContext
from isaaclab.utils import configclass, modifiers
......@@ -100,11 +100,15 @@ class TestObservationManager(unittest.TestCase):
self.num_envs = 20
self.device = "cuda:0"
# set up sim
self.sim = SimulationContext()
sim_cfg = sim_utils.SimulationCfg(dt=self.dt, device=self.device)
sim = sim_utils.SimulationContext(sim_cfg)
# create dummy environment
self.env = namedtuple("ManagerBasedEnv", ["num_envs", "device", "data", "dt", "sim"])(
self.num_envs, self.device, MyDataClass(self.num_envs, self.device), self.dt, self.sim
self.num_envs, self.device, MyDataClass(self.num_envs, self.device), self.dt, sim
)
# let the simulation play (we need this for observation manager to compute obs dims)
self.env.sim._app_control_on_stop_handle = None
self.env.sim.reset()
def test_str(self):
"""Test the string representation of the observation manager."""
......@@ -382,24 +386,25 @@ class TestObservationManager(unittest.TestCase):
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
print(expected_obs_data_t0, obs_policy)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
torch.testing.assert_close(expected_obs_data_t0, obs_policy)
# test that the history buffer holds previous data
for _ in range(HISTORY_LENGTH):
observations = self.obs_man.compute()
obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_data_t5 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t5, obs_policy))
torch.testing.assert_close(expected_obs_data_t5, obs_policy)
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
torch.testing.assert_close(expected_obs_data_t0, obs_policy)
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))
torch.testing.assert_close(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])
def test_compute_with_2d_history(self):
"""Test the observation computation with history buffers for 2D observations."""
......@@ -482,7 +487,7 @@ class TestObservationManager(unittest.TestCase):
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
torch.testing.assert_close(expected_obs_data_t0, obs_policy)
# test that the history buffer holds previous data
for _ in range(GROUP_HISTORY_LENGTH):
observations = self.obs_man.compute()
......@@ -490,16 +495,16 @@ class TestObservationManager(unittest.TestCase):
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t10 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t10, obs_policy))
torch.testing.assert_close(expected_obs_data_t10, obs_policy)
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
torch.testing.assert_close(expected_obs_data_t0, obs_policy)
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))
torch.testing.assert_close(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])
def test_invalid_observation_config(self):
"""Test the invalid observation config."""
......
......@@ -82,10 +82,10 @@ class joint_pos_limits_penalty_ratio(ManagerTermBase):
def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg):
# add default argument
if "asset_cfg" not in cfg.params:
cfg.params["asset_cfg"] = SceneEntityCfg("robot")
asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot"))
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[cfg.params["asset_cfg"].name]
asset: Articulation = env.scene[asset_cfg.name]
# resolve the gear ratio for each joint
self.gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device)
index_list, _, value_list = string_utils.resolve_matching_names_values(
......@@ -95,7 +95,11 @@ class joint_pos_limits_penalty_ratio(ManagerTermBase):
self.gear_ratio_scaled = self.gear_ratio / torch.max(self.gear_ratio)
def __call__(
self, env: ManagerBasedRLEnv, threshold: float, gear_ratio: dict[str, float], asset_cfg: SceneEntityCfg
self,
env: ManagerBasedRLEnv,
threshold: float,
gear_ratio: dict[str, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
......@@ -118,10 +122,10 @@ class power_consumption(ManagerTermBase):
def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg):
# add default argument
if "asset_cfg" not in cfg.params:
cfg.params["asset_cfg"] = SceneEntityCfg("robot")
asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot"))
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[cfg.params["asset_cfg"].name]
asset: Articulation = env.scene[asset_cfg.name]
# resolve the gear ratio for each joint
self.gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device)
index_list, _, value_list = string_utils.resolve_matching_names_values(
......@@ -130,7 +134,9 @@ class power_consumption(ManagerTermBase):
self.gear_ratio[:, index_list] = torch.tensor(value_list, device=env.device)
self.gear_ratio_scaled = self.gear_ratio / torch.max(self.gear_ratio)
def __call__(self, env: ManagerBasedRLEnv, gear_ratio: dict[str, float], asset_cfg: SceneEntityCfg) -> torch.Tensor:
def __call__(
self, env: ManagerBasedRLEnv, gear_ratio: dict[str, float], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
# return power = torque * velocity (here actions: joint torques)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment