Unverified Commit 6f4cc59d authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Changes internal working of APIs to use physics views directly (#267)

# Description

For a long time, we have been seeing a slow simulation setup time (i.e.
time spent in `sim.reset` call). It takes around 70-75 seconds to set up
the simulation for ANYmal locomotion task with the new USD asset for it.
This number is only increasing with other more complex robots we have
been trying to import.

The MR dives into the possible causes and gets rid of costly operations.
Many of these are coming from Isaac Sim itself, particularly related to
the initialization of views. Hence, the following breaking changes:

* We no longer depend on Isaac Sim for `RigidPrimView` and
`ArticulationView`. Instead, we directly create underlying PhysX views
for them.
* We add faster reimplementations of functions that are used for regex
matching.

With these changes, the simulation load time is reduced from up to 80
sec to 15 sec. A bulk of the time is still going to setting up the
simulation step for the first time.

## Type of change

- Breaking change (fix or feature that would cause existing
functionality to not work as expected)

## Screenshots

| Before | After |
| ------ | ----- |
|
![orig-fg](https://github.com/isaac-orbit/orbit/assets/12863862/c13f1634-bd2c-4daf-97e0-3b5776b5cd37)
|
![ref-fg](https://github.com/isaac-orbit/orbit/assets/12863862/b509049d-4cbd-45d6-a4f4-6082f4caf7f2)
|

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 99a238e7
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.55"
version = "0.10.0"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.10.0 (2023-12-04)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Modified the sensor and asset base classes to use the underlying PhysX views instead of Isaac Sim views.
Using Isaac Sim classes led to a very high load time (of the order of minutes) when using a scene with
many assets. This is because Isaac Sim supports USD paths which are slow and not required.
Added
^^^^^
* Added faster implementation of USD stage traversal methods inside the :class:`omni.isaac.orbit.sim.utils` module.
* Added properties :attr:`omni.isaac.orbit.assets.AssetBase.num_instances` and
:attr:`omni.isaac.orbit.sensor.SensorBase.num_instances` to obtain the number of instances of the asset
or sensor in the simulation respectively.
Removed
^^^^^^^
* Removed dependencies on Isaac Sim view classes. It is no longer possible to use :attr:`root_view` and
:attr:`body_view`. Instead use :attr:`root_physx_view` and :attr:`body_physx_view` to access the underlying
PhysX views.
0.9.55 (2023-12-03)
~~~~~~~~~~~~~~~~~~~
......
......@@ -639,9 +639,13 @@ class AppLauncher:
enable_extension("omni.kit.viewport.bundle")
# extension for window status bar
enable_extension("omni.kit.window.status_bar")
# enable isaac replicator extension
# enable replicator extension
# note: moved here since it requires to have the viewport extension to be enabled first.
enable_extension("omni.replicator.isaac")
enable_extension("omni.replicator.core")
# enable UI tools
# note: we need to always import this even with headless to make
# the module for orbit.envs.ui work
enable_extension("omni.isaac.ui")
# set the nucleus directory manually to the 2023.1.0 version
# TODO: Remove this once the 2023.1.0 version is released
......
......@@ -11,10 +11,11 @@ import weakref
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence
import omni.isaac.core.utils.prims as prim_utils
import omni.kit.app
import omni.timeline
import omni.isaac.orbit.sim as sim_utils
if TYPE_CHECKING:
from .asset_base_cfg import AssetBaseCfg
......@@ -77,8 +78,8 @@ class AssetBase(ABC):
orientation=self.cfg.init_state.rot,
)
# check that spawn was successful
matching_prim_paths = prim_utils.find_matching_prim_paths(self.cfg.prim_path)
if len(matching_prim_paths) == 0:
matching_prims = sim_utils.find_matching_prims(self.cfg.prim_path)
if len(matching_prims) == 0:
raise RuntimeError(f"Could not find prim with path {self.cfg.prim_path}.")
# note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor is called.
......@@ -120,9 +121,17 @@ class AssetBase(ABC):
@property
@abstractmethod
def num_instances(self) -> int:
"""Number of instances of the asset.
This is equal to the number of asset instances per environment multiplied by the number of environments.
"""
return NotImplementedError
@property
def device(self) -> str:
"""Memory device for computation."""
return NotImplementedError
return self._device
@property
@abstractmethod
......@@ -235,7 +244,15 @@ class AssetBase(ABC):
called whenever the simulator "plays" from a "stop" state.
"""
if not self._is_initialized:
# obtain simulation related information
sim = sim_utils.SimulationContext.instance()
if sim is None:
raise RuntimeError("SimulationContext is not initialized! Please initialize SimulationContext first.")
self._backend = sim.backend
self._device = sim.device
# initialize the asset
self._initialize_impl()
# set flag
self._is_initialized = True
def _invalidate_initialize_callback(self, event):
......
......@@ -123,7 +123,8 @@ class BaseEnv:
# note: this activates the physics simulation view that exposes TensorAPIs
# note: when started in extension mode, first call sim.reset_async() and then initialize the managers
if builtins.ISAAC_LAUNCHED_FROM_TERMINAL is False:
with Timer("[INFO]: Time taken for simulation reset"):
print("[INFO]: Starting the simulation. This may take a few seconds. Please wait...")
with Timer("[INFO]: Time taken for simulation start"):
self.sim.reset()
# add timeline event to load managers
self.load_managers()
......@@ -279,9 +280,14 @@ class BaseEnv:
Returns:
The seed used for random generator.
"""
import omni.replicator.core as rep
rep.set_global_seed(seed)
# set seed for replicator
try:
import omni.replicator.core as rep
rep.set_global_seed(seed)
except ModuleNotFoundError:
pass
# set seed for torch and other libraries
return torch_utils.set_seed(seed)
def close(self):
......
......@@ -67,10 +67,10 @@ def randomize_rigid_body_material(
material_buckets[:, 1].uniform_(*dynamic_friction_range)
material_buckets[:, 2].uniform_(*restitution_range)
# create random material assignments based on the total number of shapes: num_assets x num_bodies x num_shapes
material_ids = torch.randint(0, num_buckets, (asset.body_view.count, asset.body_view.num_shapes))
material_ids = torch.randint(0, num_buckets, (asset.body_physx_view.count, asset.body_physx_view.max_shapes))
materials = material_buckets[material_ids]
# resolve the global body indices from the env_ids and the env_body_ids
bodies_per_env = asset.body_view.count // num_envs # - number of bodies per spawned asset
bodies_per_env = asset.body_physx_view.count // num_envs # - number of bodies per spawned asset
indices = torch.tensor(asset_cfg.body_ids, dtype=torch.int).repeat(len(env_ids), 1)
indices += env_ids.unsqueeze(1) * bodies_per_env
......@@ -99,7 +99,7 @@ def add_body_mass(
masses = asset.body_physx_view.get_masses()
masses += sample_uniform(*mass_range, masses.shape, device=masses.device)
# resolve the global body indices from the env_ids and the env_body_ids
bodies_per_env = asset.body_view.count // env.num_envs
bodies_per_env = asset.body_physx_view.count // env.num_envs
indices = torch.tensor(asset_cfg.body_ids, dtype=torch.int).repeat(len(env_ids), 1)
indices += env_ids.unsqueeze(1) * bodies_per_env
......
......@@ -10,12 +10,5 @@ This includes functionalities such as tracking a robot in the simulation,
toggling different debug visualization tools, and other user-defined functionalities.
"""
# enable the extension for UI elements
# this only needs to be done once
from omni.isaac.core.utils.extensions import enable_extension
enable_extension("omni.isaac.ui")
# import all UI elements here
from .base_env_window import BaseEnvWindow
from .rl_task_env_window import RLTaskEnvWindow
......@@ -10,14 +10,13 @@ import torch
from typing import Any, Sequence
import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.isaac.core.utils.stage as stage_utils
import omni.usd
from omni.isaac.cloner import GridCloner
from omni.isaac.core.prims import XFormPrimView
from omni.isaac.core.simulation_context import SimulationContext
from omni.isaac.version import get_version
from pxr import PhysxSchema
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.assets import Articulation, ArticulationCfg, AssetBaseCfg, RigidObject, RigidObjectCfg
from omni.isaac.orbit.sensors import FrameTransformerCfg, SensorBase, SensorBaseCfg
from omni.isaac.orbit.terrains import TerrainImporter, TerrainImporterCfg
......@@ -112,12 +111,14 @@ class InteractiveScene:
"""
# store inputs
self.cfg = cfg
# obtain the current stage
self.stage = omni.usd.get_context().get_stage()
# prepare cloner for environment replication
self.cloner = GridCloner(spacing=self.cfg.env_spacing)
self.cloner.define_base_env(self.env_ns)
self.env_prim_paths = self.cloner.generate_paths(f"{self.env_ns}/env", self.cfg.num_envs)
# create source prim
prim_utils.define_prim(self.env_prim_paths[0], "Xform")
self.stage.DefinePrim(self.env_prim_paths[0], "Xform")
# obtain major isaac sim version
isaac_major_version = int(get_version()[2])
# clone the env xform
......@@ -158,7 +159,7 @@ class InteractiveScene:
)
# obtain the current physics scene
physics_scene_prim_path = None
for prim in stage_utils.traverse_stage():
for prim in self.stage.Traverse():
if prim.HasAPI(PhysxSchema.PhysxSceneAPI):
physics_scene_prim_path = prim.GetPrimPath()
carb.log_info(f"Physics scene prim path: {physics_scene_prim_path}")
......@@ -188,12 +189,12 @@ class InteractiveScene:
@property
def physics_dt(self) -> float:
"""The physics timestep of the scene."""
return SimulationContext.instance().get_physics_dt() # pyright: ignore [reportOptionalMemberAccess]
return sim_utils.SimulationContext.instance().get_physics_dt() # pyright: ignore [reportOptionalMemberAccess]
@property
def device(self) -> str:
"""The device on which the scene is created."""
return SimulationContext.instance().device # pyright: ignore [reportOptionalMemberAccess]
return sim_utils.SimulationContext.instance().device # pyright: ignore [reportOptionalMemberAccess]
@property
def env_ns(self) -> str:
......@@ -247,7 +248,7 @@ class InteractiveScene:
# note: In standalone mode, this method is called in the `step()` method of the simulation context.
# So we only need to flush when running in extension mode.
if builtins.ISAAC_LAUNCHED_FROM_TERMINAL:
SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess]
sim_utils.SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess]
def write_data_to_sim(self):
"""Writes the data of the scene entities to the simulation."""
......@@ -262,7 +263,7 @@ class InteractiveScene:
# note: In standalone mode, this method is called in the `step()` method of the simulation context.
# So we only need to flush when running in extension mode.
if builtins.ISAAC_LAUNCHED_FROM_TERMINAL:
SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess]
sim_utils.SimulationContext.instance().physics_sim_view.flush() # pyright: ignore [reportOptionalMemberAccess]
def update(self, dt: float) -> None:
"""Update the scene entities.
......@@ -370,5 +371,5 @@ class InteractiveScene:
raise ValueError(f"Unknown asset config type for {asset_name}: {asset_cfg}")
# store global collision paths
if hasattr(asset_cfg, "collision_group") and asset_cfg.collision_group == -1:
asset_paths = prim_utils.find_matching_prim_paths(asset_cfg.prim_path)
asset_paths = sim_utils.find_matching_prim_paths(asset_cfg.prim_path)
self._global_prim_paths += asset_paths
......@@ -12,13 +12,12 @@ from tensordict import TensorDict
from typing import TYPE_CHECKING, Any, Sequence
from typing_extensions import Literal
import omni.isaac.core.utils.prims as prim_utils
import omni.kit.commands
import omni.usd
from omni.isaac.core.prims import XFormPrimView
from pxr import UsdGeom
# omni-isaac-orbit
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.utils import to_camel_case
from omni.isaac.orbit.utils.array import convert_to_torch
from omni.isaac.orbit.utils.math import quat_from_matrix
......@@ -93,8 +92,8 @@ class Camera(SensorBase):
self.cfg.prim_path, self.cfg.spawn, translation=self.cfg.offset.pos, orientation=rot_offset
)
# check that spawn was successful
matching_prim_paths = prim_utils.find_matching_prim_paths(self.cfg.prim_path)
if len(matching_prim_paths) == 0:
matching_prims = sim_utils.find_matching_prims(self.cfg.prim_path)
if len(matching_prims) == 0:
raise RuntimeError(f"Could not find prim with path {self.cfg.prim_path}.")
# UsdGeom Camera prim for the sensor
......@@ -127,6 +126,10 @@ class Camera(SensorBase):
Properties
"""
@property
def num_instances(self) -> int:
return self._view.count
@property
def data(self) -> CameraData:
# update sensors if needed
......@@ -351,10 +354,12 @@ class Camera(SensorBase):
device_name = self._device.split(":")[0]
else:
device_name = "cpu"
# Obtain current stage
stage = omni.usd.get_context().get_stage()
# Convert all encapsulated prims to Camera
for cam_prim_path in self._view.prim_paths:
# Get camera prim
cam_prim = prim_utils.get_prim_at_path(cam_prim_path)
cam_prim = stage.GetPrimAtPath(cam_prim_path)
# Check if prim is a camera
if not cam_prim.IsA(UsdGeom.Camera):
raise RuntimeError(f"Prim at path '{cam_prim_path}' is not a Camera.")
......
......@@ -11,13 +11,13 @@ from __future__ import annotations
import torch
from typing import TYPE_CHECKING, Sequence
import omni.isaac.core.utils.prims as prim_utils
import omni.physics.tensors.impl.api as physx
from omni.isaac.core.prims import RigidContactView, RigidPrimView
from pxr import PhysxSchema
import omni.isaac.orbit.sim as sim_utils
import omni.isaac.orbit.utils.string as string_utils
from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.utils.math import convert_quat
from ..sensor_base import SensorBase
from .contact_sensor_data import ContactSensorData
......@@ -62,7 +62,7 @@ class ContactSensor(SensorBase):
"""Returns: A string containing information about the instance."""
return (
f"Contact sensor @ '{self.cfg.prim_path}': \n"
f"\tview type : {self._view.__class__}\n"
f"\tview type : {self.body_physx_view.__class__}\n"
f"\tupdate period (s) : {self.cfg.update_period}\n"
f"\tnumber of bodies : {self.num_bodies}\n"
f"\tbody names : {self.body_names}\n"
......@@ -72,6 +72,10 @@ class ContactSensor(SensorBase):
Properties
"""
@property
def num_instances(self) -> int:
return self.body_physx_view.count
@property
def data(self) -> ContactSensorData:
# update sensors if needed
......@@ -87,38 +91,26 @@ class ContactSensor(SensorBase):
@property
def body_names(self) -> list[str]:
"""Ordered names of bodies with contact sensors attached."""
prim_paths = self._view.prim_paths[: self.num_bodies]
prim_paths = self.body_physx_view.prim_paths[: self.num_bodies]
return [path.split("/")[-1] for path in prim_paths]
@property
def body_view(self) -> RigidPrimView:
"""View for the rigid bodies captured (Isaac Sim)."""
return self._view
@property
def contact_view(self) -> RigidContactView:
"""Contact reporter view for the bodies (Isaac Sim)."""
return self._view._contact_view # pyright: ignore [reportPrivateUsage]
@property
def body_physx_view(self) -> physx.RigidBodyView:
"""View for the rigid bodies captured (PhysX).
Note:
Use this view with caution! It requires handling of tensors in a specific way and is exposed for
advanced users who have a deep understanding of PhysX SDK. Prefer using the Isaac Sim view when possible.
Use this view with caution. It requires handling of tensors in a specific way.
"""
return self._view._physics_view # pyright: ignore [reportPrivateUsage]
return self._body_physx_view
@property
def contact_physx_view(self) -> physx.RigidContactView:
"""Contact reporter view for the bodies (PhysX).
Note:
Use this view with caution! It requires handling of tensors in a specific way and is exposed for
advanced users who have a deep understanding of PhysX SDK. Prefer using the Isaac Sim view when possible.
Use this view with caution. It requires handling of tensors in a specific way.
"""
return self._view._contact_view._physics_view # pyright: ignore [reportPrivateUsage]
return self._contact_physx_view
"""
Operations
......@@ -163,14 +155,17 @@ class ContactSensor(SensorBase):
def _initialize_impl(self):
super()._initialize_impl()
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# check that only rigid bodies are selected
matching_prim_paths = prim_utils.find_matching_prim_paths(self.cfg.prim_path)
num_prim_matches = len(matching_prim_paths) // self._num_envs
leaf_pattern = self.cfg.prim_path.rsplit("/", 1)[-1]
template_prim_path = self._parent_prims[0].GetPath().pathString
body_names = list()
for prim_path in matching_prim_paths[:num_prim_matches]:
prim = prim_utils.get_prim_at_path(prim_path)
for prim in sim_utils.find_matching_prims(template_prim_path + "/" + leaf_pattern):
# check if prim has contact reporter API
if prim.HasAPI(PhysxSchema.PhysxContactReportAPI):
prim_path = prim.GetPath().pathString
body_names.append(prim_path.rsplit("/", 1)[-1])
# check that there is at least one body with contact reporter API
if not body_names:
......@@ -183,17 +178,12 @@ class ContactSensor(SensorBase):
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
# construct a new regex expression
# create a rigid prim view for the sensor
self._view = RigidPrimView(
prim_paths_expr=body_names_regex,
reset_xform_properties=False,
track_contact_forces=True,
contact_filter_prim_paths_expr=self.cfg.filter_prim_paths_expr,
prepare_contact_sensors=False,
disable_stablization=True,
self._body_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
self._contact_physx_view = self._physics_sim_view.create_rigid_contact_view(
body_names_regex.replace(".*", "*"), filter_patterns=self.cfg.filter_prim_paths_expr
)
self._view.initialize()
# resolve the true count of bodies
self._num_bodies = self._view.count // self._num_envs
self._num_bodies = self.body_physx_view.count // self._num_envs
# check that contact reporter succeeded
if self._num_bodies != len(body_names):
raise RuntimeError(
......@@ -225,7 +215,7 @@ class ContactSensor(SensorBase):
num_shapes = self.contact_physx_view.sensor_count // self._num_bodies
num_filters = self.contact_physx_view.filter_count
self._data.force_matrix_w = torch.zeros(
self.count, self._num_bodies, num_shapes, num_filters, 3, device=self._device
self._num_envs, self._num_bodies, num_shapes, num_filters, 3, device=self._device
)
def _update_buffers_impl(self, env_ids: Sequence[int]):
......@@ -255,9 +245,9 @@ class ContactSensor(SensorBase):
self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids]
# obtain the pose of the sensor origin
if self.cfg.track_pose:
pose = self.body_physx_view.get_transforms()
self._data.pos_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, :3]
self._data.quat_w[env_ids] = pose.view(-1, self._num_bodies, 7)[env_ids, :, 3:]
pose = self.body_physx_view.get_transforms().view(-1, self._num_bodies, 7)[env_ids]
pose[..., 3:] = convert_quat(pose[..., 3:], to="wxyz")
self._data.pos_w[env_ids], self._data.quat_w[env_ids] = pose.split([3, 4], dim=-1)
# obtain the air time
if self.cfg.track_air_time:
# -- time elapsed since last update
......
......@@ -9,10 +9,10 @@ import torch
from typing import TYPE_CHECKING, Sequence
import carb
import omni.isaac.core.utils.prims as prim_utils
from omni.isaac.core.prims import RigidPrimView
import omni.physics.tensors.impl.api as physx
from pxr import UsdPhysics
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.utils.math import (
combine_frame_transforms,
......@@ -151,14 +151,15 @@ class FrameTransformer(SensorBase):
frame_offsets = [None] + [target_frame.offset for target_frame in self.cfg.target_frames]
for frame, prim_path, offset in zip(frames, frame_prim_paths, frame_offsets):
# Find correct prim
matching_prims = prim_utils.find_matching_prim_paths(prim_path)
matching_prims = sim_utils.find_matching_prims(prim_path)
if len(matching_prims) == 0:
raise ValueError(
f"Failed to create frame transformer for frame '{frame}' with path '{prim_path}'."
" No matching prims were found."
)
for matching_prim_path in matching_prims:
prim = prim_utils.get_prim_at_path(matching_prim_path)
for prim in matching_prims:
# Get the prim path of the matching prim
matching_prim_path = prim.GetPath().pathString
# check if it is a rigid prim
if not prim.HasAPI(UsdPhysics.RigidBodyAPI):
raise ValueError(
......@@ -216,14 +217,16 @@ class FrameTransformer(SensorBase):
body_names_regex = r"(" + "|".join(self._tracked_body_names) + r")"
body_names_regex = f"{self.cfg.prim_path.rsplit('/', 1)[0]}/{body_names_regex}"
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# Create a prim view for all frames and initialize it
# order of transforms coming out of view will be source frame followed by target frame(s)
self._frame_view = RigidPrimView(prim_paths_expr=body_names_regex, reset_xform_properties=False)
self._frame_view.initialize()
self._frame_physx_view = self._physics_sim_view.create_rigid_body_view(body_names_regex.replace(".*", "*"))
# Determine the order in which regex evaluated body names so we can later index into frame transforms
# by frame name correctly
all_prim_paths = self._frame_view.prim_paths
all_prim_paths = self._frame_physx_view.prim_paths
# Only need first env as the names and their orderring are the same across environments
first_env_prim_paths = all_prim_paths[0 : self._num_target_body_frames + 1]
......@@ -282,18 +285,14 @@ class FrameTransformer(SensorBase):
# Extract transforms from view - shape is:
# (the total number of source and target body frames being tracked * self._num_envs, 7)
transforms = self._frame_view._physics_view.get_transforms()
transforms = self._frame_physx_view.get_transforms()
# Convert quaternions as PhysX uses xyzw form
transforms[:, 3:] = convert_quat(transforms[:, 3:], to="wxyz")
# Process source frame transform
source_frames = transforms[self._source_frame_idxs]
target_frames = transforms[self._target_frame_idxs]
# Convert quaternions as Isaac uses xyzw form
source_frames[:, 3:] = convert_quat(source_frames[:, 3:], to="wxyz")
target_frames[:, 3:] = convert_quat(target_frames[:, 3:], to="wxyz")
# Only apply offset if the offsets will result in a coordinate frame transform
if self._apply_source_frame_offset:
# Apply offsets for source frame
source_pos_w, source_rot_w = combine_frame_transforms(
source_frames[:, :3],
source_frames[:, 3:],
......@@ -304,12 +303,12 @@ class FrameTransformer(SensorBase):
source_pos_w = source_frames[:, :3]
source_rot_w = source_frames[:, 3:]
# Process target frame transforms
target_frames = transforms[self._target_frame_idxs]
duplicated_target_frame_pos_w = target_frames[self._duplicate_frame_indices, :3]
duplicated_target_frame_rot_w = target_frames[self._duplicate_frame_indices, 3:]
# Only apply offset if the offsets will result in a coordinate frame transform
if self._apply_target_frame_offset:
# Apply offsets for target frame
target_pos_w, target_rot_w = combine_frame_transforms(
duplicated_target_frame_pos_w,
duplicated_target_frame_rot_w,
......@@ -320,8 +319,8 @@ class FrameTransformer(SensorBase):
target_pos_w = duplicated_target_frame_pos_w
target_rot_w = duplicated_target_frame_rot_w
# Compute the transform of the target frame with respect to the source frame
total_num_frames = len(self._target_frame_names)
target_pos_source, target_rot_source = subtract_frame_transforms(
source_pos_w.unsqueeze(1).expand(-1, total_num_frames, -1).reshape(-1, 3),
source_rot_w.unsqueeze(1).expand(-1, total_num_frames, -1).reshape(-1, 4),
......@@ -330,7 +329,7 @@ class FrameTransformer(SensorBase):
)
# Update buffers
# NOTE: The frame names / orderring don't change so no need to update them after initialization
# note: The frame names / ordering don't change so no need to update them after initialization
self._data.source_pos_w[:] = source_pos_w.view(-1, 3)
self._data.source_rot_w[:] = source_rot_w.view(-1, 4)
self._data.target_pos_w[:] = target_pos_w.view(-1, total_num_frames, 3)
......
......@@ -10,15 +10,15 @@ import torch
from typing import TYPE_CHECKING, ClassVar, Sequence
import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.physics.tensors.impl.api as physx
import warp as wp
from omni.isaac.core.articulations import ArticulationView
from omni.isaac.core.prims import RigidPrimView, XFormPrimView
from omni.isaac.core.prims import XFormPrimView
from pxr import UsdGeom, UsdPhysics
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.terrains.trimesh.utils import make_plane
from omni.isaac.orbit.utils.math import quat_apply, quat_apply_yaw
from omni.isaac.orbit.utils.math import convert_quat, quat_apply, quat_apply_yaw
from omni.isaac.orbit.utils.warp import convert_to_warp_mesh, raycast_mesh
from ..sensor_base import SensorBase
......@@ -82,6 +82,10 @@ class RayCaster(SensorBase):
Properties
"""
@property
def num_instances(self) -> int:
return self._view.count
@property
def data(self) -> RayCasterData:
# update sensors if needed
......@@ -108,29 +112,30 @@ class RayCaster(SensorBase):
def _initialize_impl(self):
super()._initialize_impl()
# create simulation view
self._physics_sim_view = physx.create_simulation_view(self._backend)
self._physics_sim_view.set_subspace_roots("/")
# check if the prim at path is an articulated or rigid prim
# we do this since for physics-based view classes we can access their data directly
# otherwise we need to use the xform view class which is slower
prim_view_class = None
for prim_path in prim_utils.find_matching_prim_paths(self.cfg.prim_path):
# get prim at path
prim = prim_utils.get_prim_at_path(prim_path)
# check if it is a rigid prim
if prim.HasAPI(UsdPhysics.ArticulationRootAPI):
prim_view_class = ArticulationView
elif prim.HasAPI(UsdPhysics.RigidBodyAPI):
prim_view_class = RigidPrimView
else:
prim_view_class = XFormPrimView
carb.log_warn(f"The prim at path {prim_path} is not a physics prim! Using XFormPrimView.")
# break the loop
break
found_supported_prim_class = False
prim = sim_utils.find_first_matching_prim(self.cfg.prim_path)
if prim is None:
raise RuntimeError(f"Failed to find a prim at path expression: {self.cfg.prim_path}")
# create view based on the type of prim
if prim.HasAPI(UsdPhysics.ArticulationRootAPI):
self._view = self._physics_sim_view.create_articulation_view(self.cfg.prim_path.replace(".*", "*"))
found_supported_prim_class = True
elif prim.HasAPI(UsdPhysics.RigidBodyAPI):
self._view = self._physics_sim_view.create_rigid_body_view(self.cfg.prim_path.replace(".*", "*"))
found_supported_prim_class = True
else:
self._view = XFormPrimView(self.cfg.prim_path, reset_xform_properties=False)
found_supported_prim_class = True
carb.log_warn(f"The prim at path {prim.GetPath().pathString} is not a physics prim! Using XFormPrimView.")
# check if prim view class is found
if prim_view_class is None:
if not found_supported_prim_class:
raise RuntimeError(f"Failed to find a valid prim view class for the prim paths: {self.cfg.prim_path}")
# create a rigid prim view for the sensor
self._view = prim_view_class(self.cfg.prim_path, reset_xform_properties=False)
self._view.initialize()
# load the meshes by parsing the stage
self._initialize_warp_meshes()
......@@ -152,17 +157,17 @@ class RayCaster(SensorBase):
# check if the prim is a plane - handle PhysX plane as a special case
# if a plane exists then we need to create an infinite mesh that is a plane
mesh_prim = prim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda p: prim_utils.get_prim_type_name(p) == "Plane"
mesh_prim = sim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda prim: prim.GetTypeName() == "Plane"
)
# if we did not find a plane then we need to read the mesh
if mesh_prim is None:
# obtain the mesh prim
mesh_prim = prim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda p: prim_utils.get_prim_type_name(p) == "Mesh"
mesh_prim = sim_utils.get_first_matching_child_prim(
mesh_prim_path, lambda prim: prim.GetTypeName() == "Mesh"
)
# check if valid
if not prim_utils.is_prim_path_valid(mesh_prim_path):
if mesh_prim is None or not mesh_prim.IsValid():
raise RuntimeError(f"Invalid mesh prim path: {mesh_prim_path}")
# cast into UsdGeomMesh
mesh_prim = UsdGeom.Mesh(mesh_prim)
......@@ -210,8 +215,22 @@ class RayCaster(SensorBase):
def _update_buffers_impl(self, env_ids: Sequence[int]):
"""Fills the buffers of the sensor data."""
# obtain the poses of the sensors
pos_w, quat_w = self._view.get_world_poses(env_ids, clone=False)
if isinstance(self._view, XFormPrimView):
pos_w, quat_w = self._view.get_world_poses(env_ids)
elif isinstance(self._view, physx.ArticulationView):
pos_w, quat_w = self._view.get_root_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = convert_quat(quat_w, to="wxyz")
elif isinstance(self._view, physx.RigidBodyView):
pos_w, quat_w = self._view.get_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = convert_quat(quat_w, to="wxyz")
else:
raise RuntimeError(f"Unsupported view type: {type(self._view)}")
# note: we clone here because we are read-only operations
pos_w = pos_w.clone()
quat_w = quat_w.clone()
# apply drift
pos_w += self.drift[env_ids]
# store the poses
self._data.pos_w[env_ids] = pos_w
self._data.quat_w[env_ids] = quat_w
......
......@@ -10,6 +10,7 @@ from tensordict import TensorDict
from typing import TYPE_CHECKING, ClassVar, Sequence
from typing_extensions import Literal
import omni.physics.tensors.impl.api as physx
from omni.isaac.core.prims import XFormPrimView
import omni.isaac.orbit.utils.math as math_utils
......@@ -365,12 +366,17 @@ class RayCasterCamera(RayCaster):
# obtain the poses of the sensors
# note: clone arg doesn't exist for xform prim view so we need to do this manually
if isinstance(self._view, XFormPrimView):
pos_w_temp, quat_w_temp = self._view.get_world_poses(env_ids)
pos_w = pos_w_temp.clone()
quat_w = quat_w_temp.clone()
pos_w, quat_w = self._view.get_world_poses(env_ids)
elif isinstance(self._view, physx.ArticulationView):
pos_w, quat_w = self._view.get_root_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = math_utils.convert_quat(quat_w, to="wxyz")
elif isinstance(self._view, physx.RigidBodyView):
pos_w, quat_w = self._view.get_transforms()[env_ids].split([3, 4], dim=-1)
quat_w = math_utils.convert_quat(quat_w, to="wxyz")
else:
pos_w, quat_w = self._view.get_world_poses(env_ids, clone=True)
return pos_w, quat_w
raise RuntimeError(f"Unsupported view type: {type(self._view)}")
# return the pose
return pos_w.clone(), quat_w.clone()
def _compute_camera_world_poses(self, env_ids: Sequence[int]) -> tuple[torch.Tensor, torch.Tensor]:
"""Computes the pose of the camera in the world frame.
......
......@@ -17,10 +17,10 @@ import weakref
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence
import omni.isaac.core.utils.prims as prim_utils
import omni.kit.app
import omni.timeline
from omni.isaac.core.simulation_context import SimulationContext
import omni.isaac.orbit.sim as sim_utils
if TYPE_CHECKING:
from .sensor_base_cfg import SensorBaseCfg
......@@ -91,6 +91,14 @@ class SensorBase(ABC):
Properties
"""
@property
def num_instances(self) -> int:
"""Number of instances of the sensor.
This is equal to the number of sensors per environment multiplied by the number of environments.
"""
return self._num_envs
@property
def device(self) -> str:
"""Memory device for computation."""
......@@ -193,15 +201,17 @@ class SensorBase(ABC):
def _initialize_impl(self):
"""Initializes the sensor-related handles and internal buffers."""
# Obtain Simulation Context
sim = SimulationContext.instance()
if sim is not None:
self._device = sim.device
self._sim_physics_dt = sim.get_physics_dt()
else:
sim = sim_utils.SimulationContext.instance()
if sim is None:
raise RuntimeError("Simulation Context is not initialized!")
# Obtain device and backend
self._device = sim.device
self._backend = sim.backend
self._sim_physics_dt = sim.get_physics_dt()
# Count number of environments
env_prim_path_expr = self.cfg.prim_path.rsplit("/", 1)[0]
self._num_envs = len(prim_utils.find_matching_prim_paths(env_prim_path_expr))
self._parent_prims = sim_utils.find_matching_prims(env_prim_path_expr)
self._num_envs = len(self._parent_prims)
# Boolean tensor indicating whether the sensor data has to be refreshed
self._is_outdated = torch.ones(self._num_envs, dtype=torch.bool, device=self._device)
# Current timestamp (in seconds)
......
......@@ -559,12 +559,13 @@ class SimulationContext(_SimulationContext):
# check if the simulation is stopped
if event.type == int(omni.timeline.TimelineEventType.STOP):
# keep running the simulator when configured to not shutdown the app
self.app.print_and_log(
"Simulation is stopped. The app will keep running with physics disabled."
" Press Ctrl+C or close the window to exit the app."
)
while self.app.is_running():
self.render()
if self._has_gui:
self.app.print_and_log(
"Simulation is stopped. The app will keep running with physics disabled."
" Press Ctrl+C or close the window to exit the app."
)
while self.app.is_running():
self.render()
# make sure that any replicator workflows finish rendering/writing
if not builtins.ISAAC_LAUNCHED_FROM_TERMINAL:
try:
......
......@@ -13,7 +13,6 @@ import re
from typing import TYPE_CHECKING, Any, Callable
import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.isaac.core.utils.stage as stage_utils
import omni.kit.commands
from omni.isaac.cloner import Cloner
......@@ -226,7 +225,7 @@ def clone(func: Callable) -> Callable:
# resolve matching prims for source prim path expression
if is_regex_expression and root_path != "":
source_prim_paths = prim_utils.find_matching_prim_paths(root_path)
source_prim_paths = find_matching_prim_paths(root_path)
# if no matching prims are found, raise an error
if len(source_prim_paths) == 0:
raise RuntimeError(
......@@ -241,7 +240,11 @@ def clone(func: Callable) -> Callable:
prim = func(prim_paths[0], cfg, *args, **kwargs)
# set the prim visibility
if hasattr(cfg, "visible"):
prim_utils.set_prim_visibility(prim, cfg.visible)
imageable = UsdGeom.Imageable(prim)
if cfg.visible:
imageable.MakeVisible()
else:
imageable.MakeInvisible()
# set the semantic annotations
if hasattr(cfg, "semantic_tags") and cfg.semantic_tags is not None:
# note: taken from replicator scripts.utils.utils.py
......@@ -492,3 +495,164 @@ def make_uninstanceable(prim_path: str, stage: Usd.Stage | None = None):
child_prim.SetInstanceable(False)
# add children to list
all_prims += child_prim.GetChildren()
"""
USD Stage traversal.
"""
def get_first_matching_child_prim(
prim_path: str, predicate: Callable[[Usd.Prim], bool], stage: Usd.Stage | None = None
) -> Usd.Prim | None:
"""Recursively get the first USD Prim at the path string that passes the predicate function
Args:
prim_path: The path of the prim in the stage.
predicate: The function to test the prims against. It takes a prim as input and returns a boolean.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
The first prim on the path that passes the predicate. If no prim passes the predicate, it returns None.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# get prim
prim = stage.GetPrimAtPath(prim_path)
# check if prim is valid
if not prim.IsValid():
raise ValueError(f"Prim at path '{prim_path}' is not valid.")
# iterate over all prims under prim-path
all_prims = [prim]
while len(all_prims) > 0:
# get current prim
child_prim = all_prims.pop(0)
# check if prim passes predicate
if predicate(child_prim):
return child_prim
# add children to list
all_prims += child_prim.GetChildren()
return None
def get_all_matching_child_prims(
prim_path: str,
predicate: Callable[[Usd.Prim], bool] = lambda _: True,
depth: int | None = None,
stage: Usd.Stage | None = None,
) -> list[Usd.Prim]:
"""Performs a search starting from the root and returns all the prims matching the predicate.
Args:
prim_path: The root prim path to start the search from.
predicate: The predicate that checks if the prim matches the desired criteria. It takes a prim as input
and returns a boolean. Defaults to a function that always returns True.
depth: The maximum depth for traversal, should be bigger than zero if specified.
Defaults to None (i.e: traversal happens till the end of the tree).
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
A list containing all the prims matching the predicate.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# get prim
prim = stage.GetPrimAtPath(prim_path)
# check if prim is valid
if not prim.IsValid():
raise ValueError(f"Prim at path '{prim_path}' is not valid.")
# check if depth is valid
if depth is not None and depth <= 0:
raise ValueError(f"Depth must be bigger than zero, got {depth}.")
# iterate over all prims under prim-path
# list of tuples (prim, current_depth)
all_prims_queue = [(prim, 0)]
output_prims = []
while len(all_prims_queue) > 0:
# get current prim
child_prim, current_depth = all_prims_queue.pop(0)
# check if prim passes predicate
if predicate(child_prim):
output_prims.append(child_prim)
# add children to list
if depth is None or current_depth < depth:
all_prims_queue += [(child, current_depth + 1) for child in child_prim.GetChildren()]
return output_prims
def find_first_matching_prim(prim_path_regex: str, stage: Usd.Stage | None = None) -> Usd.Prim | None:
"""Find the first matching prim in the stage based on input regex expression.
Args:
prim_path_regex: The regex expression for prim path.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
The first prim that matches input expression. If no prim matches, returns None.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# need to wrap the token patterns in '^' and '$' to prevent matching anywhere in the string
pattern = f"^{prim_path_regex}$"
compiled_pattern = re.compile(pattern)
# obtain matching prim (depth-first search)
for prim in stage.Traverse():
# check if prim passes predicate
if compiled_pattern.match(prim.GetPath().pathString) is not None:
return prim
return None
def find_matching_prims(prim_path_regex: str, stage: Usd.Stage | None = None) -> list[Usd.Prim]:
"""Find all the matching prims in the stage based on input regex expression.
Args:
prim_path_regex: The regex expression for prim path.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
A list of prims that match input expression.
"""
# get current stage
if stage is None:
stage = stage_utils.get_current_stage()
# need to wrap the token patterns in '^' and '$' to prevent matching anywhere in the string
tokens = prim_path_regex.split("/")[1:]
tokens = [f"^{token}$" for token in tokens]
# iterate over all prims in stage (breath-first search)
all_prims = [stage.GetPseudoRoot()]
output_prims = []
for index, token in enumerate(tokens):
token_compiled = re.compile(token)
for prim in all_prims:
for child in prim.GetAllChildren():
if token_compiled.match(child.GetName()) is not None:
output_prims.append(child)
if index < len(tokens) - 1:
all_prims = output_prims
output_prims = []
return output_prims
def find_matching_prim_paths(prim_path_regex: str, stage: Usd.Stage | None = None) -> list[str]:
"""Find all the matching prim paths in the stage based on input regex expression.
Args:
prim_path_regex: The regex expression for prim path.
stage: The stage where the prim exists. Defaults to None, in which case the current stage is used.
Returns:
A list of prim paths that match input expression.
"""
# obtain matching prims
output_prims = find_matching_prims(prim_path_regex, stage)
# convert prims to prim paths
output_prim_paths = []
for prim in output_prims:
output_prim_paths.append(prim.GetPath().pathString)
return output_prim_paths
......@@ -10,9 +10,7 @@ import torch
import trimesh
from typing import TYPE_CHECKING
import omni.isaac.core.utils.prims as prim_utils
import warp
from omni.isaac.core.simulation_context import SimulationContext
from pxr import UsdGeom
import omni.isaac.orbit.sim as sim_utils
......@@ -71,7 +69,7 @@ class TerrainImporter:
"""
# store inputs
self.cfg = cfg
self.device = SimulationContext.instance().device
self.device = sim_utils.SimulationContext.instance().device # type: ignore
# create a dict of meshes
self.meshes = dict()
......@@ -246,8 +244,8 @@ class TerrainImporter:
# traverse the prim and get the collision mesh
# THINK: Should the user specify the collision mesh?
mesh_prim = prim_utils.get_first_matching_child_prim(
self.cfg.prim_path + f"/{key}", lambda p: prim_utils.get_prim_type_name(p) == "Mesh"
mesh_prim = sim_utils.get_first_matching_child_prim(
self.cfg.prim_path + f"/{key}", lambda prim: prim.GetTypeName() == "Mesh"
)
# check if the mesh is valid
if mesh_prim is None:
......
......@@ -101,10 +101,12 @@ def retrieve_file_path(path: str, download_dir: str | None = None, force_downloa
# download file in temp directory using os
file_name = os.path.basename(omni.client.break_url(path).path)
target_path = os.path.join(download_dir, file_name)
# copy file to local machine
result = omni.client.copy(path, target_path)
if result != omni.client.Result.OK and not force_download:
raise RuntimeError(f"Unable to copy file: '{path}'. File already exists locally at: {target_path}")
# check if file already exists locally
if not os.path.isfile(target_path) or force_download:
# copy file to local machine
result = omni.client.copy(path, target_path)
if result != omni.client.Result.OK and force_download:
raise RuntimeError(f"Unable to copy file: '{path}'. Is the Nucleus Server running?")
return os.path.abspath(target_path)
else:
raise FileNotFoundError(f"Unable to find the file: {path}")
......
......@@ -25,8 +25,8 @@ from omni.isaac.orbit.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to external force on a legged robot.")
parser.add_argument("--body", type=str, help="Name of the body to apply force on.")
parser.add_argument("--force", type=float, help="Force to apply on the body.")
parser.add_argument("--body", default="base", type=str, help="Name of the body to apply force on.")
parser.add_argument("--force", default=1000.0, type=float, help="Force to apply on the body.")
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
......@@ -78,7 +78,7 @@ def main():
# Find bodies to apply the force
body_ids, body_names = robot.find_bodies(args_cli.body)
# Sample a large force
external_wrench_b = torch.zeros(robot.root_view.count, len(body_ids), 6, device=sim.device)
external_wrench_b = torch.zeros(robot.num_instances, len(body_ids), 6, device=sim.device)
external_wrench_b[..., 1] = args_cli.force
# Now we are ready!
......
......@@ -159,7 +159,7 @@ class TestArticulation(unittest.TestCase):
# Find bodies to apply the force
body_ids, _ = robot.find_bodies("base")
# Sample a large force
external_wrench_b = torch.zeros(robot.root_view.count, len(body_ids), 6, device=self.sim.device)
external_wrench_b = torch.zeros(robot.num_instances, len(body_ids), 6, device=self.sim.device)
external_wrench_b[..., 1] = 1000.0
# Now we are ready!
......@@ -207,7 +207,7 @@ class TestArticulation(unittest.TestCase):
# Find bodies to apply the force
body_ids, _ = robot.find_bodies(".*_SHANK")
# Sample a large force
external_wrench_b = torch.zeros(robot.root_view.count, len(body_ids), 6, device=self.sim.device)
external_wrench_b = torch.zeros(robot.num_instances, len(body_ids), 6, device=self.sim.device)
external_wrench_b[..., 1] = 100.0
# Now we are ready!
......@@ -268,7 +268,7 @@ class TestArticulation(unittest.TestCase):
".*_foot.*": 2.0,
}
indices_list, _, values_list = string_utils.resolve_matching_names_values(expected_stiffness, robot.joint_names)
expected_stiffness = torch.zeros(robot.root_view.count, robot.num_joints, device=robot.device)
expected_stiffness = torch.zeros(robot.num_instances, robot.num_joints, device=robot.device)
expected_stiffness[:, indices_list] = torch.tensor(values_list, device=robot.device)
# -- Damping values
expected_damping = {
......@@ -308,7 +308,7 @@ class TestArticulation(unittest.TestCase):
self.sim.reset()
# Expected gains
expected_stiffness = torch.full((robot.root_view.count, robot.num_joints), 10.0, device=robot.device)
expected_stiffness = torch.full((robot.num_instances, robot.num_joints), 10.0, device=robot.device)
expected_damping = torch.full_like(expected_stiffness, 2.0)
# Check that gains are loaded from USD file
......@@ -333,7 +333,7 @@ class TestArticulation(unittest.TestCase):
self.sim.reset()
# Expected gains
expected_stiffness = torch.full((robot.root_view.count, robot.num_joints), 10.0, device=robot.device)
expected_stiffness = torch.full((robot.num_instances, robot.num_joints), 10.0, device=robot.device)
expected_damping = torch.full_like(expected_stiffness, 2.0)
# Check that gains are loaded from USD file
......
......@@ -112,8 +112,8 @@ class TestRigidObject(unittest.TestCase):
body_ids, _ = cube_object.find_bodies(".*")
# Sample a large force
external_wrench_b = torch.zeros(cube_object.root_view.count, len(body_ids), 6, device=self.sim.device)
external_wrench_b[0, 0, 2] = 9.81 * cube_object.root_view.get_masses(indices=[0])
external_wrench_b = torch.zeros(cube_object.num_instances, len(body_ids), 6, device=self.sim.device)
external_wrench_b[0, 0, 2] = 9.81 * cube_object.root_physx_view.get_masses()[0]
# Now we are ready!
for _ in range(5):
......
......@@ -25,7 +25,10 @@ parser = argparse.ArgumentParser(
parser.add_argument("--headless", action="store_true", default=False, help="Force display off at all times.")
parser.add_argument("--num_robots", type=int, default=128, help="Number of robots to spawn.")
parser.add_argument(
"--asset", type=str, default="orbit", help="The asset source location for the robot. Can be: orbit, oige."
"--asset",
type=str,
default="orbit",
help="The asset source location for the robot. Can be: orbit, oige, custom asset path.",
)
args_cli = parser.parse_args()
......
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
"""Launch Isaac Sim Simulator first."""
from omni.isaac.kit import SimulationApp
# launch omniverse app
config = {"headless": True}
simulation_app = SimulationApp(config)
"""Rest everything follows."""
import numpy as np
import traceback
import unittest
import carb
import omni.isaac.core.utils.prims as prim_utils
import omni.isaac.core.utils.stage as stage_utils
import omni.isaac.orbit.sim as sim_utils
class TestUtilities(unittest.TestCase):
"""Test fixture for the sim utility functions."""
def setUp(self):
"""Create a blank new stage for each test."""
# Create a new stage
stage_utils.create_new_stage()
stage_utils.update_stage()
def tearDown(self) -> None:
"""Clear stage after each test."""
stage_utils.clear_stage()
def test_get_all_matching_child_prims(self):
"""Test get_all_matching_child_prims() function."""
# create scene
prim_utils.create_prim("/World/Floor")
prim_utils.create_prim(
"/World/Floor/thefloor", "Cube", position=np.array([75, 75, -150.1]), attributes={"size": 300}
)
prim_utils.create_prim("/World/Room", "Sphere", attributes={"radius": 1e3})
# test
isaac_sim_result = prim_utils.get_all_matching_child_prims("/World")
orbit_result = sim_utils.get_all_matching_child_prims("/World")
self.assertListEqual(isaac_sim_result, orbit_result)
def test_find_matching_prim_paths(self):
"""Test find_matching_prim_paths() function."""
# create scene
for index in range(2048):
random_pos = np.random.uniform(-100, 100, size=3)
prim_utils.create_prim(f"/World/Floor_{index}", "Cube", position=random_pos, attributes={"size": 2.0})
prim_utils.create_prim(f"/World/Floor_{index}/Sphere", "Sphere", attributes={"radius": 10})
prim_utils.create_prim(f"/World/Floor_{index}/Sphere/childSphere", "Sphere", attributes={"radius": 1})
prim_utils.create_prim(f"/World/Floor_{index}/Sphere/childSphere2", "Sphere", attributes={"radius": 1})
# test leaf paths
isaac_sim_result = prim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere")
orbit_result = sim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere")
self.assertListEqual(isaac_sim_result, orbit_result)
# test non-leaf paths
isaac_sim_result = prim_utils.find_matching_prim_paths("/World/Floor_.*")
orbit_result = sim_utils.find_matching_prim_paths("/World/Floor_.*")
self.assertListEqual(isaac_sim_result, orbit_result)
# test child-leaf paths
isaac_sim_result = prim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere/childSphere.*")
orbit_result = sim_utils.find_matching_prim_paths("/World/Floor_.*/Sphere/childSphere.*")
self.assertListEqual(isaac_sim_result, orbit_result)
if __name__ == "__main__":
try:
unittest.main()
except Exception as err:
carb.log_error(err)
carb.log_error(traceback.format_exc())
raise
finally:
# close sim app
simulation_app.close()
......@@ -34,7 +34,7 @@ def base_up_proj(env: HumanoidEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("r
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
# compute base up vector
base_up_vec = math_utils.quat_rotate(asset.data.root_quat_w, -asset._GRAVITY_VEC_W) # type: ignore
base_up_vec = math_utils.quat_rotate(asset.data.root_quat_w, -asset.GRAVITY_VEC_W)
return base_up_vec[:, 2].unsqueeze(-1)
......@@ -50,7 +50,7 @@ def base_heading_proj(
to_target_pos[:, 2] = 0.0
to_target_dir = math_utils.normalize(to_target_pos)
# compute base forward vector
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset._FORWARD_VEC_B) # type: ignore
heading_vec = math_utils.quat_rotate(asset.data.root_quat_w, asset.FORWARD_VEC_B)
# compute dot product between heading and target direction
heading_proj = torch.bmm(heading_vec.view(env.num_envs, 1, 3), to_target_dir.view(env.num_envs, 3, 1))
......
......@@ -20,7 +20,7 @@ class LiftCubePPORunnerCfg(RslRlOnPolicyRunnerCfg):
experiment_name = "franka_lift"
empirical_normalization = False
policy = RslRlPpoActorCriticCfg(
init_noise_std=0.8,
init_noise_std=1.0,
actor_hidden_dims=[256, 128, 64],
critic_hidden_dims=[256, 128, 64],
activation="elu",
......
......@@ -33,7 +33,7 @@ class FrankaCubeLiftEnvCfg(LiftEnvCfg):
# Set actions for the specific robot type (franka)
self.actions.body_joint_pos = mdp.JointPositionActionCfg(
asset_name="robot", joint_names=["panda_joint.*"], scale=1.0, use_default_offset=True
asset_name="robot", joint_names=["panda_joint.*"], scale=0.5, use_default_offset=True
)
self.actions.finger_joint_pos = mdp.BinaryJointPositionActionCfg(
asset_name="robot",
......
......@@ -95,7 +95,7 @@ def main():
print("[INFO]: Setup complete...")
# dummy actions
actions = torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
has_gripper = args_cli.robot == "franka_panda"
# Define simulation stepping
......@@ -114,9 +114,8 @@ def main():
robot.write_joint_state_to_sim(joint_pos, joint_vel)
robot.reset()
# reset command
actions = (
torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
)
actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device)
actions += robot.data.default_joint_pos
# reset gripper
if has_gripper:
actions[:, -2:] = 0.04
......
......@@ -105,7 +105,7 @@ def main():
print("[INFO]: Setup complete...")
# dummy actions
actions = torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
has_gripper = args_cli.robot == "franka_panda"
# Define simulation stepping
......@@ -124,9 +124,8 @@ def main():
robot.write_joint_state_to_sim(joint_pos, joint_vel)
robot.reset()
# reset command
actions = (
torch.rand(robot.root_view.count, robot.num_joints, device=robot.device) + robot.data.default_joint_pos
)
actions = torch.rand(robot.num_instances, robot.num_joints, device=robot.device)
actions += robot.data.default_joint_pos
# reset gripper
if has_gripper:
actions[:, -2:] = 0.04
......
......@@ -147,9 +147,8 @@ def main():
actions[:, 2] = 1.0
# change the arm action
if ep_step_count % 100:
actions[:, 3:10] = (
torch.rand(robot.root_view.count, 7, device=robot.device) + robot.data.default_joint_pos[:, 3:10]
)
actions[:, 3:10] = torch.rand(robot.num_instances, 7, device=robot.device)
actions[:, 3:10] += robot.data.default_joint_pos[:, 3:10]
# apply action
robot.set_joint_velocity_target(actions[:, :3], joint_ids=[0, 1, 2])
robot.set_joint_position_target(actions[:, 3:], joint_ids=[3, 4, 5, 6, 7, 8, 9, 10, 11])
......
......@@ -142,8 +142,8 @@ def main():
print("[INFO]: Setup complete...")
# Create buffers to store actions
rmp_commands = torch.zeros(robot.count, rmp_controller.num_actions, device=robot.device)
robot_actions = torch.ones(robot.count, robot.num_actions, device=robot.device)
rmp_commands = torch.zeros(robot.num_instances, rmp_controller.num_actions, device=robot.device)
robot_actions = torch.ones(robot.num_instances, robot.num_actions, device=robot.device)
has_gripper = robot.cfg.meta_info.tool_num_dof > 0
# Set end effector goals
......
......@@ -98,11 +98,11 @@ def main():
root_state = rigid_object.data.default_root_state.clone()
# -- position
root_state[:, :3] = sample_cylinder(
radius=0.5, h_range=(0.15, 0.25), size=rigid_object.root_view.count, device=rigid_object.device
radius=0.5, h_range=(0.15, 0.25), size=rigid_object.num_instances, device=rigid_object.device
)
# -- orientation: apply yaw rotation
root_state[:, 3:7] = quat_mul(
random_yaw_orientation(rigid_object.root_view.count, rigid_object.device), root_state[:, 3:7]
random_yaw_orientation(rigid_object.num_instances, rigid_object.device), root_state[:, 3:7]
)
# -- set root state
rigid_object.write_root_state_to_sim(root_state)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment