Unverified Commit 3acff1be authored by Pascal Roth's avatar Pascal Roth Committed by GitHub

Adds preserving of joint and body indices to the `SceneEntityCfg` (#469)

# Description

The `SceneEntityCfg` always return the `joint_ids` and `body_ids` in the
simulation order even if a different order was specified when passing
the names to the `SceneEntityCfg`. This fix changes the behavior, to
always return the ids in the specified order by introducing the flag
`simulation_order` that is true per default for everything within orbit
but false for the `SceneEntityCfg`.

Fixes #461 

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarPascal Roth <57946385+pascal-roth@users.noreply.github.com>
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent fcc216a5
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.15.6" version = "0.15.7"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.15.7 (2024-03-28)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Adds option to return indices/data in the specified query keys order in
:class:`omni.isaac.orbit.managers.SceneEntityCfg` class, and the respective
:func:`omni.isaac.orbit.utils.string.resolve_matching_names_values` and
:func:`omni.isaac.orbit.utils.string.resolve_matching_names` functions.
0.15.6 (2024-03-28) 0.15.6 (2024-03-28)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -228,7 +228,7 @@ class Articulation(RigidObject): ...@@ -228,7 +228,7 @@ class Articulation(RigidObject):
self._previous_joint_vel[:] = self._data.joint_vel[:] self._previous_joint_vel[:] = self._data.joint_vel[:]
def find_joints( def find_joints(
self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False
) -> tuple[list[int], list[str]]: ) -> tuple[list[int], list[str]]:
"""Find joints in the articulation based on the name keys. """Find joints in the articulation based on the name keys.
...@@ -239,6 +239,7 @@ class Articulation(RigidObject): ...@@ -239,6 +239,7 @@ class Articulation(RigidObject):
name_keys: A regular expression or a list of regular expressions to match the joint names. name_keys: A regular expression or a list of regular expressions to match the joint names.
joint_subset: A subset of joints to search for. Defaults to None, which means all joints joint_subset: A subset of joints to search for. Defaults to None, which means all joints
in the articulation are searched. in the articulation are searched.
preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
Returns: Returns:
A tuple of lists containing the joint indices and names. A tuple of lists containing the joint indices and names.
...@@ -246,7 +247,7 @@ class Articulation(RigidObject): ...@@ -246,7 +247,7 @@ class Articulation(RigidObject):
if joint_subset is None: if joint_subset is None:
joint_subset = self.joint_names joint_subset = self.joint_names
# find joints # find joints
return string_utils.resolve_matching_names(name_keys, joint_subset) return string_utils.resolve_matching_names(name_keys, joint_subset, preserve_order)
""" """
Operations - Setters. Operations - Setters.
......
...@@ -148,7 +148,7 @@ class RigidObject(AssetBase): ...@@ -148,7 +148,7 @@ class RigidObject(AssetBase):
# -- update common data # -- update common data
self._update_common_data(dt) self._update_common_data(dt)
def find_bodies(self, name_keys: str | Sequence[str]) -> tuple[list[int], list[str]]: def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys. """Find bodies in the articulation based on the name keys.
Please check the :meth:`omni.isaac.orbit.utils.string_utils.resolve_matching_names` function for more Please check the :meth:`omni.isaac.orbit.utils.string_utils.resolve_matching_names` function for more
...@@ -156,11 +156,12 @@ class RigidObject(AssetBase): ...@@ -156,11 +156,12 @@ class RigidObject(AssetBase):
Args: Args:
name_keys: A regular expression or a list of regular expressions to match the body names. name_keys: A regular expression or a list of regular expressions to match the body names.
preserve_order: Whether to preserve the order of the name keys in the output. Defaults to False.
Returns: Returns:
A tuple of lists containing the body indices and names. A tuple of lists containing the body indices and names.
""" """
return string_utils.resolve_matching_names(name_keys, self.body_names) return string_utils.resolve_matching_names(name_keys, self.body_names, preserve_order)
""" """
Operations - Write to simulation. Operations - Write to simulation.
......
...@@ -63,6 +63,19 @@ class SceneEntityCfg: ...@@ -63,6 +63,19 @@ class SceneEntityCfg:
manager. manager.
""" """
preserve_order: bool = False
"""Whether to preserve indices ordering to match with that in the specified joint or body names. Defaults to False.
If False, the ordering of the indices are sorted in ascending order (i.e. the ordering in the entity's joints
or bodies). Otherwise, the indices are preserved in the order of the specified joint and body names.
For more details, see the :meth:`omni.isaac.orbit.utils.string.resolve_matching_names` function.
.. note::
This attribute is only used when :attr:`joint_names` or :attr:`body_names` are specified.
"""
def resolve(self, scene: InteractiveScene): def resolve(self, scene: InteractiveScene):
"""Resolves the scene entity and converts the joint and body names to indices. """Resolves the scene entity and converts the joint and body names to indices.
...@@ -91,7 +104,7 @@ class SceneEntityCfg: ...@@ -91,7 +104,7 @@ class SceneEntityCfg:
self.joint_names = [self.joint_names] self.joint_names = [self.joint_names]
if isinstance(self.joint_ids, int): if isinstance(self.joint_ids, int):
self.joint_ids = [self.joint_ids] self.joint_ids = [self.joint_ids]
joint_ids, _ = entity.find_joints(self.joint_names) joint_ids, _ = entity.find_joints(self.joint_names, preserve_order=self.preserve_order)
joint_names = [entity.joint_names[i] for i in self.joint_ids] joint_names = [entity.joint_names[i] for i in self.joint_ids]
if joint_ids != self.joint_ids or joint_names != self.joint_names: if joint_ids != self.joint_ids or joint_names != self.joint_names:
raise ValueError( raise ValueError(
...@@ -104,9 +117,10 @@ class SceneEntityCfg: ...@@ -104,9 +117,10 @@ class SceneEntityCfg:
elif self.joint_names is not None: elif self.joint_names is not None:
if isinstance(self.joint_names, str): if isinstance(self.joint_names, str):
self.joint_names = [self.joint_names] self.joint_names = [self.joint_names]
self.joint_ids, _ = entity.find_joints(self.joint_names) self.joint_ids, _ = entity.find_joints(self.joint_names, preserve_order=self.preserve_order)
# performance optimization (slice offers faster indexing than list of indices) # performance optimization (slice offers faster indexing than list of indices)
if len(self.joint_ids) == entity.num_joints: # only all joint in the entity order are selected
if len(self.joint_ids) == entity.num_joints and self.joint_names == entity.joint_names:
self.joint_ids = slice(None) self.joint_ids = slice(None)
# -- from joint indices to joint names # -- from joint indices to joint names
elif self.joint_ids != slice(None): elif self.joint_ids != slice(None):
...@@ -123,7 +137,7 @@ class SceneEntityCfg: ...@@ -123,7 +137,7 @@ class SceneEntityCfg:
self.body_names = [self.body_names] self.body_names = [self.body_names]
if isinstance(self.body_ids, int): if isinstance(self.body_ids, int):
self.body_ids = [self.body_ids] self.body_ids = [self.body_ids]
body_ids, _ = entity.find_bodies(self.body_names) body_ids, _ = entity.find_bodies(self.body_names, preserve_order=self.preserve_order)
body_names = [entity.body_names[i] for i in self.body_ids] body_names = [entity.body_names[i] for i in self.body_ids]
if body_ids != self.body_ids or body_names != self.body_names: if body_ids != self.body_ids or body_names != self.body_names:
raise ValueError( raise ValueError(
...@@ -136,9 +150,10 @@ class SceneEntityCfg: ...@@ -136,9 +150,10 @@ class SceneEntityCfg:
elif self.body_names is not None: elif self.body_names is not None:
if isinstance(self.body_names, str): if isinstance(self.body_names, str):
self.body_names = [self.body_names] self.body_names = [self.body_names]
self.body_ids, _ = entity.find_bodies(self.body_names) self.body_ids, _ = entity.find_bodies(self.body_names, preserve_order=self.preserve_order)
# performance optimization (slice offers faster indexing than list of indices) # performance optimization (slice offers faster indexing than list of indices)
if len(self.body_ids) == entity.num_bodies: # only all bodies in the entity order are selected
if len(self.body_ids) == entity.num_bodies and self.body_names == entity.body_names:
self.body_ids = slice(None) self.body_ids = slice(None)
# -- from body indices to body names # -- from body indices to body names
elif self.body_ids != slice(None): elif self.body_ids != slice(None):
......
...@@ -147,17 +147,25 @@ Regex operations. ...@@ -147,17 +147,25 @@ Regex operations.
""" """
def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[str]) -> tuple[list[int], list[str]]: def resolve_matching_names(
keys: str | Sequence[str], list_of_strings: Sequence[str], preserve_order: bool = False
) -> tuple[list[int], list[str]]:
"""Match a list of query regular expressions against a list of strings and return the matched indices and names. """Match a list of query regular expressions against a list of strings and return the matched indices and names.
When a list of query regular expressions is provided, the function checks each target string against each When a list of query regular expressions is provided, the function checks each target string against each
query regular expression and returns the indices of the matched strings and the matched strings. query regular expression and returns the indices of the matched strings and the matched strings.
This means that the ordering is dictated by the order of the target strings and not the order of the query
regular expressions.
For example, if the list of strings is ['a', 'b', 'c', 'd', 'e'] and the regular expressions are ['a|c', 'b'], If the :attr:`preserve_order` is True, the ordering of the matched indices and names is the same as the order
then the function will return the indices of the matched strings and the matched strings, i.e. of the provided list of strings. This means that the ordering is dictated by the order of the target strings
([0, 1, 2], ['a', 'b', 'c']). and not the order of the query regular expressions.
If the :attr:`preserve_order` is False, the ordering of the matched indices and names is the same as the order
of the provided list of query regular expressions.
For example, consider the list of strings is ['a', 'b', 'c', 'd', 'e'] and the regular expressions are ['a|c', 'b'].
If :attr:`preserve_order` is False, then the function will return the indices of the matched strings and the
strings as: ([0, 1, 2], ['a', 'b', 'c']). When :attr:`preserve_order` is True, it will return them as:
([0, 2, 1], ['a', 'c', 'b']).
Note: Note:
The function does not sort the indices. It returns the indices in the order they are found. The function does not sort the indices. It returns the indices in the order they are found.
...@@ -165,6 +173,7 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[ ...@@ -165,6 +173,7 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[
Args: Args:
keys: A regular expression or a list of regular expressions to match the strings in the list. keys: A regular expression or a list of regular expressions to match the strings in the list.
list_of_strings: A list of strings to match. list_of_strings: A list of strings to match.
preserve_order: Whether to preserve the order of the query keys in the returned values. Defaults to False.
Returns: Returns:
A tuple of lists containing the matched indices and names. A tuple of lists containing the matched indices and names.
...@@ -179,6 +188,7 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[ ...@@ -179,6 +188,7 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[
# find matching patterns # find matching patterns
index_list = [] index_list = []
names_list = [] names_list = []
key_idx_list = []
# book-keeping to check that we always have a one-to-one mapping # book-keeping to check that we always have a one-to-one mapping
# i.e. each target string should match only one regular expression # i.e. each target string should match only one regular expression
target_strings_match_found = [None for _ in range(len(list_of_strings))] target_strings_match_found = [None for _ in range(len(list_of_strings))]
...@@ -197,8 +207,27 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[ ...@@ -197,8 +207,27 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[
target_strings_match_found[target_index] = re_key target_strings_match_found[target_index] = re_key
index_list.append(target_index) index_list.append(target_index)
names_list.append(potential_match_string) names_list.append(potential_match_string)
key_idx_list.append(key_index)
# add for regex key # add for regex key
keys_match_found[key_index].append(potential_match_string) keys_match_found[key_index].append(potential_match_string)
# reorder keys if they should be returned in order of the query keys
if preserve_order:
reordered_index_list = [None] * len(index_list)
global_index = 0
for key_index in range(len(keys)):
for key_idx_position, key_idx_entry in enumerate(key_idx_list):
if key_idx_entry == key_index:
reordered_index_list[key_idx_position] = global_index
global_index += 1
# reorder index and names list
index_list_reorder = [None] * len(index_list)
names_list_reorder = [None] * len(index_list)
for idx, reorder_idx in enumerate(reordered_index_list):
index_list_reorder[reorder_idx] = index_list[idx]
names_list_reorder[reorder_idx] = names_list[idx]
# update
index_list = index_list_reorder
names_list = names_list_reorder
# check that all regular expressions are matched # check that all regular expressions are matched
if not all(keys_match_found): if not all(keys_match_found):
# make this print nicely aligned for debugging # make this print nicely aligned for debugging
...@@ -215,27 +244,33 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[ ...@@ -215,27 +244,33 @@ def resolve_matching_names(keys: str | Sequence[str], list_of_strings: Sequence[
def resolve_matching_names_values( def resolve_matching_names_values(
data: dict[str, Any], list_of_strings: Sequence[str] data: dict[str, Any], list_of_strings: Sequence[str], preserve_order: bool = False
) -> tuple[list[int], list[str], list[Any]]: ) -> tuple[list[int], list[str], list[Any]]:
"""Match a list of regular expressions in a dictionary against a list of strings and return """Match a list of regular expressions in a dictionary against a list of strings and return
the matched indices, names, and values. the matched indices, names, and values.
For example, if the dictionary is {'a|b|c': 1, 'd|e': 2} and the list of strings is ['a', 'b', 'c', 'd', 'e'], If the :attr:`preserve_order` is True, the ordering of the matched indices and names is the same as the order
then the function will return the indices of the matched strings, the matched strings, and the values, i.e. of the provided list of strings. This means that the ordering is dictated by the order of the target strings
([0, 1, 2, 3, 4], ['a', 'b', 'c', 'd', 'e'], [1, 1, 1, 2, 2]). and not the order of the query regular expressions.
Note: If the :attr:`preserve_order` is False, the ordering of the matched indices and names is the same as the order
The function does not sort the indices. It returns the indices in the order they are found. of the provided list of query regular expressions.
For example, consider the dictionary is {"a|d|e": 1, "b|c": 2}, the list of strings is ['a', 'b', 'c', 'd', 'e'].
If :attr:`preserve_order` is False, then the function will return the indices of the matched strings, the
matched strings, and the values as: ([0, 1, 2, 3, 4], ['a', 'b', 'c', 'd', 'e'], [1, 2, 2, 1, 1]). When
:attr:`preserve_order` is True, it will return them as: ([0, 3, 4, 1, 2], ['a', 'd', 'e', 'b', 'c'], [1, 1, 1, 2, 2]).
Args: Args:
data: A dictionary of regular expressions and values to match the strings in the list. data: A dictionary of regular expressions and values to match the strings in the list.
list_of_strings: A list of strings to match. list_of_strings: A list of strings to match.
preserve_order: Whether to preserve the order of the query keys in the returned values. Defaults to False.
Returns: Returns:
A tuple of lists containing the matched indices, names, and values. A tuple of lists containing the matched indices, names, and values.
Raises: Raises:
TypeError: When the input argument `data` is not a dictionary. TypeError: When the input argument :attr:`data` is not a dictionary.
ValueError: When multiple matches are found for a string in the dictionary. ValueError: When multiple matches are found for a string in the dictionary.
ValueError: When not all regular expressions in the data keys are matched. ValueError: When not all regular expressions in the data keys are matched.
""" """
...@@ -246,6 +281,7 @@ def resolve_matching_names_values( ...@@ -246,6 +281,7 @@ def resolve_matching_names_values(
index_list = [] index_list = []
names_list = [] names_list = []
values_list = [] values_list = []
key_idx_list = []
# book-keeping to check that we always have a one-to-one mapping # book-keeping to check that we always have a one-to-one mapping
# i.e. each target string should match only one regular expression # i.e. each target string should match only one regular expression
target_strings_match_found = [None for _ in range(len(list_of_strings))] target_strings_match_found = [None for _ in range(len(list_of_strings))]
...@@ -265,8 +301,30 @@ def resolve_matching_names_values( ...@@ -265,8 +301,30 @@ def resolve_matching_names_values(
index_list.append(target_index) index_list.append(target_index)
names_list.append(potential_match_string) names_list.append(potential_match_string)
values_list.append(value) values_list.append(value)
key_idx_list.append(key_index)
# add for regex key # add for regex key
keys_match_found[key_index].append(potential_match_string) keys_match_found[key_index].append(potential_match_string)
# reorder keys if they should be returned in order of the query keys
if preserve_order:
reordered_index_list = [None] * len(index_list)
global_index = 0
for key_index in range(len(data)):
for key_idx_position, key_idx_entry in enumerate(key_idx_list):
if key_idx_entry == key_index:
reordered_index_list[key_idx_position] = global_index
global_index += 1
# reorder index and names list
index_list_reorder = [None] * len(index_list)
names_list_reorder = [None] * len(index_list)
values_list_reorder = [None] * len(index_list)
for idx, reorder_idx in enumerate(reordered_index_list):
index_list_reorder[reorder_idx] = index_list[idx]
names_list_reorder[reorder_idx] = names_list[idx]
values_list_reorder[reorder_idx] = values_list[idx]
# update
index_list = index_list_reorder
names_list = names_list_reorder
values_list = values_list_reorder
# check that all regular expressions are matched # check that all regular expressions are matched
if not all(keys_match_found): if not all(keys_match_found):
# make this print nicely aligned for debugging # make this print nicely aligned for debugging
......
...@@ -103,6 +103,41 @@ class TestStringUtilities(unittest.TestCase): ...@@ -103,6 +103,41 @@ class TestStringUtilities(unittest.TestCase):
self.assertEqual(index_list, ground_truth_index_list) self.assertEqual(index_list, ground_truth_index_list)
self.assertEqual(names_list, [robot_joint_names[i] for i in ground_truth_index_list]) self.assertEqual(names_list, [robot_joint_names[i] for i in ground_truth_index_list])
def test_resolve_matching_names_with_preserved_order(self):
# list of strings and query list
robot_joint_names = []
for i in ["hip", "thigh", "calf"]:
for j in ["FL", "FR", "RL", "RR"]:
robot_joint_names.append(f"{j}_{i}_joint")
query_list = [
"FL_hip_joint",
"FL_thigh_joint",
"FR_hip_joint",
"FR_thigh_joint",
"FL_calf_joint",
"FR_calf_joint",
]
# test return in target ordering with sublist
query_list.reverse()
index_list, names_list = string_utils.resolve_matching_names(query_list, robot_joint_names, preserve_order=True)
ground_truth_index_list = [9, 8, 5, 1, 4, 0]
self.assertEqual(names_list, query_list)
self.assertEqual(index_list, ground_truth_index_list)
# test return in target ordering with regex expression
index_list, names_list = string_utils.resolve_matching_names(
["FR.*", "FL.*"], robot_joint_names, preserve_order=True
)
ground_truth_index_list = [1, 5, 9, 0, 4, 8]
self.assertEqual(index_list, ground_truth_index_list)
self.assertEqual(names_list, [robot_joint_names[i] for i in ground_truth_index_list])
# test return in target ordering with a mix of regex and non-regex expression
index_list, names_list = string_utils.resolve_matching_names(
["FR.*", "FL_calf_joint", "FL_thigh_joint", "FL_hip_joint"], robot_joint_names, preserve_order=True
)
ground_truth_index_list = [1, 5, 9, 8, 4, 0]
self.assertEqual(index_list, ground_truth_index_list)
self.assertEqual(names_list, [robot_joint_names[i] for i in ground_truth_index_list])
def test_resolve_matching_names_values_with_basic_strings(self): def test_resolve_matching_names_values_with_basic_strings(self):
"""Test resolving matching names with a basic expression.""" """Test resolving matching names with a basic expression."""
# list of strings # list of strings
...@@ -126,7 +161,36 @@ class TestStringUtilities(unittest.TestCase): ...@@ -126,7 +161,36 @@ class TestStringUtilities(unittest.TestCase):
# test no regex match # test no regex match
query_names = {"a|c": 1, "b": 0, "f": 2} query_names = {"a|c": 1, "b": 0, "f": 2}
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = string_utils.resolve_matching_names(query_names, target_names) _ = string_utils.resolve_matching_names_values(query_names, target_names)
def test_resolve_matching_names_values_with_basic_strings_and_preserved_order(self):
"""Test resolving matching names with a basic expression."""
# list of strings
target_names = ["a", "b", "c", "d", "e"]
# test matching names
data = {"a|c": 1, "b": 2}
index_list, names_list, values_list = string_utils.resolve_matching_names_values(
data, target_names, preserve_order=True
)
self.assertEqual(index_list, [0, 2, 1])
self.assertEqual(names_list, ["a", "c", "b"])
self.assertEqual(values_list, [1, 1, 2])
# test matching names with regex
data = {"a|d|e": 1, "b|c": 2}
index_list, names_list, values_list = string_utils.resolve_matching_names_values(
data, target_names, preserve_order=True
)
self.assertEqual(index_list, [0, 3, 4, 1, 2])
self.assertEqual(names_list, ["a", "d", "e", "b", "c"])
self.assertEqual(values_list, [1, 1, 1, 2, 2])
# test matching names with regex
data = {"a|d|e|b": 1, "b|c": 2}
with self.assertRaises(ValueError):
_ = string_utils.resolve_matching_names_values(data, target_names, preserve_order=True)
# test no regex match
query_names = {"a|c": 1, "b": 0, "f": 2}
with self.assertRaises(ValueError):
_ = string_utils.resolve_matching_names_values(query_names, target_names, preserve_order=True)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment