Unverified Commit fb270ab5 authored by peterd-NV's avatar peterd-NV Committed by GitHub

Improves recorder performance and add additional recording capability (#3302)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

This PR adds fixes from LightWheel Labs and additional functionality to
the IsaacLab recorder.

Fixes # (issue)

- Fixes performance issue when recording long episode data by replacing
the use of torch.cat at every timestep with list append.
- Fixes configclass validation when key is not a string

Adds Functionality

- Adds optional episode meta data to HDF5 recorder
- Adds option to record data pre-physics step
- Adds joint target data to episode data. Joint target data can be
optionally recorded by users and replayed to bypass action term
controllers and improve replay determinism.


## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 4eae06fc
......@@ -358,6 +358,7 @@ def annotate_episode_in_auto_mode(
annotated_episode = env.recorder_manager.get_episode(0)
subtask_term_signal_dict = annotated_episode.data["obs"]["datagen_info"]["subtask_term_signals"]
for signal_name, signal_flags in subtask_term_signal_dict.items():
signal_flags = torch.tensor(signal_flags, device=env.device)
if not torch.any(signal_flags):
is_episode_annotated_successfully = False
print(f'\tDid not detect completion for the subtask "{signal_name}".')
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.45.10"
version = "0.45.11"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.45.10 (2025-09-02)
0.45.11 (2025-09-04)
~~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixes a high memory usage and perf slowdown issue in episode data by removing the use of torch.cat when appending to the episode data
at each timestep. The use of torch.cat was causing the episode data to be copied at each timestep, which causes high memory usage and
significant performance slowdown when recording longer episode data.
* Patches the configclass to allow validate dict with key is not a string.
Added
^^^^^
* Added optional episode metadata (ep_meta) to be stored in the HDF5 data attributes.
* Added option to record data pre-physics step.
* Added joint_target data to episode data. Joint target data can be optionally recorded by the user and replayed to improve
determinism of replay.
0.45.10 (2025-09-02)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed regression in reach task configuration where the gripper command was being returned.
* Added :attr:`~isaaclab.devices.Se3GamepadCfg.gripper_term` to :class:`~isaaclab.devices.Se3GamepadCfg`
to control whether the gamepad device should return a gripper command.
......
......@@ -188,6 +188,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
self.scene.write_data_to_sim()
# simulate
self.sim.step(render=False)
self.recorder_manager.record_post_physics_decimation_step()
# render between steps only if the GUI or an RTX sensor needs it
# note: we assume the render interval to be the shortest accepted rendering interval.
# If a camera needs rendering at a faster frequency, this will lead to unexpected behavior.
......
......@@ -123,6 +123,15 @@ class RecorderTerm(ManagerTermBase):
"""
return None, None
def record_post_physics_decimation_step(self) -> tuple[str | None, torch.Tensor | dict | None]:
"""Record data after the physics step is executed in the decimation loop.
Returns:
A tuple of key and value to be recorded.
Please refer to the `record_pre_reset` function for more details.
"""
return None, None
class RecorderManager(ManagerBase):
"""Manager for recording data from recorder terms."""
......@@ -362,6 +371,16 @@ class RecorderManager(ManagerBase):
key, value = term.record_post_step()
self.add_to_episodes(key, value)
def record_post_physics_decimation_step(self) -> None:
"""Trigger recorder terms for post-physics step functions in the decimation loop."""
# Do nothing if no active recorder terms are provided
if len(self.active_terms) == 0:
return
for term in self._terms.values():
key, value = term.record_post_physics_decimation_step()
self.add_to_episodes(key, value)
def record_pre_reset(self, env_ids: Sequence[int] | None, force_export_or_skip=None) -> None:
"""Trigger recorder terms for pre-reset functions.
......@@ -406,6 +425,23 @@ class RecorderManager(ManagerBase):
key, value = term.record_post_reset(env_ids)
self.add_to_episodes(key, value, env_ids)
def get_ep_meta(self) -> dict:
"""Get the episode metadata."""
if not hasattr(self._env.cfg, "get_ep_meta"):
# Add basic episode metadata
ep_meta = dict()
ep_meta["sim_args"] = {
"dt": self._env.cfg.sim.dt,
"decimation": self._env.cfg.decimation,
"render_interval": self._env.cfg.sim.render_interval,
"num_envs": self._env.cfg.scene.num_envs,
}
return ep_meta
# Add custom episode metadata if available
ep_meta = self._env.cfg.get_ep_meta()
return ep_meta
def export_episodes(self, env_ids: Sequence[int] | None = None) -> None:
"""Concludes and exports the episodes for the given environment ids.
......@@ -424,8 +460,18 @@ class RecorderManager(ManagerBase):
# Export episode data through dataset exporter
need_to_flush = False
if any(env_id in self._episodes and not self._episodes[env_id].is_empty() for env_id in env_ids):
ep_meta = self.get_ep_meta()
if self._dataset_file_handler is not None:
self._dataset_file_handler.add_env_args(ep_meta)
if self._failed_episode_dataset_file_handler is not None:
self._failed_episode_dataset_file_handler.add_env_args(ep_meta)
for env_id in env_ids:
if env_id in self._episodes and not self._episodes[env_id].is_empty():
self._episodes[env_id].pre_export()
episode_succeeded = self._episodes[env_id].success
target_dataset_file_handler = None
if (self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_ALL) or (
......
......@@ -268,7 +268,11 @@ def _validate(obj: object, prefix: str = "") -> list[str]:
missing_fields.extend(_validate(item, prefix=current_path))
return missing_fields
elif isinstance(obj, dict):
obj_dict = obj
# Convert any non-string keys to strings to allow validation of dict with non-string keys
if any(not isinstance(key, str) for key in obj.keys()):
obj_dict = {str(key): value for key, value in obj.items()}
else:
obj_dict = obj
elif hasattr(obj, "__dict__"):
obj_dict = obj.__dict__
else:
......
......@@ -21,6 +21,7 @@ class EpisodeData:
self._data = dict()
self._next_action_index = 0
self._next_state_index = 0
self._next_joint_target_index = 0
self._seed = None
self._env_id = None
self._success = None
......@@ -110,12 +111,11 @@ class EpisodeData:
for sub_key_index in range(len(sub_keys)):
if sub_key_index == len(sub_keys) - 1:
# Add value to the final dict layer
# Use lists to prevent slow tensor copy during concatenation
if sub_keys[sub_key_index] not in current_dataset_pointer:
current_dataset_pointer[sub_keys[sub_key_index]] = value.unsqueeze(0).clone()
current_dataset_pointer[sub_keys[sub_key_index]] = [value.clone()]
else:
current_dataset_pointer[sub_keys[sub_key_index]] = torch.cat(
(current_dataset_pointer[sub_keys[sub_key_index]], value.unsqueeze(0))
)
current_dataset_pointer[sub_keys[sub_key_index]].append(value.clone())
break
# key index
if sub_keys[sub_key_index] not in current_dataset_pointer:
......@@ -160,7 +160,7 @@ class EpisodeData:
elif isinstance(states, torch.Tensor):
if state_index >= len(states):
return None
output_state = states[state_index]
output_state = states[state_index, None]
else:
raise ValueError(f"Invalid state type: {type(states)}")
return output_state
......@@ -174,3 +174,47 @@ class EpisodeData:
if state is not None:
self._next_state_index += 1
return state
def get_joint_target(self, joint_target_index) -> dict | torch.Tensor | None:
"""Get the joint target of the specified index from the dataset."""
if "joint_targets" not in self._data:
return None
joint_targets = self._data["joint_targets"]
def get_joint_target_helper(joint_targets, joint_target_index) -> dict | torch.Tensor | None:
if isinstance(joint_targets, dict):
output_joint_targets = dict()
for key, value in joint_targets.items():
output_joint_targets[key] = get_joint_target_helper(value, joint_target_index)
if output_joint_targets[key] is None:
return None
elif isinstance(joint_targets, torch.Tensor):
if joint_target_index >= len(joint_targets):
return None
output_joint_targets = joint_targets[joint_target_index]
else:
raise ValueError(f"Invalid joint target type: {type(joint_targets)}")
return output_joint_targets
output_joint_targets = get_joint_target_helper(joint_targets, joint_target_index)
return output_joint_targets
def get_next_joint_target(self) -> dict | torch.Tensor | None:
"""Get the next joint target from the dataset."""
joint_target = self.get_joint_target(self._next_joint_target_index)
if joint_target is not None:
self._next_joint_target_index += 1
return joint_target
def pre_export(self):
"""Prepare data for export by converting lists to tensors."""
def pre_export_helper(data):
for key, value in data.items():
if isinstance(value, list):
data[key] = torch.stack(value)
elif isinstance(value, dict):
pre_export_helper(value)
pre_export_helper(self._data)
......@@ -78,6 +78,28 @@ class DummyRecorderManagerCfg(RecorderManagerBaseCfg):
dataset_export_mode = DatasetExportMode.EXPORT_ALL
@configclass
class DummyEnvCfg:
"""Dummy environment configuration."""
@configclass
class DummySimCfg:
"""Configuration for the dummy sim."""
dt = 0.01
render_interval = 1
@configclass
class DummySceneCfg:
"""Configuration for the dummy scene."""
num_envs = 1
decimation = 1
sim = DummySimCfg()
scene = DummySceneCfg()
def create_dummy_env(device: str = "cpu") -> ManagerBasedEnv:
"""Create a dummy environment."""
......@@ -86,8 +108,10 @@ def create_dummy_env(device: str = "cpu") -> ManagerBasedEnv:
dummy_termination_manager = DummyTerminationManager()
sim = SimulationContext()
dummy_cfg = DummyEnvCfg()
return namedtuple("ManagerBasedEnv", ["num_envs", "device", "sim", "cfg", "termination_manager"])(
20, device, sim, dict(), dummy_termination_manager
20, device, sim, dummy_cfg, dummy_termination_manager
)
......@@ -142,8 +166,8 @@ def test_record(dataset_dir):
# check the recorded data
for env_id in range(env.num_envs):
episode = recorder_manager.get_episode(env_id)
assert episode.data["record_pre_step"].shape == (2, 4)
assert episode.data["record_post_step"].shape == (2, 5)
assert torch.stack(episode.data["record_pre_step"]).shape == (2, 4)
assert torch.stack(episode.data["record_post_step"]).shape == (2, 5)
# Trigger pre-reset callbacks which then export and clean the episode data
recorder_manager.record_pre_reset(env_ids=None)
......@@ -154,4 +178,4 @@ def test_record(dataset_dir):
recorder_manager.record_post_reset(env_ids=None)
for env_id in range(env.num_envs):
episode = recorder_manager.get_episode(env_id)
assert episode.data["record_post_reset"].shape == (1, 3)
assert torch.stack(episode.data["record_post_reset"]).shape == (1, 3)
......@@ -38,13 +38,13 @@ def test_add_tensors(device):
# test adding data to a key that does not exist
episode.add("key", dummy_data_0)
key_data = episode.data.get("key")
key_data = torch.stack(episode.data.get("key"))
assert key_data is not None
assert torch.equal(key_data, dummy_data_0.unsqueeze(0))
# test adding data to a key that exists
episode.add("key", dummy_data_1)
key_data = episode.data.get("key")
key_data = torch.stack(episode.data.get("key"))
assert key_data is not None
assert torch.equal(key_data, expected_added_data)
......@@ -52,7 +52,7 @@ def test_add_tensors(device):
episode.add("first/second", dummy_data_0)
first_data = episode.data.get("first")
assert first_data is not None
second_data = first_data.get("second")
second_data = torch.stack(first_data.get("second"))
assert second_data is not None
assert torch.equal(second_data, dummy_data_0.unsqueeze(0))
......@@ -60,7 +60,7 @@ def test_add_tensors(device):
episode.add("first/second", dummy_data_1)
first_data = episode.data.get("first")
assert first_data is not None
second_data = first_data.get("second")
second_data = torch.stack(first_data.get("second"))
assert second_data is not None
assert torch.equal(second_data, expected_added_data)
......@@ -83,15 +83,15 @@ def test_add_dict_tensors(device):
episode.add("key", dummy_dict_data_0)
key_data = episode.data.get("key")
assert key_data is not None
key_0_data = key_data.get("key_0")
key_0_data = torch.stack(key_data.get("key_0"))
assert key_0_data is not None
assert torch.equal(key_0_data, torch.tensor([[0]], device=device))
key_1_data = key_data.get("key_1")
assert key_1_data is not None
key_1_0_data = key_1_data.get("key_1_0")
key_1_0_data = torch.stack(key_1_data.get("key_1_0"))
assert key_1_0_data is not None
assert torch.equal(key_1_0_data, torch.tensor([[1]], device=device))
key_1_1_data = key_1_data.get("key_1_1")
key_1_1_data = torch.stack(key_1_data.get("key_1_1"))
assert key_1_1_data is not None
assert torch.equal(key_1_1_data, torch.tensor([[2]], device=device))
......@@ -99,15 +99,15 @@ def test_add_dict_tensors(device):
episode.add("key", dummy_dict_data_1)
key_data = episode.data.get("key")
assert key_data is not None
key_0_data = key_data.get("key_0")
key_0_data = torch.stack(key_data.get("key_0"))
assert key_0_data is not None
assert torch.equal(key_0_data, torch.tensor([[0], [3]], device=device))
key_1_data = key_data.get("key_1")
assert key_1_data is not None
key_1_0_data = key_1_data.get("key_1_0")
key_1_0_data = torch.stack(key_1_data.get("key_1_0"))
assert key_1_0_data is not None
assert torch.equal(key_1_0_data, torch.tensor([[1], [4]], device=device))
key_1_1_data = key_1_data.get("key_1_1")
key_1_1_data = torch.stack(key_1_data.get("key_1_1"))
assert key_1_1_data is not None
assert torch.equal(key_1_1_data, torch.tensor([[2], [5]], device=device))
......@@ -119,7 +119,7 @@ def test_get_initial_state(device):
episode = EpisodeData()
episode.add("initial_state", dummy_initial_state)
initial_state = episode.get_initial_state()
initial_state = torch.stack(episode.get_initial_state())
assert initial_state is not None
assert torch.equal(initial_state, dummy_initial_state.unsqueeze(0))
......
......@@ -82,6 +82,7 @@ def test_write_and_load_episode(temp_dir, device):
test_episode = create_test_episode(device)
# write the episode to the dataset
test_episode.pre_export()
dataset_file_handler.write_episode(test_episode)
dataset_file_handler.flush()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment