Unverified Commit f4bb9875 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes utils operations for math and string (#86)

# Description

* Fixes the `omni.isaac.orbit.utils.math.quat_apply_yaw` to compute the
yaw quaternion correctly.
* Adds functions to convert string and callable objects in the
`omni.isaac.orbit.utils.string` module. The function can deal with both
module functions and lambda expressions.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 5931b15a
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.6.0" version = "0.6.1"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.6.1 (2023-07-16)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed the :meth:`omni.isaac.orbit.utils.math.quat_apply_yaw` to compute the yaw quaternion correctly.
Added
^^^^^^^
* Added functions to convert string and callable objects in :mod:`omni.isaac.orbit.utils.string`.
0.6.0 (2023-07-16) 0.6.0 (2023-07-16)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -17,7 +17,7 @@ Sub-module containing utilities for the Orbit framework. ...@@ -17,7 +17,7 @@ Sub-module containing utilities for the Orbit framework.
from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES, TensorData, convert_to_torch from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES, TensorData, convert_to_torch
from .configclass import configclass from .configclass import configclass
from .dict import class_to_dict, convert_dict_to_backend, print_dict, update_class_from_dict, update_dict from .dict import class_to_dict, convert_dict_to_backend, print_dict, update_class_from_dict, update_dict
from .string import to_camel_case, to_snake_case from .string import callable_to_string, is_lambda_expression, string_to_callable, to_camel_case, to_snake_case
from .timer import Timer from .timer import Timer
__all__ = [ __all__ = [
...@@ -37,6 +37,9 @@ __all__ = [ ...@@ -37,6 +37,9 @@ __all__ = [
# string utilities # string utilities
"to_camel_case", "to_camel_case",
"to_snake_case", "to_snake_case",
"is_lambda_expression",
"string_to_callable",
"callable_to_string",
# timer # timer
"Timer", "Timer",
] ]
...@@ -8,12 +8,11 @@ ...@@ -8,12 +8,11 @@
import collections.abc import collections.abc
import hashlib import hashlib
import importlib
import inspect
import json import json
from typing import Any, Callable, Dict, Iterable, Mapping from typing import Any, Dict, Iterable, Mapping
from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES
from .string import callable_to_string, string_to_callable
__all__ = [ __all__ = [
"class_to_dict", "class_to_dict",
...@@ -60,7 +59,7 @@ def class_to_dict(obj: object) -> Dict[str, Any]: ...@@ -60,7 +59,7 @@ def class_to_dict(obj: object) -> Dict[str, Any]:
continue continue
# check if attribute is callable -- function # check if attribute is callable -- function
if callable(value): if callable(value):
data[key] = f"{value.__module__}:{value.__name__}" data[key] = callable_to_string(value)
# check if attribute is a dictionary # check if attribute is a dictionary
elif hasattr(value, "__dict__") or isinstance(value, dict): elif hasattr(value, "__dict__") or isinstance(value, dict):
data[key] = class_to_dict(value) data[key] = class_to_dict(value)
...@@ -96,7 +95,7 @@ def update_class_from_dict(obj, data: Dict[str, Any], _ns: str = "") -> None: ...@@ -96,7 +95,7 @@ def update_class_from_dict(obj, data: Dict[str, Any], _ns: str = "") -> None:
# iterate over the dictionary to look for callable values # iterate over the dictionary to look for callable values
for k, v in obj_mem.items(): for k, v in obj_mem.items():
if callable(v): if callable(v):
value[k] = _string_to_callable(value[k]) value[k] = string_to_callable(value[k])
setattr(obj, key, value) setattr(obj, key, value)
elif isinstance(value, Mapping): elif isinstance(value, Mapping):
# recursively call if it is a dictionary # recursively call if it is a dictionary
...@@ -111,7 +110,7 @@ def update_class_from_dict(obj, data: Dict[str, Any], _ns: str = "") -> None: ...@@ -111,7 +110,7 @@ def update_class_from_dict(obj, data: Dict[str, Any], _ns: str = "") -> None:
setattr(obj, key, value) setattr(obj, key, value)
elif callable(obj_mem): elif callable(obj_mem):
# update function name # update function name
value = _string_to_callable(value) value = string_to_callable(value)
setattr(obj, key, value) setattr(obj, key, value)
elif isinstance(value, type(obj_mem)): elif isinstance(value, type(obj_mem)):
# check that they are type-safe # check that they are type-safe
...@@ -260,45 +259,7 @@ def print_dict(val, nesting: int = -4, start: bool = True): ...@@ -260,45 +259,7 @@ def print_dict(val, nesting: int = -4, start: bool = True):
print_dict(val[k], nesting, start=False) print_dict(val[k], nesting, start=False)
else: else:
# deal with functions in print statements # deal with functions in print statements
if callable(val) and val.__name__ == "<lambda>": if callable(val):
print("lambda", inspect.getsourcelines(val)[0][0].strip().split("lambda")[1].strip()[:-1]) print(callable_to_string(val))
elif callable(val):
print(f"{val.__module__}:{val.__name__}")
else: else:
print(val) print(val)
"""
Private helper functions.
"""
def _string_to_callable(name: str) -> Callable:
"""Resolves the module and function names to return the function.
Args:
name (str): The function name. The format should be 'module:attribute_name'.
Raises:
ValueError: When the resolved attribute is not a function.
ValueError: _description_
Returns:
Callable: The function loaded from the module.
"""
try:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
callable_object = getattr(mod, attr_name)
# check if attribute is callable
if callable(callable_object):
return callable_object
else:
raise ValueError(f"The imported object is not callable: '{name}'")
except AttributeError as e:
msg = (
"While updating the config from a dictionary, we could not interpret the entry"
"as a callable object. The format of input should be 'module:attribute_name'\n"
f"While processing input '{name}', received the error:\n {e}."
)
raise ValueError(msg)
...@@ -39,6 +39,7 @@ __all__ = [ ...@@ -39,6 +39,7 @@ __all__ = [
"quat_from_euler_xyz", "quat_from_euler_xyz",
"quat_apply_yaw", "quat_apply_yaw",
"quat_box_minus", "quat_box_minus",
"yaw_quat",
"euler_xyz_from_quat", "euler_xyz_from_quat",
"axis_angle_from_quat", "axis_angle_from_quat",
# Rotation-Isaac Sim # Rotation-Isaac Sim
...@@ -269,7 +270,30 @@ def euler_xyz_from_quat(quat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, ...@@ -269,7 +270,30 @@ def euler_xyz_from_quat(quat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,
cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z) cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z)
yaw = torch.atan2(sin_yaw, cos_yaw) yaw = torch.atan2(sin_yaw, cos_yaw)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) # TODO: why not wrap_to_pi here ?
@torch.jit.script
def yaw_quat(quat: torch.Tensor) -> torch.Tensor:
"""Extract the yaw component of a quaternion.
Args:
quat (torch.Tensor): Input orientation to extract yaw from.
Returns:
torch.Tensor: A quaternion with only yaw component.
"""
quat_yaw = quat.clone().view(-1, 4)
qw = quat_yaw[:, 0]
qx = quat_yaw[:, 1]
qy = quat_yaw[:, 2]
qz = quat_yaw[:, 3]
yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz))
quat_yaw[:] = 0.0
quat_yaw[:, 3] = torch.sin(yaw / 2)
quat_yaw[:, 0] = torch.cos(yaw / 2)
quat_yaw = normalize(quat_yaw)
return quat_yaw
@torch.jit.script @torch.jit.script
...@@ -283,9 +307,7 @@ def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: ...@@ -283,9 +307,7 @@ def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
Returns: Returns:
torch.Tensor: Rotated vector. torch.Tensor: Rotated vector.
""" """
quat_yaw = quat.clone().view(-1, 4) quat_yaw = yaw_quat(quat)
quat_yaw[:, 1:3] = 0.0 # set x, y components as zero
quat_yaw = normalize(quat_yaw)
return quat_apply(quat_yaw, vec) return quat_apply(quat_yaw, vec)
...@@ -294,8 +316,8 @@ def quat_box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: ...@@ -294,8 +316,8 @@ def quat_box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""Implements box-minus operator (quaternion difference). """Implements box-minus operator (quaternion difference).
Args: Args:
q1 (torch.Tensor): A (N, 4) tensor for quaternion (x, y, z, w) q1 (torch.Tensor): A (N, 4) tensor for quaternion (w, x, y, z).
q2 (torch.Tensor): A (N, 4) tensor for quaternion (x, y, z, w) q2 (torch.Tensor): A (N, 4) tensor for quaternion (w, x, y, z).
Returns: Returns:
torch.Tensor: q1 box-minus q2 torch.Tensor: q1 box-minus q2
......
...@@ -5,9 +5,24 @@ ...@@ -5,9 +5,24 @@
"""Transformations of strings.""" """Transformations of strings."""
import ast
import importlib
import inspect
import re import re
from typing import Optional from typing import Callable, Optional
__all__ = [
"to_camel_case",
"to_snake_case",
"is_lambda_expression",
"string_to_callable",
"callable_to_string",
]
"""
String formatting.
"""
def to_camel_case(snake_str: str, to: Optional[str] = "cC") -> str: def to_camel_case(snake_str: str, to: Optional[str] = "cC") -> str:
...@@ -49,3 +64,85 @@ def to_snake_case(camel_str: str) -> str: ...@@ -49,3 +64,85 @@ def to_snake_case(camel_str: str) -> str:
""" """
camel_str = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_str) camel_str = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_str)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", camel_str).lower() return re.sub("([a-z0-9])([A-Z])", r"\1_\2", camel_str).lower()
"""
String <-> Callable operations.
"""
def is_lambda_expression(name: str) -> bool:
"""Checks if the input string is a lambda expression.
Args:
name (str): The input string.
Returns:
bool: Whether the input string is a lambda expression.
"""
try:
ast.parse(name)
return isinstance(ast.parse(name).body[0], ast.Expr) and isinstance(ast.parse(name).body[0].value, ast.Lambda)
except SyntaxError:
return False
def callable_to_string(value: Callable) -> str:
"""Converts a callable object to a string.
Args:
callable_object (Callable): A callable object.
Raises:
ValueError: When the input argument is not a callable object.
Returns:
str: A string representation of the callable object.
"""
# check if callable
if not callable(value):
raise ValueError(f"The input argument is not callable: {value}.")
# check if lambda function
if value.__name__ == "<lambda>":
return f"lambda {inspect.getsourcelines(value)[0][0].strip().split('lambda')[1].strip().split(',')[0]}"
else:
# get the module and function name
module_name = value.__module__
function_name = value.__name__
# return the string
return f"{module_name}:{function_name}"
def string_to_callable(name: str) -> Callable:
"""Resolves the module and function names to return the function.
Args:
name (str): The function name. The format should be 'module:attribute_name' or a
lambda expression of format: 'lambda x: x'.
Raises:
ValueError: When the resolved attribute is not a function.
ValueError: When the module cannot be found.
Returns:
Callable: The function loaded from the module.
"""
try:
if is_lambda_expression(name):
callable_object = eval(name)
else:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
callable_object = getattr(mod, attr_name)
# check if attribute is callable
if callable(callable_object):
return callable_object
else:
raise AttributeError(f"The imported object is not callable: '{name}'")
except (ValueError, ModuleNotFoundError) as e:
msg = (
f"Could not resolve the input string '{name}' into callable object."
" The format of input should be 'module:attribute_name'.\n"
f"Received the error:\n {e}."
)
raise ValueError(msg)
...@@ -44,9 +44,9 @@ def raycast_mesh(ray_starts: torch.Tensor, ray_directions: torch.Tensor, mesh: w ...@@ -44,9 +44,9 @@ def raycast_mesh(ray_starts: torch.Tensor, ray_directions: torch.Tensor, mesh: w
ray_hits = torch.full((num_rays, 3), float("inf"), device=mesh_device) ray_hits = torch.full((num_rays, 3), float("inf"), device=mesh_device)
# map the memory to warp arrays # map the memory to warp arrays
ray_starts_wp = wp.from_torch(ray_starts, dtype=wp.vec3, requires_grad=False) ray_starts_wp = wp.from_torch(ray_starts, dtype=wp.vec3)
ray_directions_wp = wp.from_torch(ray_directions, dtype=wp.vec3, requires_grad=False) ray_directions_wp = wp.from_torch(ray_directions, dtype=wp.vec3)
ray_hits_wp = wp.from_torch(ray_hits, dtype=wp.vec3, requires_grad=False) ray_hits_wp = wp.from_torch(ray_hits, dtype=wp.vec3)
# launch the warp kernel # launch the warp kernel
wp.launch( wp.launch(
......
...@@ -8,6 +8,16 @@ import unittest ...@@ -8,6 +8,16 @@ import unittest
import omni.isaac.orbit.utils.dict as dict_utils import omni.isaac.orbit.utils.dict as dict_utils
def test_function(x):
"""Test function for string <-> callable conversion."""
return x**2
def test_lambda_function(x):
"""Test function for string <-> callable conversion."""
return x**2
class TestDictUtilities(unittest.TestCase): class TestDictUtilities(unittest.TestCase):
"""Test fixture for checking Kit utilities in Orbit.""" """Test fixture for checking Kit utilities in Orbit."""
...@@ -25,6 +35,38 @@ class TestDictUtilities(unittest.TestCase): ...@@ -25,6 +35,38 @@ class TestDictUtilities(unittest.TestCase):
# print the dictionary # print the dictionary
dict_utils.print_dict(test_dict) dict_utils.print_dict(test_dict)
def test_string_callable_function_conversion(self):
"""Test string <-> callable conversion for function."""
# convert function to string
test_string = dict_utils.callable_to_string(test_function)
# convert string to function
test_function_2 = dict_utils.string_to_callable(test_string)
# check that functions are the same
self.assertEqual(test_function(2), test_function_2(2))
def test_string_callable_function_with_lambda_in_name_conversion(self):
"""Test string <-> callable conversion for function which has lambda in its name."""
# convert function to string
test_string = dict_utils.callable_to_string(test_lambda_function)
# convert string to function
test_function_2 = dict_utils.string_to_callable(test_string)
# check that functions are the same
self.assertEqual(test_function(2), test_function_2(2))
def test_string_callable_lambda_conversion(self):
"""Test string <-> callable conversion for lambda expression."""
# create lambda function
func = lambda x: x**2
# convert function to string
test_string = dict_utils.callable_to_string(func)
# convert string to function
func_2 = dict_utils.string_to_callable(test_string)
# check that functions are the same
self.assertEqual(func(2), func_2(2))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment