Unverified Commit 01a25477 authored by CY Chen's avatar CY Chen Committed by GitHub

Adds recorder manager in manager-based environments (#1336)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link: https://isaac-sim.github.io/IsaacLab/source/refs/contributing.html
-->

This PR adds a recorder manager (RecorderManager) and relevant utility
classes for recording data produced in various reset and step stages in
manager-based environments.

Wither the built-in recorder manager, users can create custom recorder
terms in their environment configurations with callback functions
returning tensors to be recorded as environments advance. It is
particularly useful for implementing an app that collects human-operated
demos and for those who want to record robot actions for
post-validation/replay in Isaac Lab environments.

The recorder manager works in both single- and multi-environment use
cases. An episode for an environment instance is exported to a dataset
file, via a dataset file handler, upon completion (a termination term is
signaled a reset to the environment instance is triggered).

By default, the recorder manager is inactive (by assigning no recorder
terms in the default configuration), which should have minimal
performance impact for existing apps that do not require data recording.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update -- to be updated in later
PRs

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarCY Chen <cyc@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent efc1a0b2
......@@ -9,3 +9,4 @@
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.jit filter=lfs diff=lfs merge=lfs -text
*.hdf5 filter=lfs diff=lfs merge=lfs -text
......@@ -41,6 +41,7 @@ Guidelines for modifications:
* Brayden Zhang
* Calvin Yu
* Chenyu Yang
* CY (Chien-Ying) Chen
* David Yang
* Dorsa Rohani
* Felix Yu
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.16"
version = "0.27.17"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.27.17 (2024-12-02)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added :class:`~omni.isaac.lab.managers.RecorderManager` and its utility classes to record data from the simulation.
* Added :class:`~omni.isaac.lab.utils.datasets.EpisodeData` to store data for an episode.
* Added :class:`~omni.isaac.lab.utils.datasets.DatasetFileHandlerBase` as a base class for handling dataset files.
* Added :class:`~omni.isaac.lab.utils.datasets.HDF5DatasetFileHandler` as a dataset file handler implementation to
export and load episodes from HDF5 files.
* Added ``record_demos.py`` script to record human-teleoperated demos for a specified task and export to an HDF5 file.
* Added ``replay_demos.py`` script to replay demos loaded from an HDF5 file.
0.27.16 (2024-11-21)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -11,7 +11,7 @@ from typing import Any
import omni.isaac.core.utils.torch as torch_utils
import omni.log
from omni.isaac.lab.managers import ActionManager, EventManager, ObservationManager
from omni.isaac.lab.managers import ActionManager, EventManager, ObservationManager, RecorderManager
from omni.isaac.lab.scene import InteractiveScene
from omni.isaac.lab.sim import SimulationContext
from omni.isaac.lab.utils.timer import Timer
......@@ -45,6 +45,9 @@ class ManagerBasedEnv:
This includes resetting the scene to a default state, applying random pushes to the robot at different intervals
of time, or randomizing properties such as mass and friction coefficients. This is useful for training
and evaluating the robot in a variety of scenarios.
* **Recorder Manager**: The recorder manager that handles recording data produced during different steps
in the simulation. This includes recording in the beginning and end of a reset and a step. The recorded data
is distinguished per episode, per environment and can be exported through a dataset file handler to a file.
The environment provides a unified interface for interacting with the simulation. However, it does not
include task-specific quantities such as the reward function, or the termination conditions. These
......@@ -153,6 +156,9 @@ class ManagerBasedEnv:
# allocate dictionary to store metrics
self.extras = {}
# initialize observation buffers
self.obs_buf = {}
def __del__(self):
"""Cleanup for the environment."""
self.close()
......@@ -208,6 +214,9 @@ class ManagerBasedEnv:
"""
# prepare the managers
# -- recorder manager
self.recorder_manager = RecorderManager(self.cfg.recorders, self)
print("[INFO] Recorder Manager: ", self.recorder_manager)
# -- action manager
self.action_manager = ActionManager(self.cfg.actions, self)
print("[INFO] Action Manager: ", self.action_manager)
......@@ -228,15 +237,18 @@ class ManagerBasedEnv:
Operations - MDP.
"""
def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[VecEnvObs, dict]:
"""Resets all the environments and returns observations.
def reset(
self, seed: int | None = None, env_ids: Sequence[int] | None = None, options: dict[str, Any] | None = None
) -> tuple[VecEnvObs, dict]:
"""Resets the specified environments and returns observations.
This function calls the :meth:`_reset_idx` function to reset all the environments.
This function calls the :meth:`_reset_idx` function to reset the specified environments.
However, certain operations, such as procedural terrain generation, that happened during initialization
are not repeated.
Args:
seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset.
options: Additional information to specify how the environment is reset. Defaults to None.
Note:
......@@ -245,20 +257,78 @@ class ManagerBasedEnv:
Returns:
A tuple containing the observations and extras.
"""
if env_ids is None:
env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device)
# trigger recorder terms for pre-reset calls
self.recorder_manager.record_pre_reset(env_ids)
# set the seed
if seed is not None:
self.seed(seed)
# reset state of scene
indices = torch.arange(self.num_envs, dtype=torch.int64, device=self.device)
self._reset_idx(indices)
self._reset_idx(env_ids)
self.scene.write_data_to_sim()
# trigger recorder terms for post-reset calls
self.recorder_manager.record_post_reset(env_ids)
# if sensors are added to the scene, make sure we render to reflect changes in reset
if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset:
self.sim.render()
# compute observations
self.obs_buf = self.observation_manager.compute()
# return observations
return self.observation_manager.compute(), self.extras
return self.obs_buf, self.extras
def reset_to(
self,
state: dict[str, dict[str, dict[str, torch.Tensor]]],
env_ids: Sequence[int] | None,
seed: int | None = None,
is_relative: bool = False,
) -> None:
"""Resets specified environments to known states.
Note that this is different from reset() function as it resets the environments to specific states
Args:
state: The state to reset the specified environments to.
env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset.
seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
is_relative: If set to True, the state is considered relative to the environment origins. Defaults to False.
"""
# reset all envs in the scene if env_ids is None
if env_ids is None:
env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device)
# trigger recorder terms for pre-reset calls
self.recorder_manager.record_pre_reset(env_ids)
# set the seed
if seed is not None:
self.seed(seed)
self._reset_idx(env_ids)
# set the state
self.scene.reset_to(state, env_ids, is_relative=is_relative)
# trigger recorder terms for post-reset calls
self.recorder_manager.record_post_reset(env_ids)
# if sensors are added to the scene, make sure we render to reflect changes in reset
if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset:
self.sim.render()
# compute observations
self.obs_buf = self.observation_manager.compute()
# return observations
return self.obs_buf, self.extras
def step(self, action: torch.Tensor) -> tuple[VecEnvObs, dict]:
"""Execute one time-step of the environment's dynamics.
......@@ -278,6 +348,8 @@ class ManagerBasedEnv:
# process actions
self.action_manager.process_action(action.to(self.device))
self.recorder_manager.record_pre_step()
# check if we need to do rendering within the physics loop
# note: checked here once to avoid multiple checks within the loop
is_rendering = self.sim.has_gui() or self.sim.has_rtx_sensors()
......@@ -303,8 +375,12 @@ class ManagerBasedEnv:
if "interval" in self.event_manager.available_modes:
self.event_manager.apply(mode="interval", dt=self.step_dt)
# -- compute observations
self.obs_buf = self.observation_manager.compute()
self.recorder_manager.record_post_step()
# return observations and extras
return self.observation_manager.compute(), self.extras
return self.obs_buf, self.extras
@staticmethod
def seed(seed: int = -1) -> int:
......@@ -334,6 +410,7 @@ class ManagerBasedEnv:
del self.action_manager
del self.observation_manager
del self.event_manager
del self.recorder_manager
del self.scene
# clear callbacks and instance
self.sim.clear_all_callbacks()
......@@ -375,3 +452,6 @@ class ManagerBasedEnv:
# -- event manager
info = self.event_manager.reset(env_ids)
self.extras["log"].update(info)
# -- recorder manager
info = self.recorder_manager.reset(env_ids)
self.extras["log"].update(info)
......@@ -13,6 +13,7 @@ from dataclasses import MISSING
import omni.isaac.lab.envs.mdp as mdp
from omni.isaac.lab.managers import EventTermCfg as EventTerm
from omni.isaac.lab.managers import RecorderManagerBaseCfg as DefaultEmptyRecorderManagerCfg
from omni.isaac.lab.scene import InteractiveSceneCfg
from omni.isaac.lab.sim import SimulationCfg
from omni.isaac.lab.utils import configclass
......@@ -78,6 +79,12 @@ class ManagerBasedEnvCfg:
Please refer to the :class:`omni.isaac.lab.scene.InteractiveSceneCfg` class for more details.
"""
recorders: object = DefaultEmptyRecorderManagerCfg()
"""Recorder settings. Defaults to recording nothing.
Please refer to the :class:`omni.isaac.lab.managers.RecorderManager` class for more details.
"""
observations: object = MISSING
"""Observation space settings.
......
......@@ -158,6 +158,8 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
# process actions
self.action_manager.process_action(action.to(self.device))
self.recorder_manager.record_pre_step()
# check if we need to do rendering within the physics loop
# note: checked here once to avoid multiple checks within the loop
is_rendering = self.sim.has_gui() or self.sim.has_rtx_sensors()
......@@ -190,14 +192,29 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
# -- reward computation
self.reward_buf = self.reward_manager.compute(dt=self.step_dt)
if len(self.recorder_manager.active_terms) > 0:
# update observations for recording if needed
self.obs_buf = self.observation_manager.compute()
self.recorder_manager.record_post_step()
# -- reset envs that terminated/timed-out and log the episode information
reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1)
if len(reset_env_ids) > 0:
# trigger recorder terms for pre-reset calls
self.recorder_manager.record_pre_reset(reset_env_ids)
self._reset_idx(reset_env_ids)
# this is needed to make joint positions set from reset events effective
self.scene.write_data_to_sim()
# if sensors are added to the scene, make sure we render to reflect changes in reset
if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset:
self.sim.render()
# trigger recorder terms for post-reset calls
self.recorder_manager.record_post_reset(reset_env_ids)
# -- update command
self.command_manager.compute(dt=self.step_dt)
# -- step interval events
......@@ -353,6 +370,9 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
# -- termination manager
info = self.termination_manager.reset(env_ids)
self.extras["log"].update(info)
# -- recorder manager
info = self.recorder_manager.reset(env_ids)
self.extras["log"].update(info)
# reset the episode length buffer
self.episode_length_buf[env_ids] = 0
......@@ -20,5 +20,6 @@ from .commands import * # noqa: F401, F403
from .curriculums import * # noqa: F401, F403
from .events import * # noqa: F401, F403
from .observations import * # noqa: F401, F403
from .recorders import * # noqa: F401, F403
from .rewards import * # noqa: F401, F403
from .terminations import * # noqa: F401, F403
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Various recorder terms that can be used in the environment."""
from .recorders import *
from .recorders_cfg import *
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
from collections.abc import Sequence
from omni.isaac.lab.managers.recorder_manager import RecorderTerm
class InitialStateRecorder(RecorderTerm):
"""Recorder term that records the initial state of the environment after reset."""
def record_post_reset(self, env_ids: Sequence[int] | None):
return "initial_state", self._env.scene.get_state(is_relative=True)
class PostStepStatesRecorder(RecorderTerm):
"""Recorder term that records the state of the environment at the end of each step."""
def record_post_step(self):
return "states", self._env.scene.get_state(is_relative=True)
class PreStepActionsRecorder(RecorderTerm):
"""Recorder term that records the actions in the beginning of each step."""
def record_pre_step(self):
return "actions", self._env.action_manager.action
class PreStepFlatPolicyObservationsRecorder(RecorderTerm):
"""Recorder term that records the policy group observations in each step."""
def record_pre_step(self):
return "obs", self._env.obs_buf["policy"]
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from omni.isaac.lab.managers.recorder_manager import RecorderManagerBaseCfg, RecorderTerm, RecorderTermCfg
from omni.isaac.lab.utils import configclass
from . import recorders
##
# State recorders.
##
@configclass
class InitialStateRecorderCfg(RecorderTermCfg):
"""Configuration for the initial state recorder term."""
class_type: type[RecorderTerm] = recorders.InitialStateRecorder
@configclass
class PostStepStatesRecorderCfg(RecorderTermCfg):
"""Configuration for the step state recorder term."""
class_type: type[RecorderTerm] = recorders.PostStepStatesRecorder
@configclass
class PreStepActionsRecorderCfg(RecorderTermCfg):
"""Configuration for the step action recorder term."""
class_type: type[RecorderTerm] = recorders.PreStepActionsRecorder
@configclass
class PreStepFlatPolicyObservationsRecorderCfg(RecorderTermCfg):
"""Configuration for the step policy observation recorder term."""
class_type: type[RecorderTerm] = recorders.PreStepFlatPolicyObservationsRecorder
##
# Recorder manager configurations.
##
@configclass
class ActionStateRecorderManagerCfg(RecorderManagerBaseCfg):
"""Recorder configurations for recording actions and states."""
record_initial_state = InitialStateRecorderCfg()
record_post_step_states = PostStepStatesRecorderCfg()
record_pre_step_actions = PreStepActionsRecorderCfg()
record_pre_step_flat_policy_observations = PreStepFlatPolicyObservationsRecorderCfg()
......@@ -23,10 +23,12 @@ from .manager_term_cfg import (
ManagerTermBaseCfg,
ObservationGroupCfg,
ObservationTermCfg,
RecorderTermCfg,
RewardTermCfg,
TerminationTermCfg,
)
from .observation_manager import ObservationManager
from .recorder_manager import DatasetExportMode, RecorderManager, RecorderManagerBaseCfg, RecorderTerm
from .reward_manager import RewardManager
from .scene_entity_cfg import SceneEntityCfg
from .termination_manager import TerminationManager
......@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from .action_manager import ActionTerm
from .command_manager import CommandTerm
from .manager_base import ManagerTermBase
from .recorder_manager import RecorderTerm
@configclass
......@@ -51,6 +52,22 @@ class ManagerTermBaseCfg:
"""
##
# Recorder manager.
##
@configclass
class RecorderTermCfg:
"""Configuration for an recorder term."""
class_type: type[RecorderTerm] = MISSING
"""The associated recorder term class.
The class should inherit from :class:`omni.isaac.lab.managers.action_manager.RecorderTerm`.
"""
##
# Action manager.
##
......
......@@ -341,6 +341,56 @@ class InteractiveScene:
"""
return self._extras
@property
def state(self) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
"""Returns the state of the scene entities.
Returns:
A dictionary of the state of the scene entities.
"""
return self.get_state(is_relative=False)
def get_state(self, is_relative: bool = False) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
"""Returns the state of the scene entities.
Args:
is_relative: If set to True, the state is considered relative to the environment origins.
Returns:
A dictionary of the state of the scene entities.
"""
state = dict()
# articulations
state["articulation"] = dict()
for asset_name, articulation in self._articulations.items():
asset_state = dict()
asset_state["root_pose"] = articulation.data.root_state_w[:, :7].clone()
if is_relative:
asset_state["root_pose"][:, :3] -= self.env_origins
asset_state["root_velocity"] = articulation.data.root_vel_w.clone()
asset_state["joint_position"] = articulation.data.joint_pos.clone()
asset_state["joint_velocity"] = articulation.data.joint_vel.clone()
state["articulation"][asset_name] = asset_state
# deformable objects
state["deformable_object"] = dict()
for asset_name, deformable_object in self._deformable_objects.items():
asset_state = dict()
asset_state["nodal_position"] = deformable_object.data.nodal_pos_w.clone()
if is_relative:
asset_state["nodal_position"][:, :3] -= self.env_origins
asset_state["nodal_velocity"] = deformable_object.data.nodal_vel_w.clone()
state["deformable_object"][asset_name] = asset_state
# rigid objects
state["rigid_object"] = dict()
for asset_name, rigid_object in self._rigid_objects.items():
asset_state = dict()
asset_state["root_pose"] = rigid_object.data.root_state_w[:, :7].clone()
if is_relative:
asset_state["root_pose"][:, :3] -= self.env_origins
asset_state["root_velocity"] = rigid_object.data.root_vel_w.clone()
state["rigid_object"][asset_name] = asset_state
return state
"""
Operations.
"""
......@@ -365,6 +415,58 @@ class InteractiveScene:
for sensor in self._sensors.values():
sensor.reset(env_ids)
def reset_to(
self,
state: dict[str, dict[str, dict[str, torch.Tensor]]],
env_ids: Sequence[int] | None = None,
is_relative: bool = False,
):
"""Resets the scene entities to the given state.
Args:
state: The state to reset the scene entities to.
env_ids: The indices of the environments to reset.
Defaults to None (all instances).
is_relative: If set to True, the state is considered relative to the environment origins.
"""
if env_ids is None:
env_ids = slice(None)
# articulations
for asset_name, articulation in self._articulations.items():
asset_state = state["articulation"][asset_name]
# root state
root_pose = asset_state["root_pose"].clone()
if is_relative:
root_pose[:, :3] += self.env_origins[env_ids]
root_velocity = asset_state["root_velocity"].clone()
articulation.write_root_pose_to_sim(root_pose, env_ids=env_ids)
articulation.write_root_velocity_to_sim(root_velocity, env_ids=env_ids)
# joint state
joint_position = asset_state["joint_position"].clone()
joint_velocity = asset_state["joint_velocity"].clone()
articulation.write_joint_state_to_sim(joint_position, joint_velocity, env_ids=env_ids)
articulation.set_joint_position_target(joint_position, env_ids=env_ids)
articulation.set_joint_velocity_target(joint_velocity, env_ids=env_ids)
# deformable objects
for asset_name, deformable_object in self._deformable_objects.items():
asset_state = state["deformable_object"][asset_name]
nodal_position = asset_state["nodal_position"].clone()
if is_relative:
nodal_position[:, :3] += self.env_origins[env_ids]
nodal_velocity = asset_state["nodal_velocity"].clone()
deformable_object.write_nodal_pos_to_sim(nodal_position, env_ids=env_ids)
deformable_object.write_nodal_velocity_to_sim(nodal_velocity, env_ids=env_ids)
# rigid objects
for asset_name, rigid_object in self._rigid_objects.items():
asset_state = state["rigid_object"][asset_name]
root_pose = asset_state["root_pose"].clone()
if is_relative:
root_pose[:, :3] += self.env_origins[env_ids]
root_velocity = asset_state["root_velocity"].clone()
rigid_object.write_root_pose_to_sim(root_pose, env_ids=env_ids)
rigid_object.write_root_velocity_to_sim(root_velocity, env_ids=env_ids)
self.write_data_to_sim()
def write_data_to_sim(self):
"""Writes the data of the scene entities to the simulation."""
# -- assets
......
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""
Submodule for datasets classes and methods.
"""
from .dataset_file_handler_base import DatasetFileHandlerBase
from .episode_data import EpisodeData
from .hdf5_dataset_file_handler import HDF5DatasetFileHandler
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
from abc import ABC, abstractmethod
from .episode_data import EpisodeData
class DatasetFileHandlerBase(ABC):
"""Abstract class for handling dataset files."""
def __init__(self):
"""Initializes the dataset file handler."""
pass
@abstractmethod
def open(self, file_path: str, mode: str = "r"):
"""Open a file."""
return NotImplementedError
@abstractmethod
def create(self, file_path: str, env_name: str = None):
"""Create a new file."""
return NotImplementedError
@abstractmethod
def get_env_name(self) -> str | None:
"""Get the environment name."""
return NotImplementedError
@abstractmethod
def write_episode(self, episode: EpisodeData):
"""Write episode data to the file."""
return NotImplementedError
@abstractmethod
def flush(self):
"""Flush the file."""
return NotImplementedError
@abstractmethod
def close(self):
"""Close the file."""
return NotImplementedError
@abstractmethod
def load_episode(self, episode_name: str) -> EpisodeData | None:
"""Load episode data from the file."""
return NotImplementedError
@abstractmethod
def get_num_episodes(self) -> int:
"""Get number of episodes in the file."""
return NotImplementedError
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
class EpisodeData:
"""Class to store episode data."""
def __init__(self) -> None:
"""Initializes episode data class."""
self._data = dict()
self._next_action_index = 0
self._next_state_index = 0
self._seed = None
self._env_id = None
self._success = None
@property
def data(self):
"""Returns the episode data."""
return self._data
@data.setter
def data(self, data: dict):
"""Set the episode data."""
self._data = data
@property
def seed(self):
"""Returns the random number generator seed."""
return self._seed
@seed.setter
def seed(self, seed: int):
"""Set the random number generator seed."""
self._seed = seed
@property
def env_id(self):
"""Returns the environment ID."""
return self._env_id
@env_id.setter
def env_id(self, env_id: int):
"""Set the environment ID."""
self._env_id = env_id
@property
def next_action_index(self):
"""Returns the next action index."""
return self._next_action_index
@next_action_index.setter
def next_action_index(self, index: int):
"""Set the next action index."""
self._next_action_index = index
@property
def next_state_index(self):
"""Returns the next state index."""
return self._next_state_index
@next_state_index.setter
def next_state_index(self, index: int):
"""Set the next state index."""
self._next_state_index = index
@property
def success(self):
"""Returns the success value."""
return self._success
@success.setter
def success(self, success: bool):
"""Set the success value."""
self._success = success
def is_empty(self):
"""Check if the episode data is empty."""
return not bool(self._data)
def add(self, key: str, value: torch.Tensor | dict):
"""Add a key-value pair to the dataset.
The key can be nested by using the "/" character.
For example: "obs/joint_pos".
Args:
key: The key name.
value: The corresponding value of tensor type or of dict type.
"""
# check datatype
if isinstance(value, dict):
for sub_key, sub_value in value.items():
self.add(f"{key}/{sub_key}", sub_value)
return
sub_keys = key.split("/")
current_dataset_pointer = self._data
for sub_key_index in range(len(sub_keys)):
if sub_key_index == len(sub_keys) - 1:
# Add value to the final dict layer
if sub_keys[sub_key_index] not in current_dataset_pointer:
current_dataset_pointer[sub_keys[sub_key_index]] = value.unsqueeze(0).clone()
else:
current_dataset_pointer[sub_keys[sub_key_index]] = torch.cat(
(current_dataset_pointer[sub_keys[sub_key_index]], value.unsqueeze(0))
)
break
# key index
if sub_keys[sub_key_index] not in current_dataset_pointer:
current_dataset_pointer[sub_keys[sub_key_index]] = dict()
current_dataset_pointer = current_dataset_pointer[sub_keys[sub_key_index]]
def get_initial_state(self) -> torch.Tensor | None:
"""Get the initial state from the dataset."""
if "initial_state" not in self._data:
return None
return self._data["initial_state"]
def get_action(self, action_index) -> torch.Tensor | None:
"""Get the action of the specified index from the dataset."""
if "actions" not in self._data:
return None
if action_index >= len(self._data["actions"]):
return None
return self._data["actions"][action_index]
def get_next_action(self) -> torch.Tensor | None:
"""Get the next action from the dataset."""
action = self.get_action(self._next_action_index)
if action is not None:
self._next_action_index += 1
return action
def get_state(self, state_index) -> dict | None:
"""Get the state of the specified index from the dataset."""
if "states" not in self._data:
return None
states = self._data["states"]
def get_state_helper(states, state_index) -> dict | torch.Tensor | None:
if isinstance(states, dict):
output_state = dict()
for key, value in states.items():
output_state[key] = get_state_helper(value, state_index)
if output_state[key] is None:
return None
elif isinstance(states, torch.Tensor):
if state_index >= len(states):
return None
output_state = states[state_index]
else:
raise ValueError(f"Invalid state type: {type(states)}")
return output_state
output_state = get_state_helper(states, state_index)
return output_state
def get_next_state(self) -> dict | None:
"""Get the next state from the dataset."""
state = self.get_state(self._next_state_index)
if state is not None:
self._next_state_index += 1
return state
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import h5py
import json
import numpy as np
import os
import torch
from collections.abc import Iterable
from .dataset_file_handler_base import DatasetFileHandlerBase
from .episode_data import EpisodeData
class HDF5DatasetFileHandler(DatasetFileHandlerBase):
"""HDF5 dataset file handler for storing and loading episode data."""
def __init__(self):
"""Initializes the HDF5 dataset file handler."""
self._hdf5_file_stream = None
self._hdf5_data_group = None
self._demo_count = 0
self._env_args = {}
def open(self, file_path: str, mode: str = "r"):
"""Open an existing dataset file."""
if self._hdf5_file_stream is not None:
raise RuntimeError("HDF5 dataset file stream is already in use")
self._hdf5_file_stream = h5py.File(file_path, mode)
self._hdf5_data_group = self._hdf5_file_stream["data"]
self._demo_count = len(self._hdf5_data_group)
def create(self, file_path: str, env_name: str = None):
"""Create a new dataset file."""
if self._hdf5_file_stream is not None:
raise RuntimeError("HDF5 dataset file stream is already in use")
if not file_path.endswith(".hdf5"):
file_path += ".hdf5"
dir_path = os.path.dirname(file_path)
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
self._hdf5_file_stream = h5py.File(file_path, "w")
# set up a data group in the file
self._hdf5_data_group = self._hdf5_file_stream.create_group("data")
self._hdf5_data_group.attrs["total"] = 0
self._demo_count = 0
# set environment arguments
# the environment type (we use gym environment type) is set to be compatible with robomimic
# Ref: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/envs/env_base.py#L15
env_name = env_name if env_name is not None else ""
self.add_env_args({"env_name": env_name, "type": 2})
def __del__(self):
"""Destructor for the file handler."""
self.close()
"""
Properties
"""
def add_env_args(self, env_args: dict):
"""Add environment arguments to the dataset."""
self._raise_if_not_initialized()
self._env_args.update(env_args)
self._hdf5_data_group.attrs["env_args"] = json.dumps(self._env_args)
def set_env_name(self, env_name: str):
"""Set the environment name."""
self._raise_if_not_initialized()
self.add_env_args({"env_name": env_name})
def get_env_name(self) -> str | None:
"""Get the environment name."""
self._raise_if_not_initialized()
env_args = json.loads(self._hdf5_data_group.attrs["env_args"])
if "env_name" in env_args:
return env_args["env_name"]
return None
def get_episode_names(self) -> Iterable[str]:
"""Get the names of the episodes in the file."""
self._raise_if_not_initialized()
return self._hdf5_data_group.keys()
def get_num_episodes(self) -> int:
"""Get number of episodes in the file."""
return self._demo_count
@property
def demo_count(self) -> int:
"""The number of demos collected so far."""
return self._demo_count
"""
Operations.
"""
def load_episode(self, episode_name: str, device: str) -> EpisodeData | None:
"""Load episode data from the file."""
self._raise_if_not_initialized()
if episode_name not in self._hdf5_data_group:
return None
episode = EpisodeData()
h5_episode_group = self._hdf5_data_group[episode_name]
def load_dataset_helper(group):
"""Helper method to load dataset that contains recursive dict objects."""
data = {}
for key in group:
if isinstance(group[key], h5py.Group):
data[key] = load_dataset_helper(group[key])
else:
# Converting group[key] to numpy array greatly improves the performance
# when converting to torch tensor
data[key] = torch.tensor(np.array(group[key]), device=device)
return data
episode.data = load_dataset_helper(h5_episode_group)
if "seed" in h5_episode_group.attrs:
episode.seed = h5_episode_group.attrs["seed"]
if "success" in h5_episode_group.attrs:
episode.success = h5_episode_group.attrs["success"]
episode.env_id = self.get_env_name()
return episode
def write_episode(self, episode: EpisodeData):
"""Add an episode to the dataset.
Args:
episode: The episode data to add.
"""
self._raise_if_not_initialized()
if episode.is_empty():
return
# create episode group based on demo count
h5_episode_group = self._hdf5_data_group.create_group(f"demo_{self._demo_count}")
# store number of steps taken
if "actions" in episode.data:
h5_episode_group.attrs["num_samples"] = len(episode.data["actions"])
else:
h5_episode_group.attrs["num_samples"] = 0
if episode.seed is not None:
h5_episode_group.attrs["seed"] = episode.seed
if episode.success is not None:
h5_episode_group.attrs["success"] = episode.success
def create_dataset_helper(group, key, value):
"""Helper method to create dataset that contains recursive dict objects."""
if isinstance(value, dict):
key_group = group.create_group(key)
for sub_key, sub_value in value.items():
create_dataset_helper(key_group, sub_key, sub_value)
else:
group.create_dataset(key, data=value.cpu().numpy())
for key, value in episode.data.items():
create_dataset_helper(h5_episode_group, key, value)
# increment total step counts
self._hdf5_data_group.attrs["total"] += h5_episode_group.attrs["num_samples"]
# increment total demo counts
self._demo_count += 1
def flush(self):
"""Flush the episode data to disk."""
self._raise_if_not_initialized()
self._hdf5_file_stream.flush()
def close(self):
"""Close the dataset file handler."""
if self._hdf5_file_stream is not None:
self._hdf5_file_stream.close()
self._hdf5_file_stream = None
def _raise_if_not_initialized(self):
"""Raise an error if the dataset file handler is not initialized."""
if self._hdf5_file_stream is None:
raise RuntimeError("HDF5 dataset file stream is not initialized")
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
# needed to import for allowing type-hinting: torch.Tensor | None
from __future__ import annotations
"""Launch Isaac Sim Simulator first."""
from omni.isaac.lab.app import AppLauncher, run_tests
# launch omniverse app
simulation_app = AppLauncher(headless=True).app
"""Rest everything follows."""
import os
import shutil
import tempfile
import torch
import unittest
import uuid
from collections import namedtuple
from collections.abc import Sequence
from omni.isaac.lab.envs import ManagerBasedEnv
from omni.isaac.lab.managers import (
DatasetExportMode,
RecorderManager,
RecorderManagerBaseCfg,
RecorderTerm,
RecorderTermCfg,
)
from omni.isaac.lab.utils import configclass
class DummyResetRecorderTerm(RecorderTerm):
"""Dummy recorder term that records dummy data."""
def __init__(self, cfg: RecorderTermCfg, env: ManagerBasedEnv) -> None:
super().__init__(cfg, env)
def record_pre_reset(self, env_ids: Sequence[int] | None) -> tuple[str | None, torch.Tensor | None]:
return "record_pre_reset", torch.ones(self._env.num_envs, 2, device=self._env.device)
def record_post_reset(self, env_ids: Sequence[int] | None) -> tuple[str | None, torch.Tensor | None]:
return "record_post_reset", torch.ones(self._env.num_envs, 3, device=self._env.device)
class DummyStepRecorderTerm(RecorderTerm):
"""Dummy recorder term that records dummy data."""
def __init__(self, cfg: RecorderTermCfg, env: ManagerBasedEnv) -> None:
super().__init__(cfg, env)
def record_pre_step(self) -> tuple[str | None, torch.Tensor | None]:
return "record_pre_step", torch.ones(self._env.num_envs, 4, device=self._env.device)
def record_post_step(self) -> tuple[str | None, torch.Tensor | None]:
return "record_post_step", torch.ones(self._env.num_envs, 5, device=self._env.device)
@configclass
class DummyRecorderManagerCfg(RecorderManagerBaseCfg):
"""Dummy recorder configurations."""
@configclass
class DummyResetRecorderTermCfg(RecorderTermCfg):
"""Configuration for the dummy reset recorder term."""
class_type: type[RecorderTerm] = DummyResetRecorderTerm
@configclass
class DummyStepRecorderTermCfg(RecorderTermCfg):
"""Configuration for the dummy step recorder term."""
class_type: type[RecorderTerm] = DummyStepRecorderTerm
record_reset_term = DummyResetRecorderTermCfg()
record_step_term = DummyStepRecorderTermCfg()
dataset_export_mode = DatasetExportMode.EXPORT_ALL
def create_dummy_env(device: str = "cpu") -> ManagerBasedEnv:
"""Create a dummy environment."""
class DummyTerminationManager:
active_terms = []
dummy_termination_manager = DummyTerminationManager()
return namedtuple("ManagerBasedEnv", ["num_envs", "device", "cfg", "termination_manager"])(
20, device, dict(), dummy_termination_manager
)
class TestRecorderManager(unittest.TestCase):
"""Test cases for various situations with recorder manager."""
def setUp(self) -> None:
self.dataset_dir = tempfile.mkdtemp()
def tearDown(self):
# delete the temporary directory after the test
shutil.rmtree(self.dataset_dir)
def create_dummy_recorder_manager_cfg(self) -> DummyRecorderManagerCfg:
"""Get the dummy recorder manager configurations."""
cfg = DummyRecorderManagerCfg()
cfg.dataset_export_dir_path = self.dataset_dir
cfg.dataset_filename = f"{uuid.uuid4()}.hdf5"
return cfg
def test_str(self):
"""Test the string representation of the recorder manager."""
# create recorder manager
cfg = DummyRecorderManagerCfg()
recorder_manager = RecorderManager(cfg, create_dummy_env())
self.assertEqual(len(recorder_manager.active_terms), 2)
# print the expected string
print()
print(recorder_manager)
def test_initialize_dataset_file(self):
"""Test the initialization of the dataset file."""
# create recorder manager
cfg = self.create_dummy_recorder_manager_cfg()
_ = RecorderManager(cfg, create_dummy_env())
# check if the dataset is created
self.assertTrue(os.path.exists(os.path.join(cfg.dataset_export_dir_path, cfg.dataset_filename)))
def test_record(self):
"""Test the recording of the data."""
for device in ("cuda:0", "cpu"):
with self.subTest(device=device):
env = create_dummy_env(device)
# create recorder manager
recorder_manager = RecorderManager(self.create_dummy_recorder_manager_cfg(), env)
# record the step data
recorder_manager.record_pre_step()
recorder_manager.record_post_step()
recorder_manager.record_pre_step()
recorder_manager.record_post_step()
# check the recorded data
for env_id in range(env.num_envs):
episode = recorder_manager.get_episode(env_id)
self.assertEqual(episode.data["record_pre_step"].shape, (2, 4))
self.assertEqual(episode.data["record_post_step"].shape, (2, 5))
# Trigger pre-reset callbacks which then export and clean the episode data
recorder_manager.record_pre_reset(env_ids=None)
for env_id in range(env.num_envs):
episode = recorder_manager.get_episode(env_id)
self.assertTrue(episode.is_empty())
recorder_manager.record_post_reset(env_ids=None)
for env_id in range(env.num_envs):
episode = recorder_manager.get_episode(env_id)
self.assertEqual(episode.data["record_post_reset"].shape, (1, 3))
if __name__ == "__main__":
run_tests()
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Launch Isaac Sim Simulator first."""
from omni.isaac.lab.app import AppLauncher, run_tests
# launch omniverse app in headless mode
simulation_app = AppLauncher(headless=True).app
"""Rest everything follows from here."""
import torch
import unittest
from omni.isaac.lab.utils.datasets import EpisodeData
class TestEpisodeData(unittest.TestCase):
"""Test EpisodeData implementation."""
"""
Test cases for EpisodeData class.
"""
def test_is_empty(self):
"""Test checking whether the episode is empty."""
for device in ("cuda:0", "cpu"):
with self.subTest(device=device):
episode = EpisodeData()
self.assertTrue(episode.is_empty())
episode.add("key", torch.tensor([1, 2, 3], device=device))
self.assertFalse(episode.is_empty())
def test_add_tensors(self):
"""Test appending tensor data to the episode."""
for device in ("cuda:0", "cpu"):
with self.subTest(device=device):
dummy_data_0 = torch.tensor([0], device=device)
dummy_data_1 = torch.tensor([1], device=device)
expected_added_data = torch.cat((dummy_data_0.unsqueeze(0), dummy_data_1.unsqueeze(0)))
episode = EpisodeData()
# test adding data to a key that does not exist
episode.add("key", dummy_data_0)
self.assertTrue(torch.equal(episode.data.get("key"), dummy_data_0.unsqueeze(0)))
# test adding data to a key that exists
episode.add("key", dummy_data_1)
self.assertTrue(torch.equal(episode.data.get("key"), expected_added_data))
# test adding data to a key with "/" in the name
episode.add("first/second", dummy_data_0)
self.assertTrue(torch.equal(episode.data.get("first").get("second"), dummy_data_0.unsqueeze(0)))
# test adding data to a key with "/" in the name that already exists
episode.add("first/second", dummy_data_1)
self.assertTrue(torch.equal(episode.data.get("first").get("second"), expected_added_data))
def test_add_dict_tensors(self):
"""Test appending dict data to the episode."""
for device in ("cuda:0", "cpu"):
with self.subTest(device=device):
dummy_dict_data_0 = {
"key_0": torch.tensor([0], device=device),
"key_1": {"key_1_0": torch.tensor([1], device=device), "key_1_1": torch.tensor([2], device=device)},
}
dummy_dict_data_1 = {
"key_0": torch.tensor([3], device=device),
"key_1": {"key_1_0": torch.tensor([4], device=device), "key_1_1": torch.tensor([5], device=device)},
}
episode = EpisodeData()
# test adding dict data to a key that does not exist
episode.add("key", dummy_dict_data_0)
self.assertTrue(torch.equal(episode.data.get("key").get("key_0"), torch.tensor([[0]], device=device)))
self.assertTrue(
torch.equal(episode.data.get("key").get("key_1").get("key_1_0"), torch.tensor([[1]], device=device))
)
self.assertTrue(
torch.equal(episode.data.get("key").get("key_1").get("key_1_1"), torch.tensor([[2]], device=device))
)
# test adding dict data to a key that exists
episode.add("key", dummy_dict_data_1)
self.assertTrue(
torch.equal(episode.data.get("key").get("key_0"), torch.tensor([[0], [3]], device=device))
)
self.assertTrue(
torch.equal(
episode.data.get("key").get("key_1").get("key_1_0"), torch.tensor([[1], [4]], device=device)
)
)
self.assertTrue(
torch.equal(
episode.data.get("key").get("key_1").get("key_1_1"), torch.tensor([[2], [5]], device=device)
)
)
def test_get_initial_state(self):
"""Test getting the initial state of the episode."""
for device in ("cuda:0", "cpu"):
with self.subTest(device=device):
dummy_initial_state = torch.tensor([1, 2, 3], device=device)
episode = EpisodeData()
episode.add("initial_state", dummy_initial_state)
self.assertTrue(torch.equal(episode.get_initial_state(), dummy_initial_state.unsqueeze(0)))
def test_get_next_action(self):
"""Test getting next actions."""
for device in ("cuda:0", "cpu"):
with self.subTest(device=device):
# dummy actions
action1 = torch.tensor([1, 2, 3], device=device)
action2 = torch.tensor([4, 5, 6], device=device)
action3 = torch.tensor([7, 8, 9], device=device)
episode = EpisodeData()
self.assertIsNone(episode.get_next_action())
episode.add("actions", action1)
episode.add("actions", action2)
episode.add("actions", action3)
# check if actions are returned in the correct order
self.assertTrue(torch.equal(episode.get_next_action(), action1))
self.assertTrue(torch.equal(episode.get_next_action(), action2))
self.assertTrue(torch.equal(episode.get_next_action(), action3))
# check if None is returned when all actions are exhausted
self.assertIsNone(episode.get_next_action())
if __name__ == "__main__":
run_tests()
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Launch Isaac Sim Simulator first."""
from omni.isaac.lab.app import AppLauncher, run_tests
# launch omniverse app in headless mode
simulation_app = AppLauncher(headless=True).app
"""Rest everything follows from here."""
import os
import shutil
import tempfile
import torch
import unittest
import uuid
from omni.isaac.lab.utils.datasets import EpisodeData, HDF5DatasetFileHandler
def create_test_episode(device):
"""create a test episode with dummy data."""
test_episode = EpisodeData()
test_episode.seed = 0
test_episode.success = True
test_episode.add("initial_state", torch.tensor([1, 2, 3], device=device))
test_episode.add("actions", torch.tensor([1, 2, 3], device=device))
test_episode.add("actions", torch.tensor([4, 5, 6], device=device))
test_episode.add("actions", torch.tensor([7, 8, 9], device=device))
test_episode.add("obs/policy/term1", torch.tensor([1, 2, 3, 4, 5], device=device))
test_episode.add("obs/policy/term1", torch.tensor([6, 7, 8, 9, 10], device=device))
test_episode.add("obs/policy/term1", torch.tensor([11, 12, 13, 14, 15], device=device))
return test_episode
class TestHDF5DatasetFileHandler(unittest.TestCase):
"""Test HDF5 dataset filer handler implementation."""
"""
Test cases for HDF5DatasetFileHandler class.
"""
def setUp(self):
# create a temporary directory to store the test datasets
self.temp_dir = tempfile.mkdtemp()
def tearDown(self):
# delete the temporary directory after the test
shutil.rmtree(self.temp_dir)
def test_create_dataset_file(self):
"""Test creating a new dataset file."""
# create a dataset file given a file name with extension
dataset_file_path = os.path.join(self.temp_dir, f"{uuid.uuid4()}.hdf5")
dataset_file_handler = HDF5DatasetFileHandler()
dataset_file_handler.create(dataset_file_path, "test_env_name")
dataset_file_handler.close()
# check if the dataset is created
self.assertTrue(os.path.exists(dataset_file_path))
# create a dataset file given a file name without extension
dataset_file_path = os.path.join(self.temp_dir, f"{uuid.uuid4()}")
dataset_file_handler = HDF5DatasetFileHandler()
dataset_file_handler.create(dataset_file_path, "test_env_name")
dataset_file_handler.close()
# check if the dataset is created
self.assertTrue(os.path.exists(dataset_file_path + ".hdf5"))
def test_write_and_load_episode(self):
"""Test writing and loading an episode to and from the dataset file."""
for device in ("cuda:0", "cpu"):
with self.subTest(device=device):
dataset_file_path = os.path.join(self.temp_dir, f"{uuid.uuid4()}.hdf5")
dataset_file_handler = HDF5DatasetFileHandler()
dataset_file_handler.create(dataset_file_path, "test_env_name")
test_episode = create_test_episode(device)
# write the episode to the dataset
dataset_file_handler.write_episode(test_episode)
dataset_file_handler.flush()
self.assertEqual(dataset_file_handler.get_num_episodes(), 1)
# write the episode again to test writing 2nd episode
dataset_file_handler.write_episode(test_episode)
dataset_file_handler.flush()
self.assertEqual(dataset_file_handler.get_num_episodes(), 2)
# close the dataset file to prepare for testing the load function
dataset_file_handler.close()
# load the episode from the dataset
dataset_file_handler = HDF5DatasetFileHandler()
dataset_file_handler.open(dataset_file_path)
self.assertEqual(dataset_file_handler.get_env_name(), "test_env_name")
loaded_episode_names = dataset_file_handler.get_episode_names()
self.assertEqual(len(list(loaded_episode_names)), 2)
for episode_name in loaded_episode_names:
loaded_episode = dataset_file_handler.load_episode(episode_name, device=device)
self.assertEqual(loaded_episode.env_id, "test_env_name")
self.assertEqual(loaded_episode.seed, test_episode.seed)
self.assertEqual(loaded_episode.success, test_episode.success)
self.assertTrue(torch.equal(loaded_episode.get_initial_state(), test_episode.get_initial_state()))
for action in test_episode.data["actions"]:
self.assertTrue(torch.equal(loaded_episode.get_next_action(), action))
dataset_file_handler.close()
if __name__ == "__main__":
run_tests()
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""
Script to record demonstrations with Isaac Lab environments using human teleoperation.
This script allows users to record demonstrations operated by human teleoperation for a specified task.
The recorded demonstrations are stored as episodes in a hdf5 file. Users can specify the task, teleoperation
device, dataset directory, and environment stepping rate through command-line arguments.
required arguments:
--task Name of the task.
optional arguments:
-h, --help Show this help message and exit
--teleop_device Device for interacting with environment. (default: keyboard)
--dataset_file File path to export recorded demos. (default: "./datasets/dataset.hdf5")
--step_hz Environment stepping rate in Hz. (default: 30)
"""
"""Launch Isaac Sim Simulator first."""
import argparse
from omni.isaac.lab.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="Record demonstrations for Isaac Lab environments.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--teleop_device", type=str, default="keyboard", help="Device for interacting with environment.")
parser.add_argument(
"--dataset_file", type=str, default="./datasets/dataset.hdf5", help="File path to export recorded demos."
)
parser.add_argument("--step_hz", type=int, default=30, help="Environment stepping rate in Hz.")
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# launch the simulator
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
"""Rest everything follows."""
import contextlib
import gymnasium as gym
import os
import time
import torch
from omni.isaac.lab.devices import Se3Keyboard, Se3SpaceMouse
from omni.isaac.lab.envs.mdp.recorders.recorders_cfg import ActionStateRecorderManagerCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg
class RateLimiter:
"""Convenience class for enforcing rates in loops."""
def __init__(self, hz):
"""
Args:
hz (int): frequency to enforce
"""
self.hz = hz
self.last_time = time.time()
self.sleep_duration = 1.0 / hz
self.render_period = min(0.033, self.sleep_duration)
def sleep(self, env):
"""Attempt to sleep at the specified rate in hz."""
next_wakeup_time = self.last_time + self.sleep_duration
while time.time() < next_wakeup_time:
time.sleep(self.render_period)
env.unwrapped.sim.render()
self.last_time = self.last_time + self.sleep_duration
# detect time jumping forwards (e.g. loop is too slow)
if self.last_time < time.time():
while self.last_time < time.time():
self.last_time += self.sleep_duration
def pre_process_actions(delta_pose: torch.Tensor, gripper_command: bool) -> torch.Tensor:
"""Pre-process actions for the environment."""
# compute actions based on environment
if "Reach" in args_cli.task:
# note: reach is the only one that uses a different action space
# compute actions
return delta_pose
else:
# resolve gripper command
gripper_vel = torch.zeros((delta_pose.shape[0], 1), dtype=torch.float, device=delta_pose.device)
gripper_vel[:] = -1 if gripper_command else 1
# compute actions
return torch.concat([delta_pose, gripper_vel], dim=1)
def main():
"""Collect demonstrations from the environment using teleop interfaces."""
rate_limiter = RateLimiter(args_cli.step_hz)
# get directory path and file name (without extension) from cli arguments
output_dir = os.path.dirname(args_cli.dataset_file)
output_file_name = os.path.splitext(os.path.basename(args_cli.dataset_file))[0]
# create directory if it does not exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# parse configuration
env_cfg = parse_env_cfg(args_cli.task, device=args_cli.device, num_envs=1)
env_cfg.env_name = args_cli.task
# modify configuration such that the environment runs indefinitely
# until goal is reached
env_cfg.terminations.time_out = None
env_cfg.observations.policy.concatenate_terms = False
env_cfg.recorders: ActionStateRecorderManagerCfg = ActionStateRecorderManagerCfg()
env_cfg.recorders.dataset_export_dir_path = output_dir
env_cfg.recorders.dataset_filename = output_file_name
# create environment
env = gym.make(args_cli.task, cfg=env_cfg)
# create controller
if args_cli.teleop_device.lower() == "keyboard":
teleop_interface = Se3Keyboard(pos_sensitivity=0.2, rot_sensitivity=0.5)
elif args_cli.teleop_device.lower() == "spacemouse":
teleop_interface = Se3SpaceMouse(pos_sensitivity=0.2, rot_sensitivity=0.5)
else:
raise ValueError(f"Invalid device interface '{args_cli.teleop_device}'. Supported: 'keyboard', 'spacemouse'.")
# add teleoperation key for reset current recording instance
should_reset_recording_instance = False
def reset_recording_instance():
nonlocal should_reset_recording_instance
should_reset_recording_instance = True
teleop_interface.add_callback("R", reset_recording_instance)
print(teleop_interface)
# reset before starting
env.reset()
teleop_interface.reset()
# simulate environment -- run everything in inference mode
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while True:
# get keyboard command
delta_pose, gripper_command = teleop_interface.advance()
# convert to torch
delta_pose = torch.tensor(delta_pose, dtype=torch.float, device=env.device).repeat(env.num_envs, 1)
# compute actions based on environment
actions = pre_process_actions(delta_pose, gripper_command)
# perform action on environment
env.step(actions)
if should_reset_recording_instance:
env.unwrapped.recorder_manager.reset()
env.reset()
should_reset_recording_instance = False
# check that simulation is stopped or not
if env.unwrapped.sim.is_stopped():
break
rate_limiter.sleep(env.unwrapped)
env.close()
if __name__ == "__main__":
# run the main function
main()
# close sim app
simulation_app.close()
# Copyright (c) 2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Script to replay demonstrations with Isaac Lab environments."""
"""Launch Isaac Sim Simulator first."""
import argparse
from omni.isaac.lab.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="Replay demonstrations in Isaac Lab environments.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to replay episodes.")
parser.add_argument("--task", type=str, default=None, help="Force to use the specified task.")
parser.add_argument(
"--select_episodes",
type=int,
nargs="+",
default=[],
help="A list of episode indices to be replayed. Keep empty to replay all in the dataset file.",
)
parser.add_argument("--dataset_file", type=str, default="datasets/dataset.hdf5", help="Dataset file to be replayed.")
parser.add_argument(
"--validate_states",
action="store_true",
default=False,
help=(
"Validate if the states, if available, match between loaded from datasets and replayed. Only valid if"
" --num_envs is 1."
),
)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# args_cli.headless = True
# launch the simulator
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
"""Rest everything follows."""
import contextlib
import gymnasium as gym
import os
import torch
from omni.isaac.lab.devices import Se3Keyboard
from omni.isaac.lab.utils.datasets import EpisodeData, HDF5DatasetFileHandler
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg
is_paused = False
def play_cb():
global is_paused
is_paused = False
def pause_cb():
global is_paused
is_paused = True
def compare_states(state_from_dataset, runtime_state, runtime_env_index) -> (bool, str):
"""Compare states from dataset and runtime.
Args:
state_from_dataset: State from dataset.
runtime_state: State from runtime.
runtime_env_index: Index of the environment in the runtime states to be compared.
Returns:
bool: True if states match, False otherwise.
str: Log message if states don't match.
"""
states_matched = True
output_log = ""
for asset_type in ["articulation", "rigid_object"]:
for asset_name in runtime_state[asset_type].keys():
for state_name in runtime_state[asset_type][asset_name].keys():
runtime_asset_state = runtime_state[asset_type][asset_name][state_name][runtime_env_index]
dataset_asset_state = state_from_dataset[asset_type][asset_name][state_name]
if len(dataset_asset_state) != len(runtime_asset_state):
raise ValueError(f"State shape of {state_name} for asset {asset_name} don't match")
for i in range(len(dataset_asset_state)):
if abs(dataset_asset_state[i] - runtime_asset_state[i]) > 0.01:
states_matched = False
output_log += f'\tState ["{asset_type}"]["{asset_name}"]["{state_name}"][{i}] don\'t match\r\n'
output_log += f"\t Dataset:\t{dataset_asset_state[i]}\r\n"
output_log += f"\t Runtime: \t{runtime_asset_state[i]}\r\n"
return states_matched, output_log
def main():
"""Replay episodes loaded from a file."""
global is_paused
# Load dataset
if not os.path.exists(args_cli.dataset_file):
raise FileNotFoundError(f"The dataset file {args_cli.dataset_file} does not exist.")
dataset_file_handler = HDF5DatasetFileHandler()
dataset_file_handler.open(args_cli.dataset_file)
env_name = dataset_file_handler.get_env_name()
episode_count = dataset_file_handler.get_num_episodes()
if episode_count == 0:
print("No episodes found in the dataset.")
exit()
episode_indices_to_replay = args_cli.select_episodes
if len(episode_indices_to_replay) == 0:
episode_indices_to_replay = list(range(episode_count))
if args_cli.task is not None:
env_name = args_cli.task
if env_name is None:
raise ValueError("Task/env name was not specified nor found in the dataset.")
num_envs = args_cli.num_envs
env_cfg = parse_env_cfg(env_name, device=args_cli.device, num_envs=num_envs)
# Disable all recorders and terminations
env_cfg.recorders = {}
env_cfg.terminations = {}
# create environment from loaded config
env = gym.make(env_name, cfg=env_cfg)
teleop_interface = Se3Keyboard(pos_sensitivity=0.1, rot_sensitivity=0.1)
teleop_interface.add_callback("N", play_cb)
teleop_interface.add_callback("B", pause_cb)
print(teleop_interface)
# Determine if state validation should be conducted
state_validation_enabled = False
if args_cli.validate_states and num_envs == 1:
state_validation_enabled = True
elif args_cli.validate_states and num_envs > 1:
print("Warning: State validation is only supported with a single environment. Skipping state validation.")
# reset before starting
env.reset()
teleop_interface.reset()
# simulate environment -- run everything in inference mode
episode_names = list(dataset_file_handler.get_episode_names())
replayed_episode_count = 0
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while simulation_app.is_running() and not simulation_app.is_exiting():
env_episode_data_map = {index: EpisodeData() for index in range(num_envs)}
first_loop = True
has_next_action = True
while has_next_action:
# initialize actions with zeros so those without next action will not move
actions = torch.zeros(env.unwrapped.action_space.shape)
has_next_action = False
for env_id in range(num_envs):
env_next_action = env_episode_data_map[env_id].get_next_action()
if env_next_action is None:
next_episode_index = None
while episode_indices_to_replay:
next_episode_index = episode_indices_to_replay.pop(0)
if next_episode_index < episode_count:
break
next_episode_index = None
if next_episode_index is not None:
replayed_episode_count += 1
print(f"{replayed_episode_count :4}: Loading #{next_episode_index} episode to env_{env_id}")
episode_data = dataset_file_handler.load_episode(
episode_names[next_episode_index], env.unwrapped.device
)
env_episode_data_map[env_id] = episode_data
# Set initial state for the new episode
initial_state = episode_data.get_initial_state()
env.unwrapped.reset_to(
initial_state, torch.tensor([env_id], device=env.unwrapped.device), is_relative=True
)
# Get the first action for the new episode
env_next_action = env_episode_data_map[env_id].get_next_action()
has_next_action = True
else:
continue
else:
has_next_action = True
actions[env_id] = env_next_action
if first_loop:
first_loop = False
else:
while is_paused:
env.unwrapped.sim.render()
continue
env.step(actions)
if state_validation_enabled:
state_from_dataset = env_episode_data_map[0].get_next_state()
if state_from_dataset is not None:
print(
f"Validating states at action-index: {env_episode_data_map[0].next_state_index - 1 :4}",
end="",
)
current_runtime_state = env.unwrapped.scene.get_state(is_relative=True)
states_matched, comparison_log = compare_states(state_from_dataset, current_runtime_state, 0)
if states_matched:
print("\t- matched.")
else:
print("\t- mismatched.")
print(comparison_log)
break
# Close environment after replay in complete
plural_trailing_s = "s" if replayed_episode_count > 1 else ""
print(f"Finished replaying {replayed_episode_count} episode{plural_trailing_s}.")
env.close()
if __name__ == "__main__":
# run the main function
main()
# close sim app
simulation_app.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment