Unverified Commit e1534550 authored by Farbod Farshidian's avatar Farbod Farshidian Committed by GitHub

Enables contact sensor to measure both in-contact and detached intervals (#412)

# Description

Previously, the contact sensor only measured the air time (last_air_time
and current_air_time). In this PR, the book keeps both the air-time and
contact time. Moreover, the arbitrary thresholding for detecting contact
is changed to a config parameter.

This is currently a non-breaking feature. But I suggest renaming
`ContactSensorCfg.track_air_time` to `ContactSensorCfg.track_intervals`.
This will make the PR a breaking feature.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [X] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [X] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarFarbod Farshidian <ffarshidian@theaiinstitute.com>
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent e6012c96
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.10.23" version = "0.10.24"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.10.24 (2024-02-26)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added tracking of contact time in the :class:`omni.isaac.orbit.sensors.ContactSensor` class. Previously,
only the air time was being tracked.
* Added contact force threshold, :attr:`omni.isaac.orbit.sensors.ContactSensorCfg.force_threshold`, to detect
when the contact sensor is in contact. Previously, this was set to hard-coded 1.0 in the sensor class.
0.10.23 (2024-02-21) 0.10.23 (2024-02-21)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -135,6 +135,8 @@ class ContactSensor(SensorBase): ...@@ -135,6 +135,8 @@ class ContactSensor(SensorBase):
if self.cfg.track_air_time: if self.cfg.track_air_time:
self._data.current_air_time[env_ids] = 0.0 self._data.current_air_time[env_ids] = 0.0
self._data.last_air_time[env_ids] = 0.0 self._data.last_air_time[env_ids] = 0.0
self._data.current_contact_time[env_ids] = 0.0
self._data.last_contact_time[env_ids] = 0.0
def find_bodies(self, name_keys: str | Sequence[str]) -> tuple[list[int], list[str]]: def find_bodies(self, name_keys: str | Sequence[str]) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys. """Find bodies in the articulation based on the name keys.
...@@ -148,6 +150,77 @@ class ContactSensor(SensorBase): ...@@ -148,6 +150,77 @@ class ContactSensor(SensorBase):
""" """
return string_utils.resolve_matching_names(name_keys, self.body_names) return string_utils.resolve_matching_names(name_keys, self.body_names)
def compute_first_contact(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Checks if bodies that have established contact within the last :attr:`dt` seconds.
This function checks if the bodies have established contact within the last :attr:`dt` seconds
by comparing the current contact time with the given time period. If the contact time is less
than the given time period, then the bodies are considered to be in contact.
Note:
The function assumes that :attr:`dt` is a factor of the sensor update time-step. In other
words :math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true
if the sensor is updated by the physics or the environment stepping time-step and the sensor
is read by the environment stepping time-step.
Args:
dt: The time period since the contact was established.
abs_tol: The absolute tolerance for the comparison.
Returns:
A boolean tensor indicating the bodies that have established contact within the last
:attr:`dt` seconds. Shape is (N, B), where N is the number of sensors and B is the
number of bodies in each sensor.
Raises:
RuntimeError: If the sensor is not configured to track contact time.
"""
# check if the sensor is configured to track contact time
if not self.cfg.track_air_time:
raise RuntimeError(
"The contact sensor is not configured to track contact time."
"Please enable the 'track_air_time' in the sensor configuration."
)
# check if the bodies are in contact
currently_in_contact = self.data.current_contact_time > 0.0
less_than_dt_in_contact = self.data.current_contact_time < (dt + abs_tol)
return currently_in_contact * less_than_dt_in_contact
def compute_first_air(self, dt: float, abs_tol: float = 1.0e-8) -> torch.Tensor:
"""Checks if bodies that have broken contact within the last :attr:`dt` seconds.
This function checks if the bodies have broken contact within the last :attr:`dt` seconds
by comparing the current air time with the given time period. If the air time is less
than the given time period, then the bodies are considered to not be in contact.
Note:
It assumes that :attr:`dt` is a factor of the sensor update time-step. In other words,
:math:`dt / dt_sensor = n`, where :math:`n` is a natural number. This is always true if
the sensor is updated by the physics or the environment stepping time-step and the sensor
is read by the environment stepping time-step.
Args:
dt: The time period since the contract is broken.
abs_tol: The absolute tolerance for the comparison.
Returns:
A boolean tensor indicating the bodies that have broken contact within the last :attr:`dt` seconds.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Raises:
RuntimeError: If the sensor is not configured to track contact time.
"""
# check if the sensor is configured to track contact time
if not self.cfg.track_air_time:
raise RuntimeError(
"The contact sensor is not configured to track contact time."
"Please enable the 'track_air_time' in the sensor configuration."
)
# check if the sensor is configured to track contact time
currently_detached = self.data.current_air_time > 0.0
less_than_dt_detached = self.data.current_air_time < (dt + abs_tol)
return currently_detached * less_than_dt_detached
""" """
Implementation. Implementation.
""" """
...@@ -205,10 +278,12 @@ class ContactSensor(SensorBase): ...@@ -205,10 +278,12 @@ class ContactSensor(SensorBase):
if self.cfg.track_pose: if self.cfg.track_pose:
self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device) self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device)
self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device) self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device)
# -- air time between contacts # -- air/contact time between contacts
if self.cfg.track_air_time: if self.cfg.track_air_time:
self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.last_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
self._data.current_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device)
# force matrix: (num_envs, num_bodies, num_filter_shapes, 3) # force matrix: (num_envs, num_bodies, num_filter_shapes, 3)
if len(self.cfg.filter_prim_paths_expr) != 0: if len(self.cfg.filter_prim_paths_expr) != 0:
num_filters = self.contact_physx_view.filter_count num_filters = self.contact_physx_view.filter_count
...@@ -251,14 +326,29 @@ class ContactSensor(SensorBase): ...@@ -251,14 +326,29 @@ class ContactSensor(SensorBase):
# since this function is called every frame, we can use the difference to get the elapsed time # since this function is called every frame, we can use the difference to get the elapsed time
elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids] elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids]
# -- check contact state of bodies # -- check contact state of bodies
is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > 1.0 is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > self.cfg.force_threshold
is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact
# -- update ongoing timer for bodies air is_first_detached = (self._data.current_contact_time[env_ids] > 0) * ~is_contact
self._data.current_air_time[env_ids] += elapsed_time.unsqueeze(-1) # -- update the last contact time if body has just become in contact
# -- update time for the last time bodies were in contact self._data.last_air_time[env_ids] = torch.where(
self._data.last_air_time[env_ids] = self._data.current_air_time[env_ids] * is_first_contact is_first_contact,
# -- increment timers for bodies that are not in contact self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1),
self._data.current_air_time[env_ids] *= ~is_contact self._data.last_air_time[env_ids],
)
# -- increment time for bodies that are not in contact
self._data.current_air_time[env_ids] = torch.where(
~is_contact, self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0
)
# -- update the last contact time if body has just detached
self._data.last_contact_time[env_ids] = torch.where(
is_first_detached,
self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1),
self._data.last_contact_time[env_ids],
)
# -- increment time for bodies that are in contact
self._data.current_contact_time[env_ids] = torch.where(
is_contact, self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0
)
def _set_debug_vis_impl(self, debug_vis: bool): def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers # set visibility of markers
...@@ -281,7 +371,7 @@ class ContactSensor(SensorBase): ...@@ -281,7 +371,7 @@ class ContactSensor(SensorBase):
# marker indices # marker indices
# 0: contact, 1: no contact # 0: contact, 1: no contact
net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1) net_contact_force_w = torch.norm(self._data.net_forces_w, dim=-1)
marker_indices = torch.where(net_contact_force_w > 1.0, 0, 1) marker_indices = torch.where(net_contact_force_w > self.cfg.force_threshold, 0, 1)
# check if prim is visualized # check if prim is visualized
if self.cfg.track_pose: if self.cfg.track_pose:
frame_origins: torch.Tensor = self._data.pos_w frame_origins: torch.Tensor = self._data.pos_w
......
...@@ -23,7 +23,14 @@ class ContactSensorCfg(SensorBaseCfg): ...@@ -23,7 +23,14 @@ class ContactSensorCfg(SensorBaseCfg):
"""Whether to track the pose of the sensor's origin. Defaults to False.""" """Whether to track the pose of the sensor's origin. Defaults to False."""
track_air_time: bool = False track_air_time: bool = False
"""Whether to track the air time of the bodies (time between contacts). Defaults to False.""" """Whether to track the air/contact time of the bodies (time between contacts). Defaults to False."""
force_threshold: float = 1.0
"""The threshold on the norm of the contact force that determines whether two bodies are in collision or not.
This value is only used for tracking the mode duration (the time in contact or in air),
if :attr:`track_air_time` is True.
"""
filter_prim_paths_expr: list[str] = list() filter_prim_paths_expr: list[str] = list()
"""The list of primitive paths to filter contacts with. """The list of primitive paths to filter contacts with.
......
...@@ -31,13 +31,13 @@ class ContactSensorData: ...@@ -31,13 +31,13 @@ class ContactSensorData:
If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None. If the :attr:`ContactSensorCfg.track_pose` is False, then this qunatity is None.
""" """
net_forces_w: torch.Tensor = None net_forces_w: torch.Tensor | None = None
"""The net contact forces in world frame. """The net contact forces in world frame.
Shape is (N, B, 3), where N is the number of sensors and B is the number of bodies in each sensor. Shape is (N, B, 3), where N is the number of sensors and B is the number of bodies in each sensor.
""" """
net_forces_w_history: torch.Tensor = None net_forces_w_history: torch.Tensor | None = None
"""The net contact forces in world frame. """The net contact forces in world frame.
Shape is (N, T, B, 3), where N is the number of sensors, T is the configured history length Shape is (N, T, B, 3), where N is the number of sensors, T is the configured history length
...@@ -59,16 +59,34 @@ class ContactSensorData: ...@@ -59,16 +59,34 @@ class ContactSensorData:
last_air_time: torch.Tensor | None = None last_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air before the last contact. """Time spent (in s) in the air before the last contact.
Shape is (N,), where N is the number of sensors. Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note: Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None. If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
""" """
current_air_time: torch.Tensor | None = None current_air_time: torch.Tensor | None = None
"""Time spent (in s) in the air since the last contact. """Time spent (in s) in the air since the last detach.
Shape is (N,), where N is the number of sensors. Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
last_contact_time: torch.Tensor | None = None
"""Time spent (in s) in contact before the last detach.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
"""
current_contact_time: torch.Tensor | None = None
"""Time spent (in s) in contact since the last contact.
Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor.
Note: Note:
If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None. If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None.
......
...@@ -17,7 +17,7 @@ class AnymalBFlatEnvCfg(AnymalBRoughEnvCfg): ...@@ -17,7 +17,7 @@ class AnymalBFlatEnvCfg(AnymalBRoughEnvCfg):
# override rewards # override rewards
self.rewards.flat_orientation_l2.weight = -5.0 self.rewards.flat_orientation_l2.weight = -5.0
self.rewards.dof_torques_l2.weight = -2.5e-5 self.rewards.dof_torques_l2.weight = -2.5e-5
self.rewards.feet_air_time.weight = 2.0 self.rewards.feet_air_time.weight = 0.5
# change terrain to flat # change terrain to flat
self.scene.terrain.terrain_type = "plane" self.scene.terrain.terrain_type = "plane"
self.scene.terrain.terrain_generator = None self.scene.terrain.terrain_generator = None
......
...@@ -17,7 +17,7 @@ class AnymalCFlatEnvCfg(AnymalCRoughEnvCfg): ...@@ -17,7 +17,7 @@ class AnymalCFlatEnvCfg(AnymalCRoughEnvCfg):
# override rewards # override rewards
self.rewards.flat_orientation_l2.weight = -5.0 self.rewards.flat_orientation_l2.weight = -5.0
self.rewards.dof_torques_l2.weight = -2.5e-5 self.rewards.dof_torques_l2.weight = -2.5e-5
self.rewards.feet_air_time.weight = 2.0 self.rewards.feet_air_time.weight = 0.5
# change terrain to flat # change terrain to flat
self.scene.terrain.terrain_type = "plane" self.scene.terrain.terrain_type = "plane"
self.scene.terrain.terrain_generator = None self.scene.terrain.terrain_generator = None
......
...@@ -17,7 +17,7 @@ class AnymalDFlatEnvCfg(AnymalDRoughEnvCfg): ...@@ -17,7 +17,7 @@ class AnymalDFlatEnvCfg(AnymalDRoughEnvCfg):
# override rewards # override rewards
self.rewards.flat_orientation_l2.weight = -5.0 self.rewards.flat_orientation_l2.weight = -5.0
self.rewards.dof_torques_l2.weight = -2.5e-5 self.rewards.dof_torques_l2.weight = -2.5e-5
self.rewards.feet_air_time.weight = 2.0 self.rewards.feet_air_time.weight = 0.5
# change terrain to flat # change terrain to flat
self.scene.terrain.terrain_type = "plane" self.scene.terrain.terrain_type = "plane"
self.scene.terrain.terrain_generator = None self.scene.terrain.terrain_generator = None
......
...@@ -15,7 +15,7 @@ class CassieFlatEnvCfg(CassieRoughEnvCfg): ...@@ -15,7 +15,7 @@ class CassieFlatEnvCfg(CassieRoughEnvCfg):
super().__post_init__() super().__post_init__()
# rewards # rewards
self.rewards.flat_orientation_l2.weight = -2.5 self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 20.0 self.rewards.feet_air_time.weight = 5.0
self.rewards.joint_deviation_hip.params["asset_cfg"].joint_names = ["hip_rotation_.*"] self.rewards.joint_deviation_hip.params["asset_cfg"].joint_names = ["hip_rotation_.*"]
# change terrain to flat # change terrain to flat
self.scene.terrain.terrain_type = "plane" self.scene.terrain.terrain_type = "plane"
......
...@@ -21,7 +21,7 @@ class CassieRewardsCfg(RewardsCfg): ...@@ -21,7 +21,7 @@ class CassieRewardsCfg(RewardsCfg):
termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0)
feet_air_time = RewTerm( feet_air_time = RewTerm(
func=mdp.feet_air_time_positive_biped, func=mdp.feet_air_time_positive_biped,
weight=10.0, weight=2.5,
params={ params={
"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*toe"), "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*toe"),
"command_name": "base_velocity", "command_name": "base_velocity",
......
...@@ -16,7 +16,7 @@ class UnitreeA1FlatEnvCfg(UnitreeA1RoughEnvCfg): ...@@ -16,7 +16,7 @@ class UnitreeA1FlatEnvCfg(UnitreeA1RoughEnvCfg):
# override rewards # override rewards
self.rewards.flat_orientation_l2.weight = -2.5 self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 1.0 self.rewards.feet_air_time.weight = 0.25
# change terrain to flat # change terrain to flat
self.scene.terrain.terrain_type = "plane" self.scene.terrain.terrain_type = "plane"
......
...@@ -16,7 +16,7 @@ class UnitreeGo1FlatEnvCfg(UnitreeGo1RoughEnvCfg): ...@@ -16,7 +16,7 @@ class UnitreeGo1FlatEnvCfg(UnitreeGo1RoughEnvCfg):
# override rewards # override rewards
self.rewards.flat_orientation_l2.weight = -2.5 self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 1.0 self.rewards.feet_air_time.weight = 0.25
# change terrain to flat # change terrain to flat
self.scene.terrain.terrain_type = "plane" self.scene.terrain.terrain_type = "plane"
......
...@@ -16,7 +16,7 @@ class UnitreeGo2FlatEnvCfg(UnitreeGo2RoughEnvCfg): ...@@ -16,7 +16,7 @@ class UnitreeGo2FlatEnvCfg(UnitreeGo2RoughEnvCfg):
# override rewards # override rewards
self.rewards.flat_orientation_l2.weight = -2.5 self.rewards.flat_orientation_l2.weight = -2.5
self.rewards.feet_air_time.weight = 1.0 self.rewards.feet_air_time.weight = 0.25
# change terrain to flat # change terrain to flat
self.scene.terrain.terrain_type = "plane" self.scene.terrain.terrain_type = "plane"
......
...@@ -27,8 +27,8 @@ def feet_air_time(env: RLTaskEnv, command_name: str, sensor_cfg: SceneEntityCfg, ...@@ -27,8 +27,8 @@ def feet_air_time(env: RLTaskEnv, command_name: str, sensor_cfg: SceneEntityCfg,
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
# compute the reward # compute the reward
first_contact = contact_sensor.compute_first_contact(env.step_dt)[:, sensor_cfg.body_ids]
last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids] last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids]
first_contact = last_air_time > 0.0
reward = torch.sum((last_air_time - threshold) * first_contact, dim=1) reward = torch.sum((last_air_time - threshold) * first_contact, dim=1)
# no reward for zero command # no reward for zero command
reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1 reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1
...@@ -43,12 +43,15 @@ def feet_air_time_positive_biped(env, command_name: str, threshold: float, senso ...@@ -43,12 +43,15 @@ def feet_air_time_positive_biped(env, command_name: str, threshold: float, senso
If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero. If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero.
""" """
contact_sensor = env.scene.sensors[sensor_cfg.name] contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
# compute the reward # compute the reward
last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids] air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]
was_in_air = (last_air_time > 0.0).float() contact_time = contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids]
num_feet_contact = torch.sum(was_in_air, dim=1).int() in_contact = contact_time > 0.0
reward = torch.where(num_feet_contact == 1, torch.sum(last_air_time.clamp_max_(threshold), dim=1), 0.0) in_mode_time = torch.where(in_contact, contact_time, air_time)
single_stance = torch.sum(in_contact.int(), dim=1) == 1
reward = torch.min(torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1)[0]
reward = torch.clamp(reward, max=threshold)
# no reward for zero command # no reward for zero command
reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1 reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1
return reward return reward
...@@ -233,7 +233,7 @@ class RewardsCfg: ...@@ -233,7 +233,7 @@ class RewardsCfg:
action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01) action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01)
feet_air_time = RewTerm( feet_air_time = RewTerm(
func=mdp.feet_air_time, func=mdp.feet_air_time,
weight=0.5, weight=0.125,
params={ params={
"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"), "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"),
"command_name": "base_velocity", "command_name": "base_velocity",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment