Unverified Commit ba314082 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Initializes manager term classes only when sim starts (#2117)

# Description

To support creation of managers before the simulation starts playing (as
needed by the event manager for USD randomizations), the MR in #2040
added a callback to resolve scene entities at runtime. However, certain
class-based manager terms can also not be initialized if the simulation
is not playing. Those terms may often rely on parameters that are only
available once simulation plays (for instance, joint position limits).

This MR moves the initializations of class-based manager terms to the
callback too.

Fixes #2136

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 875dc627
...@@ -135,7 +135,6 @@ class DirectMARLEnv(gym.Env): ...@@ -135,7 +135,6 @@ class DirectMARLEnv(gym.Env):
# that must happen before the simulation starts. Example: randomizing mesh scale # that must happen before the simulation starts. Example: randomizing mesh scale
if self.cfg.events: if self.cfg.events:
self.event_manager = EventManager(self.cfg.events, self) self.event_manager = EventManager(self.cfg.events, self)
print("[INFO] Event Manager: ", self.event_manager)
# apply USD-related randomization events # apply USD-related randomization events
if "prestartup" in self.event_manager.available_modes: if "prestartup" in self.event_manager.available_modes:
...@@ -198,6 +197,9 @@ class DirectMARLEnv(gym.Env): ...@@ -198,6 +197,9 @@ class DirectMARLEnv(gym.Env):
# perform events at the start of the simulation # perform events at the start of the simulation
if self.cfg.events: if self.cfg.events:
# we print it here to make the logging consistent
print("[INFO] Event Manager: ", self.event_manager)
if "startup" in self.event_manager.available_modes: if "startup" in self.event_manager.available_modes:
self.event_manager.apply(mode="startup") self.event_manager.apply(mode="startup")
......
...@@ -141,7 +141,6 @@ class DirectRLEnv(gym.Env): ...@@ -141,7 +141,6 @@ class DirectRLEnv(gym.Env):
# that must happen before the simulation starts. Example: randomizing mesh scale # that must happen before the simulation starts. Example: randomizing mesh scale
if self.cfg.events: if self.cfg.events:
self.event_manager = EventManager(self.cfg.events, self) self.event_manager = EventManager(self.cfg.events, self)
print("[INFO] Event Manager: ", self.event_manager)
# apply USD-related randomization events # apply USD-related randomization events
if "prestartup" in self.event_manager.available_modes: if "prestartup" in self.event_manager.available_modes:
...@@ -202,6 +201,9 @@ class DirectRLEnv(gym.Env): ...@@ -202,6 +201,9 @@ class DirectRLEnv(gym.Env):
# perform events at the start of the simulation # perform events at the start of the simulation
if self.cfg.events: if self.cfg.events:
# we print it here to make the logging consistent
print("[INFO] Event Manager: ", self.event_manager)
if "startup" in self.event_manager.available_modes: if "startup" in self.event_manager.available_modes:
self.event_manager.apply(mode="startup") self.event_manager.apply(mode="startup")
......
...@@ -140,7 +140,6 @@ class ManagerBasedEnv: ...@@ -140,7 +140,6 @@ class ManagerBasedEnv:
# note: this is needed here (rather than after simulation play) to allow USD-related randomization events # note: this is needed here (rather than after simulation play) to allow USD-related randomization events
# that must happen before the simulation starts. Example: randomizing mesh scale # that must happen before the simulation starts. Example: randomizing mesh scale
self.event_manager = EventManager(self.cfg.events, self) self.event_manager = EventManager(self.cfg.events, self)
print("[INFO] Event Manager: ", self.event_manager)
# apply USD-related randomization events # apply USD-related randomization events
if "prestartup" in self.event_manager.available_modes: if "prestartup" in self.event_manager.available_modes:
...@@ -232,6 +231,8 @@ class ManagerBasedEnv: ...@@ -232,6 +231,8 @@ class ManagerBasedEnv:
""" """
# prepare the managers # prepare the managers
# -- event manager (we print it here to make the logging consistent)
print("[INFO] Event Manager: ", self.event_manager)
# -- recorder manager # -- recorder manager
self.recorder_manager = RecorderManager(self.cfg.recorders, self) self.recorder_manager = RecorderManager(self.cfg.recorders, self)
print("[INFO] Recorder Manager: ", self.recorder_manager) print("[INFO] Recorder Manager: ", self.recorder_manager)
......
...@@ -1162,7 +1162,9 @@ class randomize_visual_texture_material(ManagerTermBase): ...@@ -1162,7 +1162,9 @@ class randomize_visual_texture_material(ManagerTermBase):
event_name = cfg.params.get("event_name") event_name = cfg.params.get("event_name")
texture_rotation = cfg.params.get("texture_rotation", (0.0, 0.0)) texture_rotation = cfg.params.get("texture_rotation", (0.0, 0.0))
# check to make sure replicate_physics is set to False, else raise warning # check to make sure replicate_physics is set to False, else raise error
# note: We add an explicit check here since texture randomization can happen outside of 'prestartup' mode
# and the event manager doesn't check in that case.
if env.cfg.scene.replicate_physics: if env.cfg.scene.replicate_physics:
raise RuntimeError( raise RuntimeError(
"Unable to randomize visual texture material with scene replication enabled." "Unable to randomize visual texture material with scene replication enabled."
...@@ -1260,7 +1262,9 @@ class randomize_visual_color(ManagerTermBase): ...@@ -1260,7 +1262,9 @@ class randomize_visual_color(ManagerTermBase):
event_name = cfg.params.get("event_name") event_name = cfg.params.get("event_name")
mesh_name: str = cfg.params.get("mesh_name", "") # type: ignore mesh_name: str = cfg.params.get("mesh_name", "") # type: ignore
# check to make sure replicate_physics is set to False, else raise warning # check to make sure replicate_physics is set to False, else raise error
# note: We add an explicit check here since texture randomization can happen outside of 'prestartup' mode
# and the event manager doesn't check in that case.
if env.cfg.scene.replicate_physics: if env.cfg.scene.replicate_physics:
raise RuntimeError( raise RuntimeError(
"Unable to randomize visual color with scene replication enabled." "Unable to randomize visual color with scene replication enabled."
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
from __future__ import annotations from __future__ import annotations
import inspect
import torch import torch
from collections.abc import Sequence from collections.abc import Sequence
from prettytable import PrettyTable from prettytable import PrettyTable
...@@ -186,15 +187,6 @@ class EventManager(ManagerBase): ...@@ -186,15 +187,6 @@ class EventManager(ManagerBase):
if mode not in self._mode_term_names: if mode not in self._mode_term_names:
omni.log.warn(f"Event mode '{mode}' is not defined. Skipping event.") omni.log.warn(f"Event mode '{mode}' is not defined. Skipping event.")
return return
# check if mode is pre-startup and scene replication is enabled
if mode == "prestartup" and self._env.scene.cfg.replicate_physics:
omni.log.warn(
"Scene replication is enabled, which may affect USD-level randomization."
" When assets are replicated, their properties are shared across instances,"
" potentially leading to unintended behavior."
" For stable USD-level randomization, consider disabling scene replication"
" by setting 'replicate_physics' to False in 'InteractiveSceneCfg'."
)
# check if mode is interval and dt is not provided # check if mode is interval and dt is not provided
if mode == "interval" and dt is None: if mode == "interval" and dt is None:
...@@ -363,6 +355,24 @@ class EventManager(ManagerBase): ...@@ -363,6 +355,24 @@ class EventManager(ManagerBase):
# resolve common parameters # resolve common parameters
self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2) self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2)
# check if mode is pre-startup and scene replication is enabled
if term_cfg.mode == "prestartup" and self._env.scene.cfg.replicate_physics:
raise RuntimeError(
"Scene replication is enabled, which may affect USD-level randomization."
" When assets are replicated, their properties are shared across instances,"
" potentially leading to unintended behavior."
" For stable USD-level randomization, please disable scene replication"
" by setting 'replicate_physics' to False in 'InteractiveSceneCfg'."
)
# for event terms with mode "prestartup", we assume a callable class term
# can be initialized before the simulation starts.
# this is done to ensure that the USD-level randomization is possible before the simulation starts.
if inspect.isclass(term_cfg.func) and term_cfg.mode == "prestartup":
omni.log.info(f"Initializing term '{term_name}' with class '{term_cfg.func.__name__}'.")
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
# check if mode is a new mode # check if mode is a new mode
if term_cfg.mode not in self._mode_term_names: if term_cfg.mode not in self._mode_term_names:
# add new mode # add new mode
......
...@@ -112,7 +112,7 @@ class ManagerTermBase(ABC): ...@@ -112,7 +112,7 @@ class ManagerTermBase(ABC):
Returns: Returns:
The value of the term. The value of the term.
""" """
raise NotImplementedError raise NotImplementedError("The method '__call__' should be implemented by the subclass.")
class ManagerBase(ABC): class ManagerBase(ABC):
...@@ -136,32 +136,34 @@ class ManagerBase(ABC): ...@@ -136,32 +136,34 @@ class ManagerBase(ABC):
self.cfg = copy.deepcopy(cfg) self.cfg = copy.deepcopy(cfg)
self._env = env self._env = env
# parse config to create terms information
if self.cfg:
self._prepare_terms()
# if the simulation is not playing, we use callbacks to trigger the resolution of the scene # if the simulation is not playing, we use callbacks to trigger the resolution of the scene
# entities configuration. this is needed for cases where the manager is created after the # entities configuration. this is needed for cases where the manager is created after the
# simulation, but before the simulation is playing. # simulation, but before the simulation is playing.
# FIXME: Once Isaac Sim supports storing this information as USD schema, we can remove this
# callback and resolve the scene entities directly inside `_prepare_terms`.
if not self._env.sim.is_playing(): if not self._env.sim.is_playing():
# note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor # note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor
# is called # is called
# The order is set to 20 to allow asset/sensor initialization to complete before the scene entities # The order is set to 20 to allow asset/sensor initialization to complete before the scene entities
# are resolved. Those have the order 10. # are resolved. Those have the order 10.
timeline_event_stream = omni.timeline.get_timeline_interface().get_timeline_event_stream() timeline_event_stream = omni.timeline.get_timeline_interface().get_timeline_event_stream()
self._resolve_scene_entities_handle = timeline_event_stream.create_subscription_to_pop_by_type( self._resolve_terms_handle = timeline_event_stream.create_subscription_to_pop_by_type(
int(omni.timeline.TimelineEventType.PLAY), int(omni.timeline.TimelineEventType.PLAY),
lambda event, obj=weakref.proxy(self): obj._resolve_scene_entities_callback(event), lambda event, obj=weakref.proxy(self): obj._resolve_terms_callback(event),
order=20, order=20,
) )
else: else:
self._resolve_scene_entities_handle = None self._resolve_terms_handle = None
# parse config to create terms information
if self.cfg:
self._prepare_terms()
def __del__(self): def __del__(self):
"""Delete the manager.""" """Delete the manager."""
if self._resolve_scene_entities_handle: if self._resolve_terms_handle:
self._resolve_scene_entities_handle.unsubscribe() self._resolve_terms_handle.unsubscribe()
self._resolve_scene_entities_handle = None self._resolve_terms_handle = None
""" """
Properties. Properties.
...@@ -206,7 +208,7 @@ class ManagerBase(ABC): ...@@ -206,7 +208,7 @@ class ManagerBase(ABC):
specified as regular expressions or a list of regular expressions. The search is specified as regular expressions or a list of regular expressions. The search is
performed on the active terms in the manager. performed on the active terms in the manager.
Please check the :meth:`isaaclab.utils.string_utils.resolve_matching_names` function for more Please check the :meth:`~isaaclab.utils.string_utils.resolve_matching_names` function for more
information on the name matching. information on the name matching.
Args: Args:
...@@ -249,11 +251,10 @@ class ManagerBase(ABC): ...@@ -249,11 +251,10 @@ class ManagerBase(ABC):
Internal callbacks. Internal callbacks.
""" """
def _resolve_scene_entities_callback(self, event): def _resolve_terms_callback(self, event):
"""Resolve the scene entities configuration. """Resolve configurations of terms once the simulation starts.
This callback is called when the simulation starts. It is used to resolve the Please check the :meth:`_process_term_cfg_at_play` method for more information.
scene entities configuration for the terms.
""" """
# check if config is dict already # check if config is dict already
if isinstance(self.cfg, dict): if isinstance(self.cfg, dict):
...@@ -266,17 +267,26 @@ class ManagerBase(ABC): ...@@ -266,17 +267,26 @@ class ManagerBase(ABC):
# check for non config # check for non config
if term_cfg is None: if term_cfg is None:
continue continue
# resolve the scene entity configuration # process attributes at runtime
self._resolve_scene_entity_cfg(term_name, term_cfg) # these properties are only resolvable once the simulation starts playing
self._process_term_cfg_at_play(term_name, term_cfg)
""" """
Helper functions. Internal functions.
""" """
def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, min_argc: int = 1): def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, min_argc: int = 1):
"""Resolve common term configuration. """Resolve common attributes of the term configuration.
Usually, called by the :meth:`_prepare_terms` method to resolve common attributes of the term
configuration. These include:
Usually, called by the :meth:`_prepare_terms` method to resolve common term configuration. * Resolving the term function and checking if it is callable.
* Checking if the term function's arguments are matched by the parameters.
* Resolving special attributes of the term configuration like ``asset_cfg``, ``sensor_cfg``, etc.
* Initializing the term if it is a class.
The last two steps are only possible once the simulation starts playing.
By default, all term functions are expected to have at least one argument, which is the By default, all term functions are expected to have at least one argument, which is the
environment object. Some other managers may expect functions to take more arguments, for environment object. Some other managers may expect functions to take more arguments, for
...@@ -303,29 +313,31 @@ class ManagerBase(ABC): ...@@ -303,29 +313,31 @@ class ManagerBase(ABC):
f" Received: '{type(term_cfg)}'." f" Received: '{type(term_cfg)}'."
) )
# iterate over all the entities and parse the joint and body names
if self._env.sim.is_playing():
self._resolve_scene_entity_cfg(term_name, term_cfg)
# get the corresponding function or functional class # get the corresponding function or functional class
if isinstance(term_cfg.func, str): if isinstance(term_cfg.func, str):
term_cfg.func = string_to_callable(term_cfg.func) term_cfg.func = string_to_callable(term_cfg.func)
# check if function is callable
if not callable(term_cfg.func):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
# initialize the term if it is a class # check if the term is a class of valid type
if inspect.isclass(term_cfg.func): if inspect.isclass(term_cfg.func):
if not issubclass(term_cfg.func, ManagerTermBase): if not issubclass(term_cfg.func, ManagerTermBase):
raise TypeError( raise TypeError(
f"Configuration for the term '{term_name}' is not of type ManagerTermBase." f"Configuration for the term '{term_name}' is not of type ManagerTermBase."
f" Received: '{type(term_cfg.func)}'." f" Received: '{type(term_cfg.func)}'."
) )
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env) func_static = term_cfg.func.__call__
min_argc += 1 # forward by 1 to account for 'self' argument
else:
func_static = term_cfg.func
# check if function is callable # check if function is callable
if not callable(term_cfg.func): if not callable(func_static):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}") raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
# check if term's arguments are matched by params # check statically if the term's arguments are matched by params
term_params = list(term_cfg.params.keys()) term_params = list(term_cfg.params.keys())
args = inspect.signature(term_cfg.func).parameters args = inspect.signature(func_static).parameters
args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty] args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty]
args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty] args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty]
args = args_without_defaults + args_with_defaults args = args_without_defaults + args_with_defaults
...@@ -338,8 +350,22 @@ class ManagerBase(ABC): ...@@ -338,8 +350,22 @@ class ManagerBase(ABC):
f" and optional parameters: {args_with_defaults}, but received: {term_params}." f" and optional parameters: {args_with_defaults}, but received: {term_params}."
) )
def _resolve_scene_entity_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg): # process attributes at runtime
"""Resolve the scene entity configuration for the term. # these properties are only resolvable once the simulation starts playing
if self._env.sim.is_playing():
self._process_term_cfg_at_play(term_name, term_cfg)
def _process_term_cfg_at_play(self, term_name: str, term_cfg: ManagerTermBaseCfg):
"""Process the term configuration at runtime.
This function is called when the simulation starts playing. It is used to process the term
configuration at runtime. This includes:
* Resolving the scene entity configuration for the term.
* Initializing the term if it is a class.
Since the above steps rely on PhysX to parse over the simulation scene, they are deferred
until the simulation starts playing.
Args: Args:
term_name: The name of the term. term_name: The name of the term.
...@@ -362,3 +388,8 @@ class ManagerBase(ABC): ...@@ -362,3 +388,8 @@ class ManagerBase(ABC):
omni.log.info(msg) omni.log.info(msg)
# store the entity # store the entity
term_cfg.params[key] = value term_cfg.params[key] = value
# initialize the term if it is a class
if inspect.isclass(term_cfg.func):
omni.log.info(f"Initializing term '{term_name}' with class '{term_cfg.func.__name__}'.")
term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env)
...@@ -352,6 +352,14 @@ class ObservationManager(ManagerBase): ...@@ -352,6 +352,14 @@ class ObservationManager(ManagerBase):
# we store it as a separate list to only call reset on them and prevent unnecessary calls # we store it as a separate list to only call reset on them and prevent unnecessary calls
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list() self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list()
# make sure the simulation is playing since we compute obs dims which needs asset quantities
if not self._env.sim.is_playing():
raise RuntimeError(
"Simulation is not playing. Observation manager requires the simulation to be playing"
" to compute observation dimensions. Please start the simulation before using the"
" observation manager."
)
# check if config is dict already # check if config is dict already
if isinstance(self.cfg, dict): if isinstance(self.cfg, dict):
group_cfg_items = self.cfg.items() group_cfg_items = self.cfg.items()
...@@ -407,8 +415,10 @@ class ObservationManager(ManagerBase): ...@@ -407,8 +415,10 @@ class ObservationManager(ManagerBase):
# add term config to list to list # add term config to list to list
self._group_obs_term_names[group_name].append(term_name) self._group_obs_term_names[group_name].append(term_name)
self._group_obs_term_cfgs[group_name].append(term_cfg) self._group_obs_term_cfgs[group_name].append(term_cfg)
# call function the first time to fill up dimensions # call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape) obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# create history buffers and calculate history term dimensions # create history buffers and calculate history term dimensions
if term_cfg.history_length > 0: if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer( group_entry_history_buffer[term_name] = CircularBuffer(
......
...@@ -275,70 +275,86 @@ class CubeEnvCfg(ManagerBasedEnvCfg): ...@@ -275,70 +275,86 @@ class CubeEnvCfg(ManagerBasedEnvCfg):
# simulation settings # simulation settings
self.sim.dt = 0.01 self.sim.dt = 0.01
self.sim.physics_material = self.scene.terrain.physics_material self.sim.physics_material = self.scene.terrain.physics_material
self.sim.render_interval = self.decimation
class TestScaleRandomization(unittest.TestCase): class TestScaleRandomization(unittest.TestCase):
"""Test for texture randomization""" """Test for scale randomization."""
""" """
Tests Tests
""" """
def test_scale_randomization(self): def test_scale_randomization(self):
"""Main function.""" """Test scale randomization for cube environment."""
for device in ["cpu", "cuda"]:
# setup base environment with self.subTest(device=device):
env = ManagerBasedEnv(cfg=CubeEnvCfg()) # create a new stage
# setup target position commands omni.usd.get_context().new_stage()
target_position = torch.rand(env.num_envs, 3, device=env.device) * 2
target_position[:, 2] += 2.0 # set the device
# offset all targets so that they move to the world origin env_cfg = CubeEnvCfg()
target_position -= env.scene.env_origins env_cfg.sim.device = device
stage = omni.usd.get_context().get_stage() # setup base environment
env = ManagerBasedEnv(cfg=env_cfg)
# test to make sure all assets in the scene are created # setup target position commands
all_prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube.*/.*") target_position = torch.rand(env.num_envs, 3, device=env.device) * 2
self.assertEqual(len(all_prim_paths), (env.num_envs * 2)) target_position[:, 2] += 2.0
# offset all targets so that they move to the world origin
# test to make sure randomized values are truly random target_position -= env.scene.env_origins
applied_scaling_randomization = set()
prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube1") # test to make sure all assets in the scene are created
all_prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube.*/.*")
for i in range(3): self.assertEqual(len(all_prim_paths), (env.num_envs * 2))
prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i])
scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale") # test to make sure randomized values are truly random
if scale_spec.default in applied_scaling_randomization: applied_scaling_randomization = set()
raise ValueError( prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube1")
"Detected repeat in applied scale values - indication scaling randomization is not working."
) # get the stage
applied_scaling_randomization.add(scale_spec.default) stage = omni.usd.get_context().get_stage()
# test to make sure that fixed values are assigned correctly # check if the scale values are truly random
prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube2") for i in range(3):
for i in range(3): prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i])
prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i]) scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale")
scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale") if scale_spec.default in applied_scaling_randomization:
self.assertEqual(tuple(scale_spec.default), (1.0, 1.0, 1.0)) raise ValueError(
"Detected repeat in applied scale values - indication scaling randomization is not working."
# simulate physics )
with torch.inference_mode(): applied_scaling_randomization.add(scale_spec.default)
for count in range(200):
# reset every few steps to check nothing breaks # test to make sure that fixed values are assigned correctly
if count % 100 == 0: prim_paths = sim_utils.find_matching_prim_paths("/World/envs/env_.*/cube2")
env.reset() for i in range(3):
# step the environment prim_spec = Sdf.CreatePrimInLayer(stage.GetRootLayer(), prim_paths[i])
env.step(target_position) scale_spec = prim_spec.GetAttributeAtPath(prim_paths[i] + ".xformOp:scale")
self.assertEqual(tuple(scale_spec.default), (1.0, 1.0, 1.0))
env.close()
# simulate physics
with torch.inference_mode():
for count in range(200):
# reset every few steps to check nothing breaks
if count % 100 == 0:
env.reset()
# step the environment
env.step(target_position)
env.close()
def test_scale_randomization_failure_replicate_physics(self): def test_scale_randomization_failure_replicate_physics(self):
with self.assertRaises(ValueError): """Test scale randomization failure when replicate physics is set to True."""
cfg_failure = CubeEnvCfg() # create a new stage
cfg_failure.scene.replicate_physics = True omni.usd.get_context().new_stage()
# set the arguments
cfg_failure = CubeEnvCfg()
cfg_failure.scene.replicate_physics = True
# run the test
with self.assertRaises(RuntimeError):
env = ManagerBasedEnv(cfg_failure) env = ManagerBasedEnv(cfg_failure)
env.close()
env.close()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,6 +21,8 @@ import math ...@@ -21,6 +21,8 @@ import math
import torch import torch
import unittest import unittest
import omni.usd
import isaaclab.envs.mdp as mdp import isaaclab.envs.mdp as mdp
from isaaclab.envs import ManagerBasedEnv, ManagerBasedEnvCfg from isaaclab.envs import ManagerBasedEnv, ManagerBasedEnvCfg
from isaaclab.managers import EventTermCfg as EventTerm from isaaclab.managers import EventTermCfg as EventTerm
...@@ -64,10 +66,12 @@ class ObservationsCfg: ...@@ -64,10 +66,12 @@ class ObservationsCfg:
class EventCfg: class EventCfg:
"""Configuration for events.""" """Configuration for events."""
# on reset apply a new set of textures # on prestartup apply a new set of textures
# note from @mayank: Changed from 'reset' to 'prestartup' to make test pass.
# The error happens otherwise on Kit thread which is not the main thread.
cart_texture_randomizer = EventTerm( cart_texture_randomizer = EventTerm(
func=mdp.randomize_visual_texture_material, func=mdp.randomize_visual_texture_material,
mode="reset", mode="prestartup",
params={ params={
"asset_cfg": SceneEntityCfg("robot", body_names=["cart"]), "asset_cfg": SceneEntityCfg("robot", body_names=["cart"]),
"texture_paths": [ "texture_paths": [
...@@ -83,6 +87,7 @@ class EventCfg: ...@@ -83,6 +87,7 @@ class EventCfg:
}, },
) )
# on reset apply a new set of textures
pole_texture_randomizer = EventTerm( pole_texture_randomizer = EventTerm(
func=mdp.randomize_visual_texture_material, func=mdp.randomize_visual_texture_material,
mode="reset", mode="reset",
...@@ -153,35 +158,47 @@ class TestTextureRandomization(unittest.TestCase): ...@@ -153,35 +158,47 @@ class TestTextureRandomization(unittest.TestCase):
""" """
def test_texture_randomization(self): def test_texture_randomization(self):
# set the arguments """Test texture randomization for cartpole environment."""
env_cfg = CartpoleEnvCfg() for device in ["cpu", "cuda"]:
env_cfg.scene.num_envs = 16 with self.subTest(device=device):
env_cfg.scene.replicate_physics = False # create a new stage
omni.usd.get_context().new_stage()
# setup base environment
env = ManagerBasedEnv(cfg=env_cfg) # set the arguments
env_cfg = CartpoleEnvCfg()
# simulate physics env_cfg.scene.num_envs = 16
with torch.inference_mode(): env_cfg.scene.replicate_physics = False
for count in range(50): env_cfg.sim.device = device
# reset every few steps to check nothing breaks
if count % 10 == 0: # setup base environment
env.reset() env = ManagerBasedEnv(cfg=env_cfg)
# sample random actions
joint_efforts = torch.randn_like(env.action_manager.action) # simulate physics
# step the environment with torch.inference_mode():
env.step(joint_efforts) for count in range(50):
# reset every few steps to check nothing breaks
env.close() if count % 10 == 0:
env.reset()
# sample random actions
joint_efforts = torch.randn_like(env.action_manager.action)
# step the environment
env.step(joint_efforts)
env.close()
def test_texture_randomization_failure_replicate_physics(self): def test_texture_randomization_failure_replicate_physics(self):
with self.assertRaises(ValueError): """Test texture randomization failure when replicate physics is set to True."""
cfg_failure = CartpoleEnvCfg() # create a new stage
cfg_failure.scene.num_envs = 16 omni.usd.get_context().new_stage()
cfg_failure.scene.replicate_physics = True
env = ManagerBasedEnv(cfg_failure) # set the arguments
cfg_failure = CartpoleEnvCfg()
cfg_failure.scene.num_envs = 16
cfg_failure.scene.replicate_physics = True
env.close() with self.assertRaises(RuntimeError):
env = ManagerBasedEnv(cfg_failure)
env.close()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,8 +19,8 @@ import torch ...@@ -19,8 +19,8 @@ import torch
import unittest import unittest
from collections import namedtuple from collections import namedtuple
import isaaclab.sim as sim_utils
from isaaclab.managers import ManagerTermBase, ObservationGroupCfg, ObservationManager, ObservationTermCfg from isaaclab.managers import ManagerTermBase, ObservationGroupCfg, ObservationManager, ObservationTermCfg
from isaaclab.sim import SimulationContext
from isaaclab.utils import configclass, modifiers from isaaclab.utils import configclass, modifiers
...@@ -100,11 +100,15 @@ class TestObservationManager(unittest.TestCase): ...@@ -100,11 +100,15 @@ class TestObservationManager(unittest.TestCase):
self.num_envs = 20 self.num_envs = 20
self.device = "cuda:0" self.device = "cuda:0"
# set up sim # set up sim
self.sim = SimulationContext() sim_cfg = sim_utils.SimulationCfg(dt=self.dt, device=self.device)
sim = sim_utils.SimulationContext(sim_cfg)
# create dummy environment # create dummy environment
self.env = namedtuple("ManagerBasedEnv", ["num_envs", "device", "data", "dt", "sim"])( self.env = namedtuple("ManagerBasedEnv", ["num_envs", "device", "data", "dt", "sim"])(
self.num_envs, self.device, MyDataClass(self.num_envs, self.device), self.dt, self.sim self.num_envs, self.device, MyDataClass(self.num_envs, self.device), self.dt, sim
) )
# let the simulation play (we need this for observation manager to compute obs dims)
self.env.sim._app_control_on_stop_handle = None
self.env.sim.reset()
def test_str(self): def test_str(self):
"""Test the string representation of the observation manager.""" """Test the string representation of the observation manager."""
...@@ -382,24 +386,25 @@ class TestObservationManager(unittest.TestCase): ...@@ -382,24 +386,25 @@ class TestObservationManager(unittest.TestCase):
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device) expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env) expected_obs_term_2_data = lin_vel_w_data(self.env)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
print(expected_obs_data_t0, obs_policy) torch.testing.assert_close(expected_obs_data_t0, obs_policy)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test that the history buffer holds previous data # test that the history buffer holds previous data
for _ in range(HISTORY_LENGTH): for _ in range(HISTORY_LENGTH):
observations = self.obs_man.compute() observations = self.obs_man.compute()
obs_policy = observations["policy"] obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device) expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_data_t5 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) expected_obs_data_t5 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t5, obs_policy)) torch.testing.assert_close(expected_obs_data_t5, obs_policy)
# test reset # test reset
self.obs_man.reset() self.obs_man.reset()
observations = self.obs_man.compute() observations = self.obs_man.compute()
obs_policy = observations["policy"] obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy)) torch.testing.assert_close(expected_obs_data_t0, obs_policy)
# test reset of specific env ids # test reset of specific env ids
reset_env_ids = [2, 4, 16] reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids) self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])) torch.testing.assert_close(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])
def test_compute_with_2d_history(self): def test_compute_with_2d_history(self):
"""Test the observation computation with history buffers for 2D observations.""" """Test the observation computation with history buffers for 2D observations."""
...@@ -482,7 +487,7 @@ class TestObservationManager(unittest.TestCase): ...@@ -482,7 +487,7 @@ class TestObservationManager(unittest.TestCase):
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device) expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH) expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy)) torch.testing.assert_close(expected_obs_data_t0, obs_policy)
# test that the history buffer holds previous data # test that the history buffer holds previous data
for _ in range(GROUP_HISTORY_LENGTH): for _ in range(GROUP_HISTORY_LENGTH):
observations = self.obs_man.compute() observations = self.obs_man.compute()
...@@ -490,16 +495,16 @@ class TestObservationManager(unittest.TestCase): ...@@ -490,16 +495,16 @@ class TestObservationManager(unittest.TestCase):
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device) expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH) expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t10 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) expected_obs_data_t10 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t10, obs_policy)) torch.testing.assert_close(expected_obs_data_t10, obs_policy)
# test reset # test reset
self.obs_man.reset() self.obs_man.reset()
observations = self.obs_man.compute() observations = self.obs_man.compute()
obs_policy = observations["policy"] obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy)) torch.testing.assert_close(expected_obs_data_t0, obs_policy)
# test reset of specific env ids # test reset of specific env ids
reset_env_ids = [2, 4, 16] reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids) self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])) torch.testing.assert_close(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])
def test_invalid_observation_config(self): def test_invalid_observation_config(self):
"""Test the invalid observation config.""" """Test the invalid observation config."""
......
...@@ -82,10 +82,10 @@ class joint_pos_limits_penalty_ratio(ManagerTermBase): ...@@ -82,10 +82,10 @@ class joint_pos_limits_penalty_ratio(ManagerTermBase):
def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg): def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg):
# add default argument # add default argument
if "asset_cfg" not in cfg.params: asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot"))
cfg.params["asset_cfg"] = SceneEntityCfg("robot")
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[cfg.params["asset_cfg"].name] asset: Articulation = env.scene[asset_cfg.name]
# resolve the gear ratio for each joint # resolve the gear ratio for each joint
self.gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device) self.gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device)
index_list, _, value_list = string_utils.resolve_matching_names_values( index_list, _, value_list = string_utils.resolve_matching_names_values(
...@@ -95,7 +95,11 @@ class joint_pos_limits_penalty_ratio(ManagerTermBase): ...@@ -95,7 +95,11 @@ class joint_pos_limits_penalty_ratio(ManagerTermBase):
self.gear_ratio_scaled = self.gear_ratio / torch.max(self.gear_ratio) self.gear_ratio_scaled = self.gear_ratio / torch.max(self.gear_ratio)
def __call__( def __call__(
self, env: ManagerBasedRLEnv, threshold: float, gear_ratio: dict[str, float], asset_cfg: SceneEntityCfg self,
env: ManagerBasedRLEnv,
threshold: float,
gear_ratio: dict[str, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor: ) -> torch.Tensor:
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
...@@ -118,10 +122,10 @@ class power_consumption(ManagerTermBase): ...@@ -118,10 +122,10 @@ class power_consumption(ManagerTermBase):
def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg): def __init__(self, env: ManagerBasedRLEnv, cfg: RewardTermCfg):
# add default argument # add default argument
if "asset_cfg" not in cfg.params: asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot"))
cfg.params["asset_cfg"] = SceneEntityCfg("robot")
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[cfg.params["asset_cfg"].name] asset: Articulation = env.scene[asset_cfg.name]
# resolve the gear ratio for each joint # resolve the gear ratio for each joint
self.gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device) self.gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device)
index_list, _, value_list = string_utils.resolve_matching_names_values( index_list, _, value_list = string_utils.resolve_matching_names_values(
...@@ -130,7 +134,9 @@ class power_consumption(ManagerTermBase): ...@@ -130,7 +134,9 @@ class power_consumption(ManagerTermBase):
self.gear_ratio[:, index_list] = torch.tensor(value_list, device=env.device) self.gear_ratio[:, index_list] = torch.tensor(value_list, device=env.device)
self.gear_ratio_scaled = self.gear_ratio / torch.max(self.gear_ratio) self.gear_ratio_scaled = self.gear_ratio / torch.max(self.gear_ratio)
def __call__(self, env: ManagerBasedRLEnv, gear_ratio: dict[str, float], asset_cfg: SceneEntityCfg) -> torch.Tensor: def __call__(
self, env: ManagerBasedRLEnv, gear_ratio: dict[str, float], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
# return power = torque * velocity (here actions: joint torques) # return power = torque * velocity (here actions: joint torques)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment