Unverified Commit 0ef582ba authored by James Smith's avatar James Smith Committed by GitHub

Expands functionality of FrameTransformer to allow multi-body transforms (#858)

# Description

Update FrameTransformer to handle 2 new functionalities:
* Target frames that aren't children of the source frame prim_path
* Target frames that are based upon the source frame prim_path

These new changes mean that the frame names will most likely be
different than the configured order - but this was always a possibility
to the way the regex is parsed. To be safe, users need to use
`frame_names` to determine indexing into `FrameTransformerData`.

Test cases have been added for both of these new functionalities -
thanks @Mayankm96!

Also, the run script has been updated slightly as the previous indexing
was off by 1.


Fixes #857 #294

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 2a9198c8
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.24.18"
version = "0.24.19"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.24.19 (2024-10-05)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added new functionalities to the FrameTransformer to make it more general. It is now possible to track:
* Target frames that aren't children of the source frame prim_path
* Target frames that are based upon the source frame prim_path
0.24.18 (2024-10-04)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -5,6 +5,7 @@
from __future__ import annotations
import re
import torch
from collections.abc import Sequence
from typing import TYPE_CHECKING
......@@ -50,21 +51,6 @@ class FrameTransformer(SensorBase):
typically a fictitious body, the user may need to specify an offset from the end-effector to the body of the
manipulator.
.. note::
Currently, this implementation only handles frames within an articulation. This is because the frame
regex expressions are resolved based on their parent prim path. This can be extended to handle
frames outside of articulation by using the frame prim path instead. However, this would require
additional checks to ensure that the user-specified frames are valid which is not currently implemented.
.. warning::
The implementation assumes that the parent body of a target frame is not the same as that
of the source frame (i.e. :attr:`FrameTransformerCfg.prim_path`). While a corner case, this can occur
if the user specifies the same prim path for both the source frame and target frame. In this case,
the target frame will be ignored and not reported. This is a limitation of the current implementation
and will be fixed in a future release.
"""
cfg: FrameTransformerCfg
......@@ -136,9 +122,9 @@ class FrameTransformer(SensorBase):
self._source_frame_offset_pos = source_frame_offset_pos.unsqueeze(0).repeat(self._num_envs, 1)
self._source_frame_offset_quat = source_frame_offset_quat.unsqueeze(0).repeat(self._num_envs, 1)
# Keep track of mapping from the rigid body name to the desired frame, as there may be multiple frames
# Keep track of mapping from the rigid body name to the desired frames and prim path, as there may be multiple frames
# based upon the same body name and we don't want to create unnecessary views
body_names_to_frames: dict[str, set[str]] = {}
body_names_to_frames: dict[str, dict[str, set[str] | str]] = {}
# The offsets associated with each target frame
target_offsets: dict[str, dict[str, torch.Tensor]] = {}
# The frames whose offsets are not identity
......@@ -148,6 +134,9 @@ class FrameTransformer(SensorBase):
# rotation offsets are not the identity quaternion for efficiency in _update_buffer_impl
self._apply_target_frame_offset = False
# Need to keep track of whether the source frame is also a target frame
self._source_is_also_target_frame = False
# Collect all target frames, their associated body prim paths and their offsets so that we can extract
# the prim, check that it has the appropriate rigid body API in a single loop.
# First element is None because user can't specify source frame name
......@@ -155,7 +144,8 @@ class FrameTransformer(SensorBase):
frame_prim_paths = [self.cfg.prim_path] + [target_frame.prim_path for target_frame in self.cfg.target_frames]
# First element is None because source frame offset is handled separately
frame_offsets = [None] + [target_frame.offset for target_frame in self.cfg.target_frames]
for frame, prim_path, offset in zip(frames, frame_prim_paths, frame_offsets):
frame_types = ["source"] + ["target"] * len(self.cfg.target_frames)
for frame, prim_path, offset, frame_type in zip(frames, frame_prim_paths, frame_offsets, frame_types):
# Find correct prim
matching_prims = sim_utils.find_matching_prims(prim_path)
if len(matching_prims) == 0:
......@@ -180,9 +170,19 @@ class FrameTransformer(SensorBase):
# Keep track of which frames are associated with which bodies
if body_name in body_names_to_frames:
body_names_to_frames[body_name].add(frame_name)
body_names_to_frames[body_name]["frames"].add(frame_name)
# This is a corner case where the source frame is also a target frame
if body_names_to_frames[body_name]["type"] == "source" and frame_type == "target":
self._source_is_also_target_frame = True
else:
body_names_to_frames[body_name] = {frame_name}
# Store the first matching prim path and the type of frame
body_names_to_frames[body_name] = {
"frames": {frame_name},
"prim_path": matching_prim_path,
"type": frame_type,
}
if offset is not None:
offset_pos = torch.tensor(offset.pos, device=self.device)
......@@ -191,7 +191,6 @@ class FrameTransformer(SensorBase):
if not is_identity_pose(offset_pos, offset_quat):
non_identity_offset_frames.append(frame_name)
self._apply_target_frame_offset = True
target_offsets[frame_name] = {"pos": offset_pos, "quat": offset_quat}
if not self._apply_target_frame_offset:
......@@ -206,37 +205,75 @@ class FrameTransformer(SensorBase):
)
# The names of bodies that RigidPrimView will be tracking to later extract transforms from
tracked_body_names = list(body_names_to_frames.keys())
# Construct regex expression for the body names
body_names_regex = r"(" + "|".join(tracked_body_names) + r")"
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
tracked_prim_paths = [body_names_to_frames[body_name]["prim_path"] for body_name in body_names_to_frames.keys()]
tracked_body_names = [body_name for body_name in body_names_to_frames.keys()]
body_names_regex = [tracked_prim_path.replace("env_0", "env_*") for tracked_prim_path in tracked_prim_paths]
# Create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# Create a prim view for all frames and initialize it
# order of transforms coming out of view will be source frame followed by target frame(s)
self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex)
# Determine the order in which regex evaluated body names so we can later index into frame transforms
# by frame name correctly
all_prim_paths = self._frame_physx_view.prim_paths
# Only need first env as the names and their ordering are the same across environments
first_env_prim_paths = all_prim_paths[0 : len(tracked_body_names)]
first_env_body_names = [first_env_prim_path.split("/")[-1] for first_env_prim_path in first_env_prim_paths]
if "env_" in all_prim_paths[0]:
def extract_env_num_and_prim_path(item: str) -> tuple[int, str]:
"""Separates the environment number and prim_path from the item.
Args:
item: The item to extract the environment number from. Assumes item is of the form
`/World/envs/env_1/blah` or `/World/envs/env_11/blah`.
Returns:
The environment number and the prim_path.
"""
match = re.search(r"env_(\d+)(.*)", item)
return (int(match.group(1)), match.group(2))
# Find the indices that would reorganize output to be per environment. We want `env_1/blah` to come before `env_11/blah`
# and env_1/Robot/base to come before env_1/Robot/foot so we need to use custom key function
self._per_env_indices = [
index
for index, _ in sorted(
list(enumerate(all_prim_paths)), key=lambda x: extract_env_num_and_prim_path(x[1])
)
]
# Only need 0th env as the names and their ordering are the same across environments
sorted_prim_paths = [
all_prim_paths[index] for index in self._per_env_indices if "env_0" in all_prim_paths[index]
]
else:
# If no environment is present, then the order of the body names is the same as the order of the prim paths sorted alphabetically
self._per_env_indices = [index for index, _ in sorted(enumerate(all_prim_paths), key=lambda x: x[1])]
sorted_prim_paths = [all_prim_paths[index] for index in self._per_env_indices]
# -- target frames
self._target_frame_body_names = [prim_path.split("/")[-1] for prim_path in sorted_prim_paths]
# Re-parse the list as it may have moved when resolving regex above
# -- source frame
self._source_frame_body_name = self.cfg.prim_path.split("/")[-1]
source_frame_index = first_env_body_names.index(self._source_frame_body_name)
# -- target frames
self._target_frame_body_names = first_env_body_names[:]
self._target_frame_body_names.remove(self._source_frame_body_name)
source_frame_index = self._target_frame_body_names.index(self._source_frame_body_name)
# Only remove source frame from tracked bodies if it is not also a target frame
if not self._source_is_also_target_frame:
self._target_frame_body_names.remove(self._source_frame_body_name)
# Determine indices into all tracked body frames for both source and target frames
all_ids = torch.arange(self._num_envs * len(tracked_body_names))
self._source_frame_body_ids = torch.arange(self._num_envs) * len(tracked_body_names) + source_frame_index
self._target_frame_body_ids = all_ids[~torch.isin(all_ids, self._source_frame_body_ids)]
# If source frame is also a target frame, then the target frame body ids are the same as the source frame body ids
if self._source_is_also_target_frame:
self._target_frame_body_ids = all_ids
else:
self._target_frame_body_ids = all_ids[~torch.isin(all_ids, self._source_frame_body_ids)]
# The name of each of the target frame(s) - either user specified or defaulted to the body name
self._target_frame_names: list[str] = []
......@@ -249,26 +286,34 @@ class FrameTransformer(SensorBase):
duplicate_frame_indices = []
# Go through each body name and determine the number of duplicates we need for that frame
# and extract the offsets. This is all done to handles the case where multiple frames
# and extract the offsets. This is all done to handle the case where multiple frames
# reference the same body, but have different names and/or offsets
for i, body_name in enumerate(self._target_frame_body_names):
for frame in body_names_to_frames[body_name]:
target_frame_offset_pos.append(target_offsets[frame]["pos"])
target_frame_offset_quat.append(target_offsets[frame]["quat"])
self._target_frame_names.append(frame)
duplicate_frame_indices.append(i)
for frame in body_names_to_frames[body_name]["frames"]:
# Only need to handle target frames here as source frame is handled separately
if frame in target_offsets:
target_frame_offset_pos.append(target_offsets[frame]["pos"])
target_frame_offset_quat.append(target_offsets[frame]["quat"])
self._target_frame_names.append(frame)
duplicate_frame_indices.append(i)
# To handle multiple environments, need to expand so [0, 1, 1, 2] with 2 environments becomes
# [0, 1, 1, 2, 3, 4, 4, 5]. Again, this is a optimization to make _update_buffer_impl more efficient
duplicate_frame_indices = torch.tensor(duplicate_frame_indices, device=self.device)
num_target_body_frames = len(tracked_body_names) - 1
if self._source_is_also_target_frame:
num_target_body_frames = len(tracked_body_names)
else:
num_target_body_frames = len(tracked_body_names) - 1
self._duplicate_frame_indices = torch.cat(
[duplicate_frame_indices + num_target_body_frames * env_num for env_num in range(self._num_envs)]
)
# Stack up all the frame offsets for shape (num_envs, num_frames, 3) and (num_envs, num_frames, 4)
self._target_frame_offset_pos = torch.stack(target_frame_offset_pos).repeat(self._num_envs, 1)
self._target_frame_offset_quat = torch.stack(target_frame_offset_quat).repeat(self._num_envs, 1)
# Target frame offsets are only applied if at least one of the offsets are non-identity
if self._apply_target_frame_offset:
# Stack up all the frame offsets for shape (num_envs, num_frames, 3) and (num_envs, num_frames, 4)
self._target_frame_offset_pos = torch.stack(target_frame_offset_pos).repeat(self._num_envs, 1)
self._target_frame_offset_quat = torch.stack(target_frame_offset_quat).repeat(self._num_envs, 1)
# fill the data buffer
self._data.target_frame_names = self._target_frame_names
......@@ -288,6 +333,10 @@ class FrameTransformer(SensorBase):
# Extract transforms from view - shape is:
# (the total number of source and target body frames being tracked * self._num_envs, 7)
transforms = self._frame_physx_view.get_transforms()
# Reorder the transforms to be per environment as is expected of SensorData
transforms = transforms[self._per_env_indices]
# Convert quaternions as PhysX uses xyzw form
transforms[:, 3:] = convert_quat(transforms[:, 3:], to="wxyz")
......@@ -309,6 +358,7 @@ class FrameTransformer(SensorBase):
target_frames = transforms[self._target_frame_body_ids]
duplicated_target_frame_pos_w = target_frames[self._duplicate_frame_indices, :3]
duplicated_target_frame_quat_w = target_frames[self._duplicate_frame_indices, 3:]
# Only apply offset if the offsets will result in a coordinate frame transform
if self._apply_target_frame_offset:
target_pos_w, target_quat_w = combine_frame_transforms(
......
......@@ -31,10 +31,15 @@ class FrameTransformerCfg(SensorBaseCfg):
"""Information specific to a coordinate frame."""
prim_path: str = MISSING
"""The prim path corresponding to the parent rigid body.
"""The prim path corresponding to a rigid body.
This prim should be part of the same articulation as :attr:`FrameTransformerCfg.prim_path`.
This can be a regex pattern to match multiple prims. For example, "/Robot/.*" will match all prims under "/Robot".
This means that if the source :attr:`FrameTransformerCfg.prim_path` is "/Robot/base", and the target :attr:`FrameTransformerCfg.FrameCfg.prim_path` is "/Robot/.*",
then the frame transformer will track the poses of all the prims under "/Robot",
including "/Robot/base" (even though this will result in an identity pose w.r.t. the source frame).
"""
name: str | None = None
"""User-defined name for the new coordinate frame. Defaults to None.
......
......@@ -15,8 +15,8 @@ class FrameTransformerData:
"""Target frame names (this denotes the order in which that frame data is ordered).
The frame names are resolved from the :attr:`FrameTransformerCfg.FrameCfg.name` field.
This usually follows the order in which the frames are defined in the config. However, in
the case of regex matching, the order may be different.
This does not necessarily follow the order in which the frames are defined in the config due to
the regex matching.
"""
target_pos_source: torch.Tensor = None
......
......@@ -3,10 +3,6 @@
#
# SPDX-License-Identifier: BSD-3-Clause
"""
This script checks the FrameTransformer sensor by visualizing the frames that it creates.
"""
"""Launch Isaac Sim Simulator first."""
from omni.isaac.lab.app import AppLauncher, run_tests
......@@ -26,6 +22,7 @@ import omni.isaac.core.utils.stage as stage_utils
import omni.isaac.lab.sim as sim_utils
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.assets import RigidObjectCfg
from omni.isaac.lab.scene import InteractiveScene, InteractiveSceneCfg
from omni.isaac.lab.sensors import FrameTransformerCfg, OffsetCfg
from omni.isaac.lab.terrains import TerrainImporterCfg
......@@ -62,6 +59,19 @@ class MySceneCfg(InteractiveSceneCfg):
# sensors - frame transformer (filled inside unit test)
frame_transformer: FrameTransformerCfg = None
# block
cube: RigidObjectCfg = RigidObjectCfg(
prim_path="{ENV_REGEX_NS}/cube",
spawn=sim_utils.CuboidCfg(
size=(0.2, 0.2, 0.2),
rigid_props=sim_utils.RigidBodyPropertiesCfg(max_depenetration_velocity=1.0),
mass_props=sim_utils.MassPropertiesCfg(mass=1.0),
physics_material=sim_utils.RigidBodyMaterialCfg(),
visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.5, 0.0, 0.0)),
),
init_state=RigidObjectCfg.InitialStateCfg(pos=(2.0, 0.0, 5)),
)
class TestFrameTransformer(unittest.TestCase):
"""Test for frame transformer sensor."""
......@@ -71,7 +81,7 @@ class TestFrameTransformer(unittest.TestCase):
# Create a new stage
stage_utils.create_new_stage()
# Load kit helper
self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005))
self.sim = sim_utils.SimulationContext(sim_utils.SimulationCfg(dt=0.005, device="cpu"))
# Set main camera
self.sim.set_camera_view(eye=[5, 5, 5], target=[0.0, 0.0, 0.0])
......@@ -90,8 +100,7 @@ class TestFrameTransformer(unittest.TestCase):
def test_frame_transformer_feet_wrt_base(self):
"""Test feet transformations w.r.t. base source frame.
In this test, the source frame is the robot base. This frame is at index 0, when
the frame bodies are sorted in the order of the regex matching in the frame transformer.
In this test, the source frame is the robot base.
"""
# Spawn things into stage
scene_cfg = MySceneCfg(num_envs=32, env_spacing=5.0, lazy_sensor_update=False)
......@@ -141,9 +150,15 @@ class TestFrameTransformer(unittest.TestCase):
feet_indices, feet_names = scene.articulations["robot"].find_bodies(
["LF_FOOT", "RF_FOOT", "LH_FOOT", "RH_FOOT"]
)
# Check names are parsed the same order
user_feet_names = [f"{name}_USER" for name in feet_names]
self.assertListEqual(scene.sensors["frame_transformer"].data.target_frame_names, user_feet_names)
target_frame_names = scene.sensors["frame_transformer"].data.target_frame_names
# Reorder the feet indices to match the order of the target frames with _USER suffix removed
target_frame_names = [name.split("_USER")[0] for name in target_frame_names]
# Find the indices of the feet in the order of the target frames
reordering_indices = [feet_names.index(name) for name in target_frame_names]
feet_indices = [feet_indices[i] for i in reordering_indices]
# default joint targets
default_actions = scene.articulations["robot"].data.default_joint_pos.clone()
......@@ -185,11 +200,12 @@ class TestFrameTransformer(unittest.TestCase):
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
feet_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w
feet_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w
# check if they are same
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_pos_w_gt, feet_pos_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_quat_w_gt, feet_quat_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf)
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf)
torch.testing.assert_close(feet_pos_w_gt, feet_pos_w_tf)
torch.testing.assert_close(feet_quat_w_gt, feet_quat_w_tf)
# check if relative transforms are same
feet_pos_source_tf = scene.sensors["frame_transformer"].data.target_pos_source
......@@ -200,8 +216,8 @@ class TestFrameTransformer(unittest.TestCase):
root_pose_w[:, :3], root_pose_w[:, 3:], feet_pos_w_tf[:, index], feet_quat_w_tf[:, index]
)
# check if they are same
torch.testing.assert_close(feet_pos_source_tf[:, index], foot_pos_b, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_quat_source_tf[:, index], foot_quat_b, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_pos_source_tf[:, index], foot_pos_b)
torch.testing.assert_close(feet_quat_source_tf[:, index], foot_quat_b)
def test_frame_transformer_feet_wrt_thigh(self):
"""Test feet transformation w.r.t. thigh source frame.
......@@ -285,10 +301,10 @@ class TestFrameTransformer(unittest.TestCase):
feet_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w
feet_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w
# check if they are same
torch.testing.assert_close(source_pose_w_gt[:, :3], source_pos_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(source_pose_w_gt[:, 3:], source_quat_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_pos_w_gt, feet_pos_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_quat_w_gt, feet_quat_w_tf, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(source_pose_w_gt[:, :3], source_pos_w_tf)
torch.testing.assert_close(source_pose_w_gt[:, 3:], source_quat_w_tf)
torch.testing.assert_close(feet_pos_w_gt, feet_pos_w_tf)
torch.testing.assert_close(feet_quat_w_gt, feet_quat_w_tf)
# check if relative transforms are same
feet_pos_source_tf = scene.sensors["frame_transformer"].data.target_pos_source
......@@ -299,8 +315,269 @@ class TestFrameTransformer(unittest.TestCase):
source_pose_w_gt[:, :3], source_pose_w_gt[:, 3:], feet_pos_w_tf[:, index], feet_quat_w_tf[:, index]
)
# check if they are same
torch.testing.assert_close(feet_pos_source_tf[:, index], foot_pos_b, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_quat_source_tf[:, index], foot_quat_b, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(feet_pos_source_tf[:, index], foot_pos_b)
torch.testing.assert_close(feet_quat_source_tf[:, index], foot_quat_b)
def test_frame_transformer_robot_body_to_external_cube(self):
"""Test transformation from robot body to a cube in the scene.
In this test, the source frame is the robot base.
The target_frame is a cube in the scene, external to the robot.
"""
# Spawn things into stage
scene_cfg = MySceneCfg(num_envs=2, env_spacing=5.0, lazy_sensor_update=False)
scene_cfg.frame_transformer = FrameTransformerCfg(
prim_path="{ENV_REGEX_NS}/Robot/base",
target_frames=[
FrameTransformerCfg.FrameCfg(
name="CUBE_USER",
prim_path="{ENV_REGEX_NS}/cube",
),
],
)
scene = InteractiveScene(scene_cfg)
# Play the simulator
self.sim.reset()
# default joint targets
default_actions = scene.articulations["robot"].data.default_joint_pos.clone()
# Define simulation stepping
sim_dt = self.sim.get_physics_dt()
# Simulate physics
for count in range(100):
# # reset
if count % 25 == 0:
# reset root state
root_state = scene.articulations["robot"].data.default_root_state.clone()
root_state[:, :3] += scene.env_origins
joint_pos = scene.articulations["robot"].data.default_joint_pos
joint_vel = scene.articulations["robot"].data.default_joint_vel
# -- set root state
# -- robot
scene.articulations["robot"].write_root_state_to_sim(root_state)
scene.articulations["robot"].write_joint_state_to_sim(joint_pos, joint_vel)
# reset buffers
scene.reset()
# set joint targets
robot_actions = default_actions + 0.5 * torch.randn_like(default_actions)
scene.articulations["robot"].set_joint_position_target(robot_actions)
# write data to sim
scene.write_data_to_sim()
# perform step
self.sim.step()
# read data from sim
scene.update(sim_dt)
# check absolute frame transforms in world frame
# -- ground-truth
root_pose_w = scene.articulations["robot"].data.root_state_w[:, :7]
cube_pos_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, :3]
cube_quat_w_gt = scene.rigid_objects["cube"].data.root_state_w[:, 3:7]
# -- frame transformer
source_pos_w_tf = scene.sensors["frame_transformer"].data.source_pos_w
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
cube_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w.squeeze()
cube_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w.squeeze()
# check if they are same
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf)
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf)
torch.testing.assert_close(cube_pos_w_gt, cube_pos_w_tf)
torch.testing.assert_close(cube_quat_w_gt, cube_quat_w_tf)
# check if relative transforms are same
cube_pos_source_tf = scene.sensors["frame_transformer"].data.target_pos_source
cube_quat_source_tf = scene.sensors["frame_transformer"].data.target_quat_source
# ground-truth
cube_pos_b, cube_quat_b = math_utils.subtract_frame_transforms(
root_pose_w[:, :3], root_pose_w[:, 3:], cube_pos_w_tf, cube_quat_w_tf
)
# check if they are same
torch.testing.assert_close(cube_pos_source_tf[:, 0], cube_pos_b)
torch.testing.assert_close(cube_quat_source_tf[:, 0], cube_quat_b)
def test_frame_transformer_offset_frames(self):
"""Test body transformation w.r.t. base source frame.
In this test, the source frame is the cube frame.
"""
# Spawn things into stage
scene_cfg = MySceneCfg(num_envs=2, env_spacing=5.0, lazy_sensor_update=False)
scene_cfg.frame_transformer = FrameTransformerCfg(
prim_path="{ENV_REGEX_NS}/cube",
target_frames=[
FrameTransformerCfg.FrameCfg(
name="CUBE_CENTER",
prim_path="{ENV_REGEX_NS}/cube",
),
FrameTransformerCfg.FrameCfg(
name="CUBE_TOP",
prim_path="{ENV_REGEX_NS}/cube",
offset=OffsetCfg(
pos=(0.0, 0.0, 0.1),
rot=(1.0, 0.0, 0.0, 0.0),
),
),
FrameTransformerCfg.FrameCfg(
name="CUBE_BOTTOM",
prim_path="{ENV_REGEX_NS}/cube",
offset=OffsetCfg(
pos=(0.0, 0.0, -0.1),
rot=(1.0, 0.0, 0.0, 0.0),
),
),
],
)
scene = InteractiveScene(scene_cfg)
# Play the simulator
self.sim.reset()
# Define simulation stepping
sim_dt = self.sim.get_physics_dt()
# Simulate physics
for count in range(100):
# # reset
if count % 25 == 0:
# reset root state
root_state = scene["cube"].data.default_root_state.clone()
root_state[:, :3] += scene.env_origins
# -- set root state
# -- cube
scene["cube"].write_root_state_to_sim(root_state)
# reset buffers
scene.reset()
# write data to sim
scene.write_data_to_sim()
# perform step
self.sim.step()
# read data from sim
scene.update(sim_dt)
# check absolute frame transforms in world frame
# -- ground-truth
cube_pos_w_gt = scene["cube"].data.root_state_w[:, :3]
cube_quat_w_gt = scene["cube"].data.root_state_w[:, 3:7]
# -- frame transformer
source_pos_w_tf = scene.sensors["frame_transformer"].data.source_pos_w
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
target_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w.squeeze()
target_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w.squeeze()
target_frame_names = scene.sensors["frame_transformer"].data.target_frame_names
cube_center_idx = target_frame_names.index("CUBE_CENTER")
cube_bottom_idx = target_frame_names.index("CUBE_BOTTOM")
cube_top_idx = target_frame_names.index("CUBE_TOP")
# check if they are same
torch.testing.assert_close(cube_pos_w_gt, source_pos_w_tf)
torch.testing.assert_close(cube_quat_w_gt, source_quat_w_tf)
torch.testing.assert_close(cube_pos_w_gt, target_pos_w_tf[:, cube_center_idx])
torch.testing.assert_close(cube_quat_w_gt, target_quat_w_tf[:, cube_center_idx])
# test offsets are applied correctly
# -- cube top
cube_pos_top = target_pos_w_tf[:, cube_top_idx]
cube_quat_top = target_quat_w_tf[:, cube_top_idx]
torch.testing.assert_close(cube_pos_top, cube_pos_w_gt + torch.tensor([0.0, 0.0, 0.1]))
torch.testing.assert_close(cube_quat_top, cube_quat_w_gt)
# -- cube bottom
cube_pos_bottom = target_pos_w_tf[:, cube_bottom_idx]
cube_quat_bottom = target_quat_w_tf[:, cube_bottom_idx]
torch.testing.assert_close(cube_pos_bottom, cube_pos_w_gt + torch.tensor([0.0, 0.0, -0.1]))
torch.testing.assert_close(cube_quat_bottom, cube_quat_w_gt)
def test_frame_transformer_all_bodies(self):
"""Test transformation of all bodies w.r.t. base source frame.
In this test, the source frame is the robot base.
The target_frames are all bodies in the robot, implemented using .* pattern.
"""
# Spawn things into stage
scene_cfg = MySceneCfg(num_envs=2, env_spacing=5.0, lazy_sensor_update=False)
scene_cfg.frame_transformer = FrameTransformerCfg(
prim_path="{ENV_REGEX_NS}/Robot/base",
target_frames=[
FrameTransformerCfg.FrameCfg(
prim_path="{ENV_REGEX_NS}/Robot/.*",
),
],
)
scene = InteractiveScene(scene_cfg)
# Play the simulator
self.sim.reset()
target_frame_names = scene.sensors["frame_transformer"].data.target_frame_names
articulation_body_names = scene.articulations["robot"].data.body_names
reordering_indices = [target_frame_names.index(name) for name in articulation_body_names]
# default joint targets
default_actions = scene.articulations["robot"].data.default_joint_pos.clone()
# Define simulation stepping
sim_dt = self.sim.get_physics_dt()
# Simulate physics
for count in range(100):
# # reset
if count % 25 == 0:
# reset root state
root_state = scene.articulations["robot"].data.default_root_state.clone()
root_state[:, :3] += scene.env_origins
joint_pos = scene.articulations["robot"].data.default_joint_pos
joint_vel = scene.articulations["robot"].data.default_joint_vel
# -- set root state
# -- robot
scene.articulations["robot"].write_root_state_to_sim(root_state)
scene.articulations["robot"].write_joint_state_to_sim(joint_pos, joint_vel)
# reset buffers
scene.reset()
# set joint targets
robot_actions = default_actions + 0.5 * torch.randn_like(default_actions)
scene.articulations["robot"].set_joint_position_target(robot_actions)
# write data to sim
scene.write_data_to_sim()
# perform step
self.sim.step()
# read data from sim
scene.update(sim_dt)
# check absolute frame transforms in world frame
# -- ground-truth
root_pose_w = scene.articulations["robot"].data.root_state_w[:, :7]
bodies_pos_w_gt = scene.articulations["robot"].data.body_pos_w
bodies_quat_w_gt = scene.articulations["robot"].data.body_quat_w
# -- frame transformer
source_pos_w_tf = scene.sensors["frame_transformer"].data.source_pos_w
source_quat_w_tf = scene.sensors["frame_transformer"].data.source_quat_w
bodies_pos_w_tf = scene.sensors["frame_transformer"].data.target_pos_w
bodies_quat_w_tf = scene.sensors["frame_transformer"].data.target_quat_w
# check if they are same
torch.testing.assert_close(root_pose_w[:, :3], source_pos_w_tf)
torch.testing.assert_close(root_pose_w[:, 3:], source_quat_w_tf)
torch.testing.assert_close(bodies_pos_w_gt, bodies_pos_w_tf[:, reordering_indices])
torch.testing.assert_close(bodies_quat_w_gt, bodies_quat_w_tf[:, reordering_indices])
bodies_pos_source_tf = scene.sensors["frame_transformer"].data.target_pos_source
bodies_quat_source_tf = scene.sensors["frame_transformer"].data.target_quat_source
# Go through each body and check if relative transforms are same
for index in range(len(articulation_body_names)):
body_pos_b, body_quat_b = math_utils.subtract_frame_transforms(
root_pose_w[:, :3], root_pose_w[:, 3:], bodies_pos_w_tf[:, index], bodies_quat_w_tf[:, index]
)
torch.testing.assert_close(bodies_pos_source_tf[:, index], body_pos_b)
torch.testing.assert_close(bodies_quat_source_tf[:, index], body_quat_b)
if __name__ == "__main__":
......
......@@ -139,10 +139,10 @@ def run_simulator(sim: sim_utils.SimulationContext, scene_entities: dict):
if count % 50 == 0:
# get frame names
frame_names = frame_transformer.data.target_frame_names
print(f"Displaying Frame ID {frame_index}: {frame_names[frame_index]}")
# increment frame index
frame_index += 1
frame_index = frame_index % len(frame_names)
print(f"Displaying Frame ID {frame_index}: {frame_names[frame_index]}")
# visualize frame
source_pos = frame_transformer.data.source_pos_w
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment