Unverified Commit fb270ab5 authored by peterd-NV's avatar peterd-NV Committed by GitHub

Improves recorder performance and add additional recording capability (#3302)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

This PR adds fixes from LightWheel Labs and additional functionality to
the IsaacLab recorder.

Fixes # (issue)

- Fixes performance issue when recording long episode data by replacing
the use of torch.cat at every timestep with list append.
- Fixes configclass validation when key is not a string

Adds Functionality

- Adds optional episode meta data to HDF5 recorder
- Adds option to record data pre-physics step
- Adds joint target data to episode data. Joint target data can be
optionally recorded by users and replayed to bypass action term
controllers and improve replay determinism.


## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 4eae06fc
...@@ -358,6 +358,7 @@ def annotate_episode_in_auto_mode( ...@@ -358,6 +358,7 @@ def annotate_episode_in_auto_mode(
annotated_episode = env.recorder_manager.get_episode(0) annotated_episode = env.recorder_manager.get_episode(0)
subtask_term_signal_dict = annotated_episode.data["obs"]["datagen_info"]["subtask_term_signals"] subtask_term_signal_dict = annotated_episode.data["obs"]["datagen_info"]["subtask_term_signals"]
for signal_name, signal_flags in subtask_term_signal_dict.items(): for signal_name, signal_flags in subtask_term_signal_dict.items():
signal_flags = torch.tensor(signal_flags, device=env.device)
if not torch.any(signal_flags): if not torch.any(signal_flags):
is_episode_annotated_successfully = False is_episode_annotated_successfully = False
print(f'\tDid not detect completion for the subtask "{signal_name}".') print(f'\tDid not detect completion for the subtask "{signal_name}".')
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.45.10" version = "0.45.11"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.45.10 (2025-09-02) 0.45.11 (2025-09-04)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
Fixed Fixed
^^^^^ ^^^^^
* Fixes a high memory usage and perf slowdown issue in episode data by removing the use of torch.cat when appending to the episode data
at each timestep. The use of torch.cat was causing the episode data to be copied at each timestep, which causes high memory usage and
significant performance slowdown when recording longer episode data.
* Patches the configclass to allow validate dict with key is not a string.
Added
^^^^^
* Added optional episode metadata (ep_meta) to be stored in the HDF5 data attributes.
* Added option to record data pre-physics step.
* Added joint_target data to episode data. Joint target data can be optionally recorded by the user and replayed to improve
determinism of replay.
0.45.10 (2025-09-02)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed regression in reach task configuration where the gripper command was being returned. * Fixed regression in reach task configuration where the gripper command was being returned.
* Added :attr:`~isaaclab.devices.Se3GamepadCfg.gripper_term` to :class:`~isaaclab.devices.Se3GamepadCfg` * Added :attr:`~isaaclab.devices.Se3GamepadCfg.gripper_term` to :class:`~isaaclab.devices.Se3GamepadCfg`
to control whether the gamepad device should return a gripper command. to control whether the gamepad device should return a gripper command.
......
...@@ -188,6 +188,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env): ...@@ -188,6 +188,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
self.scene.write_data_to_sim() self.scene.write_data_to_sim()
# simulate # simulate
self.sim.step(render=False) self.sim.step(render=False)
self.recorder_manager.record_post_physics_decimation_step()
# render between steps only if the GUI or an RTX sensor needs it # render between steps only if the GUI or an RTX sensor needs it
# note: we assume the render interval to be the shortest accepted rendering interval. # note: we assume the render interval to be the shortest accepted rendering interval.
# If a camera needs rendering at a faster frequency, this will lead to unexpected behavior. # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior.
......
...@@ -123,6 +123,15 @@ class RecorderTerm(ManagerTermBase): ...@@ -123,6 +123,15 @@ class RecorderTerm(ManagerTermBase):
""" """
return None, None return None, None
def record_post_physics_decimation_step(self) -> tuple[str | None, torch.Tensor | dict | None]:
"""Record data after the physics step is executed in the decimation loop.
Returns:
A tuple of key and value to be recorded.
Please refer to the `record_pre_reset` function for more details.
"""
return None, None
class RecorderManager(ManagerBase): class RecorderManager(ManagerBase):
"""Manager for recording data from recorder terms.""" """Manager for recording data from recorder terms."""
...@@ -362,6 +371,16 @@ class RecorderManager(ManagerBase): ...@@ -362,6 +371,16 @@ class RecorderManager(ManagerBase):
key, value = term.record_post_step() key, value = term.record_post_step()
self.add_to_episodes(key, value) self.add_to_episodes(key, value)
def record_post_physics_decimation_step(self) -> None:
"""Trigger recorder terms for post-physics step functions in the decimation loop."""
# Do nothing if no active recorder terms are provided
if len(self.active_terms) == 0:
return
for term in self._terms.values():
key, value = term.record_post_physics_decimation_step()
self.add_to_episodes(key, value)
def record_pre_reset(self, env_ids: Sequence[int] | None, force_export_or_skip=None) -> None: def record_pre_reset(self, env_ids: Sequence[int] | None, force_export_or_skip=None) -> None:
"""Trigger recorder terms for pre-reset functions. """Trigger recorder terms for pre-reset functions.
...@@ -406,6 +425,23 @@ class RecorderManager(ManagerBase): ...@@ -406,6 +425,23 @@ class RecorderManager(ManagerBase):
key, value = term.record_post_reset(env_ids) key, value = term.record_post_reset(env_ids)
self.add_to_episodes(key, value, env_ids) self.add_to_episodes(key, value, env_ids)
def get_ep_meta(self) -> dict:
"""Get the episode metadata."""
if not hasattr(self._env.cfg, "get_ep_meta"):
# Add basic episode metadata
ep_meta = dict()
ep_meta["sim_args"] = {
"dt": self._env.cfg.sim.dt,
"decimation": self._env.cfg.decimation,
"render_interval": self._env.cfg.sim.render_interval,
"num_envs": self._env.cfg.scene.num_envs,
}
return ep_meta
# Add custom episode metadata if available
ep_meta = self._env.cfg.get_ep_meta()
return ep_meta
def export_episodes(self, env_ids: Sequence[int] | None = None) -> None: def export_episodes(self, env_ids: Sequence[int] | None = None) -> None:
"""Concludes and exports the episodes for the given environment ids. """Concludes and exports the episodes for the given environment ids.
...@@ -424,8 +460,18 @@ class RecorderManager(ManagerBase): ...@@ -424,8 +460,18 @@ class RecorderManager(ManagerBase):
# Export episode data through dataset exporter # Export episode data through dataset exporter
need_to_flush = False need_to_flush = False
if any(env_id in self._episodes and not self._episodes[env_id].is_empty() for env_id in env_ids):
ep_meta = self.get_ep_meta()
if self._dataset_file_handler is not None:
self._dataset_file_handler.add_env_args(ep_meta)
if self._failed_episode_dataset_file_handler is not None:
self._failed_episode_dataset_file_handler.add_env_args(ep_meta)
for env_id in env_ids: for env_id in env_ids:
if env_id in self._episodes and not self._episodes[env_id].is_empty(): if env_id in self._episodes and not self._episodes[env_id].is_empty():
self._episodes[env_id].pre_export()
episode_succeeded = self._episodes[env_id].success episode_succeeded = self._episodes[env_id].success
target_dataset_file_handler = None target_dataset_file_handler = None
if (self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_ALL) or ( if (self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_ALL) or (
......
...@@ -268,7 +268,11 @@ def _validate(obj: object, prefix: str = "") -> list[str]: ...@@ -268,7 +268,11 @@ def _validate(obj: object, prefix: str = "") -> list[str]:
missing_fields.extend(_validate(item, prefix=current_path)) missing_fields.extend(_validate(item, prefix=current_path))
return missing_fields return missing_fields
elif isinstance(obj, dict): elif isinstance(obj, dict):
obj_dict = obj # Convert any non-string keys to strings to allow validation of dict with non-string keys
if any(not isinstance(key, str) for key in obj.keys()):
obj_dict = {str(key): value for key, value in obj.items()}
else:
obj_dict = obj
elif hasattr(obj, "__dict__"): elif hasattr(obj, "__dict__"):
obj_dict = obj.__dict__ obj_dict = obj.__dict__
else: else:
......
...@@ -21,6 +21,7 @@ class EpisodeData: ...@@ -21,6 +21,7 @@ class EpisodeData:
self._data = dict() self._data = dict()
self._next_action_index = 0 self._next_action_index = 0
self._next_state_index = 0 self._next_state_index = 0
self._next_joint_target_index = 0
self._seed = None self._seed = None
self._env_id = None self._env_id = None
self._success = None self._success = None
...@@ -110,12 +111,11 @@ class EpisodeData: ...@@ -110,12 +111,11 @@ class EpisodeData:
for sub_key_index in range(len(sub_keys)): for sub_key_index in range(len(sub_keys)):
if sub_key_index == len(sub_keys) - 1: if sub_key_index == len(sub_keys) - 1:
# Add value to the final dict layer # Add value to the final dict layer
# Use lists to prevent slow tensor copy during concatenation
if sub_keys[sub_key_index] not in current_dataset_pointer: if sub_keys[sub_key_index] not in current_dataset_pointer:
current_dataset_pointer[sub_keys[sub_key_index]] = value.unsqueeze(0).clone() current_dataset_pointer[sub_keys[sub_key_index]] = [value.clone()]
else: else:
current_dataset_pointer[sub_keys[sub_key_index]] = torch.cat( current_dataset_pointer[sub_keys[sub_key_index]].append(value.clone())
(current_dataset_pointer[sub_keys[sub_key_index]], value.unsqueeze(0))
)
break break
# key index # key index
if sub_keys[sub_key_index] not in current_dataset_pointer: if sub_keys[sub_key_index] not in current_dataset_pointer:
...@@ -160,7 +160,7 @@ class EpisodeData: ...@@ -160,7 +160,7 @@ class EpisodeData:
elif isinstance(states, torch.Tensor): elif isinstance(states, torch.Tensor):
if state_index >= len(states): if state_index >= len(states):
return None return None
output_state = states[state_index] output_state = states[state_index, None]
else: else:
raise ValueError(f"Invalid state type: {type(states)}") raise ValueError(f"Invalid state type: {type(states)}")
return output_state return output_state
...@@ -174,3 +174,47 @@ class EpisodeData: ...@@ -174,3 +174,47 @@ class EpisodeData:
if state is not None: if state is not None:
self._next_state_index += 1 self._next_state_index += 1
return state return state
def get_joint_target(self, joint_target_index) -> dict | torch.Tensor | None:
"""Get the joint target of the specified index from the dataset."""
if "joint_targets" not in self._data:
return None
joint_targets = self._data["joint_targets"]
def get_joint_target_helper(joint_targets, joint_target_index) -> dict | torch.Tensor | None:
if isinstance(joint_targets, dict):
output_joint_targets = dict()
for key, value in joint_targets.items():
output_joint_targets[key] = get_joint_target_helper(value, joint_target_index)
if output_joint_targets[key] is None:
return None
elif isinstance(joint_targets, torch.Tensor):
if joint_target_index >= len(joint_targets):
return None
output_joint_targets = joint_targets[joint_target_index]
else:
raise ValueError(f"Invalid joint target type: {type(joint_targets)}")
return output_joint_targets
output_joint_targets = get_joint_target_helper(joint_targets, joint_target_index)
return output_joint_targets
def get_next_joint_target(self) -> dict | torch.Tensor | None:
"""Get the next joint target from the dataset."""
joint_target = self.get_joint_target(self._next_joint_target_index)
if joint_target is not None:
self._next_joint_target_index += 1
return joint_target
def pre_export(self):
"""Prepare data for export by converting lists to tensors."""
def pre_export_helper(data):
for key, value in data.items():
if isinstance(value, list):
data[key] = torch.stack(value)
elif isinstance(value, dict):
pre_export_helper(value)
pre_export_helper(self._data)
...@@ -78,6 +78,28 @@ class DummyRecorderManagerCfg(RecorderManagerBaseCfg): ...@@ -78,6 +78,28 @@ class DummyRecorderManagerCfg(RecorderManagerBaseCfg):
dataset_export_mode = DatasetExportMode.EXPORT_ALL dataset_export_mode = DatasetExportMode.EXPORT_ALL
@configclass
class DummyEnvCfg:
"""Dummy environment configuration."""
@configclass
class DummySimCfg:
"""Configuration for the dummy sim."""
dt = 0.01
render_interval = 1
@configclass
class DummySceneCfg:
"""Configuration for the dummy scene."""
num_envs = 1
decimation = 1
sim = DummySimCfg()
scene = DummySceneCfg()
def create_dummy_env(device: str = "cpu") -> ManagerBasedEnv: def create_dummy_env(device: str = "cpu") -> ManagerBasedEnv:
"""Create a dummy environment.""" """Create a dummy environment."""
...@@ -86,8 +108,10 @@ def create_dummy_env(device: str = "cpu") -> ManagerBasedEnv: ...@@ -86,8 +108,10 @@ def create_dummy_env(device: str = "cpu") -> ManagerBasedEnv:
dummy_termination_manager = DummyTerminationManager() dummy_termination_manager = DummyTerminationManager()
sim = SimulationContext() sim = SimulationContext()
dummy_cfg = DummyEnvCfg()
return namedtuple("ManagerBasedEnv", ["num_envs", "device", "sim", "cfg", "termination_manager"])( return namedtuple("ManagerBasedEnv", ["num_envs", "device", "sim", "cfg", "termination_manager"])(
20, device, sim, dict(), dummy_termination_manager 20, device, sim, dummy_cfg, dummy_termination_manager
) )
...@@ -142,8 +166,8 @@ def test_record(dataset_dir): ...@@ -142,8 +166,8 @@ def test_record(dataset_dir):
# check the recorded data # check the recorded data
for env_id in range(env.num_envs): for env_id in range(env.num_envs):
episode = recorder_manager.get_episode(env_id) episode = recorder_manager.get_episode(env_id)
assert episode.data["record_pre_step"].shape == (2, 4) assert torch.stack(episode.data["record_pre_step"]).shape == (2, 4)
assert episode.data["record_post_step"].shape == (2, 5) assert torch.stack(episode.data["record_post_step"]).shape == (2, 5)
# Trigger pre-reset callbacks which then export and clean the episode data # Trigger pre-reset callbacks which then export and clean the episode data
recorder_manager.record_pre_reset(env_ids=None) recorder_manager.record_pre_reset(env_ids=None)
...@@ -154,4 +178,4 @@ def test_record(dataset_dir): ...@@ -154,4 +178,4 @@ def test_record(dataset_dir):
recorder_manager.record_post_reset(env_ids=None) recorder_manager.record_post_reset(env_ids=None)
for env_id in range(env.num_envs): for env_id in range(env.num_envs):
episode = recorder_manager.get_episode(env_id) episode = recorder_manager.get_episode(env_id)
assert episode.data["record_post_reset"].shape == (1, 3) assert torch.stack(episode.data["record_post_reset"]).shape == (1, 3)
...@@ -38,13 +38,13 @@ def test_add_tensors(device): ...@@ -38,13 +38,13 @@ def test_add_tensors(device):
# test adding data to a key that does not exist # test adding data to a key that does not exist
episode.add("key", dummy_data_0) episode.add("key", dummy_data_0)
key_data = episode.data.get("key") key_data = torch.stack(episode.data.get("key"))
assert key_data is not None assert key_data is not None
assert torch.equal(key_data, dummy_data_0.unsqueeze(0)) assert torch.equal(key_data, dummy_data_0.unsqueeze(0))
# test adding data to a key that exists # test adding data to a key that exists
episode.add("key", dummy_data_1) episode.add("key", dummy_data_1)
key_data = episode.data.get("key") key_data = torch.stack(episode.data.get("key"))
assert key_data is not None assert key_data is not None
assert torch.equal(key_data, expected_added_data) assert torch.equal(key_data, expected_added_data)
...@@ -52,7 +52,7 @@ def test_add_tensors(device): ...@@ -52,7 +52,7 @@ def test_add_tensors(device):
episode.add("first/second", dummy_data_0) episode.add("first/second", dummy_data_0)
first_data = episode.data.get("first") first_data = episode.data.get("first")
assert first_data is not None assert first_data is not None
second_data = first_data.get("second") second_data = torch.stack(first_data.get("second"))
assert second_data is not None assert second_data is not None
assert torch.equal(second_data, dummy_data_0.unsqueeze(0)) assert torch.equal(second_data, dummy_data_0.unsqueeze(0))
...@@ -60,7 +60,7 @@ def test_add_tensors(device): ...@@ -60,7 +60,7 @@ def test_add_tensors(device):
episode.add("first/second", dummy_data_1) episode.add("first/second", dummy_data_1)
first_data = episode.data.get("first") first_data = episode.data.get("first")
assert first_data is not None assert first_data is not None
second_data = first_data.get("second") second_data = torch.stack(first_data.get("second"))
assert second_data is not None assert second_data is not None
assert torch.equal(second_data, expected_added_data) assert torch.equal(second_data, expected_added_data)
...@@ -83,15 +83,15 @@ def test_add_dict_tensors(device): ...@@ -83,15 +83,15 @@ def test_add_dict_tensors(device):
episode.add("key", dummy_dict_data_0) episode.add("key", dummy_dict_data_0)
key_data = episode.data.get("key") key_data = episode.data.get("key")
assert key_data is not None assert key_data is not None
key_0_data = key_data.get("key_0") key_0_data = torch.stack(key_data.get("key_0"))
assert key_0_data is not None assert key_0_data is not None
assert torch.equal(key_0_data, torch.tensor([[0]], device=device)) assert torch.equal(key_0_data, torch.tensor([[0]], device=device))
key_1_data = key_data.get("key_1") key_1_data = key_data.get("key_1")
assert key_1_data is not None assert key_1_data is not None
key_1_0_data = key_1_data.get("key_1_0") key_1_0_data = torch.stack(key_1_data.get("key_1_0"))
assert key_1_0_data is not None assert key_1_0_data is not None
assert torch.equal(key_1_0_data, torch.tensor([[1]], device=device)) assert torch.equal(key_1_0_data, torch.tensor([[1]], device=device))
key_1_1_data = key_1_data.get("key_1_1") key_1_1_data = torch.stack(key_1_data.get("key_1_1"))
assert key_1_1_data is not None assert key_1_1_data is not None
assert torch.equal(key_1_1_data, torch.tensor([[2]], device=device)) assert torch.equal(key_1_1_data, torch.tensor([[2]], device=device))
...@@ -99,15 +99,15 @@ def test_add_dict_tensors(device): ...@@ -99,15 +99,15 @@ def test_add_dict_tensors(device):
episode.add("key", dummy_dict_data_1) episode.add("key", dummy_dict_data_1)
key_data = episode.data.get("key") key_data = episode.data.get("key")
assert key_data is not None assert key_data is not None
key_0_data = key_data.get("key_0") key_0_data = torch.stack(key_data.get("key_0"))
assert key_0_data is not None assert key_0_data is not None
assert torch.equal(key_0_data, torch.tensor([[0], [3]], device=device)) assert torch.equal(key_0_data, torch.tensor([[0], [3]], device=device))
key_1_data = key_data.get("key_1") key_1_data = key_data.get("key_1")
assert key_1_data is not None assert key_1_data is not None
key_1_0_data = key_1_data.get("key_1_0") key_1_0_data = torch.stack(key_1_data.get("key_1_0"))
assert key_1_0_data is not None assert key_1_0_data is not None
assert torch.equal(key_1_0_data, torch.tensor([[1], [4]], device=device)) assert torch.equal(key_1_0_data, torch.tensor([[1], [4]], device=device))
key_1_1_data = key_1_data.get("key_1_1") key_1_1_data = torch.stack(key_1_data.get("key_1_1"))
assert key_1_1_data is not None assert key_1_1_data is not None
assert torch.equal(key_1_1_data, torch.tensor([[2], [5]], device=device)) assert torch.equal(key_1_1_data, torch.tensor([[2], [5]], device=device))
...@@ -119,7 +119,7 @@ def test_get_initial_state(device): ...@@ -119,7 +119,7 @@ def test_get_initial_state(device):
episode = EpisodeData() episode = EpisodeData()
episode.add("initial_state", dummy_initial_state) episode.add("initial_state", dummy_initial_state)
initial_state = episode.get_initial_state() initial_state = torch.stack(episode.get_initial_state())
assert initial_state is not None assert initial_state is not None
assert torch.equal(initial_state, dummy_initial_state.unsqueeze(0)) assert torch.equal(initial_state, dummy_initial_state.unsqueeze(0))
......
...@@ -82,6 +82,7 @@ def test_write_and_load_episode(temp_dir, device): ...@@ -82,6 +82,7 @@ def test_write_and_load_episode(temp_dir, device):
test_episode = create_test_episode(device) test_episode = create_test_episode(device)
# write the episode to the dataset # write the episode to the dataset
test_episode.pre_export()
dataset_file_handler.write_episode(test_episode) dataset_file_handler.write_episode(test_episode)
dataset_file_handler.flush() dataset_file_handler.flush()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment