Unverified Commit f4bb9875 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes utils operations for math and string (#86)

# Description

* Fixes the `omni.isaac.orbit.utils.math.quat_apply_yaw` to compute the
yaw quaternion correctly.
* Adds functions to convert string and callable objects in the
`omni.isaac.orbit.utils.string` module. The function can deal with both
module functions and lambda expressions.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 5931b15a
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.6.0"
version = "0.6.1"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.6.1 (2023-07-16)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed the :meth:`omni.isaac.orbit.utils.math.quat_apply_yaw` to compute the yaw quaternion correctly.
Added
^^^^^^^
* Added functions to convert string and callable objects in :mod:`omni.isaac.orbit.utils.string`.
0.6.0 (2023-07-16)
~~~~~~~~~~~~~~~~~~
......
......@@ -17,7 +17,7 @@ Sub-module containing utilities for the Orbit framework.
from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES, TensorData, convert_to_torch
from .configclass import configclass
from .dict import class_to_dict, convert_dict_to_backend, print_dict, update_class_from_dict, update_dict
from .string import to_camel_case, to_snake_case
from .string import callable_to_string, is_lambda_expression, string_to_callable, to_camel_case, to_snake_case
from .timer import Timer
__all__ = [
......@@ -37,6 +37,9 @@ __all__ = [
# string utilities
"to_camel_case",
"to_snake_case",
"is_lambda_expression",
"string_to_callable",
"callable_to_string",
# timer
"Timer",
]
......@@ -8,12 +8,11 @@
import collections.abc
import hashlib
import importlib
import inspect
import json
from typing import Any, Callable, Dict, Iterable, Mapping
from typing import Any, Dict, Iterable, Mapping
from .array import TENSOR_TYPE_CONVERSIONS, TENSOR_TYPES
from .string import callable_to_string, string_to_callable
__all__ = [
"class_to_dict",
......@@ -60,7 +59,7 @@ def class_to_dict(obj: object) -> Dict[str, Any]:
continue
# check if attribute is callable -- function
if callable(value):
data[key] = f"{value.__module__}:{value.__name__}"
data[key] = callable_to_string(value)
# check if attribute is a dictionary
elif hasattr(value, "__dict__") or isinstance(value, dict):
data[key] = class_to_dict(value)
......@@ -96,7 +95,7 @@ def update_class_from_dict(obj, data: Dict[str, Any], _ns: str = "") -> None:
# iterate over the dictionary to look for callable values
for k, v in obj_mem.items():
if callable(v):
value[k] = _string_to_callable(value[k])
value[k] = string_to_callable(value[k])
setattr(obj, key, value)
elif isinstance(value, Mapping):
# recursively call if it is a dictionary
......@@ -111,7 +110,7 @@ def update_class_from_dict(obj, data: Dict[str, Any], _ns: str = "") -> None:
setattr(obj, key, value)
elif callable(obj_mem):
# update function name
value = _string_to_callable(value)
value = string_to_callable(value)
setattr(obj, key, value)
elif isinstance(value, type(obj_mem)):
# check that they are type-safe
......@@ -260,45 +259,7 @@ def print_dict(val, nesting: int = -4, start: bool = True):
print_dict(val[k], nesting, start=False)
else:
# deal with functions in print statements
if callable(val) and val.__name__ == "<lambda>":
print("lambda", inspect.getsourcelines(val)[0][0].strip().split("lambda")[1].strip()[:-1])
elif callable(val):
print(f"{val.__module__}:{val.__name__}")
if callable(val):
print(callable_to_string(val))
else:
print(val)
"""
Private helper functions.
"""
def _string_to_callable(name: str) -> Callable:
"""Resolves the module and function names to return the function.
Args:
name (str): The function name. The format should be 'module:attribute_name'.
Raises:
ValueError: When the resolved attribute is not a function.
ValueError: _description_
Returns:
Callable: The function loaded from the module.
"""
try:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
callable_object = getattr(mod, attr_name)
# check if attribute is callable
if callable(callable_object):
return callable_object
else:
raise ValueError(f"The imported object is not callable: '{name}'")
except AttributeError as e:
msg = (
"While updating the config from a dictionary, we could not interpret the entry"
"as a callable object. The format of input should be 'module:attribute_name'\n"
f"While processing input '{name}', received the error:\n {e}."
)
raise ValueError(msg)
......@@ -39,6 +39,7 @@ __all__ = [
"quat_from_euler_xyz",
"quat_apply_yaw",
"quat_box_minus",
"yaw_quat",
"euler_xyz_from_quat",
"axis_angle_from_quat",
# Rotation-Isaac Sim
......@@ -269,7 +270,30 @@ def euler_xyz_from_quat(quat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor,
cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z)
yaw = torch.atan2(sin_yaw, cos_yaw)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) # TODO: why not wrap_to_pi here ?
@torch.jit.script
def yaw_quat(quat: torch.Tensor) -> torch.Tensor:
"""Extract the yaw component of a quaternion.
Args:
quat (torch.Tensor): Input orientation to extract yaw from.
Returns:
torch.Tensor: A quaternion with only yaw component.
"""
quat_yaw = quat.clone().view(-1, 4)
qw = quat_yaw[:, 0]
qx = quat_yaw[:, 1]
qy = quat_yaw[:, 2]
qz = quat_yaw[:, 3]
yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz))
quat_yaw[:] = 0.0
quat_yaw[:, 3] = torch.sin(yaw / 2)
quat_yaw[:, 0] = torch.cos(yaw / 2)
quat_yaw = normalize(quat_yaw)
return quat_yaw
@torch.jit.script
......@@ -283,9 +307,7 @@ def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Rotated vector.
"""
quat_yaw = quat.clone().view(-1, 4)
quat_yaw[:, 1:3] = 0.0 # set x, y components as zero
quat_yaw = normalize(quat_yaw)
quat_yaw = yaw_quat(quat)
return quat_apply(quat_yaw, vec)
......@@ -294,8 +316,8 @@ def quat_box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""Implements box-minus operator (quaternion difference).
Args:
q1 (torch.Tensor): A (N, 4) tensor for quaternion (x, y, z, w)
q2 (torch.Tensor): A (N, 4) tensor for quaternion (x, y, z, w)
q1 (torch.Tensor): A (N, 4) tensor for quaternion (w, x, y, z).
q2 (torch.Tensor): A (N, 4) tensor for quaternion (w, x, y, z).
Returns:
torch.Tensor: q1 box-minus q2
......
......@@ -5,9 +5,24 @@
"""Transformations of strings."""
import ast
import importlib
import inspect
import re
from typing import Optional
from typing import Callable, Optional
__all__ = [
"to_camel_case",
"to_snake_case",
"is_lambda_expression",
"string_to_callable",
"callable_to_string",
]
"""
String formatting.
"""
def to_camel_case(snake_str: str, to: Optional[str] = "cC") -> str:
......@@ -49,3 +64,85 @@ def to_snake_case(camel_str: str) -> str:
"""
camel_str = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_str)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", camel_str).lower()
"""
String <-> Callable operations.
"""
def is_lambda_expression(name: str) -> bool:
"""Checks if the input string is a lambda expression.
Args:
name (str): The input string.
Returns:
bool: Whether the input string is a lambda expression.
"""
try:
ast.parse(name)
return isinstance(ast.parse(name).body[0], ast.Expr) and isinstance(ast.parse(name).body[0].value, ast.Lambda)
except SyntaxError:
return False
def callable_to_string(value: Callable) -> str:
"""Converts a callable object to a string.
Args:
callable_object (Callable): A callable object.
Raises:
ValueError: When the input argument is not a callable object.
Returns:
str: A string representation of the callable object.
"""
# check if callable
if not callable(value):
raise ValueError(f"The input argument is not callable: {value}.")
# check if lambda function
if value.__name__ == "<lambda>":
return f"lambda {inspect.getsourcelines(value)[0][0].strip().split('lambda')[1].strip().split(',')[0]}"
else:
# get the module and function name
module_name = value.__module__
function_name = value.__name__
# return the string
return f"{module_name}:{function_name}"
def string_to_callable(name: str) -> Callable:
"""Resolves the module and function names to return the function.
Args:
name (str): The function name. The format should be 'module:attribute_name' or a
lambda expression of format: 'lambda x: x'.
Raises:
ValueError: When the resolved attribute is not a function.
ValueError: When the module cannot be found.
Returns:
Callable: The function loaded from the module.
"""
try:
if is_lambda_expression(name):
callable_object = eval(name)
else:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
callable_object = getattr(mod, attr_name)
# check if attribute is callable
if callable(callable_object):
return callable_object
else:
raise AttributeError(f"The imported object is not callable: '{name}'")
except (ValueError, ModuleNotFoundError) as e:
msg = (
f"Could not resolve the input string '{name}' into callable object."
" The format of input should be 'module:attribute_name'.\n"
f"Received the error:\n {e}."
)
raise ValueError(msg)
......@@ -44,9 +44,9 @@ def raycast_mesh(ray_starts: torch.Tensor, ray_directions: torch.Tensor, mesh: w
ray_hits = torch.full((num_rays, 3), float("inf"), device=mesh_device)
# map the memory to warp arrays
ray_starts_wp = wp.from_torch(ray_starts, dtype=wp.vec3, requires_grad=False)
ray_directions_wp = wp.from_torch(ray_directions, dtype=wp.vec3, requires_grad=False)
ray_hits_wp = wp.from_torch(ray_hits, dtype=wp.vec3, requires_grad=False)
ray_starts_wp = wp.from_torch(ray_starts, dtype=wp.vec3)
ray_directions_wp = wp.from_torch(ray_directions, dtype=wp.vec3)
ray_hits_wp = wp.from_torch(ray_hits, dtype=wp.vec3)
# launch the warp kernel
wp.launch(
......
......@@ -8,6 +8,16 @@ import unittest
import omni.isaac.orbit.utils.dict as dict_utils
def test_function(x):
"""Test function for string <-> callable conversion."""
return x**2
def test_lambda_function(x):
"""Test function for string <-> callable conversion."""
return x**2
class TestDictUtilities(unittest.TestCase):
"""Test fixture for checking Kit utilities in Orbit."""
......@@ -25,6 +35,38 @@ class TestDictUtilities(unittest.TestCase):
# print the dictionary
dict_utils.print_dict(test_dict)
def test_string_callable_function_conversion(self):
"""Test string <-> callable conversion for function."""
# convert function to string
test_string = dict_utils.callable_to_string(test_function)
# convert string to function
test_function_2 = dict_utils.string_to_callable(test_string)
# check that functions are the same
self.assertEqual(test_function(2), test_function_2(2))
def test_string_callable_function_with_lambda_in_name_conversion(self):
"""Test string <-> callable conversion for function which has lambda in its name."""
# convert function to string
test_string = dict_utils.callable_to_string(test_lambda_function)
# convert string to function
test_function_2 = dict_utils.string_to_callable(test_string)
# check that functions are the same
self.assertEqual(test_function(2), test_function_2(2))
def test_string_callable_lambda_conversion(self):
"""Test string <-> callable conversion for lambda expression."""
# create lambda function
func = lambda x: x**2
# convert function to string
test_string = dict_utils.callable_to_string(func)
# convert string to function
func_2 = dict_utils.string_to_callable(test_string)
# check that functions are the same
self.assertEqual(func(2), func_2(2))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment