Commit 464631fa authored by CY Chen's avatar CY Chen Committed by Kelly Guo

Updates mimic to support multi-eef (DexMimicGen) data generation (#287)

This PR updates mimic to support multi-eef (DexMimicgen) data
generation.
It consists of the following major changes:
- Updated mimic code to support environments with multiple end effectors
- Added support for setting subtask constraints based on DexMimicGen
- Updated annotate_demos.py to support annotating subtask term signals
for multiple end effectors
- Updated mimic API target_eef_pose_to_action() to take noise as
dictionary of eef noise values instead of a single value

- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 78a70bb5
......@@ -7,7 +7,7 @@
Main data generation script.
"""
# Launching Isaac Sim Simulator first.
"""Launch Isaac Sim Simulator first."""
import argparse
......@@ -45,10 +45,15 @@ simulation_app = app_launcher.app
import asyncio
import gymnasium as gym
import inspect
import numpy as np
import random
import torch
import omni
from isaaclab.envs import ManagerBasedRLMimicEnv
import isaaclab_mimic.envs # noqa: F401
from isaaclab_mimic.datagen.generation import env_loop, setup_async_generation, setup_env_config
from isaaclab_mimic.datagen.utils import get_env_name_from_dataset, setup_output_paths
......@@ -74,12 +79,22 @@ def main():
)
# create environment
env = gym.make(env_name, cfg=env_cfg)
env = gym.make(env_name, cfg=env_cfg).unwrapped
if not isinstance(env, ManagerBasedRLMimicEnv):
raise ValueError("The environment should be derived from ManagerBasedRLMimicEnv")
# check if the mimic API from this environment contains decprecated signatures
if "action_noise_dict" not in inspect.signature(env.target_eef_pose_to_action).parameters:
omni.log.warn(
f'The "noise" parameter in the "{env_name}" environment\'s mimic API "target_eef_pose_to_action", '
"is deprecated. Please update the API to take action_noise_dict instead."
)
# set seed for generation
random.seed(env.unwrapped.cfg.datagen_config.seed)
np.random.seed(env.unwrapped.cfg.datagen_config.seed)
torch.manual_seed(env.unwrapped.cfg.datagen_config.seed)
random.seed(env.cfg.datagen_config.seed)
np.random.seed(env.cfg.datagen_config.seed)
torch.manual_seed(env.cfg.datagen_config.seed)
# reset before starting
env.reset()
......@@ -95,7 +110,13 @@ def main():
try:
asyncio.ensure_future(asyncio.gather(*async_components["tasks"]))
env_loop(env, async_components["action_queue"], async_components["info_pool"], async_components["event_loop"])
env_loop(
env,
async_components["reset_queue"],
async_components["action_queue"],
async_components["info_pool"],
async_components["event_loop"],
)
except asyncio.CancelledError:
print("Tasks were cancelled.")
......
......@@ -118,6 +118,18 @@ Changed
* ``set_fixed_tendon_limit`` → ``set_fixed_tendon_position_limit``
0.34.12 (2025-03-06)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Updated the mimic API :meth:`target_eef_pose_to_action` in :class:`isaaclab.envs.ManagerBasedRLMimicEnv` to take a dictionary of
eef noise values instead of a single noise value.
* Added support for optional subtask constraints based on DexMimicGen to the mimic configuration class :class:`isaaclab.envs.MimicEnvCfg`.
* Enabled data compression in HDF5 dataset file handler :class:`isaaclab.utils.datasets.hdf5_dataset_file_handler.HDF5DatasetFileHandler`.
0.34.11 (2025-03-04)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -47,7 +47,11 @@ class ManagerBasedRLMimicEnv(ManagerBasedRLEnv):
raise NotImplementedError
def target_eef_pose_to_action(
self, target_eef_pose_dict: dict, gripper_action_dict: dict, noise: float | None = None, env_id: int = 0
self,
target_eef_pose_dict: dict,
gripper_action_dict: dict,
action_noise_dict: dict | None = None,
env_id: int = 0,
) -> torch.Tensor:
"""
Takes a target pose and gripper action for the end effector controller and returns an action
......@@ -57,7 +61,7 @@ class ManagerBasedRLMimicEnv(ManagerBasedRLEnv):
Args:
target_eef_pose_dict: Dictionary of 4x4 target eef pose for each end-effector.
gripper_action_dict: Dictionary of gripper actions for each end-effector.
noise: Noise to add to the action. If None, no noise is added.
action_noise_dict: Noise to add to the action. If None, no noise is added.
env_id: Environment index to compute the action for.
Returns:
......
......@@ -163,7 +163,7 @@ class HDF5DatasetFileHandler(DatasetFileHandlerBase):
for sub_key, sub_value in value.items():
create_dataset_helper(key_group, sub_key, sub_value)
else:
group.create_dataset(key, data=value.cpu().numpy())
group.create_dataset(key, data=value.cpu().numpy(), compression="gzip")
for key, value in episode.data.items():
create_dataset_helper(h5_episode_group, key, value)
......
[package]
# Semantic Versioning is used: https://semver.org/
version = "1.0.4"
version = "1.0.5"
# Description
category = "isaaclab"
......
Changelog
---------
1.0.4 (2025-03-10)
1.0.5 (2025-03-10)
~~~~~~~~~~~~~~~~~~
Changed
......@@ -15,6 +15,16 @@ Added
* Added ``Isaac-Stack-Cube-Franka-IK-Rel-Blueprint-Mimic-v0`` environment for blueprint vision stacking.
1.0.4 (2025-03-07)
~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Updated data generator to support environments with multiple end effectors.
* Updated data generator to support subtask constraints based on DexMimicGen.
1.0.3 (2025-03-06)
~~~~~~~~~~~~~~~~~~
......
......@@ -6,7 +6,6 @@
"""
Defines structure of information that is needed from an environment for data generation.
"""
import torch
from copy import deepcopy
......@@ -46,40 +45,6 @@ class DatagenInfo:
gripper_action (torch.Tensor or None): gripper actions of shape [..., D] where D
is the dimension of the gripper actuation action for the robot arm
"""
# Type checks using assert
if eef_pose is not None:
assert isinstance(
eef_pose, torch.Tensor
), f"Expected 'eef_pose' to be of type torch.Tensor, but got {type(eef_pose)}"
if object_poses is not None:
assert isinstance(
object_poses, dict
), f"Expected 'object_poses' to be a dictionary, but got {type(object_poses)}"
for k, v in object_poses.items():
assert isinstance(
v, torch.Tensor
), f"Expected 'object_poses[{k}]' to be of type torch.Tensor, but got {type(v)}"
if subtask_term_signals is not None:
assert isinstance(
subtask_term_signals, dict
), f"Expected 'subtask_term_signals' to be a dictionary, but got {type(subtask_term_signals)}"
for k, v in subtask_term_signals.items():
assert isinstance(
v, (torch.Tensor, int, float)
), f"Expected 'subtask_term_signals[{k}]' to be of type torch.Tensor, int, or float, but got {type(v)}"
if target_eef_pose is not None:
assert isinstance(
target_eef_pose, torch.Tensor
), f"Expected 'target_eef_pose' to be of type torch.Tensor, but got {type(target_eef_pose)}"
if gripper_action is not None:
assert isinstance(
gripper_action, torch.Tensor
), f"Expected 'gripper_action' to be of type torch.Tensor, but got {type(gripper_action)}"
self.eef_pose = None
if eef_pose is not None:
self.eef_pose = eef_pose
......
......@@ -5,7 +5,6 @@
import asyncio
import isaaclab.utils.math as PoseUtils
from isaaclab.utils.datasets import EpisodeData, HDF5DatasetFileHandler
from isaaclab_mimic.datagen.datagen_info import DatagenInfo
......@@ -28,7 +27,9 @@ class DataGenInfoPool:
asyncio_lock (asyncio.Lock or None): asyncio lock to use for thread safety
"""
self._datagen_infos = []
self._subtask_indices = []
# Start and end step indices of each subtask in each episode for each eef
self._subtask_boundaries: dict[str, list[list[tuple[int, int]]]] = {}
self.env = env
self.env_cfg = env_cfg
......@@ -36,13 +37,15 @@ class DataGenInfoPool:
self._asyncio_lock = asyncio_lock
if len(env_cfg.subtask_configs) != 1:
raise ValueError("Data generation currently supports only one end-effector.")
(subtask_configs,) = env_cfg.subtask_configs.values()
self.subtask_term_signals = [subtask_config.subtask_term_signal for subtask_config in subtask_configs]
self.subtask_term_offset_ranges = [
subtask_config.subtask_term_offset_range for subtask_config in subtask_configs
# Subtask termination infos for the given environment
self.subtask_term_signal_names: dict[str, list[str]] = {}
self.subtask_term_offset_ranges: dict[str, list[tuple[int, int]]] = {}
for eef_name, eef_subtask_configs in env_cfg.subtask_configs.items():
self.subtask_term_signal_names[eef_name] = [
subtask_config.subtask_term_signal for subtask_config in eef_subtask_configs
]
self.subtask_term_offset_ranges[eef_name] = [
subtask_config.subtask_term_offset_range for subtask_config in eef_subtask_configs
]
@property
......@@ -51,9 +54,9 @@ class DataGenInfoPool:
return self._datagen_infos
@property
def subtask_indices(self):
"""Returns the subtask indices."""
return self._subtask_indices
def subtask_boundaries(self) -> dict[str, list[list[tuple[int, int]]]]:
"""Returns the subtask boundaries."""
return self._subtask_boundaries
@property
def asyncio_lock(self):
......@@ -86,43 +89,18 @@ class DataGenInfoPool:
episode (EpisodeData): episode to add
"""
ep_grp = episode.data
eef_name = list(self.env.cfg.subtask_configs.keys())[0]
# extract datagen info
if "datagen_info" in ep_grp["obs"]:
eef_pose = ep_grp["obs"]["datagen_info"]["eef_pose"][eef_name]
eef_pose = ep_grp["obs"]["datagen_info"]["eef_pose"]
object_poses_dict = ep_grp["obs"]["datagen_info"]["object_pose"]
target_eef_pose = ep_grp["obs"]["datagen_info"]["target_eef_pose"][eef_name]
target_eef_pose = ep_grp["obs"]["datagen_info"]["target_eef_pose"]
subtask_term_signals_dict = ep_grp["obs"]["datagen_info"]["subtask_term_signals"]
else:
# Extract eef poses
eef_pos = ep_grp["obs"]["eef_pos"]
eef_quat = ep_grp["obs"]["eef_quat"] # format (w, x, y, z)
eef_rot_matrices = PoseUtils.matrix_from_quat(eef_quat) # shape (N, 3, 3)
# Create pose matrices for all environments
eef_pose = PoseUtils.make_pose(eef_pos, eef_rot_matrices) # shape (N, 4, 4)
# Object poses
object_poses_dict = dict()
for object_name, value in ep_grp["obs"]["object_pose"].items():
# object_pose
value = value["root_pose"]
# Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_steps, 13).
# Quaternion ordering is wxyz
# Convert to rotation matrices
object_rot_matrices = PoseUtils.matrix_from_quat(value[:, 3:7]) # shape (N, 3, 3)
object_rot_positions = value[:, 0:3] # shape (N, 3)
object_poses_dict[object_name] = PoseUtils.make_pose(object_rot_positions, object_rot_matrices)
# Target eef pose
target_eef_pose = ep_grp["obs"]["target_eef_pose"]
# Subtask termination signalsS
subtask_term_signals_dict = (ep_grp["obs"]["subtask_term_signals"],)
raise ValueError("Episode to be loaded to DatagenInfo pool lacks datagen_info annotations")
# Extract gripper actions
gripper_actions = self.env.actions_to_gripper_actions(ep_grp["actions"])[eef_name]
gripper_actions = self.env.actions_to_gripper_actions(ep_grp["actions"])
ep_datagen_info_obj = DatagenInfo(
eef_pose=eef_pose,
......@@ -133,21 +111,32 @@ class DataGenInfoPool:
)
self._datagen_infos.append(ep_datagen_info_obj)
# parse subtask indices using subtask termination signals
ep_subtask_indices = []
# parse subtask ranges using subtask termination signals and store
# the start and end indices of each subtask for each eef
for eef_name in self.subtask_term_signal_names.keys():
if eef_name not in self._subtask_boundaries:
self._subtask_boundaries[eef_name] = []
prev_subtask_term_ind = 0
for subtask_ind in range(len(self.subtask_term_signals)):
subtask_term_signal = self.subtask_term_signals[subtask_ind]
if subtask_term_signal is None:
# final subtask, finishes at end of demo
eef_subtask_boundaries = []
for subtask_term_signal_name in self.subtask_term_signal_names[eef_name]:
if subtask_term_signal_name is None:
# None refers to the final subtask, so finishes at end of demo
subtask_term_ind = ep_grp["actions"].shape[0]
else:
# trick to detect index where first 0 -> 1 transition occurs - this will be the end of the subtask
subtask_indicators = ep_datagen_info_obj.subtask_term_signals[subtask_term_signal].flatten().int()
subtask_indicators = (
ep_datagen_info_obj.subtask_term_signals[subtask_term_signal_name].flatten().int()
)
diffs = subtask_indicators[1:] - subtask_indicators[:-1]
end_ind = int(diffs.nonzero()[0][0]) + 1
subtask_term_ind = end_ind + 1 # increment to support indexing like demo[start:end]
ep_subtask_indices.append([prev_subtask_term_ind, subtask_term_ind])
if subtask_term_ind <= prev_subtask_term_ind:
raise ValueError(
f"subtask termination signal is not increasing: {subtask_term_ind} should be greater than"
f" {prev_subtask_term_ind}"
)
eef_subtask_boundaries.append((prev_subtask_term_ind, subtask_term_ind))
prev_subtask_term_ind = subtask_term_ind
# run sanity check on subtask_term_offset_range in task spec to make sure we can never
......@@ -155,29 +144,26 @@ class DataGenInfoPool:
#
# end index of subtask i + max offset of subtask i < end index of subtask i + 1 + min offset of subtask i + 1
#
assert len(ep_subtask_indices) == len(
self.subtask_term_signals
), "mismatch in length of extracted subtask info and number of subtasks"
for i in range(1, len(ep_subtask_indices)):
prev_max_offset_range = self.subtask_term_offset_ranges[i - 1][1]
for i in range(1, len(eef_subtask_boundaries)):
prev_max_offset_range = self.subtask_term_offset_ranges[eef_name][i - 1][1]
assert (
ep_subtask_indices[i - 1][1] + prev_max_offset_range
< ep_subtask_indices[i][1] + self.subtask_term_offset_ranges[i][0]
eef_subtask_boundaries[i - 1][1] + prev_max_offset_range
< eef_subtask_boundaries[i][1] + self.subtask_term_offset_ranges[eef_name][i][0]
), (
"subtask sanity check violation in demo with subtask {} end ind {}, subtask {} max offset {},"
" subtask {} end ind {}, and subtask {} min offset {}".format(
i - 1,
ep_subtask_indices[i - 1][1],
eef_subtask_boundaries[i - 1][1],
i - 1,
prev_max_offset_range,
i,
ep_subtask_indices[i][1],
eef_subtask_boundaries[i][1],
i,
self.subtask_term_offset_ranges[i][0],
self.subtask_term_offset_ranges[eef_name][i][0],
)
)
self._subtask_indices.append(ep_subtask_indices)
self._subtask_boundaries[eef_name].append(eef_subtask_boundaries)
def load_from_dataset_file(self, file_path, select_demo_keys: str | None = None):
"""
......
......@@ -8,9 +8,9 @@ import contextlib
import torch
from typing import Any
from isaaclab.envs import ManagerBasedEnv
from isaaclab.envs import ManagerBasedRLMimicEnv
from isaaclab.envs.mdp.recorders.recorders_cfg import ActionStateRecorderManagerCfg
from isaaclab.managers import DatasetExportMode
from isaaclab.managers import DatasetExportMode, TerminationTermCfg
from isaaclab_mimic.datagen.data_generator import DataGenerator
from isaaclab_mimic.datagen.datagen_info_pool import DataGenInfoPool
......@@ -24,23 +24,32 @@ num_attempts = 0
async def run_data_generator(
env: ManagerBasedEnv,
env: ManagerBasedRLMimicEnv,
env_id: int,
env_reset_queue: asyncio.Queue,
env_action_queue: asyncio.Queue,
data_generator: DataGenerator,
success_term: Any,
success_term: TerminationTermCfg,
pause_subtask: bool = False,
):
"""Run data generator."""
"""Run mimic data generation from the given data generator in the specified environment index.
Args:
env: The environment to run the data generator on.
env_id: The environment index to run the data generation on.
env_reset_queue: The asyncio queue to send environment (for this particular env_id) reset requests to.
env_action_queue: The asyncio queue to send actions to for executing actions.
data_generator: The data generator instance to use.
success_term: The success termination term to use.
pause_subtask: Whether to pause the subtask during generation.
"""
global num_success, num_failures, num_attempts
while True:
results = await data_generator.generate(
env_id=env_id,
success_term=success_term,
env_reset_queue=env_reset_queue,
env_action_queue=env_action_queue,
select_src_per_subtask=env.unwrapped.cfg.datagen_config.generation_select_src_per_subtask,
transform_first_robot_pose=env.unwrapped.cfg.datagen_config.generation_transform_first_robot_pose,
interpolate_from_last_target_pose=env.unwrapped.cfg.datagen_config.generation_interpolate_from_last_target_pose,
pause_subtask=pause_subtask,
)
if bool(results["success"]):
......@@ -51,22 +60,40 @@ async def run_data_generator(
def env_loop(
env: ManagerBasedEnv,
env: ManagerBasedRLMimicEnv,
env_reset_queue: asyncio.Queue,
env_action_queue: asyncio.Queue,
shared_datagen_info_pool: DataGenInfoPool,
asyncio_event_loop: asyncio.AbstractEventLoop,
) -> None:
"""Main loop for the environment."""
):
"""Main asyncio loop for the environment.
Args:
env: The environment to run the main step loop on.
env_reset_queue: The asyncio queue to handle reset request the environment.
env_action_queue: The asyncio queue to handle actions to for executing actions.
shared_datagen_info_pool: The shared datagen info pool that stores source demo info.
asyncio_event_loop: The main asyncio event loop.
"""
global num_success, num_failures, num_attempts
env_id_tensor = torch.tensor([0], dtype=torch.int64, device=env.device)
prev_num_attempts = 0
# simulate environment -- run everything in inference mode
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while True:
actions = torch.zeros(env.unwrapped.action_space.shape)
# check if any environment needs to be reset while waiting for actions
while env_action_queue.qsize() != env.num_envs:
asyncio_event_loop.run_until_complete(asyncio.sleep(0))
while not env_reset_queue.empty():
env_id_tensor[0] = env_reset_queue.get_nowait()
env.reset(env_ids=env_id_tensor)
env_reset_queue.task_done()
actions = torch.zeros(env.action_space.shape)
# get actions from all the data generators
for i in range(env.unwrapped.num_envs):
for i in range(env.num_envs):
# an async-blocking call to get an action from a data generator
env_id, action = asyncio_event_loop.run_until_complete(env_action_queue.get())
actions[env_id] = action
......@@ -75,27 +102,30 @@ def env_loop(
env.step(actions)
# mark done so the data generators can continue with the step results
for i in range(env.unwrapped.num_envs):
for i in range(env.num_envs):
env_action_queue.task_done()
if prev_num_attempts != num_attempts:
prev_num_attempts = num_attempts
generated_sucess_rate = 100 * num_success / num_attempts if num_attempts > 0 else 0.0
print("")
print("*" * 50)
print(f"have {num_success} successes out of {num_attempts} trials so far")
print(f"have {num_failures} failures out of {num_attempts} trials so far")
print("*" * 50)
print("*" * 50, "\033[K")
print(
f"{num_success}/{num_attempts} ({generated_sucess_rate:.1f}%) successful demos generated by"
" mimic\033[K"
)
print("*" * 50, "\033[K")
# termination condition is on enough successes if @guarantee_success or enough attempts otherwise
generation_guarantee = env.unwrapped.cfg.datagen_config.generation_guarantee
generation_num_trials = env.unwrapped.cfg.datagen_config.generation_num_trials
generation_guarantee = env.cfg.datagen_config.generation_guarantee
generation_num_trials = env.cfg.datagen_config.generation_num_trials
check_val = num_success if generation_guarantee else num_attempts
if check_val >= generation_num_trials:
print(f"Reached {generation_num_trials} successes/attempts. Exiting.")
break
# check that simulation is stopped or not
if env.unwrapped.sim.is_stopped():
if env.sim.is_stopped():
break
env.close()
......@@ -175,26 +205,28 @@ def setup_async_generation(
List of asyncio tasks for data generation
"""
asyncio_event_loop = asyncio.get_event_loop()
env_reset_queue = asyncio.Queue()
env_action_queue = asyncio.Queue()
shared_datagen_info_pool_lock = asyncio.Lock()
shared_datagen_info_pool = DataGenInfoPool(
env.unwrapped, env.unwrapped.cfg, env.unwrapped.device, asyncio_lock=shared_datagen_info_pool_lock
)
shared_datagen_info_pool = DataGenInfoPool(env, env.cfg, env.device, asyncio_lock=shared_datagen_info_pool_lock)
shared_datagen_info_pool.load_from_dataset_file(input_file)
print(f"Loaded {shared_datagen_info_pool.num_datagen_infos} to datagen info pool")
# Create and schedule data generator tasks
data_generator = DataGenerator(env=env.unwrapped, src_demo_datagen_info_pool=shared_datagen_info_pool)
data_generator = DataGenerator(env=env, src_demo_datagen_info_pool=shared_datagen_info_pool)
data_generator_asyncio_tasks = []
for i in range(num_envs):
task = asyncio_event_loop.create_task(
run_data_generator(env, i, env_action_queue, data_generator, success_term, pause_subtask=pause_subtask)
run_data_generator(
env, i, env_reset_queue, env_action_queue, data_generator, success_term, pause_subtask=pause_subtask
)
)
data_generator_asyncio_tasks.append(task)
return {
"tasks": data_generator_asyncio_tasks,
"event_loop": asyncio_event_loop,
"reset_queue": env_reset_queue,
"action_queue": env_action_queue,
"info_pool": shared_datagen_info_pool,
}
......@@ -7,10 +7,13 @@
A collection of classes used to represent waypoints and trajectories.
"""
import asyncio
import inspect
import torch
from copy import deepcopy
import isaaclab.utils.math as PoseUtils
from isaaclab.envs import ManagerBasedRLMimicEnv
from isaaclab.managers import TerminationTermCfg
class Waypoint:
......@@ -18,7 +21,7 @@ class Waypoint:
Represents a single desired 6-DoF waypoint, along with corresponding gripper actuation for this point.
"""
def __init__(self, eef_names, pose, gripper_action, noise=None):
def __init__(self, pose, gripper_action, noise=None):
"""
Args:
pose (torch.Tensor): 4x4 pose target for robot controller
......@@ -26,7 +29,6 @@ class Waypoint:
noise (float or None): action noise amplitude to apply during execution at this timestep
(for arm actions, not gripper actions)
"""
self.eef_names = eef_names
self.pose = pose
self.gripper_action = gripper_action
self.noise = noise
......@@ -54,7 +56,7 @@ class WaypointSequence:
self.sequence = deepcopy(sequence)
@classmethod
def from_poses(cls, eef_names, poses, gripper_actions, action_noise):
def from_poses(cls, poses, gripper_actions, action_noise):
"""
Instantiate a WaypointSequence object given a sequence of poses,
gripper actions, and action noise.
......@@ -79,7 +81,6 @@ class WaypointSequence:
# make WaypointSequence instance
sequence = [
Waypoint(
eef_names=eef_names,
pose=poses[t],
gripper_action=gripper_actions[t],
noise=action_noise[t, 0],
......@@ -202,7 +203,6 @@ class WaypointTrajectory:
def add_waypoint_sequence_for_target_pose(
self,
eef_names,
pose,
gripper_action,
num_steps,
......@@ -254,7 +254,6 @@ class WaypointTrajectory:
# add waypoint sequence for this set of poses
sequence = WaypointSequence.from_poses(
eef_names=eef_names,
poses=poses,
gripper_actions=gripper_actions,
action_noise=action_noise,
......@@ -281,7 +280,6 @@ class WaypointTrajectory:
def merge(
self,
other,
eef_names,
num_steps_interp=None,
num_steps_fixed=None,
action_noise=0.0,
......@@ -315,7 +313,6 @@ class WaypointTrajectory:
if need_interp:
# interpolation segment
self.add_waypoint_sequence_for_target_pose(
eef_names=eef_names,
pose=target_for_interpolation.pose,
gripper_action=target_for_interpolation.gripper_action,
num_steps=num_steps_interp,
......@@ -329,7 +326,6 @@ class WaypointTrajectory:
# account for the fact that we pop'd the first element of @other in anticipation of an interpolation segment
num_steps_fixed_to_use = num_steps_fixed if need_interp else (num_steps_fixed + 1)
self.add_waypoint_sequence_for_target_pose(
eef_names=eef_names,
pose=target_for_interpolation.pose,
gripper_action=target_for_interpolation.gripper_action,
num_steps=num_steps_fixed_to_use,
......@@ -343,67 +339,75 @@ class WaypointTrajectory:
# concatenate the trajectories
self.waypoint_sequences += other.waypoint_sequences
def get_full_sequence(self):
"""
Returns the full sequence of waypoints in the trajectory.
Returns:
sequence (WaypointSequence instance)
"""
return WaypointSequence(sequence=[waypoint for seq in self.waypoint_sequences for waypoint in seq.sequence])
class MultiWaypoint:
"""
A collection of Waypoint objects for multiple end effectors in the environment.
"""
def __init__(self, waypoints: dict[str, Waypoint]):
"""
Args:
waypoints (dict): a dictionary of waypionts of end effectors
"""
self.waypoints = waypoints
async def execute(
self,
env,
env_id,
success_term,
env: ManagerBasedRLMimicEnv,
success_term: TerminationTermCfg,
env_id: int = 0,
env_action_queue: asyncio.Queue | None = None,
):
"""
Main function to execute the trajectory. Will use env_interface.target_eef_pose_to_action to
convert each target pose at each waypoint to an action command, and pass that along to
env.step.
Executes the multi-waypoint eef actions in the environment.
Args:
env (Isaac Lab ManagerBasedEnv instance): environment to use for executing trajectory
env_id (int): environment index
success_term: success term to check if the task is successful
env_action_queue (asyncio.Queue): queue for sending actions to the environment
env: The environment to execute the multi-waypoint actions in.
success_term: The termination term to check for task success.
env_id: The environment ID to execute the multi-waypoint actions in.
env_action_queue: The asyncio queue to put the action into.
Returns:
results (dict): dictionary with the following items for the executed trajectory:
states (list): simulator state at each timestep
observations (list): observation dictionary at each timestep
datagen_infos (list): datagen_info at each timestep
actions (list): action executed at each timestep
success (bool): whether the trajectory successfully solved the task or not
A dictionary containing the state, observation, action, and success of the multi-waypoint actions.
"""
states = []
actions = []
observations = []
success = False
# iterate over waypoint sequences
for seq in self.waypoint_sequences:
# iterate over waypoints in each sequence
for j in range(len(seq)):
# current waypoint
waypoint = seq[j]
# current state and observation
obs = env.obs_buf
# current state
state = env.scene.get_state(is_relative=True)
# convert target pose and gripper action to env action
target_eef_pose_dict = {waypoint.eef_names[0]: waypoint.pose}
gripper_action_dict = {waypoint.eef_names[0]: waypoint.gripper_action}
# construct action from target poses and gripper actions
target_eef_pose_dict = {eef_name: waypoint.pose for eef_name, waypoint in self.waypoints.items()}
gripper_action_dict = {eef_name: waypoint.gripper_action for eef_name, waypoint in self.waypoints.items()}
if "action_noise_dict" in inspect.signature(env.target_eef_pose_to_action).parameters:
action_noise_dict = {eef_name: waypoint.noise for eef_name, waypoint in self.waypoints.items()}
play_action = env.target_eef_pose_to_action(
target_eef_pose_dict=target_eef_pose_dict,
gripper_action_dict=gripper_action_dict,
action_noise_dict=action_noise_dict,
env_id=env_id,
)
else:
# calling user-defined env.target_eef_pose_to_action() with noise parameter is deprecated
# (replaced by action_noise_dict)
play_action = env.target_eef_pose_to_action(
target_eef_pose_dict=target_eef_pose_dict,
gripper_action_dict=gripper_action_dict,
noise=waypoint.noise,
noise=max([waypoint.noise for waypoint in self.waypoints.values()]),
env_id=env_id,
)
# step environment
if not isinstance(play_action, torch.Tensor):
play_action = torch.tensor(play_action)
if play_action.dim() == 1 and play_action.size(0) == 7:
play_action = play_action.unsqueeze(0) # Reshape to [1, 7]
if play_action.dim() == 1:
play_action = play_action.unsqueeze(0) # Reshape with additional env dimension
# step environment
if env_action_queue is None:
obs, _, _, _, _ = env.step(play_action)
else:
......@@ -411,20 +415,12 @@ class WaypointTrajectory:
await env_action_queue.join()
obs = env.obs_buf
# collect data
states.append(state)
actions.append(play_action)
observations.append(obs)
cur_success_metric = bool(success_term.func(env, **success_term.params)[env_id])
# If the task success metric is True once during the execution, then the task is considered successful
success = success or cur_success_metric
success = bool(success_term.func(env, **success_term.params)[env_id])
results = dict(
states=states,
observations=observations,
actions=torch.stack(actions),
result = dict(
states=[state],
observations=[obs],
actions=[play_action],
success=success,
)
return results
return result
......@@ -59,10 +59,6 @@ class TestGenerateDataset(unittest.TestCase):
DATASETS_DOWNLOAD_DIR + "/dataset.hdf5",
"--output_file",
DATASETS_DOWNLOAD_DIR + "/annotated_dataset.hdf5",
"--signals",
"grasp_1",
"stack_1",
"grasp_2",
"--auto",
"--headless",
]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment