Unverified Commit 00702aa3 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds support for keyword arguments into `ManagerBase` (#198)

# Description

This MR adds support for keyword arguments into the `ManagerBase` class.
This helps in defining some default arguments into the term function
call, which makes the configuration for those terms easier. Earlier, we
had a lot of `SceneEntity("robot")` references, which is redundant for
most cases that we are interested in.

## Type of change

- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file

---------
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarFarbod Farshidian <ffarshidian@theaiinstitute.com>
parent cd645f74
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.9.15"
version = "0.9.16"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.9.16 (2023-10-22)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for keyword arguments for terms in the :class:`omni.isaac.orbit.managers.ManagerBase`.
Fixed
^^^^^
* Fixed resetting of buffers in the :class:`TerminationManager` class. Earlier, the values were being set
to ``0.0`` instead of ``False``.
0.9.15 (2023-10-22)
~~~~~~~~~~~~~~~~~~~
......
......@@ -22,7 +22,9 @@ if TYPE_CHECKING:
from omni.isaac.orbit.envs.rl_env import RLEnv
def terrain_levels_vel(env: RLEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg) -> torch.Tensor:
def terrain_levels_vel(
env: RLEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Curriculum based on the distance the robot walked when commanded to move at a desired velocity.
This term is used to increase the difficulty of the terrain when the robot walks far enough and decrease the
......
......@@ -26,21 +26,21 @@ Root state.
"""
def base_lin_vel(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def base_lin_vel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Root linear velocity in the asset's root frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_lin_vel_b
def base_ang_vel(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def base_ang_vel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Root angular velocity in the asset's root frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_ang_vel_b
def projected_gravity(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def projected_gravity(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Gravity projection on the asset's root frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
......@@ -52,14 +52,14 @@ Joint state.
"""
def joint_pos_rel(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def joint_pos_rel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""The joint positions of the asset w.r.t. the default joint positions."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
return asset.data.joint_pos - asset.data.default_joint_pos
def joint_vel_rel(env: BaseEnv, asset_cfg: SceneEntityCfg):
def joint_vel_rel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")):
"""The joint velocities of the asset w.r.t. the default joint velocities."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
......@@ -71,16 +71,12 @@ Sensors.
"""
def height_scan(env: BaseEnv, asset_cfg: SceneEntityCfg, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Height scan from the given sensor w.r.t. the asset's root frame."""
def height_scan(env: BaseEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Height scan from the given sensor w.r.t. the sensor's frame."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
sensor: RayCaster = env.scene.sensors[sensor_cfg.name]
# TODO (@dhoeller): is this sensor specific or we can generalize it?
hit_points_z = torch.nan_to_num(sensor.data.ray_hits_w[..., 2], posinf=-1.0)
# compute the height scan: robot_z - ground_z - offset
heights = asset.data.root_state_w[:, 2].unsqueeze(1) - hit_points_z - 0.5
# return the height scan
heights = sensor.data.pos_w[:, 2].unsqueeze(1) - sensor.data.ray_hits_w[..., 2] - 0.5
return heights
......@@ -89,7 +85,7 @@ Actions.
"""
def action(env: BaseEnv) -> torch.Tensor:
def last_action(env: BaseEnv) -> torch.Tensor:
"""The last input action to the environment."""
return env.action_manager.action
......
......@@ -28,11 +28,11 @@ if TYPE_CHECKING:
def randomize_rigid_body_material(
env: RLEnv,
env_ids: torch.Tensor | None,
asset_cfg: SceneEntityCfg,
static_friction_range: tuple[float, float],
dynamic_friction_range: tuple[float, float],
restitution_range: tuple[float, float],
num_buckets: int,
asset_cfg: SceneEntityCfg,
):
"""Randomize the physics materials on all geometries of the asset.
......@@ -79,7 +79,7 @@ def randomize_rigid_body_material(
asset.body_physx_view.set_material_properties(materials, indices)
def add_body_mass(env: RLEnv, env_ids: torch.Tensor | None, asset_cfg: SceneEntityCfg, mass_range: tuple[float, float]):
def add_body_mass(env: RLEnv, env_ids: torch.Tensor | None, mass_range: tuple[float, float], asset_cfg: SceneEntityCfg):
"""Randomize the mass of the bodies by adding a random value sampled from the given range.
.. tip::
......@@ -109,9 +109,9 @@ def add_body_mass(env: RLEnv, env_ids: torch.Tensor | None, asset_cfg: SceneEnti
def apply_external_force_torque(
env: RLEnv,
env_ids: torch.Tensor,
asset_cfg: SceneEntityCfg,
force_range: tuple[float, float],
torque_range: tuple[float, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""Randomize the external forces and torques applied to the bodies.
......@@ -137,7 +137,10 @@ def apply_external_force_torque(
def push_by_setting_velocity(
env: RLEnv, env_ids: torch.Tensor, asset_cfg: SceneEntityCfg, velocity_range: dict[str, tuple[float, float]]
env: RLEnv,
env_ids: torch.Tensor,
velocity_range: dict[str, tuple[float, float]],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""Push the asset by setting the root velocity to a random value within the given ranges.
......@@ -167,9 +170,9 @@ def push_by_setting_velocity(
def reset_root_state(
env: RLEnv,
env_ids: torch.Tensor,
asset_cfg: SceneEntityCfg,
pose_range: dict[str, tuple[float, float]],
velocity_range: dict[str, tuple[float, float]],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""Reset the asset root state to a random position and velocity within the given ranges.
......@@ -218,9 +221,9 @@ def reset_root_state(
def reset_joints_by_scale(
env: RLEnv,
env_ids: torch.Tensor,
asset_cfg: SceneEntityCfg,
position_range: tuple[float, float],
velocity_range: tuple[float, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""Reset the robot joints by scaling the default position and velocity by the given ranges.
......
......@@ -21,27 +21,36 @@ from omni.isaac.orbit.sensors import ContactSensor
if TYPE_CHECKING:
from omni.isaac.orbit.envs.rl_env import RLEnv
"""
General.
"""
def termination_penalty(env: RLEnv) -> torch.Tensor:
"""Penalize terminated episodes that don't correspond to episodic timeouts."""
return env.reset_buf * (~env.termination_manager.time_outs)
"""
Root penalties.
"""
def lin_vel_z_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def lin_vel_z_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize z-axis base linear velocity using L2-kernel."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return torch.square(asset.data.root_lin_vel_b[:, 2])
def ang_vel_xy_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def ang_vel_xy_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize xy-axis base angular velocity using L2-kernel."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.root_ang_vel_b[:, :2]), dim=1)
def flat_orientation_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def flat_orientation_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize non-flat base orientation using L2-kernel.
This is computed by penalizing the xy-components of the projected gravity vector.
......@@ -51,7 +60,9 @@ def flat_orientation_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
return torch.sum(torch.square(asset.data.projected_gravity_b[:, :2]), dim=1)
def base_height_l2(env: RLEnv, asset_cfg: SceneEntityCfg, target_height: float) -> torch.Tensor:
def base_height_l2(
env: RLEnv, target_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Penalize asset height from its target using L2-kernel.
Note:
......@@ -63,33 +74,39 @@ def base_height_l2(env: RLEnv, asset_cfg: SceneEntityCfg, target_height: float)
return torch.square(asset.data.root_pos_w[:, 2] - target_height)
def body_lin_acc_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize the linear acceleration of bodies using L2-kernel."""
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.norm(asset.data.body_lin_acc_w[:, asset_cfg.body_ids, :], dim=-1), dim=1)
"""
Joint penalties.
"""
def joint_torques_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def joint_torques_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize torques applied on the articulation using L2-kernel."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.applied_torque), dim=1)
def joint_vel_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def joint_vel_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint velocities on the articulation."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.joint_vel), dim=1)
def joint_acc_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def joint_acc_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint accelerations on the articulation using L2-kernel."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.joint_acc), dim=1)
def joint_pos_limits(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def joint_pos_limits(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint positions if they cross the soft limits.
This is computed as a sum of the absolute value of the difference between the joint position and the soft limits.
......@@ -102,7 +119,9 @@ def joint_pos_limits(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
return torch.sum(out_of_limits, dim=1)
def joint_vel_limits(env: RLEnv, asset_cfg: SceneEntityCfg, soft_ratio: float) -> torch.Tensor:
def joint_vel_limits(
env: RLEnv, soft_ratio: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Penalize joint velocities if they cross the soft limits.
This is computed as a sum of the absolute value of the difference between the joint velocity and the soft limits.
......@@ -124,7 +143,7 @@ Action penalties.
"""
def applied_torque_limits(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def applied_torque_limits(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize applied torques if they cross the limits.
This is computed as a sum of the absolute value of the difference between the applied torques and the limits.
......@@ -151,7 +170,7 @@ Contact sensor.
"""
def undesired_contacts(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor:
def undesired_contacts(env: RLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Penalize undesired contacts as the number of violations that are above a threshold."""
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
......@@ -162,7 +181,7 @@ def undesired_contacts(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float)
return torch.sum(is_contact, dim=1)
def contact_forces(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor:
def contact_forces(env: RLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Penalize contact forces as the amount of violations of the net contact force."""
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
......@@ -178,7 +197,7 @@ Velocity-tracking rewards.
"""
def track_lin_vel_xy_exp(env: RLEnv, asset_cfg: SceneEntityCfg, std: float) -> torch.Tensor:
def track_lin_vel_xy_exp(env: RLEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Reward tracking of linear velocity commands (xy axes) using exponential kernel."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
......@@ -189,7 +208,7 @@ def track_lin_vel_xy_exp(env: RLEnv, asset_cfg: SceneEntityCfg, std: float) -> t
return torch.exp(-lin_vel_error / std**2)
def track_ang_vel_z_exp(env: RLEnv, asset_cfg: SceneEntityCfg, std: float) -> torch.Tensor:
def track_ang_vel_z_exp(env: RLEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Reward tracking of angular velocity commands (yaw) using exponential kernel."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
......
......@@ -31,8 +31,15 @@ def time_out(env: RLEnv) -> torch.Tensor:
return env.episode_length_buf >= env.max_episode_length
def command_resample(env: RLEnv, num_commands: torch.Tensor) -> torch.Tensor:
return torch.logical_and((env.command_manager.time_left <= 0.0), (env.command_manager.num_commands == num_commands))
def command_resample(env: RLEnv, num_resamples: int = 1) -> torch.Tensor:
"""Terminate the episode based on the total number of times commands have been re-sampled.
This makes the maximum episode length fluid in nature as it depends on how the commands are
sampled. It is useful in situations where delayed rewards are used :cite:`rudin2022advanced`.
"""
return torch.logical_and(
(env.command_manager.time_left <= env.step_dt), (env.command_manager.command_counter == num_resamples)
)
"""
......@@ -40,7 +47,9 @@ Root terminations.
"""
def bad_orientation(env: RLEnv, asset_cfg: SceneEntityCfg, limit_angle: float) -> torch.Tensor:
def bad_orientation(
env: RLEnv, limit_angle: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Terminate when the asset's orientation is too far from the desired orientation limits.
This is computed by checking the angle between the projected gravity vector and the z-axis.
......@@ -50,7 +59,7 @@ def bad_orientation(env: RLEnv, asset_cfg: SceneEntityCfg, limit_angle: float) -
return torch.acos(-asset.data.projected_gravity_b[:, 2]).abs() > limit_angle
def base_height(env: RLEnv, asset_cfg: SceneEntityCfg, minimum_height: float) -> torch.Tensor:
def base_height(env: RLEnv, minimum_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when the asset's height is below the minimum height.
Note:
......@@ -66,7 +75,7 @@ Joint terminations.
"""
def joint_pos_limit(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def joint_pos_limit(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when the asset's joint positions are outside of the soft joint limits."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
......@@ -76,7 +85,7 @@ def joint_pos_limit(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
return torch.logical_or(out_of_upper_limits, out_of_lower_limits)
def joint_velocity_limit(env: RLEnv, asset_cfg: SceneEntityCfg, max_velocity) -> torch.Tensor:
def joint_velocity_limit(env: RLEnv, max_velocity, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when the asset's joint velocities are outside of the soft joint limits."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
......@@ -84,7 +93,7 @@ def joint_velocity_limit(env: RLEnv, asset_cfg: SceneEntityCfg, max_velocity) ->
return torch.any(torch.abs(asset.data.joint_vel) > max_velocity, dim=1)
def joint_torque_limit(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def joint_torque_limit(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when torque applied on the asset's joints are are outside of the soft joint limits."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
......@@ -99,7 +108,7 @@ Contact sensor.
"""
def illegal_contact(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor:
def illegal_contact(env: RLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Terminate when the contact force on the sensor exceeds the force threshold."""
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
......
......@@ -309,10 +309,10 @@ class RLEnv(BaseEnv, gym.Env):
"""
# update the curriculum for environments that need a reset
self.curriculum_manager.compute(env_ids=env_ids)
# randomize the MDP for environments that need a reset
self.randomization_manager.randomize(env_ids=env_ids, mode="reset")
# reset the internal buffers of the scene elements
self.scene.reset(env_ids)
# randomize the MDP for environments that need a reset
self.randomization_manager.randomize(env_ids=env_ids, mode="reset")
# iterate over all managers and reset them
# this returns a dictionary of information which is stored in the extras
......
......@@ -166,10 +166,15 @@ class ManagerBase(ABC):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
# check if term's arguments are matched by params
term_params = list(term_cfg.params.keys())
args = inspect.getfullargspec(term_cfg.func).args
args = inspect.signature(term_cfg.func).parameters
args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty]
args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty]
args = args_without_defaults + args_with_defaults
# ignore first two arguments for env and env_ids
# Think: Check for cases when kwargs are set inside the function?
if len(args) > min_argc:
if set(args[min_argc:]) != set(term_params):
msg = f"The term '{term_name}' expects parameters: {args[min_argc:]}, but {term_params} provided."
raise ValueError(msg)
if set(args[min_argc:]) != set(term_params + args_with_defaults):
raise ValueError(
f"The term '{term_name}' expects mandatory parameters: {args_without_defaults[min_argc:]}"
f" and optional parameters: {args_with_defaults}, but received: {term_params}."
)
......@@ -123,8 +123,8 @@ class TerminationManager(ManagerBase):
The combined termination signal of shape ``(num_envs,)``.
"""
# reset computation
self._done_buf[:] = 0.0
self._time_out_buf[:] = 0.0
self._done_buf[:] = False
self._time_out_buf[:] = False
# iterate over all the termination terms
for name, term_cfg in zip(self._term_names, self._term_cfgs):
value = term_cfg.func(self._env, **term_cfg.params)
......
......@@ -39,6 +39,10 @@ def grilled_chicken_with_yoghurt(env, hot: bool, bland: float):
return hot * bland * torch.ones(env.num_envs, 5, device=env.device)
def grilled_chicken_with_yoghurt_and_bbq(env, hot: bool, bland: float, bbq: bool = False):
return hot * bland * bbq * torch.ones(env.num_envs, 3, device=env.device)
class complex_function_class:
def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg
......@@ -84,13 +88,16 @@ class TestObservationManager(unittest.TestCase):
term_4 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt, scale=1.0, params={"hot": False, "bland": 2.0}
)
term_5 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt_and_bbq, scale=1.0, params={"hot": False, "bland": 2.0}
)
policy: ObservationGroupCfg = SampleGroupCfg()
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 4)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 5)
# print the expected string
print()
print(self.obs_man)
......
......@@ -101,28 +101,19 @@ class ObservationsCfg:
"""Observations for policy group."""
# observation terms (order preserved)
base_lin_vel = ObsTerm(
func=mdp.base_lin_vel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-0.1, n_max=0.1)
)
base_ang_vel = ObsTerm(
func=mdp.base_ang_vel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-0.2, n_max=0.2)
)
base_lin_vel = ObsTerm(func=mdp.base_lin_vel, noise=Unoise(n_min=-0.1, n_max=0.1))
base_ang_vel = ObsTerm(func=mdp.base_ang_vel, noise=Unoise(n_min=-0.2, n_max=0.2))
projected_gravity = ObsTerm(
func=mdp.projected_gravity,
params={"asset_cfg": SceneEntityCfg("robot")},
noise=Unoise(n_min=-0.05, n_max=0.05),
)
velocity_commands = ObsTerm(func=mdp.generated_commands)
joint_pos = ObsTerm(
func=mdp.joint_pos_rel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-0.01, n_max=0.01)
)
joint_vel = ObsTerm(
func=mdp.joint_vel_rel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-1.5, n_max=1.5)
)
actions = ObsTerm(func=mdp.action)
joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01))
joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5))
actions = ObsTerm(func=mdp.last_action)
height_scan = ObsTerm(
func=mdp.height_scan,
params={"asset_cfg": SceneEntityCfg("robot"), "sensor_cfg": SceneEntityCfg("height_scanner")},
params={"sensor_cfg": SceneEntityCfg("height_scanner")},
noise=Unoise(n_min=-0.1, n_max=0.1),
)
......@@ -172,7 +163,6 @@ class RandomizationCfg:
func=mdp.reset_root_state,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("robot"),
"pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)},
"velocity_range": {
"x": (-0.5, 0.5),
......@@ -189,7 +179,6 @@ class RandomizationCfg:
func=mdp.reset_joints_by_scale,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("robot"),
"position_range": (0.5, 1.5),
"velocity_range": (0.0, 0.0),
},
......@@ -200,10 +189,7 @@ class RandomizationCfg:
func=mdp.push_by_setting_velocity,
mode="interval",
interval_range_s=(10.0, 15.0),
params={
"asset_cfg": SceneEntityCfg("robot"),
"velocity_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5)},
},
params={"velocity_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5)}},
)
......@@ -212,17 +198,13 @@ class RewardsCfg:
"""Reward terms for the MDP."""
# -- task
track_lin_vel_xy_exp = RewTerm(
func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"asset_cfg": SceneEntityCfg("robot"), "std": math.sqrt(0.25)}
)
track_ang_vel_z_exp = RewTerm(
func=mdp.track_ang_vel_z_exp, weight=0.5, params={"asset_cfg": SceneEntityCfg("robot"), "std": math.sqrt(0.25)}
)
track_lin_vel_xy_exp = RewTerm(func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"std": math.sqrt(0.25)})
track_ang_vel_z_exp = RewTerm(func=mdp.track_ang_vel_z_exp, weight=0.5, params={"std": math.sqrt(0.25)})
# -- penalties
lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0, params={"asset_cfg": SceneEntityCfg("robot")})
ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05, params={"asset_cfg": SceneEntityCfg("robot")})
dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-1.0e-5, params={"asset_cfg": SceneEntityCfg("robot")})
dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-2.5e-7, params={"asset_cfg": SceneEntityCfg("robot")})
lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0)
ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05)
dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-1.0e-5)
dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-2.5e-7)
action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01)
feet_air_time = RewTerm(
func=mdp.feet_air_time,
......@@ -251,7 +233,7 @@ class TerminationsCfg:
class CurriculumCfg:
"""Curriculum terms for the MDP."""
terrain_levels = CurrTerm(func=mdp.terrain_levels_vel, params={"asset_cfg": SceneEntityCfg("robot")})
terrain_levels = CurrTerm(func=mdp.terrain_levels_vel)
##
......@@ -293,10 +275,7 @@ class LocomotionEnvRoughCfg(RLEnvCfg):
# simulation settings
self.sim.dt = 0.005
self.sim.disable_contact_processing = True
self.sim.physics_material.static_friction = 1.0
self.sim.physics_material.dynamic_friction = 1.0
self.sim.physics_material.friction_combine_mode = "multiply"
self.sim.physics_material.restitution_combine_mode = "multiply"
self.sim.physics_material = self.scene.terrain.physics_material
# update sensor update periods
# we tick all the sensors based on the smallest update period (physics update period)
self.scene.height_scanner.update_period = self.decimation * self.sim.dt
......
......@@ -118,7 +118,7 @@ class RslRlVecEnvWrapper(gym.Wrapper):
# return step information
obs = obs_dict["policy"]
extras["observations"] = obs_dict
return obs, rew, dones, extras
return obs, rew, dones.to(torch.long), extras
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment