Unverified Commit 514baa4a authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds support for drift in `RayCaster` and makes fields in `ContactSensorData` optional (#201)

# Description

This MR adds support for 2D drift into the `RayCaster`. It also makes
certain attributes in the `ContactSensorData` optional since they are
not needed by default.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent c6af307a
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.9.13" version = "0.9.14"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.9.14 (2023-10-21)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added 2-D drift (i.e. along x and y) to the :class:`omni.isaac.orbit.sensors.RayCaster` class.
* Added flags to the :class:`omni.isaac.orbit.sensors.ContactSensorCfg` to optionally obtain the
sensor origin and air time information. Since these are not required by default, they are
disabled by default.
Fixed
^^^^^
* Fixed the handling of contact sensor history buffer in the :class:`omni.isaac.orbit.sensors.ContactSensor` class.
Earlier, the buffer was not being updated correctly.
0.9.13 (2023-10-20) 0.9.13 (2023-10-20)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# Ignore optional memory usage warning globally
# pyright: reportOptionalSubscript=false
from __future__ import annotations from __future__ import annotations
...@@ -137,11 +139,17 @@ class ContactSensor(SensorBase): ...@@ -137,11 +139,17 @@ class ContactSensor(SensorBase):
if env_ids is None: if env_ids is None:
env_ids = slice(None) env_ids = slice(None)
# reset accumulative data buffers # reset accumulative data buffers
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
self._data.net_forces_w[env_ids] = 0.0 self._data.net_forces_w[env_ids] = 0.0
# reset the data history
self._data.net_forces_w_history[env_ids] = 0.0 self._data.net_forces_w_history[env_ids] = 0.0
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids] = 0.0
# reset force matrix
if len(self.cfg.filter_prim_paths_expr) != 0:
self._data.force_matrix_w[env_ids] = 0.0
# reset the current air time
if self.cfg.track_air_time:
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
# Set all reset sensors to not outdated since their value won't be updated till next sim step. # Set all reset sensors to not outdated since their value won't be updated till next sim step.
self._is_outdated[env_ids] = False self._is_outdated[env_ids] = False
...@@ -202,15 +210,24 @@ class ContactSensor(SensorBase): ...@@ -202,15 +210,24 @@ class ContactSensor(SensorBase):
f"\n\tResolved prim paths: {body_names_regex}" f"\n\tResolved prim paths: {body_names_regex}"
) )
# fill the data buffer # prepare data buffers
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device) self._data.net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.net_forces_w_history = torch.zeros( # optional buffers
self._num_envs, self.cfg.history_length + 1, self._num_bodies, 3, device=self._device # -- history of net forces
) if self.cfg.history_length > 0:
self._data.net_forces_w_history = torch.zeros(
self._num_envs, self.cfg.history_length, self._num_bodies, 3, device=self._device
)
else:
self._data.net_forces_w_history = self._data.net_forces_w.unsqueeze(1)
# -- pose of sensor origins
if self.cfg.track_pose:
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
# -- air time between contacts
if self.cfg.track_air_time:
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
# force matrix: (num_sensors, num_bodies, num_shapes, num_filter_shapes, 3) # force matrix: (num_sensors, num_bodies, num_shapes, num_filter_shapes, 3)
if len(self.cfg.filter_prim_paths_expr) != 0: if len(self.cfg.filter_prim_paths_expr) != 0:
num_shapes = self.contact_physx_view.sensor_count // self._num_bodies num_shapes = self.contact_physx_view.sensor_count // self._num_bodies
...@@ -224,17 +241,16 @@ class ContactSensor(SensorBase): ...@@ -224,17 +241,16 @@ class ContactSensor(SensorBase):
# default to all sensors # default to all sensors
if len(env_ids) == self._num_envs: if len(env_ids) == self._num_envs:
env_ids = slice(None) env_ids = slice(None)
# obtain the poses of the sensors:
# TODO decide if we really to track poses -- This is the body's CoM. Not contact location.
pose = self.body_physx_view.get_transforms()
self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3]
self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:]
# obtain the contact forces # obtain the contact forces
# TODO: We are handling the indexing ourself because of the shape; (N, B) vs expected (N * B). # TODO: We are handling the indexing ourself because of the shape; (N, B) vs expected (N * B).
# This isn't the most efficient way to do this, but it's the easiest to implement. # This isn't the most efficient way to do this, but it's the easiest to implement.
net_forces_w = self.contact_physx_view.get_net_contact_forces(dt=self._sim_physics_dt) net_forces_w = self.contact_physx_view.get_net_contact_forces(dt=self._sim_physics_dt)
self._data.net_forces_w[env_ids, :, :] = net_forces_w.view(-1, self._num_bodies, 3)[env_ids] self._data.net_forces_w[env_ids, :, :] = net_forces_w.view(-1, self._num_bodies, 3)[env_ids]
# update contact force history
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids, 1:] = self._data.net_forces_w_history[env_ids, :-1].clone()
self._data.net_forces_w_history[env_ids, 0] = self._data.net_forces_w[env_ids]
# obtain the contact force matrix # obtain the contact force matrix
if len(self.cfg.filter_prim_paths_expr) != 0: if len(self.cfg.filter_prim_paths_expr) != 0:
...@@ -245,26 +261,25 @@ class ContactSensor(SensorBase): ...@@ -245,26 +261,25 @@ class ContactSensor(SensorBase):
force_matrix_w = self.contact_physx_view.get_contact_force_matrix(dt=self._sim_physics_dt) force_matrix_w = self.contact_physx_view.get_contact_force_matrix(dt=self._sim_physics_dt)
force_matrix_w = force_matrix_w.view(-1, self._num_bodies, num_shapes, num_filters, 3) force_matrix_w = force_matrix_w.view(-1, self._num_bodies, num_shapes, num_filters, 3)
self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids] self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids]
# obtain the pose of the sensor origin
# update contact force history if self.cfg.track_pose:
previous_net_forces_w = self._data.net_forces_w_history.clone() pose = self.body_physx_view.get_transforms()
self._data.net_forces_w_history[env_ids, 0, :, :] = self._data.net_forces_w[env_ids, :, :] self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3]
if self.cfg.history_length > 0: self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:]
self._data.net_forces_w_history[env_ids, 1:, :, :] = previous_net_forces_w[env_ids, :-1, :, :] # obtain the air time
if self.cfg.track_air_time:
# contact state # -- time elapsed since last update
# -- time elapsed since last update # since this function is called every frame, we can use the difference to get the elapsed time
# since this function is called every frame, we can use the difference to get the elapsed time elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids]
elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids] # -- check contact state of bodies
# -- check contact state of bodies is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > 1.0
is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > 1.0 is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact
is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact # -- update ongoing timer for bodies air
# -- update ongoing timer for bodies air self._data.current_air_time[env_ids] += elapsed_time.unsqueeze(-1)
self._data.current_air_time[env_ids] += elapsed_time.unsqueeze(-1) # -- update time for the last time bodies were in contact
# -- update time for the last time bodies were in contact self._data.last_air_time[env_ids] = self._data.current_air_time[env_ids] * is_first_contact
self._data.last_air_time[env_ids] = self._data.current_air_time[env_ids] * is_first_contact # -- increment timers for bodies that are not in contact
# -- increment timers for bodies that are not in contact self._data.current_air_time[env_ids] *= ~is_contact
self._data.current_air_time[env_ids] *= ~is_contact
def _debug_vis_impl(self): def _debug_vis_impl(self):
# visualize the contacts # visualize the contacts
...@@ -276,4 +291,10 @@ class ContactSensor(SensorBase): ...@@ -276,4 +291,10 @@ class ContactSensor(SensorBase):
net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1) net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1)
marker_indices = torch.where(net_contact_force_w > 1.0, 0, 1) marker_indices = torch.where(net_contact_force_w > 1.0, 0, 1)
# check if prim is visualized # check if prim is visualized
self.contact_visualizer.visualize(self._data.pos_w.view(-1, 3), marker_indices=marker_indices.view(-1)) if self.cfg.track_pose:
frame_origins: torch.Tensor = self._data.pos_w
else:
pose = self.body_physx_view.get_transforms()
frame_origins = pose.view(-1, self._num_bodies, 7)[:, :, :3]
# visualize
self.contact_visualizer.visualize(frame_origins.view(-1, 3), marker_indices=marker_indices.view(-1))
...@@ -18,6 +18,12 @@ class ContactSensorCfg(SensorBaseCfg): ...@@ -18,6 +18,12 @@ class ContactSensorCfg(SensorBaseCfg):
class_type: type = ContactSensor class_type: type = ContactSensor
track_pose: bool = False
"""Whether to track the pose of the sensor's origin. Defaults to False."""
track_air_time: bool = False
"""Whether to track the air time of the bodies (time between contacts). Defaults to False."""
filter_prim_paths_expr: list[str] = list() filter_prim_paths_expr: list[str] = list()
"""The list of primitive paths to filter contacts with. """The list of primitive paths to filter contacts with.
......
...@@ -13,15 +13,22 @@ from dataclasses import dataclass ...@@ -13,15 +13,22 @@ from dataclasses import dataclass
class ContactSensorData: class ContactSensorData:
"""Data container for the contact reporting sensor.""" """Data container for the contact reporting sensor."""
pos_w: torch.Tensor = None pos_w: torch.Tensor | None = None
"""Position of the sensor origin in world frame. """Position of the sensor origin in world frame.
Shape is (N, 3), where ``N`` is the number of sensors. Shape is (N, 3), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
""" """
quat_w: torch.Tensor = None
quat_w: torch.Tensor | None = None
"""Orientation of the sensor origin in quaternion ``(w, x, y, z)`` in world frame. """Orientation of the sensor origin in quaternion ``(w, x, y, z)`` in world frame.
Shape is (N, 4), where ``N`` is the number of sensors. Shape is (N, 4), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
""" """
net_forces_w: torch.Tensor = None net_forces_w: torch.Tensor = None
...@@ -29,6 +36,7 @@ class ContactSensorData: ...@@ -29,6 +36,7 @@ class ContactSensorData:
Shape is (N, B, 3), where ``N`` is the number of sensors and ``B`` is the number of bodies in each sensor. Shape is (N, B, 3), where ``N`` is the number of sensors and ``B`` is the number of bodies in each sensor.
""" """
net_forces_w_history: torch.Tensor = None net_forces_w_history: torch.Tensor = None
"""The net contact forces in world frame. """The net contact forces in world frame.
...@@ -38,22 +46,30 @@ class ContactSensorData: ...@@ -38,22 +46,30 @@ class ContactSensorData:
In the history dimension, the first index is the most recent and the last index is the oldest. In the history dimension, the first index is the most recent and the last index is the oldest.
""" """
force_matrix_w: torch.Tensor = None force_matrix_w: torch.Tensor | None = None
"""The contact forces filtered between the sensor bodies and filtered bodies in world frame. """The contact forces filtered between the sensor bodies and filtered bodies in world frame.
Shape is (N, B, S, M, 3), where ``N`` is the number of sensors, ``B`` is number of bodies in each sensor, Shape is (N, B, S, M, 3), where ``N`` is the number of sensors, ``B`` is number of bodies in each sensor,
``S`` is number of shapes per body and ``M`` is the number of filtered bodies. ``S`` is number of shapes per body and ``M`` is the number of filtered bodies.
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this tensor will be empty. Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
""" """
last_air_time: torch.Tensor = None last_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air before the last contact. """Time spent (in s) in the air before the last contact.
Shape is (N,), where ``N`` is the number of sensors. Shape is (N,), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
""" """
current_air_time: torch.Tensor = None
current_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air since the last contact. """Time spent (in s) in the air since the last contact.
Shape is (N,), where ``N`` is the number of sensors. Shape is (N,), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
""" """
...@@ -95,6 +95,15 @@ class RayCaster(SensorBase): ...@@ -95,6 +95,15 @@ class RayCaster(SensorBase):
if self.ray_visualizer is not None: if self.ray_visualizer is not None:
self.ray_visualizer.set_visibility(debug_vis) self.ray_visualizer.set_visibility(debug_vis)
def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters
super().reset(env_ids)
# resolve None
if env_ids is None:
env_ids = slice(None)
# resample the drift
self.drift[env_ids].uniform_(*self.cfg.drift_range)
""" """
Implementation. Implementation.
""" """
...@@ -180,7 +189,8 @@ class RayCaster(SensorBase): ...@@ -180,7 +189,8 @@ class RayCaster(SensorBase):
# repeat the rays for each sensor # repeat the rays for each sensor
self.ray_starts = self.ray_starts.repeat(self._view.count, 1, 1) self.ray_starts = self.ray_starts.repeat(self._view.count, 1, 1)
self.ray_directions = self.ray_directions.repeat(self._view.count, 1, 1) self.ray_directions = self.ray_directions.repeat(self._view.count, 1, 1)
# prepare drift
self.drift = torch.zeros(self._view.count, 3, device=self.device)
# fill the data buffer # fill the data buffer
self._data.pos_w = torch.zeros(self._view.count, 3, device=self._device) self._data.pos_w = torch.zeros(self._view.count, 3, device=self._device)
self._data.quat_w = torch.zeros(self._view.count, 4, device=self._device) self._data.quat_w = torch.zeros(self._view.count, 4, device=self._device)
...@@ -190,6 +200,7 @@ class RayCaster(SensorBase): ...@@ -190,6 +200,7 @@ class RayCaster(SensorBase):
"""Fills the buffers of the sensor data.""" """Fills the buffers of the sensor data."""
# obtain the poses of the sensors # obtain the poses of the sensors
pos_w, quat_w = self._view.get_world_poses(env_ids, clone=False) pos_w, quat_w = self._view.get_world_poses(env_ids, clone=False)
pos_w += self.drift[env_ids]
self._data.pos_w[env_ids] = pos_w self._data.pos_w[env_ids] = pos_w
self._data.quat_w[env_ids] = quat_w self._data.quat_w[env_ids] = quat_w
......
...@@ -53,3 +53,9 @@ class RayCasterCfg(SensorBaseCfg): ...@@ -53,3 +53,9 @@ class RayCasterCfg(SensorBaseCfg):
max_distance: float = 100.0 max_distance: float = 100.0
"""Maximum distance (in meters) from the sensor to ray cast to. Defaults to 100.0.""" """Maximum distance (in meters) from the sensor to ray cast to. Defaults to 100.0."""
drift_range: tuple[float, float] = (0.0, 0.0)
"""The range of drift (in meters) to add to the ray starting positions (xyz). Defaults to (0.0, 0.0).
For floating base robots, this is useful for simulating drift in the robot's pose estimation.
"""
...@@ -33,6 +33,28 @@ class TestTorchOperations(unittest.TestCase): ...@@ -33,6 +33,28 @@ class TestTorchOperations(unittest.TestCase):
self.assertEqual(my_tensor[slice(None), 0, 0].shape, (400,)) self.assertEqual(my_tensor[slice(None), 0, 0].shape, (400,))
self.assertEqual(my_tensor[:, 0, 0].shape, (400,)) self.assertEqual(my_tensor[:, 0, 0].shape, (400,))
def test_array_copying(self):
"""Check how indexing effects the returned tensor."""
size = (400, 300, 5)
my_tensor = torch.rand(size, device="cuda:0")
# obtain a slice of the tensor
my_slice = my_tensor[0, ...]
self.assertEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
# obtain a slice over ranges
my_slice = my_tensor[0:2, ...]
self.assertEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
# obtain a slice over list
my_slice = my_tensor[[0, 1], ...]
self.assertNotEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
# obtain a slice over tensor
my_slice = my_tensor[torch.tensor([0, 1]), ...]
self.assertNotEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment