Unverified Commit 0c3fb1e1 authored by dxyy1's avatar dxyy1 Committed by GitHub

Uses `torch.einsum` for quat_rotate and quat_rotate_inverse operations (#900)

# Description
Extended the two functions' capability so they now can take in
multidimensional tensors and no longer limited to 2D tensors of shape
(B,4) and (B, 3)

## Type of change
- New feature (non-breaking change which adds functionality)

## Checklist
- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Signed-off-by: 's avatardxyy1 <139338590+dxyy1@users.noreply.github.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
parent af088f59
...@@ -39,6 +39,7 @@ Guidelines for modifications: ...@@ -39,6 +39,7 @@ Guidelines for modifications:
* Brayden Zhang * Brayden Zhang
* Calvin Yu * Calvin Yu
* Chenyu Yang * Chenyu Yang
* David Yang
* HoJin Jeon * HoJin Jeon
* Jia Lin Yuan * Jia Lin Yuan
* Jingzhou Liu * Jingzhou Liu
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.22.8" version = "0.22.9"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.22.9 (2024-09-08)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Modified:meth:`quat_rotate` and :meth:`quat_rotate_inverse` operations to use :meth:`torch.einsum`
for faster processing of high dimensional input tensors.
0.22.8 (2024-09-06) 0.22.8 (2024-09-06)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -580,41 +580,47 @@ def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: ...@@ -580,41 +580,47 @@ def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
@torch.jit.script @torch.jit.script
def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Rotate a vector by a quaternion. """Rotate a vector by a quaternion along the last dimension of q and v.
Args: Args:
q: The quaternion in (w, x, y, z). Shape is (N, 4). q: The quaternion in (w, x, y, z). Shape is (..., 4).
v: The vector in (x, y, z). Shape is (N, 3). v: The vector in (x, y, z). Shape is (..., 3).
Returns: Returns:
The rotated vector in (x, y, z). Shape is (N, 3). The rotated vector in (x, y, z). Shape is (..., 3).
""" """
shape = q.shape q_w = q[..., 0]
q_w = q[:, 0] q_vec = q[..., 1:]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0 # for two-dimensional tensors, bmm is faster than einsum
if q_vec.dim() == 2:
c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
else:
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a + b + c return a + b + c
@torch.jit.script @torch.jit.script
def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Rotate a vector by the inverse of a quaternion. """Rotate a vector by the inverse of a quaternion along the last dimension of q and v.
Args: Args:
q: The quaternion in (w, x, y, z). Shape is (N, 4). q: The quaternion in (w, x, y, z). Shape is (..., 4).
v: The vector in (x, y, z). Shape is (N, 3). v: The vector in (x, y, z). Shape is (..., 3).
Returns: Returns:
The rotated vector in (x, y, z). Shape is (N, 3). The rotated vector in (x, y, z). Shape is (..., 3).
""" """
shape = q.shape q_w = q[..., 0]
q_w = q[:, 0] q_vec = q[..., 1:]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0 # for two-dimensional tensors, bmm is faster than einsum
if q_vec.dim() == 2:
c = q_vec * torch.bmm(q_vec.view(q.shape[0], 1, 3), v.view(q.shape[0], 3, 1)).squeeze(-1) * 2.0
else:
c = q_vec * torch.einsum("...i,...i->...", q_vec, v).unsqueeze(-1) * 2.0
return a - b + c return a - b + c
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
import torch
import unittest import unittest
"""Launch Isaac Sim Simulator first. """Launch Isaac Sim Simulator first.
...@@ -19,6 +18,9 @@ simulation_app = AppLauncher(headless=True).app ...@@ -19,6 +18,9 @@ simulation_app = AppLauncher(headless=True).app
"""Rest everything follows.""" """Rest everything follows."""
import math
import torch
import torch.utils.benchmark as benchmark
from math import pi as PI from math import pi as PI
import omni.isaac.lab.utils.math as math_utils import omni.isaac.lab.utils.math as math_utils
...@@ -227,6 +229,153 @@ class TestMathUtilities(unittest.TestCase): ...@@ -227,6 +229,153 @@ class TestMathUtilities(unittest.TestCase):
# Check that the wrapped angle is close to the expected value # Check that the wrapped angle is close to the expected value
torch.testing.assert_close(wrapped_angle, expected_angle) torch.testing.assert_close(wrapped_angle, expected_angle)
def test_quat_rotate_and_quat_rotate_inverse(self):
"""Test for quat_rotate and quat_rotate_inverse methods.
The new implementation uses :meth:`torch.einsum` instead of `torch.bmm` which allows
for more flexibility in the input dimensions and is faster than `torch.bmm`.
"""
# define old implementation for quat_rotate and quat_rotate_inverse
# Based on commit: cdfa954fcc4394ca8daf432f61994e25a7b8e9e2
@torch.jit.script
def old_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a + b + c
@torch.jit.script
def old_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
shape = q.shape
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0
return a - b + c
# check that implementation produces the same result as the new implementation
for device in ["cpu", "cuda:0"]:
# prepare random quaternions and vectors
q_rand = math_utils.random_orientation(num=1024, device=device)
v_rand = math_utils.sample_uniform(-1000, 1000, (1024, 3), device=device)
# compute the result using the old implementation
old_result = old_quat_rotate(q_rand, v_rand)
old_result_inv = old_quat_rotate_inverse(q_rand, v_rand)
# compute the result using the new implementation
new_result = math_utils.quat_rotate(q_rand, v_rand)
new_result_inv = math_utils.quat_rotate_inverse(q_rand, v_rand)
# check that the result is close to the expected value
torch.testing.assert_close(old_result, new_result)
torch.testing.assert_close(old_result_inv, new_result_inv)
# check the performance of the new implementation
for device in ["cpu", "cuda:0"]:
# prepare random quaternions and vectors
# new implementation supports batched inputs
q_shape = (1024, 2, 5, 4)
v_shape = (1024, 2, 5, 3)
# sample random quaternions and vectors
num_quats = math.prod(q_shape[:-1])
q_rand = math_utils.random_orientation(num=num_quats, device=device).reshape(q_shape)
v_rand = math_utils.sample_uniform(-1000, 1000, v_shape, device=device)
# create functions to test
def iter_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_rotate."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of new quat_rotate_inverse."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = math_utils.quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_old_quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = old_quat_rotate(q_rand[:, i, j], v_rand[:, i, j])
return out
def iter_old_quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Iterative implementation of old quat_rotate_inverse."""
out = torch.empty_like(v)
for i in range(q.shape[1]):
for j in range(q.shape[2]):
out[:, i, j] = old_quat_rotate_inverse(q_rand[:, i, j], v_rand[:, i, j])
return out
# create benchmark
timer_iter_quat_rotate = benchmark.Timer(
stmt="iter_quat_rotate(q_rand, v_rand)",
globals={"iter_quat_rotate": iter_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_quat_rotate_inverse = benchmark.Timer(
stmt="iter_quat_rotate_inverse(q_rand, v_rand)",
globals={"iter_quat_rotate_inverse": iter_quat_rotate_inverse, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_old_quat_rotate = benchmark.Timer(
stmt="iter_old_quat_rotate(q_rand, v_rand)",
globals={"iter_old_quat_rotate": iter_old_quat_rotate, "q_rand": q_rand, "v_rand": v_rand},
)
timer_iter_old_quat_rotate_inverse = benchmark.Timer(
stmt="iter_old_quat_rotate_inverse(q_rand, v_rand)",
globals={
"iter_old_quat_rotate_inverse": iter_old_quat_rotate_inverse,
"q_rand": q_rand,
"v_rand": v_rand,
},
)
timer_quat_rotate = benchmark.Timer(
stmt="math_utils.quat_rotate(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
timer_quat_rotate_inverse = benchmark.Timer(
stmt="math_utils.quat_rotate_inverse(q_rand, v_rand)",
globals={"math_utils": math_utils, "q_rand": q_rand, "v_rand": v_rand},
)
# run the benchmark
print("--------------------------------")
print(f"Device: {device}")
print("Time for quat_rotate:", timer_quat_rotate.timeit(number=1000))
print("Time for iter_quat_rotate:", timer_iter_quat_rotate.timeit(number=1000))
print("Time for iter_old_quat_rotate:", timer_iter_old_quat_rotate.timeit(number=1000))
print("--------------------------------")
print("Time for quat_rotate_inverse:", timer_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_quat_rotate_inverse:", timer_iter_quat_rotate_inverse.timeit(number=1000))
print("Time for iter_old_quat_rotate_inverse:", timer_iter_old_quat_rotate_inverse.timeit(number=1000))
print("--------------------------------")
# check output values are the same
torch.testing.assert_close(math_utils.quat_rotate(q_rand, v_rand), iter_quat_rotate(q_rand, v_rand))
torch.testing.assert_close(math_utils.quat_rotate(q_rand, v_rand), iter_old_quat_rotate(q_rand, v_rand))
torch.testing.assert_close(
math_utils.quat_rotate_inverse(q_rand, v_rand), iter_quat_rotate_inverse(q_rand, v_rand)
)
torch.testing.assert_close(
math_utils.quat_rotate_inverse(q_rand, v_rand),
iter_old_quat_rotate_inverse(q_rand, v_rand),
)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment