Unverified Commit 1a7c86b9 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes reference count in asset instances due to circular references (#580)

With the lazy buffer implementation in asset classes, we pass the
physics sim-view to the data object. This increments the reference count
of the sim-view instance and blocks the asset instance itself to be
garbage collected properly.

This MR takes steps to fix this issue.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have run all the tests with `./isaaclab.sh --test` and they pass
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent eb8d968d
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.18.3" version = "0.18.4"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.18.4 (2024-06-26)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed double reference count of the physics sim view inside the asset classes. This was causing issues
when destroying the asset class instance since the physics sim view was not being properly released.
Added
^^^^^
* Added the attribute :attr:`~omni.isaac.lab.assets.AssetBase.is_initialized` to check if the asset and sensor
has been initialized properly. This can be used to ensure that the asset or sensor is ready to use in the simulation.
0.18.3 (2024-06-25) 0.18.3 (2024-06-25)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -88,6 +88,14 @@ class Articulation(RigidObject): ...@@ -88,6 +88,14 @@ class Articulation(RigidObject):
cfg: ArticulationCfg cfg: ArticulationCfg
"""Configuration instance for the articulations.""" """Configuration instance for the articulations."""
actuators: dict[str, ActuatorBase]
"""Dictionary of actuator instances for the articulation.
The keys are the actuator names and the values are the actuator instances. The actuator instances
are initialized based on the actuator configurations specified in the :attr:`ArticulationCfg.actuators`
attribute. They are used to compute the joint commands during the :meth:`write_data_to_sim` function.
"""
def __init__(self, cfg: ArticulationCfg): def __init__(self, cfg: ArticulationCfg):
"""Initialize the articulation. """Initialize the articulation.
...@@ -95,8 +103,6 @@ class Articulation(RigidObject): ...@@ -95,8 +103,6 @@ class Articulation(RigidObject):
cfg: A configuration instance. cfg: A configuration instance.
""" """
super().__init__(cfg) super().__init__(cfg)
# data for storing actuator group
self.actuators: dict[str, ActuatorBase] = dict.fromkeys(self.cfg.actuators.keys())
""" """
Properties Properties
...@@ -870,8 +876,10 @@ class Articulation(RigidObject): ...@@ -870,8 +876,10 @@ class Articulation(RigidObject):
# -- properties # -- properties
self._data.joint_names = self.joint_names self._data.joint_names = self.joint_names
# -- default joint state
self._data.default_joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.default_joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.default_joint_vel = torch.zeros_like(self._data.default_joint_pos) self._data.default_joint_vel = torch.zeros_like(self._data.default_joint_pos)
# -- joint commands # -- joint commands
self._data.joint_pos_target = torch.zeros_like(self._data.default_joint_pos) self._data.joint_pos_target = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_vel_target = torch.zeros_like(self._data.default_joint_pos) self._data.joint_vel_target = torch.zeros_like(self._data.default_joint_pos)
...@@ -881,9 +889,11 @@ class Articulation(RigidObject): ...@@ -881,9 +889,11 @@ class Articulation(RigidObject):
self._data.joint_armature = torch.zeros_like(self._data.default_joint_pos) self._data.joint_armature = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_friction = torch.zeros_like(self._data.default_joint_pos) self._data.joint_friction = torch.zeros_like(self._data.default_joint_pos)
self._data.joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device) self._data.joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device)
# -- joint commands (explicit) # -- joint commands (explicit)
self._data.computed_torque = torch.zeros_like(self._data.default_joint_pos) self._data.computed_torque = torch.zeros_like(self._data.default_joint_pos)
self._data.applied_torque = torch.zeros_like(self._data.default_joint_pos) self._data.applied_torque = torch.zeros_like(self._data.default_joint_pos)
# -- tendons # -- tendons
if self.num_fixed_tendons > 0: if self.num_fixed_tendons > 0:
self._data.fixed_tendon_stiffness = torch.zeros( self._data.fixed_tendon_stiffness = torch.zeros(
...@@ -907,12 +917,15 @@ class Articulation(RigidObject): ...@@ -907,12 +917,15 @@ class Articulation(RigidObject):
self._data.soft_joint_pos_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device) self._data.soft_joint_pos_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device)
self._data.soft_joint_vel_limits = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.soft_joint_vel_limits = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.gear_ratio = torch.ones(self.num_instances, self.num_joints, device=self.device) self._data.gear_ratio = torch.ones(self.num_instances, self.num_joints, device=self.device)
# -- initialize default buffers
# -- initialize default buffers related to joint properties
self._data.default_joint_stiffness = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.default_joint_stiffness = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.default_joint_damping = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.default_joint_damping = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.default_joint_armature = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.default_joint_armature = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.default_joint_friction = torch.zeros(self.num_instances, self.num_joints, device=self.device) self._data.default_joint_friction = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.default_joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device) self._data.default_joint_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device)
# -- initialize default buffers related to fixed tendon properties
if self.num_fixed_tendons > 0: if self.num_fixed_tendons > 0:
self._data.default_fixed_tendon_stiffness = torch.zeros( self._data.default_fixed_tendon_stiffness = torch.zeros(
self.num_instances, self.num_fixed_tendons, device=self.device self.num_instances, self.num_fixed_tendons, device=self.device
...@@ -963,6 +976,7 @@ class Articulation(RigidObject): ...@@ -963,6 +976,7 @@ class Articulation(RigidObject):
) )
self._data.default_joint_vel[:, indices_list] = torch.tensor(values_list, device=self.device) self._data.default_joint_vel[:, indices_list] = torch.tensor(values_list, device=self.device)
# -- joint limits
self._data.default_joint_limits = self.root_physx_view.get_dof_limits().to(device=self.device).clone() self._data.default_joint_limits = self.root_physx_view.get_dof_limits().to(device=self.device).clone()
self._data.joint_limits = self._data.default_joint_limits.clone() self._data.joint_limits = self._data.default_joint_limits.clone()
...@@ -972,9 +986,12 @@ class Articulation(RigidObject): ...@@ -972,9 +986,12 @@ class Articulation(RigidObject):
def _process_actuators_cfg(self): def _process_actuators_cfg(self):
"""Process and apply articulation joint properties.""" """Process and apply articulation joint properties."""
# create actuators
self.actuators = dict()
# flag for implicit actuators # flag for implicit actuators
# if this is false, we by-pass certain checks when doing actuator-related operations # if this is false, we by-pass certain checks when doing actuator-related operations
self._has_implicit_actuators = False self._has_implicit_actuators = False
# cache the values coming from the usd # cache the values coming from the usd
usd_stiffness = self.root_physx_view.get_dof_stiffnesses().clone() usd_stiffness = self.root_physx_view.get_dof_stiffnesses().clone()
usd_damping = self.root_physx_view.get_dof_dampings().clone() usd_damping = self.root_physx_view.get_dof_dampings().clone()
...@@ -982,6 +999,7 @@ class Articulation(RigidObject): ...@@ -982,6 +999,7 @@ class Articulation(RigidObject):
usd_friction = self.root_physx_view.get_dof_friction_coefficients().clone() usd_friction = self.root_physx_view.get_dof_friction_coefficients().clone()
usd_effort_limit = self.root_physx_view.get_dof_max_forces().clone() usd_effort_limit = self.root_physx_view.get_dof_max_forces().clone()
usd_velocity_limit = self.root_physx_view.get_dof_max_velocities().clone() usd_velocity_limit = self.root_physx_view.get_dof_max_velocities().clone()
# iterate over all actuator configurations # iterate over all actuator configurations
for actuator_name, actuator_cfg in self.cfg.actuators.items(): for actuator_name, actuator_cfg in self.cfg.actuators.items():
# type annotation for type checkers # type annotation for type checkers
......
...@@ -21,7 +21,12 @@ class ArticulationData(RigidObjectData): ...@@ -21,7 +21,12 @@ class ArticulationData(RigidObjectData):
""" """
_root_physx_view: physx.ArticulationView _root_physx_view: physx.ArticulationView
"""The root articulation view of the object.""" """The root articulation view of the object.
Note:
Internally, this is stored as a weak reference to avoid circular references between the asset class
and the data container. This is important to avoid memory leaks.
"""
def __init__(self, root_physx_view: physx.ArticulationView, device: str): def __init__(self, root_physx_view: physx.ArticulationView, device: str):
# Initialize the parent class # Initialize the parent class
......
...@@ -120,6 +120,14 @@ class AssetBase(ABC): ...@@ -120,6 +120,14 @@ class AssetBase(ABC):
Properties Properties
""" """
@property
def is_initialized(self) -> bool:
"""Whether the asset is initialized.
Returns True if the asset is initialized, False otherwise.
"""
return self._is_initialized
@property @property
@abstractmethod @abstractmethod
def num_instances(self) -> int: def num_instances(self) -> int:
......
...@@ -76,7 +76,7 @@ class RigidObject(AssetBase): ...@@ -76,7 +76,7 @@ class RigidObject(AssetBase):
@property @property
def body_names(self) -> list[str]: def body_names(self) -> list[str]:
"""Ordered names of bodies in articulation.""" """Ordered names of bodies in the rigid object."""
prim_paths = self.root_physx_view.prim_paths[: self.num_bodies] prim_paths = self.root_physx_view.prim_paths[: self.num_bodies]
return [path.split("/")[-1] for path in prim_paths] return [path.split("/")[-1] for path in prim_paths]
...@@ -340,6 +340,7 @@ class RigidObject(AssetBase): ...@@ -340,6 +340,7 @@ class RigidObject(AssetBase):
self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device) self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device)
self._external_torque_b = torch.zeros_like(self._external_force_b) self._external_torque_b = torch.zeros_like(self._external_force_b)
# set information about rigid body into data
self._data.body_names = self.body_names self._data.body_names = self.body_names
self._data.default_mass = self.root_physx_view.get_masses().clone() self._data.default_mass = self.root_physx_view.get_masses().clone()
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
import torch import torch
import weakref
import omni.physics.tensors.impl.api as physx import omni.physics.tensors.impl.api as physx
...@@ -24,7 +25,12 @@ class RigidObjectData: ...@@ -24,7 +25,12 @@ class RigidObjectData:
""" """
_root_physx_view: physx.RigidBodyView _root_physx_view: physx.RigidBodyView
"""The root rigid body view of the object.""" """The root rigid body view of the object.
Note:
Internally, this is stored as a weak reference to avoid circular references between the asset class
and the data container. This is important to avoid memory leaks.
"""
def __init__(self, root_physx_view: physx.RigidBodyView, device: str): def __init__(self, root_physx_view: physx.RigidBodyView, device: str):
"""Initializes the rigid object data. """Initializes the rigid object data.
...@@ -35,7 +41,7 @@ class RigidObjectData: ...@@ -35,7 +41,7 @@ class RigidObjectData:
""" """
# Set the parameters # Set the parameters
self.device = device self.device = device
self._root_physx_view = root_physx_view self._root_physx_view = weakref.proxy(root_physx_view) # weak reference to avoid circular references
# Set initial time stamp # Set initial time stamp
self._sim_timestamp = 0.0 self._sim_timestamp = 0.0
......
...@@ -308,11 +308,11 @@ class ManagerBasedEnv: ...@@ -308,11 +308,11 @@ class ManagerBasedEnv:
"""Cleanup for the environment.""" """Cleanup for the environment."""
if not self._is_closed: if not self._is_closed:
# destructor is order-sensitive # destructor is order-sensitive
del self.viewport_camera_controller
del self.action_manager del self.action_manager
del self.observation_manager del self.observation_manager
del self.event_manager del self.event_manager
del self.scene del self.scene
del self.viewport_camera_controller
# clear callbacks and instance # clear callbacks and instance
self.sim.clear_all_callbacks() self.sim.clear_all_callbacks()
self.sim.clear_instance() self.sim.clear_instance()
......
...@@ -92,6 +92,14 @@ class SensorBase(ABC): ...@@ -92,6 +92,14 @@ class SensorBase(ABC):
Properties Properties
""" """
@property
def is_initialized(self) -> bool:
"""Whether the sensor is initialized.
Returns True if the sensor is initialized, False otherwise.
"""
return self._is_initialized
@property @property
def num_instances(self) -> int: def num_instances(self) -> int:
"""Number of instances of the sensor. """Number of instances of the sensor.
......
...@@ -50,7 +50,7 @@ def base_heading_proj( ...@@ -50,7 +50,7 @@ def base_heading_proj(
to_target_pos[:, 2] = 0.0 to_target_pos[:, 2] = 0.0
to_target_dir = math_utils.normalize(to_target_pos) to_target_dir = math_utils.normalize(to_target_pos)
# compute base forward vector # compute base forward vector
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.data.forward_vec_b) heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.data.FORWARD_VEC_B)
# compute dot product between heading and target direction # compute dot product between heading and target direction
heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1)) heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment