Unverified Commit 8fa76ddc authored by peterd-NV's avatar peterd-NV Committed by GitHub

Updates Mimic test cases to pytest format (#550)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

Updates the following tests to pytest format:

- test_pink_ik.py
- test_selection_strategy.py
- test_generate_dataset.py

Updates test_generate_dataset.py to check that the expected number of
annotations are successfully generated in the HDF5 during the annotation
phase. If not, then test returns failure. This ensures that physics is
behaving correctly and that the correct annotated demos are being
generated. Previously, if a physics issue is introduced, the test would
timeout instead of failing.

The annotate_demos.py script was updated to return the number of
successfully annotated demos in order to support the above test.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent a958ac56
...@@ -159,7 +159,7 @@ def main(): ...@@ -159,7 +159,7 @@ def main():
if episode_count == 0: if episode_count == 0:
print("No episodes found in the dataset.") print("No episodes found in the dataset.")
exit() return 0
# get output directory path and file name (without extension) from cli arguments # get output directory path and file name (without extension) from cli arguments
output_dir = os.path.dirname(args_cli.output_file) output_dir = os.path.dirname(args_cli.output_file)
...@@ -236,6 +236,7 @@ def main(): ...@@ -236,6 +236,7 @@ def main():
# simulate environment -- run everything in inference mode # simulate environment -- run everything in inference mode
exported_episode_count = 0 exported_episode_count = 0
processed_episode_count = 0 processed_episode_count = 0
successful_task_count = 0 # Counter for successful task completions
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode(): with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while simulation_app.is_running() and not simulation_app.is_exiting(): while simulation_app.is_running() and not simulation_app.is_exiting():
# Iterate over the episodes in the loaded dataset file # Iterate over the episodes in the loaded dataset file
...@@ -259,6 +260,7 @@ def main(): ...@@ -259,6 +260,7 @@ def main():
) )
env.recorder_manager.export_episodes() env.recorder_manager.export_episodes()
exported_episode_count += 1 exported_episode_count += 1
successful_task_count += 1 # Increment successful task counter
print("\tExported the annotated episode.") print("\tExported the annotated episode.")
else: else:
print("\tSkipped exporting the episode due to incomplete subtask annotations.") print("\tSkipped exporting the episode due to incomplete subtask annotations.")
...@@ -268,11 +270,16 @@ def main(): ...@@ -268,11 +270,16 @@ def main():
f"\nExported {exported_episode_count} (out of {processed_episode_count}) annotated" f"\nExported {exported_episode_count} (out of {processed_episode_count}) annotated"
f" episode{'s' if exported_episode_count > 1 else ''}." f" episode{'s' if exported_episode_count > 1 else ''}."
) )
print(
f"Successful task completions: {successful_task_count}"
) # This line is used by the dataset generation test case to check if the expected number of demos were annotated
print("Exiting the app.") print("Exiting the app.")
# Close environment after annotation is complete # Close environment after annotation is complete
env.close() env.close()
return successful_task_count
def replay_episode( def replay_episode(
env: ManagerBasedRLMimicEnv, env: ManagerBasedRLMimicEnv,
...@@ -440,6 +447,8 @@ def annotate_episode_in_manual_mode( ...@@ -440,6 +447,8 @@ def annotate_episode_in_manual_mode(
if __name__ == "__main__": if __name__ == "__main__":
# run the main function # run the main function
main() successful_task_count = main()
# close sim app # close sim app
simulation_app.close() simulation_app.close()
# exit with the number of successful task completions as return code
exit(successful_task_count)
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.42.24" version = "0.42.25"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.42.25 (2025-07-17)
~~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Updated test_pink_ik.py test case to pytest format.
0.42.24 (2025-06-25) 0.42.24 (2025-06-25)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -21,7 +21,8 @@ simulation_app = AppLauncher(headless=True).app ...@@ -21,7 +21,8 @@ simulation_app = AppLauncher(headless=True).app
import contextlib import contextlib
import gymnasium as gym import gymnasium as gym
import torch import torch
import unittest
import pytest
from isaaclab.utils.math import axis_angle_from_quat, matrix_from_quat, quat_from_matrix, quat_inv from isaaclab.utils.math import axis_angle_from_quat, matrix_from_quat, quat_from_matrix, quat_inv
...@@ -30,59 +31,66 @@ import isaaclab_tasks.manager_based.manipulation.pick_place # noqa: F401 ...@@ -30,59 +31,66 @@ import isaaclab_tasks.manager_based.manipulation.pick_place # noqa: F401
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
class TestPinkIKController(unittest.TestCase): @pytest.fixture
"""Test fixture for the Pink IK controller with the GR1T2 humanoid robot. def pink_ik_test_config():
"""Test configuration for Pink IK controller tests."""
This test validates that the Pink IK controller can accurately track commanded
end-effector poses for a humanoid robot. It specifically:
1. Creates a GR1T2 humanoid robot with the Pink IK controller
2. Sends target pose commands to the left and right hand roll links
3. Checks that the observed poses of the links match the target poses within tolerance
4. Tests adaptability by moving the hands up and down multiple times
The test succeeds when the controller can accurately converge to each new target
position, demonstrating both accuracy and adaptability to changing targets.
"""
def setUp(self):
# End effector position mean square error tolerance in meters # End effector position mean square error tolerance in meters
self.pos_tolerance = 0.03 # 2 cm pos_tolerance = 0.03 # 3 cm
# End effector orientation mean square error tolerance in radians # End effector orientation mean square error tolerance in radians
self.rot_tolerance = 0.17 # 10 degrees rot_tolerance = 0.17 # 10 degrees
# Number of environments # Number of environments
self.num_envs = 1 num_envs = 1
# Number of joints in the 2 robot hands # Number of joints in the 2 robot hands
self.num_joints_in_robot_hands = 22 num_joints_in_robot_hands = 22
# Number of steps to wait for controller convergence # Number of steps to wait for controller convergence
self.num_steps_controller_convergence = 25 num_steps_controller_convergence = 25
self.num_times_to_move_hands_up = 3 num_times_to_move_hands_up = 3
self.num_times_to_move_hands_down = 3 num_times_to_move_hands_down = 3
# Create starting setpoints with respect to the env origin frame # Create starting setpoints with respect to the env origin frame
# These are the setpoints for the forward kinematics result of the # These are the setpoints for the forward kinematics result of the
# InitialStateCfg specified in `PickPlaceGR1T2EnvCfg` # InitialStateCfg specified in `PickPlaceGR1T2EnvCfg`
y_axis_z_axis_90_rot_quaternion = [0.5, 0.5, -0.5, 0.5] y_axis_z_axis_90_rot_quaternion = [0.5, 0.5, -0.5, 0.5]
left_hand_roll_link_pos = [-0.23, 0.28, 1.1] left_hand_roll_link_pos = [-0.23, 0.28, 1.1]
self.left_hand_roll_link_pose = left_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion left_hand_roll_link_pose = left_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion
right_hand_roll_link_pos = [0.23, 0.28, 1.1] right_hand_roll_link_pos = [0.23, 0.28, 1.1]
self.right_hand_roll_link_pose = right_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion right_hand_roll_link_pose = right_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion
""" return {
Test fixtures. "pos_tolerance": pos_tolerance,
""" "rot_tolerance": rot_tolerance,
"num_envs": num_envs,
"num_joints_in_robot_hands": num_joints_in_robot_hands,
"num_steps_controller_convergence": num_steps_controller_convergence,
"num_times_to_move_hands_up": num_times_to_move_hands_up,
"num_times_to_move_hands_down": num_times_to_move_hands_down,
"left_hand_roll_link_pose": left_hand_roll_link_pose,
"right_hand_roll_link_pose": right_hand_roll_link_pose,
}
def test_gr1t2_ik_pose_abs(self):
"""Test IK controller for GR1T2 humanoid.""" def test_gr1t2_ik_pose_abs(pink_ik_test_config):
"""Test IK controller for GR1T2 humanoid.
This test validates that the Pink IK controller can accurately track commanded
end-effector poses for a humanoid robot. It specifically:
1. Creates a GR1T2 humanoid robot with the Pink IK controller
2. Sends target pose commands to the left and right hand roll links
3. Checks that the observed poses of the links match the target poses within tolerance
4. Tests adaptability by moving the hands up and down multiple times
The test succeeds when the controller can accurately converge to each new target
position, demonstrating both accuracy and adaptability to changing targets.
"""
env_name = "Isaac-PickPlace-GR1T2-Abs-v0" env_name = "Isaac-PickPlace-GR1T2-Abs-v0"
device = "cuda:0" device = "cuda:0"
env_cfg = parse_env_cfg(env_name, device=device, num_envs=self.num_envs) env_cfg = parse_env_cfg(env_name, device=device, num_envs=pink_ik_test_config["num_envs"])
# create environment from loaded config # create environment from loaded config
env = gym.make(env_name, cfg=env_cfg).unwrapped env = gym.make(env_name, cfg=env_cfg).unwrapped
...@@ -95,13 +103,17 @@ class TestPinkIKController(unittest.TestCase): ...@@ -95,13 +103,17 @@ class TestPinkIKController(unittest.TestCase):
move_hands_up = True move_hands_up = True
test_counter = 0 test_counter = 0
# Get poses from config
left_hand_roll_link_pose = pink_ik_test_config["left_hand_roll_link_pose"].copy()
right_hand_roll_link_pose = pink_ik_test_config["right_hand_roll_link_pose"].copy()
# simulate environment -- run everything in inference mode # simulate environment -- run everything in inference mode
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode(): with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while simulation_app.is_running() and not simulation_app.is_exiting(): while simulation_app.is_running() and not simulation_app.is_exiting():
num_runs += 1 num_runs += 1
setpoint_poses = self.left_hand_roll_link_pose + self.right_hand_roll_link_pose setpoint_poses = left_hand_roll_link_pose + right_hand_roll_link_pose
actions = setpoint_poses + [0.0] * self.num_joints_in_robot_hands actions = setpoint_poses + [0.0] * pink_ik_test_config["num_joints_in_robot_hands"]
actions = torch.tensor(actions, device=device) actions = torch.tensor(actions, device=device)
actions = torch.stack([actions for _ in range(env.num_envs)]) actions = torch.stack([actions for _ in range(env.num_envs)])
...@@ -118,11 +130,9 @@ class TestPinkIKController(unittest.TestCase): ...@@ -118,11 +130,9 @@ class TestPinkIKController(unittest.TestCase):
# The observations are also wrt the env origin frame # The observations are also wrt the env origin frame
left_hand_roll_link_feedback = left_hand_roll_link_pose_obs left_hand_roll_link_feedback = left_hand_roll_link_pose_obs
left_hand_roll_link_setpoint = ( left_hand_roll_link_setpoint = (
torch.tensor(self.left_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1) torch.tensor(left_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1)
)
left_hand_roll_link_pos_error = (
left_hand_roll_link_setpoint[:, :3] - left_hand_roll_link_feedback[:, :3]
) )
left_hand_roll_link_pos_error = left_hand_roll_link_setpoint[:, :3] - left_hand_roll_link_feedback[:, :3]
left_hand_roll_link_rot_error = axis_angle_from_quat( left_hand_roll_link_rot_error = axis_angle_from_quat(
quat_from_matrix( quat_from_matrix(
matrix_from_quat(left_hand_roll_link_setpoint[:, 3:]) matrix_from_quat(left_hand_roll_link_setpoint[:, 3:])
...@@ -132,11 +142,9 @@ class TestPinkIKController(unittest.TestCase): ...@@ -132,11 +142,9 @@ class TestPinkIKController(unittest.TestCase):
right_hand_roll_link_feedback = right_hand_roll_link_pose_obs right_hand_roll_link_feedback = right_hand_roll_link_pose_obs
right_hand_roll_link_setpoint = ( right_hand_roll_link_setpoint = (
torch.tensor(self.right_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1) torch.tensor(right_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1)
)
right_hand_roll_link_pos_error = (
right_hand_roll_link_setpoint[:, :3] - right_hand_roll_link_feedback[:, :3]
) )
right_hand_roll_link_pos_error = right_hand_roll_link_setpoint[:, :3] - right_hand_roll_link_feedback[:, :3]
right_hand_roll_link_rot_error = axis_angle_from_quat( right_hand_roll_link_rot_error = axis_angle_from_quat(
quat_from_matrix( quat_from_matrix(
matrix_from_quat(right_hand_roll_link_setpoint[:, 3:]) matrix_from_quat(right_hand_roll_link_setpoint[:, 3:])
...@@ -144,13 +152,13 @@ class TestPinkIKController(unittest.TestCase): ...@@ -144,13 +152,13 @@ class TestPinkIKController(unittest.TestCase):
) )
) )
if num_runs % self.num_steps_controller_convergence == 0: if num_runs % pink_ik_test_config["num_steps_controller_convergence"] == 0:
# Check if the left hand roll link is at the target position # Check if the left hand roll link is at the target position
torch.testing.assert_close( torch.testing.assert_close(
torch.mean(torch.abs(left_hand_roll_link_pos_error), dim=1), torch.mean(torch.abs(left_hand_roll_link_pos_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"), torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0, rtol=0.0,
atol=self.pos_tolerance, atol=pink_ik_test_config["pos_tolerance"],
) )
# Check if the right hand roll link is at the target position # Check if the right hand roll link is at the target position
...@@ -158,7 +166,7 @@ class TestPinkIKController(unittest.TestCase): ...@@ -158,7 +166,7 @@ class TestPinkIKController(unittest.TestCase):
torch.mean(torch.abs(right_hand_roll_link_pos_error), dim=1), torch.mean(torch.abs(right_hand_roll_link_pos_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"), torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0, rtol=0.0,
atol=self.pos_tolerance, atol=pink_ik_test_config["pos_tolerance"],
) )
# Check if the left hand roll link is at the target orientation # Check if the left hand roll link is at the target orientation
...@@ -166,7 +174,7 @@ class TestPinkIKController(unittest.TestCase): ...@@ -166,7 +174,7 @@ class TestPinkIKController(unittest.TestCase):
torch.mean(torch.abs(left_hand_roll_link_rot_error), dim=1), torch.mean(torch.abs(left_hand_roll_link_rot_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"), torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0, rtol=0.0,
atol=self.rot_tolerance, atol=pink_ik_test_config["rot_tolerance"],
) )
# Check if the right hand roll link is at the target orientation # Check if the right hand roll link is at the target orientation
...@@ -174,27 +182,28 @@ class TestPinkIKController(unittest.TestCase): ...@@ -174,27 +182,28 @@ class TestPinkIKController(unittest.TestCase):
torch.mean(torch.abs(right_hand_roll_link_rot_error), dim=1), torch.mean(torch.abs(right_hand_roll_link_rot_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"), torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0, rtol=0.0,
atol=self.rot_tolerance, atol=pink_ik_test_config["rot_tolerance"],
) )
# Change the setpoints to move the hands up and down as per the counter # Change the setpoints to move the hands up and down as per the counter
test_counter += 1 test_counter += 1
if move_hands_up and test_counter > self.num_times_to_move_hands_up: if move_hands_up and test_counter > pink_ik_test_config["num_times_to_move_hands_up"]:
move_hands_up = False move_hands_up = False
elif not move_hands_up and test_counter > ( elif not move_hands_up and test_counter > (
self.num_times_to_move_hands_down + self.num_times_to_move_hands_up pink_ik_test_config["num_times_to_move_hands_down"]
+ pink_ik_test_config["num_times_to_move_hands_up"]
): ):
# Test is done after moving the hands up and down # Test is done after moving the hands up and down
break break
if move_hands_up: if move_hands_up:
self.left_hand_roll_link_pose[1] += 0.05 left_hand_roll_link_pose[1] += 0.05
self.left_hand_roll_link_pose[2] += 0.05 left_hand_roll_link_pose[2] += 0.05
self.right_hand_roll_link_pose[1] += 0.05 right_hand_roll_link_pose[1] += 0.05
self.right_hand_roll_link_pose[2] += 0.05 right_hand_roll_link_pose[2] += 0.05
else: else:
self.left_hand_roll_link_pose[1] -= 0.05 left_hand_roll_link_pose[1] -= 0.05
self.left_hand_roll_link_pose[2] -= 0.05 left_hand_roll_link_pose[2] -= 0.05
self.right_hand_roll_link_pose[1] -= 0.05 right_hand_roll_link_pose[1] -= 0.05
self.right_hand_roll_link_pose[2] -= 0.05 right_hand_roll_link_pose[2] -= 0.05
env.close() env.close()
[package] [package]
# Semantic Versioning is used: https://semver.org/ # Semantic Versioning is used: https://semver.org/
version = "1.0.10" version = "1.0.11"
# Description # Description
category = "isaaclab" category = "isaaclab"
......
Changelog Changelog
--------- ---------
1.0.11 (2025-07-17)
~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Updated test_selection_strategy.py and test_generate_dataset.py test cases to pytest format.
* Updated annotate_demos.py script to return the number of successful task completions as the exit code to support check in test_generate_dataset.py test case.
1.0.10 (2025-07-08) 1.0.10 (2025-07-08)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -13,35 +13,36 @@ simulation_app = AppLauncher(headless=True).app ...@@ -13,35 +13,36 @@ simulation_app = AppLauncher(headless=True).app
import os import os
import subprocess import subprocess
import tempfile import tempfile
import unittest
import pytest
from isaaclab.utils.assets import ISAACLAB_NUCLEUS_DIR, retrieve_file_path from isaaclab.utils.assets import ISAACLAB_NUCLEUS_DIR, retrieve_file_path
DATASETS_DOWNLOAD_DIR = tempfile.mkdtemp(suffix="_Isaac-Stack-Cube-Franka-IK-Rel-Mimic-v0") DATASETS_DOWNLOAD_DIR = tempfile.mkdtemp(suffix="_Isaac-Stack-Cube-Franka-IK-Rel-Mimic-v0")
NUCLEUS_DATASET_PATH = os.path.join(ISAACLAB_NUCLEUS_DIR, "Tests", "Mimic", "dataset.hdf5") NUCLEUS_DATASET_PATH = os.path.join(ISAACLAB_NUCLEUS_DIR, "Tests", "Mimic", "dataset.hdf5")
EXPECTED_SUCCESSFUL_ANNOTATIONS = 10
class TestGenerateDataset(unittest.TestCase): @pytest.fixture
"""Test the dataset generation behavior of the Isaac Lab Mimic workflow.""" def setup_test_environment():
def setUp(self):
"""Set up the environment for testing.""" """Set up the environment for testing."""
# Create the datasets directory if it does not exist # Create the datasets directory if it does not exist
if not os.path.exists(DATASETS_DOWNLOAD_DIR): if not os.path.exists(DATASETS_DOWNLOAD_DIR):
print("Creating directory : ", DATASETS_DOWNLOAD_DIR) print("Creating directory : ", DATASETS_DOWNLOAD_DIR)
os.makedirs(DATASETS_DOWNLOAD_DIR) os.makedirs(DATASETS_DOWNLOAD_DIR)
# Try to download the dataset from Nucleus # Try to download the dataset from Nucleus
try: try:
retrieve_file_path(NUCLEUS_DATASET_PATH, DATASETS_DOWNLOAD_DIR) retrieve_file_path(NUCLEUS_DATASET_PATH, DATASETS_DOWNLOAD_DIR)
except Exception as e: except Exception as e:
print(e) print(e)
print("Could not download dataset from Nucleus") print("Could not download dataset from Nucleus")
self.fail( pytest.fail(
"The dataset required for this test is currently unavailable. Dataset path: " + NUCLEUS_DATASET_PATH "The dataset required for this test is currently unavailable. Dataset path: " + NUCLEUS_DATASET_PATH
) )
# Set the environment variable PYTHONUNBUFFERED to 1 to get all text outputs in result.stdout # Set the environment variable PYTHONUNBUFFERED to 1 to get all text outputs in result.stdout
self.pythonunbuffered_env_var_ = os.environ.get("PYTHONUNBUFFERED") pythonunbuffered_env_var_ = os.environ.get("PYTHONUNBUFFERED")
os.environ["PYTHONUNBUFFERED"] = "1" os.environ["PYTHONUNBUFFERED"] = "1"
# Automatically detect the workflow root (backtrack from current file location) # Automatically detect the workflow root (backtrack from current file location)
...@@ -67,26 +68,48 @@ class TestGenerateDataset(unittest.TestCase): ...@@ -67,26 +68,48 @@ class TestGenerateDataset(unittest.TestCase):
# Execute the command and capture the result # Execute the command and capture the result
result = subprocess.run(config_command, capture_output=True, text=True) result = subprocess.run(config_command, capture_output=True, text=True)
print(f"Annotate demos result: {result.returncode}\n\n\n\n\n\n\n\n\n\n\n\n")
# Print the result for debugging purposes # Print the result for debugging purposes
print("Config generation result:") print("Config generation result:")
print(result.stdout) # Print standard output from the command print(result.stdout) # Print standard output from the command
print(result.stderr) # Print standard error from the command print(result.stderr) # Print standard error from the command
# Check if the config generation was successful # Check if the config generation was successful
self.assertEqual(result.returncode, 0, msg=result.stderr) assert result.returncode == 0, result.stderr
def tearDown(self): # Check that at least one task was completed successfully by parsing stdout
"""Clean up after tests.""" # Look for the line that reports successful task completions
if self.pythonunbuffered_env_var_: success_line = None
os.environ["PYTHONUNBUFFERED"] = self.pythonunbuffered_env_var_ for line in result.stdout.split("\n"):
if "Successful task completions:" in line:
success_line = line
break
assert success_line is not None, "Could not find 'Successful task completions:' in output"
# Extract the number from the line
try:
successful_count = int(success_line.split(":")[-1].strip())
assert (
successful_count == EXPECTED_SUCCESSFUL_ANNOTATIONS
), f"Expected 10 successful annotations but got {successful_count}"
except (ValueError, IndexError) as e:
pytest.fail(f"Could not parse successful task count from line: '{success_line}'. Error: {e}")
# Yield the workflow root for use in tests
yield workflow_root
# Cleanup: restore the original environment variable
if pythonunbuffered_env_var_:
os.environ["PYTHONUNBUFFERED"] = pythonunbuffered_env_var_
else: else:
del os.environ["PYTHONUNBUFFERED"] del os.environ["PYTHONUNBUFFERED"]
def test_generate_dataset(self):
def test_generate_dataset(setup_test_environment):
"""Test the dataset generation script.""" """Test the dataset generation script."""
# Automatically detect the workflow root (backtrack from current file location) workflow_root = setup_test_environment
current_dir = os.path.dirname(os.path.abspath(__file__))
workflow_root = os.path.abspath(os.path.join(current_dir, "../../.."))
# Define the command to run the dataset generation script # Define the command to run the dataset generation script
command = [ command = [
...@@ -111,12 +134,8 @@ class TestGenerateDataset(unittest.TestCase): ...@@ -111,12 +134,8 @@ class TestGenerateDataset(unittest.TestCase):
print(result.stderr) # Print standard error from the command print(result.stderr) # Print standard error from the command
# Check if the script executed successfully # Check if the script executed successfully
self.assertEqual(result.returncode, 0, msg=result.stderr) assert result.returncode == 0, result.stderr
# Check for specific output # Check for specific output
expected_output = "successes/attempts. Exiting" expected_output = "successes/attempts. Exiting"
self.assertIn(expected_output, result.stdout) assert expected_output in result.stdout
if __name__ == "__main__":
unittest.main()
...@@ -10,7 +10,8 @@ simulation_app = AppLauncher(headless=True).app ...@@ -10,7 +10,8 @@ simulation_app = AppLauncher(headless=True).app
import numpy as np import numpy as np
import torch import torch
import unittest
import pytest
import isaaclab.utils.math as PoseUtils import isaaclab.utils.math as PoseUtils
...@@ -26,15 +27,19 @@ from isaaclab_mimic.datagen.selection_strategy import ( ...@@ -26,15 +27,19 @@ from isaaclab_mimic.datagen.selection_strategy import (
NUM_ITERS = 1000 NUM_ITERS = 1000
class TestNearestNeighborObjectStrategy(unittest.TestCase): @pytest.fixture
"""Test the NearestNeighborObjectStrategy class.""" def nearest_neighbor_object_strategy():
"""Fixture for NearestNeighborObjectStrategy."""
return NearestNeighborObjectStrategy()
def setUp(self): @pytest.fixture
"""Set up test cases for the NearestNeighborObjectStrategy.""" def nearest_neighbor_robot_distance_strategy():
# Initialize the strategy object for selecting nearest neighbors """Fixture for NearestNeighborRobotDistanceStrategy."""
self.strategy = NearestNeighborObjectStrategy() return NearestNeighborRobotDistanceStrategy()
def test_select_source_demo_identity_orientations(self):
def test_select_source_demo_identity_orientations_object_strategy(nearest_neighbor_object_strategy):
"""Test the selection of source demonstrations using two distinct object_pose clusters. """Test the selection of source demonstrations using two distinct object_pose clusters.
This method generates two clusters of object poses and randomly adjusts the current object pose within This method generates two clusters of object poses and randomly adjusts the current object pose within
...@@ -50,15 +55,13 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase): ...@@ -50,15 +55,13 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
# Generate object poses for cluster 1 with varying translations # Generate object poses for cluster 1 with varying translations
src_object_poses_in_world_cluster_1 = [ src_object_poses_in_world_cluster_1 = [
torch.eye(4) torch.eye(4) + torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
+ torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
for i in range(cluster_1_range_min, cluster_1_range_max) for i in range(cluster_1_range_min, cluster_1_range_max)
] ]
# Generate object poses for cluster 2 similarly # Generate object poses for cluster 2 similarly
src_object_poses_in_world_cluster_2 = [ src_object_poses_in_world_cluster_2 = [
torch.eye(4) torch.eye(4) + torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
+ torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
for i in range(cluster_2_range_min, cluster_2_range_max) for i in range(cluster_2_range_min, cluster_2_range_max)
] ]
...@@ -89,7 +92,7 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase): ...@@ -89,7 +92,7 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness # Select source demonstrations multiple times to check randomness
selected_indices = [ selected_indices = [
self.strategy.select_source_demo( nearest_neighbor_object_strategy.select_source_demo(
eef_pose, eef_pose,
cluster_1_curr_object_pose, cluster_1_curr_object_pose,
src_subtask_datagen_infos, src_subtask_datagen_infos,
...@@ -101,10 +104,9 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase): ...@@ -101,10 +104,9 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
] ]
# Assert that all selected indices are valid indices within cluster 1 # Assert that all selected indices are valid indices within cluster 1
self.assertTrue( assert np.all(
np.all(np.array(selected_indices) < len(src_object_poses_in_world_cluster_1)), np.array(selected_indices) < len(src_object_poses_in_world_cluster_1)
"Some selected indices are not part of cluster 1.", ), "Some selected indices are not part of cluster 1."
)
# Test 2: # Test 2:
# Set the current object pose to the first value of cluster 2 and add some noise # Set the current object pose to the first value of cluster 2 and add some noise
...@@ -122,7 +124,7 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase): ...@@ -122,7 +124,7 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness # Select source demonstrations multiple times to check randomness
selected_indices = [ selected_indices = [
self.strategy.select_source_demo( nearest_neighbor_object_strategy.select_source_demo(
eef_pose, eef_pose,
cluster_2_curr_object_pose, cluster_2_curr_object_pose,
src_subtask_datagen_infos, src_subtask_datagen_infos,
...@@ -134,25 +136,15 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase): ...@@ -134,25 +136,15 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
] ]
# Assert that all selected indices are valid indices within cluster 2 # Assert that all selected indices are valid indices within cluster 2
self.assertTrue( assert np.all(
np.all(np.array(selected_indices) < len(src_object_poses_in_world)), np.array(selected_indices) < len(src_object_poses_in_world)
"Some selected indices are not part of cluster 2.", ), "Some selected indices are not part of cluster 2."
) assert np.all(
self.assertTrue( np.array(selected_indices) > (len(src_object_poses_in_world_cluster_1) - 1)
np.all(np.array(selected_indices) > (len(src_object_poses_in_world_cluster_1) - 1)), ), "Some selected indices are not part of cluster 2."
"Some selected indices are not part of cluster 2.",
)
class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
"""Test the NearestNeighborRobotDistanceStrategy class."""
def setUp(self):
"""Set up test cases for the NearestNeighborRobotDistanceStrategy."""
# Initialize the strategy object for selecting nearest neighbors
self.strategy = NearestNeighborRobotDistanceStrategy()
def test_select_source_demo_identity_orientations(self): def test_select_source_demo_identity_orientations_robot_distance_strategy(nearest_neighbor_robot_distance_strategy):
"""Test the selection of source demonstrations based on identity-oriented poses with varying positions. """Test the selection of source demonstrations based on identity-oriented poses with varying positions.
This method generates two clusters of object poses and randomly adjusts the current object pose within This method generates two clusters of object poses and randomly adjusts the current object pose within
...@@ -181,9 +173,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase): ...@@ -181,9 +173,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
# Combine the poses from both clusters into a single list # Combine the poses from both clusters into a single list
# This represents the first end effector pose for the transformed subtask segment for each source demo # This represents the first end effector pose for the transformed subtask segment for each source demo
transformed_eef_in_world_poses_tensor = torch.stack( transformed_eef_in_world_poses_tensor = torch.stack(transformed_eef_pose_cluster_1 + transformed_eef_pose_cluster_2)
transformed_eef_pose_cluster_1 + transformed_eef_pose_cluster_2
)
# Create transformation matrices corresponding to each source object pose # Create transformation matrices corresponding to each source object pose
src_obj_in_world_poses = torch.stack([ src_obj_in_world_poses = torch.stack([
...@@ -210,17 +200,14 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase): ...@@ -210,17 +200,14 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
) )
# Check that both lists have the same length # Check that both lists have the same length
self.assertTrue( assert src_obj_in_world_poses.shape[0] == src_eef_in_world_poses.shape[0], (
src_obj_in_world_poses.shape[0] == src_eef_in_world_poses.shape[0], "Source object poses and end effector poses does not have the same length. "
"Source object poses and end effector poses does not have the same length." "This is a bug in the test code and not the source code."
"This is a bug in the test code and not the source code.",
) )
# Create DatagenInfo instances for these positions # Create DatagenInfo instances for these positions
src_subtask_datagen_infos = [ src_subtask_datagen_infos = [
DatagenInfo( DatagenInfo(eef_pose=src_eef_in_world_pose.unsqueeze(0), object_poses={0: src_obj_in_world_pose.unsqueeze(0)})
eef_pose=src_eef_in_world_pose.unsqueeze(0), object_poses={0: src_obj_in_world_pose.unsqueeze(0)}
)
for src_obj_in_world_pose, src_eef_in_world_pose in zip(src_obj_in_world_poses, src_eef_in_world_poses) for src_obj_in_world_pose, src_eef_in_world_pose in zip(src_obj_in_world_poses, src_eef_in_world_poses)
] ]
...@@ -239,7 +226,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase): ...@@ -239,7 +226,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness # Select source demonstrations multiple times to check randomness
selected_indices = [ selected_indices = [
self.strategy.select_source_demo( nearest_neighbor_robot_distance_strategy.select_source_demo(
curr_eef_in_world_pose, curr_eef_in_world_pose,
curr_object_in_world_pose, curr_object_in_world_pose,
src_subtask_datagen_infos, src_subtask_datagen_infos,
...@@ -251,10 +238,9 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase): ...@@ -251,10 +238,9 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
] ]
# Assert that all selected indices are valid indices within cluster 1 # Assert that all selected indices are valid indices within cluster 1
self.assertTrue( assert np.all(
np.all(np.array(selected_indices) < len(transformed_eef_pose_cluster_1)), np.array(selected_indices) < len(transformed_eef_pose_cluster_1)
"Some selected indices are not part of cluster 1.", ), "Some selected indices are not part of cluster 1."
)
# Test 2: Ensure the nearest neighbor is always part of cluster 2 # Test 2: Ensure the nearest neighbor is always part of cluster 2
max_deviation = 3 # Define a maximum deviation for the current pose max_deviation = 3 # Define a maximum deviation for the current pose
...@@ -271,7 +257,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase): ...@@ -271,7 +257,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness # Select source demonstrations multiple times to check randomness
selected_indices = [ selected_indices = [
self.strategy.select_source_demo( nearest_neighbor_robot_distance_strategy.select_source_demo(
curr_eef_in_world_pose, curr_eef_in_world_pose,
curr_object_in_world_pose, curr_object_in_world_pose,
src_subtask_datagen_infos, src_subtask_datagen_infos,
...@@ -283,15 +269,9 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase): ...@@ -283,15 +269,9 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
] ]
# Assert that all selected indices are valid indices within cluster 2 # Assert that all selected indices are valid indices within cluster 2
self.assertTrue( assert np.all(
np.all(np.array(selected_indices) < transformed_eef_in_world_poses_tensor.shape[0]), np.array(selected_indices) < transformed_eef_in_world_poses_tensor.shape[0]
"Some selected indices are not part of cluster 2.", ), "Some selected indices are not part of cluster 2."
) assert np.all(
self.assertTrue( np.array(selected_indices) > (len(transformed_eef_pose_cluster_1) - 1)
np.all(np.array(selected_indices) > (len(transformed_eef_pose_cluster_1) - 1)), ), "Some selected indices are not part of cluster 2."
"Some selected indices are not part of cluster 2.",
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment