Unverified Commit 0c3fb1e1 authored by dxyy1's avatar dxyy1 Committed by GitHub

Uses `torch.einsum` for quat_rotate and quat_rotate_inverse operations (#900)

# Description
Extended the two functions' capability so they now can take in
multidimensional tensors and no longer limited to 2D tensors of shape
(B,4) and (B, 3)

## Type of change
- New feature (non-breaking change which adds functionality)

## Checklist
- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Signed-off-by: 's avatardxyy1 <139338590+dxyy1@users.noreply.github.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
parent af088f59
......@@ -39,6 +39,7 @@ Guidelines for modifications:
* Brayden Zhang
* Calvin Yu
* Chenyu Yang
* David Yang
* HoJin Jeon
* Jia Lin Yuan
* Jingzhou Liu
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.22.8"
version = "0.22.9"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.22.9 (2024-09-08)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Modified:meth:`quat_rotate` and :meth:`quat_rotate_inverse` operations to use :meth:`torch.einsum`
for faster processing of high dimensional input tensors.
0.22.8 (2024-09-06)
~~~~~~~~~~~~~~~~~~~
......
......@@ -580,41 +580,47 @@ def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
@torch.jit.script
def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Rotate a vector by a quaternion.
"""Rotate a vector by a quaternion along the last dimension of q and v.
Args:
q: The quaternion in (w, x, y, z). Shape is (N, 4).
v: The vector in (x, y, z). Shape is (N, 3).
q: The quaternion in (w, x, y, z). Shape is (..., 4).
v: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (N, 3).
The rotated vector in (x, y, z). Shape is (..., 3).
"""
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
q_w = q[..., 0]
q_vec = q[..., 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
# for two-dimensional tensors, bmm is faster than einsum
if q_vec.dim() == 2:
c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
else:
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a + b + c
@torch.jit.script
def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Rotate a vector by the inverse of a quaternion.
"""Rotate a vector by the inverse of a quaternion along the last dimension of q and v.
Args:
q: The quaternion in (w, x, y, z). Shape is (N, 4).
v: The vector in (x, y, z). Shape is (N, 3).
q: The quaternion in (w, x, y, z). Shape is (..., 4).
v: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (N, 3).
The rotated vector in (x, y, z). Shape is (..., 3).
"""
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
q_w = q[..., 0]
q_vec = q[..., 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
# for two-dimensional tensors, bmm is faster than einsum
if q_vec.dim() == 2:
c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
else:
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a - b + c
......
......@@ -3,7 +3,6 @@
#
# SPDX-License-Identifier: BSD-3-Clause
import torch
import unittest
"""Launch Isaac Sim Simulator first.
......@@ -19,6 +18,9 @@ simulation_app = AppLauncher(headless=True).app
"""Rest everything follows."""
import math
import torch
import torch.utils.benchmark as benchmark
from math import pi as PI
import omni.isaac.lab.utils.math as math_utils
......@@ -227,6 +229,153 @@ class TestMathUtilities(unittest.TestCase):
# Check that the wrapped angle is close to the expected value
torch.testing.assert_close(wrapped_angle, expected_angle)
def test_quat_rotate_and_quat_rotate_inverse(self):
"""Test for quat_rotate and quat_rotate_inverse methods.
The new implementation uses :meth:`torch.einsum` instead of `torch.bmm` which allows
for more flexibility in the input dimensions and is faster than `torch.bmm`.
"""
# define old implementation for quat_rotate and quat_rotate_inverse
# Based on commit: cdfa954fcc4394ca8daf432f61994e25a7b8e9e2
@torch.jit.script
def old_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a + b + c
@torch.jit.script
def old_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
# check that implementation produces the same result as the new implementation
for device in ["cpu", "cuda:0"]:
# prepare random quaternions and vectors
q_rand = math_utils.random_orientation(num=1024, device=device)
v_rand = math_utils.sample_uniform(-1000, 1000, (1024, 3), device=device)
# compute the result using the old implementation
old_result = old_quat_rotate(q_rand, v_rand)
old_result_inv = old_quat_rotate_inverse(q_rand, v_rand)
# compute the result using the new implementation
new_result = math_utils.quat_rotate(q_rand, v_rand)
new_result_inv = math_utils.quat_rotate_inverse(q_rand, v_rand)
# check that the result is close to the expected value
torch.testing.assert_close(old_result, new_result)
torch.testing.assert_close(old_result_inv, new_result_inv)
# check the performance of the new implementation
for device in ["cpu", "cuda:0"]:
# prepare random quaternions and vectors
# new implementation supports batched inputs
q_shape = (1024, 2, 5, 4)
v_shape = (1024, 2, 5, 3)
# sample random quaternions and vectors
num_quats = math.prod(q_shape[:-1])
q_rand = math_utils.random_orientation(num=num_quats, device=device).reshape(q_shape)
v_rand = math_utils.sample_uniform(-1000, 1000, v_shape, device=device)
# create functions to test
def iter_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_rotate."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_rotate_inverse."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_old_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = old_quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_old_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate_inverse."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = old_quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
# create benchmark
timer_iter_quat_rotate = benchmark.Timer(
stmt="iter_quat_rotate(q_rand, v_rand)",
globals={"iter_quat_rotate": iter_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_quat_rotate_inverse = benchmark.Timer(
stmt="iter_quat_rotate_inverse(q_rand, v_rand)",
globals={"iter_quat_rotate_inverse": iter_quat_rotate_inverse, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_old_quat_rotate = benchmark.Timer(
stmt="iter_old_quat_rotate(q_rand, v_rand)",
globals={"iter_old_quat_rotate": iter_old_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_old_quat_rotate_inverse = benchmark.Timer(
stmt="iter_old_quat_rotate_inverse(q_rand, v_rand)",
globals={
"iter_old_quat_rotate_inverse": iter_old_quat_rotate_inverse,
"q_rand": q_rand,
"v_rand": v_rand,
},
)
timer_quat_rotate = benchmark.Timer(
stmt="math_utils.quat_rotate(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
timer_quat_rotate_inverse = benchmark.Timer(
stmt="math_utils.quat_rotate_inverse(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
# run the benchmark
print("--------------------------------")
print(f"Device: {device}")
print("Time for quat_rotate:", timer_quat_rotate.timeit(number=1000))
print("Time for iter_quat_rotate:", timer_iter_quat_rotate.timeit(number=1000))
print("Time for iter_old_quat_rotate:", timer_iter_old_quat_rotate.timeit(number=1000))
print("--------------------------------")
print("Time for quat_rotate_inverse:", timer_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_quat_rotate_inverse:", timer_iter_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_old_quat_rotate_inverse:", timer_iter_old_quat_rotate_inverse.timeit(number=1000))
print("--------------------------------")
# check output values are the same
torch.testing.assert_close(math_utils.quat_rotate(q_rand, v_rand), iter_quat_rotate(q_rand, v_rand))
torch.testing.assert_close(math_utils.quat_rotate(q_rand, v_rand), iter_old_quat_rotate(q_rand, v_rand))
torch.testing.assert_close(
math_utils.quat_rotate_inverse(q_rand, v_rand), iter_quat_rotate_inverse(q_rand, v_rand)
)
torch.testing.assert_close(
math_utils.quat_rotate_inverse(q_rand, v_rand),
iter_old_quat_rotate_inverse(q_rand, v_rand),
)
if __name__ == "__main__":
run_tests()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment