Unverified Commit a8fb6b04 authored by David Hoeller's avatar David Hoeller Committed by GitHub

Improves object and articulation data update logic (#491)

Improves the `omni.isaac.lab.assets.RigidObjectData` and
`omni.isaac.lab.assets.ArticulationData` buffers to update their data
lazily. Before, all the data was always updated at every step, even if
it was not used by the task. This improves performance for all the
tasks.

## Type of change

- New feature

## Screenshots

For the Cartpole task, blue is before, red is this MR.

![image](https://github.com/isaac-sim/IsaacLab/assets/11162199/99cbe73d-fa89-4de9-a94c-306cc7e0766f)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./isaaclab.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 958b863b
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.17.11"
version = "0.17.12"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.17.12 (2024-06-13)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added the class :class:`omni.isaac.lab.utils.buffers.TimestampedBuffer` to store timestamped data.
Changed
^^^^^^^
* Added time-stamped buffers in the classes :class:`omni.isaac.lab.assets.RigidObjectData` and :class:`omni.isaac.lab.assets.ArticulationData`
to update some values lazily and avoid unnecessary computations between physics updates. Before, all the data was always
updated at every step, even if it was not used by the task.
0.17.11 (2024-05-30)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -95,8 +95,6 @@ class Articulation(RigidObject):
cfg: A configuration instance.
"""
super().__init__(cfg)
# container for data access
self._data = ArticulationData()
# data for storing actuator group
self.actuators: dict[str, ActuatorBase] = dict.fromkeys(self.cfg.actuators.keys())
......@@ -205,30 +203,6 @@ class Articulation(RigidObject):
self.root_physx_view.set_dof_position_targets(self._joint_pos_target_sim, self._ALL_INDICES)
self.root_physx_view.set_dof_velocity_targets(self._joint_vel_target_sim, self._ALL_INDICES)
def update(self, dt: float):
# -- root state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz)
self._data.root_state_w[:, :7] = self.root_physx_view.get_root_transforms()
self._data.root_state_w[:, 3:7] = math_utils.convert_quat(self._data.root_state_w[:, 3:7], to="wxyz")
self._data.root_state_w[:, 7:] = self.root_physx_view.get_root_velocities()
# -- body-state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz)
self._data.body_state_w[..., :7] = self.root_physx_view.get_link_transforms()
self._data.body_state_w[..., 3:7] = math_utils.convert_quat(self._data.body_state_w[..., 3:7], to="wxyz")
self._data.body_state_w[..., 7:] = self.root_physx_view.get_link_velocities()
# -- joint states
self._data.joint_pos[:] = self.root_physx_view.get_dof_positions()
self._data.joint_vel[:] = self.root_physx_view.get_dof_velocities()
if dt > 0.0:
self._data.joint_acc[:] = (self._data.joint_vel - self._previous_joint_vel) / dt
# -- update common data
# note: these are computed in the base class
self._update_common_data(dt)
# -- update history buffers
self._previous_joint_vel[:] = self._data.joint_vel[:]
def find_joints(
self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False
) -> tuple[list[int], list[str]]:
......@@ -320,6 +294,7 @@ class Articulation(RigidObject):
# note: we need to do this here since tensors are not set into simulation until step.
# set into internal buffers
self._data.root_state_w[env_ids, 7:] = root_velocity.clone()
self._data.body_acc_w[env_ids] = 0.0
# set into simulation
self.root_physx_view.set_root_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids)
......@@ -350,7 +325,7 @@ class Articulation(RigidObject):
# set into internal buffers
self._data.joint_pos[env_ids, joint_ids] = position
self._data.joint_vel[env_ids, joint_ids] = velocity
self._previous_joint_vel[env_ids, joint_ids] = velocity
self._data._previous_joint_vel[env_ids, joint_ids] = velocity
self._data.joint_acc[env_ids, joint_ids] = 0.0
# set into simulation
self.root_physx_view.set_dof_positions(self._data.joint_pos, indices=physx_env_ids)
......@@ -861,6 +836,9 @@ class Articulation(RigidObject):
if set(physx_body_names) != set(self.body_names):
raise RuntimeError("Failed to parse all bodies properly in the articulation.")
# container for data access
self._data = ArticulationData(self.root_physx_view, self.device)
# create buffers
self._create_buffers()
# process configuration
......@@ -877,30 +855,25 @@ class Articulation(RigidObject):
def _create_buffers(self):
# allocate buffers
super()._create_buffers()
# history buffers
self._previous_joint_vel = torch.zeros(self.num_instances, self.num_joints, device=self.device)
# asset data
# -- properties
self._data.joint_names = self.joint_names
# -- joint states
self._data.joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.joint_vel = torch.zeros_like(self._data.joint_pos)
self._data.joint_acc = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_pos = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_vel = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.default_joint_vel = torch.zeros_like(self._data.default_joint_pos)
# -- joint commands
self._data.joint_pos_target = torch.zeros_like(self._data.joint_pos)
self._data.joint_vel_target = torch.zeros_like(self._data.joint_pos)
self._data.joint_effort_target = torch.zeros_like(self._data.joint_pos)
self._data.joint_stiffness = torch.zeros_like(self._data.joint_pos)
self._data.joint_damping = torch.zeros_like(self._data.joint_pos)
self._data.joint_armature = torch.zeros_like(self._data.joint_pos)
self._data.joint_friction = torch.zeros_like(self._data.joint_pos)
self._data.joint_pos_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_vel_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_effort_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_stiffness = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_damping = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_armature = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_friction = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device)
# -- joint commands (explicit)
self._data.computed_torque = torch.zeros_like(self._data.joint_pos)
self._data.applied_torque = torch.zeros_like(self._data.joint_pos)
self._data.computed_torque = torch.zeros_like(self._data.default_joint_pos)
self._data.applied_torque = torch.zeros_like(self._data.default_joint_pos)
# -- tendons
if self.num_fixed_tendons > 0:
self._data.fixed_tendon_stiffness = torch.zeros(
......
......@@ -56,8 +56,6 @@ class RigidObject(AssetBase):
cfg: A configuration instance.
"""
super().__init__(cfg)
# container for data access
self._data = RigidObjectData()
"""
Properties
......@@ -116,8 +114,6 @@ class RigidObject(AssetBase):
# reset external wrench
self._external_force_b[env_ids] = 0.0
self._external_torque_b[env_ids] = 0.0
# reset last body vel
self._last_body_vel_w[env_ids] = 0.0
def write_data_to_sim(self):
"""Write external wrench to the simulation.
......@@ -137,16 +133,7 @@ class RigidObject(AssetBase):
)
def update(self, dt: float):
# -- root-state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz)
self._data.root_state_w[:, :7] = self.root_physx_view.get_transforms()
self._data.root_state_w[:, 3:7] = math_utils.convert_quat(self._data.root_state_w[:, 3:7], to="wxyz")
self._data.root_state_w[:, 7:] = self.root_physx_view.get_velocities()
# -- body-state (note: for rigid objects, we only have one body so we just copy the root state)
self._data.body_state_w[:] = self._data.root_state_w.view(-1, self.num_bodies, 13)
# -- update common data
self._update_common_data(dt)
self._data.update(dt)
def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys.
......@@ -219,6 +206,8 @@ class RigidObject(AssetBase):
# note: we need to do this here since tensors are not set into simulation until step.
# set into internal buffers
self._data.root_state_w[env_ids, 7:] = root_velocity.clone()
self._data._previous_body_vel_w[env_ids, 0] = root_velocity.clone()
self._data.body_acc_w[env_ids] = 0.0
# set into simulation
self.root_physx_view.set_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids)
......@@ -329,6 +318,9 @@ class RigidObject(AssetBase):
carb.log_info(f"Number of bodies: {self.num_bodies}")
carb.log_info(f"Body names: {self.body_names}")
# container for data access
self._data = RigidObjectData(self.root_physx_view, self.device)
# create buffers
self._create_buffers()
# process configuration
......@@ -343,35 +335,12 @@ class RigidObject(AssetBase):
self._ALL_BODY_INDICES = torch.arange(
self.root_physx_view.count * self.num_bodies, dtype=torch.long, device=self.device
)
self.GRAVITY_VEC_W = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self.num_instances, 1)
self.FORWARD_VEC_B = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self.num_instances, 1)
# external forces and torques
self.has_external_wrench = False
self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device)
self._external_torque_b = torch.zeros_like(self._external_force_b)
# asset data
# -- properties
self._data.body_names = self.body_names
# -- root states
self._data.root_state_w = torch.zeros(self.num_instances, 13, device=self.device)
self._data.root_state_w[:, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
self._data.default_root_state = torch.zeros_like(self._data.root_state_w)
self._data.default_root_state[:, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
# -- body states
self._data.body_state_w = torch.zeros(self.num_instances, self.num_bodies, 13, device=self.device)
self._data.body_state_w[:, :, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
# -- post-computed
self._data.root_vel_b = torch.zeros(self.num_instances, 6, device=self.device)
self._data.projected_gravity_b = torch.zeros(self.num_instances, 3, device=self.device)
self._data.heading_w = torch.zeros(self.num_instances, device=self.device)
self._data.body_acc_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
# history buffers for quantities
# -- used to compute body accelerations numerically
self._last_body_vel_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
# mass
self._data.default_mass = self.root_physx_view.get_masses().clone()
def _process_cfg(self):
......@@ -388,29 +357,6 @@ class RigidObject(AssetBase):
default_root_state = torch.tensor(default_root_state, dtype=torch.float, device=self.device)
self._data.default_root_state = default_root_state.repeat(self.num_instances, 1)
def _update_common_data(self, dt: float):
"""Update common quantities related to rigid objects.
Note:
This has been separated from the update function to allow for the child classes to
override the update function without having to worry about updating the common data.
"""
# -- body acceleration
if dt > 0.0:
self._data.body_acc_w[:] = (self._data.body_state_w[..., 7:] - self._last_body_vel_w) / dt
self._last_body_vel_w[:] = self._data.body_state_w[..., 7:]
# -- root state in body frame
self._data.root_vel_b[:, 0:3] = math_utils.quat_rotate_inverse(
self._data.root_quat_w, self._data.root_lin_vel_w
)
self._data.root_vel_b[:, 3:6] = math_utils.quat_rotate_inverse(
self._data.root_quat_w, self._data.root_ang_vel_w
)
self._data.projected_gravity_b[:] = math_utils.quat_rotate_inverse(self._data.root_quat_w, self.GRAVITY_VEC_W)
# -- heading direction of root
forward_w = math_utils.quat_apply(self._data.root_quat_w, self.FORWARD_VEC_B)
self._data.heading_w[:] = torch.atan2(forward_w[:, 1], forward_w[:, 0])
"""
Internal simulation callbacks.
"""
......
......@@ -4,69 +4,93 @@
# SPDX-License-Identifier: BSD-3-Clause
import torch
from dataclasses import dataclass
import omni.physics.tensors.impl.api as physx
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.utils.buffers import TimestampedBuffer
@dataclass
class RigidObjectData:
"""Data container for a rigid object."""
##
# Properties.
##
def __init__(self, root_physx_view: physx.RigidBodyView, device):
self.device = device
self._time_stamp = 0.0
self._root_physx_view: physx.RigidBodyView = root_physx_view
self.gravity_vec_w = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self._root_physx_view.count, 1)
self.forward_vec_b = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self._root_physx_view.count, 1)
self._previous_body_vel_w = torch.zeros((self._root_physx_view.count, 1, 6), device=self.device)
# Initialize the lazy buffers.
self._root_state_w: TimestampedBuffer = TimestampedBuffer()
self._body_acc_w: TimestampedBuffer = TimestampedBuffer()
def update(self, dt: float):
self._time_stamp += dt
# Trigger an update of the body acceleration buffer at a higher frequency since we do finite differencing.
self.body_acc_w
body_names: list[str] = None
"""Body names in the order parsed by the simulation view."""
##
# Default states.
# Defaults.
##
default_root_state: torch.Tensor = None
"""Default root state ``[pos, quat, lin_vel, ang_vel]`` in local environment frame. Shape is (num_instances, 13)."""
default_mass: torch.Tensor = None
""" Default mass provided by simulation. Shape is (num_instances, num_bodies)."""
##
# Frame states.
# Properties.
##
root_state_w: torch.Tensor = None
"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13)."""
root_vel_b: torch.Tensor = None
"""Root velocity `[lin_vel, ang_vel]` in base frame. Shape is (num_instances, 6)."""
projected_gravity_b: torch.Tensor = None
"""Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
heading_w: torch.Tensor = None
"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
Note:
This quantity is computed by assuming that the forward-direction of the base
frame is along x-direction, i.e. :math:`(1, 0, 0)`.
"""
body_state_w: torch.Tensor = None
"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame.
Shape is (num_instances, num_bodies, 13)."""
@property
def root_state_w(self):
"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13)."""
if self._root_state_w.update_timestamp < self._time_stamp:
pose = self._root_physx_view.get_transforms().clone()
pose[:, 3:7] = math_utils.convert_quat(pose[:, 3:7], to="wxyz")
velocity = self._root_physx_view.get_velocities()
self._root_state_w.data = torch.cat((pose, velocity), dim=-1)
self._root_state_w.update_timestamp = self._time_stamp
return self._root_state_w.data
body_acc_w: torch.Tensor = None
"""Acceleration of all bodies. Shape is (num_instances, num_bodies, 6).
@property
def body_state_w(self):
"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame. Shape is (num_instances, 1, 13)."""
return self.root_state_w.view(-1, 1, 13)
Note:
This quantity is computed based on the rigid body state from the last step.
"""
@property
def body_acc_w(self):
"""Acceleration of all bodies. Shape is (num_instances, 1, 6)."""
if self._body_acc_w.update_timestamp < self._time_stamp:
self._body_acc_w.data = (self.body_vel_w - self._previous_body_vel_w) / (
self._time_stamp - self._body_acc_w.update_timestamp
)
self._previous_body_vel_w[:] = self.body_vel_w
self._body_acc_w.update_timestamp = self._time_stamp
return self._body_acc_w.data
##
# Default rigid body properties
##
@property
def projected_gravity_b(self):
"""Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
return math_utils.quat_rotate_inverse(self.root_quat_w, self.gravity_vec_w)
default_mass: torch.Tensor = None
""" Default mass provided by simulation. Shape is (num_instances, num_bodies)."""
@property
def heading_w(self):
"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
"""
Properties
"""
Note:
This quantity is computed by assuming that the forward-direction of the base
frame is along x-direction, i.e. :math:`(1, 0, 0)`.
"""
forward_w = math_utils.quat_apply(self.root_quat_w, self.forward_vec_b)
return torch.atan2(forward_w[:, 1], forward_w[:, 0])
@property
def root_pos_w(self) -> torch.Tensor:
......@@ -96,12 +120,12 @@ class RigidObjectData:
@property
def root_lin_vel_b(self) -> torch.Tensor:
"""Root linear velocity in base frame. Shape is (num_instances, 3)."""
return self.root_vel_b[:, 0:3]
return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_lin_vel_w)
@property
def root_ang_vel_b(self) -> torch.Tensor:
"""Root angular velocity in base world frame. Shape is (num_instances, 3)."""
return self.root_vel_b[:, 3:6]
return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_ang_vel_w)
@property
def body_pos_w(self) -> torch.Tensor:
......@@ -127,13 +151,3 @@ class RigidObjectData:
def body_ang_vel_w(self) -> torch.Tensor:
"""Angular velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_state_w[..., 10:13]
@property
def body_lin_acc_w(self) -> torch.Tensor:
"""Linear acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 0:3]
@property
def body_ang_acc_w(self) -> torch.Tensor:
"""Angular acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 3:6]
......@@ -7,3 +7,4 @@
from .circular_buffer import BatchedCircularBuffer
from .delay_buffer import DelayBuffer
from .timestamped_buffer import TimestampedBuffer
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import torch
from dataclasses import dataclass
@dataclass
class TimestampedBuffer:
"""Buffer to hold timestamped data.
Such a buffer is useful to check whether data is outdated and needs to be refreshed to create lazy buffers.
"""
data: torch.Tensor = None
"""Data stored in the buffer."""
update_timestamp: float = -1.0
"""Timestamp of the last update of the buffer."""
......@@ -34,7 +34,7 @@ def base_up_proj(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCf
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
# compute base up vector
base_up_vec = math_utils.quat_rotate(asset.data.root_quat_w, -asset.GRAVITY_VEC_W)
base_up_vec = -asset.data.projected_gravity_b
return base_up_vec[:, 2].unsqueeze(-1)
......@@ -50,7 +50,7 @@ def base_heading_proj(
to_target_pos[:, 2] = 0.0
to_target_dir = math_utils.normalize(to_target_pos)
# compute base forward vector
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.FORWARD_VEC_B)
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.data.forward_vec_b)
# compute dot product between heading and target direction
heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment