Unverified Commit 7b16b679 authored by Greg Attra's avatar Greg Attra Committed by GitHub

Adds friction force reporting to ContactSensor (#3563)

# Description

This PR extends the `ContactSensor` class to expose aggregated friction
forces for each filtered body. It uses the same vectorized approach used
for [`contact_points`](https://github.com/isaac-sim/IsaacLab/pull/2842).

Concretely, this change introduces:
- `ContactSensorCfg.track_friction_forces` toggle to turn on friction
tracking
- `ContactSensorData.friction_forces_w` where the sum of friction forces
for each filtered body are stored

Fixes https://github.com/isaac-sim/IsaacLab/issues/2074, #2064

## Performance

Results of `check_contact_sensor.py` with `track_friction_data = False`:
```
avg dt real-time 0.017448579105403043
avg dt real-time 0.017589360827958443
avg dt real-time 0.016146250123070787
```

Results of `check_contact_sensor.py` with `track_friction_data = True`:
```
avg dt real-time 0.01818224351439858
avg dt real-time 0.017720674386015163
avg dt real-time 0.01777262271923246
```

## Type of change
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have read and understood the [contribution
guidelines](https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html)
- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 190cc1eb
...@@ -36,6 +36,7 @@ Guidelines for modifications: ...@@ -36,6 +36,7 @@ Guidelines for modifications:
* Pascal Roth * Pascal Roth
* Sheikh Dawood * Sheikh Dawood
* Ossama Ahmed * Ossama Ahmed
* Greg Attra
## Contributors ## Contributors
......
Changelog Changelog
--------- ---------
0.49.2 (2025-11-26) 0.49.2 (2025-11-17)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
Changed Added
^^^^^^^ ^^^^^
* Changed import from ``isaacsim.core.utils.prims`` to ``isaaclab.sim.utils.prims`` across repo to reduce IsaacLab dependencies.
* Added :attr:`~isaaclab.sensors.contact_sensor.ContactSensorCfg.track_friction_forces` to toggle tracking of friction forces between sensor bodies and filtered bodies.
* Added :attr:`~isaaclab.sensors.contact_sensor.ContactSensorData.friction_forces_w` data field for tracking friction forces.
0.49.1 (2025-12-08) 0.49.1 (2025-11-26)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
Added Changed
^^^^^ ^^^^^^^
* Added write to file on close to :class:`~isaaclab.manager.RecorderManager`. * Changed import from ``isaacsim.core.utils.prims`` to ``isaaclab.sim.utils.prims`` across repo to reduce IsaacLab dependencies.
* Added :attr:`~isaaclab.manager.RecorderManagerCfg.export_in_close` configuration parameter.
0.49.0 (2025-11-10) 0.49.0 (2025-11-10)
...@@ -122,6 +121,7 @@ Added ...@@ -122,6 +121,7 @@ Added
* Added demo script ``scripts/demos/haply_teleoperation.py`` and documentation guide in * Added demo script ``scripts/demos/haply_teleoperation.py`` and documentation guide in
``docs/source/how-to/haply_teleoperation.rst`` for Haply-based robot teleoperation. ``docs/source/how-to/haply_teleoperation.rst`` for Haply-based robot teleoperation.
0.48.0 (2025-11-03) 0.48.0 (2025-11-03)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
...@@ -195,7 +195,7 @@ Changed ...@@ -195,7 +195,7 @@ Changed
0.47.6 (2025-11-01) 0.47.6 (2025-11-01)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
Fixed Fixed
^^^^^ ^^^^^
......
...@@ -162,8 +162,9 @@ class ContactSensor(SensorBase): ...@@ -162,8 +162,9 @@ class ContactSensor(SensorBase):
# reset contact positions # reset contact positions
if self.cfg.track_contact_points: if self.cfg.track_contact_points:
self._data.contact_pos_w[env_ids, :] = torch.nan self._data.contact_pos_w[env_ids, :] = torch.nan
# buffer used during contact position aggregation # reset friction forces
self._contact_position_aggregate_buffer[env_ids, :] = torch.nan if self.cfg.track_friction_forces:
self._data.friction_forces_w[env_ids, :] = 0.0
def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]: def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys. """Find bodies in the articulation based on the name keys.
...@@ -310,6 +311,21 @@ class ContactSensor(SensorBase): ...@@ -310,6 +311,21 @@ class ContactSensor(SensorBase):
if self.cfg.track_pose: if self.cfg.track_pose:
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device) self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device) self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
# check if filter paths are valid
if self.cfg.track_contact_points or self.cfg.track_friction_forces:
if len(self.cfg.filter_prim_paths_expr) == 0:
raise ValueError(
"The 'filter_prim_paths_expr' is empty. Please specify a valid filter pattern to track"
f" {'contact points' if self.cfg.track_contact_points else 'friction forces'}."
)
if self.cfg.max_contact_data_count_per_prim < 1:
raise ValueError(
f"The 'max_contact_data_count_per_prim' is {self.cfg.max_contact_data_count_per_prim}. "
"Please set it to a value greater than 0 to track"
f" {'contact points' if self.cfg.track_contact_points else 'friction forces'}."
)
# -- position of contact points # -- position of contact points
if self.cfg.track_contact_points: if self.cfg.track_contact_points:
self._data.contact_pos_w = torch.full( self._data.contact_pos_w = torch.full(
...@@ -317,10 +333,11 @@ class ContactSensor(SensorBase): ...@@ -317,10 +333,11 @@ class ContactSensor(SensorBase):
torch.nan, torch.nan,
device=self._device, device=self._device,
) )
# buffer used during contact position aggregation # -- friction forces at contact points
self._contact_position_aggregate_buffer = torch.full( if self.cfg.track_friction_forces:
(self._num_bodies * self._num_envs, self.contact_physx_view.filter_count, 3), self._data.friction_forces_w = torch.full(
torch.nan, (self._num_envs, self._num_bodies, self.contact_physx_view.filter_count, 3),
0.0,
device=self._device, device=self._device,
) )
# -- air/contact time between contacts # -- air/contact time between contacts
...@@ -382,28 +399,17 @@ class ContactSensor(SensorBase): ...@@ -382,28 +399,17 @@ class ContactSensor(SensorBase):
_, buffer_contact_points, _, _, buffer_count, buffer_start_indices = ( _, buffer_contact_points, _, _, buffer_count, buffer_start_indices = (
self.contact_physx_view.get_contact_data(dt=self._sim_physics_dt) self.contact_physx_view.get_contact_data(dt=self._sim_physics_dt)
) )
# unpack the contact points: see RigidContactView.get_contact_data() documentation for details: self._data.contact_pos_w[env_ids] = self._unpack_contact_buffer_data(
# https://docs.omniverse.nvidia.com/kit/docs/omni_physics/107.3/extensions/runtime/source/omni.physics.tensors/docs/api/python.html#omni.physics.tensors.impl.api.RigidContactView.get_net_contact_forces buffer_contact_points, buffer_count, buffer_start_indices
# buffer_count: (N_envs * N_bodies, N_filters), buffer_contact_points: (N_envs * N_bodies, 3) )[env_ids]
counts, starts = buffer_count.view(-1), buffer_start_indices.view(-1)
n_rows, total = counts.numel(), int(counts.sum())
# default to NaN rows
agg = torch.full((n_rows, 3), float("nan"), device=self._device, dtype=buffer_contact_points.dtype)
if total > 0:
row_ids = torch.repeat_interleave(torch.arange(n_rows, device=self._device), counts)
total = row_ids.numel()
block_starts = counts.cumsum(0) - counts
deltas = torch.arange(total, device=counts.device) - block_starts.repeat_interleave(counts)
flat_idx = starts[row_ids] + deltas
pts = buffer_contact_points.index_select(0, flat_idx)
agg = agg.zero_().index_add_(0, row_ids, pts) / counts.clamp_min(1).unsqueeze(1)
agg[counts == 0] = float("nan")
self._contact_position_aggregate_buffer[:] = agg.view(self._num_envs * self.num_bodies, -1, 3) # obtain friction forces
self._data.contact_pos_w[env_ids] = self._contact_position_aggregate_buffer.view( if self.cfg.track_friction_forces:
self._num_envs, self._num_bodies, self.contact_physx_view.filter_count, 3 friction_forces, _, buffer_count, buffer_start_indices = self.contact_physx_view.get_friction_data(
dt=self._sim_physics_dt
)
self._data.friction_forces_w[env_ids] = self._unpack_contact_buffer_data(
friction_forces, buffer_count, buffer_start_indices, avg=False, default=0.0
)[env_ids] )[env_ids]
# obtain the air time # obtain the air time
...@@ -436,6 +442,58 @@ class ContactSensor(SensorBase): ...@@ -436,6 +442,58 @@ class ContactSensor(SensorBase):
is_contact, self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0 is_contact, self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0
) )
def _unpack_contact_buffer_data(
self,
contact_data: torch.Tensor,
buffer_count: torch.Tensor,
buffer_start_indices: torch.Tensor,
avg: bool = True,
default: float = float("nan"),
) -> torch.Tensor:
"""
Unpacks and aggregates contact data for each (env, body, filter) group.
This function vectorizes the following nested loop:
for i in range(self._num_bodies * self._num_envs):
for j in range(self.contact_physx_view.filter_count):
start_index_ij = buffer_start_indices[i, j]
count_ij = buffer_count[i, j]
self._contact_position_aggregate_buffer[i, j, :] = torch.mean(
contact_data[start_index_ij : (start_index_ij + count_ij), :], dim=0
)
For more details, see the `RigidContactView.get_contact_data() documentation <https://docs.omniverse.nvidia.com/kit/docs/omni_physics/107.3/extensions/runtime/source/omni.physics.tensors/docs/api/python.html#omni.physics.tensors.impl.api.RigidContactView.get_contact_data>`_.
Args:
contact_data: Flat tensor of contact data, shape (N_envs * N_bodies, 3).
buffer_count: Number of contact points per (env, body, filter), shape (N_envs * N_bodies, N_filters).
buffer_start_indices: Start indices for each (env, body, filter), shape (N_envs * N_bodies, N_filters).
avg: If True, average the contact data for each group; if False, sum the data. Defaults to True.
default: Default value to use for groups with zero contacts. Defaults to NaN.
Returns:
Aggregated contact data, shape (N_envs, N_bodies, N_filters, 3).
"""
counts, starts = buffer_count.view(-1), buffer_start_indices.view(-1)
n_rows, total = counts.numel(), int(counts.sum())
agg = torch.full((n_rows, 3), default, device=self._device, dtype=contact_data.dtype)
if total > 0:
row_ids = torch.repeat_interleave(torch.arange(n_rows, device=self._device), counts)
block_starts = counts.cumsum(0) - counts
deltas = torch.arange(row_ids.numel(), device=counts.device) - block_starts.repeat_interleave(counts)
flat_idx = starts[row_ids] + deltas
pts = contact_data.index_select(0, flat_idx)
agg = agg.zero_().index_add_(0, row_ids, pts)
agg = agg / counts.clamp_min(1).unsqueeze(-1) if avg else agg
agg[counts == 0] = default
return agg.view(self._num_envs * self.num_bodies, -1, 3).view(
self._num_envs, self._num_bodies, self.contact_physx_view.filter_count, 3
)
def _set_debug_vis_impl(self, debug_vis: bool): def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers # set visibility of markers
# note: parent only deals with callbacks. not their visibility # note: parent only deals with callbacks. not their visibility
......
...@@ -23,6 +23,9 @@ class ContactSensorCfg(SensorBaseCfg): ...@@ -23,6 +23,9 @@ class ContactSensorCfg(SensorBaseCfg):
track_contact_points: bool = False track_contact_points: bool = False
"""Whether to track the contact point locations. Defaults to False.""" """Whether to track the contact point locations. Defaults to False."""
track_friction_forces: bool = False
"""Whether to track the friction forces at the contact points. Defaults to False."""
max_contact_data_count_per_prim: int = 4 max_contact_data_count_per_prim: int = 4
"""The maximum number of contacts across all batches of the sensor to keep track of. Default is 4. """The maximum number of contacts across all batches of the sensor to keep track of. Default is 4.
......
...@@ -35,12 +35,32 @@ class ContactSensorData: ...@@ -35,12 +35,32 @@ class ContactSensorData:
Note: Note:
* If the :attr:`ContactSensorCfg.track_contact_points` is False, then this quantity is None. * If the :attr:`ContactSensorCfg.track_contact_points` is False, then this quantity is None.
* If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is an empty tensor. * If the :attr:`ContactSensorCfg.track_contact_points` is True, a ValueError will be raised if:
* If the :attr:`ContactSensorCfg.max_contact_data_per_prim` is not specified or less than 1, then this quantity
* If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty.
* If the :attr:`ContactSensorCfg.max_contact_data_per_prim` is not specified or less than 1.
will not be calculated. will not be calculated.
""" """
friction_forces_w: torch.Tensor | None = None
"""Sum of the friction forces between sensor body and filter prim in world frame.
Shape is (N, B, M, 3), where N is the number of sensors, B is number of bodies in each sensor
and M is the number of filtered bodies.
Collision pairs not in contact will result in NaN.
Note:
* If the :attr:`ContactSensorCfg.track_friction_forces` is False, then this quantity is None.
* If the :attr:`ContactSensorCfg.track_friction_forces` is True, a ValueError will be raised if:
* The :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty.
* The :attr:`ContactSensorCfg.max_contact_data_per_prim` is not specified or less than 1.
"""
quat_w: torch.Tensor | None = None quat_w: torch.Tensor | None = None
"""Orientation of the sensor origin in quaternion (w, x, y, z) in world frame. """Orientation of the sensor origin in quaternion (w, x, y, z) in world frame.
......
...@@ -105,6 +105,7 @@ def main(): ...@@ -105,6 +105,7 @@ def main():
prim_path="/World/envs/env_.*/Robot/.*_FOOT", prim_path="/World/envs/env_.*/Robot/.*_FOOT",
track_air_time=True, track_air_time=True,
track_contact_points=True, track_contact_points=True,
track_friction_forces=True,
debug_vis=False, # not args_cli.headless, debug_vis=False, # not args_cli.headless,
filter_prim_paths_expr=["/World/defaultGroundPlane/GroundPlane/CollisionPlane"], filter_prim_paths_expr=["/World/defaultGroundPlane/GroundPlane/CollisionPlane"],
) )
......
...@@ -27,7 +27,7 @@ import isaaclab.sim as sim_utils ...@@ -27,7 +27,7 @@ import isaaclab.sim as sim_utils
from isaaclab.assets import RigidObject, RigidObjectCfg from isaaclab.assets import RigidObject, RigidObjectCfg
from isaaclab.scene import InteractiveScene, InteractiveSceneCfg from isaaclab.scene import InteractiveScene, InteractiveSceneCfg
from isaaclab.sensors import ContactSensor, ContactSensorCfg from isaaclab.sensors import ContactSensor, ContactSensorCfg
from isaaclab.sim import SimulationContext, build_simulation_context from isaaclab.sim import SimulationCfg, SimulationContext, build_simulation_context
from isaaclab.sim.utils.stage import get_current_stage from isaaclab.sim.utils.stage import get_current_stage
from isaaclab.terrains import HfRandomUniformTerrainCfg, TerrainGeneratorCfg, TerrainImporterCfg from isaaclab.terrains import HfRandomUniformTerrainCfg, TerrainGeneratorCfg, TerrainImporterCfg
from isaaclab.utils import configclass from isaaclab.utils import configclass
...@@ -438,6 +438,137 @@ def test_contact_sensor_threshold(setup_simulation, device): ...@@ -438,6 +438,137 @@ def test_contact_sensor_threshold(setup_simulation, device):
), f"Expected USD threshold to be close to 0.0, but got {threshold_value}" ), f"Expected USD threshold to be close to 0.0, but got {threshold_value}"
# minor gravity force in -z to ensure object stays on ground plane
@pytest.mark.parametrize("grav_dir", [(-10.0, 0.0, -0.1), (0.0, -10.0, -0.1)])
@pytest.mark.isaacsim_ci
def test_friction_reporting(setup_simulation, grav_dir):
"""
Test friction force reporting for contact sensors.
This test places a contact sensor enabled cube onto a ground plane under different gravity directions.
It then compares the normalized friction force dir with the direction of gravity to ensure they are aligned.
"""
sim_dt, _, _, _, carb_settings_iface = setup_simulation
carb_settings_iface.set_bool("/physics/disableContactProcessing", True)
device = "cuda:0"
sim_cfg = SimulationCfg(dt=sim_dt, device=device, gravity=grav_dir)
with build_simulation_context(sim_cfg=sim_cfg, add_lighting=False) as sim:
sim._app_control_on_stop_handle = None
scene_cfg = ContactSensorSceneCfg(num_envs=1, env_spacing=1.0, lazy_sensor_update=False)
scene_cfg.terrain = FLAT_TERRAIN_CFG
scene_cfg.shape = CUBE_CFG
filter_prim_paths_expr = [scene_cfg.terrain.prim_path + "/terrain/GroundPlane/CollisionPlane"]
scene_cfg.contact_sensor = ContactSensorCfg(
prim_path=scene_cfg.shape.prim_path,
track_pose=True,
debug_vis=False,
update_period=0.0,
track_air_time=True,
history_length=3,
track_friction_forces=True,
filter_prim_paths_expr=filter_prim_paths_expr,
)
scene = InteractiveScene(scene_cfg)
sim.reset()
scene["contact_sensor"].reset()
scene["shape"].write_root_pose_to_sim(
root_pose=torch.tensor([0, 0.0, CUBE_CFG.spawn.size[2] / 2.0, 1, 0, 0, 0])
)
# step sim once to compute friction forces
_perform_sim_step(sim, scene, sim_dt)
# check that forces are being reported match expected friction forces
expected_friction, _, _, _ = scene["contact_sensor"].contact_physx_view.get_friction_data(dt=sim_dt)
reported_friction = scene["contact_sensor"].data.friction_forces_w[0, 0, :]
torch.testing.assert_close(expected_friction.sum(dim=0), reported_friction[0], atol=1e-6, rtol=1e-5)
# check that friction force direction opposes gravity direction
grav = torch.tensor(grav_dir, device=device)
norm_reported_friction = reported_friction / reported_friction.norm()
norm_gravity = grav / grav.norm()
dot = torch.dot(norm_reported_friction[0], norm_gravity)
torch.testing.assert_close(torch.abs(dot), torch.tensor(1.0, device=device), atol=1e-4, rtol=1e-3)
@pytest.mark.isaacsim_ci
def test_invalid_prim_paths_config(setup_simulation):
sim_dt, _, _, _, carb_settings_iface = setup_simulation
carb_settings_iface.set_bool("/physics/disableContactProcessing", True)
device = "cuda:0"
sim_cfg = SimulationCfg(dt=sim_dt, device=device)
with build_simulation_context(sim_cfg=sim_cfg, add_lighting=False) as sim:
sim._app_control_on_stop_handle = None
scene_cfg = ContactSensorSceneCfg(num_envs=1, env_spacing=1.0, lazy_sensor_update=False)
scene_cfg.terrain = FLAT_TERRAIN_CFG
scene_cfg.shape = CUBE_CFG
scene_cfg.contact_sensor = ContactSensorCfg(
prim_path=scene_cfg.shape.prim_path,
track_pose=True,
debug_vis=False,
update_period=0.0,
track_air_time=True,
history_length=3,
track_friction_forces=True,
filter_prim_paths_expr=[],
)
try:
_ = InteractiveScene(scene_cfg)
sim.reset()
assert False, "Expected ValueError due to invalid contact sensor configuration."
except ValueError:
pass
@pytest.mark.isaacsim_ci
def test_invalid_max_contact_points_config(setup_simulation):
sim_dt, _, _, _, carb_settings_iface = setup_simulation
carb_settings_iface.set_bool("/physics/disableContactProcessing", True)
device = "cuda:0"
sim_cfg = SimulationCfg(dt=sim_dt, device=device)
with build_simulation_context(sim_cfg=sim_cfg, add_lighting=False) as sim:
sim._app_control_on_stop_handle = None
scene_cfg = ContactSensorSceneCfg(num_envs=1, env_spacing=1.0, lazy_sensor_update=False)
scene_cfg.terrain = FLAT_TERRAIN_CFG
scene_cfg.shape = CUBE_CFG
filter_prim_paths_expr = [scene_cfg.terrain.prim_path + "/terrain/GroundPlane/CollisionPlane"]
scene_cfg.contact_sensor = ContactSensorCfg(
prim_path=scene_cfg.shape.prim_path,
track_pose=True,
debug_vis=False,
update_period=0.0,
track_air_time=True,
history_length=3,
track_friction_forces=True,
filter_prim_paths_expr=filter_prim_paths_expr,
max_contact_data_count_per_prim=0,
)
try:
_ = InteractiveScene(scene_cfg)
sim.reset()
assert False, "Expected ValueError due to invalid contact sensor configuration."
except ValueError:
pass
""" """
Internal helpers. Internal helpers.
""" """
...@@ -459,20 +590,20 @@ def _run_contact_sensor_test( ...@@ -459,20 +590,20 @@ def _run_contact_sensor_test(
""" """
for device in devices: for device in devices:
for terrain in terrains: for terrain in terrains:
for track_contact_points in [True, False]: for track_contact_data in [True, False]:
with build_simulation_context(device=device, dt=sim_dt, add_lighting=True) as sim: with build_simulation_context(device=device, dt=sim_dt, add_lighting=True) as sim:
sim._app_control_on_stop_handle = None sim._app_control_on_stop_handle = None
scene_cfg = ContactSensorSceneCfg(num_envs=1, env_spacing=1.0, lazy_sensor_update=False) scene_cfg = ContactSensorSceneCfg(num_envs=1, env_spacing=1.0, lazy_sensor_update=False)
scene_cfg.terrain = terrain scene_cfg.terrain = terrain
scene_cfg.shape = shape_cfg scene_cfg.shape = shape_cfg
test_contact_position = False test_contact_data = False
if (type(shape_cfg.spawn) is sim_utils.SphereCfg) and (terrain.terrain_type == "plane"): if (type(shape_cfg.spawn) is sim_utils.SphereCfg) and (terrain.terrain_type == "plane"):
test_contact_position = True test_contact_data = True
elif track_contact_points: elif track_contact_data:
continue continue
if track_contact_points: if track_contact_data:
if terrain.terrain_type == "plane": if terrain.terrain_type == "plane":
filter_prim_paths_expr = [terrain.prim_path + "/terrain/GroundPlane/CollisionPlane"] filter_prim_paths_expr = [terrain.prim_path + "/terrain/GroundPlane/CollisionPlane"]
elif terrain.terrain_type == "generator": elif terrain.terrain_type == "generator":
...@@ -487,7 +618,8 @@ def _run_contact_sensor_test( ...@@ -487,7 +618,8 @@ def _run_contact_sensor_test(
update_period=0.0, update_period=0.0,
track_air_time=True, track_air_time=True,
history_length=3, history_length=3,
track_contact_points=track_contact_points, track_contact_points=track_contact_data,
track_friction_forces=track_contact_data,
filter_prim_paths_expr=filter_prim_paths_expr, filter_prim_paths_expr=filter_prim_paths_expr,
) )
scene = InteractiveScene(scene_cfg) scene = InteractiveScene(scene_cfg)
...@@ -504,7 +636,7 @@ def _run_contact_sensor_test( ...@@ -504,7 +636,7 @@ def _run_contact_sensor_test(
scene=scene, scene=scene,
sim_dt=sim_dt, sim_dt=sim_dt,
durations=durations, durations=durations,
test_contact_position=test_contact_position, test_contact_data=test_contact_data,
) )
_test_sensor_contact( _test_sensor_contact(
shape=scene["shape"], shape=scene["shape"],
...@@ -514,7 +646,7 @@ def _run_contact_sensor_test( ...@@ -514,7 +646,7 @@ def _run_contact_sensor_test(
scene=scene, scene=scene,
sim_dt=sim_dt, sim_dt=sim_dt,
durations=durations, durations=durations,
test_contact_position=test_contact_position, test_contact_data=test_contact_data,
) )
...@@ -526,7 +658,7 @@ def _test_sensor_contact( ...@@ -526,7 +658,7 @@ def _test_sensor_contact(
scene: InteractiveScene, scene: InteractiveScene,
sim_dt: float, sim_dt: float,
durations: list[float], durations: list[float],
test_contact_position: bool = False, test_contact_data: bool = False,
): ):
"""Test for the contact sensor. """Test for the contact sensor.
...@@ -593,8 +725,11 @@ def _test_sensor_contact( ...@@ -593,8 +725,11 @@ def _test_sensor_contact(
expected_last_air_time=expected_last_test_contact_time, expected_last_air_time=expected_last_test_contact_time,
dt=duration + sim_dt, dt=duration + sim_dt,
) )
if test_contact_position:
if test_contact_data:
_test_contact_position(shape, sensor, mode) _test_contact_position(shape, sensor, mode)
_test_friction_forces(shape, sensor, mode)
# switch the contact mode for 1 dt step before the next contact test begins. # switch the contact mode for 1 dt step before the next contact test begins.
shape.write_root_pose_to_sim(root_pose=reset_pose) shape.write_root_pose_to_sim(root_pose=reset_pose)
# perform simulation step # perform simulation step
...@@ -605,6 +740,33 @@ def _test_sensor_contact( ...@@ -605,6 +740,33 @@ def _test_sensor_contact(
expected_last_reset_contact_time = 2 * sim_dt expected_last_reset_contact_time = 2 * sim_dt
def _test_friction_forces(shape: RigidObject, sensor: ContactSensor, mode: ContactTestMode) -> None:
if not sensor.cfg.track_friction_forces:
assert sensor._data.friction_forces_w is None
return
# check shape of the contact_pos_w tensor
num_bodies = sensor.num_bodies
assert sensor._data.friction_forces_w.shape == (sensor.num_instances // num_bodies, num_bodies, 1, 3)
# compare friction forces
if mode == ContactTestMode.IN_CONTACT:
assert torch.any(torch.abs(sensor._data.friction_forces_w) > 1e-5).item()
friction_forces, _, buffer_count, buffer_start_indices = sensor.contact_physx_view.get_friction_data(
dt=sensor._sim_physics_dt
)
for i in range(sensor.num_instances * num_bodies):
for j in range(sensor.contact_physx_view.filter_count):
start_index_ij = buffer_start_indices[i, j]
count_ij = buffer_count[i, j]
force = torch.sum(friction_forces[start_index_ij : (start_index_ij + count_ij), :], dim=0)
env_idx = i // num_bodies
body_idx = i % num_bodies
assert torch.allclose(force, sensor._data.friction_forces_w[env_idx, body_idx, j, :], atol=1e-5)
elif mode == ContactTestMode.NON_CONTACT:
assert torch.all(sensor._data.friction_forces_w == 0.0).item()
def _test_contact_position(shape: RigidObject, sensor: ContactSensor, mode: ContactTestMode) -> None: def _test_contact_position(shape: RigidObject, sensor: ContactSensor, mode: ContactTestMode) -> None:
"""Test for the contact positions (only implemented for sphere and flat terrain) """Test for the contact positions (only implemented for sphere and flat terrain)
checks that the contact position is radius distance away from the root of the object checks that the contact position is radius distance away from the root of the object
...@@ -613,10 +775,13 @@ def _test_contact_position(shape: RigidObject, sensor: ContactSensor, mode: Cont ...@@ -613,10 +775,13 @@ def _test_contact_position(shape: RigidObject, sensor: ContactSensor, mode: Cont
sensor: The sensor reporting data to be verified by the contact sensor test. sensor: The sensor reporting data to be verified by the contact sensor test.
mode: The contact test mode: either contact with ground plane or air time. mode: The contact test mode: either contact with ground plane or air time.
""" """
if sensor.cfg.track_contact_points: if not sensor.cfg.track_contact_points:
assert sensor._data.contact_pos_w is None
return
# check shape of the contact_pos_w tensor # check shape of the contact_pos_w tensor
num_bodies = sensor.num_bodies num_bodies = sensor.num_bodies
assert sensor._data.contact_pos_w.shape == (sensor.num_instances / num_bodies, num_bodies, 1, 3) assert sensor._data.contact_pos_w.shape == (sensor.num_instances // num_bodies, num_bodies, 1, 3)
# check contact positions # check contact positions
if mode == ContactTestMode.IN_CONTACT: if mode == ContactTestMode.IN_CONTACT:
contact_position = sensor._data.pos_w + torch.tensor( contact_position = sensor._data.pos_w + torch.tensor(
...@@ -627,8 +792,6 @@ def _test_contact_position(shape: RigidObject, sensor: ContactSensor, mode: Cont ...@@ -627,8 +792,6 @@ def _test_contact_position(shape: RigidObject, sensor: ContactSensor, mode: Cont
).item() ).item()
elif mode == ContactTestMode.NON_CONTACT: elif mode == ContactTestMode.NON_CONTACT:
assert torch.all(torch.isnan(sensor._data.contact_pos_w)).item() assert torch.all(torch.isnan(sensor._data.contact_pos_w)).item()
else:
assert sensor._data.contact_pos_w is None
def _check_prim_contact_state_times( def _check_prim_contact_state_times(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment