Unverified Commit 40cdf401 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Makes function order consistent in the command terms (#496)

# Description

This MR makes sure that the function ordering is consistent with their
call order in the command terms.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have run all the tests with `./orbit.sh --test` and they pass
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 06ada157
......@@ -7,9 +7,11 @@ from __future__ import annotations
import builtins
import torch
import warnings
from collections.abc import Sequence
from typing import Any, Dict
import carb
import omni.isaac.core.utils.torch as torch_utils
from omni.isaac.orbit.managers import ActionManager, EventManager, ObservationManager
......@@ -202,6 +204,17 @@ class BaseEnv:
:meth:`SimulationContext.reset_async` and it isn't possible to call async functions in the constructor.
"""
# check the configs
if self.cfg.randomization is not None:
msg = (
"The 'randomization' attribute is deprecated and will be removed in a future release. "
"Please use the 'events' attribute to configure the randomization settings."
)
warnings.warn(msg, category=DeprecationWarning)
carb.log_warn(msg)
# set the randomization as events (for backward compatibility)
self.cfg.events = self.cfg.randomization
# prepare the managers
# -- action manager
self.action_manager = ActionManager(self.cfg.actions, self)
......
......@@ -11,7 +11,6 @@ configuring the environment instances, viewer settings, and simulation parameter
from __future__ import annotations
import warnings
from dataclasses import MISSING
from typing import Literal
......@@ -141,12 +140,3 @@ class BaseEnvCfg:
attribute to configure the randomization settings.
"""
def __post_init__(self):
if self.randomization is not None:
warnings.warn(
"The 'randomization' attribute is deprecated and will be removed in a future release. "
"Please use the 'events' attribute to configure the randomization settings.",
DeprecationWarning,
)
self.events = self.randomization
......@@ -59,11 +59,11 @@ class NullCommand(CommandTerm):
Implementation specific functions.
"""
def _resample_command(self, env_ids: Sequence[int]):
def _update_metrics(self):
pass
def _update_command(self):
def _resample_command(self, env_ids: Sequence[int]):
pass
def _update_metrics(self):
def _update_command(self):
pass
......@@ -80,6 +80,11 @@ class UniformPose2dCommand(CommandTerm):
Implementation specific functions.
"""
def _update_metrics(self):
# logs data
self.metrics["error_pos_2d"] = torch.norm(self.pos_command_w[:, :2] - self.robot.data.root_pos_w[:, :2], dim=1)
self.metrics["error_heading"] = torch.abs(wrap_to_pi(self.heading_command_w - self.robot.data.heading_w))
def _resample_command(self, env_ids: Sequence[int]):
# obtain env origins for the environments
self.pos_command_w[env_ids] = self._env.scene.env_origins[env_ids]
......@@ -116,11 +121,6 @@ class UniformPose2dCommand(CommandTerm):
self.pos_command_b[:] = quat_rotate_inverse(yaw_quat(self.robot.data.root_quat_w), target_vec)
self.heading_command_b[:] = wrap_to_pi(self.heading_command_w - self.robot.data.heading_w)
def _update_metrics(self):
# logs data
self.metrics["error_pos_2d"] = torch.norm(self.pos_command_w[:, :2] - self.robot.data.root_pos_w[:, :2], dim=1)
self.metrics["error_heading"] = torch.abs(wrap_to_pi(self.heading_command_w - self.robot.data.heading_w))
def _set_debug_vis_impl(self, debug_vis: bool):
# create markers if necessary for the first tome
if debug_vis:
......
......@@ -90,6 +90,24 @@ class UniformPoseCommand(CommandTerm):
Implementation specific functions.
"""
def _update_metrics(self):
# transform command from base frame to simulation world frame
self.pose_command_w[:, :3], self.pose_command_w[:, 3:] = combine_frame_transforms(
self.robot.data.root_pos_w,
self.robot.data.root_quat_w,
self.pose_command_b[:, :3],
self.pose_command_b[:, 3:],
)
# compute the error
pos_error, rot_error = compute_pose_error(
self.pose_command_w[:, :3],
self.pose_command_w[:, 3:],
self.robot.data.body_state_w[:, self.body_idx, :3],
self.robot.data.body_state_w[:, self.body_idx, 3:7],
)
self.metrics["position_error"] = torch.norm(pos_error, dim=-1)
self.metrics["orientation_error"] = torch.norm(rot_error, dim=-1)
def _resample_command(self, env_ids: Sequence[int]):
# sample new pose targets
# -- position
......@@ -109,24 +127,6 @@ class UniformPoseCommand(CommandTerm):
def _update_command(self):
pass
def _update_metrics(self):
# transform command from base frame to simulation world frame
self.pose_command_w[:, :3], self.pose_command_w[:, 3:] = combine_frame_transforms(
self.robot.data.root_pos_w,
self.robot.data.root_quat_w,
self.pose_command_b[:, :3],
self.pose_command_b[:, 3:],
)
# compute the error
pos_error, rot_error = compute_pose_error(
self.pose_command_w[:, :3],
self.pose_command_w[:, 3:],
self.robot.data.body_state_w[:, self.body_idx, :3],
self.robot.data.body_state_w[:, self.body_idx, 3:7],
)
self.metrics["position_error"] = torch.norm(pos_error, dim=-1)
self.metrics["orientation_error"] = torch.norm(rot_error, dim=-1)
def _set_debug_vis_impl(self, debug_vis: bool):
# create markers if necessary for the first tome
if debug_vis:
......
......@@ -92,6 +92,18 @@ class UniformVelocityCommand(CommandTerm):
Implementation specific functions.
"""
def _update_metrics(self):
# time for which the command was executed
max_command_time = self.cfg.resampling_time_range[1]
max_command_step = max_command_time / self._env.step_dt
# logs data
self.metrics["error_vel_xy"] += (
torch.norm(self.vel_command_b[:, :2] - self.robot.data.root_lin_vel_b[:, :2], dim=-1) / max_command_step
)
self.metrics["error_vel_yaw"] += (
torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_step
)
def _resample_command(self, env_ids: Sequence[int]):
# sample velocity commands
r = torch.empty(len(env_ids), device=self.device)
......@@ -131,18 +143,6 @@ class UniformVelocityCommand(CommandTerm):
standing_env_ids = self.is_standing_env.nonzero(as_tuple=False).flatten()
self.vel_command_b[standing_env_ids, :] = 0.0
def _update_metrics(self):
# time for which the command was executed
max_command_time = self.cfg.resampling_time_range[1]
max_command_step = max_command_time / self._env.step_dt
# logs data
self.metrics["error_vel_xy"] += (
torch.norm(self.vel_command_b[:, :2] - self.robot.data.root_lin_vel_b[:, :2], dim=-1) / max_command_step
)
self.metrics["error_vel_yaw"] += (
torch.abs(self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]) / max_command_step
)
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
# note: parent only deals with callbacks. not their visibility
......
......@@ -186,6 +186,11 @@ class CommandTerm(ManagerTermBase):
Implementation specific functions.
"""
@abstractmethod
def _update_metrics(self):
"""Update the metrics based on the current state."""
raise NotImplementedError
@abstractmethod
def _resample_command(self, env_ids: Sequence[int]):
"""Resample the command for the specified environments."""
......@@ -196,11 +201,6 @@ class CommandTerm(ManagerTermBase):
"""Update the command based on the current state."""
raise NotImplementedError
@abstractmethod
def _update_metrics(self):
"""Update the metrics based on the current state."""
raise NotImplementedError
def _set_debug_vis_impl(self, debug_vis: bool):
"""Set debug visualization into visualization objects.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment