Unverified Commit cdc66407 authored by James Tigue's avatar James Tigue Committed by GitHub

Changes to `quat_apply` and `quat_apply_inverse` for speed (#2129)

# Description

As per findings in #1711, `quat_apply` was found to be faster that
`quat_rotate`. This PR:
- adds `quat_apply_inverse`
- changes all instances of `quat_rotate` and `quat_rotate_inverse` to
their apply counterparts.

Fixes #1711

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Screenshots

| Per 1000 | cpu | cuda |
|:----------|:-------:|:---------:|
|**quat_apply:** |			**217.91 us** |	**47.07 us**|
|einsum_quat_rotate: |		295.95 us |	127.62 us|
|iter_quat_apply: |		679.10 us |	850.25 us|
|iter_bmm_quat_rotate: |		829.62 us |	1.28 ms|
|iter_einsum_quat_rotate: |	937.73 us |	1.46 ms|
|**quat_apply_inverse:** |		**212.20 us** |	**48.43 us**|
|einsum_quat_rotate_inverse: |	278.43 us |	114.25 us|
|iter_quat_apply_inverse: |	681.85 us |	774.82 us|
|iter_bmm_quat_rotate_inverse: |	863.27 us |	1.23 ms|
|iter_einsum_quat_rotate_inverse: |	1.04 ms |	1.45 ms|

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarJames Tigue <166445701+jtigue-bdai@users.noreply.github.com>
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 3b6d615f
......@@ -49,8 +49,8 @@ from isaaclab.utils import configclass
from isaaclab.utils.math import (
combine_frame_transforms,
matrix_from_quat,
quat_apply_inverse,
quat_inv,
quat_rotate_inverse,
subtract_frame_transforms,
)
......@@ -336,8 +336,8 @@ def update_states(
ee_vel_w = robot.data.body_vel_w[:, ee_frame_idx, :] # Extract end-effector velocity in the world frame
root_vel_w = robot.data.root_vel_w # Extract root velocity in the world frame
relative_vel_w = ee_vel_w - root_vel_w # Compute the relative velocity in the world frame
ee_lin_vel_b = quat_rotate_inverse(robot.data.root_quat_w, relative_vel_w[:, 0:3]) # From world to root frame
ee_ang_vel_b = quat_rotate_inverse(robot.data.root_quat_w, relative_vel_w[:, 3:6])
ee_lin_vel_b = quat_apply_inverse(robot.data.root_quat_w, relative_vel_w[:, 0:3]) # From world to root frame
ee_ang_vel_b = quat_apply_inverse(robot.data.root_quat_w, relative_vel_w[:, 3:6])
ee_vel_b = torch.cat([ee_lin_vel_b, ee_ang_vel_b], dim=-1)
# Calculate the contact force
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.39.7"
version = "0.40.0"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.40.0 (2025-05-16)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added deprecation warning for :meth:`~isaaclab.utils.math.quat_rotate` and
:meth:`~isaaclab.utils.math.quat_rotate_inverse`
Changed
^^^^^^^
* Changed all calls to :meth:`~isaaclab.utils.math.quat_rotate` and :meth:`~isaaclab.utils.math.quat_rotate_inverse` to
:meth:`~isaaclab.utils.math.quat_apply` and :meth:`~isaaclab.utils.math.quat_apply_inverse` for speed.
0.39.7 (2025-05-19)
~~~~~~~~~~~~~~~~~~~
......
......@@ -391,7 +391,7 @@ class Articulation(AssetBase):
root_link_pos, root_link_quat = math_utils.combine_frame_transforms(
root_pose[..., :3],
root_pose[..., 3:7],
math_utils.quat_rotate(math_utils.quat_inv(com_quat), -com_pos),
math_utils.quat_apply(math_utils.quat_inv(com_quat), -com_pos),
math_utils.quat_inv(com_quat),
)
......@@ -465,7 +465,7 @@ class Articulation(AssetBase):
com_pos_b = self.data.com_pos_b[env_ids, 0, :]
# transform given velocity to center of mass
root_com_velocity[:, :3] += torch.linalg.cross(
root_com_velocity[:, 3:], math_utils.quat_rotate(quat, com_pos_b), dim=-1
root_com_velocity[:, 3:], math_utils.quat_apply(quat, com_pos_b), dim=-1
)
# write center of mass velocity to sim
self.write_root_com_velocity_to_sim(root_velocity=root_com_velocity, env_ids=physx_env_ids)
......
......@@ -395,7 +395,7 @@ class ArticulationData:
# adjust linear velocity to link from center of mass
velocity[:, :3] += torch.linalg.cross(
velocity[:, 3:], math_utils.quat_rotate(pose[:, 3:7], -self.com_pos_b[:, 0, :]), dim=-1
velocity[:, 3:], math_utils.quat_apply(pose[:, 3:7], -self.com_pos_b[:, 0, :]), dim=-1
)
# set the buffer data and timestamp
self._root_link_state_w.data = torch.cat((pose, velocity), dim=-1)
......@@ -463,7 +463,7 @@ class ArticulationData:
# adjust linear velocity to link from center of mass
velocity[..., :3] += torch.linalg.cross(
velocity[..., 3:], math_utils.quat_rotate(pose[..., 3:7], -self.com_pos_b), dim=-1
velocity[..., 3:], math_utils.quat_apply(pose[..., 3:7], -self.com_pos_b), dim=-1
)
# set the buffer data and timestamp
self._body_link_state_w.data = torch.cat((pose, velocity), dim=-1)
......@@ -529,7 +529,7 @@ class ArticulationData:
@property
def projected_gravity_b(self):
"""Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.GRAVITY_VEC_W)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.GRAVITY_VEC_W)
@property
def heading_w(self):
......@@ -624,7 +624,7 @@ class ArticulationData:
This quantity is the linear velocity of the articulation root's center of mass frame relative to the world
with respect to the articulation root's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_lin_vel_w)
return math_utils.quat_apply_inverse(self.root_quat_w, self.root_lin_vel_w)
@property
def root_ang_vel_b(self) -> torch.Tensor:
......@@ -633,7 +633,7 @@ class ArticulationData:
This quantity is the angular velocity of the articulation root's center of mass frame relative to the world with
respect to the articulation root's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_ang_vel_w)
return math_utils.quat_apply_inverse(self.root_quat_w, self.root_ang_vel_w)
##
# Derived Root Link Frame Properties
......@@ -696,7 +696,7 @@ class ArticulationData:
This quantity is the linear velocity of the actor frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_link_lin_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_link_lin_vel_w)
@property
def root_link_ang_vel_b(self) -> torch.Tensor:
......@@ -705,7 +705,7 @@ class ArticulationData:
This quantity is the angular velocity of the actor frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_link_ang_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_link_ang_vel_w)
##
# Root Center of Mass state properties
......@@ -771,7 +771,7 @@ class ArticulationData:
This quantity is the linear velocity of the root rigid body's center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_com_lin_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_com_lin_vel_w)
@property
def root_com_ang_vel_b(self) -> torch.Tensor:
......@@ -780,7 +780,7 @@ class ArticulationData:
This quantity is the angular velocity of the root rigid body's center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_com_ang_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_com_ang_vel_w)
@property
def body_pos_w(self) -> torch.Tensor:
......
......@@ -259,7 +259,7 @@ class RigidObject(AssetBase):
root_link_pos, root_link_quat = math_utils.combine_frame_transforms(
root_pose[..., :3],
root_pose[..., 3:7],
math_utils.quat_rotate(math_utils.quat_inv(com_quat), -com_pos),
math_utils.quat_apply(math_utils.quat_inv(com_quat), -com_pos),
math_utils.quat_inv(com_quat),
)
......@@ -333,7 +333,7 @@ class RigidObject(AssetBase):
com_pos_b = self.data.com_pos_b[local_env_ids, 0, :]
# transform given velocity to center of mass
root_com_velocity[:, :3] += torch.linalg.cross(
root_com_velocity[:, 3:], math_utils.quat_rotate(quat, com_pos_b), dim=-1
root_com_velocity[:, 3:], math_utils.quat_apply(quat, com_pos_b), dim=-1
)
# write center of mass velocity to sim
self.write_root_com_velocity_to_sim(root_velocity=root_com_velocity, env_ids=env_ids)
......
......@@ -142,7 +142,7 @@ class RigidObjectData:
# adjust linear velocity to link from center of mass
velocity[:, :3] += torch.linalg.cross(
velocity[:, 3:], math_utils.quat_rotate(pose[:, 3:7], -self.com_pos_b[:, 0, :]), dim=-1
velocity[:, 3:], math_utils.quat_apply(pose[:, 3:7], -self.com_pos_b[:, 0, :]), dim=-1
)
# set the buffer data and timestamp
self._root_link_state_w.data = torch.cat((pose, velocity), dim=-1)
......@@ -218,7 +218,7 @@ class RigidObjectData:
@property
def projected_gravity_b(self):
"""Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.GRAVITY_VEC_W)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.GRAVITY_VEC_W)
@property
def heading_w(self):
......@@ -282,7 +282,7 @@ class RigidObjectData:
This quantity is the linear velocity of the root rigid body's center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_lin_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_lin_vel_w)
@property
def root_ang_vel_b(self) -> torch.Tensor:
......@@ -291,7 +291,7 @@ class RigidObjectData:
This quantity is the angular velocity of the root rigid body's center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_ang_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_ang_vel_w)
@property
def root_link_pos_w(self) -> torch.Tensor:
......@@ -350,7 +350,7 @@ class RigidObjectData:
This quantity is the linear velocity of the actor frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_link_lin_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_link_lin_vel_w)
@property
def root_link_ang_vel_b(self) -> torch.Tensor:
......@@ -359,7 +359,7 @@ class RigidObjectData:
This quantity is the angular velocity of the actor frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_link_ang_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_link_ang_vel_w)
@property
def root_com_pos_w(self) -> torch.Tensor:
......@@ -420,7 +420,7 @@ class RigidObjectData:
This quantity is the linear velocity of the root rigid body's center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_com_lin_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_com_lin_vel_w)
@property
def root_com_ang_vel_b(self) -> torch.Tensor:
......@@ -429,7 +429,7 @@ class RigidObjectData:
This quantity is the angular velocity of the root rigid body's center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.root_link_quat_w, self.root_com_ang_vel_w)
return math_utils.quat_apply_inverse(self.root_link_quat_w, self.root_com_ang_vel_w)
@property
def body_pos_w(self) -> torch.Tensor:
......
......@@ -368,7 +368,7 @@ class RigidObjectCollection(AssetBase):
object_link_pos, object_link_quat = math_utils.combine_frame_transforms(
object_pose[..., :3],
object_pose[..., 3:7],
math_utils.quat_rotate(math_utils.quat_inv(com_quat), -com_pos),
math_utils.quat_apply(math_utils.quat_inv(com_quat), -com_pos),
math_utils.quat_inv(com_quat),
)
......@@ -465,7 +465,7 @@ class RigidObjectCollection(AssetBase):
com_pos_b = self.data.com_pos_b[local_env_ids][:, local_object_ids, :]
# transform given velocity to center of mass
object_com_velocity[..., :3] += torch.linalg.cross(
object_com_velocity[..., 3:], math_utils.quat_rotate(quat, com_pos_b), dim=-1
object_com_velocity[..., 3:], math_utils.quat_apply(quat, com_pos_b), dim=-1
)
# write center of mass velocity to sim
self.write_object_com_velocity_to_sim(
......
......@@ -150,7 +150,7 @@ class RigidObjectCollectionData:
# adjust linear velocity to link from center of mass
velocity[..., :3] += torch.linalg.cross(
velocity[..., 3:], math_utils.quat_rotate(pose[..., 3:7], -self.com_pos_b[..., :]), dim=-1
velocity[..., 3:], math_utils.quat_apply(pose[..., 3:7], -self.com_pos_b[..., :]), dim=-1
)
# set the buffer data and timestamp
......@@ -198,7 +198,7 @@ class RigidObjectCollectionData:
@property
def projected_gravity_b(self):
"""Projection of the gravity direction on base frame. Shape is (num_instances, num_objects, 3)."""
return math_utils.quat_rotate_inverse(self.object_link_quat_w, self.GRAVITY_VEC_W)
return math_utils.quat_apply_inverse(self.object_link_quat_w, self.GRAVITY_VEC_W)
@property
def heading_w(self):
......@@ -262,7 +262,7 @@ class RigidObjectCollectionData:
This quantity is the linear velocity of the rigid bodies' center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.object_quat_w, self.object_lin_vel_w)
return math_utils.quat_apply_inverse(self.object_quat_w, self.object_lin_vel_w)
@property
def object_ang_vel_b(self) -> torch.Tensor:
......@@ -271,7 +271,7 @@ class RigidObjectCollectionData:
This quantity is the angular velocity of the rigid bodies' center of mass frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.object_quat_w, self.object_ang_vel_w)
return math_utils.quat_apply_inverse(self.object_quat_w, self.object_ang_vel_w)
@property
def object_lin_acc_w(self) -> torch.Tensor:
......@@ -345,7 +345,7 @@ class RigidObjectCollectionData:
This quantity is the linear velocity of the actor frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.object_link_quat_w, self.object_link_lin_vel_w)
return math_utils.quat_apply_inverse(self.object_link_quat_w, self.object_link_lin_vel_w)
@property
def object_link_ang_vel_b(self) -> torch.Tensor:
......@@ -354,7 +354,7 @@ class RigidObjectCollectionData:
This quantity is the angular velocity of the actor frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.object_link_quat_w, self.object_link_ang_vel_w)
return math_utils.quat_apply_inverse(self.object_link_quat_w, self.object_link_ang_vel_w)
@property
def object_com_pos_w(self) -> torch.Tensor:
......@@ -415,7 +415,7 @@ class RigidObjectCollectionData:
This quantity is the linear velocity of the center of mass frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.object_link_quat_w, self.object_com_lin_vel_w)
return math_utils.quat_apply_inverse(self.object_link_quat_w, self.object_com_lin_vel_w)
@property
def object_com_ang_vel_b(self) -> torch.Tensor:
......@@ -424,7 +424,7 @@ class RigidObjectCollectionData:
This quantity is the angular velocity of the center of mass frame of the root rigid body frame with respect to the
rigid body's actor frame.
"""
return math_utils.quat_rotate_inverse(self.object_link_quat_w, self.object_com_ang_vel_w)
return math_utils.quat_apply_inverse(self.object_link_quat_w, self.object_com_ang_vel_w)
@property
def com_pos_b(self) -> torch.Tensor:
......
......@@ -622,13 +622,13 @@ class OperationalSpaceControllerAction(ActionTerm):
relative_vel_w = self._ee_vel_w - self._asset.data.root_vel_w
# Convert ee velocities from world to root frame
self._ee_vel_b[:, 0:3] = math_utils.quat_rotate_inverse(self._asset.data.root_quat_w, relative_vel_w[:, 0:3])
self._ee_vel_b[:, 3:6] = math_utils.quat_rotate_inverse(self._asset.data.root_quat_w, relative_vel_w[:, 3:6])
self._ee_vel_b[:, 0:3] = math_utils.quat_apply_inverse(self._asset.data.root_quat_w, relative_vel_w[:, 0:3])
self._ee_vel_b[:, 3:6] = math_utils.quat_apply_inverse(self._asset.data.root_quat_w, relative_vel_w[:, 3:6])
# Account for the offset
if self.cfg.body_offset is not None:
# Compute offset vector in root frame
r_offset_b = math_utils.quat_rotate(self._ee_pose_b_no_offset[:, 3:7], self._offset_pos)
r_offset_b = math_utils.quat_apply(self._ee_pose_b_no_offset[:, 3:7], self._offset_pos)
# Adjust the linear velocity to account for the offset
self._ee_vel_b[:, :3] += torch.cross(self._ee_vel_b[:, 3:], r_offset_b, dim=-1)
# Angular velocity is not affected by the offset
......@@ -640,7 +640,7 @@ class OperationalSpaceControllerAction(ActionTerm):
self._contact_sensor.update(self._sim_dt)
self._ee_force_w[:] = self._contact_sensor.data.net_forces_w[:, 0, :] # type: ignore
# Rotate forces and torques into root frame
self._ee_force_b[:] = math_utils.quat_rotate_inverse(self._asset.data.root_quat_w, self._ee_force_w)
self._ee_force_b[:] = math_utils.quat_apply_inverse(self._asset.data.root_quat_w, self._ee_force_w)
def _compute_joint_states(self):
"""Computes the joint states for operational space control."""
......
......@@ -15,7 +15,7 @@ from isaaclab.assets import Articulation
from isaaclab.managers import CommandTerm
from isaaclab.markers import VisualizationMarkers
from isaaclab.terrains import TerrainImporter
from isaaclab.utils.math import quat_from_euler_xyz, quat_rotate_inverse, wrap_to_pi, yaw_quat
from isaaclab.utils.math import quat_apply_inverse, quat_from_euler_xyz, wrap_to_pi, yaw_quat
if TYPE_CHECKING:
from isaaclab.envs import ManagerBasedEnv
......@@ -117,7 +117,7 @@ class UniformPose2dCommand(CommandTerm):
def _update_command(self):
"""Re-target the position command to the current root state."""
target_vec = self.pos_command_w - self.robot.data.root_pos_w[:, :3]
self.pos_command_b[:] = quat_rotate_inverse(yaw_quat(self.robot.data.root_quat_w), target_vec)
self.pos_command_b[:] = quat_apply_inverse(yaw_quat(self.robot.data.root_quat_w), target_vec)
self.heading_command_b[:] = wrap_to_pi(self.heading_command_w - self.robot.data.heading_w)
def _set_debug_vis_impl(self, debug_vis: bool):
......
......@@ -153,7 +153,7 @@ class Imu(SensorBase):
quat_w = math_utils.convert_quat(quat_w, to="wxyz")
# store the poses
self._data.pos_w[env_ids] = pos_w + math_utils.quat_rotate(quat_w, self._offset_pos_b[env_ids])
self._data.pos_w[env_ids] = pos_w + math_utils.quat_apply(quat_w, self._offset_pos_b[env_ids])
self._data.quat_w[env_ids] = math_utils.quat_mul(quat_w, self._offset_quat_b[env_ids])
# get the offset from COM to link origin
......@@ -164,18 +164,18 @@ class Imu(SensorBase):
# if an offset is present or the COM does not agree with the link origin, the linear velocity has to be
# transformed taking the angular velocity into account
lin_vel_w += torch.linalg.cross(
ang_vel_w, math_utils.quat_rotate(quat_w, self._offset_pos_b[env_ids] - com_pos_b[env_ids]), dim=-1
ang_vel_w, math_utils.quat_apply(quat_w, self._offset_pos_b[env_ids] - com_pos_b[env_ids]), dim=-1
)
# numerical derivative
lin_acc_w = (lin_vel_w - self._prev_lin_vel_w[env_ids]) / self._dt + self._gravity_bias_w[env_ids]
ang_acc_w = (ang_vel_w - self._prev_ang_vel_w[env_ids]) / self._dt
# store the velocities
self._data.lin_vel_b[env_ids] = math_utils.quat_rotate_inverse(self._data.quat_w[env_ids], lin_vel_w)
self._data.ang_vel_b[env_ids] = math_utils.quat_rotate_inverse(self._data.quat_w[env_ids], ang_vel_w)
self._data.lin_vel_b[env_ids] = math_utils.quat_apply_inverse(self._data.quat_w[env_ids], lin_vel_w)
self._data.ang_vel_b[env_ids] = math_utils.quat_apply_inverse(self._data.quat_w[env_ids], ang_vel_w)
# store the accelerations
self._data.lin_acc_b[env_ids] = math_utils.quat_rotate_inverse(self._data.quat_w[env_ids], lin_acc_w)
self._data.ang_acc_b[env_ids] = math_utils.quat_rotate_inverse(self._data.quat_w[env_ids], ang_acc_w)
self._data.lin_acc_b[env_ids] = math_utils.quat_apply_inverse(self._data.quat_w[env_ids], lin_acc_w)
self._data.ang_acc_b[env_ids] = math_utils.quat_apply_inverse(self._data.quat_w[env_ids], ang_acc_w)
self._prev_lin_vel_w[env_ids] = lin_vel_w
self._prev_ang_vel_w[env_ids] = ang_vel_w
......@@ -232,7 +232,7 @@ class Imu(SensorBase):
quat_opengl = math_utils.quat_from_matrix(
math_utils.create_rotation_matrix_from_view(
self._data.pos_w,
self._data.pos_w + math_utils.quat_rotate(self._data.quat_w, self._data.lin_acc_b),
self._data.pos_w + math_utils.quat_apply(self._data.quat_w, self._data.lin_acc_b),
up_axis=up_axis,
device=self._device,
)
......
......@@ -14,6 +14,8 @@ import torch
import torch.nn.functional
from typing import Literal
import omni.log
"""
General
"""
......@@ -633,6 +635,28 @@ def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Apply an inverse quaternion rotation to a vector.
Args:
quat: The quaternion in (w, x, y, z). Shape is (..., 4).
vec: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
# store shape
shape = vec.shape
# reshape to (N, 3) for multiplication
quat = quat.reshape(-1, 4)
vec = vec.reshape(-1, 3)
# extract components from quaternions
xyz = quat[:, 1:]
t = xyz.cross(vec, dim=-1) * 2
return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Rotate a vector only around the yaw-direction.
......@@ -648,9 +672,10 @@ def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
return quat_apply(quat_yaw, vec)
@torch.jit.script
def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Rotate a vector by a quaternion along the last dimension of q and v.
.. deprecated v2.1.0:
This function will be removed in a future release in favor of the faster implementation :meth:`quat_apply`.
Args:
q: The quaternion in (w, x, y, z). Shape is (..., 4).
......@@ -659,22 +684,19 @@ def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
q_w = q[..., 0]
q_vec = q[..., 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
# for two-dimensional tensors, bmm is faster than einsum
if q_vec.dim() == 2:
c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
else:
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a + b + c
# deprecation
omni.log.warn(
"The function 'quat_rotate' will be deprecated in favor of the faster method 'quat_apply'."
" Please use 'quat_apply' instead...."
)
return quat_apply(q, v)
@torch.jit.script
def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Rotate a vector by the inverse of a quaternion along the last dimension of q and v.
.. deprecated v2.1.0:
This function will be removed in a future release in favor of the faster implementation :meth:`quat_apply_inverse`.
Args:
q: The quaternion in (w, x, y, z). Shape is (..., 4).
v: The vector in (x, y, z). Shape is (..., 3).
......@@ -682,16 +704,11 @@ def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
q_w = q[..., 0]
q_vec = q[..., 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
# for two-dimensional tensors, bmm is faster than einsum
if q_vec.dim() == 2:
c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
else:
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a - b + c
omni.log.warn(
"The function 'quat_rotate_inverse' will be deprecated in favor of the faster method 'quat_apply_inverse'."
" Please use 'quat_apply_inverse' instead...."
)
return quat_apply_inverse(q, v)
@torch.jit.script
......
......@@ -28,7 +28,7 @@ from isaaclab.assets import RigidObject, RigidObjectCfg
from isaaclab.sim import build_simulation_context
from isaaclab.sim.spawners import materials
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR, ISAACLAB_NUCLEUS_DIR
from isaaclab.utils.math import default_orientation, quat_mul, quat_rotate_inverse, random_orientation
from isaaclab.utils.math import default_orientation, quat_apply_inverse, quat_mul, random_orientation
def generate_cubes_scene(
......@@ -811,12 +811,12 @@ def test_body_root_state_properties(num_cubes, device, with_offset):
torch.testing.assert_close(env_pos + offset, root_com_state_w[..., :3])
torch.testing.assert_close(env_pos + offset, body_com_state_w[..., :3].squeeze(-2))
# link position will be moving but should stay constant away from center of mass
root_link_state_pos_rel_com = quat_rotate_inverse(
root_link_state_pos_rel_com = quat_apply_inverse(
root_link_state_w[..., 3:7],
root_link_state_w[..., :3] - root_com_state_w[..., :3],
)
torch.testing.assert_close(-offset, root_link_state_pos_rel_com)
body_link_state_pos_rel_com = quat_rotate_inverse(
body_link_state_pos_rel_com = quat_apply_inverse(
body_link_state_w[..., 3:7],
body_link_state_w[..., :3] - body_com_state_w[..., :3],
)
......@@ -837,8 +837,8 @@ def test_body_root_state_properties(num_cubes, device, with_offset):
torch.testing.assert_close(torch.zeros_like(root_com_state_w[..., 7:10]), root_com_state_w[..., 7:10])
torch.testing.assert_close(torch.zeros_like(body_com_state_w[..., 7:10]), body_com_state_w[..., 7:10])
# link frame will be moving, and should be equal to input angular velocity cross offset
lin_vel_rel_root_gt = quat_rotate_inverse(root_link_state_w[..., 3:7], root_link_state_w[..., 7:10])
lin_vel_rel_body_gt = quat_rotate_inverse(body_link_state_w[..., 3:7], body_link_state_w[..., 7:10])
lin_vel_rel_root_gt = quat_apply_inverse(root_link_state_w[..., 3:7], root_link_state_w[..., 7:10])
lin_vel_rel_body_gt = quat_apply_inverse(body_link_state_w[..., 3:7], body_link_state_w[..., 7:10])
lin_vel_rel_gt = torch.linalg.cross(spin_twist.repeat(num_cubes, 1)[..., 3:], -offset)
torch.testing.assert_close(lin_vel_rel_gt, lin_vel_rel_root_gt, atol=1e-4, rtol=1e-4)
torch.testing.assert_close(lin_vel_rel_gt, lin_vel_rel_body_gt.squeeze(-2), atol=1e-4, rtol=1e-4)
......
......@@ -26,7 +26,7 @@ import isaaclab.sim as sim_utils
from isaaclab.assets import RigidObjectCfg, RigidObjectCollection, RigidObjectCollectionCfg
from isaaclab.sim import build_simulation_context
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR
from isaaclab.utils.math import default_orientation, quat_mul, quat_rotate_inverse, random_orientation
from isaaclab.utils.math import default_orientation, quat_apply_inverse, quat_mul, random_orientation
def generate_cubes_scene(
......@@ -417,7 +417,7 @@ def test_object_state_properties(sim, num_envs, num_cubes, device, with_offset,
torch.testing.assert_close(init_com, object_com_state_w[..., :3])
# link position will be moving but should stay constant away from center of mass
object_link_state_pos_rel_com = quat_rotate_inverse(
object_link_state_pos_rel_com = quat_apply_inverse(
object_link_state_w[..., 3:7],
object_link_state_w[..., :3] - object_com_state_w[..., :3],
)
......@@ -440,7 +440,7 @@ def test_object_state_properties(sim, num_envs, num_cubes, device, with_offset,
)
# link frame will be moving, and should be equal to input angular velocity cross offset
lin_vel_rel_object_gt = quat_rotate_inverse(object_link_state_w[..., 3:7], object_link_state_w[..., 7:10])
lin_vel_rel_object_gt = quat_apply_inverse(object_link_state_w[..., 3:7], object_link_state_w[..., 7:10])
lin_vel_rel_gt = torch.linalg.cross(spin_twist.repeat(num_envs, num_cubes, 1)[..., 3:], -offset)
torch.testing.assert_close(lin_vel_rel_gt, lin_vel_rel_object_gt, atol=1e-4, rtol=1e-3)
......
......@@ -30,8 +30,8 @@ from isaaclab.utils.math import (
combine_frame_transforms,
compute_pose_error,
matrix_from_quat,
quat_apply_inverse,
quat_inv,
quat_rotate_inverse,
subtract_frame_transforms,
)
......@@ -1422,8 +1422,8 @@ def _update_states(
ee_vel_w = robot.data.body_vel_w[:, ee_frame_idx, :] # Extract end-effector velocity in the world frame
root_vel_w = robot.data.root_vel_w # Extract root velocity in the world frame
relative_vel_w = ee_vel_w - root_vel_w # Compute the relative velocity in the world frame
ee_lin_vel_b = quat_rotate_inverse(robot.data.root_quat_w, relative_vel_w[:, 0:3]) # From world to root frame
ee_ang_vel_b = quat_rotate_inverse(robot.data.root_quat_w, relative_vel_w[:, 3:6])
ee_lin_vel_b = quat_apply_inverse(robot.data.root_quat_w, relative_vel_w[:, 0:3]) # From world to root frame
ee_ang_vel_b = quat_apply_inverse(robot.data.root_quat_w, relative_vel_w[:, 3:6])
ee_vel_b = torch.cat([ee_lin_vel_b, ee_ang_vel_b], dim=-1)
# Calculate the contact force
......
......@@ -287,7 +287,7 @@ def test_constant_acceleration(setup_sim):
# check the imu data
torch.testing.assert_close(
scene.sensors["imu_ball"].data.lin_acc_b,
math_utils.quat_rotate_inverse(
math_utils.quat_apply_inverse(
scene.rigid_objects["balls"].data.root_quat_w,
torch.tensor([[0.1, 0.0, 0.0]], dtype=torch.float32, device=scene.device).repeat(scene.num_envs, 1)
/ sim.get_physics_dt(),
......@@ -331,12 +331,12 @@ def test_single_dof_pendulum(setup_sim):
base_data = scene.sensors["imu_pendulum_base"].data
# extract imu_link imu_sensor dynamics
lin_vel_w_imu_link = math_utils.quat_rotate(imu_data.quat_w, imu_data.lin_vel_b)
lin_acc_w_imu_link = math_utils.quat_rotate(imu_data.quat_w, imu_data.lin_acc_b)
lin_vel_w_imu_link = math_utils.quat_apply(imu_data.quat_w, imu_data.lin_vel_b)
lin_acc_w_imu_link = math_utils.quat_apply(imu_data.quat_w, imu_data.lin_acc_b)
# calculate the joint dynamics from the imu_sensor (y axis of imu_link is parallel to joint axis of pendulum)
joint_vel_imu = math_utils.quat_rotate(imu_data.quat_w, imu_data.ang_vel_b)[..., 1].unsqueeze(-1)
joint_acc_imu = math_utils.quat_rotate(imu_data.quat_w, imu_data.ang_acc_b)[..., 1].unsqueeze(-1)
joint_vel_imu = math_utils.quat_apply(imu_data.quat_w, imu_data.ang_vel_b)[..., 1].unsqueeze(-1)
joint_acc_imu = math_utils.quat_apply(imu_data.quat_w, imu_data.ang_acc_b)[..., 1].unsqueeze(-1)
# calculate analytical solution
vx = -joint_vel * pend_length * torch.sin(joint_pos)
......
......@@ -293,159 +293,6 @@ def test_wrap_to_pi(device):
torch.testing.assert_close(wrapped_angle, expected_angle)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_quat_rotate_and_quat_rotate_inverse(device):
"""Test for quat_rotate and quat_rotate_inverse methods.
The new implementation uses :meth:`torch.einsum` instead of `torch.bmm` which allows
for more flexibility in the input dimensions and is faster than `torch.bmm`.
"""
# define old implementation for quat_rotate and quat_rotate_inverse
# Based on commit: cdfa954fcc4394ca8daf432f61994e25a7b8e9e2
@torch.jit.script
def old_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a + b + c
@torch.jit.script
def old_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
# check that implementation produces the same result as the new implementation
# prepare random quaternions and vectors
q_rand = math_utils.random_orientation(num=1024, device=device)
v_rand = math_utils.sample_uniform(-1000, 1000, (1024, 3), device=device)
# compute the result using the old implementation
old_result = old_quat_rotate(q_rand, v_rand)
old_result_inv = old_quat_rotate_inverse(q_rand, v_rand)
# compute the result using the new implementation
new_result = math_utils.quat_rotate(q_rand, v_rand)
new_result_inv = math_utils.quat_rotate_inverse(q_rand, v_rand)
# check that the result is close to the expected value
torch.testing.assert_close(old_result, new_result)
torch.testing.assert_close(old_result_inv, new_result_inv)
# check the performance of the new implementation
# prepare random quaternions and vectors
# new implementation supports batched inputs
q_shape = (1024, 2, 5, 4)
v_shape = (1024, 2, 5, 3)
# sample random quaternions and vectors
num_quats = math.prod(q_shape[:-1])
q_rand = math_utils.random_orientation(num=num_quats, device=device).reshape(q_shape)
v_rand = math_utils.sample_uniform(-1000, 1000, v_shape, device=device)
# create functions to test
def iter_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_rotate."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_rotate_inverse."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_old_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = old_quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_old_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate_inverse."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = old_quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
# create benchmark
timer_iter_quat_rotate = benchmark.Timer(
stmt="iter_quat_rotate(q_rand, v_rand)",
globals={"iter_quat_rotate": iter_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_quat_rotate_inverse = benchmark.Timer(
stmt="iter_quat_rotate_inverse(q_rand, v_rand)",
globals={"iter_quat_rotate_inverse": iter_quat_rotate_inverse, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_old_quat_rotate = benchmark.Timer(
stmt="iter_old_quat_rotate(q_rand, v_rand)",
globals={"iter_old_quat_rotate": iter_old_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_old_quat_rotate_inverse = benchmark.Timer(
stmt="iter_old_quat_rotate_inverse(q_rand, v_rand)",
globals={
"iter_old_quat_rotate_inverse": iter_old_quat_rotate_inverse,
"q_rand": q_rand,
"v_rand": v_rand,
},
)
timer_quat_rotate = benchmark.Timer(
stmt="math_utils.quat_rotate(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
timer_quat_rotate_inverse = benchmark.Timer(
stmt="math_utils.quat_rotate_inverse(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
# run the benchmark
print("--------------------------------")
print(f"Device: {device}")
print("Time for quat_rotate:", timer_quat_rotate.timeit(number=1000))
print("Time for iter_quat_rotate:", timer_iter_quat_rotate.timeit(number=1000))
print("Time for iter_old_quat_rotate:", timer_iter_old_quat_rotate.timeit(number=1000))
print("--------------------------------")
print("Time for quat_rotate_inverse:", timer_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_quat_rotate_inverse:", timer_iter_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_old_quat_rotate_inverse:", timer_iter_old_quat_rotate_inverse.timeit(number=1000))
print("--------------------------------")
# check output values are the same
torch.testing.assert_close(
math_utils.quat_rotate(q_rand, v_rand), iter_quat_rotate(q_rand, v_rand), atol=1e-4, rtol=1e-3
)
torch.testing.assert_close(
math_utils.quat_rotate(q_rand, v_rand), iter_old_quat_rotate(q_rand, v_rand), atol=1e-4, rtol=1e-3
)
torch.testing.assert_close(
math_utils.quat_rotate_inverse(q_rand, v_rand), iter_quat_rotate_inverse(q_rand, v_rand), atol=1e-4, rtol=1e-3
)
torch.testing.assert_close(
math_utils.quat_rotate_inverse(q_rand, v_rand),
iter_old_quat_rotate_inverse(q_rand, v_rand),
atol=1e-4,
rtol=1e-3,
)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_orthogonalize_perspective_depth(device):
"""Test for converting perspective depth to orthogonal depth."""
......@@ -529,6 +376,32 @@ def test_interpolate_poses(device):
np.testing.assert_array_almost_equal(result_pos, expected_pos, decimal=DECIMAL_PRECISION)
def test_pose_inv():
"""Test pose_inv function.
This test checks the output from the :meth:`~isaaclab.utils.math_utils.pose_inv` function against
the output from :func:`np.linalg.inv`. Two test cases are performed:
1. Checking the inverse of a random transformation matrix matches Numpy's built-in inverse.
2. Checking the inverse of a batch of random transformation matrices matches Numpy's built-in inverse.
"""
# Check against a single matrix
for _ in range(100):
test_mat = math_utils.generate_random_transformation_matrix(pos_boundary=10, rot_boundary=(2 * np.pi))
result = np.array(math_utils.pose_inv(test_mat))
expected = np.linalg.inv(np.array(test_mat))
np.testing.assert_array_almost_equal(result, expected, decimal=DECIMAL_PRECISION)
# Check against a batch of matrices
test_mats = torch.stack([
math_utils.generate_random_transformation_matrix(pos_boundary=10, rot_boundary=(2 * math.pi))
for _ in range(100)
])
result = np.array(math_utils.pose_inv(test_mats))
expected = np.linalg.inv(np.array(test_mats))
np.testing.assert_array_almost_equal(result, expected, decimal=DECIMAL_PRECISION)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_quat_box_minus(device):
"""Test quat_box_minus method.
......@@ -669,6 +542,274 @@ def test_quat_slerp(device):
np.testing.assert_array_almost_equal(result.cpu(), expected, decimal=DECIMAL_PRECISION)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_matrix_from_quat(device):
"""test matrix_from_quat against scipy."""
# prepare random quaternions and vectors
n = 1024
q_rand = math_utils.random_orientation(num=n, device=device)
rot_mat = math_utils.matrix_from_quat(quaternions=q_rand)
rot_mat_scipy = torch.tensor(
scipy_tf.Rotation.from_quat(math_utils.convert_quat(quat=q_rand.to(device="cpu"), to="xyzw")).as_matrix(),
device=device,
dtype=torch.float32,
)
print()
torch.testing.assert_close(rot_mat_scipy.to(device=device), rot_mat)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_quat_apply(device):
"""Test for quat_apply against scipy."""
# prepare random quaternions and vectors
n = 1024
q_rand = math_utils.random_orientation(num=n, device=device)
Rotation = scipy_tf.Rotation.from_quat(math_utils.convert_quat(quat=q_rand.to(device="cpu").numpy(), to="xyzw"))
v_rand = math_utils.sample_uniform(-1000, 1000, (n, 3), device=device)
# compute the result using the new implementation
scipy_result = torch.tensor(Rotation.apply(v_rand.to(device="cpu").numpy()), device=device, dtype=torch.float)
apply_result = math_utils.quat_apply(q_rand, v_rand)
torch.testing.assert_close(scipy_result.to(device=device), apply_result, atol=2e-4, rtol=2e-4)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_quat_apply_inverse(device):
"""Test for quat_apply against scipy."""
# prepare random quaternions and vectors
n = 1024
q_rand = math_utils.random_orientation(num=n, device=device)
Rotation = scipy_tf.Rotation.from_quat(math_utils.convert_quat(quat=q_rand.to(device="cpu").numpy(), to="xyzw"))
v_rand = math_utils.sample_uniform(-1000, 1000, (n, 3), device=device)
# compute the result using the new implementation
scipy_result = torch.tensor(
Rotation.apply(v_rand.to(device="cpu").numpy(), inverse=True), device=device, dtype=torch.float
)
apply_result = math_utils.quat_apply_inverse(q_rand, v_rand)
torch.testing.assert_close(scipy_result.to(device=device), apply_result, atol=2e-4, rtol=2e-4)
def test_quat_apply_benchmarks():
"""Test for quat_apply and quat_apply_inverse methods compared to old methods using torch.bmm and torch.einsum.
The new implementation uses :meth:`torch.einsum` instead of `torch.bmm` which allows
for more flexibility in the input dimensions and is faster than `torch.bmm`.
"""
# define old implementation for quat_rotate and quat_rotate_inverse
# Based on commit: cdfa954fcc4394ca8daf432f61994e25a7b8e9e2
@torch.jit.script
def bmm_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a + b + c
@torch.jit.script
def bmm_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
@torch.jit.script
def einsum_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
q_w = q[..., 0]
q_vec = q[..., 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a + b + c
@torch.jit.script
def einsum_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
q_w = q[..., 0]
q_vec = q[..., 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a - b + c
# check that implementation produces the same result as the new implementation
for device in ["cpu", "cuda:0"]:
# prepare random quaternions and vectors
q_rand = math_utils.random_orientation(num=1024, device=device)
v_rand = math_utils.sample_uniform(-1000, 1000, (1024, 3), device=device)
# compute the result using the old implementation
bmm_result = bmm_quat_rotate(q_rand, v_rand)
bmm_result_inv = bmm_quat_rotate_inverse(q_rand, v_rand)
# compute the result using the old implementation
einsum_result = einsum_quat_rotate(q_rand, v_rand)
einsum_result_inv = einsum_quat_rotate_inverse(q_rand, v_rand)
# compute the result using the new implementation
new_result = math_utils.quat_apply(q_rand, v_rand)
new_result_inv = math_utils.quat_apply_inverse(q_rand, v_rand)
# check that the result is close to the expected value
torch.testing.assert_close(bmm_result, new_result, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(bmm_result_inv, new_result_inv, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(einsum_result, new_result, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(einsum_result_inv, new_result_inv, atol=1e-3, rtol=1e-3)
# check the performance of the new implementation
for device in ["cpu", "cuda:0"]:
# prepare random quaternions and vectors
# new implementation supports batched inputs
q_shape = (1024, 2, 5, 4)
v_shape = (1024, 2, 5, 3)
# sample random quaternions and vectors
num_quats = math.prod(q_shape[:-1])
q_rand = math_utils.random_orientation(num=num_quats, device=device).reshape(q_shape)
v_rand = math_utils.sample_uniform(-1000, 1000, v_shape, device=device)
# create functions to test
def iter_quat_apply(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_apply."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_apply(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_quat_apply_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_apply_inverse."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_apply_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_bmm_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate using torch.bmm."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = bmm_quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_bmm_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate_inverse using torch.bmm."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = bmm_quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_einsum_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate using torch.einsum."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = einsum_quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_einsum_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate_inverse using torch.einsum."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = einsum_quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
# benchmarks for iterative calls
timer_iter_quat_apply = benchmark.Timer(
stmt="iter_quat_apply(q_rand, v_rand)",
globals={"iter_quat_apply": iter_quat_apply, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_quat_apply_inverse = benchmark.Timer(
stmt="iter_quat_apply_inverse(q_rand, v_rand)",
globals={"iter_quat_apply_inverse": iter_quat_apply_inverse, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_bmm_quat_rotate = benchmark.Timer(
stmt="iter_bmm_quat_rotate(q_rand, v_rand)",
globals={"iter_bmm_quat_rotate": iter_bmm_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_bmm_quat_rotate_inverse = benchmark.Timer(
stmt="iter_bmm_quat_rotate_inverse(q_rand, v_rand)",
globals={
"iter_bmm_quat_rotate_inverse": iter_bmm_quat_rotate_inverse,
"q_rand": q_rand,
"v_rand": v_rand,
},
)
timer_iter_einsum_quat_rotate = benchmark.Timer(
stmt="iter_einsum_quat_rotate(q_rand, v_rand)",
globals={"iter_einsum_quat_rotate": iter_einsum_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_einsum_quat_rotate_inverse = benchmark.Timer(
stmt="iter_einsum_quat_rotate_inverse(q_rand, v_rand)",
globals={
"iter_einsum_quat_rotate_inverse": iter_einsum_quat_rotate_inverse,
"q_rand": q_rand,
"v_rand": v_rand,
},
)
# create benchmaks for size independent calls
timer_quat_apply = benchmark.Timer(
stmt="math_utils.quat_apply(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
timer_quat_apply_inverse = benchmark.Timer(
stmt="math_utils.quat_apply_inverse(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
timer_einsum_quat_rotate = benchmark.Timer(
stmt="einsum_quat_rotate(q_rand, v_rand)",
globals={"einsum_quat_rotate": einsum_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_einsum_quat_rotate_inverse = benchmark.Timer(
stmt="einsum_quat_rotate_inverse(q_rand, v_rand)",
globals={"einsum_quat_rotate_inverse": einsum_quat_rotate_inverse, "q_rand": q_rand, "v_rand": v_rand},
)
# run the benchmark
print("--------------------------------")
print(f"Device: {device}")
print("Time for quat_apply:", timer_quat_apply.timeit(number=1000))
print("Time for einsum_quat_rotate:", timer_einsum_quat_rotate.timeit(number=1000))
print("Time for iter_quat_apply:", timer_iter_quat_apply.timeit(number=1000))
print("Time for iter_bmm_quat_rotate:", timer_iter_bmm_quat_rotate.timeit(number=1000))
print("Time for iter_einsum_quat_rotate:", timer_iter_einsum_quat_rotate.timeit(number=1000))
print("--------------------------------")
print("Time for quat_apply_inverse:", timer_quat_apply_inverse.timeit(number=1000))
print("Time for einsum_quat_rotate_inverse:", timer_einsum_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_quat_apply_inverse:", timer_iter_quat_apply_inverse.timeit(number=1000))
print("Time for iter_bmm_quat_rotate_inverse:", timer_iter_bmm_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_einsum_quat_rotate_inverse:", timer_iter_einsum_quat_rotate_inverse.timeit(number=1000))
print("--------------------------------")
# check output values are the same
torch.testing.assert_close(math_utils.quat_apply(q_rand, v_rand), iter_quat_apply(q_rand, v_rand))
torch.testing.assert_close(
math_utils.quat_apply(q_rand, v_rand), iter_bmm_quat_rotate(q_rand, v_rand), atol=1e-3, rtol=1e-3
)
torch.testing.assert_close(
math_utils.quat_apply_inverse(q_rand, v_rand), iter_quat_apply_inverse(q_rand, v_rand)
)
torch.testing.assert_close(
math_utils.quat_apply_inverse(q_rand, v_rand),
iter_bmm_quat_rotate_inverse(q_rand, v_rand),
atol=1e-3,
rtol=1e-3,
)
def test_interpolate_rotations():
"""Test interpolate_rotations function.
......
......@@ -13,7 +13,7 @@ import isaaclab.sim as sim_utils
from isaaclab.assets import Articulation
from isaaclab.envs import DirectRLEnv
from isaaclab.sim.spawners.from_files import GroundPlaneCfg, spawn_ground_plane
from isaaclab.utils.math import quat_rotate
from isaaclab.utils.math import quat_apply
from .humanoid_amp_env_cfg import HumanoidAmpEnvCfg
from .motions import MotionLoader
......@@ -208,8 +208,8 @@ def quaternion_to_tangent_and_normal(q: torch.Tensor) -> torch.Tensor:
ref_normal = torch.zeros_like(q[..., :3])
ref_tangent[..., 0] = 1
ref_normal[..., -1] = 1
tangent = quat_rotate(q, ref_tangent)
normal = quat_rotate(q, ref_normal)
tangent = quat_apply(q, ref_tangent)
normal = quat_apply(q, ref_normal)
return torch.cat([tangent, normal], dim=len(tangent.shape) - 1)
......
......@@ -50,7 +50,7 @@ def base_heading_proj(
to_target_pos[:, 2] = 0.0
to_target_dir = math_utils.normalize(to_target_pos)
# compute base forward vector
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.data.FORWARD_VEC_B)
heading_vec = math_utils.quat_apply(asset.data.root_quat_w, asset.data.FORWARD_VEC_B)
# compute dot product between heading and target direction
heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1))
......
......@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
from isaaclab.managers import SceneEntityCfg
from isaaclab.sensors import ContactSensor
from isaaclab.utils.math import quat_rotate_inverse, yaw_quat
from isaaclab.utils.math import quat_apply_inverse, yaw_quat
if TYPE_CHECKING:
from isaaclab.envs import ManagerBasedRLEnv
......@@ -89,7 +89,7 @@ def track_lin_vel_xy_yaw_frame_exp(
"""Reward tracking of linear velocity commands (xy axes) in the gravity aligned robot frame using exponential kernel."""
# extract the used quantities (to enable type-hinting)
asset = env.scene[asset_cfg.name]
vel_yaw = quat_rotate_inverse(yaw_quat(asset.data.root_quat_w), asset.data.root_lin_vel_w[:, :3])
vel_yaw = quat_apply_inverse(yaw_quat(asset.data.root_quat_w), asset.data.root_lin_vel_w[:, :3])
lin_vel_error = torch.sum(
torch.square(env.command_manager.get_command(command_name)[:, :2] - vel_yaw[:, :2]), dim=1
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment