Unverified Commit 6f4cc59d authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Changes internal working of APIs to use physics views directly (#267)

# Description

For a long time, we have been seeing a slow simulation setup time (i.e.
time spent in `sim.reset` call). It takes around 70-75 seconds to set up
the simulation for ANYmal locomotion task with the new USD asset for it.
This number is only increasing with other more complex robots we have
been trying to import.

The MR dives into the possible causes and gets rid of costly operations.
Many of these are coming from Isaac Sim itself, particularly related to
the initialization of views. Hence, the following breaking changes:

* We no longer depend on Isaac Sim for `RigidPrimView` and
`ArticulationView`. Instead, we directly create underlying PhysX views
for them.
* We add faster reimplementations of functions that are used for regex
matching.

With these changes, the simulation load time is reduced from up to 80
sec to 15 sec. A bulk of the time is still going to setting up the
simulation step for the first time.

## Type of change

- Breaking change (fix or feature that would cause existing
functionality to not work as expected)

## Screenshots

| Before | After |
| ------ | ----- |
|
![orig-fg](https://github.com/isaac-orbit/orbit/assets/12863862/c13f1634-bd2c-4daf-97e0-3b5776b5cd37)
|
![ref-fg](https://github.com/isaac-orbit/orbit/assets/12863862/b509049d-4cbd-45d6-a4f4-6082f4caf7f2)
|

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 99a238e7
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.9.55" version = "0.10.0"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.10.0 (2023-12-04)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Modified the sensor and asset base classes to use the underlying PhysX views instead of Isaac Sim views.
Using Isaac Sim classes led to a very high load time (of the order of minutes) when using a scene with
many assets. This is because Isaac Sim supports USD paths which are slow and not required.
Added
^^^^^
* Added faster implementation of USD stage traversal methods inside the :class:`omni.isaac.orbit.sim.utils` module.
* Added properties :attr:`omni.isaac.orbit.assets.AssetBase.num_instances` and
:attr:`omni.isaac.orbit.sensor.SensorBase.num_instances` to obtain the number of instances of the asset
or sensor in the simulation respectively.
Removed
^^^^^^^
* Removed dependencies on Isaac Sim view classes. It is no longer possible to use :attr:`root_view` and
:attr:`body_view`. Instead use :attr:`root_physx_view` and :attr:`body_physx_view` to access the underlying
PhysX views.
0.9.55 (2023-12-03) 0.9.55 (2023-12-03)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -639,9 +639,13 @@ class AppLauncher: ...@@ -639,9 +639,13 @@ class AppLauncher:
enable_extension("omni.kit.viewport.bundle") enable_extension("omni.kit.viewport.bundle")
# extension for window status bar # extension for window status bar
enable_extension("omni.kit.window.status_bar") enable_extension("omni.kit.window.status_bar")
# enable isaac replicator extension # enable replicator extension
# note: moved here since it requires to have the viewport extension to be enabled first. # note: moved here since it requires to have the viewport extension to be enabled first.
enable_extension("omni.replicator.isaac") enable_extension("omni.replicator.core")
# enable UI tools
# note: we need to always import this even with headless to make
# the module for orbit.envs.ui work
enable_extension("omni.isaac.ui")
# set the nucleus directory manually to the 2023.1.0 version # set the nucleus directory manually to the 2023.1.0 version
# TODO: Remove this once the 2023.1.0 version is released # TODO: Remove this once the 2023.1.0 version is released
......
...@@ -13,13 +13,11 @@ from prettytable import PrettyTable ...@@ -13,13 +13,11 @@ from prettytable import PrettyTable
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING, Sequence
import carb import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.physics.tensors.impl.api as physx import omni.physics.tensors.impl.api as physx
from omni.isaac.core.articulations import ArticulationView
from omni.isaac.core.prims import RigidPrimView
from omni.isaac.core.utils.types import ArticulationActions from omni.isaac.core.utils.types import ArticulationActions
from pxr import Usd, UsdPhysics from pxr import UsdPhysics
import omni.isaac.orbit.sim as sim_utils
import omni.isaac.orbit.utils.math as math_utils import omni.isaac.orbit.utils.math as math_utils
import omni.isaac.orbit.utils.string as string_utils import omni.isaac.orbit.utils.string as string_utils
from omni.isaac.orbit.actuators import ActuatorBase, ActuatorBaseCfg, ImplicitActuator from omni.isaac.orbit.actuators import ActuatorBase, ActuatorBaseCfg, ImplicitActuator
...@@ -49,12 +47,10 @@ class Articulation(RigidObject): ...@@ -49,12 +47,10 @@ class Articulation(RigidObject):
articulation root prim can be specified using the :attr:`AssetBaseCfg.prim_path` attribute. articulation root prim can be specified using the :attr:`AssetBaseCfg.prim_path` attribute.
The articulation class is a subclass of the :class:`RigidObject` class. Therefore, it inherits The articulation class is a subclass of the :class:`RigidObject` class. Therefore, it inherits
all the functionality of the rigid object class. In case of an articulation, the :attr:`root_view` all the functionality of the rigid object class. In case of an articulation, the :attr:`root_physx_view`
attribute corresponds to the articulation root view and can be used to access the articulation attribute corresponds to the articulation root view and can be used to access the articulation
related data. The :attr:`body_view` attribute corresponds to the rigid body view of the articulated related data. The :attr:`body_physx_view` attribute corresponds to the rigid body view of the articulated
links and can be used to access the rigid body related data. The :attr:`root_physx_view` and links and can be used to access the rigid body related data.
:attr:`body_physx_view` attributes correspond to the underlying physics views of the articulation
root and the articulated links, respectively.
The articulation class also provides the functionality to augment the simulation of an articulated The articulation class also provides the functionality to augment the simulation of an articulated
system with custom actuator models. These models can either be explicit or implicit, as detailed in system with custom actuator models. These models can either be explicit or implicit, as detailed in
...@@ -108,44 +104,35 @@ class Articulation(RigidObject): ...@@ -108,44 +104,35 @@ class Articulation(RigidObject):
@property @property
def data(self) -> ArticulationData: def data(self) -> ArticulationData:
"""Data related to articulation."""
return self._data return self._data
@property @property
def is_fixed_base(self) -> bool: def is_fixed_base(self) -> bool:
"""Whether the articulation is a fixed-base or floating-base system.""" """Whether the articulation is a fixed-base or floating-base system."""
return self._is_fixed_base return self.root_physx_view.shared_metatype.fixed_base
@property @property
def num_joints(self) -> int: def num_joints(self) -> int:
"""Number of joints in articulation.""" """Number of joints in articulation."""
return self.root_view.num_dof return self.root_physx_view.max_dofs
@property @property
def num_bodies(self) -> int: def num_bodies(self) -> int:
"""Number of bodies in articulation.""" """Number of bodies in articulation."""
return self.root_view.num_bodies return self.root_physx_view.max_links
@property @property
def joint_names(self) -> list[str]: def joint_names(self) -> list[str]:
"""Ordered names of joints in articulation.""" """Ordered names of joints in articulation."""
return self.root_view.dof_names return self.root_physx_view.shared_metatype.dof_names
@property
def root_view(self) -> ArticulationView:
return self._root_view
@property
def body_view(self) -> RigidPrimView:
return self._body_view
@property @property
def root_physx_view(self) -> physx.ArticulationView: def root_physx_view(self) -> physx.ArticulationView:
return self._root_view._physics_view # pyright: ignore [reportPrivateUsage] return self._root_physx_view
@property @property
def body_physx_view(self) -> physx.RigidBodyView: def body_physx_view(self) -> physx.RigidBodyView:
return self._body_view._physics_view # pyright: ignore [reportPrivateUsage] return self._body_physx_view
""" """
Operations. Operations.
...@@ -483,56 +470,36 @@ class Articulation(RigidObject): ...@@ -483,56 +470,36 @@ class Articulation(RigidObject):
""" """
def _initialize_impl(self): def _initialize_impl(self):
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# obtain the first prim in the regex expression (all others are assumed to be a copy of this)
template_prim = sim_utils.find_first_matching_prim(self.cfg.prim_path)
if template_prim is None:
raise RuntimeError(f"Failed to find prim for expression: '{self.cfg.prim_path}'.")
template_prim_path = template_prim.GetPath().pathString
# find articulation root prims # find articulation root prims
asset_prim_path = prim_utils.find_matching_prim_paths(self.cfg.prim_path)[0] root_prims = sim_utils.get_all_matching_child_prims(
root_prims = prim_utils.get_all_matching_child_prims( template_prim_path, predicate=lambda prim: prim.HasAPI(UsdPhysics.ArticulationRootAPI)
asset_prim_path, predicate=lambda a: prim_utils.get_prim_at_path(a).HasAPI(UsdPhysics.ArticulationRootAPI)
) )
if len(root_prims) != 1: if len(root_prims) != 1:
raise RuntimeError( raise RuntimeError(
f"Failed to find a single articulation root when resolving '{self.cfg.prim_path}'." f"Failed to find a single articulation root when resolving '{self.cfg.prim_path}'."
f" Found roots '{root_prims}' under '{asset_prim_path}'." f" Found roots '{root_prims}' under '{template_prim_path}'."
) )
# resolve articulation root prim back into regex expression # resolve articulation root prim back into regex expression
root_prim_path = prim_utils.get_prim_path(root_prims[0]) root_prim_path = root_prims[0].GetPath().pathString
root_prim_path_expr = self.cfg.prim_path + root_prim_path[len(asset_prim_path) :] root_prim_path_expr = self.cfg.prim_path + root_prim_path[len(template_prim_path) :]
# -- articulation # -- articulation
self._root_view = ArticulationView(root_prim_path_expr, reset_xform_properties=False) self._root_physx_view = self._physics_sim_view.create_articulation_view(root_prim_path_expr.replace(".*", "*"))
# Hacking the initialization of the articulation view.
# reason: The default initialization of the articulation view is not working properly as it tries to create
# default actions that is not possible within the post-play callback.
# We override their internal function that throws an error which is not desired or needed.
dummy_tensor = torch.empty(size=(0, 0), device=self.device)
dummy_joint_actions = ArticulationActions(dummy_tensor, dummy_tensor, dummy_tensor)
current_fn = self._root_view.get_applied_actions
self._root_view.get_applied_actions = lambda *args, **kwargs: dummy_joint_actions
# initialize the root view
self._root_view.initialize()
# restore the function
self._root_view.get_applied_actions = current_fn
# -- link views # -- link views
# note: we use the root view to get the body names, but we use the body view to get the # note: we use the root view to get the body names, but we use the body view to get the
# actual data. This is mainly needed to apply external forces to the bodies. # actual data. This is mainly needed to apply external forces to the bodies.
body_names_regex = r"(" + "|".join(self.root_view.body_names) + r")" physx_body_names = self.root_physx_view.shared_metatype.link_names
body_names_regex = r"(" + "|".join(physx_body_names) + r")"
body_names_regex = f"{self.cfg.prim_path}/{body_names_regex}" body_names_regex = f"{self.cfg.prim_path}/{body_names_regex}"
self._body_view = RigidPrimView(body_names_regex, reset_xform_properties=False) self._body_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
self._body_view.initialize()
# check that initialization was successful
if len(self.body_names) != self.num_bodies:
raise RuntimeError("Failed to initialize all bodies properly in the articulation.")
# -- fixed base based on root joint
self._is_fixed_base = False
for prim in Usd.PrimRange(self._root_view.prims[0]):
joint_prim = UsdPhysics.FixedJoint(prim)
# we check all joints under the root prim and classify the asset as fixed base if there exists
# a fixed joint that has only one target (i.e. the root link).
if joint_prim and joint_prim.GetJointEnabledAttr().Get():
body_0_exist = joint_prim.GetBody0Rel().GetTargets() != []
body_1_exist = joint_prim.GetBody1Rel().GetTargets() != []
if not (body_0_exist and body_1_exist):
self._is_fixed_base = True
break
# log information about the articulation # log information about the articulation
carb.log_info(f"Articulation initialized at: {self.cfg.prim_path} with root '{root_prim_path_expr}'.") carb.log_info(f"Articulation initialized at: {self.cfg.prim_path} with root '{root_prim_path_expr}'.")
carb.log_info(f"Is fixed root: {self.is_fixed_base}") carb.log_info(f"Is fixed root: {self.is_fixed_base}")
...@@ -541,7 +508,7 @@ class Articulation(RigidObject): ...@@ -541,7 +508,7 @@ class Articulation(RigidObject):
carb.log_info(f"Number of joints: {self.num_joints}") carb.log_info(f"Number of joints: {self.num_joints}")
carb.log_info(f"Joint names: {self.joint_names}") carb.log_info(f"Joint names: {self.joint_names}")
# -- assert that parsing was successful # -- assert that parsing was successful
if set(self.root_view.body_names) != set(self.body_names): if set(physx_body_names) != set(self.body_names):
raise RuntimeError("Failed to parse all bodies properly in the articulation.") raise RuntimeError("Failed to parse all bodies properly in the articulation.")
# create buffers # create buffers
self._create_buffers() self._create_buffers()
...@@ -555,13 +522,13 @@ class Articulation(RigidObject): ...@@ -555,13 +522,13 @@ class Articulation(RigidObject):
# allocate buffers # allocate buffers
super()._create_buffers() super()._create_buffers()
# history buffers # history buffers
self._previous_joint_vel = torch.zeros(self.root_view.count, self.num_joints, device=self.device) self._previous_joint_vel = torch.zeros(self.num_instances, self.num_joints, device=self.device)
# asset data # asset data
# -- properties # -- properties
self._data.joint_names = self.joint_names self._data.joint_names = self.joint_names
# -- joint states # -- joint states
self._data.joint_pos = torch.zeros(self.root_view.count, self.num_joints, dtype=torch.float, device=self.device) self._data.joint_pos = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.joint_vel = torch.zeros_like(self._data.joint_pos) self._data.joint_vel = torch.zeros_like(self._data.joint_pos)
self._data.joint_acc = torch.zeros_like(self._data.joint_pos) self._data.joint_acc = torch.zeros_like(self._data.joint_pos)
self._data.default_joint_pos = torch.zeros_like(self._data.joint_pos) self._data.default_joint_pos = torch.zeros_like(self._data.joint_pos)
...@@ -578,9 +545,9 @@ class Articulation(RigidObject): ...@@ -578,9 +545,9 @@ class Articulation(RigidObject):
self._data.computed_torque = torch.zeros_like(self._data.joint_pos) self._data.computed_torque = torch.zeros_like(self._data.joint_pos)
self._data.applied_torque = torch.zeros_like(self._data.joint_pos) self._data.applied_torque = torch.zeros_like(self._data.joint_pos)
# -- other data # -- other data
self._data.soft_joint_pos_limits = torch.zeros(self.root_view.count, self.num_joints, 2, device=self.device) self._data.soft_joint_pos_limits = torch.zeros(self.num_instances, self.num_joints, 2, device=self.device)
self._data.soft_joint_vel_limits = torch.zeros(self.root_view.count, self.num_joints, device=self.device) self._data.soft_joint_vel_limits = torch.zeros(self.num_instances, self.num_joints, device=self.device)
self._data.gear_ratio = torch.ones(self.root_view.count, self.num_joints, device=self.device) self._data.gear_ratio = torch.ones(self.num_instances, self.num_joints, device=self.device)
# soft joint position limits (recommended not to be too close to limits). # soft joint position limits (recommended not to be too close to limits).
joint_pos_limits = self.root_physx_view.get_dof_limits() joint_pos_limits = self.root_physx_view.get_dof_limits()
...@@ -635,18 +602,18 @@ class Articulation(RigidObject): ...@@ -635,18 +602,18 @@ class Articulation(RigidObject):
) )
# create actuator collection # create actuator collection
# note: for efficiency avoid indexing when over all indices # note: for efficiency avoid indexing when over all indices
sim_stiffness, sim_damping = self.root_view.get_gains(joint_indices=joint_ids)
actuator: ActuatorBase = actuator_cfg.class_type( actuator: ActuatorBase = actuator_cfg.class_type(
cfg=actuator_cfg, cfg=actuator_cfg,
joint_names=joint_names, joint_names=joint_names,
joint_ids=slice(None) if len(joint_names) == self.num_joints else joint_ids, joint_ids=slice(None) if len(joint_names) == self.num_joints else joint_ids,
num_envs=self.root_view.count, num_envs=self.num_instances,
device=self.device, device=self.device,
stiffness=sim_stiffness, stiffness=self.root_physx_view.get_dof_stiffnesses()[:, joint_ids],
damping=sim_damping, damping=self.root_physx_view.get_dof_dampings()[:, joint_ids],
armature=self.root_view.get_armatures(joint_indices=joint_ids), armature=self.root_physx_view.get_dof_armatures()[:, joint_ids],
friction=self.root_view.get_friction_coefficients(joint_indices=joint_ids), friction=self.root_physx_view.get_dof_friction_coefficients()[:, joint_ids],
effort_limit=self.root_view.get_max_efforts(joint_indices=joint_ids), effort_limit=self.root_physx_view.get_dof_max_forces()[:, joint_ids],
velocity_limit=self.root_physx_view.get_dof_max_velocities()[:, joint_ids],
) )
# log information on actuator groups # log information on actuator groups
carb.log_info( carb.log_info(
...@@ -733,16 +700,30 @@ class Articulation(RigidObject): ...@@ -733,16 +700,30 @@ class Articulation(RigidObject):
def _log_articulation_joint_info(self): def _log_articulation_joint_info(self):
"""Log information about the articulation's simulated joints.""" """Log information about the articulation's simulated joints."""
# read out all joint parameters from simulation # read out all joint parameters from simulation
gains = self.root_view.get_gains(indices=[0]) # -- gains
stiffnesses, dampings = gains[0].squeeze(0).tolist(), gains[1].squeeze(0).tolist() stiffnesses = self.root_physx_view.get_dof_stiffnesses()[0].squeeze(0).tolist()
armatures = self.root_view.get_armatures(indices=[0]).squeeze(0).tolist() dampings = self.root_physx_view.get_dof_dampings()[0].squeeze(0).tolist()
frictions = self.root_view.get_friction_coefficients(indices=[0]).squeeze(0).tolist() # -- properties
effort_limits = self.root_view.get_max_efforts(indices=[0]).squeeze(0).tolist() armatures = self.root_physx_view.get_dof_armatures()[0].squeeze(0).tolist()
pos_limits = self.root_view.get_dof_limits()[0].squeeze(0).tolist() frictions = self.root_physx_view.get_dof_friction_coefficients()[0].squeeze(0).tolist()
# -- limits
position_limits = self.root_physx_view.get_dof_limits()[0].squeeze(0).tolist()
velocity_limits = self.root_physx_view.get_dof_max_velocities()[0].squeeze(0).tolist()
effort_limits = self.root_physx_view.get_dof_max_forces()[0].squeeze(0).tolist()
# create table for term information # create table for term information
table = PrettyTable(float_format=".3f") table = PrettyTable(float_format=".3f")
table.title = f"Simulation Joint Information (Prim path: {self.cfg.prim_path})" table.title = f"Simulation Joint Information (Prim path: {self.cfg.prim_path})"
table.field_names = ["Index", "Name", "Stiffness", "Damping", "Armature", "Friction", "Effort Limit", "Limits"] table.field_names = [
"Index",
"Name",
"Stiffness",
"Damping",
"Armature",
"Friction",
"Position Limits",
"Velocity Limits",
"Effort Limits",
]
# set alignment of table columns # set alignment of table columns
table.align["Name"] = "l" table.align["Name"] = "l"
# add info on each term # add info on each term
...@@ -755,8 +736,9 @@ class Articulation(RigidObject): ...@@ -755,8 +736,9 @@ class Articulation(RigidObject):
dampings[index], dampings[index],
armatures[index], armatures[index],
frictions[index], frictions[index],
position_limits[index],
velocity_limits[index],
effort_limits[index], effort_limits[index],
pos_limits[index],
] ]
) )
# convert table to string # convert table to string
......
...@@ -11,10 +11,11 @@ import weakref ...@@ -11,10 +11,11 @@ import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence from typing import TYPE_CHECKING, Any, Sequence
import omni.isaac.core.utils.prims as prim_utils
import omni.kit.app import omni.kit.app
import omni.timeline import omni.timeline
import omni.isaac.orbit.sim as sim_utils
if TYPE_CHECKING: if TYPE_CHECKING:
from .asset_base_cfg import AssetBaseCfg from .asset_base_cfg import AssetBaseCfg
...@@ -77,8 +78,8 @@ class AssetBase(ABC): ...@@ -77,8 +78,8 @@ class AssetBase(ABC):
orientation=self.cfg.init_state.rot, orientation=self.cfg.init_state.rot,
) )
# check that spawn was successful # check that spawn was successful
matching_prim_paths = prim_utils.find_matching_prim_paths(self.cfg.prim_path) matching_prims = sim_utils.find_matching_prims(self.cfg.prim_path)
if len(matching_prim_paths) == 0: if len(matching_prims) == 0:
raise RuntimeError(f"Could not find prim with path {self.cfg.prim_path}.") raise RuntimeError(f"Could not find prim with path {self.cfg.prim_path}.")
# note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor is called. # note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor is called.
...@@ -120,9 +121,17 @@ class AssetBase(ABC): ...@@ -120,9 +121,17 @@ class AssetBase(ABC):
@property @property
@abstractmethod @abstractmethod
def num_instances(self) -> int:
"""Number of instances of the asset.
This is equal to the number of asset instances per environment multiplied by the number of environments.
"""
return NotImplementedError
@property
def device(self) -> str: def device(self) -> str:
"""Memory device for computation.""" """Memory device for computation."""
return NotImplementedError return self._device
@property @property
@abstractmethod @abstractmethod
...@@ -235,7 +244,15 @@ class AssetBase(ABC): ...@@ -235,7 +244,15 @@ class AssetBase(ABC):
called whenever the simulator "plays" from a "stop" state. called whenever the simulator "plays" from a "stop" state.
""" """
if not self._is_initialized: if not self._is_initialized:
# obtain simulation related information
sim = sim_utils.SimulationContext.instance()
if sim is None:
raise RuntimeError("SimulationContext is not initialized! Please initialize SimulationContext first.")
self._backend = sim.backend
self._device = sim.device
# initialize the asset
self._initialize_impl() self._initialize_impl()
# set flag
self._is_initialized = True self._is_initialized = True
def _invalidate_initialize_callback(self, event): def _invalidate_initialize_callback(self, event):
......
...@@ -9,11 +9,10 @@ import torch ...@@ -9,11 +9,10 @@ import torch
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING, Sequence
import carb import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.physics.tensors.impl.api as physx import omni.physics.tensors.impl.api as physx
from omni.isaac.core.prims import RigidPrimView
from pxr import UsdPhysics from pxr import UsdPhysics
import omni.isaac.orbit.sim as sim_utils
import omni.isaac.orbit.utils.math as math_utils import omni.isaac.orbit.utils.math as math_utils
import omni.isaac.orbit.utils.string as string_utils import omni.isaac.orbit.utils.string as string_utils
...@@ -40,10 +39,10 @@ class RigidObject(AssetBase): ...@@ -40,10 +39,10 @@ class RigidObject(AssetBase):
.. note:: .. note::
For users familiar with Isaac Sim, they can use the :attr:`root_view` and :attr:`body_view` attributes For users familiar with Isaac Sim, the PhysX view class API is not the exactly same as Isaac Sim view
to access the rigid body views. These views are wrappers around the PhysX rigid body handles. However, class API. Similar to Orbit, Isaac Sim wraps around the PhysX view API. However, as of now (2023.1 release),
for advanced users who have a deep understanding of PhysX SDK and TensorAPI, they can use the we see a large difference in initializing the view classes in Isaac Sim. This is because the view classes
:attr:`root_physx_view` and :attr:`body_physx_view` attributes to access the rigid body handles directly. in Isaac Sim perform additional USD-related operations which are slow and also not required.
.. _`USD RigidBodyAPI`: https://openusd.org/dev/api/class_usd_physics_rigid_body_a_p_i.html .. _`USD RigidBodyAPI`: https://openusd.org/dev/api/class_usd_physics_rigid_body_a_p_i.html
""" """
...@@ -65,16 +64,14 @@ class RigidObject(AssetBase): ...@@ -65,16 +64,14 @@ class RigidObject(AssetBase):
Properties Properties
""" """
@property
def device(self) -> str:
"""Memory device for computation."""
return self.root_view._device # pyright: ignore [reportPrivateUsage]
@property @property
def data(self) -> RigidObjectData: def data(self) -> RigidObjectData:
"""Data related to articulation."""
return self._data return self._data
@property
def num_instances(self) -> int:
return self.root_physx_view.count
@property @property
def num_bodies(self) -> int: def num_bodies(self) -> int:
"""Number of bodies in the asset.""" """Number of bodies in the asset."""
...@@ -83,38 +80,26 @@ class RigidObject(AssetBase): ...@@ -83,38 +80,26 @@ class RigidObject(AssetBase):
@property @property
def body_names(self) -> list[str]: def body_names(self) -> list[str]:
"""Ordered names of bodies in articulation.""" """Ordered names of bodies in articulation."""
prim_paths = self.body_view.prim_paths[: self.num_bodies] prim_paths = self.body_physx_view.prim_paths[: self.num_bodies]
return [path.split("/")[-1] for path in prim_paths] return [path.split("/")[-1] for path in prim_paths]
@property
def root_view(self) -> RigidPrimView:
"""Rigid body view for the asset (Isaac Sim)."""
return self._root_view
@property
def body_view(self) -> RigidPrimView:
"""View for the bodies in the asset (Isaac Sim)."""
return self._root_view
@property @property
def root_physx_view(self) -> physx.RigidBodyView: def root_physx_view(self) -> physx.RigidBodyView:
"""Rigid body view for the asset (PhysX). """Rigid body view for the asset (PhysX).
Note: Note:
Use this view with caution! It requires handling of tensors in a specific way and is exposed for Use this view with caution. It requires handling of tensors in a specific way.
advanced users who have a deep understanding of PhysX SDK. Prefer using the Isaac Sim view when possible.
""" """
return self._root_view._physics_view # pyright: ignore [reportPrivateUsage] return self._root_physx_view
@property @property
def body_physx_view(self) -> physx.RigidBodyView: def body_physx_view(self) -> physx.RigidBodyView:
"""View for the bodies in the asset (PhysX). """View for the bodies in the asset (PhysX).
Note: Note:
Use this view with caution! It requires handling of tensors in a specific way and is exposed for Use this view with caution. It requires handling of tensors in a specific way.
advanced users who have a deep understanding of PhysX SDK. Prefer using the Isaac Sim view when possible.
""" """
return self._root_view._physics_view # pyright: ignore [reportPrivateUsage] return self._body_physx_view
""" """
Operations. Operations.
...@@ -281,7 +266,7 @@ class RigidObject(AssetBase): ...@@ -281,7 +266,7 @@ class RigidObject(AssetBase):
# note: we need to do this complicated indexing since torch doesn't support multi-indexing # note: we need to do this complicated indexing since torch doesn't support multi-indexing
# create global body indices from env_ids and env_body_ids # create global body indices from env_ids and env_body_ids
# (env_id * total_bodies_per_env) + body_id # (env_id * total_bodies_per_env) + body_id
total_bodies_per_env = self.body_view.count // self.root_view.count total_bodies_per_env = self.body_physx_view.count // self.root_physx_view.count
indices = body_ids.repeat(len(env_ids), 1) + env_ids.unsqueeze(1) * total_bodies_per_env indices = body_ids.repeat(len(env_ids), 1) + env_ids.unsqueeze(1) * total_bodies_per_env
indices = indices.view(-1) indices = indices.view(-1)
# set into internal buffers # set into internal buffers
...@@ -296,26 +281,34 @@ class RigidObject(AssetBase): ...@@ -296,26 +281,34 @@ class RigidObject(AssetBase):
""" """
def _initialize_impl(self): def _initialize_impl(self):
# find articulation root prims # create simulation view
asset_prim_path = prim_utils.find_matching_prim_paths(self.cfg.prim_path)[0] self._physics_sim_view = physx.create_simulation_view(self._backend)
root_prims = prim_utils.get_all_matching_child_prims( self._physics_sim_view.set_subspace_roots("/")
asset_prim_path, predicate=lambda a: prim_utils.get_prim_at_path(a).HasAPI(UsdPhysics.RigidBodyAPI) # obtain the first prim in the regex expression (all others are assumed to be a copy of this)
template_prim = sim_utils.find_first_matching_prim(self.cfg.prim_path)
if template_prim is None:
raise RuntimeError(f"Failed to find prim for expression: '{self.cfg.prim_path}'.")
template_prim_path = template_prim.GetPath().pathString
# find rigid root prims
root_prims = sim_utils.get_all_matching_child_prims(
template_prim_path, predicate=lambda prim: prim.HasAPI(UsdPhysics.RigidBodyAPI)
) )
if len(root_prims) != 1: if len(root_prims) != 1:
raise RuntimeError( raise RuntimeError(
f"Failed to find a single rigid body when resolving '{self.cfg.prim_path}'." f"Failed to find a single rigid body when resolving '{self.cfg.prim_path}'."
f" Found multiple '{root_prims}' under '{asset_prim_path}'." f" Found multiple '{root_prims}' under '{template_prim_path}'."
) )
# resolve articulation root prim back into regex expression # resolve root prim back into regex expression
root_prim_path = prim_utils.get_prim_path(root_prims[0]) root_prim_path = root_prims[0].GetPath().pathString
root_prim_path_expr = self.cfg.prim_path + root_prim_path[len(asset_prim_path) :] root_prim_path_expr = self.cfg.prim_path + root_prim_path[len(template_prim_path) :]
# -- object views # -- object view
self._root_view = RigidPrimView(root_prim_path_expr, reset_xform_properties=False) self._root_physx_view = self._physics_sim_view.create_rigid_body_view(root_prim_path_expr.replace(".*", "*"))
self._root_view.initialize() self._body_physx_view = self._root_physx_view
# log information about the articulation # log information about the articulation
carb.log_info(f"Rigid body initialized at: {self.cfg.prim_path} with root '{root_prim_path_expr}'.") carb.log_info(f"Rigid body initialized at: {self.cfg.prim_path} with root '{root_prim_path_expr}'.")
carb.log_info(f"Number of bodies (orbit): {self.num_bodies}") carb.log_info(f"Number of instances: {self.num_instances}")
carb.log_info(f"Body names (orbit): {self.body_names}") carb.log_info(f"Number of bodies: {self.num_bodies}")
carb.log_info(f"Body names: {self.body_names}")
# create buffers # create buffers
self._create_buffers() self._create_buffers()
# process configuration # process configuration
...@@ -324,32 +317,32 @@ class RigidObject(AssetBase): ...@@ -324,32 +317,32 @@ class RigidObject(AssetBase):
def _create_buffers(self): def _create_buffers(self):
"""Create buffers for storing data.""" """Create buffers for storing data."""
# constants # constants
self._ALL_INDICES = torch.arange(self.root_view.count, dtype=torch.long, device=self.device) self._ALL_INDICES = torch.arange(self.num_instances, dtype=torch.long, device=self.device)
self._ALL_BODY_INDICES = torch.arange(self.body_view.count, dtype=torch.long, device=self.device) self._ALL_BODY_INDICES = torch.arange(self.body_physx_view.count, dtype=torch.long, device=self.device)
self._GRAVITY_VEC_W = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self.root_view.count, 1) self.GRAVITY_VEC_W = torch.tensor((0.0, 0.0, -1.0), device=self.device).repeat(self.num_instances, 1)
self._FORWARD_VEC_B = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self.root_view.count, 1) self.FORWARD_VEC_B = torch.tensor((1.0, 0.0, 0.0), device=self.device).repeat(self.num_instances, 1)
# external forces and torques # external forces and torques
self.has_external_wrench = False self.has_external_wrench = False
self._external_force_b = torch.zeros((self.root_view.count, self.num_bodies, 3), device=self.device) self._external_force_b = torch.zeros((self.num_instances, self.num_bodies, 3), device=self.device)
self._external_torque_b = torch.zeros_like(self._external_force_b) self._external_torque_b = torch.zeros_like(self._external_force_b)
# asset data # asset data
# -- properties # -- properties
self._data.body_names = self.body_names self._data.body_names = self.body_names
# -- root states # -- root states
self._data.root_state_w = torch.zeros(self.root_view.count, 13, device=self.device) self._data.root_state_w = torch.zeros(self.num_instances, 13, device=self.device)
self._data.default_root_state = torch.zeros_like(self._data.root_state_w) self._data.default_root_state = torch.zeros_like(self._data.root_state_w)
# -- body states # -- body states
self._data.body_state_w = torch.zeros(self.root_view.count, self.num_bodies, 13, device=self.device) self._data.body_state_w = torch.zeros(self.num_instances, self.num_bodies, 13, device=self.device)
# -- post-computed # -- post-computed
self._data.root_vel_b = torch.zeros(self.root_view.count, 6, device=self.device) self._data.root_vel_b = torch.zeros(self.num_instances, 6, device=self.device)
self._data.projected_gravity_b = torch.zeros(self.root_view.count, 3, device=self.device) self._data.projected_gravity_b = torch.zeros(self.num_instances, 3, device=self.device)
self._data.heading_w = torch.zeros(self.root_view.count, device=self.device) self._data.heading_w = torch.zeros(self.num_instances, device=self.device)
self._data.body_acc_w = torch.zeros(self.root_view.count, self.num_bodies, 6, device=self.device) self._data.body_acc_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
# history buffers for quantities # history buffers for quantities
# -- used to compute body accelerations numerically # -- used to compute body accelerations numerically
self._last_body_vel_w = torch.zeros(self.root_view.count, self.num_bodies, 6, device=self.device) self._last_body_vel_w = torch.zeros(self.num_instances, self.num_bodies, 6, device=self.device)
def _process_cfg(self): def _process_cfg(self):
"""Post processing of configuration parameters.""" """Post processing of configuration parameters."""
...@@ -363,7 +356,7 @@ class RigidObject(AssetBase): ...@@ -363,7 +356,7 @@ class RigidObject(AssetBase):
+ tuple(self.cfg.init_state.ang_vel) + tuple(self.cfg.init_state.ang_vel)
) )
default_root_state = torch.tensor(default_root_state, dtype=torch.float, device=self.device) default_root_state = torch.tensor(default_root_state, dtype=torch.float, device=self.device)
self._data.default_root_state = default_root_state.repeat(self.root_view.count, 1) self._data.default_root_state = default_root_state.repeat(self.num_instances, 1)
def _update_common_data(self, dt: float): def _update_common_data(self, dt: float):
"""Update common quantities related to rigid objects. """Update common quantities related to rigid objects.
...@@ -386,7 +379,7 @@ class RigidObject(AssetBase): ...@@ -386,7 +379,7 @@ class RigidObject(AssetBase):
self._data.root_vel_b[:, 3:6] = math_utils.quat_rotate_inverse( self._data.root_vel_b[:, 3:6] = math_utils.quat_rotate_inverse(
self._data.root_quat_w, self._data.root_ang_vel_w self._data.root_quat_w, self._data.root_ang_vel_w
) )
self._data.projected_gravity_b[:] = math_utils.quat_rotate_inverse(self._data.root_quat_w, self._GRAVITY_VEC_W) self._data.projected_gravity_b[:] = math_utils.quat_rotate_inverse(self._data.root_quat_w, self.GRAVITY_VEC_W)
# -- heading direction of root # -- heading direction of root
forward_w = math_utils.quat_apply(self._data.root_quat_w, self._FORWARD_VEC_B) forward_w = math_utils.quat_apply(self._data.root_quat_w, self.FORWARD_VEC_B)
self._data.heading_w[:] = torch.atan2(forward_w[:, 1], forward_w[:, 0]) self._data.heading_w[:] = torch.atan2(forward_w[:, 1], forward_w[:, 0])
...@@ -123,7 +123,8 @@ class BaseEnv: ...@@ -123,7 +123,8 @@ class BaseEnv:
# note: this activates the physics simulation view that exposes TensorAPIs # note: this activates the physics simulation view that exposes TensorAPIs
# note: when started in extension mode, first call sim.reset_async() and then initialize the managers # note: when started in extension mode, first call sim.reset_async() and then initialize the managers
if builtins.ISAAC_LAUNCHED_FROM_TERMINAL is False: if builtins.ISAAC_LAUNCHED_FROM_TERMINAL is False:
with Timer("[INFO]: Time taken for simulation reset"): print("[INFO]: Starting the simulation. This may take a few seconds. Please wait...")
with Timer("[INFO]: Time taken for simulation start"):
self.sim.reset() self.sim.reset()
# add timeline event to load managers # add timeline event to load managers
self.load_managers() self.load_managers()
...@@ -279,9 +280,14 @@ class BaseEnv: ...@@ -279,9 +280,14 @@ class BaseEnv:
Returns: Returns:
The seed used for random generator. The seed used for random generator.
""" """
# set seed for replicator
try:
import omni.replicator.core as rep import omni.replicator.core as rep
rep.set_global_seed(seed) rep.set_global_seed(seed)
except ModuleNotFoundError:
pass
# set seed for torch and other libraries
return torch_utils.set_seed(seed) return torch_utils.set_seed(seed)
def close(self): def close(self):
......
...@@ -67,10 +67,10 @@ def randomize_rigid_body_material( ...@@ -67,10 +67,10 @@ def randomize_rigid_body_material(
material_buckets[:, 1].uniform_(*dynamic_friction_range) material_buckets[:, 1].uniform_(*dynamic_friction_range)
material_buckets[:, 2].uniform_(*restitution_range) material_buckets[:, 2].uniform_(*restitution_range)
# create random material assignments based on the total number of shapes: num_assets x num_bodies x num_shapes # create random material assignments based on the total number of shapes: num_assets x num_bodies x num_shapes
material_ids = torch.randint(0, num_buckets, (asset.body_view.count, asset.body_view.num_shapes)) material_ids = torch.randint(0, num_buckets, (asset.body_physx_view.count, asset.body_physx_view.max_shapes))
materials = material_buckets[material_ids] materials = material_buckets[material_ids]
# resolve the global body indices from the env_ids and the env_body_ids # resolve the global body indices from the env_ids and the env_body_ids
bodies_per_env = asset.body_view.count // num_envs # - number of bodies per spawned asset bodies_per_env = asset.body_physx_view.count // num_envs # - number of bodies per spawned asset
indices = torch.tensor(asset_cfg.body_ids, dtype=torch.int).repeat(len(env_ids), 1) indices = torch.tensor(asset_cfg.body_ids, dtype=torch.int).repeat(len(env_ids), 1)
indices += env_ids.unsqueeze(1) * bodies_per_env indices += env_ids.unsqueeze(1) * bodies_per_env
...@@ -99,7 +99,7 @@ def add_body_mass( ...@@ -99,7 +99,7 @@ def add_body_mass(
masses = asset.body_physx_view.get_masses() masses = asset.body_physx_view.get_masses()
masses += sample_uniform(*mass_range, masses.shape, device=masses.device) masses += sample_uniform(*mass_range, masses.shape, device=masses.device)
# resolve the global body indices from the env_ids and the env_body_ids # resolve the global body indices from the env_ids and the env_body_ids
bodies_per_env = asset.body_view.count // env.num_envs bodies_per_env = asset.body_physx_view.count // env.num_envs
indices = torch.tensor(asset_cfg.body_ids, dtype=torch.int).repeat(len(env_ids), 1) indices = torch.tensor(asset_cfg.body_ids, dtype=torch.int).repeat(len(env_ids), 1)
indices += env_ids.unsqueeze(1) * bodies_per_env indices += env_ids.unsqueeze(1) * bodies_per_env
......
...@@ -10,12 +10,5 @@ This includes functionalities such as tracking a robot in the simulation, ...@@ -10,12 +10,5 @@ This includes functionalities such as tracking a robot in the simulation,
toggling different debug visualization tools, and other user-defined functionalities. toggling different debug visualization tools, and other user-defined functionalities.
""" """
# enable the extension for UI elements
# this only needs to be done once
from omni.isaac.core.utils.extensions import enable_extension
enable_extension("omni.isaac.ui")
# import all UI elements here
from .base_env_window import BaseEnvWindow from .base_env_window import BaseEnvWindow
from .rl_task_env_window import RLTaskEnvWindow from .rl_task_env_window import RLTaskEnvWindow
...@@ -10,14 +10,13 @@ import torch ...@@ -10,14 +10,13 @@ import torch
from typing import Any, Sequence from typing import Any, Sequence
import carb import carb
import omni.isaac.core.utils.prims as prim_utils import omni.usd
import omni.isaac.core.utils.stage as stage_utils
from omni.isaac.cloner import GridCloner from omni.isaac.cloner import GridCloner
from omni.isaac.core.prims import XFormPrimView from omni.isaac.core.prims import XFormPrimView
from omni.isaac.core.simulation_context import SimulationContext
from omni.isaac.version import get_version from omni.isaac.version import get_version
from pxr import PhysxSchema from pxr import PhysxSchema
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.assets import Articulation, ArticulationCfg, AssetBaseCfg, RigidObject, RigidObjectCfg from omni.isaac.orbit.assets import Articulation, ArticulationCfg, AssetBaseCfg, RigidObject, RigidObjectCfg
from omni.isaac.orbit.sensors import FrameTransformerCfg, SensorBase, SensorBaseCfg from omni.isaac.orbit.sensors import FrameTransformerCfg, SensorBase, SensorBaseCfg
from omni.isaac.orbit.terrains import TerrainImporter, TerrainImporterCfg from omni.isaac.orbit.terrains import TerrainImporter, TerrainImporterCfg
...@@ -112,12 +111,14 @@ class InteractiveScene: ...@@ -112,12 +111,14 @@ class InteractiveScene:
""" """
# store inputs # store inputs
self.cfg = cfg self.cfg = cfg
# obtain the current stage
self.stage = omni.usd.get_context().get_stage()
# prepare cloner for environment replication # prepare cloner for environment replication
self.cloner = GridCloner(spacing=self.cfg.env_spacing) self.cloner = GridCloner(spacing=self.cfg.env_spacing)
self.cloner.define_base_env(self.env_ns) self.cloner.define_base_env(self.env_ns)
self.env_prim_paths = self.cloner.generate_paths(f"{self.env_ns}/env", self.cfg.num_envs) self.env_prim_paths = self.cloner.generate_paths(f"{self.env_ns}/env", self.cfg.num_envs)
# create source prim # create source prim
prim_utils.define_prim(self.env_prim_paths[0], "Xform") self.stage.DefinePrim(self.env_prim_paths[0], "Xform")
# obtain major isaac sim version # obtain major isaac sim version
isaac_major_version = int(get_version()[2]) isaac_major_version = int(get_version()[2])
# clone the env xform # clone the env xform
...@@ -158,7 +159,7 @@ class InteractiveScene: ...@@ -158,7 +159,7 @@ class InteractiveScene:
) )
# obtain the current physics scene # obtain the current physics scene
physics_scene_prim_path = None physics_scene_prim_path = None
for prim in stage_utils.traverse_stage(): for prim in self.stage.Traverse():
if prim.HasAPI(PhysxSchema.PhysxSceneAPI): if prim.HasAPI(PhysxSchema.PhysxSceneAPI):
physics_scene_prim_path = prim.GetPrimPath() physics_scene_prim_path = prim.GetPrimPath()
carb.log_info(f"Physics scene prim path: {physics_scene_prim_path}") carb.log_info(f"Physics scene prim path: {physics_scene_prim_path}")
...@@ -188,12 +189,12 @@ class InteractiveScene: ...@@ -188,12 +189,12 @@ class InteractiveScene:
@property @property
def physics_dt(self) -> float: def physics_dt(self) -> float:
"""The physics timestep of the scene.""" """The physics timestep of the scene."""
return SimulationContext.instance().get_physics_dt() # pyright: ignore [reportOptionalMemberAccess] return sim_utils.SimulationContext.instance().get_physics_dt() # pyright: ignore [reportOptionalMemberAccess]
@property @property
def device(self) -> str: def device(self) -> str:
"""The device on which the scene is created.""" """The device on which the scene is created."""
return SimulationContext.instance().device # pyright: ignore [reportOptionalMemberAccess] return sim_utils.SimulationContext.instance().device # pyright: ignore [reportOptionalMemberAccess]
@property @property
def env_ns(self) -> str: def env_ns(self) -> str:
...@@ -247,7 +248,7 @@ class InteractiveScene: ...@@ -247,7 +248,7 @@ class InteractiveScene:
# note: In standalone mode, this method is called in the `step()` method of the simulation context. # note: In standalone mode, this method is called in the `step()` method of the simulation context.
# So we only need to flush when running in extension mode. # So we only need to flush when running in extension mode.
if builtins.ISAAC_LAUNCHED_FROM_TERMINAL: if builtins.ISAAC_LAUNCHED_FROM_TERMINAL:
SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess] sim_utils.SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess]
def write_data_to_sim(self): def write_data_to_sim(self):
"""Writes the data of the scene entities to the simulation.""" """Writes the data of the scene entities to the simulation."""
...@@ -262,7 +263,7 @@ class InteractiveScene: ...@@ -262,7 +263,7 @@ class InteractiveScene:
# note: In standalone mode, this method is called in the `step()` method of the simulation context. # note: In standalone mode, this method is called in the `step()` method of the simulation context.
# So we only need to flush when running in extension mode. # So we only need to flush when running in extension mode.
if builtins.ISAAC_LAUNCHED_FROM_TERMINAL: if builtins.ISAAC_LAUNCHED_FROM_TERMINAL:
SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess] sim_utils.SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess]
def update(self, dt: float) -> None: def update(self, dt: float) -> None:
"""Update the scene entities. """Update the scene entities.
...@@ -370,5 +371,5 @@ class InteractiveScene: ...@@ -370,5 +371,5 @@ class InteractiveScene:
raise ValueError(f"Unknown asset config type for {asset_name}: {asset_cfg}") raise ValueError(f"Unknown asset config type for {asset_name}: {asset_cfg}")
# store global collision paths # store global collision paths
if hasattr(asset_cfg, "collision_group") and asset_cfg.collision_group == -1: if hasattr(asset_cfg, "collision_group") and asset_cfg.collision_group == -1:
asset_paths = prim_utils.find_matching_prim_paths(asset_cfg.prim_path) asset_paths = sim_utils.find_matching_prim_paths(asset_cfg.prim_path)
self._global_prim_paths += asset_paths self._global_prim_paths += asset_paths
...@@ -12,13 +12,12 @@ from tensordict import TensorDict ...@@ -12,13 +12,12 @@ from tensordict import TensorDict
from typing import TYPE_CHECKING, Any, Sequence from typing import TYPE_CHECKING, Any, Sequence
from typing_extensions import Literal from typing_extensions import Literal
import omni.isaac.core.utils.prims as prim_utils
import omni.kit.commands import omni.kit.commands
import omni.usd import omni.usd
from omni.isaac.core.prims import XFormPrimView from omni.isaac.core.prims import XFormPrimView
from pxr import UsdGeom from pxr import UsdGeom
# omni-isaac-orbit import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.utils import to_camel_case from omni.isaac.orbit.utils import to_camel_case
from omni.isaac.orbit.utils.array import convert_to_torch from omni.isaac.orbit.utils.array import convert_to_torch
from omni.isaac.orbit.utils.math import quat_from_matrix from omni.isaac.orbit.utils.math import quat_from_matrix
...@@ -93,8 +92,8 @@ class Camera(SensorBase): ...@@ -93,8 +92,8 @@ class Camera(SensorBase):
self.cfg.prim_path, self.cfg.spawn, translation=self.cfg.offset.pos, orientation=rot_offset self.cfg.prim_path, self.cfg.spawn, translation=self.cfg.offset.pos, orientation=rot_offset
) )
# check that spawn was successful # check that spawn was successful
matching_prim_paths = prim_utils.find_matching_prim_paths(self.cfg.prim_path) matching_prims = sim_utils.find_matching_prims(self.cfg.prim_path)
if len(matching_prim_paths) == 0: if len(matching_prims) == 0:
raise RuntimeError(f"Could not find prim with path {self.cfg.prim_path}.") raise RuntimeError(f"Could not find prim with path {self.cfg.prim_path}.")
# UsdGeom Camera prim for the sensor # UsdGeom Camera prim for the sensor
...@@ -127,6 +126,10 @@ class Camera(SensorBase): ...@@ -127,6 +126,10 @@ class Camera(SensorBase):
Properties Properties
""" """
@property
def num_instances(self) -> int:
return self._view.count
@property @property
def data(self) -> CameraData: def data(self) -> CameraData:
# update sensors if needed # update sensors if needed
...@@ -351,10 +354,12 @@ class Camera(SensorBase): ...@@ -351,10 +354,12 @@ class Camera(SensorBase):
device_name = self._device.split(":")[0] device_name = self._device.split(":")[0]
else: else:
device_name = "cpu" device_name = "cpu"
# Obtain current stage
stage = omni.usd.get_context().get_stage()
# Convert all encapsulated prims to Camera # Convert all encapsulated prims to Camera
for cam_prim_path in self._view.prim_paths: for cam_prim_path in self._view.prim_paths:
# Get camera prim # Get camera prim
cam_prim = prim_utils.get_prim_at_path(cam_prim_path) cam_prim = stage.GetPrimAtPath(cam_prim_path)
# Check if prim is a camera # Check if prim is a camera
if not cam_prim.IsA(UsdGeom.Camera): if not cam_prim.IsA(UsdGeom.Camera):
raise RuntimeError(f"Prim at path '{cam_prim_path}' is not a Camera.") raise RuntimeError(f"Prim at path '{cam_prim_path}' is not a Camera.")
......
...@@ -11,13 +11,13 @@ from __future__ import annotations ...@@ -11,13 +11,13 @@ from __future__ import annotations
import torch import torch
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING, Sequence
import omni.isaac.core.utils.prims as prim_utils
import omni.physics.tensors.impl.api as physx import omni.physics.tensors.impl.api as physx
from omni.isaac.core.prims import RigidContactView, RigidPrimView
from pxr import PhysxSchema from pxr import PhysxSchema
import omni.isaac.orbit.sim as sim_utils
import omni.isaac.orbit.utils.string as string_utils import omni.isaac.orbit.utils.string as string_utils
from omni.isaac.orbit.markers import VisualizationMarkers from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.utils.math import convert_quat
from ..sensor_base import SensorBase from ..sensor_base import SensorBase
from .contact_sensor_data import ContactSensorData from .contact_sensor_data import ContactSensorData
...@@ -62,7 +62,7 @@ class ContactSensor(SensorBase): ...@@ -62,7 +62,7 @@ class ContactSensor(SensorBase):
"""Returns: A string containing information about the instance.""" """Returns: A string containing information about the instance."""
return ( return (
f"Contact sensor @ '{self.cfg.prim_path}': \n" f"Contact sensor @ '{self.cfg.prim_path}': \n"
f"\tview type : {self._view.__class__}\n" f"\tview type : {self.body_physx_view.__class__}\n"
f"\tupdate period (s) : {self.cfg.update_period}\n" f"\tupdate period (s) : {self.cfg.update_period}\n"
f"\tnumber of bodies : {self.num_bodies}\n" f"\tnumber of bodies : {self.num_bodies}\n"
f"\tbody names : {self.body_names}\n" f"\tbody names : {self.body_names}\n"
...@@ -72,6 +72,10 @@ class ContactSensor(SensorBase): ...@@ -72,6 +72,10 @@ class ContactSensor(SensorBase):
Properties Properties
""" """
@property
def num_instances(self) -> int:
return self.body_physx_view.count
@property @property
def data(self) -> ContactSensorData: def data(self) -> ContactSensorData:
# update sensors if needed # update sensors if needed
...@@ -87,38 +91,26 @@ class ContactSensor(SensorBase): ...@@ -87,38 +91,26 @@ class ContactSensor(SensorBase):
@property @property
def body_names(self) -> list[str]: def body_names(self) -> list[str]:
"""Ordered names of bodies with contact sensors attached.""" """Ordered names of bodies with contact sensors attached."""
prim_paths = self._view.prim_paths[: self.num_bodies] prim_paths = self.body_physx_view.prim_paths[: self.num_bodies]
return [path.split("/")[-1] for path in prim_paths] return [path.split("/")[-1] for path in prim_paths]
@property
def body_view(self) -> RigidPrimView:
"""View for the rigid bodies captured (Isaac Sim)."""
return self._view
@property
def contact_view(self) -> RigidContactView:
"""Contact reporter view for the bodies (Isaac Sim)."""
return self._view._contact_view # pyright: ignore [reportPrivateUsage]
@property @property
def body_physx_view(self) -> physx.RigidBodyView: def body_physx_view(self) -> physx.RigidBodyView:
"""View for the rigid bodies captured (PhysX). """View for the rigid bodies captured (PhysX).
Note: Note:
Use this view with caution! It requires handling of tensors in a specific way and is exposed for Use this view with caution. It requires handling of tensors in a specific way.
advanced users who have a deep understanding of PhysX SDK. Prefer using the Isaac Sim view when possible.
""" """
return self._view._physics_view # pyright: ignore [reportPrivateUsage] return self._body_physx_view
@property @property
def contact_physx_view(self) -> physx.RigidContactView: def contact_physx_view(self) -> physx.RigidContactView:
"""Contact reporter view for the bodies (PhysX). """Contact reporter view for the bodies (PhysX).
Note: Note:
Use this view with caution! It requires handling of tensors in a specific way and is exposed for Use this view with caution. It requires handling of tensors in a specific way.
advanced users who have a deep understanding of PhysX SDK. Prefer using the Isaac Sim view when possible.
""" """
return self._view._contact_view._physics_view # pyright: ignore [reportPrivateUsage] return self._contact_physx_view
""" """
Operations Operations
...@@ -163,14 +155,17 @@ class ContactSensor(SensorBase): ...@@ -163,14 +155,17 @@ class ContactSensor(SensorBase):
def _initialize_impl(self): def _initialize_impl(self):
super()._initialize_impl() super()._initialize_impl()
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# check that only rigid bodies are selected # check that only rigid bodies are selected
matching_prim_paths = prim_utils.find_matching_prim_paths(self.cfg.prim_path) leaf_pattern = self.cfg.prim_path.rsplit("/", 1)[-1]
num_prim_matches = len(matching_prim_paths) // self._num_envs template_prim_path = self._parent_prims[0].GetPath().pathString
body_names = list() body_names = list()
for prim_path in matching_prim_paths[:num_prim_matches]: for prim in sim_utils.find_matching_prims(template_prim_path + "/" + leaf_pattern):
prim = prim_utils.get_prim_at_path(prim_path)
# check if prim has contact reporter API # check if prim has contact reporter API
if prim.HasAPI(PhysxSchema.PhysxContactReportAPI): if prim.HasAPI(PhysxSchema.PhysxContactReportAPI):
prim_path = prim.GetPath().pathString
body_names.append(prim_path.rsplit("/", 1)[-1]) body_names.append(prim_path.rsplit("/", 1)[-1])
# check that there is at least one body with contact reporter API # check that there is at least one body with contact reporter API
if not body_names: if not body_names:
...@@ -183,17 +178,12 @@ class ContactSensor(SensorBase): ...@@ -183,17 +178,12 @@ class ContactSensor(SensorBase):
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}" body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
# construct a new regex expression # construct a new regex expression
# create a rigid prim view for the sensor # create a rigid prim view for the sensor
self._view = RigidPrimView( self._body_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
prim_paths_expr=body_names_regex, self._contact_physx_view = self._physics_sim_view.create_rigid_contact_view(
reset_xform_properties=False, body_names_regex.replace(".*", "*"), filter_patterns=self.cfg.filter_prim_paths_expr
track_contact_forces=True,
contact_filter_prim_paths_expr=self.cfg.filter_prim_paths_expr,
prepare_contact_sensors=False,
disable_stablization=True,
) )
self._view.initialize()
# resolve the true count of bodies # resolve the true count of bodies
self._num_bodies = self._view.count // self._num_envs self._num_bodies = self.body_physx_view.count // self._num_envs
# check that contact reporter succeeded # check that contact reporter succeeded
if self._num_bodies != len(body_names): if self._num_bodies != len(body_names):
raise RuntimeError( raise RuntimeError(
...@@ -225,7 +215,7 @@ class ContactSensor(SensorBase): ...@@ -225,7 +215,7 @@ class ContactSensor(SensorBase):
num_shapes = self.contact_physx_view.sensor_count // self._num_bodies num_shapes = self.contact_physx_view.sensor_count // self._num_bodies
num_filters = self.contact_physx_view.filter_count num_filters = self.contact_physx_view.filter_count
self._data.force_matrix_w = torch.zeros( self._data.force_matrix_w = torch.zeros(
self.count, self._num_bodies, num_shapes, num_filters, 3, device=self._device self._num_envs, self._num_bodies, num_shapes, num_filters, 3, device=self._device
) )
def _update_buffers_impl(self, env_ids: Sequence[int]): def _update_buffers_impl(self, env_ids: Sequence[int]):
...@@ -255,9 +245,9 @@ class ContactSensor(SensorBase): ...@@ -255,9 +245,9 @@ class ContactSensor(SensorBase):
self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids] self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids]
# obtain the pose of the sensor origin # obtain the pose of the sensor origin
if self.cfg.track_pose: if self.cfg.track_pose:
pose = self.body_physx_view.get_transforms() pose = self.body_physx_view.get_transforms().view(-1, self._num_bodies, 7)[env_ids]
self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3] pose[..., 3:] = convert_quat(pose[..., 3:], to="wxyz")
self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:] self._data.pos_w[env_ids], self._data.quat_w[env_ids] = pose.split([3, 4], dim=-1)
# obtain the air time # obtain the air time
if self.cfg.track_air_time: if self.cfg.track_air_time:
# -- time elapsed since last update # -- time elapsed since last update
......
...@@ -9,10 +9,10 @@ import torch ...@@ -9,10 +9,10 @@ import torch
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING, Sequence
import carb import carb
import omni.isaac.core.utils.prims as prim_utils import omni.physics.tensors.impl.api as physx
from omni.isaac.core.prims import RigidPrimView
from pxr import UsdPhysics from pxr import UsdPhysics
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.markers import VisualizationMarkers from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.utils.math import ( from omni.isaac.orbit.utils.math import (
combine_frame_transforms, combine_frame_transforms,
...@@ -151,14 +151,15 @@ class FrameTransformer(SensorBase): ...@@ -151,14 +151,15 @@ class FrameTransformer(SensorBase):
frame_offsets = [None] + [target_frame.offset for target_frame in self.cfg.target_frames] frame_offsets = [None] + [target_frame.offset for target_frame in self.cfg.target_frames]
for frame, prim_path, offset in zip(frames, frame_prim_paths, frame_offsets): for frame, prim_path, offset in zip(frames, frame_prim_paths, frame_offsets):
# Find correct prim # Find correct prim
matching_prims = prim_utils.find_matching_prim_paths(prim_path) matching_prims = sim_utils.find_matching_prims(prim_path)
if len(matching_prims) == 0: if len(matching_prims) == 0:
raise ValueError( raise ValueError(
f"Failed to create frame transformer for frame '{frame}' with path '{prim_path}'." f"Failed to create frame transformer for frame '{frame}' with path '{prim_path}'."
" No matching prims were found." " No matching prims were found."
) )
for matching_prim_path in matching_prims: for prim in matching_prims:
prim = prim_utils.get_prim_at_path(matching_prim_path) # Get the prim path of the matching prim
matching_prim_path = prim.GetPath().pathString
# check if it is a rigid prim # check if it is a rigid prim
if not prim.HasAPI(UsdPhysics.RigidBodyAPI): if not prim.HasAPI(UsdPhysics.RigidBodyAPI):
raise ValueError( raise ValueError(
...@@ -216,14 +217,16 @@ class FrameTransformer(SensorBase): ...@@ -216,14 +217,16 @@ class FrameTransformer(SensorBase):
body_names_regex = r"(" + "|".join(self._tracked_body_names) + r")" body_names_regex = r"(" + "|".join(self._tracked_body_names) + r")"
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}" body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# Create a prim view for all frames and initialize it # Create a prim view for all frames and initialize it
# order of transforms coming out of view will be source frame followed by target frame(s) # order of transforms coming out of view will be source frame followed by target frame(s)
self._frame_view = RigidPrimView(prim_paths_expr=body_names_regex, reset_xform_properties=False) self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
self._frame_view.initialize()
# Determine the order in which regex evaluated body names so we can later index into frame transforms # Determine the order in which regex evaluated body names so we can later index into frame transforms
# by frame name correctly # by frame name correctly
all_prim_paths = self._frame_view.prim_paths all_prim_paths = self._frame_physx_view.prim_paths
# Only need first env as the names and their orderring are the same across environments # Only need first env as the names and their orderring are the same across environments
first_env_prim_paths = all_prim_paths[0 : self._num_target_body_frames + 1] first_env_prim_paths = all_prim_paths[0 : self._num_target_body_frames + 1]
...@@ -282,18 +285,14 @@ class FrameTransformer(SensorBase): ...@@ -282,18 +285,14 @@ class FrameTransformer(SensorBase):
# Extract transforms from view - shape is: # Extract transforms from view - shape is:
# (the total number of source and target body frames being tracked * self._num_envs, 7) # (the total number of source and target body frames being tracked * self._num_envs, 7)
transforms = self._frame_view._physics_view.get_transforms() transforms = self._frame_physx_view.get_transforms()
# Convert quaternions as PhysX uses xyzw form
transforms[:, 3:] = convert_quat(transforms[:, 3:], to="wxyz")
# Process source frame transform
source_frames = transforms[self._source_frame_idxs] source_frames = transforms[self._source_frame_idxs]
target_frames = transforms[self._target_frame_idxs]
# Convert quaternions as Isaac uses xyzw form
source_frames[:, 3:] = convert_quat(source_frames[:, 3:], to="wxyz")
target_frames[:, 3:] = convert_quat(target_frames[:, 3:], to="wxyz")
# Only apply offset if the offsets will result in a coordinate frame transform # Only apply offset if the offsets will result in a coordinate frame transform
if self._apply_source_frame_offset: if self._apply_source_frame_offset:
# Apply offsets for source frame
source_pos_w, source_rot_w = combine_frame_transforms( source_pos_w, source_rot_w = combine_frame_transforms(
source_frames[:, :3], source_frames[:, :3],
source_frames[:, 3:], source_frames[:, 3:],
...@@ -304,12 +303,12 @@ class FrameTransformer(SensorBase): ...@@ -304,12 +303,12 @@ class FrameTransformer(SensorBase):
source_pos_w = source_frames[:, :3] source_pos_w = source_frames[:, :3]
source_rot_w = source_frames[:, 3:] source_rot_w = source_frames[:, 3:]
# Process target frame transforms
target_frames = transforms[self._target_frame_idxs]
duplicated_target_frame_pos_w = target_frames[self._duplicate_frame_indices, :3] duplicated_target_frame_pos_w = target_frames[self._duplicate_frame_indices, :3]
duplicated_target_frame_rot_w = target_frames[self._duplicate_frame_indices, 3:] duplicated_target_frame_rot_w = target_frames[self._duplicate_frame_indices, 3:]
# Only apply offset if the offsets will result in a coordinate frame transform # Only apply offset if the offsets will result in a coordinate frame transform
if self._apply_target_frame_offset: if self._apply_target_frame_offset:
# Apply offsets for target frame
target_pos_w, target_rot_w = combine_frame_transforms( target_pos_w, target_rot_w = combine_frame_transforms(
duplicated_target_frame_pos_w, duplicated_target_frame_pos_w,
duplicated_target_frame_rot_w, duplicated_target_frame_rot_w,
...@@ -320,8 +319,8 @@ class FrameTransformer(SensorBase): ...@@ -320,8 +319,8 @@ class FrameTransformer(SensorBase):
target_pos_w = duplicated_target_frame_pos_w target_pos_w = duplicated_target_frame_pos_w
target_rot_w = duplicated_target_frame_rot_w target_rot_w = duplicated_target_frame_rot_w
# Compute the transform of the target frame with respect to the source frame
total_num_frames = len(self._target_frame_names) total_num_frames = len(self._target_frame_names)
target_pos_source, target_rot_source = subtract_frame_transforms( target_pos_source, target_rot_source = subtract_frame_transforms(
source_pos_w.unsqueeze(1).expand(-1, total_num_frames, -1).reshape(-1, 3), source_pos_w.unsqueeze(1).expand(-1, total_num_frames, -1).reshape(-1, 3),
source_rot_w.unsqueeze(1).expand(-1, total_num_frames, -1).reshape(-1, 4), source_rot_w.unsqueeze(1).expand(-1, total_num_frames, -1).reshape(-1, 4),
...@@ -330,7 +329,7 @@ class FrameTransformer(SensorBase): ...@@ -330,7 +329,7 @@ class FrameTransformer(SensorBase):
) )
# Update buffers # Update buffers
# NOTE: The frame names / orderring don't change so no need to update them after initialization # note: The frame names / ordering don't change so no need to update them after initialization
self._data.source_pos_w[:] = source_pos_w.view(-1, 3) self._data.source_pos_w[:] = source_pos_w.view(-1, 3)
self._data.source_rot_w[:] = source_rot_w.view(-1, 4) self._data.source_rot_w[:] = source_rot_w.view(-1, 4)
self._data.target_pos_w[:] = target_pos_w.view(-1, total_num_frames, 3) self._data.target_pos_w[:] = target_pos_w.view(-1, total_num_frames, 3)
......
...@@ -10,15 +10,15 @@ import torch ...@@ -10,15 +10,15 @@ import torch
from typing import TYPE_CHECKING, ClassVar, Sequence from typing import TYPE_CHECKING, ClassVar, Sequence
import carb import carb
import omni.isaac.core.utils.prims as prim_utils import omni.physics.tensors.impl.api as physx
import warp as wp import warp as wp
from omni.isaac.core.articulations import ArticulationView from omni.isaac.core.prims import XFormPrimView
from omni.isaac.core.prims import RigidPrimView, XFormPrimView
from pxr import UsdGeom, UsdPhysics from pxr import UsdGeom, UsdPhysics
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.markers import VisualizationMarkers from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.terrains.trimesh.utils import make_plane from omni.isaac.orbit.terrains.trimesh.utils import make_plane
from omni.isaac.orbit.utils.math import quat_apply, quat_apply_yaw from omni.isaac.orbit.utils.math import convert_quat, quat_apply, quat_apply_yaw
from omni.isaac.orbit.utils.warp import convert_to_warp_mesh, raycast_mesh from omni.isaac.orbit.utils.warp import convert_to_warp_mesh, raycast_mesh
from ..sensor_base import SensorBase from ..sensor_base import SensorBase
...@@ -82,6 +82,10 @@ class RayCaster(SensorBase): ...@@ -82,6 +82,10 @@ class RayCaster(SensorBase):
Properties Properties
""" """
@property
def num_instances(self) -> int:
return self._view.count
@property @property
def data(self) -> RayCasterData: def data(self) -> RayCasterData:
# update sensors if needed # update sensors if needed
...@@ -108,29 +112,30 @@ class RayCaster(SensorBase): ...@@ -108,29 +112,30 @@ class RayCaster(SensorBase):
def _initialize_impl(self): def _initialize_impl(self):
super()._initialize_impl() super()._initialize_impl()
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# check if the prim at path is an articulated or rigid prim # check if the prim at path is an articulated or rigid prim
# we do this since for physics-based view classes we can access their data directly # we do this since for physics-based view classes we can access their data directly
# otherwise we need to use the xform view class which is slower # otherwise we need to use the xform view class which is slower
prim_view_class = None found_supported_prim_class = False
for prim_path in prim_utils.find_matching_prim_paths(self.cfg.prim_path): prim = sim_utils.find_first_matching_prim(self.cfg.prim_path)
# get prim at path if prim is None:
prim = prim_utils.get_prim_at_path(prim_path) raise RuntimeError(f"Failed to find a prim at path expression: {self.cfg.prim_path}")
# check if it is a rigid prim # create view based on the type of prim
if prim.HasAPI(UsdPhysics.ArticulationRootAPI): if prim.HasAPI(UsdPhysics.ArticulationRootAPI):
prim_view_class = ArticulationView self._view = self._physics_sim_view.create_articulation_view(self.cfg.prim_path.replace(".*", "*"))
found_supported_prim_class = True
elif prim.HasAPI(UsdPhysics.RigidBodyAPI): elif prim.HasAPI(UsdPhysics.RigidBodyAPI):
prim_view_class = RigidPrimView self._view = self._physics_sim_view.create_rigid_body_view(self.cfg.prim_path.replace(".*", "*"))
found_supported_prim_class = True
else: else:
prim_view_class = XFormPrimView self._view = XFormPrimView(self.cfg.prim_path, reset_xform_properties=False)
carb.log_warn(f"The prim at path {prim_path} is not a physics prim! Using XFormPrimView.") found_supported_prim_class = True
# break the loop carb.log_warn(f"The prim at path {prim.GetPath().pathString} is not a physics prim! Using XFormPrimView.")
break
# check if prim view class is found # check if prim view class is found
if prim_view_class is None: if not found_supported_prim_class:
raise RuntimeError(f"Failed to find a valid prim view class for the prim paths: {self.cfg.prim_path}") raise RuntimeError(f"Failed to find a valid prim view class for the prim paths: {self.cfg.prim_path}")
# create a rigid prim view for the sensor
self._view = prim_view_class(self.cfg.prim_path, reset_xform_properties=False)
self._view.initialize()
# load the meshes by parsing the stage # load the meshes by parsing the stage
self._initialize_warp_meshes() self._initialize_warp_meshes()
...@@ -152,17 +157,17 @@ class RayCaster(SensorBase): ...@@ -152,17 +157,17 @@ class RayCaster(SensorBase):
# check if the prim is a plane - handle PhysX plane as a special case # check if the prim is a plane - handle PhysX plane as a special case
# if a plane exists then we need to create an infinite mesh that is a plane # if a plane exists then we need to create an infinite mesh that is a plane
mesh_prim = prim_utils.get_first_matching_child_prim( mesh_prim = sim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda p: prim_utils.get_prim_type_name(p) == "Plane" mesh_prim_path, lambda prim: prim.GetTypeName() == "Plane"
) )
# if we did not find a plane then we need to read the mesh # if we did not find a plane then we need to read the mesh
if mesh_prim is None: if mesh_prim is None:
# obtain the mesh prim # obtain the mesh prim
mesh_prim = prim_utils.get_first_matching_child_prim( mesh_prim = sim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda p: prim_utils.get_prim_type_name(p) == "Mesh" mesh_prim_path, lambda prim: prim.GetTypeName() == "Mesh"
) )
# check if valid # check if valid
if not prim_utils.is_prim_path_valid(mesh_prim_path): if mesh_prim is None or not mesh_prim.IsValid():
raise RuntimeError(f"Invalid mesh prim path: {mesh_prim_path}") raise RuntimeError(f"Invalid mesh prim path: {mesh_prim_path}")
# cast into UsdGeomMesh # cast into UsdGeomMesh
mesh_prim = UsdGeom.Mesh(mesh_prim) mesh_prim = UsdGeom.Mesh(mesh_prim)
...@@ -210,8 +215,22 @@ class RayCaster(SensorBase): ...@@ -210,8 +215,22 @@ class RayCaster(SensorBase):
def _update_buffers_impl(self, env_ids: Sequence[int]): def _update_buffers_impl(self, env_ids: Sequence[int]):
"""Fills the buffers of the sensor data.""" """Fills the buffers of the sensor data."""
# obtain the poses of the sensors # obtain the poses of the sensors
pos_w, quat_w = self._view.get_world_poses(env_ids, clone=False) if isinstance(self._view, XFormPrimView):
pos_w, quat_w = self._view.get_world_poses(env_ids)
elif isinstance(self._view, physx.ArticulationView):
pos_w, quat_w = self._view.get_root_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = convert_quat(quat_w, to="wxyz")
elif isinstance(self._view, physx.RigidBodyView):
pos_w, quat_w = self._view.get_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = convert_quat(quat_w, to="wxyz")
else:
raise RuntimeError(f"Unsupported view type: {type(self._view)}")
# note: we clone here because we are read-only operations
pos_w = pos_w.clone()
quat_w = quat_w.clone()
# apply drift
pos_w += self.drift[env_ids] pos_w += self.drift[env_ids]
# store the poses
self._data.pos_w[env_ids] = pos_w self._data.pos_w[env_ids] = pos_w
self._data.quat_w[env_ids] = quat_w self._data.quat_w[env_ids] = quat_w
......
...@@ -10,6 +10,7 @@ from tensordict import TensorDict ...@@ -10,6 +10,7 @@ from tensordict import TensorDict
from typing import TYPE_CHECKING, ClassVar, Sequence from typing import TYPE_CHECKING, ClassVar, Sequence
from typing_extensions import Literal from typing_extensions import Literal
import omni.physics.tensors.impl.api as physx
from omni.isaac.core.prims import XFormPrimView from omni.isaac.core.prims import XFormPrimView
import omni.isaac.orbit.utils.math as math_utils import omni.isaac.orbit.utils.math as math_utils
...@@ -365,12 +366,17 @@ class RayCasterCamera(RayCaster): ...@@ -365,12 +366,17 @@ class RayCasterCamera(RayCaster):
# obtain the poses of the sensors # obtain the poses of the sensors
# note: clone arg doesn't exist for xform prim view so we need to do this manually # note: clone arg doesn't exist for xform prim view so we need to do this manually
if isinstance(self._view, XFormPrimView): if isinstance(self._view, XFormPrimView):
pos_w_temp, quat_w_temp = self._view.get_world_poses(env_ids) pos_w, quat_w = self._view.get_world_poses(env_ids)
pos_w = pos_w_temp.clone() elif isinstance(self._view, physx.ArticulationView):
quat_w = quat_w_temp.clone() pos_w, quat_w = self._view.get_root_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = math_utils.convert_quat(quat_w, to="wxyz")
elif isinstance(self._view, physx.RigidBodyView):
pos_w, quat_w = self._view.get_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = math_utils.convert_quat(quat_w, to="wxyz")
else: else:
pos_w, quat_w = self._view.get_world_poses(env_ids, clone=True) raise RuntimeError(f"Unsupported view type: {type(self._view)}")
return pos_w, quat_w # return the pose
return pos_w.clone(), quat_w.clone()
def _compute_camera_world_poses(self, env_ids: Sequence[int]) -> tuple[torch.Tensor, torch.Tensor]: def _compute_camera_world_poses(self, env_ids: Sequence[int]) -> tuple[torch.Tensor, torch.Tensor]:
"""Computes the pose of the camera in the world frame. """Computes the pose of the camera in the world frame.
......
...@@ -17,10 +17,10 @@ import weakref ...@@ -17,10 +17,10 @@ import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence from typing import TYPE_CHECKING, Any, Sequence
import omni.isaac.core.utils.prims as prim_utils
import omni.kit.app import omni.kit.app
import omni.timeline import omni.timeline
from omni.isaac.core.simulation_context import SimulationContext
import omni.isaac.orbit.sim as sim_utils
if TYPE_CHECKING: if TYPE_CHECKING:
from .sensor_base_cfg import SensorBaseCfg from .sensor_base_cfg import SensorBaseCfg
...@@ -91,6 +91,14 @@ class SensorBase(ABC): ...@@ -91,6 +91,14 @@ class SensorBase(ABC):
Properties Properties
""" """
@property
def num_instances(self) -> int:
"""Number of instances of the sensor.
This is equal to the number of sensors per environment multiplied by the number of environments.
"""
return self._num_envs
@property @property
def device(self) -> str: def device(self) -> str:
"""Memory device for computation.""" """Memory device for computation."""
...@@ -193,15 +201,17 @@ class SensorBase(ABC): ...@@ -193,15 +201,17 @@ class SensorBase(ABC):
def _initialize_impl(self): def _initialize_impl(self):
"""Initializes the sensor-related handles and internal buffers.""" """Initializes the sensor-related handles and internal buffers."""
# Obtain Simulation Context # Obtain Simulation Context
sim = SimulationContext.instance() sim = sim_utils.SimulationContext.instance()
if sim is not None: if sim is None:
raise RuntimeError("Simulation Context is not initialized!")
# Obtain device and backend
self._device = sim.device self._device = sim.device
self._backend = sim.backend
self._sim_physics_dt = sim.get_physics_dt() self._sim_physics_dt = sim.get_physics_dt()
else:
raise RuntimeError("Simulation Context is not initialized!")
# Count number of environments # Count number of environments
env_prim_path_expr = self.cfg.prim_path.rsplit("/", 1)[0] env_prim_path_expr = self.cfg.prim_path.rsplit("/", 1)[0]
self._num_envs = len(prim_utils.find_matching_prim_paths(env_prim_path_expr)) self._parent_prims = sim_utils.find_matching_prims(env_prim_path_expr)
self._num_envs = len(self._parent_prims)
# Boolean tensor indicating whether the sensor data has to be refreshed # Boolean tensor indicating whether the sensor data has to be refreshed
self._is_outdated = torch.ones(self._num_envs, dtype=torch.bool, device=self._device) self._is_outdated = torch.ones(self._num_envs, dtype=torch.bool, device=self._device)
# Current timestamp (in seconds) # Current timestamp (in seconds)
......
...@@ -559,6 +559,7 @@ class SimulationContext(_SimulationContext): ...@@ -559,6 +559,7 @@ class SimulationContext(_SimulationContext):
# check if the simulation is stopped # check if the simulation is stopped
if event.type == int(omni.timeline.TimelineEventType.STOP): if event.type == int(omni.timeline.TimelineEventType.STOP):
# keep running the simulator when configured to not shutdown the app # keep running the simulator when configured to not shutdown the app
if self._has_gui:
self.app.print_and_log( self.app.print_and_log(
"Simulation is stopped. The app will keep running with physics disabled." "Simulation is stopped. The app will keep running with physics disabled."
" Press Ctrl+C or close the window to exit the app." " Press Ctrl+C or close the window to exit the app."
......
...@@ -13,7 +13,6 @@ import re ...@@ -13,7 +13,6 @@ import re
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
import carb import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.isaac.core.utils.stage as stage_utils import omni.isaac.core.utils.stage as stage_utils
import omni.kit.commands import omni.kit.commands
from omni.isaac.cloner import Cloner from omni.isaac.cloner import Cloner
...@@ -226,7 +225,7 @@ def clone(func: Callable) -> Callable: ...@@ -226,7 +225,7 @@ def clone(func: Callable) -> Callable:
# resolve matching prims for source prim path expression # resolve matching prims for source prim path expression
if is_regex_expression and root_path != "": if is_regex_expression and root_path != "":
source_prim_paths = prim_utils.find_matching_prim_paths(root_path) source_prim_paths = find_matching_prim_paths(root_path)
# if no matching prims are found, raise an error # if no matching prims are found, raise an error
if len(source_prim_paths) == 0: if len(source_prim_paths) == 0:
raise RuntimeError( raise RuntimeError(
...@@ -241,7 +240,11 @@ def clone(func: Callable) -> Callable: ...@@ -241,7 +240,11 @@ def clone(func: Callable) -> Callable:
prim = func(prim_paths[0], cfg, *args, **kwargs) prim = func(prim_paths[0], cfg, *args, **kwargs)
# set the prim visibility # set the prim visibility
if hasattr(cfg, "visible"): if hasattr(cfg, "visible"):
prim_utils.set_prim_visibility(prim, cfg.visible) imageable = UsdGeom.Imageable(prim)
if cfg.visible:
imageable.MakeVisible()
else:
imageable.MakeInvisible()
# set the semantic annotations # set the semantic annotations
if hasattr(cfg, "semantic_tags") and cfg.semantic_tags is not None: if hasattr(cfg, "semantic_tags") and cfg.semantic_tags is not None:
# note: taken from replicator scripts.utils.utils.py # note: taken from replicator scripts.utils.utils.py
...@@ -492,3 +495,164 @@ def make_uninstanceable(prim_path: str, stage: Usd.Stage | None = None): ...@@ -492,3 +495,164 @@ def make_uninstanceable(prim_path: str, stage: Usd.Stage | None = None):
child_prim.SetInstanceable(False) child_prim.SetInstanceable(False)
# add children to list # add children to list
all_prims += child_prim.GetChildren() all_prims += child_prim.GetChildren()
"""
USD Stage traversal.
"""
def get_first_matching_child_prim(
prim_path: str, predicate: Callable[[Usd.Prim], bool], stage: Usd.Stage | None = None
) -> Usd.Prim | None:
"""Recursively get the first USD Prim at the path string that passes the predicate function
Args:
prim_path: The path of the prim in the stage.
predicate: The function to test the prims against. It takes a prim as input and returns a boolean.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
The first prim on the path that passes the predicate. If no prim passes the predicate, it returns None.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# get prim
prim = stage.GetPrimAtPath(prim_path)
# check if prim is valid
if not prim.IsValid():
raise ValueError(f"Prim at path '{prim_path}' is not valid.")
# iterate over all prims under prim-path
all_prims = [prim]
while len(all_prims) > 0:
# get current prim
child_prim = all_prims.pop(0)
# check if prim passes predicate
if predicate(child_prim):
return child_prim
# add children to list
all_prims += child_prim.GetChildren()
return None
def get_all_matching_child_prims(
prim_path: str,
predicate: Callable[[Usd.Prim], bool] = lambda _: True,
depth: int | None = None,
stage: Usd.Stage | None = None,
) -> list[Usd.Prim]:
"""Performs a search starting from the root and returns all the prims matching the predicate.
Args:
prim_path: The root prim path to start the search from.
predicate: The predicate that checks if the prim matches the desired criteria. It takes a prim as input
and returns a boolean. Defaults to a function that always returns True.
depth: The maximum depth for traversal, should be bigger than zero if specified.
Defaults to None (i.e: traversal happens till the end of the tree).
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
A list containing all the prims matching the predicate.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# get prim
prim = stage.GetPrimAtPath(prim_path)
# check if prim is valid
if not prim.IsValid():
raise ValueError(f"Prim at path '{prim_path}' is not valid.")
# check if depth is valid
if depth is not None and depth <= 0:
raise ValueError(f"Depth must be bigger than zero, got {depth}.")
# iterate over all prims under prim-path
# list of tuples (prim, current_depth)
all_prims_queue = [(prim, 0)]
output_prims = []
while len(all_prims_queue) > 0:
# get current prim
child_prim, current_depth = all_prims_queue.pop(0)
# check if prim passes predicate
if predicate(child_prim):
output_prims.append(child_prim)
# add children to list
if depth is None or current_depth < depth:
all_prims_queue += [(child, current_depth + 1) for child in child_prim.GetChildren()]
return output_prims
def find_first_matching_prim(prim_path_regex: str, stage: Usd.Stage | None = None) -> Usd.Prim | None:
"""Find the first matching prim in the stage based on input regex expression.
Args:
prim_path_regex: The regex expression for prim path.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
The first prim that matches input expression. If no prim matches, returns None.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# need to wrap the token patterns in '^' and '$' to prevent matching anywhere in the string
pattern = f"^{prim_path_regex}$"
compiled_pattern = re.compile(pattern)
# obtain matching prim (depth-first search)
for prim in stage.Traverse():
# check if prim passes predicate
if compiled_pattern.match(prim.GetPath().pathString) is not None:
return prim
return None
def find_matching_prims(prim_path_regex: str, stage: Usd.Stage | None = None) -> list[Usd.Prim]:
"""Find all the matching prims in the stage based on input regex expression.
Args:
prim_path_regex: The regex expression for prim path.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
A list of prims that match input expression.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# need to wrap the token patterns in '^' and '$' to prevent matching anywhere in the string
tokens = prim_path_regex.split("/")[1:]
tokens = [f"^{token}$" for token in tokens]
# iterate over all prims in stage (breath-first search)
all_prims = [stage.GetPseudoRoot()]
output_prims = []
for index, token in enumerate(tokens):
token_compiled = re.compile(token)
for prim in all_prims:
for child in prim.GetAllChildren():
if token_compiled.match(child.GetName()) is not None:
output_prims.append(child)
if index < len(tokens) - 1:
all_prims = output_prims
output_prims = []
return output_prims
def find_matching_prim_paths(prim_path_regex: str, stage: Usd.Stage | None = None) -> list[str]:
"""Find all the matching prim paths in the stage based on input regex expression.
Args:
prim_path_regex: The regex expression for prim path.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
A list of prim paths that match input expression.
"""
# obtain matching prims
output_prims = find_matching_prims(prim_path_regex, stage)
# convert prims to prim paths
output_prim_paths = []
for prim in output_prims:
output_prim_paths.append(prim.GetPath().pathString)
return output_prim_paths
...@@ -10,9 +10,7 @@ import torch ...@@ -10,9 +10,7 @@ import torch
import trimesh import trimesh
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import omni.isaac.core.utils.prims as prim_utils
import warp import warp
from omni.isaac.core.simulation_context import SimulationContext
from pxr import UsdGeom from pxr import UsdGeom
import omni.isaac.orbit.sim as sim_utils import omni.isaac.orbit.sim as sim_utils
...@@ -71,7 +69,7 @@ class TerrainImporter: ...@@ -71,7 +69,7 @@ class TerrainImporter:
""" """
# store inputs # store inputs
self.cfg = cfg self.cfg = cfg
self.device = SimulationContext.instance().device self.device = sim_utils.SimulationContext.instance().device # type: ignore
# create a dict of meshes # create a dict of meshes
self.meshes = dict() self.meshes = dict()
...@@ -246,8 +244,8 @@ class TerrainImporter: ...@@ -246,8 +244,8 @@ class TerrainImporter:
# traverse the prim and get the collision mesh # traverse the prim and get the collision mesh
# THINK: Should the user specify the collision mesh? # THINK: Should the user specify the collision mesh?
mesh_prim = prim_utils.get_first_matching_child_prim( mesh_prim = sim_utils.get_first_matching_child_prim(
self.cfg.prim_path + f"/{key}", lambda p: prim_utils.get_prim_type_name(p) == "Mesh" self.cfg.prim_path + f"/{key}", lambda prim: prim.GetTypeName() == "Mesh"
) )
# check if the mesh is valid # check if the mesh is valid
if mesh_prim is None: if mesh_prim is None:
......
...@@ -101,10 +101,12 @@ def retrieve_file_path(path: str, download_dir: str | None = None, force_downloa ...@@ -101,10 +101,12 @@ def retrieve_file_path(path: str, download_dir: str | None = None, force_downloa
# download file in temp directory using os # download file in temp directory using os
file_name = os.path.basename(omni.client.break_url(path).path) file_name = os.path.basename(omni.client.break_url(path).path)
target_path = os.path.join(download_dir, file_name) target_path = os.path.join(download_dir, file_name)
# check if file already exists locally
if not os.path.isfile(target_path) or force_download:
# copy file to local machine # copy file to local machine
result = omni.client.copy(path, target_path) result = omni.client.copy(path, target_path)
if result != omni.client.Result.OK and not force_download: if result != omni.client.Result.OK and force_download:
raise RuntimeError(f"Unable to copy file: '{path}'. File already exists locally at: {target_path}") raise RuntimeError(f"Unable to copy file: '{path}'. Is the Nucleus Server running?")
return os.path.abspath(target_path) return os.path.abspath(target_path)
else: else:
raise FileNotFoundError(f"Unable to find the file: {path}") raise FileNotFoundError(f"Unable to find the file: {path}")
......
...@@ -25,8 +25,8 @@ from omni.isaac.orbit.app import AppLauncher ...@@ -25,8 +25,8 @@ from omni.isaac.orbit.app import AppLauncher
# add argparse arguments # add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to external force on a legged robot.") parser = argparse.ArgumentParser(description="This script demonstrates how to external force on a legged robot.")
parser.add_argument("--body", type=str, help="Name of the body to apply force on.") parser.add_argument("--body", default="base", type=str, help="Name of the body to apply force on.")
parser.add_argument("--force", type=float, help="Force to apply on the body.") parser.add_argument("--force", default=1000.0, type=float, help="Force to apply on the body.")
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
...@@ -78,7 +78,7 @@ def main(): ...@@ -78,7 +78,7 @@ def main():
# Find bodies to apply the force # Find bodies to apply the force
body_ids, body_names = robot.find_bodies(args_cli.body) body_ids, body_names = robot.find_bodies(args_cli.body)
# Sample a large force # Sample a large force
external_wrench_b = torch.zeros(robot.root_view.count, len(body_ids), 6, device=sim.device) external_wrench_b = torch.zeros(robot.num_instances, len(body_ids), 6, device=sim.device)
external_wrench_b[..., 1] = args_cli.force external_wrench_b[..., 1] = args_cli.force
# Now we are ready! # Now we are ready!
......
...@@ -159,7 +159,7 @@ class TestArticulation(unittest.TestCase): ...@@ -159,7 +159,7 @@ class TestArticulation(unittest.TestCase):
# Find bodies to apply the force # Find bodies to apply the force
body_ids, _ = robot.find_bodies("base") body_ids, _ = robot.find_bodies("base")
# Sample a large force # Sample a large force
external_wrench_b = torch.zeros(robot.root_view.count, len(body_ids), 6, device=self.sim.device) external_wrench_b = torch.zeros(robot.num_instances, len(body_ids), 6, device=self.sim.device)
external_wrench_b[..., 1] = 1000.0 external_wrench_b[..., 1] = 1000.0
# Now we are ready! # Now we are ready!
...@@ -207,7 +207,7 @@ class TestArticulation(unittest.TestCase): ...@@ -207,7 +207,7 @@ class TestArticulation(unittest.TestCase):
# Find bodies to apply the force # Find bodies to apply the force
body_ids, _ = robot.find_bodies(".*_SHANK") body_ids, _ = robot.find_bodies(".*_SHANK")
# Sample a large force # Sample a large force
external_wrench_b = torch.zeros(robot.root_view.count, len(body_ids), 6, device=self.sim.device) external_wrench_b = torch.zeros(robot.num_instances, len(body_ids), 6, device=self.sim.device)
external_wrench_b[..., 1] = 100.0 external_wrench_b[..., 1] = 100.0
# Now we are ready! # Now we are ready!
...@@ -268,7 +268,7 @@ class TestArticulation(unittest.TestCase): ...@@ -268,7 +268,7 @@ class TestArticulation(unittest.TestCase):
".*_foot.*": 2.0, ".*_foot.*": 2.0,
} }
indices_list, _, values_list = string_utils.resolve_matching_names_values(expected_stiffness, robot.joint_names) indices_list, _, values_list = string_utils.resolve_matching_names_values(expected_stiffness, robot.joint_names)
expected_stiffness = torch.zeros(robot.root_view.count, robot.num_joints, device=robot.device) expected_stiffness = torch.zeros(robot.num_instances, robot.num_joints, device=robot.device)
expected_stiffness[:, indices_list] = torch.tensor(values_list, device=robot.device) expected_stiffness[:, indices_list] = torch.tensor(values_list, device=robot.device)
# -- Damping values # -- Damping values
expected_damping = { expected_damping = {
...@@ -308,7 +308,7 @@ class TestArticulation(unittest.TestCase): ...@@ -308,7 +308,7 @@ class TestArticulation(unittest.TestCase):
self.sim.reset() self.sim.reset()
# Expected gains # Expected gains
expected_stiffness = torch.full((robot.root_view.count, robot.num_joints), 10.0, device=robot.device) expected_stiffness = torch.full((robot.num_instances, robot.num_joints), 10.0, device=robot.device)
expected_damping = torch.full_like(expected_stiffness, 2.0) expected_damping = torch.full_like(expected_stiffness, 2.0)
# Check that gains are loaded from USD file # Check that gains are loaded from USD file
...@@ -333,7 +333,7 @@ class TestArticulation(unittest.TestCase): ...@@ -333,7 +333,7 @@ class TestArticulation(unittest.TestCase):
self.sim.reset() self.sim.reset()
# Expected gains # Expected gains
expected_stiffness = torch.full((robot.root_view.count, robot.num_joints), 10.0, device=robot.device) expected_stiffness = torch.full((robot.num_instances, robot.num_joints), 10.0, device=robot.device)
expected_damping = torch.full_like(expected_stiffness, 2.0) expected_damping = torch.full_like(expected_stiffness, 2.0)
# Check that gains are loaded from USD file # Check that gains are loaded from USD file
......
...@@ -112,8 +112,8 @@ class TestRigidObject(unittest.TestCase): ...@@ -112,8 +112,8 @@ class TestRigidObject(unittest.TestCase):
body_ids, _ = cube_object.find_bodies(".*") body_ids, _ = cube_object.find_bodies(".*")
# Sample a large force # Sample a large force
external_wrench_b = torch.zeros(cube_object.root_view.count, len(body_ids), 6, device=self.sim.device) external_wrench_b = torch.zeros(cube_object.num_instances, len(body_ids), 6, device=self.sim.device)
external_wrench_b[0, 0, 2] = 9.81 * cube_object.root_view.get_masses(indices=[0]) external_wrench_b[0, 0, 2] = 9.81 * cube_object.root_physx_view.get_masses()[0]
# Now we are ready! # Now we are ready!
for _ in range(5): for _ in range(5):
......
...@@ -25,7 +25,10 @@ parser = argparse.ArgumentParser( ...@@ -25,7 +25,10 @@ parser = argparse.ArgumentParser(
parser.add_argument("--headless", action="store_true", default=False, help="Force display off at all times.") parser.add_argument("--headless", action="store_true", default=False, help="Force display off at all times.")
parser.add_argument("--num_robots", type=int, default=128, help="Number of robots to spawn.") parser.add_argument("--num_robots", type=int, default=128, help="Number of robots to spawn.")
parser.add_argument( parser.add_argument(
"--asset", type=str, default="orbit", help="The asset source location for the robot. Can be: orbit, oige." "--asset",
type=str,
default="orbit",
help="The asset source location for the robot. Can be: orbit, oige, custom asset path.",
) )
args_cli = parser.parse_args() args_cli = parser.parse_args()
......
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
"""Launch Isaac Sim Simulator first."""
from omni.isaac.kit import SimulationApp
# launch omniverse app
config = {"headless": True}
simulation_app = SimulationApp(config)
"""Rest everything follows."""
import numpy as np
import traceback
import unittest
import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.isaac.core.utils.stage as stage_utils
import omni.isaac.orbit.sim as sim_utils
class TestUtilities(unittest.TestCase):
"""Test fixture for the sim utility functions."""
def setUp(self):
"""Create a blank new stage for each test."""
# Create a new stage
stage_utils.create_new_stage()
stage_utils.update_stage()
def tearDown(self) -> None:
"""Clear stage after each test."""
stage_utils.clear_stage()
def test_get_all_matching_child_prims(self):
"""Test get_all_matching_child_prims() function."""
# create scene
prim_utils.create_prim("/World/Floor")
prim_utils.create_prim(
"/World/Floor/thefloor", "Cube", position=np.array([75, 75, -150.1]), attributes={"size": 300}
)
prim_utils.create_prim("/World/Room", "Sphere", attributes={"radius": 1e3})
# test
isaac_sim_result = prim_utils.get_all_matching_child_prims("/World")
orbit_result = sim_utils.get_all_matching_child_prims("/World")
self.assertListEqual(isaac_sim_result, orbit_result)
def test_find_matching_prim_paths(self):
"""Test find_matching_prim_paths() function."""
# create scene
for index in range(2048):
random_pos = np.random.uniform(-100, 100, size=3)
prim_utils.create_prim(f"/World/Floor_{index}", "Cube", position=random_pos, attributes={"size": 2.0})
prim_utils.create_prim(f"/World/Floor_{index}/Sphere", "Sphere", attributes={"radius": 10})
prim_utils.create_prim(f"/World/Floor_{index}/Sphere/childSphere", "Sphere", attributes={"radius": 1})
prim_utils.create_prim(f"/World/Floor_{index}/Sphere/childSphere2", "Sphere", attributes={"radius": 1})
# test leaf paths
isaac_sim_result = prim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere")
orbit_result = sim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere")
self.assertListEqual(isaac_sim_result, orbit_result)
# test non-leaf paths
isaac_sim_result = prim_utils.find_matching_prim_paths("/World/Floor_.*")
orbit_result = sim_utils.find_matching_prim_paths("/World/Floor_.*")
self.assertListEqual(isaac_sim_result, orbit_result)
# test child-leaf paths
isaac_sim_result = prim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere/childSphere.*")
orbit_result = sim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere/childSphere.*")
self.assertListEqual(isaac_sim_result, orbit_result)
if __name__ == "__main__":
try:
unittest.main()
except Exception as err:
carb.log_error(err)
carb.log_error(traceback.format_exc())
raise
finally:
# close sim app
simulation_app.close()
...@@ -34,7 +34,7 @@ def base_up_proj(env: HumanoidEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("r ...@@ -34,7 +34,7 @@ def base_up_proj(env: HumanoidEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("r
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
# compute base up vector # compute base up vector
base_up_vec = math_utils.quat_rotate(asset.data.root_quat_w, -asset._GRAVITY_VEC_W) # type: ignore base_up_vec = math_utils.quat_rotate(asset.data.root_quat_w, -asset.GRAVITY_VEC_W)
return base_up_vec[:, 2].unsqueeze(-1) return base_up_vec[:, 2].unsqueeze(-1)
...@@ -50,7 +50,7 @@ def base_heading_proj( ...@@ -50,7 +50,7 @@ def base_heading_proj(
to_target_pos[:, 2] = 0.0 to_target_pos[:, 2] = 0.0
to_target_dir = math_utils.normalize(to_target_pos) to_target_dir = math_utils.normalize(to_target_pos)
# compute base forward vector # compute base forward vector
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset._FORWARD_VEC_B) # type: ignore heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.FORWARD_VEC_B)
# compute dot product between heading and target direction # compute dot product between heading and target direction
heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1)) heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1))
......
...@@ -20,7 +20,7 @@ class LiftCubePPORunnerCfg(RslRlOnPolicyRunnerCfg): ...@@ -20,7 +20,7 @@ class LiftCubePPORunnerCfg(RslRlOnPolicyRunnerCfg):
experiment_name = "franka_lift" experiment_name = "franka_lift"
empirical_normalization = False empirical_normalization = False
policy = RslRlPpoActorCriticCfg( policy = RslRlPpoActorCriticCfg(
init_noise_std=0.8, init_noise_std=1.0,
actor_hidden_dims=[256, 128, 64], actor_hidden_dims=[256, 128, 64],
critic_hidden_dims=[256, 128, 64], critic_hidden_dims=[256, 128, 64],
activation="elu", activation="elu",
......
...@@ -33,7 +33,7 @@ class FrankaCubeLiftEnvCfg(LiftEnvCfg): ...@@ -33,7 +33,7 @@ class FrankaCubeLiftEnvCfg(LiftEnvCfg):
# Set actions for the specific robot type (franka) # Set actions for the specific robot type (franka)
self.actions.body_joint_pos = mdp.JointPositionActionCfg( self.actions.body_joint_pos = mdp.JointPositionActionCfg(
asset_name="robot", joint_names=["panda_joint.*"], scale=1.0, use_default_offset=True asset_name="robot", joint_names=["panda_joint.*"], scale=0.5, use_default_offset=True
) )
self.actions.finger_joint_pos = mdp.BinaryJointPositionActionCfg( self.actions.finger_joint_pos = mdp.BinaryJointPositionActionCfg(
asset_name="robot", asset_name="robot",
......
...@@ -95,7 +95,7 @@ def main(): ...@@ -95,7 +95,7 @@ def main():
print("[INFO]: Setup complete...") print("[INFO]: Setup complete...")
# dummy actions # dummy actions
actions = torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
has_gripper = args_cli.robot == "franka_panda" has_gripper = args_cli.robot == "franka_panda"
# Define simulation stepping # Define simulation stepping
...@@ -114,9 +114,8 @@ def main(): ...@@ -114,9 +114,8 @@ def main():
robot.write_joint_state_to_sim(joint_pos, joint_vel) robot.write_joint_state_to_sim(joint_pos, joint_vel)
robot.reset() robot.reset()
# reset command # reset command
actions = ( actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device)
torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos actions += robot.data.default_joint_pos
)
# reset gripper # reset gripper
if has_gripper: if has_gripper:
actions[:, -2:] = 0.04 actions[:, -2:] = 0.04
......
...@@ -105,7 +105,7 @@ def main(): ...@@ -105,7 +105,7 @@ def main():
print("[INFO]: Setup complete...") print("[INFO]: Setup complete...")
# dummy actions # dummy actions
actions = torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
has_gripper = args_cli.robot == "franka_panda" has_gripper = args_cli.robot == "franka_panda"
# Define simulation stepping # Define simulation stepping
...@@ -124,9 +124,8 @@ def main(): ...@@ -124,9 +124,8 @@ def main():
robot.write_joint_state_to_sim(joint_pos, joint_vel) robot.write_joint_state_to_sim(joint_pos, joint_vel)
robot.reset() robot.reset()
# reset command # reset command
actions = ( actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device)
torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos actions += robot.data.default_joint_pos
)
# reset gripper # reset gripper
if has_gripper: if has_gripper:
actions[:, -2:] = 0.04 actions[:, -2:] = 0.04
......
...@@ -147,9 +147,8 @@ def main(): ...@@ -147,9 +147,8 @@ def main():
actions[:, 2] = 1.0 actions[:, 2] = 1.0
# change the arm action # change the arm action
if ep_step_count % 100: if ep_step_count % 100:
actions[:, 3:10] = ( actions[:, 3:10] = torch.rand(robot.num_instances, 7, device=robot.device)
torch.rand(robot.root_view.count, 7, device=robot.device) + robot.data.default_joint_pos[:, 3:10] actions[:, 3:10] += robot.data.default_joint_pos[:, 3:10]
)
# apply action # apply action
robot.set_joint_velocity_target(actions[:, :3], joint_ids=[0, 1, 2]) robot.set_joint_velocity_target(actions[:, :3], joint_ids=[0, 1, 2])
robot.set_joint_position_target(actions[:, 3:], joint_ids=[3, 4, 5, 6, 7, 8, 9, 10, 11]) robot.set_joint_position_target(actions[:, 3:], joint_ids=[3, 4, 5, 6, 7, 8, 9, 10, 11])
......
...@@ -142,8 +142,8 @@ def main(): ...@@ -142,8 +142,8 @@ def main():
print("[INFO]: Setup complete...") print("[INFO]: Setup complete...")
# Create buffers to store actions # Create buffers to store actions
rmp_commands = torch.zeros(robot.count, rmp_controller.num_actions, device=robot.device) rmp_commands = torch.zeros(robot.num_instances, rmp_controller.num_actions, device=robot.device)
robot_actions = torch.ones(robot.count, robot.num_actions, device=robot.device) robot_actions = torch.ones(robot.num_instances, robot.num_actions, device=robot.device)
has_gripper = robot.cfg.meta_info.tool_num_dof > 0 has_gripper = robot.cfg.meta_info.tool_num_dof > 0
# Set end effector goals # Set end effector goals
......
...@@ -98,11 +98,11 @@ def main(): ...@@ -98,11 +98,11 @@ def main():
root_state = rigid_object.data.default_root_state.clone() root_state = rigid_object.data.default_root_state.clone()
# -- position # -- position
root_state[:, :3] = sample_cylinder( root_state[:, :3] = sample_cylinder(
radius=0.5, h_range=(0.15, 0.25), size=rigid_object.root_view.count, device=rigid_object.device radius=0.5, h_range=(0.15, 0.25), size=rigid_object.num_instances, device=rigid_object.device
) )
# -- orientation: apply yaw rotation # -- orientation: apply yaw rotation
root_state[:, 3:7] = quat_mul( root_state[:, 3:7] = quat_mul(
random_yaw_orientation(rigid_object.root_view.count, rigid_object.device), root_state[:, 3:7] random_yaw_orientation(rigid_object.num_instances, rigid_object.device), root_state[:, 3:7]
) )
# -- set root state # -- set root state
rigid_object.write_root_state_to_sim(root_state) rigid_object.write_root_state_to_sim(root_state)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment