Unverified Commit 8fa76ddc authored by peterd-NV's avatar peterd-NV Committed by GitHub

Updates Mimic test cases to pytest format (#550)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

Updates the following tests to pytest format:

- test_pink_ik.py
- test_selection_strategy.py
- test_generate_dataset.py

Updates test_generate_dataset.py to check that the expected number of
annotations are successfully generated in the HDF5 during the annotation
phase. If not, then test returns failure. This ensures that physics is
behaving correctly and that the correct annotated demos are being
generated. Previously, if a physics issue is introduced, the test would
timeout instead of failing.

The annotate_demos.py script was updated to return the number of
successfully annotated demos in order to support the above test.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent a958ac56
......@@ -159,7 +159,7 @@ def main():
if episode_count == 0:
print("No episodes found in the dataset.")
exit()
return 0
# get output directory path and file name (without extension) from cli arguments
output_dir = os.path.dirname(args_cli.output_file)
......@@ -236,6 +236,7 @@ def main():
# simulate environment -- run everything in inference mode
exported_episode_count = 0
processed_episode_count = 0
successful_task_count = 0 # Counter for successful task completions
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while simulation_app.is_running() and not simulation_app.is_exiting():
# Iterate over the episodes in the loaded dataset file
......@@ -259,6 +260,7 @@ def main():
)
env.recorder_manager.export_episodes()
exported_episode_count += 1
successful_task_count += 1 # Increment successful task counter
print("\tExported the annotated episode.")
else:
print("\tSkipped exporting the episode due to incomplete subtask annotations.")
......@@ -268,11 +270,16 @@ def main():
f"\nExported {exported_episode_count} (out of {processed_episode_count}) annotated"
f" episode{'s' if exported_episode_count > 1 else ''}."
)
print(
f"Successful task completions: {successful_task_count}"
) # This line is used by the dataset generation test case to check if the expected number of demos were annotated
print("Exiting the app.")
# Close environment after annotation is complete
env.close()
return successful_task_count
def replay_episode(
env: ManagerBasedRLMimicEnv,
......@@ -440,6 +447,8 @@ def annotate_episode_in_manual_mode(
if __name__ == "__main__":
# run the main function
main()
successful_task_count = main()
# close sim app
simulation_app.close()
# exit with the number of successful task completions as return code
exit(successful_task_count)
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.42.24"
version = "0.42.25"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.42.25 (2025-07-17)
~~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Updated test_pink_ik.py test case to pytest format.
0.42.24 (2025-06-25)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -21,7 +21,8 @@ simulation_app = AppLauncher(headless=True).app
import contextlib
import gymnasium as gym
import torch
import unittest
import pytest
from isaaclab.utils.math import axis_angle_from_quat, matrix_from_quat, quat_from_matrix, quat_inv
......@@ -30,59 +31,66 @@ import isaaclab_tasks.manager_based.manipulation.pick_place # noqa: F401
from isaaclab_tasks.utils.parse_cfg import parse_env_cfg
class TestPinkIKController(unittest.TestCase):
"""Test fixture for the Pink IK controller with the GR1T2 humanoid robot.
This test validates that the Pink IK controller can accurately track commanded
end-effector poses for a humanoid robot. It specifically:
1. Creates a GR1T2 humanoid robot with the Pink IK controller
2. Sends target pose commands to the left and right hand roll links
3. Checks that the observed poses of the links match the target poses within tolerance
4. Tests adaptability by moving the hands up and down multiple times
The test succeeds when the controller can accurately converge to each new target
position, demonstrating both accuracy and adaptability to changing targets.
"""
def setUp(self):
@pytest.fixture
def pink_ik_test_config():
"""Test configuration for Pink IK controller tests."""
# End effector position mean square error tolerance in meters
self.pos_tolerance = 0.03 # 2 cm
pos_tolerance = 0.03 # 3 cm
# End effector orientation mean square error tolerance in radians
self.rot_tolerance = 0.17 # 10 degrees
rot_tolerance = 0.17 # 10 degrees
# Number of environments
self.num_envs = 1
num_envs = 1
# Number of joints in the 2 robot hands
self.num_joints_in_robot_hands = 22
num_joints_in_robot_hands = 22
# Number of steps to wait for controller convergence
self.num_steps_controller_convergence = 25
num_steps_controller_convergence = 25
self.num_times_to_move_hands_up = 3
self.num_times_to_move_hands_down = 3
num_times_to_move_hands_up = 3
num_times_to_move_hands_down = 3
# Create starting setpoints with respect to the env origin frame
# These are the setpoints for the forward kinematics result of the
# InitialStateCfg specified in `PickPlaceGR1T2EnvCfg`
y_axis_z_axis_90_rot_quaternion = [0.5, 0.5, -0.5, 0.5]
left_hand_roll_link_pos = [-0.23, 0.28, 1.1]
self.left_hand_roll_link_pose = left_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion
left_hand_roll_link_pose = left_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion
right_hand_roll_link_pos = [0.23, 0.28, 1.1]
self.right_hand_roll_link_pose = right_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion
right_hand_roll_link_pose = right_hand_roll_link_pos + y_axis_z_axis_90_rot_quaternion
"""
Test fixtures.
"""
return {
"pos_tolerance": pos_tolerance,
"rot_tolerance": rot_tolerance,
"num_envs": num_envs,
"num_joints_in_robot_hands": num_joints_in_robot_hands,
"num_steps_controller_convergence": num_steps_controller_convergence,
"num_times_to_move_hands_up": num_times_to_move_hands_up,
"num_times_to_move_hands_down": num_times_to_move_hands_down,
"left_hand_roll_link_pose": left_hand_roll_link_pose,
"right_hand_roll_link_pose": right_hand_roll_link_pose,
}
def test_gr1t2_ik_pose_abs(self):
"""Test IK controller for GR1T2 humanoid."""
def test_gr1t2_ik_pose_abs(pink_ik_test_config):
"""Test IK controller for GR1T2 humanoid.
This test validates that the Pink IK controller can accurately track commanded
end-effector poses for a humanoid robot. It specifically:
1. Creates a GR1T2 humanoid robot with the Pink IK controller
2. Sends target pose commands to the left and right hand roll links
3. Checks that the observed poses of the links match the target poses within tolerance
4. Tests adaptability by moving the hands up and down multiple times
The test succeeds when the controller can accurately converge to each new target
position, demonstrating both accuracy and adaptability to changing targets.
"""
env_name = "Isaac-PickPlace-GR1T2-Abs-v0"
device = "cuda:0"
env_cfg = parse_env_cfg(env_name, device=device, num_envs=self.num_envs)
env_cfg = parse_env_cfg(env_name, device=device, num_envs=pink_ik_test_config["num_envs"])
# create environment from loaded config
env = gym.make(env_name, cfg=env_cfg).unwrapped
......@@ -95,13 +103,17 @@ class TestPinkIKController(unittest.TestCase):
move_hands_up = True
test_counter = 0
# Get poses from config
left_hand_roll_link_pose = pink_ik_test_config["left_hand_roll_link_pose"].copy()
right_hand_roll_link_pose = pink_ik_test_config["right_hand_roll_link_pose"].copy()
# simulate environment -- run everything in inference mode
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while simulation_app.is_running() and not simulation_app.is_exiting():
num_runs += 1
setpoint_poses = self.left_hand_roll_link_pose + self.right_hand_roll_link_pose
actions = setpoint_poses + [0.0] * self.num_joints_in_robot_hands
setpoint_poses = left_hand_roll_link_pose + right_hand_roll_link_pose
actions = setpoint_poses + [0.0] * pink_ik_test_config["num_joints_in_robot_hands"]
actions = torch.tensor(actions, device=device)
actions = torch.stack([actions for _ in range(env.num_envs)])
......@@ -118,11 +130,9 @@ class TestPinkIKController(unittest.TestCase):
# The observations are also wrt the env origin frame
left_hand_roll_link_feedback = left_hand_roll_link_pose_obs
left_hand_roll_link_setpoint = (
torch.tensor(self.left_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1)
)
left_hand_roll_link_pos_error = (
left_hand_roll_link_setpoint[:, :3] - left_hand_roll_link_feedback[:, :3]
torch.tensor(left_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1)
)
left_hand_roll_link_pos_error = left_hand_roll_link_setpoint[:, :3] - left_hand_roll_link_feedback[:, :3]
left_hand_roll_link_rot_error = axis_angle_from_quat(
quat_from_matrix(
matrix_from_quat(left_hand_roll_link_setpoint[:, 3:])
......@@ -132,11 +142,9 @@ class TestPinkIKController(unittest.TestCase):
right_hand_roll_link_feedback = right_hand_roll_link_pose_obs
right_hand_roll_link_setpoint = (
torch.tensor(self.right_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1)
)
right_hand_roll_link_pos_error = (
right_hand_roll_link_setpoint[:, :3] - right_hand_roll_link_feedback[:, :3]
torch.tensor(right_hand_roll_link_pose, device=device).unsqueeze(0).repeat(env.num_envs, 1)
)
right_hand_roll_link_pos_error = right_hand_roll_link_setpoint[:, :3] - right_hand_roll_link_feedback[:, :3]
right_hand_roll_link_rot_error = axis_angle_from_quat(
quat_from_matrix(
matrix_from_quat(right_hand_roll_link_setpoint[:, 3:])
......@@ -144,13 +152,13 @@ class TestPinkIKController(unittest.TestCase):
)
)
if num_runs % self.num_steps_controller_convergence == 0:
if num_runs % pink_ik_test_config["num_steps_controller_convergence"] == 0:
# Check if the left hand roll link is at the target position
torch.testing.assert_close(
torch.mean(torch.abs(left_hand_roll_link_pos_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0,
atol=self.pos_tolerance,
atol=pink_ik_test_config["pos_tolerance"],
)
# Check if the right hand roll link is at the target position
......@@ -158,7 +166,7 @@ class TestPinkIKController(unittest.TestCase):
torch.mean(torch.abs(right_hand_roll_link_pos_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0,
atol=self.pos_tolerance,
atol=pink_ik_test_config["pos_tolerance"],
)
# Check if the left hand roll link is at the target orientation
......@@ -166,7 +174,7 @@ class TestPinkIKController(unittest.TestCase):
torch.mean(torch.abs(left_hand_roll_link_rot_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0,
atol=self.rot_tolerance,
atol=pink_ik_test_config["rot_tolerance"],
)
# Check if the right hand roll link is at the target orientation
......@@ -174,27 +182,28 @@ class TestPinkIKController(unittest.TestCase):
torch.mean(torch.abs(right_hand_roll_link_rot_error), dim=1),
torch.zeros(env.num_envs, device="cuda:0"),
rtol=0.0,
atol=self.rot_tolerance,
atol=pink_ik_test_config["rot_tolerance"],
)
# Change the setpoints to move the hands up and down as per the counter
test_counter += 1
if move_hands_up and test_counter > self.num_times_to_move_hands_up:
if move_hands_up and test_counter > pink_ik_test_config["num_times_to_move_hands_up"]:
move_hands_up = False
elif not move_hands_up and test_counter > (
self.num_times_to_move_hands_down + self.num_times_to_move_hands_up
pink_ik_test_config["num_times_to_move_hands_down"]
+ pink_ik_test_config["num_times_to_move_hands_up"]
):
# Test is done after moving the hands up and down
break
if move_hands_up:
self.left_hand_roll_link_pose[1] += 0.05
self.left_hand_roll_link_pose[2] += 0.05
self.right_hand_roll_link_pose[1] += 0.05
self.right_hand_roll_link_pose[2] += 0.05
left_hand_roll_link_pose[1] += 0.05
left_hand_roll_link_pose[2] += 0.05
right_hand_roll_link_pose[1] += 0.05
right_hand_roll_link_pose[2] += 0.05
else:
self.left_hand_roll_link_pose[1] -= 0.05
self.left_hand_roll_link_pose[2] -= 0.05
self.right_hand_roll_link_pose[1] -= 0.05
self.right_hand_roll_link_pose[2] -= 0.05
left_hand_roll_link_pose[1] -= 0.05
left_hand_roll_link_pose[2] -= 0.05
right_hand_roll_link_pose[1] -= 0.05
right_hand_roll_link_pose[2] -= 0.05
env.close()
[package]
# Semantic Versioning is used: https://semver.org/
version = "1.0.10"
version = "1.0.11"
# Description
category = "isaaclab"
......
Changelog
---------
1.0.11 (2025-07-17)
~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Updated test_selection_strategy.py and test_generate_dataset.py test cases to pytest format.
* Updated annotate_demos.py script to return the number of successful task completions as the exit code to support check in test_generate_dataset.py test case.
1.0.10 (2025-07-08)
~~~~~~~~~~~~~~~~~~
......
......@@ -13,35 +13,36 @@ simulation_app = AppLauncher(headless=True).app
import os
import subprocess
import tempfile
import unittest
import pytest
from isaaclab.utils.assets import ISAACLAB_NUCLEUS_DIR, retrieve_file_path
DATASETS_DOWNLOAD_DIR = tempfile.mkdtemp(suffix="_Isaac-Stack-Cube-Franka-IK-Rel-Mimic-v0")
NUCLEUS_DATASET_PATH = os.path.join(ISAACLAB_NUCLEUS_DIR, "Tests", "Mimic", "dataset.hdf5")
EXPECTED_SUCCESSFUL_ANNOTATIONS = 10
class TestGenerateDataset(unittest.TestCase):
"""Test the dataset generation behavior of the Isaac Lab Mimic workflow."""
def setUp(self):
@pytest.fixture
def setup_test_environment():
"""Set up the environment for testing."""
# Create the datasets directory if it does not exist
if not os.path.exists(DATASETS_DOWNLOAD_DIR):
print("Creating directory : ", DATASETS_DOWNLOAD_DIR)
os.makedirs(DATASETS_DOWNLOAD_DIR)
# Try to download the dataset from Nucleus
try:
retrieve_file_path(NUCLEUS_DATASET_PATH, DATASETS_DOWNLOAD_DIR)
except Exception as e:
print(e)
print("Could not download dataset from Nucleus")
self.fail(
pytest.fail(
"The dataset required for this test is currently unavailable. Dataset path: " + NUCLEUS_DATASET_PATH
)
# Set the environment variable PYTHONUNBUFFERED to 1 to get all text outputs in result.stdout
self.pythonunbuffered_env_var_ = os.environ.get("PYTHONUNBUFFERED")
pythonunbuffered_env_var_ = os.environ.get("PYTHONUNBUFFERED")
os.environ["PYTHONUNBUFFERED"] = "1"
# Automatically detect the workflow root (backtrack from current file location)
......@@ -67,26 +68,48 @@ class TestGenerateDataset(unittest.TestCase):
# Execute the command and capture the result
result = subprocess.run(config_command, capture_output=True, text=True)
print(f"Annotate demos result: {result.returncode}\n\n\n\n\n\n\n\n\n\n\n\n")
# Print the result for debugging purposes
print("Config generation result:")
print(result.stdout) # Print standard output from the command
print(result.stderr) # Print standard error from the command
# Check if the config generation was successful
self.assertEqual(result.returncode, 0, msg=result.stderr)
assert result.returncode == 0, result.stderr
def tearDown(self):
"""Clean up after tests."""
if self.pythonunbuffered_env_var_:
os.environ["PYTHONUNBUFFERED"] = self.pythonunbuffered_env_var_
# Check that at least one task was completed successfully by parsing stdout
# Look for the line that reports successful task completions
success_line = None
for line in result.stdout.split("\n"):
if "Successful task completions:" in line:
success_line = line
break
assert success_line is not None, "Could not find 'Successful task completions:' in output"
# Extract the number from the line
try:
successful_count = int(success_line.split(":")[-1].strip())
assert (
successful_count == EXPECTED_SUCCESSFUL_ANNOTATIONS
), f"Expected 10 successful annotations but got {successful_count}"
except (ValueError, IndexError) as e:
pytest.fail(f"Could not parse successful task count from line: '{success_line}'. Error: {e}")
# Yield the workflow root for use in tests
yield workflow_root
# Cleanup: restore the original environment variable
if pythonunbuffered_env_var_:
os.environ["PYTHONUNBUFFERED"] = pythonunbuffered_env_var_
else:
del os.environ["PYTHONUNBUFFERED"]
def test_generate_dataset(self):
def test_generate_dataset(setup_test_environment):
"""Test the dataset generation script."""
# Automatically detect the workflow root (backtrack from current file location)
current_dir = os.path.dirname(os.path.abspath(__file__))
workflow_root = os.path.abspath(os.path.join(current_dir, "../../.."))
workflow_root = setup_test_environment
# Define the command to run the dataset generation script
command = [
......@@ -111,12 +134,8 @@ class TestGenerateDataset(unittest.TestCase):
print(result.stderr) # Print standard error from the command
# Check if the script executed successfully
self.assertEqual(result.returncode, 0, msg=result.stderr)
assert result.returncode == 0, result.stderr
# Check for specific output
expected_output = "successes/attempts. Exiting"
self.assertIn(expected_output, result.stdout)
if __name__ == "__main__":
unittest.main()
assert expected_output in result.stdout
......@@ -10,7 +10,8 @@ simulation_app = AppLauncher(headless=True).app
import numpy as np
import torch
import unittest
import pytest
import isaaclab.utils.math as PoseUtils
......@@ -26,15 +27,19 @@ from isaaclab_mimic.datagen.selection_strategy import (
NUM_ITERS = 1000
class TestNearestNeighborObjectStrategy(unittest.TestCase):
"""Test the NearestNeighborObjectStrategy class."""
@pytest.fixture
def nearest_neighbor_object_strategy():
"""Fixture for NearestNeighborObjectStrategy."""
return NearestNeighborObjectStrategy()
def setUp(self):
"""Set up test cases for the NearestNeighborObjectStrategy."""
# Initialize the strategy object for selecting nearest neighbors
self.strategy = NearestNeighborObjectStrategy()
@pytest.fixture
def nearest_neighbor_robot_distance_strategy():
"""Fixture for NearestNeighborRobotDistanceStrategy."""
return NearestNeighborRobotDistanceStrategy()
def test_select_source_demo_identity_orientations(self):
def test_select_source_demo_identity_orientations_object_strategy(nearest_neighbor_object_strategy):
"""Test the selection of source demonstrations using two distinct object_pose clusters.
This method generates two clusters of object poses and randomly adjusts the current object pose within
......@@ -50,15 +55,13 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
# Generate object poses for cluster 1 with varying translations
src_object_poses_in_world_cluster_1 = [
torch.eye(4)
+ torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
torch.eye(4) + torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
for i in range(cluster_1_range_min, cluster_1_range_max)
]
# Generate object poses for cluster 2 similarly
src_object_poses_in_world_cluster_2 = [
torch.eye(4)
+ torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
torch.eye(4) + torch.tensor([[0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, i], [0.0, 0.0, 0.0, -1.0]])
for i in range(cluster_2_range_min, cluster_2_range_max)
]
......@@ -89,7 +92,7 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness
selected_indices = [
self.strategy.select_source_demo(
nearest_neighbor_object_strategy.select_source_demo(
eef_pose,
cluster_1_curr_object_pose,
src_subtask_datagen_infos,
......@@ -101,10 +104,9 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
]
# Assert that all selected indices are valid indices within cluster 1
self.assertTrue(
np.all(np.array(selected_indices) < len(src_object_poses_in_world_cluster_1)),
"Some selected indices are not part of cluster 1.",
)
assert np.all(
np.array(selected_indices) < len(src_object_poses_in_world_cluster_1)
), "Some selected indices are not part of cluster 1."
# Test 2:
# Set the current object pose to the first value of cluster 2 and add some noise
......@@ -122,7 +124,7 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness
selected_indices = [
self.strategy.select_source_demo(
nearest_neighbor_object_strategy.select_source_demo(
eef_pose,
cluster_2_curr_object_pose,
src_subtask_datagen_infos,
......@@ -134,25 +136,15 @@ class TestNearestNeighborObjectStrategy(unittest.TestCase):
]
# Assert that all selected indices are valid indices within cluster 2
self.assertTrue(
np.all(np.array(selected_indices) < len(src_object_poses_in_world)),
"Some selected indices are not part of cluster 2.",
)
self.assertTrue(
np.all(np.array(selected_indices) > (len(src_object_poses_in_world_cluster_1) - 1)),
"Some selected indices are not part of cluster 2.",
)
class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
"""Test the NearestNeighborRobotDistanceStrategy class."""
assert np.all(
np.array(selected_indices) < len(src_object_poses_in_world)
), "Some selected indices are not part of cluster 2."
assert np.all(
np.array(selected_indices) > (len(src_object_poses_in_world_cluster_1) - 1)
), "Some selected indices are not part of cluster 2."
def setUp(self):
"""Set up test cases for the NearestNeighborRobotDistanceStrategy."""
# Initialize the strategy object for selecting nearest neighbors
self.strategy = NearestNeighborRobotDistanceStrategy()
def test_select_source_demo_identity_orientations(self):
def test_select_source_demo_identity_orientations_robot_distance_strategy(nearest_neighbor_robot_distance_strategy):
"""Test the selection of source demonstrations based on identity-oriented poses with varying positions.
This method generates two clusters of object poses and randomly adjusts the current object pose within
......@@ -181,9 +173,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
# Combine the poses from both clusters into a single list
# This represents the first end effector pose for the transformed subtask segment for each source demo
transformed_eef_in_world_poses_tensor = torch.stack(
transformed_eef_pose_cluster_1 + transformed_eef_pose_cluster_2
)
transformed_eef_in_world_poses_tensor = torch.stack(transformed_eef_pose_cluster_1 + transformed_eef_pose_cluster_2)
# Create transformation matrices corresponding to each source object pose
src_obj_in_world_poses = torch.stack([
......@@ -210,17 +200,14 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
)
# Check that both lists have the same length
self.assertTrue(
src_obj_in_world_poses.shape[0] == src_eef_in_world_poses.shape[0],
"Source object poses and end effector poses does not have the same length."
"This is a bug in the test code and not the source code.",
assert src_obj_in_world_poses.shape[0] == src_eef_in_world_poses.shape[0], (
"Source object poses and end effector poses does not have the same length. "
"This is a bug in the test code and not the source code."
)
# Create DatagenInfo instances for these positions
src_subtask_datagen_infos = [
DatagenInfo(
eef_pose=src_eef_in_world_pose.unsqueeze(0), object_poses={0: src_obj_in_world_pose.unsqueeze(0)}
)
DatagenInfo(eef_pose=src_eef_in_world_pose.unsqueeze(0), object_poses={0: src_obj_in_world_pose.unsqueeze(0)})
for src_obj_in_world_pose, src_eef_in_world_pose in zip(src_obj_in_world_poses, src_eef_in_world_poses)
]
......@@ -239,7 +226,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness
selected_indices = [
self.strategy.select_source_demo(
nearest_neighbor_robot_distance_strategy.select_source_demo(
curr_eef_in_world_pose,
curr_object_in_world_pose,
src_subtask_datagen_infos,
......@@ -251,10 +238,9 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
]
# Assert that all selected indices are valid indices within cluster 1
self.assertTrue(
np.all(np.array(selected_indices) < len(transformed_eef_pose_cluster_1)),
"Some selected indices are not part of cluster 1.",
)
assert np.all(
np.array(selected_indices) < len(transformed_eef_pose_cluster_1)
), "Some selected indices are not part of cluster 1."
# Test 2: Ensure the nearest neighbor is always part of cluster 2
max_deviation = 3 # Define a maximum deviation for the current pose
......@@ -271,7 +257,7 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
# Select source demonstrations multiple times to check randomness
selected_indices = [
self.strategy.select_source_demo(
nearest_neighbor_robot_distance_strategy.select_source_demo(
curr_eef_in_world_pose,
curr_object_in_world_pose,
src_subtask_datagen_infos,
......@@ -283,15 +269,9 @@ class TestNearestNeighborRobotDistanceStrategy(unittest.TestCase):
]
# Assert that all selected indices are valid indices within cluster 2
self.assertTrue(
np.all(np.array(selected_indices) < transformed_eef_in_world_poses_tensor.shape[0]),
"Some selected indices are not part of cluster 2.",
)
self.assertTrue(
np.all(np.array(selected_indices) > (len(transformed_eef_pose_cluster_1) - 1)),
"Some selected indices are not part of cluster 2.",
)
if __name__ == "__main__":
unittest.main()
assert np.all(
np.array(selected_indices) < transformed_eef_in_world_poses_tensor.shape[0]
), "Some selected indices are not part of cluster 2."
assert np.all(
np.array(selected_indices) > (len(transformed_eef_pose_cluster_1) - 1)
), "Some selected indices are not part of cluster 2."
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment