Unverified Commit ad441d97 authored by rwiltz's avatar rwiltz Committed by GitHub

Refactors teleop device factory to follow config class style (#3897)

# Description

Refactors the teleop factory to shift declaration of teleop devices and
retargeters out of the factory and into themselves.

Fixes # (issue)

## Type of change

- New feature (non-breaking change which adds functionality)

## Screenshots

## Checklist

- [x] I have read and understood the [contribution
guidelines](https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html)
- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent fdadc90e
...@@ -935,18 +935,15 @@ The retargeting system is designed to be extensible. You can create custom retar ...@@ -935,18 +935,15 @@ The retargeting system is designed to be extensible. You can create custom retar
# Return control commands in appropriate format # Return control commands in appropriate format
return torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) # Example output return torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) # Example output
3. Register your retargeter with the factory by adding it to the ``RETARGETER_MAP``: 3. Register your retargeter by setting ``retargeter_type`` on the config class:
.. code-block:: python .. code-block:: python
# Import your retargeter at the top of your module # Import your retargeter at the top of your module
from my_package.retargeters import MyCustomRetargeter, MyCustomRetargeterCfg from my_package.retargeters import MyCustomRetargeter, MyCustomRetargeterCfg
# Add your retargeter to the factory # Link the config to the implementation for factory construction
from isaaclab.devices.teleop_device_factory import RETARGETER_MAP MyCustomRetargeterCfg.retargeter_type = MyCustomRetargeter
# Register your retargeter type with its constructor
RETARGETER_MAP[MyCustomRetargeterCfg] = MyCustomRetargeter
4. Now you can use your custom retargeter in teleop device configurations: 4. Now you can use your custom retargeter in teleop device configurations:
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.48.2" version = "0.48.3"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.48.3 (2025-11-13)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Moved retargeter and device declaration out of factory and into the devices/retargeters themselves.
0.48.2 (2025-11-13) 0.48.2 (2025-11-13)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -18,8 +18,14 @@ from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg ...@@ -18,8 +18,14 @@ from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg
class DeviceCfg: class DeviceCfg:
"""Configuration for teleoperation devices.""" """Configuration for teleoperation devices."""
# Whether teleoperation should start active by default
teleoperation_active_default: bool = True
# Torch device string to place output tensors on
sim_device: str = "cpu" sim_device: str = "cpu"
# Retargeters that transform device data into robot commands
retargeters: list[RetargeterCfg] = field(default_factory=list) retargeters: list[RetargeterCfg] = field(default_factory=list)
# Concrete device class to construct for this config. Set by each device module.
class_type: type["DeviceBase"] | None = None
@dataclass @dataclass
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""Gamepad controller for SE(2) control.""" """Gamepad controller for SE(2) control."""
from __future__ import annotations
import numpy as np import numpy as np
import torch import torch
import weakref import weakref
...@@ -18,16 +20,6 @@ import omni ...@@ -18,16 +20,6 @@ import omni
from ..device_base import DeviceBase, DeviceCfg from ..device_base import DeviceBase, DeviceCfg
@dataclass
class Se2GamepadCfg(DeviceCfg):
"""Configuration for SE2 gamepad devices."""
v_x_sensitivity: float = 1.0
v_y_sensitivity: float = 1.0
omega_z_sensitivity: float = 1.0
dead_zone: float = 0.01
class Se2Gamepad(DeviceBase): class Se2Gamepad(DeviceBase):
r"""A gamepad controller for sending SE(2) commands as velocity commands. r"""A gamepad controller for sending SE(2) commands as velocity commands.
...@@ -209,3 +201,14 @@ class Se2Gamepad(DeviceBase): ...@@ -209,3 +201,14 @@ class Se2Gamepad(DeviceBase):
command[command_sign] *= -1 command[command_sign] *= -1
return command return command
@dataclass
class Se2GamepadCfg(DeviceCfg):
"""Configuration for SE2 gamepad devices."""
v_x_sensitivity: float = 1.0
v_y_sensitivity: float = 1.0
omega_z_sensitivity: float = 1.0
dead_zone: float = 0.01
class_type: type[DeviceBase] = Se2Gamepad
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""Gamepad controller for SE(3) control.""" """Gamepad controller for SE(3) control."""
from __future__ import annotations
import numpy as np import numpy as np
import torch import torch
import weakref import weakref
...@@ -18,17 +20,6 @@ import omni ...@@ -18,17 +20,6 @@ import omni
from ..device_base import DeviceBase, DeviceCfg from ..device_base import DeviceBase, DeviceCfg
@dataclass
class Se3GamepadCfg(DeviceCfg):
"""Configuration for SE3 gamepad devices."""
gripper_term: bool = True
dead_zone: float = 0.01 # For gamepad devices
pos_sensitivity: float = 1.0
rot_sensitivity: float = 1.6
retargeters: None = None
class Se3Gamepad(DeviceBase): class Se3Gamepad(DeviceBase):
"""A gamepad controller for sending SE(3) commands as delta poses and binary command (open/close). """A gamepad controller for sending SE(3) commands as delta poses and binary command (open/close).
...@@ -264,3 +255,14 @@ class Se3Gamepad(DeviceBase): ...@@ -264,3 +255,14 @@ class Se3Gamepad(DeviceBase):
delta_command[delta_command_sign] *= -1 delta_command[delta_command_sign] *= -1
return delta_command return delta_command
@dataclass
class Se3GamepadCfg(DeviceCfg):
"""Configuration for SE3 gamepad devices."""
gripper_term: bool = True
dead_zone: float = 0.01 # For gamepad devices
pos_sensitivity: float = 1.0
rot_sensitivity: float = 1.6
class_type: type[DeviceBase] = Se3Gamepad
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""Haply device controller for SE3 control with force feedback.""" """Haply device controller for SE3 control with force feedback."""
from __future__ import annotations
import asyncio import asyncio
import json import json
import numpy as np import numpy as np
...@@ -25,23 +27,6 @@ from ..device_base import DeviceBase, DeviceCfg ...@@ -25,23 +27,6 @@ from ..device_base import DeviceBase, DeviceCfg
from ..retargeter_base import RetargeterBase from ..retargeter_base import RetargeterBase
@dataclass
class HaplyDeviceCfg(DeviceCfg):
"""Configuration for Haply device.
Attributes:
websocket_uri: WebSocket URI for Haply SDK connection
pos_sensitivity: Position sensitivity scaling factor
data_rate: Data exchange rate in Hz
limit_force: Maximum force magnitude in Newtons (safety limit)
"""
websocket_uri: str = "ws://localhost:10001"
pos_sensitivity: float = 1.0
data_rate: float = 200.0
limit_force: float = 2.0
class HaplyDevice(DeviceBase): class HaplyDevice(DeviceBase):
"""A Haply device controller for sending SE(3) commands with force feedback. """A Haply device controller for sending SE(3) commands with force feedback.
...@@ -387,3 +372,21 @@ class HaplyDevice(DeviceBase): ...@@ -387,3 +372,21 @@ class HaplyDevice(DeviceBase):
await asyncio.sleep(2.0) await asyncio.sleep(2.0)
else: else:
break break
@dataclass
class HaplyDeviceCfg(DeviceCfg):
"""Configuration for Haply device.
Attributes:
websocket_uri: WebSocket URI for Haply SDK connection
pos_sensitivity: Position sensitivity scaling factor
data_rate: Data exchange rate in Hz
limit_force: Maximum force magnitude in Newtons (safety limit)
"""
websocket_uri: str = "ws://localhost:10001"
pos_sensitivity: float = 1.0
data_rate: float = 200.0
limit_force: float = 2.0
class_type: type[DeviceBase] = HaplyDevice
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""Keyboard controller for SE(2) control.""" """Keyboard controller for SE(2) control."""
from __future__ import annotations
import numpy as np import numpy as np
import torch import torch
import weakref import weakref
...@@ -17,15 +19,6 @@ import omni ...@@ -17,15 +19,6 @@ import omni
from ..device_base import DeviceBase, DeviceCfg from ..device_base import DeviceBase, DeviceCfg
@dataclass
class Se2KeyboardCfg(DeviceCfg):
"""Configuration for SE2 keyboard devices."""
v_x_sensitivity: float = 0.8
v_y_sensitivity: float = 0.4
omega_z_sensitivity: float = 1.0
class Se2Keyboard(DeviceBase): class Se2Keyboard(DeviceBase):
r"""A keyboard controller for sending SE(2) commands as velocity commands. r"""A keyboard controller for sending SE(2) commands as velocity commands.
...@@ -178,3 +171,13 @@ class Se2Keyboard(DeviceBase): ...@@ -178,3 +171,13 @@ class Se2Keyboard(DeviceBase):
"NUMPAD_9": np.asarray([0.0, 0.0, -1.0]) * self.omega_z_sensitivity, "NUMPAD_9": np.asarray([0.0, 0.0, -1.0]) * self.omega_z_sensitivity,
"X": np.asarray([0.0, 0.0, -1.0]) * self.omega_z_sensitivity, "X": np.asarray([0.0, 0.0, -1.0]) * self.omega_z_sensitivity,
} }
@dataclass
class Se2KeyboardCfg(DeviceCfg):
"""Configuration for SE2 keyboard devices."""
v_x_sensitivity: float = 0.8
v_y_sensitivity: float = 0.4
omega_z_sensitivity: float = 1.0
class_type: type[DeviceBase] = Se2Keyboard
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""Keyboard controller for SE(3) control.""" """Keyboard controller for SE(3) control."""
from __future__ import annotations
import numpy as np import numpy as np
import torch import torch
import weakref import weakref
...@@ -18,16 +20,6 @@ import omni ...@@ -18,16 +20,6 @@ import omni
from ..device_base import DeviceBase, DeviceCfg from ..device_base import DeviceBase, DeviceCfg
@dataclass
class Se3KeyboardCfg(DeviceCfg):
"""Configuration for SE3 keyboard devices."""
gripper_term: bool = True
pos_sensitivity: float = 0.4
rot_sensitivity: float = 0.8
retargeters: None = None
class Se3Keyboard(DeviceBase): class Se3Keyboard(DeviceBase):
"""A keyboard controller for sending SE(3) commands as delta poses and binary command (open/close). """A keyboard controller for sending SE(3) commands as delta poses and binary command (open/close).
...@@ -206,3 +198,14 @@ class Se3Keyboard(DeviceBase): ...@@ -206,3 +198,14 @@ class Se3Keyboard(DeviceBase):
"C": np.asarray([0.0, 0.0, 1.0]) * self.rot_sensitivity, "C": np.asarray([0.0, 0.0, 1.0]) * self.rot_sensitivity,
"V": np.asarray([0.0, 0.0, -1.0]) * self.rot_sensitivity, "V": np.asarray([0.0, 0.0, -1.0]) * self.rot_sensitivity,
} }
@dataclass
class Se3KeyboardCfg(DeviceCfg):
"""Configuration for SE3 keyboard devices."""
gripper_term: bool = True
pos_sensitivity: float = 0.4
rot_sensitivity: float = 0.8
retargeters: None = None
class_type: type[DeviceBase] = Se3Keyboard
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
Manus and Vive for teleoperation and interaction. Manus and Vive for teleoperation and interaction.
""" """
from __future__ import annotations
import contextlib import contextlib
import numpy as np import numpy as np
from collections.abc import Callable from collections.abc import Callable
...@@ -34,13 +36,6 @@ from isaacsim.core.prims import SingleXFormPrim ...@@ -34,13 +36,6 @@ from isaacsim.core.prims import SingleXFormPrim
from .manus_vive_utils import HAND_JOINT_MAP, ManusViveIntegration from .manus_vive_utils import HAND_JOINT_MAP, ManusViveIntegration
@dataclass
class ManusViveCfg(DeviceCfg):
"""Configuration for Manus and Vive."""
xr_cfg: XrCfg | None = None
class ManusVive(DeviceBase): class ManusVive(DeviceBase):
"""Manus gloves and Vive trackers for teleoperation and interaction. """Manus gloves and Vive trackers for teleoperation and interaction.
...@@ -246,3 +241,11 @@ class ManusVive(DeviceBase): ...@@ -246,3 +241,11 @@ class ManusVive(DeviceBase):
elif "reset" in msg: elif "reset" in msg:
if "RESET" in self._additional_callbacks: if "RESET" in self._additional_callbacks:
self._additional_callbacks["RESET"]() self._additional_callbacks["RESET"]()
@dataclass
class ManusViveCfg(DeviceCfg):
"""Configuration for Manus and Vive."""
xr_cfg: XrCfg | None = None
class_type: type[DeviceBase] = ManusVive
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""OpenXR-powered device for teleoperation and interaction.""" """OpenXR-powered device for teleoperation and interaction."""
from __future__ import annotations
import contextlib import contextlib
import numpy as np import numpy as np
from collections.abc import Callable from collections.abc import Callable
...@@ -26,14 +28,8 @@ XRPoseValidityFlags = None ...@@ -26,14 +28,8 @@ XRPoseValidityFlags = None
with contextlib.suppress(ModuleNotFoundError): with contextlib.suppress(ModuleNotFoundError):
from omni.kit.xr.core import XRCore, XRPoseValidityFlags from omni.kit.xr.core import XRCore, XRPoseValidityFlags
from isaacsim.core.prims import SingleXFormPrim
@dataclass
class OpenXRDeviceCfg(DeviceCfg):
"""Configuration for OpenXR devices."""
xr_cfg: XrCfg | None = None from isaacsim.core.prims import SingleXFormPrim
class OpenXRDevice(DeviceBase): class OpenXRDevice(DeviceBase):
...@@ -303,3 +299,11 @@ class OpenXRDevice(DeviceBase): ...@@ -303,3 +299,11 @@ class OpenXRDevice(DeviceBase):
elif "reset" in msg: elif "reset" in msg:
if "RESET" in self._additional_callbacks: if "RESET" in self._additional_callbacks:
self._additional_callbacks["RESET"]() self._additional_callbacks["RESET"]()
@dataclass
class OpenXRDeviceCfg(DeviceCfg):
"""Configuration for OpenXR devices."""
xr_cfg: XrCfg | None = None
class_type: type[DeviceBase] = OpenXRDevice
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
from typing import Any
from isaaclab.devices.retargeter_base import RetargeterBase
class DexRetargeter(RetargeterBase):
"""Retargets OpenXR hand joint data to DEX robot joint commands.
This class implements the RetargeterBase interface to convert hand tracking data
into a format suitable for controlling DEX robot hands.
"""
def __init__(self):
"""Initialize the DEX retargeter."""
super().__init__()
# TODO: Add any initialization parameters and state variables needed
pass
def retarget(self, joint_data: dict[str, np.ndarray]) -> Any:
"""Convert OpenXR hand joint poses to DEX robot commands.
Args:
joint_data: Dictionary mapping OpenXR joint names to their pose data.
Each pose is a numpy array of shape (7,) containing
[x, y, z, qx, qy, qz, qw] for absolute mode or
[x, y, z, roll, pitch, yaw] for relative mode.
Returns:
Retargeted data in the format expected by DEX robot control interface.
TODO: Specify the exact return type and format
"""
# TODO: Implement the retargeting logic
raise NotImplementedError("DexRetargeter.retarget() not implemented")
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import contextlib import contextlib
import numpy as np import numpy as np
import torch import torch
...@@ -19,15 +21,6 @@ with contextlib.suppress(Exception): ...@@ -19,15 +21,6 @@ with contextlib.suppress(Exception):
from .gr1_t2_dex_retargeting_utils import GR1TR2DexRetargeting from .gr1_t2_dex_retargeting_utils import GR1TR2DexRetargeting
@dataclass
class GR1T2RetargeterCfg(RetargeterCfg):
"""Configuration for the GR1T2 retargeter."""
enable_visualization: bool = False
num_open_xr_hand_joints: int = 100
hand_joint_names: list[str] | None = None # List of robot hand joint names
class GR1T2Retargeter(RetargeterBase): class GR1T2Retargeter(RetargeterBase):
"""Retargets OpenXR hand tracking data to GR1T2 hand end-effector commands. """Retargets OpenXR hand tracking data to GR1T2 hand end-effector commands.
...@@ -156,3 +149,13 @@ class GR1T2Retargeter(RetargeterBase): ...@@ -156,3 +149,13 @@ class GR1T2Retargeter(RetargeterBase):
usd_right_roll_link_in_world_quat = PoseUtils.quat_from_matrix(usd_right_roll_link_in_world_mat) usd_right_roll_link_in_world_quat = PoseUtils.quat_from_matrix(usd_right_roll_link_in_world_mat)
return np.concatenate([usd_right_roll_link_in_world_pos, usd_right_roll_link_in_world_quat]) return np.concatenate([usd_right_roll_link_in_world_pos, usd_right_roll_link_in_world_quat])
@dataclass
class GR1T2RetargeterCfg(RetargeterCfg):
"""Configuration for the GR1T2 retargeter."""
enable_visualization: bool = False
num_open_xr_hand_joints: int = 100
hand_joint_names: list[str] | None = None # List of robot hand joint names
retargeter_type: type[RetargeterBase] = GR1T2Retargeter
...@@ -3,20 +3,14 @@ ...@@ -3,20 +3,14 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg
@dataclass
class G1LowerBodyStandingRetargeterCfg(RetargeterCfg):
"""Configuration for the G1 lower body standing retargeter."""
hip_height: float = 0.72
"""Height of the G1 robot hip in meters. The value is a fixed height suitable for G1 to do tabletop manipulation."""
class G1LowerBodyStandingRetargeter(RetargeterBase): class G1LowerBodyStandingRetargeter(RetargeterBase):
"""Provides lower body standing commands for the G1 robot.""" """Provides lower body standing commands for the G1 robot."""
...@@ -26,3 +20,12 @@ class G1LowerBodyStandingRetargeter(RetargeterBase): ...@@ -26,3 +20,12 @@ class G1LowerBodyStandingRetargeter(RetargeterBase):
def retarget(self, data: dict) -> torch.Tensor: def retarget(self, data: dict) -> torch.Tensor:
return torch.tensor([0.0, 0.0, 0.0, self.cfg.hip_height], device=self.cfg.sim_device) return torch.tensor([0.0, 0.0, 0.0, self.cfg.hip_height], device=self.cfg.sim_device)
@dataclass
class G1LowerBodyStandingRetargeterCfg(RetargeterCfg):
"""Configuration for the G1 lower body standing retargeter."""
hip_height: float = 0.72
"""Height of the G1 robot hip in meters. The value is a fixed height suitable for G1 to do tabletop manipulation."""
retargeter_type: type[RetargeterBase] = G1LowerBodyStandingRetargeter
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import contextlib import contextlib
import numpy as np import numpy as np
import torch import torch
...@@ -19,15 +21,6 @@ with contextlib.suppress(Exception): ...@@ -19,15 +21,6 @@ with contextlib.suppress(Exception):
from .g1_dex_retargeting_utils import UnitreeG1DexRetargeting from .g1_dex_retargeting_utils import UnitreeG1DexRetargeting
@dataclass
class UnitreeG1RetargeterCfg(RetargeterCfg):
"""Configuration for the UnitreeG1 retargeter."""
enable_visualization: bool = False
num_open_xr_hand_joints: int = 100
hand_joint_names: list[str] | None = None # List of robot hand joint names
class UnitreeG1Retargeter(RetargeterBase): class UnitreeG1Retargeter(RetargeterBase):
"""Retargets OpenXR hand tracking data to GR1T2 hand end-effector commands. """Retargets OpenXR hand tracking data to GR1T2 hand end-effector commands.
...@@ -152,3 +145,13 @@ class UnitreeG1Retargeter(RetargeterBase): ...@@ -152,3 +145,13 @@ class UnitreeG1Retargeter(RetargeterBase):
quat = PoseUtils.quat_from_matrix(rot_mat) quat = PoseUtils.quat_from_matrix(rot_mat)
return np.concatenate([pos.numpy(), quat.numpy()]) return np.concatenate([pos.numpy(), quat.numpy()])
@dataclass
class UnitreeG1RetargeterCfg(RetargeterCfg):
"""Configuration for the UnitreeG1 retargeter."""
enable_visualization: bool = False
num_open_xr_hand_joints: int = 100
hand_joint_names: list[str] | None = None # List of robot hand joint names
retargeter_type: type[RetargeterBase] = UnitreeG1Retargeter
...@@ -21,15 +21,6 @@ with contextlib.suppress(Exception): ...@@ -21,15 +21,6 @@ with contextlib.suppress(Exception):
from .g1_dex_retargeting_utils import G1TriHandDexRetargeting from .g1_dex_retargeting_utils import G1TriHandDexRetargeting
@dataclass
class G1TriHandUpperBodyRetargeterCfg(RetargeterCfg):
"""Configuration for the G1UpperBody retargeter."""
enable_visualization: bool = False
num_open_xr_hand_joints: int = 100
hand_joint_names: list[str] | None = None # List of robot hand joint names
class G1TriHandUpperBodyRetargeter(RetargeterBase): class G1TriHandUpperBodyRetargeter(RetargeterBase):
"""Retargets OpenXR data to G1 upper body commands. """Retargets OpenXR data to G1 upper body commands.
...@@ -164,3 +155,13 @@ class G1TriHandUpperBodyRetargeter(RetargeterBase): ...@@ -164,3 +155,13 @@ class G1TriHandUpperBodyRetargeter(RetargeterBase):
quat = PoseUtils.quat_from_matrix(rot_mat) quat = PoseUtils.quat_from_matrix(rot_mat)
return np.concatenate([pos.numpy(), quat.numpy()]) return np.concatenate([pos.numpy(), quat.numpy()])
@dataclass
class G1TriHandUpperBodyRetargeterCfg(RetargeterCfg):
"""Configuration for the G1UpperBody retargeter."""
enable_visualization: bool = False
num_open_xr_hand_joints: int = 100
hand_joint_names: list[str] | None = None # List of robot hand joint names
retargeter_type: type[RetargeterBase] = G1TriHandUpperBodyRetargeter
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# All rights reserved. # All rights reserved.
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import numpy as np import numpy as np
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
...@@ -11,13 +13,6 @@ from isaaclab.devices import OpenXRDevice ...@@ -11,13 +13,6 @@ from isaaclab.devices import OpenXRDevice
from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg
@dataclass
class GripperRetargeterCfg(RetargeterCfg):
"""Configuration for gripper retargeter."""
bound_hand: OpenXRDevice.TrackingTarget = OpenXRDevice.TrackingTarget.HAND_RIGHT
class GripperRetargeter(RetargeterBase): class GripperRetargeter(RetargeterBase):
"""Retargeter specifically for gripper control based on hand tracking data. """Retargeter specifically for gripper control based on hand tracking data.
...@@ -90,3 +85,11 @@ class GripperRetargeter(RetargeterBase): ...@@ -90,3 +85,11 @@ class GripperRetargeter(RetargeterBase):
self._previous_gripper_command = True self._previous_gripper_command = True
return self._previous_gripper_command return self._previous_gripper_command
@dataclass
class GripperRetargeterCfg(RetargeterCfg):
"""Configuration for gripper retargeter."""
bound_hand: OpenXRDevice.TrackingTarget = OpenXRDevice.TrackingTarget.HAND_RIGHT
retargeter_type: type[RetargeterBase] = GripperRetargeter
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# All rights reserved. # All rights reserved.
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import numpy as np import numpy as np
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
...@@ -13,17 +15,6 @@ from isaaclab.markers import VisualizationMarkers ...@@ -13,17 +15,6 @@ from isaaclab.markers import VisualizationMarkers
from isaaclab.markers.config import FRAME_MARKER_CFG from isaaclab.markers.config import FRAME_MARKER_CFG
@dataclass
class Se3AbsRetargeterCfg(RetargeterCfg):
"""Configuration for absolute position retargeter."""
zero_out_xy_rotation: bool = True
use_wrist_rotation: bool = False
use_wrist_position: bool = True
enable_visualization: bool = False
bound_hand: OpenXRDevice.TrackingTarget = OpenXRDevice.TrackingTarget.HAND_RIGHT
class Se3AbsRetargeter(RetargeterBase): class Se3AbsRetargeter(RetargeterBase):
"""Retargets OpenXR hand tracking data to end-effector commands using absolute positioning. """Retargets OpenXR hand tracking data to end-effector commands using absolute positioning.
...@@ -164,3 +155,15 @@ class Se3AbsRetargeter(RetargeterBase): ...@@ -164,3 +155,15 @@ class Se3AbsRetargeter(RetargeterBase):
quat = Rotation.from_matrix(self._visualization_rot).as_quat() quat = Rotation.from_matrix(self._visualization_rot).as_quat()
rot = np.array([np.array([quat[3], quat[0], quat[1], quat[2]])]) rot = np.array([np.array([quat[3], quat[0], quat[1], quat[2]])])
self._goal_marker.visualize(translations=trans, orientations=rot) self._goal_marker.visualize(translations=trans, orientations=rot)
@dataclass
class Se3AbsRetargeterCfg(RetargeterCfg):
"""Configuration for absolute position retargeter."""
zero_out_xy_rotation: bool = True
use_wrist_rotation: bool = False
use_wrist_position: bool = True
enable_visualization: bool = False
bound_hand: OpenXRDevice.TrackingTarget = OpenXRDevice.TrackingTarget.HAND_RIGHT
retargeter_type: type[RetargeterBase] = Se3AbsRetargeter
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# All rights reserved. # All rights reserved.
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import numpy as np import numpy as np
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
...@@ -13,21 +15,6 @@ from isaaclab.markers import VisualizationMarkers ...@@ -13,21 +15,6 @@ from isaaclab.markers import VisualizationMarkers
from isaaclab.markers.config import FRAME_MARKER_CFG from isaaclab.markers.config import FRAME_MARKER_CFG
@dataclass
class Se3RelRetargeterCfg(RetargeterCfg):
"""Configuration for relative position retargeter."""
zero_out_xy_rotation: bool = True
use_wrist_rotation: bool = False
use_wrist_position: bool = True
delta_pos_scale_factor: float = 10.0
delta_rot_scale_factor: float = 10.0
alpha_pos: float = 0.5
alpha_rot: float = 0.5
enable_visualization: bool = False
bound_hand: OpenXRDevice.TrackingTarget = OpenXRDevice.TrackingTarget.HAND_RIGHT
class Se3RelRetargeter(RetargeterBase): class Se3RelRetargeter(RetargeterBase):
"""Retargets OpenXR hand tracking data to end-effector commands using relative positioning. """Retargets OpenXR hand tracking data to end-effector commands using relative positioning.
...@@ -206,3 +193,19 @@ class Se3RelRetargeter(RetargeterBase): ...@@ -206,3 +193,19 @@ class Se3RelRetargeter(RetargeterBase):
quat = Rotation.from_matrix(self._visualization_rot).as_quat() quat = Rotation.from_matrix(self._visualization_rot).as_quat()
rot = np.array([np.array([quat[3], quat[0], quat[1], quat[2]])]) rot = np.array([np.array([quat[3], quat[0], quat[1], quat[2]])])
self._goal_marker.visualize(translations=trans, orientations=rot) self._goal_marker.visualize(translations=trans, orientations=rot)
@dataclass
class Se3RelRetargeterCfg(RetargeterCfg):
"""Configuration for relative position retargeter."""
zero_out_xy_rotation: bool = True
use_wrist_rotation: bool = False
use_wrist_position: bool = True
delta_pos_scale_factor: float = 10.0
delta_rot_scale_factor: float = 10.0
alpha_pos: float = 0.5
alpha_rot: float = 0.5
enable_visualization: bool = False
bound_hand: OpenXRDevice.TrackingTarget = OpenXRDevice.TrackingTarget.HAND_RIGHT
retargeter_type: type[RetargeterBase] = Se3RelRetargeter
...@@ -13,6 +13,8 @@ class RetargeterCfg: ...@@ -13,6 +13,8 @@ class RetargeterCfg:
"""Base configuration for hand tracking retargeters.""" """Base configuration for hand tracking retargeters."""
sim_device: str = "cpu" sim_device: str = "cpu"
# Concrete retargeter class to construct for this config. Set by each retargeter module.
retargeter_type: type["RetargeterBase"] | None = None
class RetargeterBase(ABC): class RetargeterBase(ABC):
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""Spacemouse controller for SE(2) control.""" """Spacemouse controller for SE(2) control."""
from __future__ import annotations
import hid import hid
import numpy as np import numpy as np
import threading import threading
...@@ -19,16 +21,6 @@ from ..device_base import DeviceBase, DeviceCfg ...@@ -19,16 +21,6 @@ from ..device_base import DeviceBase, DeviceCfg
from .utils import convert_buffer from .utils import convert_buffer
@dataclass
class Se2SpaceMouseCfg(DeviceCfg):
"""Configuration for SE2 space mouse devices."""
v_x_sensitivity: float = 0.8
v_y_sensitivity: float = 0.4
omega_z_sensitivity: float = 1.0
sim_device: str = "cpu"
class Se2SpaceMouse(DeviceBase): class Se2SpaceMouse(DeviceBase):
r"""A space-mouse controller for sending SE(2) commands as delta poses. r"""A space-mouse controller for sending SE(2) commands as delta poses.
...@@ -168,3 +160,13 @@ class Se2SpaceMouse(DeviceBase): ...@@ -168,3 +160,13 @@ class Se2SpaceMouse(DeviceBase):
# additional callbacks # additional callbacks
if "R" in self._additional_callbacks: if "R" in self._additional_callbacks:
self._additional_callbacks["R"] self._additional_callbacks["R"]
@dataclass
class Se2SpaceMouseCfg(DeviceCfg):
"""Configuration for SE2 space mouse devices."""
v_x_sensitivity: float = 0.8
v_y_sensitivity: float = 0.4
omega_z_sensitivity: float = 1.0
class_type: type[DeviceBase] = Se2SpaceMouse
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
"""Spacemouse controller for SE(3) control.""" """Spacemouse controller for SE(3) control."""
from __future__ import annotations
import hid import hid
import numpy as np import numpy as np
import threading import threading
...@@ -18,16 +20,6 @@ from ..device_base import DeviceBase, DeviceCfg ...@@ -18,16 +20,6 @@ from ..device_base import DeviceBase, DeviceCfg
from .utils import convert_buffer from .utils import convert_buffer
@dataclass
class Se3SpaceMouseCfg(DeviceCfg):
"""Configuration for SE3 space mouse devices."""
gripper_term: bool = True
pos_sensitivity: float = 0.4
rot_sensitivity: float = 0.8
retargeters: None = None
class Se3SpaceMouse(DeviceBase): class Se3SpaceMouse(DeviceBase):
"""A space-mouse controller for sending SE(3) commands as delta poses. """A space-mouse controller for sending SE(3) commands as delta poses.
...@@ -210,3 +202,14 @@ class Se3SpaceMouse(DeviceBase): ...@@ -210,3 +202,14 @@ class Se3SpaceMouse(DeviceBase):
self._additional_callbacks["R"]() self._additional_callbacks["R"]()
if data[1] == 3: if data[1] == 3:
self._read_rotation = not self._read_rotation self._read_rotation = not self._read_rotation
@dataclass
class Se3SpaceMouseCfg(DeviceCfg):
"""Configuration for SE3 space mouse devices."""
gripper_term: bool = True
pos_sensitivity: float = 0.4
rot_sensitivity: float = 0.8
retargeters: None = None
class_type: type[DeviceBase] = Se3SpaceMouse
...@@ -4,67 +4,17 @@ ...@@ -4,67 +4,17 @@
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
"""Factory to create teleoperation devices from configuration.""" """Factory to create teleoperation devices from configuration."""
import contextlib
import inspect import inspect
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import cast
from isaaclab.devices import DeviceBase, DeviceCfg from isaaclab.devices import DeviceBase, DeviceCfg
from isaaclab.devices.gamepad import Se2Gamepad, Se2GamepadCfg, Se3Gamepad, Se3GamepadCfg from isaaclab.devices.retargeter_base import RetargeterBase
from isaaclab.devices.haply import HaplyDevice, HaplyDeviceCfg
from isaaclab.devices.keyboard import Se2Keyboard, Se2KeyboardCfg, Se3Keyboard, Se3KeyboardCfg
from isaaclab.devices.openxr.retargeters import (
G1LowerBodyStandingRetargeter,
G1LowerBodyStandingRetargeterCfg,
G1TriHandUpperBodyRetargeter,
G1TriHandUpperBodyRetargeterCfg,
GR1T2Retargeter,
GR1T2RetargeterCfg,
GripperRetargeter,
GripperRetargeterCfg,
Se3AbsRetargeter,
Se3AbsRetargeterCfg,
Se3RelRetargeter,
Se3RelRetargeterCfg,
UnitreeG1Retargeter,
UnitreeG1RetargeterCfg,
)
from isaaclab.devices.retargeter_base import RetargeterBase, RetargeterCfg
from isaaclab.devices.spacemouse import Se2SpaceMouse, Se2SpaceMouseCfg, Se3SpaceMouse, Se3SpaceMouseCfg
with contextlib.suppress(ModuleNotFoundError):
# May fail if xr is not in use
from isaaclab.devices.openxr import ManusVive, ManusViveCfg, OpenXRDevice, OpenXRDeviceCfg
# import logger # import logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Map device types to their constructor and expected config type
DEVICE_MAP: dict[type[DeviceCfg], type[DeviceBase]] = {
Se3KeyboardCfg: Se3Keyboard,
Se3SpaceMouseCfg: Se3SpaceMouse,
Se3GamepadCfg: Se3Gamepad,
Se2KeyboardCfg: Se2Keyboard,
Se2GamepadCfg: Se2Gamepad,
Se2SpaceMouseCfg: Se2SpaceMouse,
HaplyDeviceCfg: HaplyDevice,
OpenXRDeviceCfg: OpenXRDevice,
ManusViveCfg: ManusVive,
}
# Map configuration types to their corresponding retargeter classes
RETARGETER_MAP: dict[type[RetargeterCfg], type[RetargeterBase]] = {
Se3AbsRetargeterCfg: Se3AbsRetargeter,
Se3RelRetargeterCfg: Se3RelRetargeter,
GripperRetargeterCfg: GripperRetargeter,
GR1T2RetargeterCfg: GR1T2Retargeter,
G1TriHandUpperBodyRetargeterCfg: G1TriHandUpperBodyRetargeter,
G1LowerBodyStandingRetargeterCfg: G1LowerBodyStandingRetargeter,
UnitreeG1RetargeterCfg: UnitreeG1Retargeter,
}
def create_teleop_device( def create_teleop_device(
device_name: str, devices_cfg: dict[str, DeviceCfg], callbacks: dict[str, Callable] | None = None device_name: str, devices_cfg: dict[str, DeviceCfg], callbacks: dict[str, Callable] | None = None
...@@ -90,35 +40,44 @@ def create_teleop_device( ...@@ -90,35 +40,44 @@ def create_teleop_device(
device_cfg = devices_cfg[device_name] device_cfg = devices_cfg[device_name]
callbacks = callbacks or {} callbacks = callbacks or {}
# Check if device config type is supported # Determine constructor from the configuration itself
cfg_type = type(device_cfg) device_constructor = getattr(device_cfg, "class_type", None)
if cfg_type not in DEVICE_MAP: if device_constructor is None:
raise ValueError(f"Unsupported device configuration type: {cfg_type.__name__}") raise ValueError(
f"Device configuration '{device_name}' does not declare class_type. "
# Get the constructor for this config type "Set cfg.class_type to the concrete DeviceBase subclass."
constructor = DEVICE_MAP[cfg_type] )
if not issubclass(device_constructor, DeviceBase):
raise TypeError(f"class_type for '{device_name}' must be a subclass of DeviceBase; got {device_constructor}")
# Try to create retargeters if they are configured # Try to create retargeters if they are configured
retargeters = [] retargeters = []
if hasattr(device_cfg, "retargeters") and device_cfg.retargeters is not None: if hasattr(device_cfg, "retargeters") and device_cfg.retargeters is not None:
try: try:
# Create retargeters based on configuration # Create retargeters based on configuration using per-config retargeter_type
for retargeter_cfg in device_cfg.retargeters: for retargeter_cfg in device_cfg.retargeters:
cfg_type = type(retargeter_cfg) retargeter_constructor = getattr(retargeter_cfg, "retargeter_type", None)
if cfg_type in RETARGETER_MAP: if retargeter_constructor is None:
retargeters.append(RETARGETER_MAP[cfg_type](retargeter_cfg)) raise ValueError(
else: f"Retargeter configuration {type(retargeter_cfg).__name__} does not declare retargeter_type. "
raise ValueError(f"Unknown retargeter configuration type: {cfg_type.__name__}") "Set cfg.retargeter_type to the concrete RetargeterBase subclass."
)
if not issubclass(retargeter_constructor, RetargeterBase):
raise TypeError(
f"retargeter_type for {type(retargeter_cfg).__name__} must be a subclass of RetargeterBase; got"
f" {retargeter_constructor}"
)
retargeters.append(retargeter_constructor(retargeter_cfg))
except NameError as e: except NameError as e:
raise ValueError(f"Failed to create retargeters: {e}") raise ValueError(f"Failed to create retargeters: {e}")
# Check if the constructor accepts retargeters parameter # Build constructor kwargs based on signature
constructor_params = inspect.signature(constructor).parameters constructor_params = inspect.signature(device_constructor).parameters
if "retargeters" in constructor_params and retargeters: params: dict = {"cfg": device_cfg}
device = constructor(cfg=device_cfg, retargeters=retargeters) if "retargeters" in constructor_params:
else: params["retargeters"] = retargeters
device = constructor(cfg=device_cfg) device = cast(DeviceBase, device_constructor(**params))
# Register callbacks # Register callbacks
for key, callback in callbacks.items(): for key, callback in callbacks.items():
......
...@@ -15,11 +15,13 @@ simulation_app = AppLauncher(headless=True).app ...@@ -15,11 +15,13 @@ simulation_app = AppLauncher(headless=True).app
import importlib import importlib
import json import json
import torch import torch
from typing import cast
import pytest import pytest
# Import device classes to test # Import device classes to test
from isaaclab.devices import ( from isaaclab.devices import (
DeviceCfg,
HaplyDevice, HaplyDevice,
HaplyDeviceCfg, HaplyDeviceCfg,
OpenXRDevice, OpenXRDevice,
...@@ -69,6 +71,11 @@ def mock_environment(mocker): ...@@ -69,6 +71,11 @@ def mock_environment(mocker):
carb_mock.input.KeyboardEventType.KEY_PRESS = 1 carb_mock.input.KeyboardEventType.KEY_PRESS = 1
carb_mock.input.KeyboardEventType.KEY_RELEASE = 2 carb_mock.input.KeyboardEventType.KEY_RELEASE = 2
# Mock carb events used by OpenXRDevice
events_mock = mocker.MagicMock()
events_mock.type_from_string.return_value = 0
carb_mock.events = events_mock
# Mock the SpaceMouse # Mock the SpaceMouse
hid_mock.enumerate.return_value = [{"product_string": "SpaceMouse Compact", "vendor_id": 123, "product_id": 456}] hid_mock.enumerate.return_value = [{"product_string": "SpaceMouse Compact", "vendor_id": 123, "product_id": 456}]
hid_mock.device.return_value = device_mock hid_mock.device.return_value = device_mock
...@@ -300,6 +307,7 @@ def test_openxr_constructors(mock_environment, mocker): ...@@ -300,6 +307,7 @@ def test_openxr_constructors(mock_environment, mocker):
"isaacsim.core.prims": mocker.MagicMock(), "isaacsim.core.prims": mocker.MagicMock(),
}, },
) )
mocker.patch.object(device_mod, "carb", mock_environment["carb"])
mocker.patch.object(device_mod, "XRCore", mock_environment["omni"].kit.xr.core.XRCore) mocker.patch.object(device_mod, "XRCore", mock_environment["omni"].kit.xr.core.XRCore)
mocker.patch.object(device_mod, "XRPoseValidityFlags", mock_environment["omni"].kit.xr.core.XRPoseValidityFlags) mocker.patch.object(device_mod, "XRPoseValidityFlags", mock_environment["omni"].kit.xr.core.XRPoseValidityFlags)
mock_single_xform = mocker.patch.object(device_mod, "SingleXFormPrim") mock_single_xform = mocker.patch.object(device_mod, "SingleXFormPrim")
...@@ -477,7 +485,7 @@ def test_create_teleop_device_basic(mock_environment, mocker): ...@@ -477,7 +485,7 @@ def test_create_teleop_device_basic(mock_environment, mocker):
keyboard_cfg = Se3KeyboardCfg(pos_sensitivity=0.8, rot_sensitivity=1.2) keyboard_cfg = Se3KeyboardCfg(pos_sensitivity=0.8, rot_sensitivity=1.2)
# Create devices configuration dictionary # Create devices configuration dictionary
devices_cfg = {"test_keyboard": keyboard_cfg} devices_cfg: dict[str, DeviceCfg] = {"test_keyboard": keyboard_cfg}
# Mock Se3Keyboard class # Mock Se3Keyboard class
device_mod = importlib.import_module("isaaclab.devices.keyboard.se3_keyboard") device_mod = importlib.import_module("isaaclab.devices.keyboard.se3_keyboard")
...@@ -501,7 +509,7 @@ def test_create_teleop_device_with_callbacks(mock_environment, mocker): ...@@ -501,7 +509,7 @@ def test_create_teleop_device_with_callbacks(mock_environment, mocker):
openxr_cfg = OpenXRDeviceCfg(xr_cfg=xr_cfg) openxr_cfg = OpenXRDeviceCfg(xr_cfg=xr_cfg)
# Create devices configuration dictionary # Create devices configuration dictionary
devices_cfg = {"test_xr": openxr_cfg} devices_cfg: dict[str, DeviceCfg] = {"test_xr": openxr_cfg}
# Create mock callbacks # Create mock callbacks
button_a_callback = mocker.MagicMock() button_a_callback = mocker.MagicMock()
...@@ -518,6 +526,7 @@ def test_create_teleop_device_with_callbacks(mock_environment, mocker): ...@@ -518,6 +526,7 @@ def test_create_teleop_device_with_callbacks(mock_environment, mocker):
"isaacsim.core.prims": mocker.MagicMock(), "isaacsim.core.prims": mocker.MagicMock(),
}, },
) )
mocker.patch.object(device_mod, "carb", mock_environment["carb"])
mocker.patch.object(device_mod, "XRCore", mock_environment["omni"].kit.xr.core.XRCore) mocker.patch.object(device_mod, "XRCore", mock_environment["omni"].kit.xr.core.XRCore)
mocker.patch.object(device_mod, "XRPoseValidityFlags", mock_environment["omni"].kit.xr.core.XRPoseValidityFlags) mocker.patch.object(device_mod, "XRPoseValidityFlags", mock_environment["omni"].kit.xr.core.XRPoseValidityFlags)
mock_single_xform = mocker.patch.object(device_mod, "SingleXFormPrim") mock_single_xform = mocker.patch.object(device_mod, "SingleXFormPrim")
...@@ -532,10 +541,8 @@ def test_create_teleop_device_with_callbacks(mock_environment, mocker): ...@@ -532,10 +541,8 @@ def test_create_teleop_device_with_callbacks(mock_environment, mocker):
# Verify the device was created correctly # Verify the device was created correctly
assert isinstance(device, OpenXRDevice) assert isinstance(device, OpenXRDevice)
# Verify callbacks were registered # Verify callbacks were registered by the factory
device.add_callback("button_a", button_a_callback) assert set(device._additional_callbacks.keys()) == {"button_a", "button_b"}
device.add_callback("button_b", button_b_callback)
assert len(device._additional_callbacks) == 2
def test_create_teleop_device_with_retargeters(mock_environment, mocker): def test_create_teleop_device_with_retargeters(mock_environment, mocker):
...@@ -549,7 +556,7 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker): ...@@ -549,7 +556,7 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker):
device_cfg = OpenXRDeviceCfg(xr_cfg=xr_cfg, retargeters=[retargeter_cfg1, retargeter_cfg2]) device_cfg = OpenXRDeviceCfg(xr_cfg=xr_cfg, retargeters=[retargeter_cfg1, retargeter_cfg2])
# Create devices configuration dictionary # Create devices configuration dictionary
devices_cfg = {"test_xr": device_cfg} devices_cfg: dict[str, DeviceCfg] = {"test_xr": device_cfg}
# Mock OpenXRDevice class and dependencies # Mock OpenXRDevice class and dependencies
device_mod = importlib.import_module("isaaclab.devices.openxr.openxr_device") device_mod = importlib.import_module("isaaclab.devices.openxr.openxr_device")
...@@ -561,6 +568,7 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker): ...@@ -561,6 +568,7 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker):
"isaacsim.core.prims": mocker.MagicMock(), "isaacsim.core.prims": mocker.MagicMock(),
}, },
) )
mocker.patch.object(device_mod, "carb", mock_environment["carb"])
mocker.patch.object(device_mod, "XRCore", mock_environment["omni"].kit.xr.core.XRCore) mocker.patch.object(device_mod, "XRCore", mock_environment["omni"].kit.xr.core.XRCore)
mocker.patch.object(device_mod, "XRPoseValidityFlags", mock_environment["omni"].kit.xr.core.XRPoseValidityFlags) mocker.patch.object(device_mod, "XRPoseValidityFlags", mock_environment["omni"].kit.xr.core.XRPoseValidityFlags)
mock_single_xform = mocker.patch.object(device_mod, "SingleXFormPrim") mock_single_xform = mocker.patch.object(device_mod, "SingleXFormPrim")
...@@ -569,11 +577,6 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker): ...@@ -569,11 +577,6 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker):
mock_instance = mock_single_xform.return_value mock_instance = mock_single_xform.return_value
mock_instance.prim_path = "/XRAnchor" mock_instance.prim_path = "/XRAnchor"
# Mock retargeter classes
retargeter_mod = importlib.import_module("isaaclab.devices.openxr.retargeters")
mocker.patch.object(retargeter_mod, "Se3AbsRetargeter")
mocker.patch.object(retargeter_mod, "GripperRetargeter")
# Create the device using the factory # Create the device using the factory
device = create_teleop_device("test_xr", devices_cfg) device = create_teleop_device("test_xr", devices_cfg)
...@@ -584,7 +587,7 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker): ...@@ -584,7 +587,7 @@ def test_create_teleop_device_with_retargeters(mock_environment, mocker):
def test_create_teleop_device_device_not_found(): def test_create_teleop_device_device_not_found():
"""Test error when device name is not found in configuration.""" """Test error when device name is not found in configuration."""
# Create devices configuration dictionary # Create devices configuration dictionary
devices_cfg = {"keyboard": Se3KeyboardCfg()} devices_cfg: dict[str, DeviceCfg] = {"keyboard": Se3KeyboardCfg()}
# Try to create a non-existent device # Try to create a non-existent device
with pytest.raises(ValueError, match="Device 'gamepad' not found"): with pytest.raises(ValueError, match="Device 'gamepad' not found"):
...@@ -599,8 +602,8 @@ def test_create_teleop_device_unsupported_config(): ...@@ -599,8 +602,8 @@ def test_create_teleop_device_unsupported_config():
pass pass
# Create devices configuration dictionary with unsupported config # Create devices configuration dictionary with unsupported config
devices_cfg = {"unsupported": UnsupportedCfg()} devices_cfg: dict[str, DeviceCfg] = cast(dict[str, DeviceCfg], {"unsupported": UnsupportedCfg()})
# Try to create a device with unsupported configuration # Try to create a device with unsupported configuration
with pytest.raises(ValueError, match="Unsupported device configuration type"): with pytest.raises(ValueError, match="does not declare class_type"):
create_teleop_device("unsupported", devices_cfg) create_teleop_device("unsupported", devices_cfg)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment