Commit 8d9ef6fa authored by CY Chen's avatar CY Chen Committed by Kelly Guo

Adds subtask annotation checks in annotate_demos.py (#243)

# Description

Added additional checks for subtask annotations in mimic's
`annotate_demos.py` to make sure the exported demos are all with the
valid annotations required for running data generation in mimic.

Add script to merge HDF5 files into one dataset. Enables users to merge
together files after running annotate_demos.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- New feature

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Co-authored-by: 's avatarPeter Du <peterd@nvidia.com>
parent 25965b74
......@@ -164,11 +164,20 @@ def main():
env_cfg = parse_env_cfg(env_name, device=args_cli.device, num_envs=1)
env_cfg.env_name = args_cli.task
# extract success checking function to invoke manually
success_term = None
if hasattr(env_cfg.terminations, "success"):
success_term = env_cfg.terminations.success
env_cfg.terminations.success = None
else:
raise NotImplementedError("No success termination term was found in the environment.")
# Disable all termination terms
env_cfg.terminations = {}
env_cfg.terminations = None
# Set up recorder terms for mimic annotations
env_cfg.env_name = args_cli.task
env_cfg.recorders: MimicRecorderManagerCfg = MimicRecorderManagerCfg()
if not args_cli.auto:
# disable subtask term signals recorder term if in manual mode
......@@ -203,10 +212,13 @@ def main():
keyboard_interface.reset()
# simulate environment -- run everything in inference mode
exported_episode_count = 0
processed_episode_count = 0
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while simulation_app.is_running() and not simulation_app.is_exiting():
# Iterate over the episodes in the loaded dataset file
for episode_index, episode_name in enumerate(dataset_file_handler.get_episode_names()):
processed_episode_count += 1
subtask_indices = []
print(f"\nAnnotating episode #{episode_index} ({episode_name})")
episode = dataset_file_handler.load_episode(episode_name, env.unwrapped.device)
......@@ -231,6 +243,7 @@ def main():
action_tensor = torch.Tensor(action).reshape([1, action.shape[0]])
env.step(torch.Tensor(action_tensor))
is_episode_annotated_successfully = False
if not args_cli.auto:
print(f"\tSubtasks marked at action indices: {subtask_indices}")
if len(args_cli.signals) != len(subtask_indices):
......@@ -246,16 +259,39 @@ def main():
annotated_episode.add(
f"obs/datagen_info/subtask_term_signals/{args_cli.signals[subtask_index]}", subtask_signals
)
is_episode_annotated_successfully = True
else:
# check if all the subtask term signals are annotated
annotated_episode = env.unwrapped.recorder_manager.get_episode(0)
subtask_term_signal_dict = annotated_episode.data["obs"]["datagen_info"]["subtask_term_signals"]
is_episode_annotated_successfully = True
for signal_name, signal_flags in subtask_term_signal_dict.items():
if not torch.any(signal_flags):
is_episode_annotated_successfully = False
print(f'\tDid not detect completion for the subtask "{signal_name}".')
if not bool(success_term.func(env, **success_term.params)[0]):
is_episode_annotated_successfully = False
print("\tThe final task was not completed.")
if is_episode_annotated_successfully:
# set success to the recorded episode data and export to file
env.unwrapped.recorder_manager.set_success_to_episodes(
None, torch.tensor([[True]], dtype=torch.bool, device=env.unwrapped.device)
)
env.unwrapped.recorder_manager.export_episodes()
print("\tExported annotated episode.")
exported_episode_count += 1
print("\tExported the annotated episode.")
else:
print("\tSkipped exporting the episode due to incomplete subtask annotations.")
break
print(
f"\nExported {exported_episode_count} (out of {processed_episode_count}) annotated"
f" episode{'s' if exported_episode_count > 1 else ''}."
)
print("Exiting the app.")
# Close environment after annotation is complete
env.close()
......
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import argparse
import h5py
import os
parser = argparse.ArgumentParser(description="Merge a set of HDF5 datasets.")
parser.add_argument(
"--input_files",
type=str,
nargs="+",
default=[],
help="A list of paths to HDF5 files to merge.",
)
parser.add_argument("--output_file", type=str, default="merged_dataset.hdf5", help="File path to merged output.")
args_cli = parser.parse_args()
def merge_datasets():
for filepath in args_cli.input_files:
if not os.path.exists(filepath):
raise FileNotFoundError(f"The dataset file {filepath} does not exist.")
with h5py.File(args_cli.output_file, "w") as output:
episode_idx = 0
copy_attributes = True
for filepath in args_cli.input_files:
with h5py.File(filepath, "r") as input:
for episode, data in input["data"].items():
input.copy(f"data/{episode}", output, f"data/demo_{episode_idx}")
episode_idx += 1
if copy_attributes:
output["data"].attrs["env_args"] = input["data"].attrs["env_args"]
copy_attributes = False
print(f"Merged dataset saved to {args_cli.output_file}")
if __name__ == "__main__":
merge_datasets()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment