Commit 1a42eb97 authored by arbhardwaj98's avatar arbhardwaj98 Committed by Mayank Mittal

Adds new MDP observation, randomization and reward terms (#60)

# Description

This MR adds the following:

1. Observations: Adds observations for root state (pos, quat, linear
vel, and angular vel) in the environment frame. Important for assets
such as objects during manipulation.

2. Randomizations: Adds random orientation randomization for assets
(such as objects) and joint position randomization for articulations.

3. Rewards: Adds a termination reward function for specific termination
terms. Needed if terminations are to be weighted individually, for eg,
if successful termination reward should have a different weighting
factor than illegal state termination reward.

Tested for functionality.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent e7506fea
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.12.1" version = "0.12.2"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.12.2 (2024-03-10)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added observation terms for states of a rigid object in world frame.
* Added randomization terms to set root state with randomized orientation and joint state within user-specified limits.
* Added reward term for penalizing specific termination terms.
Fixed
^^^^^
* Improved sampling of states inside randomization terms. Earlier, the code did multiple torch calls
for sampling different components of the vector. Now, it uses a single call to sample the entire vector.
0.12.1 (2024-03-09) 0.12.1 (2024-03-09)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -114,7 +114,7 @@ class ArticulationData(RigidObjectData): ...@@ -114,7 +114,7 @@ class ArticulationData(RigidObjectData):
"""Joint positions limits for all joints. Shape is (count, num_joints, 2).""" """Joint positions limits for all joints. Shape is (count, num_joints, 2)."""
soft_joint_vel_limits: torch.Tensor = None soft_joint_vel_limits: torch.Tensor = None
"""Joint velocity limits for all joints. Shape is (count, num_joints, 2).""" """Joint velocity limits for all joints. Shape is (count, num_joints)."""
gear_ratio: torch.Tensor = None gear_ratio: torch.Tensor = None
"""Gear ratio for relating motor torques to applied Joint torques. Shape is (count, num_joints).""" """Gear ratio for relating motor torques to applied Joint torques. Shape is (count, num_joints)."""
...@@ -55,6 +55,34 @@ def projected_gravity(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg(" ...@@ -55,6 +55,34 @@ def projected_gravity(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("
return asset.data.projected_gravity_b return asset.data.projected_gravity_b
def root_pos_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root position in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_pos_w - env.scene.env_origins
def root_quat_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root orientation in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_quat_w
def root_lin_vel_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root linear velocity in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_lin_vel_w
def root_ang_vel_w(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Asset root angular velocity in the environment frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_ang_vel_w
""" """
Joint state. Joint state.
""" """
......
...@@ -19,7 +19,10 @@ from typing import TYPE_CHECKING ...@@ -19,7 +19,10 @@ from typing import TYPE_CHECKING
from omni.isaac.orbit.assets import Articulation, RigidObject from omni.isaac.orbit.assets import Articulation, RigidObject
from omni.isaac.orbit.managers import SceneEntityCfg from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.utils.math import quat_from_euler_xyz, sample_uniform from omni.isaac.orbit.managers.manager_base import ManagerTermBase
from omni.isaac.orbit.managers.manager_term_cfg import RandomizationTermCfg
from omni.isaac.orbit.terrains import TerrainImporter
from omni.isaac.orbit.utils.math import quat_from_euler_xyz, random_orientation, sample_uniform
if TYPE_CHECKING: if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv from omni.isaac.orbit.envs import BaseEnv
...@@ -159,12 +162,9 @@ def push_by_setting_velocity( ...@@ -159,12 +162,9 @@ def push_by_setting_velocity(
# velocities # velocities
vel_w = asset.data.root_vel_w[env_ids] vel_w = asset.data.root_vel_w[env_ids]
# sample random velocities # sample random velocities
vel_w[:, 0].uniform_(*velocity_range.get("x", (0.0, 0.0))) range_list = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]]
vel_w[:, 1].uniform_(*velocity_range.get("y", (0.0, 0.0))) ranges = torch.tensor(range_list, device=asset.device)
vel_w[:, 2].uniform_(*velocity_range.get("z", (0.0, 0.0))) vel_w[:] = sample_uniform(ranges[:, 0], ranges[:, 1], vel_w.shape, device=asset.device)
vel_w[:, 3].uniform_(*velocity_range.get("roll", (0.0, 0.0)))
vel_w[:, 4].uniform_(*velocity_range.get("pitch", (0.0, 0.0)))
vel_w[:, 5].uniform_(*velocity_range.get("yaw", (0.0, 0.0)))
# set the velocities into the physics simulation # set the velocities into the physics simulation
asset.write_root_velocity_to_sim(vel_w, env_ids=env_ids) asset.write_root_velocity_to_sim(vel_w, env_ids=env_ids)
...@@ -194,26 +194,72 @@ def reset_root_state_uniform( ...@@ -194,26 +194,72 @@ def reset_root_state_uniform(
# get default root state # get default root state
root_states = asset.data.default_root_state[env_ids].clone() root_states = asset.data.default_root_state[env_ids].clone()
# positions # poses
pos_offset = torch.zeros_like(root_states[:, 0:3]) range_list = [pose_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]]
pos_offset[:, 0].uniform_(*pose_range.get("x", (0.0, 0.0))) ranges = torch.tensor(range_list, device=asset.device)
pos_offset[:, 1].uniform_(*pose_range.get("y", (0.0, 0.0))) rand_samples = sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 6), device=asset.device)
pos_offset[:, 2].uniform_(*pose_range.get("z", (0.0, 0.0)))
positions = root_states[:, 0:3] + env.scene.env_origins[env_ids] + pos_offset positions = root_states[:, 0:3] + env.scene.env_origins[env_ids] + rand_samples[:, 0:3]
# orientations orientations = quat_from_euler_xyz(rand_samples[:, 3], rand_samples[:, 4], rand_samples[:, 5])
euler_angles = torch.zeros_like(positions)
euler_angles[:, 0].uniform_(*pose_range.get("roll", (0.0, 0.0))) # velocities
euler_angles[:, 1].uniform_(*pose_range.get("pitch", (0.0, 0.0))) range_list = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]]
euler_angles[:, 2].uniform_(*pose_range.get("yaw", (0.0, 0.0))) ranges = torch.tensor(range_list, device=asset.device)
orientations = quat_from_euler_xyz(euler_angles[:, 0], euler_angles[:, 1], euler_angles[:, 2]) rand_samples = sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 6), device=asset.device)
velocities = root_states[:, 7:13] + rand_samples
# set into the physics simulation
asset.write_root_pose_to_sim(torch.cat([positions, orientations], dim=-1), env_ids=env_ids)
asset.write_root_velocity_to_sim(velocities, env_ids=env_ids)
def reset_root_state_with_random_orientation(
env: BaseEnv,
env_ids: torch.Tensor,
pose_range: dict[str, tuple[float, float]],
velocity_range: dict[str, tuple[float, float]],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""Reset the asset root position and velocities sampled randomly within the given ranges
and the asset root orientation sampled randomly from the SO(3).
This function randomizes the root position and velocity of the asset.
* It samples the root position from the given ranges and adds them to the default root position, before setting
them into the physics simulation.
* It samples the root orientation uniformly from the SO(3) and sets them into the physics simulation.
* It samples the root velocity from the given ranges and sets them into the physics simulation.
The function takes a dictionary of position and velocity ranges for each axis and rotation:
* :attr:`pose_range` - a dictionary of position ranges for each axis. The keys of the dictionary are ``x``,
``y``, and ``z``.
* :attr:`velocity_range` - a dictionary of velocity ranges for each axis and rotation. The keys of the dictionary
are ``x``, ``y``, ``z``, ``roll``, ``pitch``, and ``yaw``.
The values are tuples of the form ``(min, max)``. If the dictionary does not contain a particular key,
the position is set to zero for that axis.
"""
# extract the used quantities (to enable type-hinting)
asset: RigidObject | Articulation = env.scene[asset_cfg.name]
# get default root state
root_states = asset.data.default_root_state[env_ids].clone()
# poses
range_list = [pose_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z"]]
ranges = torch.tensor(range_list, device=asset.device)
rand_samples = sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 3), device=asset.device)
positions = root_states[:, 0:3] + env.scene.env_origins[env_ids] + rand_samples
orientations = random_orientation(len(env_ids), device=asset.device)
# velocities # velocities
velocities = root_states[:, 7:13] range_list = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]]
velocities[:, 0].uniform_(*velocity_range.get("x", (0.0, 0.0))) ranges = torch.tensor(range_list, device=asset.device)
velocities[:, 1].uniform_(*velocity_range.get("y", (0.0, 0.0))) rand_samples = sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 6), device=asset.device)
velocities[:, 2].uniform_(*velocity_range.get("z", (0.0, 0.0)))
velocities[:, 3].uniform_(*velocity_range.get("roll", (0.0, 0.0))) velocities = root_states[:, 7:13] + rand_samples
velocities[:, 4].uniform_(*velocity_range.get("pitch", (0.0, 0.0)))
velocities[:, 5].uniform_(*velocity_range.get("yaw", (0.0, 0.0)))
# set into the physics simulation # set into the physics simulation
asset.write_root_pose_to_sim(torch.cat([positions, orientations], dim=-1), env_ids=env_ids) asset.write_root_pose_to_sim(torch.cat([positions, orientations], dim=-1), env_ids=env_ids)
...@@ -242,48 +288,35 @@ def reset_robot_root_from_terrain( ...@@ -242,48 +288,35 @@ def reset_robot_root_from_terrain(
""" """
# access the used quantities (to enable type-hinting) # access the used quantities (to enable type-hinting)
asset: RigidObject | Articulation = env.scene[asset_cfg.name] asset: RigidObject | Articulation = env.scene[asset_cfg.name]
terrain: TerrainImporter = env.scene.terrain
# obtain all flat patches corresponding to the valid poses # obtain all flat patches corresponding to the valid poses
valid_poses: torch.Tensor = env.scene.terrain.flat_patches.get("init_pos") valid_poses: torch.Tensor = terrain.flat_patches.get("init_pos")
if valid_poses is None: if valid_poses is None:
raise ValueError( raise ValueError(
"The randomization term 'reset_robot_root_from_terrain' requires valid flat patches under 'init_pos'." "The randomization term 'reset_robot_root_from_terrain' requires valid flat patches under 'init_pos'."
f" Found: {list(env.scene.terrain.flat_patches.keys())}" f" Found: {list(terrain.flat_patches.keys())}"
) )
# sample random valid poses # sample random valid poses
ids = torch.randint(0, valid_poses.shape[2], size=(len(env_ids),), device=env.device) ids = torch.randint(0, valid_poses.shape[2], size=(len(env_ids),), device=env.device)
positions = valid_poses[env.scene.terrain.terrain_levels[env_ids], env.scene.terrain.terrain_types[env_ids], ids] positions = valid_poses[terrain.terrain_levels[env_ids], terrain.terrain_types[env_ids], ids]
positions += asset.data.default_root_state[env_ids, :3] positions += asset.data.default_root_state[env_ids, :3]
# sample random orientations # sample random orientations
ranges = torch.tensor( range_list = [pose_range.get(key, (0.0, 0.0)) for key in ["roll", "pitch", "yaw"]]
[ ranges = torch.tensor(range_list, device=asset.device)
pose_range.get("roll", (0.0, 0.0)), rand_samples = sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 3), device=asset.device)
pose_range.get("pitch", (0.0, 0.0)),
pose_range.get("yaw", (0.0, 0.0)),
],
device=env.device,
)
euler_angles = torch.zeros_like(positions).uniform_()
euler_angles = ranges[0] + (ranges[1] - ranges[0]) * euler_angles
# convert to quaternions # convert to quaternions
orientations = quat_from_euler_xyz(euler_angles[:, 0], euler_angles[:, 1], euler_angles[:, 2]) orientations = quat_from_euler_xyz(rand_samples[:, 0], rand_samples[:, 1], rand_samples[:, 2])
# sample random velocities # sample random velocities
ranges = torch.tensor( range_list = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]]
[ ranges = torch.tensor(range_list, device=asset.device)
velocity_range.get("x", (0.0, 0.0)), rand_samples = sample_uniform(ranges[:, 0], ranges[:, 1], (len(env_ids), 6), device=asset.device)
velocity_range.get("y", (0.0, 0.0)),
velocity_range.get("z", (0.0, 0.0)), velocities = asset.data.default_root_state[:, 7:13] + rand_samples
velocity_range.get("roll", (0.0, 0.0)),
velocity_range.get("pitch", (0.0, 0.0)),
velocity_range.get("yaw", (0.0, 0.0)),
],
device=env.device,
)
velocities = torch.zeros(len(env_ids), 6, device=asset.device).uniform_()
velocities = ranges[:, 0] + (ranges[:, 1] - ranges[:, 0]) * velocities
# set into the physics simulation # set into the physics simulation
asset.write_root_pose_to_sim(torch.cat([positions, orientations], dim=-1), env_ids=env_ids) asset.write_root_pose_to_sim(torch.cat([positions, orientations], dim=-1), env_ids=env_ids)
...@@ -347,6 +380,124 @@ def reset_joints_by_offset( ...@@ -347,6 +380,124 @@ def reset_joints_by_offset(
asset.write_joint_state_to_sim(joint_pos, joint_vel, env_ids=env_ids) asset.write_joint_state_to_sim(joint_pos, joint_vel, env_ids=env_ids)
class reset_joints_within_range(ManagerTermBase):
"""Reset an articulation's joints to a random position in the given ranges.
This function samples random values for the joint position and velocities from the given ranges.
The values are then set into the physics simulation.
The parameters to the function are:
* :attr:`position_range` - a dictionary of position ranges for each joint. The keys of the dictionary are the
joint names (or regular expressions) of the asset.
* :attr:`velocity_range` - a dictionary of velocity ranges for each joint. The keys of the dictionary are the
joint names (or regular expressions) of the asset.
* :attr:`use_default_offset` - a boolean flag to indicate if the ranges are offset by the default joint state.
Defaults to False.
* :attr:`asset_cfg` - the configuration of the asset to reset. Defaults to the entity named "robot" in the scene.
The dictionary values are a tuple of the form ``(min, max)``, where ``min`` and ``max`` are the minimum and
maximum values. If the dictionary does not contain a key, the joint position or joint velocity is set to
the default value for that joint. If the ``min`` or the ``max`` value is ``None``, the joint limits are used
instead.
"""
def __init__(self, cfg: RandomizationTermCfg, env: BaseEnv):
# initialize the base class
super().__init__(cfg, env)
# check if the cfg has the required parameters
if "position_range" not in cfg.params or "velocity_range" not in cfg.params:
raise ValueError(
f"The term 'reset_joints_within_range' requires parameters: 'position_range' and 'velocity_range'."
f" Received: {list(cfg.params.keys())}."
)
# parse the parameters
asset_cfg: SceneEntityCfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot"))
use_default_offset = cfg.params.get("use_default_offset", False)
# extract the used quantities (to enable type-hinting)
self._asset: Articulation = env.scene[asset_cfg.name]
default_joint_pos = self._asset.data.default_joint_pos[0]
default_joint_vel = self._asset.data.default_joint_vel[0]
# create buffers to store the joint position and velocity ranges
self._pos_ranges = self._asset.data.soft_joint_pos_limits[0].clone()
self._vel_ranges = torch.stack(
[-self._asset.data.soft_joint_vel_limits[0], self._asset.data.soft_joint_vel_limits[0]], dim=1
)
# parse joint position ranges
pos_joint_ids = []
for joint_name, joint_range in cfg.params["position_range"].items():
# find the joint ids
joint_ids = self._asset.find_joints(joint_name)[0]
pos_joint_ids.extend(joint_ids)
# set the joint position ranges based on the given values
if joint_range[0] is not None:
self._pos_ranges[joint_ids, 0] = joint_range[0] + use_default_offset * default_joint_pos[joint_ids]
if joint_range[1] is not None:
self._pos_ranges[joint_ids, 1] = joint_range[1] + use_default_offset * default_joint_pos[joint_ids]
# store the joint pos ids (used later to sample the joint positions)
self._pos_joint_ids = torch.tensor(pos_joint_ids, device=self._pos_ranges.device)
# clamp sampling range to the joint position limits
joint_pos_limits = self._asset.data.soft_joint_pos_limits[0]
self._pos_ranges = self._pos_ranges.clamp(min=joint_pos_limits[:, 0], max=joint_pos_limits[:, 1])
self._pos_ranges = self._pos_ranges[self._pos_joint_ids]
# parse joint velocity ranges
vel_joint_ids = []
for joint_name, joint_range in cfg.params["velocity_range"].items():
# find the joint ids
joint_ids = self._asset.find_joints(joint_name)[0]
vel_joint_ids.extend(joint_ids)
# set the joint position ranges based on the given values
if joint_range[0] is not None:
self._vel_ranges[joint_ids, 0] = joint_range[0] + use_default_offset * default_joint_vel[joint_ids]
if joint_range[1] is not None:
self._vel_ranges[joint_ids, 1] = joint_range[1] + use_default_offset * default_joint_vel[joint_ids]
# store the joint vel ids (used later to sample the joint positions)
self._vel_joint_ids = torch.tensor(vel_joint_ids, device=self._vel_ranges.device)
# clamp sampling range to the joint velocity limits
joint_vel_limits = self._asset.data.soft_joint_vel_limits[0]
self._vel_ranges = self._vel_ranges.clamp(min=-joint_vel_limits[:, None], max=joint_vel_limits[:, None])
self._vel_ranges = self._vel_ranges[self._vel_joint_ids]
def __call__(
self,
env: BaseEnv,
env_ids: torch.Tensor,
position_range: dict[str, tuple[float | None, float | None]],
velocity_range: dict[str, tuple[float | None, float | None]],
use_default_offset: bool = False,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
# get default joint state
joint_pos = self._asset.data.default_joint_pos[env_ids].clone()
joint_vel = self._asset.data.default_joint_vel[env_ids].clone()
# sample random joint positions for each joint
if len(self._pos_joint_ids) > 0:
joint_pos_shape = (len(env_ids), len(self._pos_joint_ids))
joint_pos[:, self._pos_joint_ids] = sample_uniform(
self._pos_ranges[:, 0], self._pos_ranges[:, 1], joint_pos_shape, device=joint_pos.device
)
# sample random joint velocities for each joint
if len(self._vel_joint_ids) > 0:
joint_vel_shape = (len(env_ids), len(self._vel_joint_ids))
joint_vel[:, self._vel_joint_ids] = sample_uniform(
self._vel_ranges[:, 0], self._vel_ranges[:, 1], joint_vel_shape, device=joint_vel.device
)
# set into the physics simulation
self._asset.write_joint_state_to_sim(joint_pos, joint_vel, env_ids=env_ids)
def reset_scene_to_default(env: BaseEnv, env_ids: torch.Tensor): def reset_scene_to_default(env: BaseEnv, env_ids: torch.Tensor):
"""Reset the scene to the default state specified in the scene configuration.""" """Reset the scene to the default state specified in the scene configuration."""
# rigid bodies # rigid bodies
......
...@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING ...@@ -16,6 +16,8 @@ from typing import TYPE_CHECKING
from omni.isaac.orbit.assets import Articulation, RigidObject from omni.isaac.orbit.assets import Articulation, RigidObject
from omni.isaac.orbit.managers import SceneEntityCfg from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.managers.manager_base import ManagerTermBase
from omni.isaac.orbit.managers.manager_term_cfg import RewardTermCfg
from omni.isaac.orbit.sensors import ContactSensor from omni.isaac.orbit.sensors import ContactSensor
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -36,6 +38,36 @@ def is_terminated(env: RLTaskEnv) -> torch.Tensor: ...@@ -36,6 +38,36 @@ def is_terminated(env: RLTaskEnv) -> torch.Tensor:
return env.termination_manager.terminated.float() return env.termination_manager.terminated.float()
class is_terminated_term(ManagerTermBase):
"""Penalize termination for specific terms that don't correspond to episodic timeouts.
The parameters are as follows:
* attr:`term_keys`: The termination terms to penalize. This can be a string, a list of strings
or regular expressions. Default is ".*" which penalizes all terminations.
The reward is computed as the sum of the termination terms that are not episodic timeouts.
This means that the reward is 0 if the episode is terminated due to an episodic timeout. Otherwise,
if two termination terms are active, the reward is 2.
"""
def __init__(self, cfg: RewardTermCfg, env: RLTaskEnv):
# initialize the base class
super().__init__(cfg, env)
# find and store the termination terms
term_keys = cfg.params.get("term_keys", ".*")
self._term_names = env.termination_manager.find_terms(term_keys)
def __call__(self, env: RLTaskEnv, term_keys: str | list[str] = ".*") -> torch.Tensor:
# Return the unweighted reward for the termination terms
reset_buf = torch.zeros(env.num_envs, device=env.device)
for term in self._term_names:
# Sums over terminations term values to account for multiple terminations in the same step
reset_buf += env.termination_manager.get_term(term)
return (reset_buf * (~env.termination_manager.time_outs)).float()
""" """
Root penalties. Root penalties.
""" """
......
...@@ -120,18 +120,19 @@ class RLTaskEnv(BaseEnv, gym.Env): ...@@ -120,18 +120,19 @@ class RLTaskEnv(BaseEnv, gym.Env):
def load_managers(self): def load_managers(self):
# note: this order is important since observation manager needs to know the command and action managers # note: this order is important since observation manager needs to know the command and action managers
# and the reward manager needs to know the termination manager
# -- command manager # -- command manager
self.command_manager: CommandManager = CommandManager(self.cfg.commands, self) self.command_manager: CommandManager = CommandManager(self.cfg.commands, self)
print("[INFO] Command Manager: ", self.command_manager) print("[INFO] Command Manager: ", self.command_manager)
# call the parent class to load the managers for observations and actions. # call the parent class to load the managers for observations and actions.
super().load_managers() super().load_managers()
# prepare the managers # prepare the managers
# -- reward manager
self.reward_manager = RewardManager(self.cfg.rewards, self)
print("[INFO] Reward Manager: ", self.reward_manager)
# -- termination manager # -- termination manager
self.termination_manager = TerminationManager(self.cfg.terminations, self) self.termination_manager = TerminationManager(self.cfg.terminations, self)
print("[INFO] Termination Manager: ", self.termination_manager) print("[INFO] Termination Manager: ", self.termination_manager)
# -- reward manager
self.reward_manager = RewardManager(self.cfg.rewards, self)
print("[INFO] Reward Manager: ", self.reward_manager)
# -- curriculum manager # -- curriculum manager
self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self) self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self)
print("[INFO] Curriculum Manager: ", self.curriculum_manager) print("[INFO] Curriculum Manager: ", self.curriculum_manager)
......
...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any ...@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any
import carb import carb
import omni.isaac.orbit.utils.string as string_utils
from omni.isaac.orbit.utils import string_to_callable from omni.isaac.orbit.utils import string_to_callable
from .manager_term_cfg import ManagerTermBaseCfg from .manager_term_cfg import ManagerTermBaseCfg
...@@ -164,6 +165,33 @@ class ManagerBase(ABC): ...@@ -164,6 +165,33 @@ class ManagerBase(ABC):
""" """
return {} return {}
def find_terms(self, name_keys: str | Sequence[str]) -> list[str]:
"""Find terms in the manager based on the names.
This function searches the manager for terms based on the names. The names can be
specified as regular expressions or a list of regular expressions. The search is
performed on the active terms in the manager.
Please check the :meth:`omni.isaac.orbit.utils.string_utils.resolve_matching_names` function for more
information on the name matching.
Args:
name_keys: A regular expression or a list of regular expressions to match the term names.
Returns:
A list of term names that match the input keys.
"""
# resolve search keys
if isinstance(self.active_terms, dict):
list_of_strings = []
for names in self.active_terms.values():
list_of_strings.extend(names)
else:
list_of_strings = self.active_terms
# return the matching names
return string_utils.resolve_matching_names(name_keys, list_of_strings)[1]
""" """
Implementation specific. Implementation specific.
""" """
......
...@@ -55,9 +55,9 @@ class TerminationManager(ManagerBase): ...@@ -55,9 +55,9 @@ class TerminationManager(ManagerBase):
""" """
super().__init__(cfg, env) super().__init__(cfg, env)
# prepare extra info to store individual termination term information # prepare extra info to store individual termination term information
self._episode_dones = dict() self._term_dones = dict()
for term_name in self._term_names: for term_name in self._term_names:
self._episode_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) self._term_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
# create buffer for managing termination per environment # create buffer for managing termination per environment
self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self._terminated_buf = torch.zeros_like(self._truncated_buf) self._terminated_buf = torch.zeros_like(self._truncated_buf)
...@@ -133,11 +133,11 @@ class TerminationManager(ManagerBase): ...@@ -133,11 +133,11 @@ class TerminationManager(ManagerBase):
env_ids = slice(None) env_ids = slice(None)
# add to episode dict # add to episode dict
extras = {} extras = {}
for key in self._episode_dones.keys(): for key in self._term_dones.keys():
# store information # store information
extras["Episode Termination/" + key] = torch.count_nonzero(self._episode_dones[key][env_ids]).item() extras["Episode Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
# reset episode dones # reset episode dones
self._episode_dones[key][env_ids] = False self._term_dones[key][env_ids] = False
# reset all the reward terms # reset all the reward terms
for term_cfg in self._class_term_cfgs: for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids) term_cfg.func.reset(env_ids=env_ids)
...@@ -165,10 +165,21 @@ class TerminationManager(ManagerBase): ...@@ -165,10 +165,21 @@ class TerminationManager(ManagerBase):
else: else:
self._terminated_buf |= value self._terminated_buf |= value
# add to episode dones # add to episode dones
self._episode_dones[name] |= value self._term_dones[name] |= value
# return combined termination signal # return combined termination signal
return self._truncated_buf | self._terminated_buf return self._truncated_buf | self._terminated_buf
def get_term(self, name: str) -> torch.Tensor:
"""Returns the termination term with the specified name.
Args:
name: The name of the termination term.
Returns:
The corresponding termination term value. Shape is (num_envs,).
"""
return self._term_dones[name]
""" """
Operations - Term settings. Operations - Term settings.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment