Unverified Commit 571189b1 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Modifies behavior of debug visualization for better UI experience (#208)

# Description

Earlier the UI was disabling buttons for which the config object set
`debug_vis` flag as False. This wasn't an elegant behavior as users may
want to enable the visualization from the GUI at runtime. This MR
modifies the behavior of all classes that provide debug visualization to
support this behavior.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 6ae3c3fb
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.9.23" version = "0.9.24"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.9.24 (2023-10-27)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Changed the behavior of setting up debug visualization for assets, sensors and command generators.
Earlier it was raising an error if debug visualization was not enabled in the configuration object.
Now it checks whether debug visualization is implemented and only sets up the callback if it is
implemented.
0.9.23 (2023-10-27) 0.9.23 (2023-10-27)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import inspect
import re import re
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -68,14 +69,10 @@ class AssetBase(ABC): ...@@ -68,14 +69,10 @@ class AssetBase(ABC):
lambda event, obj=weakref.proxy(self): obj._invalidate_initialize_callback(event), lambda event, obj=weakref.proxy(self): obj._invalidate_initialize_callback(event),
order=10, order=10,
) )
# add callback for debug visualization # add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
if self.cfg.debug_vis: self._debug_vis_handle = None
app_interface = omni.kit.app.get_app_interface() # set initial state of debug visualization
self._debug_visualization_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop( self.set_debug_vis(self.cfg.debug_vis)
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event),
)
else:
self._debug_visualization_handle = None
def __del__(self): def __del__(self):
"""Unsubscribe from the callbacks.""" """Unsubscribe from the callbacks."""
...@@ -87,9 +84,9 @@ class AssetBase(ABC): ...@@ -87,9 +84,9 @@ class AssetBase(ABC):
self._invalidate_initialize_handle.unsubscribe() self._invalidate_initialize_handle.unsubscribe()
self._invalidate_initialize_handle = None self._invalidate_initialize_handle = None
# clear debug visualization # clear debug visualization
if self._debug_visualization_handle: if self._debug_vis_handle:
self._debug_visualization_handle.unsubscribe() self._debug_vis_handle.unsubscribe()
self._debug_visualization_handle = None self._debug_vis_handle = None
""" """
Properties Properties
...@@ -107,21 +104,47 @@ class AssetBase(ABC): ...@@ -107,21 +104,47 @@ class AssetBase(ABC):
"""Data related to the asset.""" """Data related to the asset."""
return NotImplementedError return NotImplementedError
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the asset has a debug visualization implemented."""
# check if function raises NotImplementedError
source_code = inspect.getsource(self._debug_vis_callback)
return "NotImplementedError" not in source_code
""" """
Operations. Operations.
""" """
def set_debug_vis(self, debug_vis: bool): def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the asset data. """Sets whether to visualize the asset data.
Args: Args:
debug_vis: Whether to visualize the asset data. debug_vis: Whether to visualize the asset data.
Raises: Returns:
RuntimeError: If the asset debug visualization is not enabled. Whether the debug visualization was successfully set. False if the asset
does not support debug visualization.
""" """
if not self.cfg.debug_vis: # check if debug visualization is supported
raise RuntimeError("Debug visualization is not enabled for this sensor.") if not self.has_debug_vis_implementation:
return False
# toggle debug visualization objects
self._set_debug_vis_impl(debug_vis)
# toggle debug visualization handles
if debug_vis:
# create a subscriber for the post update event if it doesn't exist
if self._debug_vis_handle is None:
app_interface = omni.kit.app.get_app_interface()
self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event)
)
else:
# remove the subscriber if it exists
if self._debug_vis_handle is not None:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
# return success
return True
@abstractmethod @abstractmethod
def reset(self, env_ids: Sequence[int] | None = None): def reset(self, env_ids: Sequence[int] | None = None):
...@@ -158,12 +181,24 @@ class AssetBase(ABC): ...@@ -158,12 +181,24 @@ class AssetBase(ABC):
"""Initializes the PhysX handles and internal buffers.""" """Initializes the PhysX handles and internal buffers."""
raise NotImplementedError raise NotImplementedError
def _debug_vis_impl(self): def _set_debug_vis_impl(self, debug_vis: bool):
"""Perform debug visualization of the asset.""" """Set debug visualization into visualization objects.
pass
This function is responsible for creating the visualization objects if they don't exist
and input ``debug_vis`` is True. If the visualization objects exist, the function should
set their visibility into the stage.
""" """
Simulation callbacks. raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
def _debug_vis_callback(self, event):
"""Callback for debug visualization.
This function calls the visualization objects and sets the data to visualize into them.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
"""
Internal simulation callbacks.
""" """
def _initialize_callback(self, event): def _initialize_callback(self, event):
...@@ -180,7 +215,3 @@ class AssetBase(ABC): ...@@ -180,7 +215,3 @@ class AssetBase(ABC):
def _invalidate_initialize_callback(self, event): def _invalidate_initialize_callback(self, event):
"""Invalidates the scene elements.""" """Invalidates the scene elements."""
self._is_initialized = False self._is_initialized = False
def _debug_vis_callback(self, event):
"""Visualizes the asset data."""
self._debug_vis_impl()
...@@ -12,6 +12,7 @@ methods. ...@@ -12,6 +12,7 @@ methods.
from __future__ import annotations from __future__ import annotations
import inspect
import torch import torch
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -48,6 +49,7 @@ class CommandGeneratorBase(ABC): ...@@ -48,6 +49,7 @@ class CommandGeneratorBase(ABC):
# store the inputs # store the inputs
self.cfg = cfg self.cfg = cfg
self._env = env self._env = env
# create buffers to store the command # create buffers to store the command
# -- metrics that can be used for logging # -- metrics that can be used for logging
self.metrics = dict() self.metrics = dict()
...@@ -56,21 +58,16 @@ class CommandGeneratorBase(ABC): ...@@ -56,21 +58,16 @@ class CommandGeneratorBase(ABC):
# -- counter for the number of times the command has been resampled within the current episode # -- counter for the number of times the command has been resampled within the current episode
self.command_counter = torch.zeros(self.num_envs, device=self.device, dtype=torch.long) self.command_counter = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
# add callback for debug visualization # add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
if self.cfg.debug_vis: self._debug_vis_handle = None
app_interface = omni.kit.app.get_app_interface() # set initial state of debug visualization
# NOTE: Use weakref on callback to ensure that this object can be deleted when its destructor is called. self.set_debug_vis(self.cfg.debug_vis)
self._debug_visualization_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event),
)
else:
self._debug_visualization_handle = None
def __del__(self): def __del__(self):
"""Unsubscribe from the callbacks.""" """Unsubscribe from the callbacks."""
if self._debug_visualization_handle is not None: if self._debug_vis_handle:
self._debug_visualization_handle.unsubscribe() self._debug_vis_handle.unsubscribe()
self._debug_visualization_handle = None self._debug_vis_handle = None
""" """
Properties Properties
...@@ -92,21 +89,47 @@ class CommandGeneratorBase(ABC): ...@@ -92,21 +89,47 @@ class CommandGeneratorBase(ABC):
"""The command tensor. Shape is (num_envs, command_dim).""" """The command tensor. Shape is (num_envs, command_dim)."""
raise NotImplementedError raise NotImplementedError
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the command generator has a debug visualization implemented."""
# check if function raises NotImplementedError
source_code = inspect.getsource(self._debug_vis_callback)
return "NotImplementedError" not in source_code
""" """
Operations. Operations.
""" """
def set_debug_vis(self, debug_vis: bool): def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the command data. """Sets whether to visualize the command data.
Args: Args:
debug_vis: Whether to visualize the command data. debug_vis: Whether to visualize the command data.
Raises: Returns:
RuntimeError: If the command debug visualization is not enabled. Whether the debug visualization was successfully set. False if the command
""" generator does not support debug visualization.
if not self.cfg.debug_vis: """
raise RuntimeError("Debug visualization is not enabled for this sensor.") # check if debug visualization is supported
if not self.has_debug_vis_implementation:
return False
# toggle debug visualization objects
self._set_debug_vis_impl(debug_vis)
# toggle debug visualization handles
if debug_vis:
# create a subscriber for the post update event if it doesn't exist
if self._debug_vis_handle is None:
app_interface = omni.kit.app.get_app_interface()
self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event)
)
else:
# remove the subscriber if it exists
if self._debug_vis_handle is not None:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
# return success
return True
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]: def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
"""Reset the command generator and log metrics. """Reset the command generator and log metrics.
...@@ -173,14 +196,6 @@ class CommandGeneratorBase(ABC): ...@@ -173,14 +196,6 @@ class CommandGeneratorBase(ABC):
# resample the command # resample the command
self._resample_command(env_ids) self._resample_command(env_ids)
"""
Simulation callbacks.
"""
def _debug_vis_callback(self, event):
"""Visualizes the sensor data."""
self._debug_vis_impl()
""" """
Implementation specific functions. Implementation specific functions.
""" """
...@@ -200,9 +215,18 @@ class CommandGeneratorBase(ABC): ...@@ -200,9 +215,18 @@ class CommandGeneratorBase(ABC):
"""Update the metrics based on the current state.""" """Update the metrics based on the current state."""
raise NotImplementedError raise NotImplementedError
def _debug_vis_impl(self): def _set_debug_vis_impl(self, debug_vis: bool):
"""Visualize the command in the simulator. """Set debug visualization into visualization objects.
This function is responsible for creating the visualization objects if they don't exist
and input ``debug_vis`` is True. If the visualization objects exist, the function should
set their visibility into the stage.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
def _debug_vis_callback(self, event):
"""Callback for debug visualization.
This is an optional function that can be used to visualize the command in the simulator. This function calls the visualization objects and sets the data to visualize into them.
""" """
pass raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
...@@ -41,11 +41,16 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase): ...@@ -41,11 +41,16 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
cfg: The configuration parameters for the command generator. cfg: The configuration parameters for the command generator.
env: The environment object. env: The environment object.
""" """
# initialize the base class
super().__init__(cfg, env) super().__init__(cfg, env)
# obtain the robot and terrain assets
# -- robot # -- robot
self.robot: Articulation = env.scene[cfg.asset_name] self.robot: Articulation = env.scene[cfg.asset_name]
# -- terrain # -- terrain
self.terrain: TerrainImporter = env.scene.terrain self.terrain: TerrainImporter = env.scene.terrain
# crete buffers to store the command
# -- commands: (x, y, z, heading) # -- commands: (x, y, z, heading)
self.pos_command_w = torch.zeros(self.num_envs, 3, device=self.device) self.pos_command_w = torch.zeros(self.num_envs, 3, device=self.device)
self.heading_command_w = torch.zeros(self.num_envs, device=self.device) self.heading_command_w = torch.zeros(self.num_envs, device=self.device)
...@@ -54,8 +59,6 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase): ...@@ -54,8 +59,6 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
# -- metrics # -- metrics
self.metrics["error_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_pos"] = torch.zeros(self.num_envs, device=self.device)
self.metrics["error_heading"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_heading"] = torch.zeros(self.num_envs, device=self.device)
# -- debug vis
self.box_goal_visualizer = None
def __str__(self) -> str: def __str__(self) -> str:
msg = "TerrainBasedPositionCommandGenerator:\n" msg = "TerrainBasedPositionCommandGenerator:\n"
...@@ -73,15 +76,6 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase): ...@@ -73,15 +76,6 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
"""The desired base position in base frame. Shape is (num_envs, 3).""" """The desired base position in base frame. Shape is (num_envs, 3)."""
return self.pos_command_b return self.pos_command_b
"""
Operations.
"""
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
if self.box_goal_visualizer is not None:
self.box_goal_visualizer.set_visibility(debug_vis)
""" """
Implementation specific functions. Implementation specific functions.
""" """
...@@ -120,12 +114,20 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase): ...@@ -120,12 +114,20 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
self.metrics["error_pos"] = torch.norm(self.pos_command_w - self.robot.data.root_pos_w[:, :3], dim=1) self.metrics["error_pos"] = torch.norm(self.pos_command_w - self.robot.data.root_pos_w[:, :3], dim=1)
self.metrics["error_heading"] = torch.abs(wrap_to_pi(self.heading_command_w - self.robot.heading_w)) self.metrics["error_heading"] = torch.abs(wrap_to_pi(self.heading_command_w - self.robot.heading_w))
def _debug_vis_impl(self): def _set_debug_vis_impl(self, debug_vis: bool):
# create the box marker if necessary # create markers if necessary for the first tome
if self.box_goal_visualizer is None: if debug_vis:
if not hasattr(self, "box_goal_visualizer"):
marker_cfg = CUBOID_MARKER_CFG.copy() marker_cfg = CUBOID_MARKER_CFG.copy()
marker_cfg.prim_path = "/Visuals/Command/position_goal" marker_cfg.prim_path = "/Visuals/Command/position_goal"
marker_cfg.markers["cuboid"].scale = (0.1, 0.1, 0.1) marker_cfg.markers["cuboid"].scale = (0.1, 0.1, 0.1)
self.box_goal_visualizer = VisualizationMarkers(marker_cfg) self.box_goal_visualizer = VisualizationMarkers(marker_cfg)
# set their visibility to true
self.box_goal_visualizer.set_visibility(True)
else:
if hasattr(self, "box_goal_visualizer"):
self.box_goal_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# update the box marker # update the box marker
self.box_goal_visualizer.visualize(self.pos_command_w) self.box_goal_visualizer.visualize(self.pos_command_w)
...@@ -51,9 +51,14 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase): ...@@ -51,9 +51,14 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
cfg: The configuration of the command generator. cfg: The configuration of the command generator.
env: The environment. env: The environment.
""" """
# initialize the base class
super().__init__(cfg, env) super().__init__(cfg, env)
# obtain the robot asset
# -- robot # -- robot
self.robot: Articulation = env.scene[cfg.asset_name] self.robot: Articulation = env.scene[cfg.asset_name]
# crete buffers to store the command
# -- command: x vel, y vel, yaw vel, heading # -- command: x vel, y vel, yaw vel, heading
self.vel_command_b = torch.zeros(self.num_envs, 3, device=self.device) self.vel_command_b = torch.zeros(self.num_envs, 3, device=self.device)
self.heading_target = torch.zeros(self.num_envs, device=self.device) self.heading_target = torch.zeros(self.num_envs, device=self.device)
...@@ -62,9 +67,6 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase): ...@@ -62,9 +67,6 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
# -- metrics # -- metrics
self.metrics["error_vel_xy"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_vel_xy"] = torch.zeros(self.num_envs, device=self.device)
self.metrics["error_vel_yaw"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_vel_yaw"] = torch.zeros(self.num_envs, device=self.device)
# -- debug vis
self.base_vel_goal_visualizer = None
self.base_vel_visualizer = None
def __str__(self) -> str: def __str__(self) -> str:
"""Return a string representation of the command generator.""" """Return a string representation of the command generator."""
...@@ -86,19 +88,6 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase): ...@@ -86,19 +88,6 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
"""The desired base velocity command in the base frame. Shape is (num_envs, 3).""" """The desired base velocity command in the base frame. Shape is (num_envs, 3)."""
return self.vel_command_b return self.vel_command_b
"""
Operations.
"""
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
# -- current
if self.base_vel_visualizer is not None:
self.base_vel_visualizer.set_visibility(debug_vis)
# -- goal
if self.base_vel_goal_visualizer is not None:
self.base_vel_goal_visualizer.set_visibility(debug_vis)
""" """
Implementation specific functions. Implementation specific functions.
""" """
...@@ -152,20 +141,31 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase): ...@@ -152,20 +141,31 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_time torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_time
) )
def _debug_vis_impl(self): def _set_debug_vis_impl(self, debug_vis: bool):
# create markers if necessary # set visibility of markers
# note: parent only deals with callbacks. not their visibility
if debug_vis:
# create markers if necessary for the first tome
if not hasattr(self, "base_vel_goal_visualizer"):
# -- goal # -- goal
if self.base_vel_goal_visualizer is None:
marker_cfg = GREEN_ARROW_X_MARKER_CFG.copy() marker_cfg = GREEN_ARROW_X_MARKER_CFG.copy()
marker_cfg.prim_path = "/Visuals/Command/velocity_goal" marker_cfg.prim_path = "/Visuals/Command/velocity_goal"
marker_cfg.markers["arrow"].scale = (2.5, 0.1, 0.1) marker_cfg.markers["arrow"].scale = (2.5, 0.1, 0.1)
self.base_vel_goal_visualizer = VisualizationMarkers(marker_cfg) self.base_vel_goal_visualizer = VisualizationMarkers(marker_cfg)
# -- current # -- current
if self.base_vel_visualizer is None:
marker_cfg = BLUE_ARROW_X_MARKER_CFG.copy() marker_cfg = BLUE_ARROW_X_MARKER_CFG.copy()
marker_cfg.prim_path = "/Visuals/Command/velocity_current" marker_cfg.prim_path = "/Visuals/Command/velocity_current"
marker_cfg.markers["arrow"].scale = (2.5, 0.1, 0.1) marker_cfg.markers["arrow"].scale = (2.5, 0.1, 0.1)
self.base_vel_visualizer = VisualizationMarkers(marker_cfg) self.base_vel_visualizer = VisualizationMarkers(marker_cfg)
# set their visibility to true
self.base_vel_goal_visualizer.set_visibility(True)
self.base_vel_visualizer.set_visibility(True)
else:
if hasattr(self, "base_vel_goal_visualizer"):
self.base_vel_goal_visualizer.set_visibility(False)
self.base_vel_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# get marker location # get marker location
# -- base state # -- base state
base_pos_w = self.robot.data.root_pos_w.clone() base_pos_w = self.robot.data.root_pos_w.clone()
...@@ -173,9 +173,8 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase): ...@@ -173,9 +173,8 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
# -- resolve the scales and quaternions # -- resolve the scales and quaternions
vel_des_arrow_scale, vel_des_arrow_quat = self._resolve_xy_velocity_to_arrow(self.command[:, :2]) vel_des_arrow_scale, vel_des_arrow_quat = self._resolve_xy_velocity_to_arrow(self.command[:, :2])
vel_arrow_scale, vel_arrow_quat = self._resolve_xy_velocity_to_arrow(self.robot.data.root_lin_vel_b[:, :2]) vel_arrow_scale, vel_arrow_quat = self._resolve_xy_velocity_to_arrow(self.robot.data.root_lin_vel_b[:, :2])
# -- goal # display markers
self.base_vel_goal_visualizer.visualize(base_pos_w, vel_des_arrow_quat, vel_des_arrow_scale) self.base_vel_goal_visualizer.visualize(base_pos_w, vel_des_arrow_quat, vel_des_arrow_scale)
# -- base velocity
self.base_vel_visualizer.visualize(base_pos_w, vel_arrow_quat, vel_arrow_scale) self.base_vel_visualizer.visualize(base_pos_w, vel_arrow_quat, vel_arrow_scale)
""" """
......
...@@ -10,6 +10,7 @@ import gym ...@@ -10,6 +10,7 @@ import gym
import math import math
import numpy as np import numpy as np
import torch import torch
import weakref
from typing import Any, ClassVar, Dict, Sequence, Tuple, Union from typing import Any, ClassVar, Dict, Sequence, Tuple, Union
import omni.usd import omni.usd
...@@ -445,59 +446,42 @@ class RLEnv(BaseEnv, gym.Env): ...@@ -445,59 +446,42 @@ class RLEnv(BaseEnv, gym.Env):
# create stack for debug visualization # create stack for debug visualization
self._orbit_window_elements["debug_vstack"] = ui.VStack(spacing=5, height=0) self._orbit_window_elements["debug_vstack"] = ui.VStack(spacing=5, height=0)
with self._orbit_window_elements["debug_vstack"]: with self._orbit_window_elements["debug_vstack"]:
elements = [
self.scene.terrain,
self.command_manager,
*self.scene.rigid_objects.values(),
*self.scene.articulations.values(),
*self.scene.sensors.values(),
]
names = [
"terrain",
"commands",
*self.scene.rigid_objects.keys(),
*self.scene.articulations.keys(),
*self.scene.sensors.keys(),
]
# create one for the terrain # create one for the terrain
if self.scene.terrain is not None: for elem, name in zip(elements, names):
if elem is not None:
with ui.HStack(): with ui.HStack():
# create the UI element # create the UI element
debug_vis_checkbox = { text = (
"model": ui.SimpleBoolModel(default_value=self.scene.terrain.cfg.debug_vis), "Toggle debug visualization."
"enabled": self.scene.terrain.cfg.debug_vis, if elem.has_debug_vis_implementation
"checked": self.scene.terrain.cfg.debug_vis, else "Debug visualization not implemented."
"on_checked_fn": lambda value: self.scene.terrain.set_debug_vis(value),
}
ui.Label(
"Terrain",
width=ui_utils.LABEL_WIDTH - 12,
alignment=ui.Alignment.LEFT_CENTER,
tooltip="Toggle debug visualization",
) )
self._orbit_window_elements["terrain_cb"] = SimpleCheckBox(**debug_vis_checkbox)
ui_utils.add_line_rect_flourish()
# iterate over each scene element and add a checkbox for debug visualization
for name, element in self.scene.sensors.items():
with ui.HStack():
# create the UI element
# note: need to deal with closure of lambda function inside for loop
# ref: https://stackoverflow.com/questions/66131048/python-lambda-function-is-not-being-called-correctly-from-within-a-for-loop
debug_vis_checkbox = {
"model": ui.SimpleBoolModel(default_value=element.cfg.debug_vis),
"enabled": element.cfg.debug_vis,
"checked": element.cfg.debug_vis,
"on_checked_fn": lambda value, ele=element: ele.set_debug_vis(value),
}
ui.Label( ui.Label(
ui_utils.format_tt(name.replace("_", " ")), name.replace("_", " ").title(),
width=ui_utils.LABEL_WIDTH - 12, width=ui_utils.LABEL_WIDTH - 12,
alignment=ui.Alignment.LEFT_CENTER, alignment=ui.Alignment.LEFT_CENTER,
tooltip="Toggle debug visualization", tooltip=text,
) )
self._orbit_window_elements[f"sensor_{name}_cb"] = SimpleCheckBox(**debug_vis_checkbox) self._orbit_window_elements[f"{name}_cb"] = SimpleCheckBox(
ui_utils.add_line_rect_flourish() model=ui.SimpleBoolModel(),
# create one for the command manager enabled=elem.has_debug_vis_implementation,
with ui.HStack(): checked=elem.cfg.debug_vis,
debug_vis_checkbox = { on_checked_fn=lambda value, e=weakref.proxy(elem): e.set_debug_vis(value),
"model": ui.SimpleBoolModel(default_value=self.command_manager.cfg.debug_vis),
"enabled": self.command_manager.cfg.debug_vis,
"checked": self.command_manager.cfg.debug_vis,
"on_checked_fn": lambda value: self.command_manager.set_debug_vis(value),
}
ui.Label(
"Command Manager",
width=ui_utils.LABEL_WIDTH - 12,
alignment=ui.Alignment.LEFT_CENTER,
tooltip="Toggle debug visualization",
) )
self._orbit_window_elements["command_cb"] = SimpleCheckBox(**debug_vis_checkbox)
ui_utils.add_line_rect_flourish() ui_utils.add_line_rect_flourish()
async def _dock_window(self, window_title: str): async def _dock_window(self, window_title: str):
......
...@@ -58,8 +58,6 @@ class ContactSensor(SensorBase): ...@@ -58,8 +58,6 @@ class ContactSensor(SensorBase):
super().__init__(cfg) super().__init__(cfg)
# Create empty variables for storing output data # Create empty variables for storing output data
self._data: ContactSensorData = ContactSensorData() self._data: ContactSensorData = ContactSensorData()
# visualization markers
self.contact_visualizer = None
def __str__(self) -> str: def __str__(self) -> str:
"""Returns: A string containing information about the instance.""" """Returns: A string containing information about the instance."""
...@@ -127,11 +125,6 @@ class ContactSensor(SensorBase): ...@@ -127,11 +125,6 @@ class ContactSensor(SensorBase):
Operations Operations
""" """
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
if self.contact_visualizer is not None:
self.contact_visualizer.set_visibility(debug_vis)
def reset(self, env_ids: Sequence[int] | None = None): def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters # reset the timers and counters
super().reset(env_ids) super().reset(env_ids)
...@@ -281,11 +274,21 @@ class ContactSensor(SensorBase): ...@@ -281,11 +274,21 @@ class ContactSensor(SensorBase):
# -- increment timers for bodies that are not in contact # -- increment timers for bodies that are not in contact
self._data.current_air_time[env_ids] *= ~is_contact self._data.current_air_time[env_ids] *= ~is_contact
def _debug_vis_impl(self): def _set_debug_vis_impl(self, debug_vis: bool):
# visualize the contacts # set visibility of markers
if self.contact_visualizer is None: # note: parent only deals with callbacks. not their visibility
if debug_vis:
# create markers if necessary for the first tome
if not hasattr(self, "contact_visualizer"):
visualizer_cfg = CONTACT_SENSOR_MARKER_CFG.replace(prim_path="/Visuals/ContactSensor") visualizer_cfg = CONTACT_SENSOR_MARKER_CFG.replace(prim_path="/Visuals/ContactSensor")
self.contact_visualizer = VisualizationMarkers(visualizer_cfg) self.contact_visualizer = VisualizationMarkers(visualizer_cfg)
# set their visibility to true
self.contact_visualizer.set_visibility(True)
else:
if hasattr(self, "contact_visualizer"):
self.contact_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# safely return if view becomes invalid # safely return if view becomes invalid
# note: this invalidity happens because of isaac sim view callbacks # note: this invalidity happens because of isaac sim view callbacks
if self.body_physx_view is None: if self.body_physx_view is None:
......
...@@ -54,14 +54,12 @@ class RayCaster(SensorBase): ...@@ -54,14 +54,12 @@ class RayCaster(SensorBase):
Args: Args:
cfg: The configuration parameters. cfg: The configuration parameters.
""" """
# initialize base class # Initialize base class
super().__init__(cfg) super().__init__(cfg)
# Create empty variables for storing output data # Create empty variables for storing output data
self._data = RayCasterData() self._data = RayCasterData()
# List of meshes to ray-cast # List of meshes to ray-cast
self.warp_meshes = [] self.warp_meshes = []
# visualization markers
self.ray_visualizer = None
def __str__(self) -> str: def __str__(self) -> str:
"""Returns: A string containing information about the instance.""" """Returns: A string containing information about the instance."""
...@@ -90,11 +88,6 @@ class RayCaster(SensorBase): ...@@ -90,11 +88,6 @@ class RayCaster(SensorBase):
Operations. Operations.
""" """
def set_debug_vis(self, debug_vis: bool):
super().set_debug_vis(debug_vis)
if self.ray_visualizer is not None:
self.ray_visualizer.set_visibility(debug_vis)
def reset(self, env_ids: Sequence[int] | None = None): def reset(self, env_ids: Sequence[int] | None = None):
# reset the timers and counters # reset the timers and counters
super().reset(env_ids) super().reset(env_ids)
...@@ -219,10 +212,19 @@ class RayCaster(SensorBase): ...@@ -219,10 +212,19 @@ class RayCaster(SensorBase):
# TODO: Make this work for multiple meshes? # TODO: Make this work for multiple meshes?
self._data.ray_hits_w[env_ids] = raycast_mesh(ray_starts_w, ray_directions_w, self.warp_meshes[0]) self._data.ray_hits_w[env_ids] = raycast_mesh(ray_starts_w, ray_directions_w, self.warp_meshes[0])
def _debug_vis_impl(self): def _set_debug_vis_impl(self, debug_vis: bool):
# visualize the point hits # set visibility of markers
if self.ray_visualizer is None: # note: parent only deals with callbacks. not their visibility
if debug_vis:
if not hasattr(self, "ray_visualizer"):
visualizer_cfg = RAY_CASTER_MARKER_CFG.replace(prim_path="/Visuals/RayCaster") visualizer_cfg = RAY_CASTER_MARKER_CFG.replace(prim_path="/Visuals/RayCaster")
self.ray_visualizer = VisualizationMarkers(visualizer_cfg) self.ray_visualizer = VisualizationMarkers(visualizer_cfg)
# check if prim is visualized # set their visibility to true
self.ray_visualizer.set_visibility(True)
else:
if hasattr(self, "ray_visualizer"):
self.ray_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# show ray hit positions
self.ray_visualizer.visualize(self._data.ray_hits_w.view(-1, 3)) self.ray_visualizer.visualize(self._data.ray_hits_w.view(-1, 3))
...@@ -11,6 +11,7 @@ Each sensor class should inherit from this class and implement the abstract meth ...@@ -11,6 +11,7 @@ Each sensor class should inherit from this class and implement the abstract meth
from __future__ import annotations from __future__ import annotations
import inspect
import torch import torch
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -65,14 +66,10 @@ class SensorBase(ABC): ...@@ -65,14 +66,10 @@ class SensorBase(ABC):
lambda event, obj=weakref.proxy(self): obj._invalidate_initialize_callback(event), lambda event, obj=weakref.proxy(self): obj._invalidate_initialize_callback(event),
order=10, order=10,
) )
# add callback for debug visualization # add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
if self.cfg.debug_vis: self._debug_vis_handle = None
app_interface = omni.kit.app.get_app_interface() # set initial state of debug visualization
self._debug_visualization_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop( self.set_debug_vis(self.cfg.debug_vis)
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event),
)
else:
self._debug_visualization_handle = None
def __del__(self): def __del__(self):
"""Unsubscribe from the callbacks.""" """Unsubscribe from the callbacks."""
...@@ -84,9 +81,9 @@ class SensorBase(ABC): ...@@ -84,9 +81,9 @@ class SensorBase(ABC):
self._invalidate_initialize_handle.unsubscribe() self._invalidate_initialize_handle.unsubscribe()
self._invalidate_initialize_handle = None self._invalidate_initialize_handle = None
# clear debug visualization # clear debug visualization
if self._debug_visualization_handle: if self._debug_vis_handle:
self._debug_visualization_handle.unsubscribe() self._debug_vis_handle.unsubscribe()
self._debug_visualization_handle = None self._debug_vis_handle = None
""" """
Properties Properties
...@@ -117,21 +114,48 @@ class SensorBase(ABC): ...@@ -117,21 +114,48 @@ class SensorBase(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the sensor has a debug visualization implemented."""
# check if function raises NotImplementedError
# check if function raises NotImplementedError
source_code = inspect.getsource(self._debug_vis_callback)
return "NotImplementedError" not in source_code
""" """
Operations Operations
""" """
def set_debug_vis(self, debug_vis: bool): def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the sensor data. """Sets whether to visualize the sensor data.
Args: Args:
debug_vis: Whether to visualize the sensor data. debug_vis: Whether to visualize the sensor data.
Raises: Returns:
RuntimeError: If the asset debug visualization is not enabled. Whether the debug visualization was successfully set. False if the sensor
does not support debug visualization.
""" """
if not self.cfg.debug_vis: # check if debug visualization is supported
raise RuntimeError("Debug visualization is not enabled for this sensor.") if not self.has_debug_vis_implementation:
return False
# toggle debug visualization objects
self._set_debug_vis_impl(debug_vis)
# toggle debug visualization handles
if debug_vis:
# create a subscriber for the post update event if it doesn't exist
if self._debug_vis_handle is None:
app_interface = omni.kit.app.get_app_interface()
self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event)
)
else:
# remove the subscriber if it exists
if self._debug_vis_handle is not None:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
# return success
return True
def reset(self, env_ids: Sequence[int] | None = None): def reset(self, env_ids: Sequence[int] | None = None):
"""Resets the sensor internals. """Resets the sensor internals.
...@@ -194,19 +218,24 @@ class SensorBase(ABC): ...@@ -194,19 +218,24 @@ class SensorBase(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def _debug_vis_impl(self): def _set_debug_vis_impl(self, debug_vis: bool):
"""Visualizes the sensor data. """Set debug visualization into visualization objects.
This is an empty function that can be overridden by the derived class to visualize the sensor data. This function is responsible for creating the visualization objects if they don't exist
and input ``debug_vis`` is True. If the visualization objects exist, the function should
set their visibility into the stage.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
Note: def _debug_vis_callback(self, event):
Visualization of sensor data may add overhead to the simulation. It is recommended to disable """Callback for debug visualization.
visualization when running the simulation in headless mode.
This function calls the visualization objects and sets the data to visualize into them.
""" """
pass raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
""" """
Simulation callbacks. Internal simulation callbacks.
""" """
def _initialize_callback(self, event): def _initialize_callback(self, event):
...@@ -224,10 +253,6 @@ class SensorBase(ABC): ...@@ -224,10 +253,6 @@ class SensorBase(ABC):
"""Invalidates the scene elements.""" """Invalidates the scene elements."""
self._is_initialized = False self._is_initialized = False
def _debug_vis_callback(self, event):
"""Visualizes the sensor data."""
self._debug_vis_impl()
""" """
Helper functions. Helper functions.
""" """
......
...@@ -78,13 +78,6 @@ class TerrainImporter: ...@@ -78,13 +78,6 @@ class TerrainImporter:
self.warp_meshes = dict() self.warp_meshes = dict()
self.env_origins = None self.env_origins = None
self.terrain_origins = None self.terrain_origins = None
# marker for visualization
if self.cfg.debug_vis:
self.origin_visualizer = VisualizationMarkers(
cfg=FRAME_MARKER_CFG.replace(prim_path="/Visuals/TerrainOrigin")
)
else:
self.origin_visualizer = None
# auto-import the terrain based on the config # auto-import the terrain based on the config
if self.cfg.terrain_type == "generator": if self.cfg.terrain_type == "generator":
...@@ -112,20 +105,57 @@ class TerrainImporter: ...@@ -112,20 +105,57 @@ class TerrainImporter:
else: else:
raise ValueError(f"Terrain type '{self.cfg.terrain_type}' not available.") raise ValueError(f"Terrain type '{self.cfg.terrain_type}' not available.")
# set initial state of debug visualization
self.set_debug_vis(self.cfg.debug_vis)
"""
Properties.
"""
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the terrain importer has a debug visualization implemented.
This always returns True.
"""
return True
""" """
Operations - Visibility. Operations - Visibility.
""" """
def set_debug_vis(self, debug_vis: bool): def set_debug_vis(self, debug_vis: bool) -> bool:
"""Set the debug visualization of the terrain importer. """Set the debug visualization of the terrain importer.
Args: Args:
debug_vis: Whether to visualize the terrain origins. debug_vis: Whether to visualize the terrain origins.
Returns:
Whether the debug visualization was successfully set. False if the terrain
importer does not support debug visualization.
Raises:
RuntimeError: If terrain origins are not configured.
""" """
if not self.cfg.debug_vis: # create a marker if necessary
raise RuntimeError("Debug visualization is not enabled for this sensor.") if debug_vis:
if not hasattr(self, "origin_visualizer"):
self.origin_visualizer = VisualizationMarkers(
cfg=FRAME_MARKER_CFG.replace(prim_path="/Visuals/TerrainOrigin")
)
if self.terrain_origins is not None:
self.origin_visualizer.visualize(self.terrain_origins.reshape(-1, 3))
elif self.env_origins is not None:
self.origin_visualizer.visualize(self.env_origins.reshape(-1, 3))
else:
raise RuntimeError("Terrain origins are not configured.")
# set visibility # set visibility
self.origin_visualizer.set_visibility(debug_vis) self.origin_visualizer.set_visibility(True)
else:
if hasattr(self, "origin_visualizer"):
self.origin_visualizer.set_visibility(False)
# report success
return True
""" """
Operations - Import. Operations - Import.
...@@ -251,9 +281,6 @@ class TerrainImporter: ...@@ -251,9 +281,6 @@ class TerrainImporter:
self.terrain_origins = origins.to(self.device, dtype=torch.float) self.terrain_origins = origins.to(self.device, dtype=torch.float)
# compute environment origins # compute environment origins
self.env_origins = self._compute_env_origins_curriculum(self.cfg.num_envs, self.terrain_origins) self.env_origins = self._compute_env_origins_curriculum(self.cfg.num_envs, self.terrain_origins)
# put markers on the sub-terrain origins
if self.origin_visualizer is not None:
self.origin_visualizer.visualize(self.terrain_origins.reshape(-1, 3))
else: else:
self.terrain_origins = None self.terrain_origins = None
# check if env spacing is valid # check if env spacing is valid
...@@ -261,9 +288,6 @@ class TerrainImporter: ...@@ -261,9 +288,6 @@ class TerrainImporter:
raise ValueError("Environment spacing must be specified for configuring grid-like origins.") raise ValueError("Environment spacing must be specified for configuring grid-like origins.")
# compute environment origins # compute environment origins
self.env_origins = self._compute_env_origins_grid(self.cfg.num_envs, self.cfg.env_spacing) self.env_origins = self._compute_env_origins_grid(self.cfg.num_envs, self.cfg.env_spacing)
# put markers on the grid origins
if self.origin_visualizer is not None:
self.origin_visualizer.visualize(self.env_origins.reshape(-1, 3))
def update_env_origins(self, env_ids: torch.Tensor, move_up: torch.Tensor, move_down: torch.Tensor): def update_env_origins(self, env_ids: torch.Tensor, move_up: torch.Tensor, move_down: torch.Tensor):
"""Update the environment origins based on the terrain levels.""" """Update the environment origins based on the terrain levels."""
......
...@@ -56,7 +56,7 @@ class TerrainSceneCfg(InteractiveSceneCfg): ...@@ -56,7 +56,7 @@ class TerrainSceneCfg(InteractiveSceneCfg):
static_friction=1.0, static_friction=1.0,
dynamic_friction=1.0, dynamic_friction=1.0,
), ),
debug_vis=True, debug_vis=False,
) )
# robots # robots
robot = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") robot = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment