Unverified Commit 40cdf401 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Makes function order consistent in the command terms (#496)

# Description

This MR makes sure that the function ordering is consistent with their
call order in the command terms.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have run all the tests with `./orbit.sh --test` and they pass
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 06ada157
...@@ -7,9 +7,11 @@ from __future__ import annotations ...@@ -7,9 +7,11 @@ from __future__ import annotations
import builtins import builtins
import torch import torch
import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Dict from typing import Any, Dict
import carb
import omni.isaac.core.utils.torch as torch_utils import omni.isaac.core.utils.torch as torch_utils
from omni.isaac.orbit.managers import ActionManager, EventManager, ObservationManager from omni.isaac.orbit.managers import ActionManager, EventManager, ObservationManager
...@@ -202,6 +204,17 @@ class BaseEnv: ...@@ -202,6 +204,17 @@ class BaseEnv:
:meth:`SimulationContext.reset_async` and it isn't possible to call async functions in the constructor. :meth:`SimulationContext.reset_async` and it isn't possible to call async functions in the constructor.
""" """
# check the configs
if self.cfg.randomization is not None:
msg = (
"The 'randomization' attribute is deprecated and will be removed in a future release. "
"Please use the 'events' attribute to configure the randomization settings."
)
warnings.warn(msg, category=DeprecationWarning)
carb.log_warn(msg)
# set the randomization as events (for backward compatibility)
self.cfg.events = self.cfg.randomization
# prepare the managers # prepare the managers
# -- action manager # -- action manager
self.action_manager = ActionManager(self.cfg.actions, self) self.action_manager = ActionManager(self.cfg.actions, self)
......
...@@ -11,7 +11,6 @@ configuring the environment instances, viewer settings, and simulation parameter ...@@ -11,7 +11,6 @@ configuring the environment instances, viewer settings, and simulation parameter
from __future__ import annotations from __future__ import annotations
import warnings
from dataclasses import MISSING from dataclasses import MISSING
from typing import Literal from typing import Literal
...@@ -141,12 +140,3 @@ class BaseEnvCfg: ...@@ -141,12 +140,3 @@ class BaseEnvCfg:
attribute to configure the randomization settings. attribute to configure the randomization settings.
""" """
def __post_init__(self):
if self.randomization is not None:
warnings.warn(
"The 'randomization' attribute is deprecated and will be removed in a future release. "
"Please use the 'events' attribute to configure the randomization settings.",
DeprecationWarning,
)
self.events = self.randomization
...@@ -59,11 +59,11 @@ class NullCommand(CommandTerm): ...@@ -59,11 +59,11 @@ class NullCommand(CommandTerm):
Implementation specific functions. Implementation specific functions.
""" """
def _resample_command(self, env_ids: Sequence[int]): def _update_metrics(self):
pass pass
def _update_command(self): def _resample_command(self, env_ids: Sequence[int]):
pass pass
def _update_metrics(self): def _update_command(self):
pass pass
...@@ -80,6 +80,11 @@ class UniformPose2dCommand(CommandTerm): ...@@ -80,6 +80,11 @@ class UniformPose2dCommand(CommandTerm):
Implementation specific functions. Implementation specific functions.
""" """
def _update_metrics(self):
# logs data
self.metrics["error_pos_2d"] = torch.norm(self.pos_command_w[:, :2] - self.robot.data.root_pos_w[:, :2], dim=1)
self.metrics["error_heading"] = torch.abs(wrap_to_pi(self.heading_command_w - self.robot.data.heading_w))
def _resample_command(self, env_ids: Sequence[int]): def _resample_command(self, env_ids: Sequence[int]):
# obtain env origins for the environments # obtain env origins for the environments
self.pos_command_w[env_ids] = self._env.scene.env_origins[env_ids] self.pos_command_w[env_ids] = self._env.scene.env_origins[env_ids]
...@@ -116,11 +121,6 @@ class UniformPose2dCommand(CommandTerm): ...@@ -116,11 +121,6 @@ class UniformPose2dCommand(CommandTerm):
self.pos_command_b[:] = quat_rotate_inverse(yaw_quat(self.robot.data.root_quat_w), target_vec) self.pos_command_b[:] = quat_rotate_inverse(yaw_quat(self.robot.data.root_quat_w), target_vec)
self.heading_command_b[:] = wrap_to_pi(self.heading_command_w - self.robot.data.heading_w) self.heading_command_b[:] = wrap_to_pi(self.heading_command_w - self.robot.data.heading_w)
def _update_metrics(self):
# logs data
self.metrics["error_pos_2d"] = torch.norm(self.pos_command_w[:, :2] - self.robot.data.root_pos_w[:, :2], dim=1)
self.metrics["error_heading"] = torch.abs(wrap_to_pi(self.heading_command_w - self.robot.data.heading_w))
def _set_debug_vis_impl(self, debug_vis: bool): def _set_debug_vis_impl(self, debug_vis: bool):
# create markers if necessary for the first tome # create markers if necessary for the first tome
if debug_vis: if debug_vis:
......
...@@ -90,6 +90,24 @@ class UniformPoseCommand(CommandTerm): ...@@ -90,6 +90,24 @@ class UniformPoseCommand(CommandTerm):
Implementation specific functions. Implementation specific functions.
""" """
def _update_metrics(self):
# transform command from base frame to simulation world frame
self.pose_command_w[:, :3], self.pose_command_w[:, 3:] = combine_frame_transforms(
self.robot.data.root_pos_w,
self.robot.data.root_quat_w,
self.pose_command_b[:, :3],
self.pose_command_b[:, 3:],
)
# compute the error
pos_error, rot_error = compute_pose_error(
self.pose_command_w[:, :3],
self.pose_command_w[:, 3:],
self.robot.data.body_state_w[:, self.body_idx, :3],
self.robot.data.body_state_w[:, self.body_idx, 3:7],
)
self.metrics["position_error"] = torch.norm(pos_error, dim=-1)
self.metrics["orientation_error"] = torch.norm(rot_error, dim=-1)
def _resample_command(self, env_ids: Sequence[int]): def _resample_command(self, env_ids: Sequence[int]):
# sample new pose targets # sample new pose targets
# -- position # -- position
...@@ -109,24 +127,6 @@ class UniformPoseCommand(CommandTerm): ...@@ -109,24 +127,6 @@ class UniformPoseCommand(CommandTerm):
def _update_command(self): def _update_command(self):
pass pass
def _update_metrics(self):
# transform command from base frame to simulation world frame
self.pose_command_w[:, :3], self.pose_command_w[:, 3:] = combine_frame_transforms(
self.robot.data.root_pos_w,
self.robot.data.root_quat_w,
self.pose_command_b[:, :3],
self.pose_command_b[:, 3:],
)
# compute the error
pos_error, rot_error = compute_pose_error(
self.pose_command_w[:, :3],
self.pose_command_w[:, 3:],
self.robot.data.body_state_w[:, self.body_idx, :3],
self.robot.data.body_state_w[:, self.body_idx, 3:7],
)
self.metrics["position_error"] = torch.norm(pos_error, dim=-1)
self.metrics["orientation_error"] = torch.norm(rot_error, dim=-1)
def _set_debug_vis_impl(self, debug_vis: bool): def _set_debug_vis_impl(self, debug_vis: bool):
# create markers if necessary for the first tome # create markers if necessary for the first tome
if debug_vis: if debug_vis:
......
...@@ -92,6 +92,18 @@ class UniformVelocityCommand(CommandTerm): ...@@ -92,6 +92,18 @@ class UniformVelocityCommand(CommandTerm):
Implementation specific functions. Implementation specific functions.
""" """
def _update_metrics(self):
# time for which the command was executed
max_command_time = self.cfg.resampling_time_range[1]
max_command_step = max_command_time / self._env.step_dt
# logs data
self.metrics["error_vel_xy"] += (
torch.norm(self.vel_command_b[:, :2] - self.robot.data.root_lin_vel_b[:, :2], dim=-1) / max_command_step
)
self.metrics["error_vel_yaw"] += (
torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_step
)
def _resample_command(self, env_ids: Sequence[int]): def _resample_command(self, env_ids: Sequence[int]):
# sample velocity commands # sample velocity commands
r = torch.empty(len(env_ids), device=self.device) r = torch.empty(len(env_ids), device=self.device)
...@@ -131,18 +143,6 @@ class UniformVelocityCommand(CommandTerm): ...@@ -131,18 +143,6 @@ class UniformVelocityCommand(CommandTerm):
standing_env_ids = self.is_standing_env.nonzero(as_tuple=False).flatten() standing_env_ids = self.is_standing_env.nonzero(as_tuple=False).flatten()
self.vel_command_b[standing_env_ids, :] = 0.0 self.vel_command_b[standing_env_ids, :] = 0.0
def _update_metrics(self):
# time for which the command was executed
max_command_time = self.cfg.resampling_time_range[1]
max_command_step = max_command_time / self._env.step_dt
# logs data
self.metrics["error_vel_xy"] += (
torch.norm(self.vel_command_b[:, :2] - self.robot.data.root_lin_vel_b[:, :2], dim=-1) / max_command_step
)
self.metrics["error_vel_yaw"] += (
torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_step
)
def _set_debug_vis_impl(self, debug_vis: bool): def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers # set visibility of markers
# note: parent only deals with callbacks. not their visibility # note: parent only deals with callbacks. not their visibility
......
...@@ -186,6 +186,11 @@ class CommandTerm(ManagerTermBase): ...@@ -186,6 +186,11 @@ class CommandTerm(ManagerTermBase):
Implementation specific functions. Implementation specific functions.
""" """
@abstractmethod
def _update_metrics(self):
"""Update the metrics based on the current state."""
raise NotImplementedError
@abstractmethod @abstractmethod
def _resample_command(self, env_ids: Sequence[int]): def _resample_command(self, env_ids: Sequence[int]):
"""Resample the command for the specified environments.""" """Resample the command for the specified environments."""
...@@ -196,11 +201,6 @@ class CommandTerm(ManagerTermBase): ...@@ -196,11 +201,6 @@ class CommandTerm(ManagerTermBase):
"""Update the command based on the current state.""" """Update the command based on the current state."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def _update_metrics(self):
"""Update the metrics based on the current state."""
raise NotImplementedError
def _set_debug_vis_impl(self, debug_vis: bool): def _set_debug_vis_impl(self, debug_vis: bool):
"""Set debug visualization into visualization objects. """Set debug visualization into visualization objects.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment