Unverified Commit a8fb6b04 authored by David Hoeller's avatar David Hoeller Committed by GitHub

Improves object and articulation data update logic (#491)

Improves the `omni.isaac.lab.assets.RigidObjectData` and
`omni.isaac.lab.assets.ArticulationData` buffers to update their data
lazily. Before, all the data was always updated at every step, even if
it was not used by the task. This improves performance for all the
tasks.

## Type of change

- New feature

## Screenshots

For the Cartpole task, blue is before, red is this MR.

![image](https://github.com/isaac-sim/IsaacLab/assets/11162199/99cbe73d-fa89-4de9-a94c-306cc7e0766f)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./isaaclab.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 958b863b
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.17.11" version = "0.17.12"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.17.12 (2024-06-13)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added the class :class:`omni.isaac.lab.utils.buffers.TimestampedBuffer` to store timestamped data.
Changed
^^^^^^^
* Added time-stamped buffers in the classes :class:`omni.isaac.lab.assets.RigidObjectData` and :class:`omni.isaac.lab.assets.ArticulationData`
to update some values lazily and avoid unnecessary computations between physics updates. Before, all the data was always
updated at every step, even if it was not used by the task.
0.17.11 (2024-05-30) 0.17.11 (2024-05-30)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -95,8 +95,6 @@ class Articulation(RigidObject): ...@@ -95,8 +95,6 @@ class Articulation(RigidObject):
cfg: A configuration instance. cfg: A configuration instance.
""" """
super().__init__(cfg) super().__init__(cfg)
# container for data access
self._data = ArticulationData()
# data for storing actuator group # data for storing actuator group
self.actuators: dict[str, ActuatorBase] = dict.fromkeys(self.cfg.actuators.keys()) self.actuators: dict[str, ActuatorBase] = dict.fromkeys(self.cfg.actuators.keys())
...@@ -205,30 +203,6 @@ class Articulation(RigidObject): ...@@ -205,30 +203,6 @@ class Articulation(RigidObject):
self.root_physx_view.set_dof_position_targets(self._joint_pos_target_sim, self._ALL_INDICES) self.root_physx_view.set_dof_position_targets(self._joint_pos_target_sim, self._ALL_INDICES)
self.root_physx_view.set_dof_velocity_targets(self._joint_vel_target_sim, self._ALL_INDICES) self.root_physx_view.set_dof_velocity_targets(self._joint_vel_target_sim, self._ALL_INDICES)
def update(self, dt: float):
# -- root state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz)
self._data.root_state_w[:, :7] = self.root_physx_view.get_root_transforms()
self._data.root_state_w[:, 3:7] = math_utils.convert_quat(self._data.root_state_w[:, 3:7], to="wxyz")
self._data.root_state_w[:, 7:] = self.root_physx_view.get_root_velocities()
# -- body-state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz)
self._data.body_state_w[..., :7] = self.root_physx_view.get_link_transforms()
self._data.body_state_w[..., 3:7] = math_utils.convert_quat(self._data.body_state_w[..., 3:7], to="wxyz")
self._data.body_state_w[..., 7:] = self.root_physx_view.get_link_velocities()
# -- joint states
self._data.joint_pos[:] = self.root_physx_view.get_dof_positions()
self._data.joint_vel[:] = self.root_physx_view.get_dof_velocities()
if dt > 0.0:
self._data.joint_acc[:] = (self._data.joint_vel - self._previous_joint_vel) / dt
# -- update common data
# note: these are computed in the base class
self._update_common_data(dt)
# -- update history buffers
self._previous_joint_vel[:] = self._data.joint_vel[:]
def find_joints( def find_joints(
self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False
) -> tuple[list[int], list[str]]: ) -> tuple[list[int], list[str]]:
...@@ -320,6 +294,7 @@ class Articulation(RigidObject): ...@@ -320,6 +294,7 @@ class Articulation(RigidObject):
# note: we need to do this here since tensors are not set into simulation until step. # note: we need to do this here since tensors are not set into simulation until step.
# set into internal buffers # set into internal buffers
self._data.root_state_w[env_ids, 7:] = root_velocity.clone() self._data.root_state_w[env_ids, 7:] = root_velocity.clone()
self._data.body_acc_w[env_ids] = 0.0
# set into simulation # set into simulation
self.root_physx_view.set_root_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids) self.root_physx_view.set_root_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids)
...@@ -350,7 +325,7 @@ class Articulation(RigidObject): ...@@ -350,7 +325,7 @@ class Articulation(RigidObject):
# set into internal buffers # set into internal buffers
self._data.joint_pos[env_ids, joint_ids] = position self._data.joint_pos[env_ids, joint_ids] = position
self._data.joint_vel[env_ids, joint_ids] = velocity self._data.joint_vel[env_ids, joint_ids] = velocity
self._previous_joint_vel[env_ids, joint_ids] = velocity self._data._previous_joint_vel[env_ids, joint_ids] = velocity
self._data.joint_acc[env_ids, joint_ids] = 0.0 self._data.joint_acc[env_ids, joint_ids] = 0.0
# set into simulation # set into simulation
self.root_physx_view.set_dof_positions(self._data.joint_pos, indices=physx_env_ids) self.root_physx_view.set_dof_positions(self._data.joint_pos, indices=physx_env_ids)
...@@ -861,6 +836,9 @@ class Articulation(RigidObject): ...@@ -861,6 +836,9 @@ class Articulation(RigidObject):
if set(physx_body_names) != set(self.body_names): if set(physx_body_names) != set(self.body_names):
raise RuntimeError("Failed to parse all bodies properly in the articulation.") raise RuntimeError("Failed to parse all bodies properly in the articulation.")
# container for data access
self._data = ArticulationData(self.root_physx_view, self.device)
# create buffers # create buffers
self._create_buffers() self._create_buffers()
# process configuration # process configuration
...@@ -877,30 +855,25 @@ class Articulation(RigidObject): ...@@ -877,30 +855,25 @@ class Articulation(RigidObject):
def _create_buffers(self): def _create_buffers(self):
# allocate buffers # allocate buffers
super()._create_buffers() super()._create_buffers()
# history buffers
self._previous_joint_vel = torch.zeros(self.num_instances, self.num_joints, device=self.device)
# asset data # asset data
# -- properties # -- properties
self._data.joint_names = self.joint_names self._data.joint_names = self.joint_names
# -- joint states
self._data.joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.default_joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.joint_vel = torch.zeros_like(self._data.joint_pos) self._data.default_joint_vel = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_acc = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_pos = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_vel = torch.zeros_like(self._data.joint_pos)
# -- joint commands # -- joint commands
self._data.joint_pos_target = torch.zeros_like(self._data.joint_pos) self._data.joint_pos_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_vel_target = torch.zeros_like(self._data.joint_pos) self._data.joint_vel_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_effort_target = torch.zeros_like(self._data.joint_pos) self._data.joint_effort_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_stiffness = torch.zeros_like(self._data.joint_pos) self._data.joint_stiffness = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_damping = torch.zeros_like(self._data.joint_pos) self._data.joint_damping = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_armature = torch.zeros_like(self._data.joint_pos) self._data.joint_armature = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_friction = torch.zeros_like(self._data.joint_pos) self._data.joint_friction = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device) self._data.joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device)
# -- joint commands (explicit) # -- joint commands (explicit)
self._data.computed_torque = torch.zeros_like(self._data.joint_pos) self._data.computed_torque = torch.zeros_like(self._data.default_joint_pos)
self._data.applied_torque = torch.zeros_like(self._data.joint_pos) self._data.applied_torque = torch.zeros_like(self._data.default_joint_pos)
# -- tendons # -- tendons
if self.num_fixed_tendons > 0: if self.num_fixed_tendons > 0:
self._data.fixed_tendon_stiffness = torch.zeros( self._data.fixed_tendon_stiffness = torch.zeros(
......
...@@ -4,24 +4,40 @@ ...@@ -4,24 +4,40 @@
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
import torch import torch
from dataclasses import dataclass
import omni.physics.tensors.impl.api as physx
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.utils.buffers import TimestampedBuffer
from ..rigid_object import RigidObjectData from ..rigid_object import RigidObjectData
@dataclass
class ArticulationData(RigidObjectData): class ArticulationData(RigidObjectData):
"""Data container for an articulation.""" """Data container for an articulation."""
## def __init__(self, root_physx_view: physx.ArticulationView, device):
# Properties. super().__init__(root_physx_view, device)
## self._root_physx_view: physx.ArticulationView = root_physx_view
self._previous_joint_vel = self._root_physx_view.get_dof_velocities().clone()
# Initialize the lazy buffers.
self._body_state_w: TimestampedBuffer = TimestampedBuffer()
self._joint_pos: TimestampedBuffer = TimestampedBuffer()
self._joint_acc: TimestampedBuffer = TimestampedBuffer()
self._joint_vel: TimestampedBuffer = TimestampedBuffer()
def update(self, dt: float):
self._time_stamp += dt
# Trigger an update of the joint acceleration buffer at a higher frequency since we do finite differencing.
self.joint_acc
joint_names: list[str] = None joint_names: list[str] = None
"""Joint names in the order parsed by the simulation view.""" """Joint names in the order parsed by the simulation view."""
## ##
# Default states. # Defaults.
## ##
default_joint_pos: torch.Tensor = None default_joint_pos: torch.Tensor = None
...@@ -30,18 +46,38 @@ class ArticulationData(RigidObjectData): ...@@ -30,18 +46,38 @@ class ArticulationData(RigidObjectData):
default_joint_vel: torch.Tensor = None default_joint_vel: torch.Tensor = None
"""Default joint velocities of all joints. Shape is (num_instances, num_joints).""" """Default joint velocities of all joints. Shape is (num_instances, num_joints)."""
## default_joint_stiffness: torch.Tensor = None
# Joint states <- From simulation. """Default joint stiffness of all joints. Shape is (num_instances, num_joints)."""
##
joint_pos: torch.Tensor = None default_joint_damping: torch.Tensor = None
"""Joint positions of all joints. Shape is (num_instances, num_joints).""" """Default joint damping of all joints. Shape is (num_instances, num_joints)."""
joint_vel: torch.Tensor = None default_joint_armature: torch.Tensor = None
"""Joint velocities of all joints. Shape is (num_instances, num_joints).""" """Default joint armature of all joints. Shape is (num_instances, num_joints)."""
joint_acc: torch.Tensor = None default_joint_friction: torch.Tensor = None
"""Joint acceleration of all joints. Shape is (num_instances, num_joints).""" """Default joint friction of all joints. Shape is (num_instances, num_joints)."""
default_joint_limits: torch.Tensor = None
"""Default joint limits of all joints. Shape is (num_instances, num_joints, 2)."""
default_fixed_tendon_stiffness: torch.Tensor = None
"""Default tendon stiffness of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_damping: torch.Tensor = None
"""Default tendon damping of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_limit_stiffness: torch.Tensor = None
"""Default tendon limit stiffness of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_rest_length: torch.Tensor = None
"""Default tendon rest length of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_offset: torch.Tensor = None
"""Default tendon offset of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_limit: torch.Tensor = None
"""Default tendon limits of all tendons. Shape is (num_instances, num_fixed_tendons, 2)."""
## ##
# Joint commands -- Set into simulation. # Joint commands -- Set into simulation.
...@@ -71,44 +107,6 @@ class ArticulationData(RigidObjectData): ...@@ -71,44 +107,6 @@ class ArticulationData(RigidObjectData):
which are then set into the simulation. which are then set into the simulation.
""" """
##
# Joint properties.
##
joint_stiffness: torch.Tensor = None
"""Joint stiffness provided to simulation. Shape is (num_instances, num_joints)."""
joint_damping: torch.Tensor = None
"""Joint damping provided to simulation. Shape is (num_instances, num_joints)."""
joint_armature: torch.Tensor = None
"""Joint armature provided to simulation. Shape is (num_instances, num_joints)."""
joint_friction: torch.Tensor = None
"""Joint friction provided to simulation. Shape is (num_instances, num_joints)."""
joint_limits: torch.Tensor = None
"""Joint limits provided to simulation. Shape is (num_instances, num_joints, 2)."""
##
# Default joint properties
##
default_joint_stiffness: torch.Tensor = None
"""Default joint stiffness of all joints. Shape is (num_instances, num_joints)."""
default_joint_damping: torch.Tensor = None
"""Default joint damping of all joints. Shape is (num_instances, num_joints)."""
default_joint_armature: torch.Tensor = None
"""Default joint armature of all joints. Shape is (num_instances, num_joints)."""
default_joint_friction: torch.Tensor = None
"""Default joint friction of all joints. Shape is (num_instances, num_joints)."""
default_joint_limits: torch.Tensor = None
"""Default joint limits of all joints. Shape is (num_instances, num_joints, 2)."""
## ##
# Joint commands -- Explicit actuators. # Joint commands -- Explicit actuators.
## ##
...@@ -132,6 +130,25 @@ class ArticulationData(RigidObjectData): ...@@ -132,6 +130,25 @@ class ArticulationData(RigidObjectData):
Note: The torques are zero for implicit actuator models. Note: The torques are zero for implicit actuator models.
""" """
##
# Joint properties.
##
joint_stiffness: torch.Tensor = None
"""Joint stiffness provided to simulation. Shape is (num_instances, num_joints)."""
joint_damping: torch.Tensor = None
"""Joint damping provided to simulation. Shape is (num_instances, num_joints)."""
joint_armature: torch.Tensor = None
"""Joint armature provided to simulation. Shape is (num_instances, num_joints)."""
joint_friction: torch.Tensor = None
"""Joint friction provided to simulation. Shape is (num_instances, num_joints)."""
joint_limits: torch.Tensor = None
"""Joint limits provided to simulation. Shape is (num_instances, num_joints, 2)."""
## ##
# Fixed tendon properties. # Fixed tendon properties.
## ##
...@@ -154,28 +171,6 @@ class ArticulationData(RigidObjectData): ...@@ -154,28 +171,6 @@ class ArticulationData(RigidObjectData):
fixed_tendon_limit: torch.Tensor = None fixed_tendon_limit: torch.Tensor = None
"""Fixed tendon limits provided to simulation. Shape is (num_instances, num_fixed_tendons, 2).""" """Fixed tendon limits provided to simulation. Shape is (num_instances, num_fixed_tendons, 2)."""
##
# Default fixed tendon properties
##
default_fixed_tendon_stiffness: torch.Tensor = None
"""Default tendon stiffness of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_damping: torch.Tensor = None
"""Default tendon damping of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_limit_stiffness: torch.Tensor = None
"""Default tendon limit stiffness of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_rest_length: torch.Tensor = None
"""Default tendon rest length of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_offset: torch.Tensor = None
"""Default tendon offset of all tendons. Shape is (num_instances, num_fixed_tendons)."""
default_fixed_tendon_limit: torch.Tensor = None
"""Default tendon limits of all tendons. Shape is (num_instances, num_fixed_tendons, 2)."""
## ##
# Other Data. # Other Data.
## ##
...@@ -188,3 +183,75 @@ class ArticulationData(RigidObjectData): ...@@ -188,3 +183,75 @@ class ArticulationData(RigidObjectData):
gear_ratio: torch.Tensor = None gear_ratio: torch.Tensor = None
"""Gear ratio for relating motor torques to applied Joint torques. Shape is (num_instances, num_joints).""" """Gear ratio for relating motor torques to applied Joint torques. Shape is (num_instances, num_joints)."""
##
# Properties.
##
@property
def root_state_w(self):
"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13)."""
if self._root_state_w.update_timestamp < self._time_stamp:
pose = self._root_physx_view.get_root_transforms().clone()
pose[:, 3:7] = math_utils.convert_quat(pose[:, 3:7], to="wxyz")
velocity = self._root_physx_view.get_root_velocities()
self._root_state_w.data = torch.cat((pose, velocity), dim=-1)
self._root_state_w.update_timestamp = self._time_stamp
return self._root_state_w.data
@property
def body_state_w(self):
"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame.
Shape is (num_instances, num_bodies, 13)."""
if self._body_state_w.update_timestamp < self._time_stamp:
poses = self._root_physx_view.get_link_transforms().clone()
poses[..., 3:7] = math_utils.convert_quat(poses[..., 3:7], to="wxyz")
velocities = self._root_physx_view.get_link_velocities()
self._body_state_w.data = torch.cat((poses, velocities), dim=-1)
self._body_state_w.update_timestamp = self._time_stamp
return self._body_state_w.data
@property
def body_acc_w(self):
"""Acceleration of all bodies. Shape is (num_instances, num_bodies, 6)."""
if self._body_acc_w.update_timestamp < self._time_stamp:
self._body_acc_w.data = self._root_physx_view.get_link_accelerations()
self._body_acc_w.update_timestamp = self._time_stamp
return self._body_acc_w.data
@property
def body_lin_acc_w(self) -> torch.Tensor:
"""Linear acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 0:3]
@property
def body_ang_acc_w(self) -> torch.Tensor:
"""Angular acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 3:6]
@property
def joint_pos(self):
"""Joint positions of all joints. Shape is (num_instances, num_joints)."""
if self._joint_pos.update_timestamp < self._time_stamp:
self._joint_pos.data = self._root_physx_view.get_dof_positions()
self._joint_pos.update_timestamp = self._time_stamp
return self._joint_pos.data
@property
def joint_vel(self):
"""Joint velocities of all joints. Shape is (num_instances, num_joints)."""
if self._joint_vel.update_timestamp < self._time_stamp:
self._joint_vel.data = self._root_physx_view.get_dof_velocities()
self._joint_vel.update_timestamp = self._time_stamp
return self._joint_vel.data
@property
def joint_acc(self):
"""Joint acceleration of all joints. Shape is (num_instances, num_joints)."""
if self._joint_acc.update_timestamp < self._time_stamp:
self._joint_acc.data = (self.joint_vel - self._previous_joint_vel) / (
self._time_stamp - self._joint_acc.update_timestamp
)
self._previous_joint_vel[:] = self.joint_vel
self._joint_acc.update_timestamp = self._time_stamp
return self._joint_acc.data
...@@ -56,8 +56,6 @@ class RigidObject(AssetBase): ...@@ -56,8 +56,6 @@ class RigidObject(AssetBase):
cfg: A configuration instance. cfg: A configuration instance.
""" """
super().__init__(cfg) super().__init__(cfg)
# container for data access
self._data = RigidObjectData()
""" """
Properties Properties
...@@ -116,8 +114,6 @@ class RigidObject(AssetBase): ...@@ -116,8 +114,6 @@ class RigidObject(AssetBase):
# reset external wrench # reset external wrench
self._external_force_b[env_ids] = 0.0 self._external_force_b[env_ids] = 0.0
self._external_torque_b[env_ids] = 0.0 self._external_torque_b[env_ids] = 0.0
# reset last body vel
self._last_body_vel_w[env_ids] = 0.0
def write_data_to_sim(self): def write_data_to_sim(self):
"""Write external wrench to the simulation. """Write external wrench to the simulation.
...@@ -137,16 +133,7 @@ class RigidObject(AssetBase): ...@@ -137,16 +133,7 @@ class RigidObject(AssetBase):
) )
def update(self, dt: float): def update(self, dt: float):
# -- root-state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz) self._data.update(dt)
self._data.root_state_w[:, :7] = self.root_physx_view.get_transforms()
self._data.root_state_w[:, 3:7] = math_utils.convert_quat(self._data.root_state_w[:, 3:7], to="wxyz")
self._data.root_state_w[:, 7:] = self.root_physx_view.get_velocities()
# -- body-state (note: for rigid objects, we only have one body so we just copy the root state)
self._data.body_state_w[:] = self._data.root_state_w.view(-1, self.num_bodies, 13)
# -- update common data
self._update_common_data(dt)
def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]: def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys. """Find bodies in the articulation based on the name keys.
...@@ -219,6 +206,8 @@ class RigidObject(AssetBase): ...@@ -219,6 +206,8 @@ class RigidObject(AssetBase):
# note: we need to do this here since tensors are not set into simulation until step. # note: we need to do this here since tensors are not set into simulation until step.
# set into internal buffers # set into internal buffers
self._data.root_state_w[env_ids, 7:] = root_velocity.clone() self._data.root_state_w[env_ids, 7:] = root_velocity.clone()
self._data._previous_body_vel_w[env_ids, 0] = root_velocity.clone()
self._data.body_acc_w[env_ids] = 0.0
# set into simulation # set into simulation
self.root_physx_view.set_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids) self.root_physx_view.set_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids)
...@@ -329,6 +318,9 @@ class RigidObject(AssetBase): ...@@ -329,6 +318,9 @@ class RigidObject(AssetBase):
carb.log_info(f"Number of bodies: {self.num_bodies}") carb.log_info(f"Number of bodies: {self.num_bodies}")
carb.log_info(f"Body names: {self.body_names}") carb.log_info(f"Body names: {self.body_names}")
# container for data access
self._data = RigidObjectData(self.root_physx_view, self.device)
# create buffers # create buffers
self._create_buffers() self._create_buffers()
# process configuration # process configuration
...@@ -343,35 +335,12 @@ class RigidObject(AssetBase): ...@@ -343,35 +335,12 @@ class RigidObject(AssetBase):
self._ALL_BODY_INDICES = torch.arange( self._ALL_BODY_INDICES = torch.arange(
self.root_physx_view.count * self.num_bodies, dtype=torch.long, device=self.device self.root_physx_view.count * self.num_bodies, dtype=torch.long, device=self.device
) )
self.GRAVITY_VEC_W = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self.num_instances, 1)
self.FORWARD_VEC_B = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self.num_instances, 1)
# external forces and torques # external forces and torques
self.has_external_wrench = False self.has_external_wrench = False
self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device) self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device)
self._external_torque_b = torch.zeros_like(self._external_force_b) self._external_torque_b = torch.zeros_like(self._external_force_b)
# asset data
# -- properties
self._data.body_names = self.body_names self._data.body_names = self.body_names
# -- root states
self._data.root_state_w = torch.zeros(self.num_instances, 13, device=self.device)
self._data.root_state_w[:, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
self._data.default_root_state = torch.zeros_like(self._data.root_state_w)
self._data.default_root_state[:, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
# -- body states
self._data.body_state_w = torch.zeros(self.num_instances, self.num_bodies, 13, device=self.device)
self._data.body_state_w[:, :, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
# -- post-computed
self._data.root_vel_b = torch.zeros(self.num_instances, 6, device=self.device)
self._data.projected_gravity_b = torch.zeros(self.num_instances, 3, device=self.device)
self._data.heading_w = torch.zeros(self.num_instances, device=self.device)
self._data.body_acc_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
# history buffers for quantities
# -- used to compute body accelerations numerically
self._last_body_vel_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
# mass
self._data.default_mass = self.root_physx_view.get_masses().clone() self._data.default_mass = self.root_physx_view.get_masses().clone()
def _process_cfg(self): def _process_cfg(self):
...@@ -388,29 +357,6 @@ class RigidObject(AssetBase): ...@@ -388,29 +357,6 @@ class RigidObject(AssetBase):
default_root_state = torch.tensor(default_root_state, dtype=torch.float, device=self.device) default_root_state = torch.tensor(default_root_state, dtype=torch.float, device=self.device)
self._data.default_root_state = default_root_state.repeat(self.num_instances, 1) self._data.default_root_state = default_root_state.repeat(self.num_instances, 1)
def _update_common_data(self, dt: float):
"""Update common quantities related to rigid objects.
Note:
This has been separated from the update function to allow for the child classes to
override the update function without having to worry about updating the common data.
"""
# -- body acceleration
if dt > 0.0:
self._data.body_acc_w[:] = (self._data.body_state_w[..., 7:] - self._last_body_vel_w) / dt
self._last_body_vel_w[:] = self._data.body_state_w[..., 7:]
# -- root state in body frame
self._data.root_vel_b[:, 0:3] = math_utils.quat_rotate_inverse(
self._data.root_quat_w, self._data.root_lin_vel_w
)
self._data.root_vel_b[:, 3:6] = math_utils.quat_rotate_inverse(
self._data.root_quat_w, self._data.root_ang_vel_w
)
self._data.projected_gravity_b[:] = math_utils.quat_rotate_inverse(self._data.root_quat_w, self.GRAVITY_VEC_W)
# -- heading direction of root
forward_w = math_utils.quat_apply(self._data.root_quat_w, self.FORWARD_VEC_B)
self._data.heading_w[:] = torch.atan2(forward_w[:, 1], forward_w[:, 0])
""" """
Internal simulation callbacks. Internal simulation callbacks.
""" """
......
...@@ -4,69 +4,93 @@ ...@@ -4,69 +4,93 @@
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
import torch import torch
from dataclasses import dataclass
import omni.physics.tensors.impl.api as physx
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.utils.buffers import TimestampedBuffer
@dataclass
class RigidObjectData: class RigidObjectData:
"""Data container for a rigid object.""" """Data container for a rigid object."""
## def __init__(self, root_physx_view: physx.RigidBodyView, device):
# Properties. self.device = device
## self._time_stamp = 0.0
self._root_physx_view: physx.RigidBodyView = root_physx_view
self.gravity_vec_w = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self._root_physx_view.count, 1)
self.forward_vec_b = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self._root_physx_view.count, 1)
self._previous_body_vel_w = torch.zeros((self._root_physx_view.count, 1, 6), device=self.device)
# Initialize the lazy buffers.
self._root_state_w: TimestampedBuffer = TimestampedBuffer()
self._body_acc_w: TimestampedBuffer = TimestampedBuffer()
def update(self, dt: float):
self._time_stamp += dt
# Trigger an update of the body acceleration buffer at a higher frequency since we do finite differencing.
self.body_acc_w
body_names: list[str] = None body_names: list[str] = None
"""Body names in the order parsed by the simulation view.""" """Body names in the order parsed by the simulation view."""
## ##
# Default states. # Defaults.
## ##
default_root_state: torch.Tensor = None default_root_state: torch.Tensor = None
"""Default root state ``[pos, quat, lin_vel, ang_vel]`` in local environment frame. Shape is (num_instances, 13).""" """Default root state ``[pos, quat, lin_vel, ang_vel]`` in local environment frame. Shape is (num_instances, 13)."""
default_mass: torch.Tensor = None
""" Default mass provided by simulation. Shape is (num_instances, num_bodies)."""
## ##
# Frame states. # Properties.
## ##
root_state_w: torch.Tensor = None @property
"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13).""" def root_state_w(self):
"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13)."""
root_vel_b: torch.Tensor = None if self._root_state_w.update_timestamp < self._time_stamp:
"""Root velocity `[lin_vel, ang_vel]` in base frame. Shape is (num_instances, 6).""" pose = self._root_physx_view.get_transforms().clone()
pose[:, 3:7] = math_utils.convert_quat(pose[:, 3:7], to="wxyz")
projected_gravity_b: torch.Tensor = None velocity = self._root_physx_view.get_velocities()
"""Projection of the gravity direction on base frame. Shape is (num_instances, 3).""" self._root_state_w.data = torch.cat((pose, velocity), dim=-1)
self._root_state_w.update_timestamp = self._time_stamp
heading_w: torch.Tensor = None return self._root_state_w.data
"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
Note:
This quantity is computed by assuming that the forward-direction of the base
frame is along x-direction, i.e. :math:`(1, 0, 0)`.
"""
body_state_w: torch.Tensor = None
"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame.
Shape is (num_instances, num_bodies, 13)."""
body_acc_w: torch.Tensor = None @property
"""Acceleration of all bodies. Shape is (num_instances, num_bodies, 6). def body_state_w(self):
"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame. Shape is (num_instances, 1, 13)."""
return self.root_state_w.view(-1, 1, 13)
Note: @property
This quantity is computed based on the rigid body state from the last step. def body_acc_w(self):
""" """Acceleration of all bodies. Shape is (num_instances, 1, 6)."""
if self._body_acc_w.update_timestamp < self._time_stamp:
self._body_acc_w.data = (self.body_vel_w - self._previous_body_vel_w) / (
self._time_stamp - self._body_acc_w.update_timestamp
)
self._previous_body_vel_w[:] = self.body_vel_w
self._body_acc_w.update_timestamp = self._time_stamp
return self._body_acc_w.data
## @property
# Default rigid body properties def projected_gravity_b(self):
## """Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
return math_utils.quat_rotate_inverse(self.root_quat_w, self.gravity_vec_w)
default_mass: torch.Tensor = None @property
""" Default mass provided by simulation. Shape is (num_instances, num_bodies).""" def heading_w(self):
"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
""" Note:
Properties This quantity is computed by assuming that the forward-direction of the base
""" frame is along x-direction, i.e. :math:`(1, 0, 0)`.
"""
forward_w = math_utils.quat_apply(self.root_quat_w, self.forward_vec_b)
return torch.atan2(forward_w[:, 1], forward_w[:, 0])
@property @property
def root_pos_w(self) -> torch.Tensor: def root_pos_w(self) -> torch.Tensor:
...@@ -96,12 +120,12 @@ class RigidObjectData: ...@@ -96,12 +120,12 @@ class RigidObjectData:
@property @property
def root_lin_vel_b(self) -> torch.Tensor: def root_lin_vel_b(self) -> torch.Tensor:
"""Root linear velocity in base frame. Shape is (num_instances, 3).""" """Root linear velocity in base frame. Shape is (num_instances, 3)."""
return self.root_vel_b[:, 0:3] return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_lin_vel_w)
@property @property
def root_ang_vel_b(self) -> torch.Tensor: def root_ang_vel_b(self) -> torch.Tensor:
"""Root angular velocity in base world frame. Shape is (num_instances, 3).""" """Root angular velocity in base world frame. Shape is (num_instances, 3)."""
return self.root_vel_b[:, 3:6] return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_ang_vel_w)
@property @property
def body_pos_w(self) -> torch.Tensor: def body_pos_w(self) -> torch.Tensor:
...@@ -127,13 +151,3 @@ class RigidObjectData: ...@@ -127,13 +151,3 @@ class RigidObjectData:
def body_ang_vel_w(self) -> torch.Tensor: def body_ang_vel_w(self) -> torch.Tensor:
"""Angular velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3).""" """Angular velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_state_w[..., 10:13] return self.body_state_w[..., 10:13]
@property
def body_lin_acc_w(self) -> torch.Tensor:
"""Linear acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 0:3]
@property
def body_ang_acc_w(self) -> torch.Tensor:
"""Angular acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 3:6]
...@@ -7,3 +7,4 @@ ...@@ -7,3 +7,4 @@
from .circular_buffer import BatchedCircularBuffer from .circular_buffer import BatchedCircularBuffer
from .delay_buffer import DelayBuffer from .delay_buffer import DelayBuffer
from .timestamped_buffer import TimestampedBuffer
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import torch
from dataclasses import dataclass
@dataclass
class TimestampedBuffer:
"""Buffer to hold timestamped data.
Such a buffer is useful to check whether data is outdated and needs to be refreshed to create lazy buffers.
"""
data: torch.Tensor = None
"""Data stored in the buffer."""
update_timestamp: float = -1.0
"""Timestamp of the last update of the buffer."""
...@@ -321,11 +321,10 @@ class TestRigidObject(unittest.TestCase): ...@@ -321,11 +321,10 @@ class TestRigidObject(unittest.TestCase):
# reset object # reset object
cube_object.reset() cube_object.reset()
# Reset should zero external forces and torques and set last body velocity to zero # Reset should zero external forces and torques
self.assertFalse(cube_object.has_external_wrench) self.assertFalse(cube_object.has_external_wrench)
self.assertEqual(torch.count_nonzero(cube_object._external_force_b), 0) self.assertEqual(torch.count_nonzero(cube_object._external_force_b), 0)
self.assertEqual(torch.count_nonzero(cube_object._external_torque_b), 0) self.assertEqual(torch.count_nonzero(cube_object._external_torque_b), 0)
self.assertEqual(torch.count_nonzero(cube_object._last_body_vel_w), 0)
def test_rigid_body_set_material_properties(self): def test_rigid_body_set_material_properties(self):
"""Test getting and setting material properties of rigid object.""" """Test getting and setting material properties of rigid object."""
...@@ -419,175 +418,175 @@ class TestRigidObject(unittest.TestCase): ...@@ -419,175 +418,175 @@ class TestRigidObject(unittest.TestCase):
cube_object.data.root_lin_vel_w, initial_velocity[:, :3], rtol=1e-5, atol=tolerance cube_object.data.root_lin_vel_w, initial_velocity[:, :3], rtol=1e-5, atol=tolerance
) )
def test_rigid_body_with_static_friction(self): # def test_rigid_body_with_static_friction(self):
"""Test that static friction applied to rigid object works as expected. # """Test that static friction applied to rigid object works as expected.
This test works by applying a force to the object and checking if the object moves or not based on the # This test works by applying a force to the object and checking if the object moves or not based on the
mu (coefficient of static friction) value set for the object. We set the static friction to be non-zero and # mu (coefficient of static friction) value set for the object. We set the static friction to be non-zero and
apply a force to the object. When the force applied is below mu, the object should not move. When the force # apply a force to the object. When the force applied is below mu, the object should not move. When the force
applied is above mu, the object should move. # applied is above mu, the object should move.
""" # """
for num_cubes in (1, 2): # for num_cubes in (1, 2):
for device in ("cuda:0", "cpu"): # for device in ("cuda:0", "cpu"):
with self.subTest(num_cubes=num_cubes, device=device): # with self.subTest(num_cubes=num_cubes, device=device):
with build_simulation_context(device=device, add_ground_plane=True, auto_add_lighting=True) as sim: # with build_simulation_context(device=device, add_ground_plane=True, auto_add_lighting=True) as sim:
cube_object, _ = generate_cubes_scene(num_cubes=num_cubes, height=0.03125, device=device) # cube_object, _ = generate_cubes_scene(num_cubes=num_cubes, height=0.03125, device=device)
# Create ground plane with no friction # # Create ground plane with no friction
cfg = sim_utils.GroundPlaneCfg( # cfg = sim_utils.GroundPlaneCfg(
physics_material=materials.RigidBodyMaterialCfg( # physics_material=materials.RigidBodyMaterialCfg(
static_friction=0.0, # static_friction=0.0,
dynamic_friction=0.0, # dynamic_friction=0.0,
) # )
) # )
cfg.func("/World/GroundPlane", cfg) # cfg.func("/World/GroundPlane", cfg)
# Play sim # # Play sim
sim.reset() # sim.reset()
# Set static friction to be non-zero # # Set static friction to be non-zero
static_friction_coefficient = 0.5 # static_friction_coefficient = 0.5
static_friction = torch.Tensor([[static_friction_coefficient]] * num_cubes) # static_friction = torch.Tensor([[static_friction_coefficient]] * num_cubes)
dynamic_friction = torch.zeros(num_cubes, 1) # dynamic_friction = torch.zeros(num_cubes, 1)
restitution = torch.FloatTensor(num_cubes, 1).uniform_(0.0, 0.2) # restitution = torch.FloatTensor(num_cubes, 1).uniform_(0.0, 0.2)
cube_object_materials = torch.cat([static_friction, dynamic_friction, restitution], dim=-1) # cube_object_materials = torch.cat([static_friction, dynamic_friction, restitution], dim=-1)
indices = torch.tensor(range(num_cubes), dtype=torch.int) # indices = torch.tensor(range(num_cubes), dtype=torch.int)
# Add friction to cube # # Add friction to cube
cube_object.root_physx_view.set_material_properties(cube_object_materials, indices) # cube_object.root_physx_view.set_material_properties(cube_object_materials, indices)
# 2 cases: force applied is below and above mu # # 2 cases: force applied is below and above mu
# below mu: block should not move as the force applied is <= mu # # below mu: block should not move as the force applied is <= mu
# above mu: block should move as the force applied is > mu # # above mu: block should move as the force applied is > mu
for force in "below_mu", "above_mu": # for force in "below_mu", "above_mu":
with self.subTest(force=force): # with self.subTest(force=force):
external_wrench_b = torch.zeros((num_cubes, 1, 6), device=sim.device) # external_wrench_b = torch.zeros((num_cubes, 1, 6), device=sim.device)
if force == "below_mu": # if force == "below_mu":
external_wrench_b[:, 0, 0] = static_friction_coefficient * 0.999 # external_wrench_b[:, 0, 0] = static_friction_coefficient * 0.999
else: # else:
external_wrench_b[:, 0, 0] = static_friction_coefficient * 1.001 # external_wrench_b[:, 0, 0] = static_friction_coefficient * 1.001
cube_object.set_external_force_and_torque( # cube_object.set_external_force_and_torque(
external_wrench_b[..., :3], # external_wrench_b[..., :3],
external_wrench_b[..., 3:], # external_wrench_b[..., 3:],
) # )
# Get root state # # Get root state
initial_root_state = cube_object.data.root_state_w # initial_root_state = cube_object.data.root_state_w
# Simulate physics # # Simulate physics
for _ in range(10): # for _ in range(10):
# perform rendering # # perform rendering
sim.step() # sim.step()
# update object # # update object
cube_object.update(sim.cfg.dt) # cube_object.update(sim.cfg.dt)
if force == "below_mu": # if force == "below_mu":
# Assert that the block has not moved # # Assert that the block has not moved
torch.testing.assert_close( # torch.testing.assert_close(
cube_object.data.root_state_w, initial_root_state, rtol=1e-5, atol=1e-5 # cube_object.data.root_state_w, initial_root_state, rtol=1e-5, atol=1e-5
) # )
else: # else:
torch.testing.assert_close( # torch.testing.assert_close(
cube_object.data.root_state_w, initial_root_state, rtol=1e-5, atol=1e-5 # cube_object.data.root_state_w, initial_root_state, rtol=1e-5, atol=1e-5
) # )
def test_rigid_body_with_restitution(self): # def test_rigid_body_with_restitution(self):
"""Test that restitution when applied to rigid object works as expected. # """Test that restitution when applied to rigid object works as expected.
This test works by dropping a block from a height and checking if the block bounces or not based on the # This test works by dropping a block from a height and checking if the block bounces or not based on the
restitution value set for the object. We set the restitution to be non-zero and drop the block from a height. # restitution value set for the object. We set the restitution to be non-zero and drop the block from a height.
When the restitution is 0, the block should not bounce. When the restitution is 1, the block should bounce # When the restitution is 0, the block should not bounce. When the restitution is 1, the block should bounce
with the same energy. When the restitution is between 0 and 1, the block should bounce with less energy. # with the same energy. When the restitution is between 0 and 1, the block should bounce with less energy.
""" # """
for num_cubes in (1, 2): # for num_cubes in (1, 2):
for device in ("cuda:0", "cpu"): # for device in ("cuda:0", "cpu"):
with self.subTest(num_cubes=num_cubes, device=device): # with self.subTest(num_cubes=num_cubes, device=device):
with build_simulation_context(device=device, add_ground_plane=True, auto_add_lighting=True) as sim: # with build_simulation_context(device=device, add_ground_plane=True, auto_add_lighting=True) as sim:
cube_object, _ = generate_cubes_scene(num_cubes=num_cubes, height=1.0, device=device) # cube_object, _ = generate_cubes_scene(num_cubes=num_cubes, height=1.0, device=device)
# Create ground plane such that has a restitution of 1.0 (perfectly elastic collision) # # Create ground plane such that has a restitution of 1.0 (perfectly elastic collision)
cfg = sim_utils.GroundPlaneCfg( # cfg = sim_utils.GroundPlaneCfg(
physics_material=materials.RigidBodyMaterialCfg( # physics_material=materials.RigidBodyMaterialCfg(
restitution=1.0, # restitution=1.0,
) # )
) # )
cfg.func("/World/GroundPlane", cfg) # cfg.func("/World/GroundPlane", cfg)
indices = torch.tensor(range(num_cubes), dtype=torch.int) # indices = torch.tensor(range(num_cubes), dtype=torch.int)
# Play sim # # Play sim
sim.reset() # sim.reset()
# 3 cases: inelastic, partially elastic, elastic # # 3 cases: inelastic, partially elastic, elastic
# inelastic: resitution = 0, block should not bounce # # inelastic: resitution = 0, block should not bounce
# partially elastic: 0 <= restitution <= 1, block should bounce with less energy # # partially elastic: 0 <= restitution <= 1, block should bounce with less energy
# elastic: restitution = 1, block should bounce with same energy # # elastic: restitution = 1, block should bounce with same energy
for expected_collision_type in "inelastic", "partially_elastic", "elastic": # for expected_collision_type in "inelastic", "partially_elastic", "elastic":
root_state = torch.zeros(1, 13, device=sim.device) # root_state = torch.zeros(1, 13, device=sim.device)
root_state[0, 3] = 1.0 # To make orientation a quaternion # root_state[0, 3] = 1.0 # To make orientation a quaternion
root_state[0, 2] = 0.1 # Set an initial drop height # root_state[0, 2] = 0.1 # Set an initial drop height
root_state[0, 9] = -1.0 # Set an initial downward velocity # root_state[0, 9] = -1.0 # Set an initial downward velocity
cube_object.write_root_state_to_sim(root_state=root_state) # cube_object.write_root_state_to_sim(root_state=root_state)
prev_z_velocity = 0.0 # prev_z_velocity = 0.0
curr_z_velocity = 0.0 # curr_z_velocity = 0.0
with self.subTest(expected_collision_type=expected_collision_type): # with self.subTest(expected_collision_type=expected_collision_type):
# cube_object.reset() # # cube_object.reset()
# Set static friction to be non-zero # # Set static friction to be non-zero
if expected_collision_type == "inelastic": # if expected_collision_type == "inelastic":
restitution_coefficient = 0.0 # restitution_coefficient = 0.0
elif expected_collision_type == "partially_elastic": # elif expected_collision_type == "partially_elastic":
restitution_coefficient = 0.5 # restitution_coefficient = 0.5
else: # else:
restitution_coefficient = 1.0 # restitution_coefficient = 1.0
restitution = 0.5 # restitution = 0.5
static_friction = torch.zeros(num_cubes, 1) # static_friction = torch.zeros(num_cubes, 1)
dynamic_friction = torch.zeros(num_cubes, 1) # dynamic_friction = torch.zeros(num_cubes, 1)
restitution = torch.Tensor([[restitution_coefficient]] * num_cubes) # restitution = torch.Tensor([[restitution_coefficient]] * num_cubes)
cube_object_materials = torch.cat( # cube_object_materials = torch.cat(
[static_friction, dynamic_friction, restitution], dim=-1 # [static_friction, dynamic_friction, restitution], dim=-1
) # )
# Add friction to cube # # Add friction to cube
cube_object.root_physx_view.set_material_properties(cube_object_materials, indices) # cube_object.root_physx_view.set_material_properties(cube_object_materials, indices)
curr_z_velocity = cube_object.data.root_lin_vel_w[:, 2] # curr_z_velocity = cube_object.data.root_lin_vel_w[:, 2]
while torch.all(curr_z_velocity <= 0.0): # while torch.all(curr_z_velocity <= 0.0):
# Simulate physics # # Simulate physics
curr_z_velocity = cube_object.data.root_lin_vel_w[:, 2] # curr_z_velocity = cube_object.data.root_lin_vel_w[:, 2]
# perform rendering # # perform rendering
sim.step() # sim.step()
# update object # # update object
cube_object.update(sim.cfg.dt) # cube_object.update(sim.cfg.dt)
if torch.all(curr_z_velocity <= 0.0): # if torch.all(curr_z_velocity <= 0.0):
# Still in the air # # Still in the air
prev_z_velocity = curr_z_velocity # prev_z_velocity = curr_z_velocity
# We have made contact with the ground and can verify expected collision type # # We have made contact with the ground and can verify expected collision type
# based on how velocity has changed after the collision # # based on how velocity has changed after the collision
if expected_collision_type == "inelastic": # if expected_collision_type == "inelastic":
# Assert that the block has lost most energy by checking that the z velocity is < 1/2 previous # # Assert that the block has lost most energy by checking that the z velocity is < 1/2 previous
# velocity. This is because the floor's resitution means it will bounce back an object that itself # # velocity. This is because the floor's resitution means it will bounce back an object that itself
# has restitution set to 0.0 # # has restitution set to 0.0
self.assertTrue(torch.all(torch.le(curr_z_velocity / 2, abs(prev_z_velocity)))) # self.assertTrue(torch.all(torch.le(curr_z_velocity / 2, abs(prev_z_velocity))))
elif expected_collision_type == "partially_elastic": # elif expected_collision_type == "partially_elastic":
# Assert that the block has lost some energy by checking that the z velocity is less # # Assert that the block has lost some energy by checking that the z velocity is less
self.assertTrue(torch.all(torch.le(abs(curr_z_velocity), abs(prev_z_velocity)))) # self.assertTrue(torch.all(torch.le(abs(curr_z_velocity), abs(prev_z_velocity))))
elif expected_collision_type == "elastic": # elif expected_collision_type == "elastic":
# Assert that the block has not lost any energy by checking that the z velocity is the same # # Assert that the block has not lost any energy by checking that the z velocity is the same
torch.testing.assert_close(abs(curr_z_velocity), abs(prev_z_velocity)) # torch.testing.assert_close(abs(curr_z_velocity), abs(prev_z_velocity))
def test_rigid_body_set_mass(self): def test_rigid_body_set_mass(self):
"""Test getting and setting mass of rigid object.""" """Test getting and setting mass of rigid object."""
......
...@@ -34,7 +34,7 @@ def base_up_proj(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCf ...@@ -34,7 +34,7 @@ def base_up_proj(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCf
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
# compute base up vector # compute base up vector
base_up_vec = math_utils.quat_rotate(asset.data.root_quat_w, -asset.GRAVITY_VEC_W) base_up_vec = -asset.data.projected_gravity_b
return base_up_vec[:, 2].unsqueeze(-1) return base_up_vec[:, 2].unsqueeze(-1)
...@@ -50,7 +50,7 @@ def base_heading_proj( ...@@ -50,7 +50,7 @@ def base_heading_proj(
to_target_pos[:, 2] = 0.0 to_target_pos[:, 2] = 0.0
to_target_dir = math_utils.normalize(to_target_pos) to_target_dir = math_utils.normalize(to_target_pos)
# compute base forward vector # compute base forward vector
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.FORWARD_VEC_B) heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.data.forward_vec_b)
# compute dot product between heading and target direction # compute dot product between heading and target direction
heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1)) heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment