Unverified Commit 514baa4a authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds support for drift in `RayCaster` and makes fields in `ContactSensorData` optional (#201)

# Description

This MR adds support for 2D drift into the `RayCaster`. It also makes
certain attributes in the `ContactSensorData` optional since they are
not needed by default.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent c6af307a
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.13"
version = "0.9.14"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.9.14 (2023-10-21)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added 2-D drift (i.e. along x and y) to the :class:`omni.isaac.orbit.sensors.RayCaster` class.
* Added flags to the :class:`omni.isaac.orbit.sensors.ContactSensorCfg` to optionally obtain the
sensor origin and air time information. Since these are not required by default, they are
disabled by default.
Fixed
^^^^^
* Fixed the handling of contact sensor history buffer in the :class:`omni.isaac.orbit.sensors.ContactSensor` class.
Earlier, the buffer was not being updated correctly.
0.9.13 (2023-10-20)
~~~~~~~~~~~~~~~~~~~
......
......@@ -3,6 +3,8 @@
#
# SPDX-License-Identifier: BSD-3-Clause
# Ignore optional memory usage warning globally
# pyright: reportOptionalSubscript=false
from __future__ import annotations
......@@ -137,11 +139,17 @@ class ContactSensor(SensorBase):
if env_ids is None:
env_ids = slice(None)
# reset accumulative data buffers
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
self._data.net_forces_w[env_ids] = 0.0
# reset the data history
self._data.net_forces_w_history[env_ids] = 0.0
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids] = 0.0
# reset force matrix
if len(self.cfg.filter_prim_paths_expr) != 0:
self._data.force_matrix_w[env_ids] = 0.0
# reset the current air time
if self.cfg.track_air_time:
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
# Set all reset sensors to not outdated since their value won't be updated till next sim step.
self._is_outdated[env_ids] = False
......@@ -202,15 +210,24 @@ class ContactSensor(SensorBase):
f"\n\tResolved prim paths: {body_names_regex}"
)
# fill the data buffer
# prepare data buffers
self._data.net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
# optional buffers
# -- history of net forces
if self.cfg.history_length > 0:
self._data.net_forces_w_history = torch.zeros(
self._num_envs, self.cfg.history_length, self._num_bodies, 3, device=self._device
)
else:
self._data.net_forces_w_history = self._data.net_forces_w.unsqueeze(1)
# -- pose of sensor origins
if self.cfg.track_pose:
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
# -- air time between contacts
if self.cfg.track_air_time:
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.net_forces_w_history = torch.zeros(
self._num_envs, self.cfg.history_length + 1, self._num_bodies, 3, device=self._device
)
# force matrix: (num_sensors, num_bodies, num_shapes, num_filter_shapes, 3)
if len(self.cfg.filter_prim_paths_expr) != 0:
num_shapes = self.contact_physx_view.sensor_count // self._num_bodies
......@@ -224,17 +241,16 @@ class ContactSensor(SensorBase):
# default to all sensors
if len(env_ids) == self._num_envs:
env_ids = slice(None)
# obtain the poses of the sensors:
# TODO decide if we really to track poses -- This is the body's CoM. Not contact location.
pose = self.body_physx_view.get_transforms()
self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3]
self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:]
# obtain the contact forces
# TODO: We are handling the indexing ourself because of the shape; (N, B) vs expected (N * B).
# This isn't the most efficient way to do this, but it's the easiest to implement.
net_forces_w = self.contact_physx_view.get_net_contact_forces(dt=self._sim_physics_dt)
self._data.net_forces_w[env_ids, :, :] = net_forces_w.view(-1, self._num_bodies, 3)[env_ids]
# update contact force history
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids, 1:] = self._data.net_forces_w_history[env_ids, :-1].clone()
self._data.net_forces_w_history[env_ids, 0] = self._data.net_forces_w[env_ids]
# obtain the contact force matrix
if len(self.cfg.filter_prim_paths_expr) != 0:
......@@ -245,14 +261,13 @@ class ContactSensor(SensorBase):
force_matrix_w = self.contact_physx_view.get_contact_force_matrix(dt=self._sim_physics_dt)
force_matrix_w = force_matrix_w.view(-1, self._num_bodies, num_shapes, num_filters, 3)
self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids]
# update contact force history
previous_net_forces_w = self._data.net_forces_w_history.clone()
self._data.net_forces_w_history[env_ids, 0, :, :] = self._data.net_forces_w[env_ids, :, :]
if self.cfg.history_length > 0:
self._data.net_forces_w_history[env_ids, 1:, :, :] = previous_net_forces_w[env_ids, :-1, :, :]
# contact state
# obtain the pose of the sensor origin
if self.cfg.track_pose:
pose = self.body_physx_view.get_transforms()
self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3]
self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:]
# obtain the air time
if self.cfg.track_air_time:
# -- time elapsed since last update
# since this function is called every frame, we can use the difference to get the elapsed time
elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids]
......@@ -276,4 +291,10 @@ class ContactSensor(SensorBase):
net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1)
marker_indices = torch.where(net_contact_force_w > 1.0, 0, 1)
# check if prim is visualized
self.contact_visualizer.visualize(self._data.pos_w.view(-1, 3), marker_indices=marker_indices.view(-1))
if self.cfg.track_pose:
frame_origins: torch.Tensor = self._data.pos_w
else:
pose = self.body_physx_view.get_transforms()
frame_origins = pose.view(-1, self._num_bodies, 7)[:, :, :3]
# visualize
self.contact_visualizer.visualize(frame_origins.view(-1, 3), marker_indices=marker_indices.view(-1))
......@@ -18,6 +18,12 @@ class ContactSensorCfg(SensorBaseCfg):
class_type: type = ContactSensor
track_pose: bool = False
"""Whether to track the pose of the sensor's origin. Defaults to False."""
track_air_time: bool = False
"""Whether to track the air time of the bodies (time between contacts). Defaults to False."""
filter_prim_paths_expr: list[str] = list()
"""The list of primitive paths to filter contacts with.
......
......@@ -13,15 +13,22 @@ from dataclasses import dataclass
class ContactSensorData:
"""Data container for the contact reporting sensor."""
pos_w: torch.Tensor = None
pos_w: torch.Tensor | None = None
"""Position of the sensor origin in world frame.
Shape is (N, 3), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
"""
quat_w: torch.Tensor = None
quat_w: torch.Tensor | None = None
"""Orientation of the sensor origin in quaternion ``(w, x, y, z)`` in world frame.
Shape is (N, 4), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
"""
net_forces_w: torch.Tensor = None
......@@ -29,6 +36,7 @@ class ContactSensorData:
Shape is (N, B, 3), where ``N`` is the number of sensors and ``B`` is the number of bodies in each sensor.
"""
net_forces_w_history: torch.Tensor = None
"""The net contact forces in world frame.
......@@ -38,22 +46,30 @@ class ContactSensorData:
In the history dimension, the first index is the most recent and the last index is the oldest.
"""
force_matrix_w: torch.Tensor = None
force_matrix_w: torch.Tensor | None = None
"""The contact forces filtered between the sensor bodies and filtered bodies in world frame.
Shape is (N, B, S, M, 3), where ``N`` is the number of sensors, ``B`` is number of bodies in each sensor,
``S`` is number of shapes per body and ``M`` is the number of filtered bodies.
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this tensor will be empty.
Note:
If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None.
"""
last_air_time: torch.Tensor = None
last_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air before the last contact.
Shape is (N,), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
current_air_time: torch.Tensor = None
current_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air since the last contact.
Shape is (N,), where ``N`` is the number of sensors.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
......@@ -95,6 +95,15 @@ class RayCaster(SensorBase):
if self.ray_visualizer is not None:
self.ray_visualizer.set_visibility(debug_vis)
def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters
super().reset(env_ids)
# resolve None
if env_ids is None:
env_ids = slice(None)
# resample the drift
self.drift[env_ids].uniform_(*self.cfg.drift_range)
"""
Implementation.
"""
......@@ -180,7 +189,8 @@ class RayCaster(SensorBase):
# repeat the rays for each sensor
self.ray_starts = self.ray_starts.repeat(self._view.count, 1, 1)
self.ray_directions = self.ray_directions.repeat(self._view.count, 1, 1)
# prepare drift
self.drift = torch.zeros(self._view.count, 3, device=self.device)
# fill the data buffer
self._data.pos_w = torch.zeros(self._view.count, 3, device=self._device)
self._data.quat_w = torch.zeros(self._view.count, 4, device=self._device)
......@@ -190,6 +200,7 @@ class RayCaster(SensorBase):
"""Fills the buffers of the sensor data."""
# obtain the poses of the sensors
pos_w, quat_w = self._view.get_world_poses(env_ids, clone=False)
pos_w += self.drift[env_ids]
self._data.pos_w[env_ids] = pos_w
self._data.quat_w[env_ids] = quat_w
......
......@@ -53,3 +53,9 @@ class RayCasterCfg(SensorBaseCfg):
max_distance: float = 100.0
"""Maximum distance (in meters) from the sensor to ray cast to. Defaults to 100.0."""
drift_range: tuple[float, float] = (0.0, 0.0)
"""The range of drift (in meters) to add to the ray starting positions (xyz). Defaults to (0.0, 0.0).
For floating base robots, this is useful for simulating drift in the robot's pose estimation.
"""
......@@ -33,6 +33,28 @@ class TestTorchOperations(unittest.TestCase):
self.assertEqual(my_tensor[slice(None), 0, 0].shape, (400,))
self.assertEqual(my_tensor[:, 0, 0].shape, (400,))
def test_array_copying(self):
"""Check how indexing effects the returned tensor."""
size = (400, 300, 5)
my_tensor = torch.rand(size, device="cuda:0")
# obtain a slice of the tensor
my_slice = my_tensor[0, ...]
self.assertEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
# obtain a slice over ranges
my_slice = my_tensor[0:2, ...]
self.assertEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
# obtain a slice over list
my_slice = my_tensor[[0, 1], ...]
self.assertNotEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
# obtain a slice over tensor
my_slice = my_tensor[torch.tensor([0, 1]), ...]
self.assertNotEqual(my_slice.untyped_storage().data_ptr(), my_tensor.untyped_storage().data_ptr())
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment