Unverified Commit 571189b1 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Modifies behavior of debug visualization for better UI experience (#208)

# Description

Earlier the UI was disabling buttons for which the config object set
`debug_vis` flag as False. This wasn't an elegant behavior as users may
want to enable the visualization from the GUI at runtime. This MR
modifies the behavior of all classes that provide debug visualization to
support this behavior.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 6ae3c3fb
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.23"
version = "0.9.24"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.9.24 (2023-10-27)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Changed the behavior of setting up debug visualization for assets, sensors and command generators.
Earlier it was raising an error if debug visualization was not enabled in the configuration object.
Now it checks whether debug visualization is implemented and only sets up the callback if it is
implemented.
0.9.23 (2023-10-27)
~~~~~~~~~~~~~~~~~~~
......
......@@ -5,6 +5,7 @@
from __future__ import annotations
import inspect
import re
import weakref
from abc import ABC, abstractmethod
......@@ -68,14 +69,10 @@ class AssetBase(ABC):
lambda event, obj=weakref.proxy(self): obj._invalidate_initialize_callback(event),
order=10,
)
# add callback for debug visualization
if self.cfg.debug_vis:
app_interface = omni.kit.app.get_app_interface()
self._debug_visualization_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event),
)
else:
self._debug_visualization_handle = None
# add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
self._debug_vis_handle = None
# set initial state of debug visualization
self.set_debug_vis(self.cfg.debug_vis)
def __del__(self):
"""Unsubscribe from the callbacks."""
......@@ -87,9 +84,9 @@ class AssetBase(ABC):
self._invalidate_initialize_handle.unsubscribe()
self._invalidate_initialize_handle = None
# clear debug visualization
if self._debug_visualization_handle:
self._debug_visualization_handle.unsubscribe()
self._debug_visualization_handle = None
if self._debug_vis_handle:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
"""
Properties
......@@ -107,21 +104,47 @@ class AssetBase(ABC):
"""Data related to the asset."""
return NotImplementedError
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the asset has a debug visualization implemented."""
# check if function raises NotImplementedError
source_code = inspect.getsource(self._debug_vis_callback)
return "NotImplementedError" not in source_code
"""
Operations.
"""
def set_debug_vis(self, debug_vis: bool):
def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the asset data.
Args:
debug_vis: Whether to visualize the asset data.
Raises:
RuntimeError: If the asset debug visualization is not enabled.
Returns:
Whether the debug visualization was successfully set. False if the asset
does not support debug visualization.
"""
if not self.cfg.debug_vis:
raise RuntimeError("Debug visualization is not enabled for this sensor.")
# check if debug visualization is supported
if not self.has_debug_vis_implementation:
return False
# toggle debug visualization objects
self._set_debug_vis_impl(debug_vis)
# toggle debug visualization handles
if debug_vis:
# create a subscriber for the post update event if it doesn't exist
if self._debug_vis_handle is None:
app_interface = omni.kit.app.get_app_interface()
self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event)
)
else:
# remove the subscriber if it exists
if self._debug_vis_handle is not None:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
# return success
return True
@abstractmethod
def reset(self, env_ids: Sequence[int] | None = None):
......@@ -158,12 +181,24 @@ class AssetBase(ABC):
"""Initializes the PhysX handles and internal buffers."""
raise NotImplementedError
def _debug_vis_impl(self):
"""Perform debug visualization of the asset."""
pass
def _set_debug_vis_impl(self, debug_vis: bool):
"""Set debug visualization into visualization objects.
This function is responsible for creating the visualization objects if they don't exist
and input ``debug_vis`` is True. If the visualization objects exist, the function should
set their visibility into the stage.
"""
Simulation callbacks.
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
def _debug_vis_callback(self, event):
"""Callback for debug visualization.
This function calls the visualization objects and sets the data to visualize into them.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
"""
Internal simulation callbacks.
"""
def _initialize_callback(self, event):
......@@ -180,7 +215,3 @@ class AssetBase(ABC):
def _invalidate_initialize_callback(self, event):
"""Invalidates the scene elements."""
self._is_initialized = False
def _debug_vis_callback(self, event):
"""Visualizes the asset data."""
self._debug_vis_impl()
......@@ -12,6 +12,7 @@ methods.
from __future__ import annotations
import inspect
import torch
import weakref
from abc import ABC, abstractmethod
......@@ -48,6 +49,7 @@ class CommandGeneratorBase(ABC):
# store the inputs
self.cfg = cfg
self._env = env
# create buffers to store the command
# -- metrics that can be used for logging
self.metrics = dict()
......@@ -56,21 +58,16 @@ class CommandGeneratorBase(ABC):
# -- counter for the number of times the command has been resampled within the current episode
self.command_counter = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
# add callback for debug visualization
if self.cfg.debug_vis:
app_interface = omni.kit.app.get_app_interface()
# NOTE: Use weakref on callback to ensure that this object can be deleted when its destructor is called.
self._debug_visualization_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event),
)
else:
self._debug_visualization_handle = None
# add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
self._debug_vis_handle = None
# set initial state of debug visualization
self.set_debug_vis(self.cfg.debug_vis)
def __del__(self):
"""Unsubscribe from the callbacks."""
if self._debug_visualization_handle is not None:
self._debug_visualization_handle.unsubscribe()
self._debug_visualization_handle = None
if self._debug_vis_handle:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
"""
Properties
......@@ -92,21 +89,47 @@ class CommandGeneratorBase(ABC):
"""The command tensor. Shape is (num_envs, command_dim)."""
raise NotImplementedError
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the command generator has a debug visualization implemented."""
# check if function raises NotImplementedError
source_code = inspect.getsource(self._debug_vis_callback)
return "NotImplementedError" not in source_code
"""
Operations.
"""
def set_debug_vis(self, debug_vis: bool):
def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the command data.
Args:
debug_vis: Whether to visualize the command data.
Raises:
RuntimeError: If the command debug visualization is not enabled.
"""
if not self.cfg.debug_vis:
raise RuntimeError("Debug visualization is not enabled for this sensor.")
Returns:
Whether the debug visualization was successfully set. False if the command
generator does not support debug visualization.
"""
# check if debug visualization is supported
if not self.has_debug_vis_implementation:
return False
# toggle debug visualization objects
self._set_debug_vis_impl(debug_vis)
# toggle debug visualization handles
if debug_vis:
# create a subscriber for the post update event if it doesn't exist
if self._debug_vis_handle is None:
app_interface = omni.kit.app.get_app_interface()
self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event)
)
else:
# remove the subscriber if it exists
if self._debug_vis_handle is not None:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
# return success
return True
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
"""Reset the command generator and log metrics.
......@@ -173,14 +196,6 @@ class CommandGeneratorBase(ABC):
# resample the command
self._resample_command(env_ids)
"""
Simulation callbacks.
"""
def _debug_vis_callback(self, event):
"""Visualizes the sensor data."""
self._debug_vis_impl()
"""
Implementation specific functions.
"""
......@@ -200,9 +215,18 @@ class CommandGeneratorBase(ABC):
"""Update the metrics based on the current state."""
raise NotImplementedError
def _debug_vis_impl(self):
"""Visualize the command in the simulator.
def _set_debug_vis_impl(self, debug_vis: bool):
"""Set debug visualization into visualization objects.
This function is responsible for creating the visualization objects if they don't exist
and input ``debug_vis`` is True. If the visualization objects exist, the function should
set their visibility into the stage.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
def _debug_vis_callback(self, event):
"""Callback for debug visualization.
This is an optional function that can be used to visualize the command in the simulator.
This function calls the visualization objects and sets the data to visualize into them.
"""
pass
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
......@@ -41,11 +41,16 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
cfg: The configuration parameters for the command generator.
env: The environment object.
"""
# initialize the base class
super().__init__(cfg, env)
# obtain the robot and terrain assets
# -- robot
self.robot: Articulation = env.scene[cfg.asset_name]
# -- terrain
self.terrain: TerrainImporter = env.scene.terrain
# crete buffers to store the command
# -- commands: (x, y, z, heading)
self.pos_command_w = torch.zeros(self.num_envs, 3, device=self.device)
self.heading_command_w = torch.zeros(self.num_envs, device=self.device)
......@@ -54,8 +59,6 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
# -- metrics
self.metrics["error_pos"] = torch.zeros(self.num_envs, device=self.device)
self.metrics["error_heading"] = torch.zeros(self.num_envs, device=self.device)
# -- debug vis
self.box_goal_visualizer = None
def __str__(self) -> str:
msg = "TerrainBasedPositionCommandGenerator:\n"
......@@ -73,15 +76,6 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
"""The desired base position in base frame. Shape is (num_envs, 3)."""
return self.pos_command_b
"""
Operations.
"""
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
if self.box_goal_visualizer is not None:
self.box_goal_visualizer.set_visibility(debug_vis)
"""
Implementation specific functions.
"""
......@@ -120,12 +114,20 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
self.metrics["error_pos"] = torch.norm(self.pos_command_w - self.robot.data.root_pos_w[:, :3], dim=1)
self.metrics["error_heading"] = torch.abs(wrap_to_pi(self.heading_command_w - self.robot.heading_w))
def _debug_vis_impl(self):
# create the box marker if necessary
if self.box_goal_visualizer is None:
def _set_debug_vis_impl(self, debug_vis: bool):
# create markers if necessary for the first tome
if debug_vis:
if not hasattr(self, "box_goal_visualizer"):
marker_cfg = CUBOID_MARKER_CFG.copy()
marker_cfg.prim_path = "/Visuals/Command/position_goal"
marker_cfg.markers["cuboid"].scale = (0.1, 0.1, 0.1)
self.box_goal_visualizer = VisualizationMarkers(marker_cfg)
# set their visibility to true
self.box_goal_visualizer.set_visibility(True)
else:
if hasattr(self, "box_goal_visualizer"):
self.box_goal_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# update the box marker
self.box_goal_visualizer.visualize(self.pos_command_w)
......@@ -51,9 +51,14 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
cfg: The configuration of the command generator.
env: The environment.
"""
# initialize the base class
super().__init__(cfg, env)
# obtain the robot asset
# -- robot
self.robot: Articulation = env.scene[cfg.asset_name]
# crete buffers to store the command
# -- command: x vel, y vel, yaw vel, heading
self.vel_command_b = torch.zeros(self.num_envs, 3, device=self.device)
self.heading_target = torch.zeros(self.num_envs, device=self.device)
......@@ -62,9 +67,6 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
# -- metrics
self.metrics["error_vel_xy"] = torch.zeros(self.num_envs, device=self.device)
self.metrics["error_vel_yaw"] = torch.zeros(self.num_envs, device=self.device)
# -- debug vis
self.base_vel_goal_visualizer = None
self.base_vel_visualizer = None
def __str__(self) -> str:
"""Return a string representation of the command generator."""
......@@ -86,19 +88,6 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
"""The desired base velocity command in the base frame. Shape is (num_envs, 3)."""
return self.vel_command_b
"""
Operations.
"""
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
# -- current
if self.base_vel_visualizer is not None:
self.base_vel_visualizer.set_visibility(debug_vis)
# -- goal
if self.base_vel_goal_visualizer is not None:
self.base_vel_goal_visualizer.set_visibility(debug_vis)
"""
Implementation specific functions.
"""
......@@ -152,20 +141,31 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_time
)
def _debug_vis_impl(self):
# create markers if necessary
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
# note: parent only deals with callbacks. not their visibility
if debug_vis:
# create markers if necessary for the first tome
if not hasattr(self, "base_vel_goal_visualizer"):
# -- goal
if self.base_vel_goal_visualizer is None:
marker_cfg = GREEN_ARROW_X_MARKER_CFG.copy()
marker_cfg.prim_path = "/Visuals/Command/velocity_goal"
marker_cfg.markers["arrow"].scale = (2.5, 0.1, 0.1)
self.base_vel_goal_visualizer = VisualizationMarkers(marker_cfg)
# -- current
if self.base_vel_visualizer is None:
marker_cfg = BLUE_ARROW_X_MARKER_CFG.copy()
marker_cfg.prim_path = "/Visuals/Command/velocity_current"
marker_cfg.markers["arrow"].scale = (2.5, 0.1, 0.1)
self.base_vel_visualizer = VisualizationMarkers(marker_cfg)
# set their visibility to true
self.base_vel_goal_visualizer.set_visibility(True)
self.base_vel_visualizer.set_visibility(True)
else:
if hasattr(self, "base_vel_goal_visualizer"):
self.base_vel_goal_visualizer.set_visibility(False)
self.base_vel_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# get marker location
# -- base state
base_pos_w = self.robot.data.root_pos_w.clone()
......@@ -173,9 +173,8 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
# -- resolve the scales and quaternions
vel_des_arrow_scale, vel_des_arrow_quat = self._resolve_xy_velocity_to_arrow(self.command[:, :2])
vel_arrow_scale, vel_arrow_quat = self._resolve_xy_velocity_to_arrow(self.robot.data.root_lin_vel_b[:, :2])
# -- goal
# display markers
self.base_vel_goal_visualizer.visualize(base_pos_w, vel_des_arrow_quat, vel_des_arrow_scale)
# -- base velocity
self.base_vel_visualizer.visualize(base_pos_w, vel_arrow_quat, vel_arrow_scale)
"""
......
......@@ -10,6 +10,7 @@ import gym
import math
import numpy as np
import torch
import weakref
from typing import Any, ClassVar, Dict, Sequence, Tuple, Union
import omni.usd
......@@ -445,59 +446,42 @@ class RLEnv(BaseEnv, gym.Env):
# create stack for debug visualization
self._orbit_window_elements["debug_vstack"] = ui.VStack(spacing=5, height=0)
with self._orbit_window_elements["debug_vstack"]:
elements = [
self.scene.terrain,
self.command_manager,
*self.scene.rigid_objects.values(),
*self.scene.articulations.values(),
*self.scene.sensors.values(),
]
names = [
"terrain",
"commands",
*self.scene.rigid_objects.keys(),
*self.scene.articulations.keys(),
*self.scene.sensors.keys(),
]
# create one for the terrain
if self.scene.terrain is not None:
for elem, name in zip(elements, names):
if elem is not None:
with ui.HStack():
# create the UI element
debug_vis_checkbox = {
"model": ui.SimpleBoolModel(default_value=self.scene.terrain.cfg.debug_vis),
"enabled": self.scene.terrain.cfg.debug_vis,
"checked": self.scene.terrain.cfg.debug_vis,
"on_checked_fn": lambda value: self.scene.terrain.set_debug_vis(value),
}
ui.Label(
"Terrain",
width=ui_utils.LABEL_WIDTH - 12,
alignment=ui.Alignment.LEFT_CENTER,
tooltip="Toggle debug visualization",
text = (
"Toggle debug visualization."
if elem.has_debug_vis_implementation
else "Debug visualization not implemented."
)
self._orbit_window_elements["terrain_cb"] = SimpleCheckBox(**debug_vis_checkbox)
ui_utils.add_line_rect_flourish()
# iterate over each scene element and add a checkbox for debug visualization
for name, element in self.scene.sensors.items():
with ui.HStack():
# create the UI element
# note: need to deal with closure of lambda function inside for loop
# ref: https://stackoverflow.com/questions/66131048/python-lambda-function-is-not-being-called-correctly-from-within-a-for-loop
debug_vis_checkbox = {
"model": ui.SimpleBoolModel(default_value=element.cfg.debug_vis),
"enabled": element.cfg.debug_vis,
"checked": element.cfg.debug_vis,
"on_checked_fn": lambda value, ele=element: ele.set_debug_vis(value),
}
ui.Label(
ui_utils.format_tt(name.replace("_", " ")),
name.replace("_", " ").title(),
width=ui_utils.LABEL_WIDTH - 12,
alignment=ui.Alignment.LEFT_CENTER,
tooltip="Toggle debug visualization",
tooltip=text,
)
self._orbit_window_elements[f"sensor_{name}_cb"] = SimpleCheckBox(**debug_vis_checkbox)
ui_utils.add_line_rect_flourish()
# create one for the command manager
with ui.HStack():
debug_vis_checkbox = {
"model": ui.SimpleBoolModel(default_value=self.command_manager.cfg.debug_vis),
"enabled": self.command_manager.cfg.debug_vis,
"checked": self.command_manager.cfg.debug_vis,
"on_checked_fn": lambda value: self.command_manager.set_debug_vis(value),
}
ui.Label(
"Command Manager",
width=ui_utils.LABEL_WIDTH - 12,
alignment=ui.Alignment.LEFT_CENTER,
tooltip="Toggle debug visualization",
self._orbit_window_elements[f"{name}_cb"] = SimpleCheckBox(
model=ui.SimpleBoolModel(),
enabled=elem.has_debug_vis_implementation,
checked=elem.cfg.debug_vis,
on_checked_fn=lambda value, e=weakref.proxy(elem): e.set_debug_vis(value),
)
self._orbit_window_elements["command_cb"] = SimpleCheckBox(**debug_vis_checkbox)
ui_utils.add_line_rect_flourish()
async def _dock_window(self, window_title: str):
......
......@@ -58,8 +58,6 @@ class ContactSensor(SensorBase):
super().__init__(cfg)
# Create empty variables for storing output data
self._data: ContactSensorData = ContactSensorData()
# visualization markers
self.contact_visualizer = None
def __str__(self) -> str:
"""Returns: A string containing information about the instance."""
......@@ -127,11 +125,6 @@ class ContactSensor(SensorBase):
Operations
"""
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
if self.contact_visualizer is not None:
self.contact_visualizer.set_visibility(debug_vis)
def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters
super().reset(env_ids)
......@@ -281,11 +274,21 @@ class ContactSensor(SensorBase):
# -- increment timers for bodies that are not in contact
self._data.current_air_time[env_ids] *= ~is_contact
def _debug_vis_impl(self):
# visualize the contacts
if self.contact_visualizer is None:
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
# note: parent only deals with callbacks. not their visibility
if debug_vis:
# create markers if necessary for the first tome
if not hasattr(self, "contact_visualizer"):
visualizer_cfg = CONTACT_SENSOR_MARKER_CFG.replace(prim_path="/Visuals/ContactSensor")
self.contact_visualizer = VisualizationMarkers(visualizer_cfg)
# set their visibility to true
self.contact_visualizer.set_visibility(True)
else:
if hasattr(self, "contact_visualizer"):
self.contact_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# safely return if view becomes invalid
# note: this invalidity happens because of isaac sim view callbacks
if self.body_physx_view is None:
......
......@@ -54,14 +54,12 @@ class RayCaster(SensorBase):
Args:
cfg: The configuration parameters.
"""
# initialize base class
# Initialize base class
super().__init__(cfg)
# Create empty variables for storing output data
self._data = RayCasterData()
# List of meshes to ray-cast
self.warp_meshes = []
# visualization markers
self.ray_visualizer = None
def __str__(self) -> str:
"""Returns: A string containing information about the instance."""
......@@ -90,11 +88,6 @@ class RayCaster(SensorBase):
Operations.
"""
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
if self.ray_visualizer is not None:
self.ray_visualizer.set_visibility(debug_vis)
def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters
super().reset(env_ids)
......@@ -219,10 +212,19 @@ class RayCaster(SensorBase):
# TODO: Make this work for multiple meshes?
self._data.ray_hits_w[env_ids] = raycast_mesh(ray_starts_w, ray_directions_w, self.warp_meshes[0])
def _debug_vis_impl(self):
# visualize the point hits
if self.ray_visualizer is None:
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
# note: parent only deals with callbacks. not their visibility
if debug_vis:
if not hasattr(self, "ray_visualizer"):
visualizer_cfg = RAY_CASTER_MARKER_CFG.replace(prim_path="/Visuals/RayCaster")
self.ray_visualizer = VisualizationMarkers(visualizer_cfg)
# check if prim is visualized
# set their visibility to true
self.ray_visualizer.set_visibility(True)
else:
if hasattr(self, "ray_visualizer"):
self.ray_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# show ray hit positions
self.ray_visualizer.visualize(self._data.ray_hits_w.view(-1, 3))
......@@ -11,6 +11,7 @@ Each sensor class should inherit from this class and implement the abstract meth
from __future__ import annotations
import inspect
import torch
import weakref
from abc import ABC, abstractmethod
......@@ -65,14 +66,10 @@ class SensorBase(ABC):
lambda event, obj=weakref.proxy(self): obj._invalidate_initialize_callback(event),
order=10,
)
# add callback for debug visualization
if self.cfg.debug_vis:
app_interface = omni.kit.app.get_app_interface()
self._debug_visualization_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event),
)
else:
self._debug_visualization_handle = None
# add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
self._debug_vis_handle = None
# set initial state of debug visualization
self.set_debug_vis(self.cfg.debug_vis)
def __del__(self):
"""Unsubscribe from the callbacks."""
......@@ -84,9 +81,9 @@ class SensorBase(ABC):
self._invalidate_initialize_handle.unsubscribe()
self._invalidate_initialize_handle = None
# clear debug visualization
if self._debug_visualization_handle:
self._debug_visualization_handle.unsubscribe()
self._debug_visualization_handle = None
if self._debug_vis_handle:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
"""
Properties
......@@ -117,21 +114,48 @@ class SensorBase(ABC):
"""
raise NotImplementedError
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the sensor has a debug visualization implemented."""
# check if function raises NotImplementedError
# check if function raises NotImplementedError
source_code = inspect.getsource(self._debug_vis_callback)
return "NotImplementedError" not in source_code
"""
Operations
"""
def set_debug_vis(self, debug_vis: bool):
def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the sensor data.
Args:
debug_vis: Whether to visualize the sensor data.
Raises:
RuntimeError: If the asset debug visualization is not enabled.
Returns:
Whether the debug visualization was successfully set. False if the sensor
does not support debug visualization.
"""
if not self.cfg.debug_vis:
raise RuntimeError("Debug visualization is not enabled for this sensor.")
# check if debug visualization is supported
if not self.has_debug_vis_implementation:
return False
# toggle debug visualization objects
self._set_debug_vis_impl(debug_vis)
# toggle debug visualization handles
if debug_vis:
# create a subscriber for the post update event if it doesn't exist
if self._debug_vis_handle is None:
app_interface = omni.kit.app.get_app_interface()
self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event)
)
else:
# remove the subscriber if it exists
if self._debug_vis_handle is not None:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
# return success
return True
def reset(self, env_ids: Sequence[int] | None = None):
"""Resets the sensor internals.
......@@ -194,19 +218,24 @@ class SensorBase(ABC):
"""
raise NotImplementedError
def _debug_vis_impl(self):
"""Visualizes the sensor data.
def _set_debug_vis_impl(self, debug_vis: bool):
"""Set debug visualization into visualization objects.
This is an empty function that can be overridden by the derived class to visualize the sensor data.
This function is responsible for creating the visualization objects if they don't exist
and input ``debug_vis`` is True. If the visualization objects exist, the function should
set their visibility into the stage.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
Note:
Visualization of sensor data may add overhead to the simulation. It is recommended to disable
visualization when running the simulation in headless mode.
def _debug_vis_callback(self, event):
"""Callback for debug visualization.
This function calls the visualization objects and sets the data to visualize into them.
"""
pass
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
"""
Simulation callbacks.
Internal simulation callbacks.
"""
def _initialize_callback(self, event):
......@@ -224,10 +253,6 @@ class SensorBase(ABC):
"""Invalidates the scene elements."""
self._is_initialized = False
def _debug_vis_callback(self, event):
"""Visualizes the sensor data."""
self._debug_vis_impl()
"""
Helper functions.
"""
......
......@@ -78,13 +78,6 @@ class TerrainImporter:
self.warp_meshes = dict()
self.env_origins = None
self.terrain_origins = None
# marker for visualization
if self.cfg.debug_vis:
self.origin_visualizer = VisualizationMarkers(
cfg=FRAME_MARKER_CFG.replace(prim_path="/Visuals/TerrainOrigin")
)
else:
self.origin_visualizer = None
# auto-import the terrain based on the config
if self.cfg.terrain_type == "generator":
......@@ -112,20 +105,57 @@ class TerrainImporter:
else:
raise ValueError(f"Terrain type '{self.cfg.terrain_type}' not available.")
# set initial state of debug visualization
self.set_debug_vis(self.cfg.debug_vis)
"""
Properties.
"""
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the terrain importer has a debug visualization implemented.
This always returns True.
"""
return True
"""
Operations - Visibility.
"""
def set_debug_vis(self, debug_vis: bool):
def set_debug_vis(self, debug_vis: bool) -> bool:
"""Set the debug visualization of the terrain importer.
Args:
debug_vis: Whether to visualize the terrain origins.
Returns:
Whether the debug visualization was successfully set. False if the terrain
importer does not support debug visualization.
Raises:
RuntimeError: If terrain origins are not configured.
"""
if not self.cfg.debug_vis:
raise RuntimeError("Debug visualization is not enabled for this sensor.")
# create a marker if necessary
if debug_vis:
if not hasattr(self, "origin_visualizer"):
self.origin_visualizer = VisualizationMarkers(
cfg=FRAME_MARKER_CFG.replace(prim_path="/Visuals/TerrainOrigin")
)
if self.terrain_origins is not None:
self.origin_visualizer.visualize(self.terrain_origins.reshape(-1, 3))
elif self.env_origins is not None:
self.origin_visualizer.visualize(self.env_origins.reshape(-1, 3))
else:
raise RuntimeError("Terrain origins are not configured.")
# set visibility
self.origin_visualizer.set_visibility(debug_vis)
self.origin_visualizer.set_visibility(True)
else:
if hasattr(self, "origin_visualizer"):
self.origin_visualizer.set_visibility(False)
# report success
return True
"""
Operations - Import.
......@@ -251,9 +281,6 @@ class TerrainImporter:
self.terrain_origins = origins.to(self.device, dtype=torch.float)
# compute environment origins
self.env_origins = self._compute_env_origins_curriculum(self.cfg.num_envs, self.terrain_origins)
# put markers on the sub-terrain origins
if self.origin_visualizer is not None:
self.origin_visualizer.visualize(self.terrain_origins.reshape(-1, 3))
else:
self.terrain_origins = None
# check if env spacing is valid
......@@ -261,9 +288,6 @@ class TerrainImporter:
raise ValueError("Environment spacing must be specified for configuring grid-like origins.")
# compute environment origins
self.env_origins = self._compute_env_origins_grid(self.cfg.num_envs, self.cfg.env_spacing)
# put markers on the grid origins
if self.origin_visualizer is not None:
self.origin_visualizer.visualize(self.env_origins.reshape(-1, 3))
def update_env_origins(self, env_ids: torch.Tensor, move_up: torch.Tensor, move_down: torch.Tensor):
"""Update the environment origins based on the terrain levels."""
......
......@@ -56,7 +56,7 @@ class TerrainSceneCfg(InteractiveSceneCfg):
static_friction=1.0,
dynamic_friction=1.0,
),
debug_vis=True,
debug_vis=False,
)
# robots
robot = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment