Unverified Commit c75bc5c5 authored by ooctipus's avatar ooctipus Committed by GitHub

Fixes inconsistent data reading in body, link, com for RigidObject,...

Fixes inconsistent data reading in body, link, com for RigidObject, RigidObjectCollection and Articulation (#2736)

# Description
When WriteState, WriteLink, WriteCOM, WriteJoint are invoked, there is a
inconsistency when reading values of ReadState, ReadLink, ReadCOM. The
Source of the bug is because of missing timestamp invalidation of
relative data or missing update to the related data within the write
function. Below I list the all functions that is problematics

RigitObject:
write_root_link_pose_to_sim
write_root_com_velocity_to_sim

RigitObjectCollection:
write_object_link_pose_to_sim
write_object_com_velocity_to_sim

Articulation:
write_joint_state_to_sim

The bug if fixed by invalidating the relevant data timestamps in
write_joint_state_to_sim function for articulation, and added direct
update to the dependent data in write_(state|link|com)_to_sim of
RigitObject and RigitObjectCollection.

I have added the tests cases that checks the consistency among
ReadState, ReadLink, ReadCOM when either WriteState, WriteLink,
WriteCOM, WriteJoint is called and passed all tests.

Fixes #2534 #2702 

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->
Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent ea717fa5
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.40.7" version = "0.40.8"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.40.8 (2025-06-18)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed data inconsistency between read_body, read_link, read_com when write_body, write_com, write_joint performed, in
:class:`~isaaclab.assets.Articulation`, :class:`~isaaclab.assets.RigidObject`, and
:class:`~isaaclab.assets.RigidObjectCollection`
* added pytest that check against these data consistencies
0.40.7 (2025-06-24) 0.40.7 (2025-06-24)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -517,6 +517,12 @@ class Articulation(AssetBase): ...@@ -517,6 +517,12 @@ class Articulation(AssetBase):
# set into internal buffers # set into internal buffers
self._data.joint_pos[env_ids, joint_ids] = position self._data.joint_pos[env_ids, joint_ids] = position
# Need to invalidate the buffer to trigger the update with the new root pose. # Need to invalidate the buffer to trigger the update with the new root pose.
self._data._body_com_vel_w.timestamp = -1.0
self._data._body_link_vel_w.timestamp = -1.0
self._data._body_com_pose_b.timestamp = -1.0
self._data._body_com_pose_w.timestamp = -1.0
self._data._body_link_pose_w.timestamp = -1.0
self._data._body_state_w.timestamp = -1.0 self._data._body_state_w.timestamp = -1.0
self._data._body_link_state_w.timestamp = -1.0 self._data._body_link_state_w.timestamp = -1.0
self._data._body_com_state_w.timestamp = -1.0 self._data._body_com_state_w.timestamp = -1.0
......
...@@ -220,11 +220,18 @@ class RigidObject(AssetBase): ...@@ -220,11 +220,18 @@ class RigidObject(AssetBase):
self._data.root_link_state_w[env_ids, :7] = self._data.root_link_pose_w[env_ids] self._data.root_link_state_w[env_ids, :7] = self._data.root_link_pose_w[env_ids]
if self._data._root_state_w.data is not None: if self._data._root_state_w.data is not None:
self._data.root_state_w[env_ids, :7] = self._data.root_link_pose_w[env_ids] self._data.root_state_w[env_ids, :7] = self._data.root_link_pose_w[env_ids]
if self._data._root_com_state_w.data is not None:
expected_com_pos, expected_com_quat = math_utils.combine_frame_transforms(
self._data.root_link_pose_w[env_ids, :3],
self._data.root_link_pose_w[env_ids, 3:7],
self.data.body_com_pos_b[env_ids, 0, :],
self.data.body_com_quat_b[env_ids, 0, :],
)
self._data.root_com_state_w[env_ids, :3] = expected_com_pos
self._data.root_com_state_w[env_ids, 3:7] = expected_com_quat
# convert root quaternion from wxyz to xyzw # convert root quaternion from wxyz to xyzw
root_poses_xyzw = self._data.root_link_pose_w.clone() root_poses_xyzw = self._data.root_link_pose_w.clone()
root_poses_xyzw[:, 3:] = math_utils.convert_quat(root_poses_xyzw[:, 3:], to="xyzw") root_poses_xyzw[:, 3:] = math_utils.convert_quat(root_poses_xyzw[:, 3:], to="xyzw")
# set into simulation # set into simulation
self.root_physx_view.set_transforms(root_poses_xyzw, indices=physx_env_ids) self.root_physx_view.set_transforms(root_poses_xyzw, indices=physx_env_ids)
...@@ -301,9 +308,10 @@ class RigidObject(AssetBase): ...@@ -301,9 +308,10 @@ class RigidObject(AssetBase):
self._data.root_com_state_w[env_ids, 7:] = self._data.root_com_vel_w[env_ids] self._data.root_com_state_w[env_ids, 7:] = self._data.root_com_vel_w[env_ids]
if self._data._root_state_w.data is not None: if self._data._root_state_w.data is not None:
self._data.root_state_w[env_ids, 7:] = self._data.root_com_vel_w[env_ids] self._data.root_state_w[env_ids, 7:] = self._data.root_com_vel_w[env_ids]
if self._data._root_link_state_w.data is not None:
self._data.root_link_state_w[env_ids, 7:] = self._data.root_com_vel_w[env_ids]
# make the acceleration zero to prevent reporting old values # make the acceleration zero to prevent reporting old values
self._data.body_com_acc_w[env_ids] = 0.0 self._data.body_com_acc_w[env_ids] = 0.0
# set into simulation # set into simulation
self.root_physx_view.set_velocities(self._data.root_com_vel_w, indices=physx_env_ids) self.root_physx_view.set_velocities(self._data.root_com_vel_w, indices=physx_env_ids)
......
...@@ -317,6 +317,18 @@ class RigidObjectCollection(AssetBase): ...@@ -317,6 +317,18 @@ class RigidObjectCollection(AssetBase):
self._data.object_link_state_w[env_ids[:, None], object_ids, :7] = object_pose.clone() self._data.object_link_state_w[env_ids[:, None], object_ids, :7] = object_pose.clone()
if self._data._object_state_w.data is not None: if self._data._object_state_w.data is not None:
self._data.object_state_w[env_ids[:, None], object_ids, :7] = object_pose.clone() self._data.object_state_w[env_ids[:, None], object_ids, :7] = object_pose.clone()
if self._data._object_com_state_w.data is not None:
# get CoM pose in link frame
com_pos_b = self.data.object_com_pos_b[env_ids[:, None], object_ids]
com_quat_b = self.data.object_com_quat_b[env_ids[:, None], object_ids]
com_pos, com_quat = math_utils.combine_frame_transforms(
object_pose[..., :3],
object_pose[..., 3:7],
com_pos_b,
com_quat_b,
)
self._data.object_com_state_w[env_ids[:, None], object_ids, :3] = com_pos
self._data.object_com_state_w[env_ids[:, None], object_ids, 3:7] = com_quat
# convert the quaternion from wxyz to xyzw # convert the quaternion from wxyz to xyzw
poses_xyzw = self._data.object_link_pose_w.clone() poses_xyzw = self._data.object_link_pose_w.clone()
...@@ -415,6 +427,8 @@ class RigidObjectCollection(AssetBase): ...@@ -415,6 +427,8 @@ class RigidObjectCollection(AssetBase):
self._data.object_com_state_w[env_ids[:, None], object_ids, 7:] = object_velocity.clone() self._data.object_com_state_w[env_ids[:, None], object_ids, 7:] = object_velocity.clone()
if self._data._object_state_w.data is not None: if self._data._object_state_w.data is not None:
self._data.object_state_w[env_ids[:, None], object_ids, 7:] = object_velocity.clone() self._data.object_state_w[env_ids[:, None], object_ids, 7:] = object_velocity.clone()
if self._data._object_link_state_w.data is not None:
self._data.object_link_state_w[env_ids[:, None], object_ids, 7:] = object_velocity.clone()
# make the acceleration zero to prevent reporting old values # make the acceleration zero to prevent reporting old values
self._data.object_com_acc_w[env_ids[:, None], object_ids] = 0.0 self._data.object_com_acc_w[env_ids[:, None], object_ids] = 0.0
......
...@@ -1608,5 +1608,88 @@ def test_body_incoming_joint_wrench_b_single_joint(sim, num_articulations, devic ...@@ -1608,5 +1608,88 @@ def test_body_incoming_joint_wrench_b_single_joint(sim, num_articulations, devic
sim.reset() sim.reset()
@pytest.mark.parametrize("num_articulations", [1, 2])
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("gravity_enabled", [False])
def test_write_joint_state_data_consistency(sim, num_articulations, device, gravity_enabled):
"""Test the setters for root_state using both the link frame and center of mass as reference frame.
This test verifies that after write_joint_state_to_sim operations:
1. state, com_state, link_state value consistency
2. body_pose, link
Args:
sim: The simulation fixture
num_articulations: Number of articulations to test
device: The device to run the simulation on
"""
sim._app_control_on_stop_handle = None
articulation_cfg = generate_articulation_cfg(articulation_type="anymal")
articulation, env_pos = generate_articulation(articulation_cfg, num_articulations, device)
env_idx = torch.tensor([x for x in range(num_articulations)])
# Play sim
sim.reset()
limits = torch.zeros(num_articulations, articulation.num_joints, 2, device=device)
limits[..., 0] = (torch.rand(num_articulations, articulation.num_joints, device=device) + 5.0) * -1.0
limits[..., 1] = torch.rand(num_articulations, articulation.num_joints, device=device) + 5.0
articulation.write_joint_position_limit_to_sim(limits)
from torch.distributions import Uniform
pos_dist = Uniform(articulation.data.joint_pos_limits[..., 0], articulation.data.joint_pos_limits[..., 1])
vel_dist = Uniform(-articulation.data.joint_vel_limits, articulation.data.joint_vel_limits)
original_body_states = articulation.data.body_state_w.clone()
rand_joint_pos = pos_dist.sample()
rand_joint_vel = vel_dist.sample()
articulation.write_joint_state_to_sim(rand_joint_pos, rand_joint_vel)
articulation.root_physx_view.get_jacobians()
# make sure valued updated
assert torch.count_nonzero(original_body_states[:, 1:] != articulation.data.body_state_w[:, 1:]) > (
len(original_body_states[:, 1:]) / 2
)
# validate body - link consistency
torch.testing.assert_close(articulation.data.body_state_w[..., :7], articulation.data.body_link_state_w[..., :7])
# skip 7:10 because they differs from link frame, this should be fine because we are only checking
# if velocity update is triggered, which can be determined by comparing angular velocity
torch.testing.assert_close(articulation.data.body_state_w[..., 10:], articulation.data.body_link_state_w[..., 10:])
# validate link - com conistency
expected_com_pos, expected_com_quat = math_utils.combine_frame_transforms(
articulation.data.body_link_state_w[..., :3].view(-1, 3),
articulation.data.body_link_state_w[..., 3:7].view(-1, 4),
articulation.data.body_com_pos_b.view(-1, 3),
articulation.data.body_com_quat_b.view(-1, 4),
)
torch.testing.assert_close(expected_com_pos.view(len(env_idx), -1, 3), articulation.data.body_com_pos_w)
torch.testing.assert_close(expected_com_quat.view(len(env_idx), -1, 4), articulation.data.body_com_quat_w)
# validate body - com consistency
torch.testing.assert_close(articulation.data.body_state_w[..., 7:10], articulation.data.body_com_lin_vel_w)
torch.testing.assert_close(articulation.data.body_state_w[..., 10:], articulation.data.body_com_ang_vel_w)
# validate pos_w, quat_w, pos_b, quat_b is consistent with pose_w and pose_b
expected_com_pose_w = torch.cat((articulation.data.body_com_pos_w, articulation.data.body_com_quat_w), dim=2)
expected_com_pose_b = torch.cat((articulation.data.body_com_pos_b, articulation.data.body_com_quat_b), dim=2)
expected_body_pose_w = torch.cat((articulation.data.body_pos_w, articulation.data.body_quat_w), dim=2)
expected_body_link_pose_w = torch.cat(
(articulation.data.body_link_pos_w, articulation.data.body_link_quat_w), dim=2
)
torch.testing.assert_close(articulation.data.body_com_pose_w, expected_com_pose_w)
torch.testing.assert_close(articulation.data.body_com_pose_b, expected_com_pose_b)
torch.testing.assert_close(articulation.data.body_pose_w, expected_body_pose_w)
torch.testing.assert_close(articulation.data.body_link_pose_w, expected_body_link_pose_w)
# validate pose_w is consistent state[..., :7]
torch.testing.assert_close(articulation.data.body_pose_w, articulation.data.body_state_w[..., :7])
torch.testing.assert_close(articulation.data.body_vel_w, articulation.data.body_state_w[..., 7:])
torch.testing.assert_close(articulation.data.body_link_pose_w, articulation.data.body_link_state_w[..., :7])
torch.testing.assert_close(articulation.data.body_com_pose_w, articulation.data.body_com_state_w[..., :7])
torch.testing.assert_close(articulation.data.body_vel_w, articulation.data.body_state_w[..., 7:])
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__, "-v", "--maxfail=1"]) pytest.main([__file__, "-v", "--maxfail=1"])
...@@ -28,7 +28,15 @@ from isaaclab.assets import RigidObject, RigidObjectCfg ...@@ -28,7 +28,15 @@ from isaaclab.assets import RigidObject, RigidObjectCfg
from isaaclab.sim import build_simulation_context from isaaclab.sim import build_simulation_context
from isaaclab.sim.spawners import materials from isaaclab.sim.spawners import materials
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR, ISAACLAB_NUCLEUS_DIR from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR, ISAACLAB_NUCLEUS_DIR
from isaaclab.utils.math import default_orientation, quat_apply_inverse, quat_mul, random_orientation from isaaclab.utils.math import (
combine_frame_transforms,
default_orientation,
quat_apply_inverse,
quat_inv,
quat_mul,
quat_rotate,
random_orientation,
)
def generate_cubes_scene( def generate_cubes_scene(
...@@ -910,3 +918,109 @@ def test_write_root_state(num_cubes, device, with_offset, state_location): ...@@ -910,3 +918,109 @@ def test_write_root_state(num_cubes, device, with_offset, state_location):
torch.testing.assert_close(rand_state, cube_object.data.root_com_state_w) torch.testing.assert_close(rand_state, cube_object.data.root_com_state_w)
elif state_location == "link": elif state_location == "link":
torch.testing.assert_close(rand_state, cube_object.data.root_link_state_w) torch.testing.assert_close(rand_state, cube_object.data.root_link_state_w)
@pytest.mark.parametrize("num_cubes", [1, 2])
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("with_offset", [True])
@pytest.mark.parametrize("state_location", ["com", "link", "root"])
def test_write_state_functions_data_consistency(num_cubes, device, with_offset, state_location):
"""Test the setters for root_state using both the link frame and center of mass as reference frame."""
with build_simulation_context(device=device, gravity_enabled=False, auto_add_lighting=True) as sim:
sim._app_control_on_stop_handle = None
# Create a scene with random cubes
cube_object, env_pos = generate_cubes_scene(num_cubes=num_cubes, height=0.0, device=device)
env_idx = torch.tensor([x for x in range(num_cubes)])
# Play sim
sim.reset()
# Check if cube_object is initialized
assert cube_object.is_initialized
# change center of mass offset from link frame
if with_offset:
offset = torch.tensor([0.1, 0.0, 0.0], device=device).repeat(num_cubes, 1)
else:
offset = torch.tensor([0.0, 0.0, 0.0], device=device).repeat(num_cubes, 1)
com = cube_object.root_physx_view.get_coms()
com[..., :3] = offset.to("cpu")
cube_object.root_physx_view.set_coms(com, env_idx)
# check ceter of mass has been set
torch.testing.assert_close(cube_object.root_physx_view.get_coms(), com)
rand_state = torch.rand_like(cube_object.data.root_state_w)
# rand_state[..., :7] = cube_object.data.default_root_state[..., :7]
rand_state[..., :3] += env_pos
# make quaternion a unit vector
rand_state[..., 3:7] = torch.nn.functional.normalize(rand_state[..., 3:7], dim=-1)
env_idx = env_idx.to(device)
# perform step
sim.step()
# update buffers
cube_object.update(sim.cfg.dt)
if state_location == "com":
cube_object.write_root_com_state_to_sim(rand_state)
elif state_location == "link":
cube_object.write_root_link_state_to_sim(rand_state)
elif state_location == "root":
cube_object.write_root_state_to_sim(rand_state)
if state_location == "com":
expected_root_link_pos, expected_root_link_quat = combine_frame_transforms(
cube_object.data.root_com_state_w[:, :3],
cube_object.data.root_com_state_w[:, 3:7],
quat_rotate(
quat_inv(cube_object.data.body_com_pose_b[:, 0, 3:7]), -cube_object.data.body_com_pose_b[:, 0, :3]
),
quat_inv(cube_object.data.body_com_pose_b[:, 0, 3:7]),
)
expected_root_link_pose = torch.cat((expected_root_link_pos, expected_root_link_quat), dim=1)
# test both root_pose and root_link_state_w successfully updated when root_com_state_w updates
torch.testing.assert_close(expected_root_link_pose, cube_object.data.root_link_state_w[:, :7])
# skip 7:10 because they differs from link frame, this should be fine because we are only checking
# if velocity update is triggered, which can be determined by comparing angular velocity
torch.testing.assert_close(
cube_object.data.root_com_state_w[:, 10:], cube_object.data.root_link_state_w[:, 10:]
)
torch.testing.assert_close(expected_root_link_pose, cube_object.data.root_state_w[:, :7])
torch.testing.assert_close(cube_object.data.root_com_state_w[:, 10:], cube_object.data.root_state_w[:, 10:])
elif state_location == "link":
expected_com_pos, expected_com_quat = combine_frame_transforms(
cube_object.data.root_link_state_w[:, :3],
cube_object.data.root_link_state_w[:, 3:7],
cube_object.data.body_com_pose_b[:, 0, :3],
cube_object.data.body_com_pose_b[:, 0, 3:7],
)
expected_com_pose = torch.cat((expected_com_pos, expected_com_quat), dim=1)
# test both root_pose and root_com_state_w successfully updated when root_link_state_w updates
torch.testing.assert_close(expected_com_pose, cube_object.data.root_com_state_w[:, :7])
# skip 7:10 because they differs from link frame, this should be fine because we are only checking
# if velocity update is triggered, which can be determined by comparing angular velocity
torch.testing.assert_close(
cube_object.data.root_link_state_w[:, 10:], cube_object.data.root_com_state_w[:, 10:]
)
torch.testing.assert_close(cube_object.data.root_link_state_w[:, :7], cube_object.data.root_state_w[:, :7])
torch.testing.assert_close(
cube_object.data.root_link_state_w[:, 10:], cube_object.data.root_state_w[:, 10:]
)
elif state_location == "root":
expected_com_pos, expected_com_quat = combine_frame_transforms(
cube_object.data.root_state_w[:, :3],
cube_object.data.root_state_w[:, 3:7],
cube_object.data.body_com_pose_b[:, 0, :3],
cube_object.data.body_com_pose_b[:, 0, 3:7],
)
expected_com_pose = torch.cat((expected_com_pos, expected_com_quat), dim=1)
# test both root_com_state_w and root_link_state_w successfully updated when root_pose updates
torch.testing.assert_close(expected_com_pose, cube_object.data.root_com_state_w[:, :7])
torch.testing.assert_close(cube_object.data.root_state_w[:, 7:], cube_object.data.root_com_state_w[:, 7:])
torch.testing.assert_close(cube_object.data.root_state_w[:, :7], cube_object.data.root_link_state_w[:, :7])
torch.testing.assert_close(
cube_object.data.root_state_w[:, 10:], cube_object.data.root_link_state_w[:, 10:]
)
...@@ -26,7 +26,16 @@ import isaaclab.sim as sim_utils ...@@ -26,7 +26,16 @@ import isaaclab.sim as sim_utils
from isaaclab.assets import RigidObjectCfg, RigidObjectCollection, RigidObjectCollectionCfg from isaaclab.assets import RigidObjectCfg, RigidObjectCollection, RigidObjectCollectionCfg
from isaaclab.sim import build_simulation_context from isaaclab.sim import build_simulation_context
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR
from isaaclab.utils.math import default_orientation, quat_apply_inverse, quat_mul, random_orientation from isaaclab.utils.math import (
combine_frame_transforms,
default_orientation,
quat_apply_inverse,
quat_inv,
quat_mul,
quat_rotate,
random_orientation,
subtract_frame_transforms,
)
def generate_cubes_scene( def generate_cubes_scene(
...@@ -601,3 +610,128 @@ def test_gravity_vec_w(sim, num_envs, num_cubes, device, gravity_enabled): ...@@ -601,3 +610,128 @@ def test_gravity_vec_w(sim, num_envs, num_cubes, device, gravity_enabled):
# Check the body accelerations are correct # Check the body accelerations are correct
torch.testing.assert_close(object_collection.data.object_acc_w, gravity) torch.testing.assert_close(object_collection.data.object_acc_w, gravity)
@pytest.mark.parametrize("num_envs", [1, 3])
@pytest.mark.parametrize("num_cubes", [1, 2])
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
@pytest.mark.parametrize("with_offset", [True])
@pytest.mark.parametrize("state_location", ["com", "link", "root"])
@pytest.mark.parametrize("gravity_enabled", [False])
def test_write_object_state_functions_data_consistency(
sim, num_envs, num_cubes, device, with_offset, state_location, gravity_enabled
):
"""Test the setters for object_state using both the link frame and center of mass as reference frame."""
# Create a scene with random cubes
cube_object, env_pos = generate_cubes_scene(num_envs=num_envs, num_cubes=num_cubes, height=0.0, device=device)
view_ids = torch.tensor([x for x in range(num_cubes * num_cubes)])
env_ids = torch.tensor([x for x in range(num_envs)])
object_ids = torch.tensor([x for x in range(num_cubes)])
sim.reset()
# Check if cube_object is initialized
assert cube_object.is_initialized
# change center of mass offset from link frame
offset = (
torch.tensor([0.1, 0.0, 0.0], device=device).repeat(num_envs, num_cubes, 1)
if with_offset
else torch.tensor([0.0, 0.0, 0.0], device=device).repeat(num_envs, num_cubes, 1)
)
com = cube_object.reshape_view_to_data(cube_object.root_physx_view.get_coms())
com[..., :3] = offset.to("cpu")
cube_object.root_physx_view.set_coms(cube_object.reshape_data_to_view(com.clone()), view_ids)
# check center of mass has been set
torch.testing.assert_close(cube_object.reshape_view_to_data(cube_object.root_physx_view.get_coms()), com)
rand_state = torch.rand_like(cube_object.data.object_link_state_w)
rand_state[..., :3] += cube_object.data.object_link_pos_w
# make quaternion a unit vector
rand_state[..., 3:7] = torch.nn.functional.normalize(rand_state[..., 3:7], dim=-1)
env_ids = env_ids.to(device)
object_ids = object_ids.to(device)
sim.step()
cube_object.update(sim.cfg.dt)
object_link_to_com_pos, object_link_to_com_quat = subtract_frame_transforms(
cube_object.data.object_link_state_w[..., :3].view(-1, 3),
cube_object.data.object_link_state_w[..., 3:7].view(-1, 4),
cube_object.data.object_com_state_w[..., :3].view(-1, 3),
cube_object.data.object_com_state_w[..., 3:7].view(-1, 4),
)
if state_location == "com":
cube_object.write_object_com_state_to_sim(rand_state, env_ids=env_ids, object_ids=object_ids)
elif state_location == "link":
cube_object.write_object_link_state_to_sim(rand_state, env_ids=env_ids, object_ids=object_ids)
elif state_location == "root":
cube_object.write_object_state_to_sim(rand_state, env_ids=env_ids, object_ids=object_ids)
if state_location == "com":
expected_root_link_pos, expected_root_link_quat = combine_frame_transforms(
cube_object.data.object_com_state_w[..., :3].view(-1, 3),
cube_object.data.object_com_state_w[..., 3:7].view(-1, 4),
quat_rotate(quat_inv(object_link_to_com_quat), -object_link_to_com_pos),
quat_inv(object_link_to_com_quat),
)
# torch.testing.assert_close(rand_state, cube_object.data.object_com_state_w)
expected_object_link_pose = torch.cat((expected_root_link_pos, expected_root_link_quat), dim=1).view(
num_envs, -1, 7
)
# test both root_pose and root_link_state_w successfully updated when root_com_state_w updates
torch.testing.assert_close(expected_object_link_pose, cube_object.data.object_link_state_w[..., :7])
# skip 7:10 because they differs from link frame, this should be fine because we are only checking
# if velocity update is triggered, which can be determined by comparing angular velocity
torch.testing.assert_close(
cube_object.data.object_com_state_w[..., 10:], cube_object.data.object_link_state_w[..., 10:]
)
torch.testing.assert_close(expected_object_link_pose, cube_object.data.object_state_w[..., :7])
torch.testing.assert_close(
cube_object.data.object_com_state_w[..., 10:], cube_object.data.object_state_w[..., 10:]
)
elif state_location == "link":
expected_com_pos, expected_com_quat = combine_frame_transforms(
cube_object.data.object_link_state_w[..., :3].view(-1, 3),
cube_object.data.object_link_state_w[..., 3:7].view(-1, 4),
object_link_to_com_pos,
object_link_to_com_quat,
)
expected_object_com_pose = torch.cat((expected_com_pos, expected_com_quat), dim=1).view(num_envs, -1, 7)
# test both root_pose and root_com_state_w successfully updated when root_link_state_w updates
torch.testing.assert_close(expected_object_com_pose, cube_object.data.object_com_state_w[..., :7])
# skip 7:10 because they differs from link frame, this should be fine because we are only checking
# if velocity update is triggered, which can be determined by comparing angular velocity
torch.testing.assert_close(
cube_object.data.object_link_state_w[..., 10:], cube_object.data.object_com_state_w[..., 10:]
)
torch.testing.assert_close(
cube_object.data.object_link_state_w[..., :7], cube_object.data.object_state_w[..., :7]
)
torch.testing.assert_close(
cube_object.data.object_link_state_w[..., 10:], cube_object.data.object_state_w[..., 10:]
)
elif state_location == "root":
expected_object_com_pos, expected_object_com_quat = combine_frame_transforms(
cube_object.data.object_state_w[..., :3].view(-1, 3),
cube_object.data.object_state_w[..., 3:7].view(-1, 4),
object_link_to_com_pos,
object_link_to_com_quat,
)
expected_object_com_pose = torch.cat((expected_object_com_pos, expected_object_com_quat), dim=1).view(
num_envs, -1, 7
)
# test both root_com_state_w and root_link_state_w successfully updated when root_pose updates
torch.testing.assert_close(expected_object_com_pose, cube_object.data.object_com_state_w[..., :7])
torch.testing.assert_close(
cube_object.data.object_state_w[..., 7:], cube_object.data.object_com_state_w[..., 7:]
)
torch.testing.assert_close(
cube_object.data.object_state_w[..., :7], cube_object.data.object_link_state_w[..., :7]
)
torch.testing.assert_close(
cube_object.data.object_state_w[..., 10:], cube_object.data.object_link_state_w[..., 10:]
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment