Unverified Commit e1534550 authored by Farbod Farshidian's avatar Farbod Farshidian Committed by GitHub

Enables contact sensor to measure both in-contact and detached intervals (#412)

# Description

Previously, the contact sensor only measured the air time (last_air_time
and current_air_time). In this PR, the book keeps both the air-time and
contact time. Moreover, the arbitrary thresholding for detecting contact
is changed to a config parameter.

This is currently a non-breaking feature. But I suggest renaming
`ContactSensorCfg.track_air_time` to `ContactSensorCfg.track_intervals`.
This will make the PR a breaking feature.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [X] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [X] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarFarbod Farshidian <ffarshidian@theaiinstitute.com>
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent e6012c96
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.10.23"
version = "0.10.24"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.10.24 (2024-02-26)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added tracking of contact time in the :class:`omni.isaac.orbit.sensors.ContactSensor` class. Previously,
only the air time was being tracked.
* Added contact force threshold, :attr:`omni.isaac.orbit.sensors.ContactSensorCfg.force_threshold`, to detect
when the contact sensor is in contact. Previously, this was set to hard-coded 1.0 in the sensor class.
0.10.23 (2024-02-21)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -135,6 +135,8 @@ class ContactSensor(SensorBase):
if self.cfg.track_air_time:
self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0
self._data.current_contact_time[env_ids] = 0.0
self._data.last_contact_time[env_ids] = 0.0
def find_bodies(self, name_keys: str | Sequence[str]) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys.
......@@ -148,6 +150,77 @@ class ContactSensor(SensorBase):
"""
return string_utils.resolve_matching_names(name_keys, self.body_names)
def compute_first_contact(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Checks if bodies that have established contact within the last :attr:`dt` seconds.
This function checks if the bodies have established contact within the last :attr:`dt` seconds
by comparing the current contact time with the given time period. If the contact time is less
than the given time period, then the bodies are considered to be in contact.
Note:
The function assumes that :attr:`dt` is a factor of the sensor update time-step. In other
words :math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true
if the sensor is updated by the physics or the environment stepping time-step and the sensor
is read by the environment stepping time-step.
Args:
dt: The time period since the contact was established.
abs_tol: The absolute tolerance for the comparison.
Returns:
A boolean tensor indicating the bodies that have established contact within the last
:attr:`dt` seconds. Shape is (N, B), where N is the number of sensors and B is the
number of bodies in each sensor.
Raises:
RuntimeError: If the sensor is not configured to track contact time.
"""
# check if the sensor is configured to track contact time
if not self.cfg.track_air_time:
raise RuntimeError(
"The contact sensor is not configured to track contact time."
"Please enable the 'track_air_time' in the sensor configuration."
)
# check if the bodies are in contact
currently_in_contact = self.data.current_contact_time > 0.0
less_than_dt_in_contact = self.data.current_contact_time < (dt + abs_tol)
return currently_in_contact * less_than_dt_in_contact
def compute_first_air(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Checks if bodies that have broken contact within the last :attr:`dt` seconds.
This function checks if the bodies have broken contact within the last :attr:`dt` seconds
by comparing the current air time with the given time period. If the air time is less
than the given time period, then the bodies are considered to not be in contact.
Note:
It assumes that :attr:`dt` is a factor of the sensor update time-step. In other words,
:math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true if
the sensor is updated by the physics or the environment stepping time-step and the sensor
is read by the environment stepping time-step.
Args:
dt: The time period since the contract is broken.
abs_tol: The absolute tolerance for the comparison.
Returns:
A boolean tensor indicating the bodies that have broken contact within the last :attr:`dt` seconds.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Raises:
RuntimeError: If the sensor is not configured to track contact time.
"""
# check if the sensor is configured to track contact time
if not self.cfg.track_air_time:
raise RuntimeError(
"The contact sensor is not configured to track contact time."
"Please enable the 'track_air_time' in the sensor configuration."
)
# check if the sensor is configured to track contact time
currently_detached = self.data.current_air_time > 0.0
less_than_dt_detached = self.data.current_air_time < (dt + abs_tol)
return currently_detached * less_than_dt_detached
"""
Implementation.
"""
......@@ -205,10 +278,12 @@ class ContactSensor(SensorBase):
if self.cfg.track_pose:
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
# -- air time between contacts
# -- air/contact time between contacts
if self.cfg.track_air_time:
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.last_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
# force matrix: (num_envs, num_bodies, num_filter_shapes, 3)
if len(self.cfg.filter_prim_paths_expr) != 0:
num_filters = self.contact_physx_view.filter_count
......@@ -251,14 +326,29 @@ class ContactSensor(SensorBase):
# since this function is called every frame, we can use the difference to get the elapsed time
elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids]
# -- check contact state of bodies
is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > 1.0
is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > self.cfg.force_threshold
is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact
# -- update ongoing timer for bodies air
self._data.current_air_time[env_ids] += elapsed_time.unsqueeze(-1)
# -- update time for the last time bodies were in contact
self._data.last_air_time[env_ids] = self._data.current_air_time[env_ids] * is_first_contact
# -- increment timers for bodies that are not in contact
self._data.current_air_time[env_ids] *= ~is_contact
is_first_detached = (self._data.current_contact_time[env_ids] > 0) * ~is_contact
# -- update the last contact time if body has just become in contact
self._data.last_air_time[env_ids] = torch.where(
is_first_contact,
self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1),
self._data.last_air_time[env_ids],
)
# -- increment time for bodies that are not in contact
self._data.current_air_time[env_ids] = torch.where(
~is_contact, self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0
)
# -- update the last contact time if body has just detached
self._data.last_contact_time[env_ids] = torch.where(
is_first_detached,
self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1),
self._data.last_contact_time[env_ids],
)
# -- increment time for bodies that are in contact
self._data.current_contact_time[env_ids] = torch.where(
is_contact, self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0
)
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
......@@ -281,7 +371,7 @@ class ContactSensor(SensorBase):
# marker indices
# 0: contact, 1: no contact
net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1)
marker_indices = torch.where(net_contact_force_w > 1.0, 0, 1)
marker_indices = torch.where(net_contact_force_w > self.cfg.force_threshold, 0, 1)
# check if prim is visualized
if self.cfg.track_pose:
frame_origins: torch.Tensor = self._data.pos_w
......
......@@ -23,7 +23,14 @@ class ContactSensorCfg(SensorBaseCfg):
"""Whether to track the pose of the sensor's origin. Defaults to False."""
track_air_time: bool = False
"""Whether to track the air time of the bodies (time between contacts). Defaults to False."""
"""Whether to track the air/contact time of the bodies (time between contacts). Defaults to False."""
force_threshold: float = 1.0
"""The threshold on the norm of the contact force that determines whether two bodies are in collision or not.
This value is only used for tracking the mode duration (the time in contact or in air),
if :attr:`track_air_time` is True.
"""
filter_prim_paths_expr: list[str] = list()
"""The list of primitive paths to filter contacts with.
......
......@@ -31,13 +31,13 @@ class ContactSensorData:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
"""
net_forces_w: torch.Tensor = None
net_forces_w: torch.Tensor | None = None
"""The net contact forces in world frame.
Shape is (N, B, 3), where N is the number of sensors and B is the number of bodies in each sensor.
"""
net_forces_w_history: torch.Tensor = None
net_forces_w_history: torch.Tensor | None = None
"""The net contact forces in world frame.
Shape is (N, T, B, 3), where N is the number of sensors, T is the configured history length
......@@ -59,16 +59,34 @@ class ContactSensorData:
last_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air before the last contact.
Shape is (N,), where N is the number of sensors.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
current_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air since the last contact.
"""Time spent (in s) in the air since the last detach.
Shape is (N,), where N is the number of sensors.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
last_contact_time: torch.Tensor | None = None
"""Time spent (in s) in contact before the last detach.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
current_contact_time: torch.Tensor | None = None
"""Time spent (in s) in contact since the last contact.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
......
......@@ -17,7 +17,7 @@ class AnymalBFlatEnvCfg(AnymalBRoughEnvCfg):
# override rewards
self.rewards.flat_orientation_l2.weight = -5.0
self.rewards.dof_torques_l2.weight = -2.5e-5
self.rewards.feet_air_time.weight = 2.0
self.rewards.feet_air_time.weight = 0.5
# change terrain to flat
self.scene.terrain.terrain_type = "plane"
self.scene.terrain.terrain_generator = None
......
......@@ -17,7 +17,7 @@ class AnymalCFlatEnvCfg(AnymalCRoughEnvCfg):
# override rewards
self.rewards.flat_orientation_l2.weight = -5.0
self.rewards.dof_torques_l2.weight = -2.5e-5
self.rewards.feet_air_time.weight = 2.0
self.rewards.feet_air_time.weight = 0.5
# change terrain to flat
self.scene.terrain.terrain_type = "plane"
self.scene.terrain.terrain_generator = None
......
......@@ -17,7 +17,7 @@ class AnymalDFlatEnvCfg(AnymalDRoughEnvCfg):
# override rewards
self.rewards.flat_orientation_l2.weight = -5.0
self.rewards.dof_torques_l2.weight = -2.5e-5
self.rewards.feet_air_time.weight = 2.0
self.rewards.feet_air_time.weight = 0.5
# change terrain to flat
self.scene.terrain.terrain_type = "plane"
self.scene.terrain.terrain_generator = None
......
......@@ -15,7 +15,7 @@ class CassieFlatEnvCfg(CassieRoughEnvCfg):
super().__post_init__()
# rewards
self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 20.0
self.rewards.feet_air_time.weight = 5.0
self.rewards.joint_deviation_hip.params["asset_cfg"].joint_names = ["hip_rotation_.*"]
# change terrain to flat
self.scene.terrain.terrain_type = "plane"
......
......@@ -21,7 +21,7 @@ class CassieRewardsCfg(RewardsCfg):
termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0)
feet_air_time = RewTerm(
func=mdp.feet_air_time_positive_biped,
weight=10.0,
weight=2.5,
params={
"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*toe"),
"command_name": "base_velocity",
......
......@@ -16,7 +16,7 @@ class UnitreeA1FlatEnvCfg(UnitreeA1RoughEnvCfg):
# override rewards
self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 1.0
self.rewards.feet_air_time.weight = 0.25
# change terrain to flat
self.scene.terrain.terrain_type = "plane"
......
......@@ -16,7 +16,7 @@ class UnitreeGo1FlatEnvCfg(UnitreeGo1RoughEnvCfg):
# override rewards
self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 1.0
self.rewards.feet_air_time.weight = 0.25
# change terrain to flat
self.scene.terrain.terrain_type = "plane"
......
......@@ -16,7 +16,7 @@ class UnitreeGo2FlatEnvCfg(UnitreeGo2RoughEnvCfg):
# override rewards
self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 1.0
self.rewards.feet_air_time.weight = 0.25
# change terrain to flat
self.scene.terrain.terrain_type = "plane"
......
......@@ -27,8 +27,8 @@ def feet_air_time(env: RLTaskEnv, command_name: str, sensor_cfg: SceneEntityCfg,
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
# compute the reward
first_contact = contact_sensor.compute_first_contact(env.step_dt)[:, sensor_cfg.body_ids]
last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids]
first_contact = last_air_time > 0.0
reward = torch.sum((last_air_time - threshold) * first_contact, dim=1)
# no reward for zero command
reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1
......@@ -43,12 +43,15 @@ def feet_air_time_positive_biped(env, command_name: str, threshold: float, senso
If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero.
"""
contact_sensor = env.scene.sensors[sensor_cfg.name]
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
# compute the reward
last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids]
was_in_air = (last_air_time > 0.0).float()
num_feet_contact = torch.sum(was_in_air, dim=1).int()
reward = torch.where(num_feet_contact == 1, torch.sum(last_air_time.clamp_max_(threshold), dim=1), 0.0)
air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]
contact_time = contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids]
in_contact = contact_time > 0.0
in_mode_time = torch.where(in_contact, contact_time, air_time)
single_stance = torch.sum(in_contact.int(), dim=1) == 1
reward = torch.min(torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1)[0]
reward = torch.clamp(reward, max=threshold)
# no reward for zero command
reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1
return reward
......@@ -233,7 +233,7 @@ class RewardsCfg:
action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01)
feet_air_time = RewTerm(
func=mdp.feet_air_time,
weight=0.5,
weight=0.125,
params={
"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"),
"command_name": "base_velocity",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment