Unverified Commit a8fb6b04 authored by David Hoeller's avatar David Hoeller Committed by GitHub

Improves object and articulation data update logic (#491)

Improves the `omni.isaac.lab.assets.RigidObjectData` and
`omni.isaac.lab.assets.ArticulationData` buffers to update their data
lazily. Before, all the data was always updated at every step, even if
it was not used by the task. This improves performance for all the
tasks.

## Type of change

- New feature

## Screenshots

For the Cartpole task, blue is before, red is this MR.

![image](https://github.com/isaac-sim/IsaacLab/assets/11162199/99cbe73d-fa89-4de9-a94c-306cc7e0766f)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./isaaclab.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 958b863b
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.17.11" version = "0.17.12"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.17.12 (2024-06-13)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added the class :class:`omni.isaac.lab.utils.buffers.TimestampedBuffer` to store timestamped data.
Changed
^^^^^^^
* Added time-stamped buffers in the classes :class:`omni.isaac.lab.assets.RigidObjectData` and :class:`omni.isaac.lab.assets.ArticulationData`
to update some values lazily and avoid unnecessary computations between physics updates. Before, all the data was always
updated at every step, even if it was not used by the task.
0.17.11 (2024-05-30) 0.17.11 (2024-05-30)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -95,8 +95,6 @@ class Articulation(RigidObject): ...@@ -95,8 +95,6 @@ class Articulation(RigidObject):
cfg: A configuration instance. cfg: A configuration instance.
""" """
super().__init__(cfg) super().__init__(cfg)
# container for data access
self._data = ArticulationData()
# data for storing actuator group # data for storing actuator group
self.actuators: dict[str, ActuatorBase] = dict.fromkeys(self.cfg.actuators.keys()) self.actuators: dict[str, ActuatorBase] = dict.fromkeys(self.cfg.actuators.keys())
...@@ -205,30 +203,6 @@ class Articulation(RigidObject): ...@@ -205,30 +203,6 @@ class Articulation(RigidObject):
self.root_physx_view.set_dof_position_targets(self._joint_pos_target_sim, self._ALL_INDICES) self.root_physx_view.set_dof_position_targets(self._joint_pos_target_sim, self._ALL_INDICES)
self.root_physx_view.set_dof_velocity_targets(self._joint_vel_target_sim, self._ALL_INDICES) self.root_physx_view.set_dof_velocity_targets(self._joint_vel_target_sim, self._ALL_INDICES)
def update(self, dt: float):
# -- root state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz)
self._data.root_state_w[:, :7] = self.root_physx_view.get_root_transforms()
self._data.root_state_w[:, 3:7] = math_utils.convert_quat(self._data.root_state_w[:, 3:7], to="wxyz")
self._data.root_state_w[:, 7:] = self.root_physx_view.get_root_velocities()
# -- body-state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz)
self._data.body_state_w[..., :7] = self.root_physx_view.get_link_transforms()
self._data.body_state_w[..., 3:7] = math_utils.convert_quat(self._data.body_state_w[..., 3:7], to="wxyz")
self._data.body_state_w[..., 7:] = self.root_physx_view.get_link_velocities()
# -- joint states
self._data.joint_pos[:] = self.root_physx_view.get_dof_positions()
self._data.joint_vel[:] = self.root_physx_view.get_dof_velocities()
if dt > 0.0:
self._data.joint_acc[:] = (self._data.joint_vel - self._previous_joint_vel) / dt
# -- update common data
# note: these are computed in the base class
self._update_common_data(dt)
# -- update history buffers
self._previous_joint_vel[:] = self._data.joint_vel[:]
def find_joints( def find_joints(
self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False self, name_keys: str | Sequence[str], joint_subset: list[str] | None = None, preserve_order: bool = False
) -> tuple[list[int], list[str]]: ) -> tuple[list[int], list[str]]:
...@@ -320,6 +294,7 @@ class Articulation(RigidObject): ...@@ -320,6 +294,7 @@ class Articulation(RigidObject):
# note: we need to do this here since tensors are not set into simulation until step. # note: we need to do this here since tensors are not set into simulation until step.
# set into internal buffers # set into internal buffers
self._data.root_state_w[env_ids, 7:] = root_velocity.clone() self._data.root_state_w[env_ids, 7:] = root_velocity.clone()
self._data.body_acc_w[env_ids] = 0.0
# set into simulation # set into simulation
self.root_physx_view.set_root_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids) self.root_physx_view.set_root_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids)
...@@ -350,7 +325,7 @@ class Articulation(RigidObject): ...@@ -350,7 +325,7 @@ class Articulation(RigidObject):
# set into internal buffers # set into internal buffers
self._data.joint_pos[env_ids, joint_ids] = position self._data.joint_pos[env_ids, joint_ids] = position
self._data.joint_vel[env_ids, joint_ids] = velocity self._data.joint_vel[env_ids, joint_ids] = velocity
self._previous_joint_vel[env_ids, joint_ids] = velocity self._data._previous_joint_vel[env_ids, joint_ids] = velocity
self._data.joint_acc[env_ids, joint_ids] = 0.0 self._data.joint_acc[env_ids, joint_ids] = 0.0
# set into simulation # set into simulation
self.root_physx_view.set_dof_positions(self._data.joint_pos, indices=physx_env_ids) self.root_physx_view.set_dof_positions(self._data.joint_pos, indices=physx_env_ids)
...@@ -861,6 +836,9 @@ class Articulation(RigidObject): ...@@ -861,6 +836,9 @@ class Articulation(RigidObject):
if set(physx_body_names) != set(self.body_names): if set(physx_body_names) != set(self.body_names):
raise RuntimeError("Failed to parse all bodies properly in the articulation.") raise RuntimeError("Failed to parse all bodies properly in the articulation.")
# container for data access
self._data = ArticulationData(self.root_physx_view, self.device)
# create buffers # create buffers
self._create_buffers() self._create_buffers()
# process configuration # process configuration
...@@ -877,30 +855,25 @@ class Articulation(RigidObject): ...@@ -877,30 +855,25 @@ class Articulation(RigidObject):
def _create_buffers(self): def _create_buffers(self):
# allocate buffers # allocate buffers
super()._create_buffers() super()._create_buffers()
# history buffers
self._previous_joint_vel = torch.zeros(self.num_instances, self.num_joints, device=self.device)
# asset data # asset data
# -- properties # -- properties
self._data.joint_names = self.joint_names self._data.joint_names = self.joint_names
# -- joint states
self._data.joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.default_joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.joint_vel = torch.zeros_like(self._data.joint_pos) self._data.default_joint_vel = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_acc = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_pos = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_vel = torch.zeros_like(self._data.joint_pos)
# -- joint commands # -- joint commands
self._data.joint_pos_target = torch.zeros_like(self._data.joint_pos) self._data.joint_pos_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_vel_target = torch.zeros_like(self._data.joint_pos) self._data.joint_vel_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_effort_target = torch.zeros_like(self._data.joint_pos) self._data.joint_effort_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_stiffness = torch.zeros_like(self._data.joint_pos) self._data.joint_stiffness = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_damping = torch.zeros_like(self._data.joint_pos) self._data.joint_damping = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_armature = torch.zeros_like(self._data.joint_pos) self._data.joint_armature = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_friction = torch.zeros_like(self._data.joint_pos) self._data.joint_friction = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device) self._data.joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device)
# -- joint commands (explicit) # -- joint commands (explicit)
self._data.computed_torque = torch.zeros_like(self._data.joint_pos) self._data.computed_torque = torch.zeros_like(self._data.default_joint_pos)
self._data.applied_torque = torch.zeros_like(self._data.joint_pos) self._data.applied_torque = torch.zeros_like(self._data.default_joint_pos)
# -- tendons # -- tendons
if self.num_fixed_tendons > 0: if self.num_fixed_tendons > 0:
self._data.fixed_tendon_stiffness = torch.zeros( self._data.fixed_tendon_stiffness = torch.zeros(
......
...@@ -56,8 +56,6 @@ class RigidObject(AssetBase): ...@@ -56,8 +56,6 @@ class RigidObject(AssetBase):
cfg: A configuration instance. cfg: A configuration instance.
""" """
super().__init__(cfg) super().__init__(cfg)
# container for data access
self._data = RigidObjectData()
""" """
Properties Properties
...@@ -116,8 +114,6 @@ class RigidObject(AssetBase): ...@@ -116,8 +114,6 @@ class RigidObject(AssetBase):
# reset external wrench # reset external wrench
self._external_force_b[env_ids] = 0.0 self._external_force_b[env_ids] = 0.0
self._external_torque_b[env_ids] = 0.0 self._external_torque_b[env_ids] = 0.0
# reset last body vel
self._last_body_vel_w[env_ids] = 0.0
def write_data_to_sim(self): def write_data_to_sim(self):
"""Write external wrench to the simulation. """Write external wrench to the simulation.
...@@ -137,16 +133,7 @@ class RigidObject(AssetBase): ...@@ -137,16 +133,7 @@ class RigidObject(AssetBase):
) )
def update(self, dt: float): def update(self, dt: float):
# -- root-state (note: we roll the quaternion to match the convention used in Isaac Sim -- wxyz) self._data.update(dt)
self._data.root_state_w[:, :7] = self.root_physx_view.get_transforms()
self._data.root_state_w[:, 3:7] = math_utils.convert_quat(self._data.root_state_w[:, 3:7], to="wxyz")
self._data.root_state_w[:, 7:] = self.root_physx_view.get_velocities()
# -- body-state (note: for rigid objects, we only have one body so we just copy the root state)
self._data.body_state_w[:] = self._data.root_state_w.view(-1, self.num_bodies, 13)
# -- update common data
self._update_common_data(dt)
def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]: def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]:
"""Find bodies in the articulation based on the name keys. """Find bodies in the articulation based on the name keys.
...@@ -219,6 +206,8 @@ class RigidObject(AssetBase): ...@@ -219,6 +206,8 @@ class RigidObject(AssetBase):
# note: we need to do this here since tensors are not set into simulation until step. # note: we need to do this here since tensors are not set into simulation until step.
# set into internal buffers # set into internal buffers
self._data.root_state_w[env_ids, 7:] = root_velocity.clone() self._data.root_state_w[env_ids, 7:] = root_velocity.clone()
self._data._previous_body_vel_w[env_ids, 0] = root_velocity.clone()
self._data.body_acc_w[env_ids] = 0.0
# set into simulation # set into simulation
self.root_physx_view.set_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids) self.root_physx_view.set_velocities(self._data.root_state_w[:, 7:], indices=physx_env_ids)
...@@ -329,6 +318,9 @@ class RigidObject(AssetBase): ...@@ -329,6 +318,9 @@ class RigidObject(AssetBase):
carb.log_info(f"Number of bodies: {self.num_bodies}") carb.log_info(f"Number of bodies: {self.num_bodies}")
carb.log_info(f"Body names: {self.body_names}") carb.log_info(f"Body names: {self.body_names}")
# container for data access
self._data = RigidObjectData(self.root_physx_view, self.device)
# create buffers # create buffers
self._create_buffers() self._create_buffers()
# process configuration # process configuration
...@@ -343,35 +335,12 @@ class RigidObject(AssetBase): ...@@ -343,35 +335,12 @@ class RigidObject(AssetBase):
self._ALL_BODY_INDICES = torch.arange( self._ALL_BODY_INDICES = torch.arange(
self.root_physx_view.count * self.num_bodies, dtype=torch.long, device=self.device self.root_physx_view.count * self.num_bodies, dtype=torch.long, device=self.device
) )
self.GRAVITY_VEC_W = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self.num_instances, 1)
self.FORWARD_VEC_B = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self.num_instances, 1)
# external forces and torques # external forces and torques
self.has_external_wrench = False self.has_external_wrench = False
self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device) self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device)
self._external_torque_b = torch.zeros_like(self._external_force_b) self._external_torque_b = torch.zeros_like(self._external_force_b)
# asset data
# -- properties
self._data.body_names = self.body_names self._data.body_names = self.body_names
# -- root states
self._data.root_state_w = torch.zeros(self.num_instances, 13, device=self.device)
self._data.root_state_w[:, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
self._data.default_root_state = torch.zeros_like(self._data.root_state_w)
self._data.default_root_state[:, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
# -- body states
self._data.body_state_w = torch.zeros(self.num_instances, self.num_bodies, 13, device=self.device)
self._data.body_state_w[:, :, 3] = 1.0 # set default quaternion to (1, 0, 0, 0)
# -- post-computed
self._data.root_vel_b = torch.zeros(self.num_instances, 6, device=self.device)
self._data.projected_gravity_b = torch.zeros(self.num_instances, 3, device=self.device)
self._data.heading_w = torch.zeros(self.num_instances, device=self.device)
self._data.body_acc_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
# history buffers for quantities
# -- used to compute body accelerations numerically
self._last_body_vel_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
# mass
self._data.default_mass = self.root_physx_view.get_masses().clone() self._data.default_mass = self.root_physx_view.get_masses().clone()
def _process_cfg(self): def _process_cfg(self):
...@@ -388,29 +357,6 @@ class RigidObject(AssetBase): ...@@ -388,29 +357,6 @@ class RigidObject(AssetBase):
default_root_state = torch.tensor(default_root_state, dtype=torch.float, device=self.device) default_root_state = torch.tensor(default_root_state, dtype=torch.float, device=self.device)
self._data.default_root_state = default_root_state.repeat(self.num_instances, 1) self._data.default_root_state = default_root_state.repeat(self.num_instances, 1)
def _update_common_data(self, dt: float):
"""Update common quantities related to rigid objects.
Note:
This has been separated from the update function to allow for the child classes to
override the update function without having to worry about updating the common data.
"""
# -- body acceleration
if dt > 0.0:
self._data.body_acc_w[:] = (self._data.body_state_w[..., 7:] - self._last_body_vel_w) / dt
self._last_body_vel_w[:] = self._data.body_state_w[..., 7:]
# -- root state in body frame
self._data.root_vel_b[:, 0:3] = math_utils.quat_rotate_inverse(
self._data.root_quat_w, self._data.root_lin_vel_w
)
self._data.root_vel_b[:, 3:6] = math_utils.quat_rotate_inverse(
self._data.root_quat_w, self._data.root_ang_vel_w
)
self._data.projected_gravity_b[:] = math_utils.quat_rotate_inverse(self._data.root_quat_w, self.GRAVITY_VEC_W)
# -- heading direction of root
forward_w = math_utils.quat_apply(self._data.root_quat_w, self.FORWARD_VEC_B)
self._data.heading_w[:] = torch.atan2(forward_w[:, 1], forward_w[:, 0])
""" """
Internal simulation callbacks. Internal simulation callbacks.
""" """
......
...@@ -4,69 +4,93 @@ ...@@ -4,69 +4,93 @@
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
import torch import torch
from dataclasses import dataclass
import omni.physics.tensors.impl.api as physx
import omni.isaac.lab.utils.math as math_utils
from omni.isaac.lab.utils.buffers import TimestampedBuffer
@dataclass
class RigidObjectData: class RigidObjectData:
"""Data container for a rigid object.""" """Data container for a rigid object."""
## def __init__(self, root_physx_view: physx.RigidBodyView, device):
# Properties. self.device = device
## self._time_stamp = 0.0
self._root_physx_view: physx.RigidBodyView = root_physx_view
self.gravity_vec_w = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self._root_physx_view.count, 1)
self.forward_vec_b = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self._root_physx_view.count, 1)
self._previous_body_vel_w = torch.zeros((self._root_physx_view.count, 1, 6), device=self.device)
# Initialize the lazy buffers.
self._root_state_w: TimestampedBuffer = TimestampedBuffer()
self._body_acc_w: TimestampedBuffer = TimestampedBuffer()
def update(self, dt: float):
self._time_stamp += dt
# Trigger an update of the body acceleration buffer at a higher frequency since we do finite differencing.
self.body_acc_w
body_names: list[str] = None body_names: list[str] = None
"""Body names in the order parsed by the simulation view.""" """Body names in the order parsed by the simulation view."""
## ##
# Default states. # Defaults.
## ##
default_root_state: torch.Tensor = None default_root_state: torch.Tensor = None
"""Default root state ``[pos, quat, lin_vel, ang_vel]`` in local environment frame. Shape is (num_instances, 13).""" """Default root state ``[pos, quat, lin_vel, ang_vel]`` in local environment frame. Shape is (num_instances, 13)."""
default_mass: torch.Tensor = None
""" Default mass provided by simulation. Shape is (num_instances, num_bodies)."""
## ##
# Frame states. # Properties.
## ##
root_state_w: torch.Tensor = None @property
"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13).""" def root_state_w(self):
"""Root state ``[pos, quat, lin_vel, ang_vel]`` in simulation world frame. Shape is (num_instances, 13)."""
root_vel_b: torch.Tensor = None if self._root_state_w.update_timestamp < self._time_stamp:
"""Root velocity `[lin_vel, ang_vel]` in base frame. Shape is (num_instances, 6).""" pose = self._root_physx_view.get_transforms().clone()
pose[:, 3:7] = math_utils.convert_quat(pose[:, 3:7], to="wxyz")
projected_gravity_b: torch.Tensor = None velocity = self._root_physx_view.get_velocities()
"""Projection of the gravity direction on base frame. Shape is (num_instances, 3).""" self._root_state_w.data = torch.cat((pose, velocity), dim=-1)
self._root_state_w.update_timestamp = self._time_stamp
heading_w: torch.Tensor = None return self._root_state_w.data
"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
Note:
This quantity is computed by assuming that the forward-direction of the base
frame is along x-direction, i.e. :math:`(1, 0, 0)`.
"""
body_state_w: torch.Tensor = None
"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame.
Shape is (num_instances, num_bodies, 13)."""
body_acc_w: torch.Tensor = None @property
"""Acceleration of all bodies. Shape is (num_instances, num_bodies, 6). def body_state_w(self):
"""State of all bodies `[pos, quat, lin_vel, ang_vel]` in simulation world frame. Shape is (num_instances, 1, 13)."""
return self.root_state_w.view(-1, 1, 13)
Note: @property
This quantity is computed based on the rigid body state from the last step. def body_acc_w(self):
""" """Acceleration of all bodies. Shape is (num_instances, 1, 6)."""
if self._body_acc_w.update_timestamp < self._time_stamp:
self._body_acc_w.data = (self.body_vel_w - self._previous_body_vel_w) / (
self._time_stamp - self._body_acc_w.update_timestamp
)
self._previous_body_vel_w[:] = self.body_vel_w
self._body_acc_w.update_timestamp = self._time_stamp
return self._body_acc_w.data
## @property
# Default rigid body properties def projected_gravity_b(self):
## """Projection of the gravity direction on base frame. Shape is (num_instances, 3)."""
return math_utils.quat_rotate_inverse(self.root_quat_w, self.gravity_vec_w)
default_mass: torch.Tensor = None @property
""" Default mass provided by simulation. Shape is (num_instances, num_bodies).""" def heading_w(self):
"""Yaw heading of the base frame (in radians). Shape is (num_instances,).
""" Note:
Properties This quantity is computed by assuming that the forward-direction of the base
""" frame is along x-direction, i.e. :math:`(1, 0, 0)`.
"""
forward_w = math_utils.quat_apply(self.root_quat_w, self.forward_vec_b)
return torch.atan2(forward_w[:, 1], forward_w[:, 0])
@property @property
def root_pos_w(self) -> torch.Tensor: def root_pos_w(self) -> torch.Tensor:
...@@ -96,12 +120,12 @@ class RigidObjectData: ...@@ -96,12 +120,12 @@ class RigidObjectData:
@property @property
def root_lin_vel_b(self) -> torch.Tensor: def root_lin_vel_b(self) -> torch.Tensor:
"""Root linear velocity in base frame. Shape is (num_instances, 3).""" """Root linear velocity in base frame. Shape is (num_instances, 3)."""
return self.root_vel_b[:, 0:3] return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_lin_vel_w)
@property @property
def root_ang_vel_b(self) -> torch.Tensor: def root_ang_vel_b(self) -> torch.Tensor:
"""Root angular velocity in base world frame. Shape is (num_instances, 3).""" """Root angular velocity in base world frame. Shape is (num_instances, 3)."""
return self.root_vel_b[:, 3:6] return math_utils.quat_rotate_inverse(self.root_quat_w, self.root_ang_vel_w)
@property @property
def body_pos_w(self) -> torch.Tensor: def body_pos_w(self) -> torch.Tensor:
...@@ -127,13 +151,3 @@ class RigidObjectData: ...@@ -127,13 +151,3 @@ class RigidObjectData:
def body_ang_vel_w(self) -> torch.Tensor: def body_ang_vel_w(self) -> torch.Tensor:
"""Angular velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3).""" """Angular velocity of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_state_w[..., 10:13] return self.body_state_w[..., 10:13]
@property
def body_lin_acc_w(self) -> torch.Tensor:
"""Linear acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 0:3]
@property
def body_ang_acc_w(self) -> torch.Tensor:
"""Angular acceleration of all bodies in simulation world frame. Shape is (num_instances, num_bodies, 3)."""
return self.body_acc_w[..., 3:6]
...@@ -7,3 +7,4 @@ ...@@ -7,3 +7,4 @@
from .circular_buffer import BatchedCircularBuffer from .circular_buffer import BatchedCircularBuffer
from .delay_buffer import DelayBuffer from .delay_buffer import DelayBuffer
from .timestamped_buffer import TimestampedBuffer
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import torch
from dataclasses import dataclass
@dataclass
class TimestampedBuffer:
"""Buffer to hold timestamped data.
Such a buffer is useful to check whether data is outdated and needs to be refreshed to create lazy buffers.
"""
data: torch.Tensor = None
"""Data stored in the buffer."""
update_timestamp: float = -1.0
"""Timestamp of the last update of the buffer."""
...@@ -34,7 +34,7 @@ def base_up_proj(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCf ...@@ -34,7 +34,7 @@ def base_up_proj(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg = SceneEntityCf
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
# compute base up vector # compute base up vector
base_up_vec = math_utils.quat_rotate(asset.data.root_quat_w, -asset.GRAVITY_VEC_W) base_up_vec = -asset.data.projected_gravity_b
return base_up_vec[:, 2].unsqueeze(-1) return base_up_vec[:, 2].unsqueeze(-1)
...@@ -50,7 +50,7 @@ def base_heading_proj( ...@@ -50,7 +50,7 @@ def base_heading_proj(
to_target_pos[:, 2] = 0.0 to_target_pos[:, 2] = 0.0
to_target_dir = math_utils.normalize(to_target_pos) to_target_dir = math_utils.normalize(to_target_pos)
# compute base forward vector # compute base forward vector
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.FORWARD_VEC_B) heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.data.forward_vec_b)
# compute dot product between heading and target direction # compute dot product between heading and target direction
heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1)) heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment