Unverified Commit 00702aa3 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds support for keyword arguments into `ManagerBase` (#198)

# Description

This MR adds support for keyword arguments into the `ManagerBase` class.
This helps in defining some default arguments into the term function
call, which makes the configuration for those terms easier. Earlier, we
had a lot of `SceneEntity("robot")` references, which is redundant for
most cases that we are interested in.

## Type of change

- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file

---------
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarFarbod Farshidian <ffarshidian@theaiinstitute.com>
parent cd645f74
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.9.15" version = "0.9.16"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.9.16 (2023-10-22)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for keyword arguments for terms in the :class:`omni.isaac.orbit.managers.ManagerBase`.
Fixed
^^^^^
* Fixed resetting of buffers in the :class:`TerminationManager` class. Earlier, the values were being set
to ``0.0`` instead of ``False``.
0.9.15 (2023-10-22) 0.9.15 (2023-10-22)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -22,7 +22,9 @@ if TYPE_CHECKING: ...@@ -22,7 +22,9 @@ if TYPE_CHECKING:
from omni.isaac.orbit.envs.rl_env import RLEnv from omni.isaac.orbit.envs.rl_env import RLEnv
def terrain_levels_vel(env: RLEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg) -> torch.Tensor: def terrain_levels_vel(
env: RLEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Curriculum based on the distance the robot walked when commanded to move at a desired velocity. """Curriculum based on the distance the robot walked when commanded to move at a desired velocity.
This term is used to increase the difficulty of the terrain when the robot walks far enough and decrease the This term is used to increase the difficulty of the terrain when the robot walks far enough and decrease the
......
...@@ -26,21 +26,21 @@ Root state. ...@@ -26,21 +26,21 @@ Root state.
""" """
def base_lin_vel(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def base_lin_vel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Root linear velocity in the asset's root frame.""" """Root linear velocity in the asset's root frame."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name] asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_lin_vel_b return asset.data.root_lin_vel_b
def base_ang_vel(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def base_ang_vel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Root angular velocity in the asset's root frame.""" """Root angular velocity in the asset's root frame."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name] asset: RigidObject = env.scene[asset_cfg.name]
return asset.data.root_ang_vel_b return asset.data.root_ang_vel_b
def projected_gravity(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def projected_gravity(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Gravity projection on the asset's root frame.""" """Gravity projection on the asset's root frame."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name] asset: RigidObject = env.scene[asset_cfg.name]
...@@ -52,14 +52,14 @@ Joint state. ...@@ -52,14 +52,14 @@ Joint state.
""" """
def joint_pos_rel(env: BaseEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def joint_pos_rel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""The joint positions of the asset w.r.t. the default joint positions.""" """The joint positions of the asset w.r.t. the default joint positions."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
return asset.data.joint_pos - asset.data.default_joint_pos return asset.data.joint_pos - asset.data.default_joint_pos
def joint_vel_rel(env: BaseEnv, asset_cfg: SceneEntityCfg): def joint_vel_rel(env: BaseEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")):
"""The joint velocities of the asset w.r.t. the default joint velocities.""" """The joint velocities of the asset w.r.t. the default joint velocities."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
...@@ -71,16 +71,12 @@ Sensors. ...@@ -71,16 +71,12 @@ Sensors.
""" """
def height_scan(env: BaseEnv, asset_cfg: SceneEntityCfg, sensor_cfg: SceneEntityCfg) -> torch.Tensor: def height_scan(env: BaseEnv, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Height scan from the given sensor w.r.t. the asset's root frame.""" """Height scan from the given sensor w.r.t. the sensor's frame."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
sensor: RayCaster = env.scene.sensors[sensor_cfg.name] sensor: RayCaster = env.scene.sensors[sensor_cfg.name]
# TODO (@dhoeller): is this sensor specific or we can generalize it?
hit_points_z = torch.nan_to_num(sensor.data.ray_hits_w[..., 2], posinf=-1.0)
# compute the height scan: robot_z - ground_z - offset
heights = asset.data.root_state_w[:, 2].unsqueeze(1) - hit_points_z - 0.5
# return the height scan # return the height scan
heights = sensor.data.pos_w[:, 2].unsqueeze(1) - sensor.data.ray_hits_w[..., 2] - 0.5
return heights return heights
...@@ -89,7 +85,7 @@ Actions. ...@@ -89,7 +85,7 @@ Actions.
""" """
def action(env: BaseEnv) -> torch.Tensor: def last_action(env: BaseEnv) -> torch.Tensor:
"""The last input action to the environment.""" """The last input action to the environment."""
return env.action_manager.action return env.action_manager.action
......
...@@ -28,11 +28,11 @@ if TYPE_CHECKING: ...@@ -28,11 +28,11 @@ if TYPE_CHECKING:
def randomize_rigid_body_material( def randomize_rigid_body_material(
env: RLEnv, env: RLEnv,
env_ids: torch.Tensor | None, env_ids: torch.Tensor | None,
asset_cfg: SceneEntityCfg,
static_friction_range: tuple[float, float], static_friction_range: tuple[float, float],
dynamic_friction_range: tuple[float, float], dynamic_friction_range: tuple[float, float],
restitution_range: tuple[float, float], restitution_range: tuple[float, float],
num_buckets: int, num_buckets: int,
asset_cfg: SceneEntityCfg,
): ):
"""Randomize the physics materials on all geometries of the asset. """Randomize the physics materials on all geometries of the asset.
...@@ -79,7 +79,7 @@ def randomize_rigid_body_material( ...@@ -79,7 +79,7 @@ def randomize_rigid_body_material(
asset.body_physx_view.set_material_properties(materials, indices) asset.body_physx_view.set_material_properties(materials, indices)
def add_body_mass(env: RLEnv, env_ids: torch.Tensor | None, asset_cfg: SceneEntityCfg, mass_range: tuple[float, float]): def add_body_mass(env: RLEnv, env_ids: torch.Tensor | None, mass_range: tuple[float, float], asset_cfg: SceneEntityCfg):
"""Randomize the mass of the bodies by adding a random value sampled from the given range. """Randomize the mass of the bodies by adding a random value sampled from the given range.
.. tip:: .. tip::
...@@ -109,9 +109,9 @@ def add_body_mass(env: RLEnv, env_ids: torch.Tensor | None, asset_cfg: SceneEnti ...@@ -109,9 +109,9 @@ def add_body_mass(env: RLEnv, env_ids: torch.Tensor | None, asset_cfg: SceneEnti
def apply_external_force_torque( def apply_external_force_torque(
env: RLEnv, env: RLEnv,
env_ids: torch.Tensor, env_ids: torch.Tensor,
asset_cfg: SceneEntityCfg,
force_range: tuple[float, float], force_range: tuple[float, float],
torque_range: tuple[float, float], torque_range: tuple[float, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
): ):
"""Randomize the external forces and torques applied to the bodies. """Randomize the external forces and torques applied to the bodies.
...@@ -137,7 +137,10 @@ def apply_external_force_torque( ...@@ -137,7 +137,10 @@ def apply_external_force_torque(
def push_by_setting_velocity( def push_by_setting_velocity(
env: RLEnv, env_ids: torch.Tensor, asset_cfg: SceneEntityCfg, velocity_range: dict[str, tuple[float, float]] env: RLEnv,
env_ids: torch.Tensor,
velocity_range: dict[str, tuple[float, float]],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
): ):
"""Push the asset by setting the root velocity to a random value within the given ranges. """Push the asset by setting the root velocity to a random value within the given ranges.
...@@ -167,9 +170,9 @@ def push_by_setting_velocity( ...@@ -167,9 +170,9 @@ def push_by_setting_velocity(
def reset_root_state( def reset_root_state(
env: RLEnv, env: RLEnv,
env_ids: torch.Tensor, env_ids: torch.Tensor,
asset_cfg: SceneEntityCfg,
pose_range: dict[str, tuple[float, float]], pose_range: dict[str, tuple[float, float]],
velocity_range: dict[str, tuple[float, float]], velocity_range: dict[str, tuple[float, float]],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
): ):
"""Reset the asset root state to a random position and velocity within the given ranges. """Reset the asset root state to a random position and velocity within the given ranges.
...@@ -218,9 +221,9 @@ def reset_root_state( ...@@ -218,9 +221,9 @@ def reset_root_state(
def reset_joints_by_scale( def reset_joints_by_scale(
env: RLEnv, env: RLEnv,
env_ids: torch.Tensor, env_ids: torch.Tensor,
asset_cfg: SceneEntityCfg,
position_range: tuple[float, float], position_range: tuple[float, float],
velocity_range: tuple[float, float], velocity_range: tuple[float, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
): ):
"""Reset the robot joints by scaling the default position and velocity by the given ranges. """Reset the robot joints by scaling the default position and velocity by the given ranges.
......
...@@ -21,27 +21,36 @@ from omni.isaac.orbit.sensors import ContactSensor ...@@ -21,27 +21,36 @@ from omni.isaac.orbit.sensors import ContactSensor
if TYPE_CHECKING: if TYPE_CHECKING:
from omni.isaac.orbit.envs.rl_env import RLEnv from omni.isaac.orbit.envs.rl_env import RLEnv
"""
General.
"""
def termination_penalty(env: RLEnv) -> torch.Tensor:
"""Penalize terminated episodes that don't correspond to episodic timeouts."""
return env.reset_buf * (~env.termination_manager.time_outs)
""" """
Root penalties. Root penalties.
""" """
def lin_vel_z_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def lin_vel_z_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize z-axis base linear velocity using L2-kernel.""" """Penalize z-axis base linear velocity using L2-kernel."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name] asset: RigidObject = env.scene[asset_cfg.name]
return torch.square(asset.data.root_lin_vel_b[:, 2]) return torch.square(asset.data.root_lin_vel_b[:, 2])
def ang_vel_xy_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def ang_vel_xy_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize xy-axis base angular velocity using L2-kernel.""" """Penalize xy-axis base angular velocity using L2-kernel."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name] asset: RigidObject = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.root_ang_vel_b[:, :2]), dim=1) return torch.sum(torch.square(asset.data.root_ang_vel_b[:, :2]), dim=1)
def flat_orientation_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def flat_orientation_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize non-flat base orientation using L2-kernel. """Penalize non-flat base orientation using L2-kernel.
This is computed by penalizing the xy-components of the projected gravity vector. This is computed by penalizing the xy-components of the projected gravity vector.
...@@ -51,7 +60,9 @@ def flat_orientation_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: ...@@ -51,7 +60,9 @@ def flat_orientation_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
return torch.sum(torch.square(asset.data.projected_gravity_b[:, :2]), dim=1) return torch.sum(torch.square(asset.data.projected_gravity_b[:, :2]), dim=1)
def base_height_l2(env: RLEnv, asset_cfg: SceneEntityCfg, target_height: float) -> torch.Tensor: def base_height_l2(
env: RLEnv, target_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Penalize asset height from its target using L2-kernel. """Penalize asset height from its target using L2-kernel.
Note: Note:
...@@ -63,33 +74,39 @@ def base_height_l2(env: RLEnv, asset_cfg: SceneEntityCfg, target_height: float) ...@@ -63,33 +74,39 @@ def base_height_l2(env: RLEnv, asset_cfg: SceneEntityCfg, target_height: float)
return torch.square(asset.data.root_pos_w[:, 2] - target_height) return torch.square(asset.data.root_pos_w[:, 2] - target_height)
def body_lin_acc_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize the linear acceleration of bodies using L2-kernel."""
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.norm(asset.data.body_lin_acc_w[:, asset_cfg.body_ids, :], dim=-1), dim=1)
""" """
Joint penalties. Joint penalties.
""" """
def joint_torques_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def joint_torques_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize torques applied on the articulation using L2-kernel.""" """Penalize torques applied on the articulation using L2-kernel."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.applied_torque), dim=1) return torch.sum(torch.square(asset.data.applied_torque), dim=1)
def joint_vel_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def joint_vel_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint velocities on the articulation.""" """Penalize joint velocities on the articulation."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.joint_vel), dim=1) return torch.sum(torch.square(asset.data.joint_vel), dim=1)
def joint_acc_l2(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def joint_acc_l2(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint accelerations on the articulation using L2-kernel.""" """Penalize joint accelerations on the articulation using L2-kernel."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.joint_acc), dim=1) return torch.sum(torch.square(asset.data.joint_acc), dim=1)
def joint_pos_limits(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def joint_pos_limits(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint positions if they cross the soft limits. """Penalize joint positions if they cross the soft limits.
This is computed as a sum of the absolute value of the difference between the joint position and the soft limits. This is computed as a sum of the absolute value of the difference between the joint position and the soft limits.
...@@ -102,7 +119,9 @@ def joint_pos_limits(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: ...@@ -102,7 +119,9 @@ def joint_pos_limits(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
return torch.sum(out_of_limits, dim=1) return torch.sum(out_of_limits, dim=1)
def joint_vel_limits(env: RLEnv, asset_cfg: SceneEntityCfg, soft_ratio: float) -> torch.Tensor: def joint_vel_limits(
env: RLEnv, soft_ratio: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Penalize joint velocities if they cross the soft limits. """Penalize joint velocities if they cross the soft limits.
This is computed as a sum of the absolute value of the difference between the joint velocity and the soft limits. This is computed as a sum of the absolute value of the difference between the joint velocity and the soft limits.
...@@ -124,7 +143,7 @@ Action penalties. ...@@ -124,7 +143,7 @@ Action penalties.
""" """
def applied_torque_limits(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def applied_torque_limits(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize applied torques if they cross the limits. """Penalize applied torques if they cross the limits.
This is computed as a sum of the absolute value of the difference between the applied torques and the limits. This is computed as a sum of the absolute value of the difference between the applied torques and the limits.
...@@ -151,7 +170,7 @@ Contact sensor. ...@@ -151,7 +170,7 @@ Contact sensor.
""" """
def undesired_contacts(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor: def undesired_contacts(env: RLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Penalize undesired contacts as the number of violations that are above a threshold.""" """Penalize undesired contacts as the number of violations that are above a threshold."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
...@@ -162,7 +181,7 @@ def undesired_contacts(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float) ...@@ -162,7 +181,7 @@ def undesired_contacts(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float)
return torch.sum(is_contact, dim=1) return torch.sum(is_contact, dim=1)
def contact_forces(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor: def contact_forces(env: RLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Penalize contact forces as the amount of violations of the net contact force.""" """Penalize contact forces as the amount of violations of the net contact force."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
...@@ -178,7 +197,7 @@ Velocity-tracking rewards. ...@@ -178,7 +197,7 @@ Velocity-tracking rewards.
""" """
def track_lin_vel_xy_exp(env: RLEnv, asset_cfg: SceneEntityCfg, std: float) -> torch.Tensor: def track_lin_vel_xy_exp(env: RLEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Reward tracking of linear velocity commands (xy axes) using exponential kernel.""" """Reward tracking of linear velocity commands (xy axes) using exponential kernel."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name] asset: RigidObject = env.scene[asset_cfg.name]
...@@ -189,7 +208,7 @@ def track_lin_vel_xy_exp(env: RLEnv, asset_cfg: SceneEntityCfg, std: float) -> t ...@@ -189,7 +208,7 @@ def track_lin_vel_xy_exp(env: RLEnv, asset_cfg: SceneEntityCfg, std: float) -> t
return torch.exp(-lin_vel_error / std**2) return torch.exp(-lin_vel_error / std**2)
def track_ang_vel_z_exp(env: RLEnv, asset_cfg: SceneEntityCfg, std: float) -> torch.Tensor: def track_ang_vel_z_exp(env: RLEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Reward tracking of angular velocity commands (yaw) using exponential kernel.""" """Reward tracking of angular velocity commands (yaw) using exponential kernel."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name] asset: RigidObject = env.scene[asset_cfg.name]
......
...@@ -31,8 +31,15 @@ def time_out(env: RLEnv) -> torch.Tensor: ...@@ -31,8 +31,15 @@ def time_out(env: RLEnv) -> torch.Tensor:
return env.episode_length_buf >= env.max_episode_length return env.episode_length_buf >= env.max_episode_length
def command_resample(env: RLEnv, num_commands: torch.Tensor) -> torch.Tensor: def command_resample(env: RLEnv, num_resamples: int = 1) -> torch.Tensor:
return torch.logical_and((env.command_manager.time_left <= 0.0), (env.command_manager.num_commands == num_commands)) """Terminate the episode based on the total number of times commands have been re-sampled.
This makes the maximum episode length fluid in nature as it depends on how the commands are
sampled. It is useful in situations where delayed rewards are used :cite:`rudin2022advanced`.
"""
return torch.logical_and(
(env.command_manager.time_left <= env.step_dt), (env.command_manager.command_counter == num_resamples)
)
""" """
...@@ -40,7 +47,9 @@ Root terminations. ...@@ -40,7 +47,9 @@ Root terminations.
""" """
def bad_orientation(env: RLEnv, asset_cfg: SceneEntityCfg, limit_angle: float) -> torch.Tensor: def bad_orientation(
env: RLEnv, limit_angle: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Terminate when the asset's orientation is too far from the desired orientation limits. """Terminate when the asset's orientation is too far from the desired orientation limits.
This is computed by checking the angle between the projected gravity vector and the z-axis. This is computed by checking the angle between the projected gravity vector and the z-axis.
...@@ -50,7 +59,7 @@ def bad_orientation(env: RLEnv, asset_cfg: SceneEntityCfg, limit_angle: float) - ...@@ -50,7 +59,7 @@ def bad_orientation(env: RLEnv, asset_cfg: SceneEntityCfg, limit_angle: float) -
return torch.acos(-asset.data.projected_gravity_b[:, 2]).abs() > limit_angle return torch.acos(-asset.data.projected_gravity_b[:, 2]).abs() > limit_angle
def base_height(env: RLEnv, asset_cfg: SceneEntityCfg, minimum_height: float) -> torch.Tensor: def base_height(env: RLEnv, minimum_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when the asset's height is below the minimum height. """Terminate when the asset's height is below the minimum height.
Note: Note:
...@@ -66,7 +75,7 @@ Joint terminations. ...@@ -66,7 +75,7 @@ Joint terminations.
""" """
def joint_pos_limit(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def joint_pos_limit(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when the asset's joint positions are outside of the soft joint limits.""" """Terminate when the asset's joint positions are outside of the soft joint limits."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
...@@ -76,7 +85,7 @@ def joint_pos_limit(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: ...@@ -76,7 +85,7 @@ def joint_pos_limit(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
return torch.logical_or(out_of_upper_limits, out_of_lower_limits) return torch.logical_or(out_of_upper_limits, out_of_lower_limits)
def joint_velocity_limit(env: RLEnv, asset_cfg: SceneEntityCfg, max_velocity) -> torch.Tensor: def joint_velocity_limit(env: RLEnv, max_velocity, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when the asset's joint velocities are outside of the soft joint limits.""" """Terminate when the asset's joint velocities are outside of the soft joint limits."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
...@@ -84,7 +93,7 @@ def joint_velocity_limit(env: RLEnv, asset_cfg: SceneEntityCfg, max_velocity) -> ...@@ -84,7 +93,7 @@ def joint_velocity_limit(env: RLEnv, asset_cfg: SceneEntityCfg, max_velocity) ->
return torch.any(torch.abs(asset.data.joint_vel) > max_velocity, dim=1) return torch.any(torch.abs(asset.data.joint_vel) > max_velocity, dim=1)
def joint_torque_limit(env: RLEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: def joint_torque_limit(env: RLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when torque applied on the asset's joints are are outside of the soft joint limits.""" """Terminate when torque applied on the asset's joints are are outside of the soft joint limits."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name] asset: Articulation = env.scene[asset_cfg.name]
...@@ -99,7 +108,7 @@ Contact sensor. ...@@ -99,7 +108,7 @@ Contact sensor.
""" """
def illegal_contact(env: RLEnv, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor: def illegal_contact(env: RLEnv, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
"""Terminate when the contact force on the sensor exceeds the force threshold.""" """Terminate when the contact force on the sensor exceeds the force threshold."""
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
......
...@@ -309,10 +309,10 @@ class RLEnv(BaseEnv, gym.Env): ...@@ -309,10 +309,10 @@ class RLEnv(BaseEnv, gym.Env):
""" """
# update the curriculum for environments that need a reset # update the curriculum for environments that need a reset
self.curriculum_manager.compute(env_ids=env_ids) self.curriculum_manager.compute(env_ids=env_ids)
# randomize the MDP for environments that need a reset
self.randomization_manager.randomize(env_ids=env_ids, mode="reset")
# reset the internal buffers of the scene elements # reset the internal buffers of the scene elements
self.scene.reset(env_ids) self.scene.reset(env_ids)
# randomize the MDP for environments that need a reset
self.randomization_manager.randomize(env_ids=env_ids, mode="reset")
# iterate over all managers and reset them # iterate over all managers and reset them
# this returns a dictionary of information which is stored in the extras # this returns a dictionary of information which is stored in the extras
......
...@@ -166,10 +166,15 @@ class ManagerBase(ABC): ...@@ -166,10 +166,15 @@ class ManagerBase(ABC):
raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}") raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}")
# check if term's arguments are matched by params # check if term's arguments are matched by params
term_params = list(term_cfg.params.keys()) term_params = list(term_cfg.params.keys())
args = inspect.getfullargspec(term_cfg.func).args args = inspect.signature(term_cfg.func).parameters
args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty]
args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty]
args = args_without_defaults + args_with_defaults
# ignore first two arguments for env and env_ids # ignore first two arguments for env and env_ids
# Think: Check for cases when kwargs are set inside the function? # Think: Check for cases when kwargs are set inside the function?
if len(args) > min_argc: if len(args) > min_argc:
if set(args[min_argc:]) != set(term_params): if set(args[min_argc:]) != set(term_params + args_with_defaults):
msg = f"The term '{term_name}' expects parameters: {args[min_argc:]}, but {term_params} provided." raise ValueError(
raise ValueError(msg) f"The term '{term_name}' expects mandatory parameters: {args_without_defaults[min_argc:]}"
f" and optional parameters: {args_with_defaults}, but received: {term_params}."
)
...@@ -123,8 +123,8 @@ class TerminationManager(ManagerBase): ...@@ -123,8 +123,8 @@ class TerminationManager(ManagerBase):
The combined termination signal of shape ``(num_envs,)``. The combined termination signal of shape ``(num_envs,)``.
""" """
# reset computation # reset computation
self._done_buf[:] = 0.0 self._done_buf[:] = False
self._time_out_buf[:] = 0.0 self._time_out_buf[:] = False
# iterate over all the termination terms # iterate over all the termination terms
for name, term_cfg in zip(self._term_names, self._term_cfgs): for name, term_cfg in zip(self._term_names, self._term_cfgs):
value = term_cfg.func(self._env, **term_cfg.params) value = term_cfg.func(self._env, **term_cfg.params)
......
...@@ -39,6 +39,10 @@ def grilled_chicken_with_yoghurt(env, hot: bool, bland: float): ...@@ -39,6 +39,10 @@ def grilled_chicken_with_yoghurt(env, hot: bool, bland: float):
return hot * bland * torch.ones(env.num_envs, 5, device=env.device) return hot * bland * torch.ones(env.num_envs, 5, device=env.device)
def grilled_chicken_with_yoghurt_and_bbq(env, hot: bool, bland: float, bbq: bool = False):
return hot * bland * bbq * torch.ones(env.num_envs, 3, device=env.device)
class complex_function_class: class complex_function_class:
def __init__(self, cfg: ObservationTermCfg, env: object): def __init__(self, cfg: ObservationTermCfg, env: object):
self.cfg = cfg self.cfg = cfg
...@@ -84,13 +88,16 @@ class TestObservationManager(unittest.TestCase): ...@@ -84,13 +88,16 @@ class TestObservationManager(unittest.TestCase):
term_4 = ObservationTermCfg( term_4 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt, scale=1.0, params={"hot": False, "bland": 2.0} func=grilled_chicken_with_yoghurt, scale=1.0, params={"hot": False, "bland": 2.0}
) )
term_5 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt_and_bbq, scale=1.0, params={"hot": False, "bland": 2.0}
)
policy: ObservationGroupCfg = SampleGroupCfg() policy: ObservationGroupCfg = SampleGroupCfg()
# create observation manager # create observation manager
cfg = MyObservationManagerCfg() cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env) self.obs_man = ObservationManager(cfg, self.env)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 4) self.assertEqual(len(self.obs_man.active_terms["policy"]), 5)
# print the expected string # print the expected string
print() print()
print(self.obs_man) print(self.obs_man)
......
...@@ -101,28 +101,19 @@ class ObservationsCfg: ...@@ -101,28 +101,19 @@ class ObservationsCfg:
"""Observations for policy group.""" """Observations for policy group."""
# observation terms (order preserved) # observation terms (order preserved)
base_lin_vel = ObsTerm( base_lin_vel = ObsTerm(func=mdp.base_lin_vel, noise=Unoise(n_min=-0.1, n_max=0.1))
func=mdp.base_lin_vel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-0.1, n_max=0.1) base_ang_vel = ObsTerm(func=mdp.base_ang_vel, noise=Unoise(n_min=-0.2, n_max=0.2))
)
base_ang_vel = ObsTerm(
func=mdp.base_ang_vel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-0.2, n_max=0.2)
)
projected_gravity = ObsTerm( projected_gravity = ObsTerm(
func=mdp.projected_gravity, func=mdp.projected_gravity,
params={"asset_cfg": SceneEntityCfg("robot")},
noise=Unoise(n_min=-0.05, n_max=0.05), noise=Unoise(n_min=-0.05, n_max=0.05),
) )
velocity_commands = ObsTerm(func=mdp.generated_commands) velocity_commands = ObsTerm(func=mdp.generated_commands)
joint_pos = ObsTerm( joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01))
func=mdp.joint_pos_rel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-0.01, n_max=0.01) joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5))
) actions = ObsTerm(func=mdp.last_action)
joint_vel = ObsTerm(
func=mdp.joint_vel_rel, params={"asset_cfg": SceneEntityCfg("robot")}, noise=Unoise(n_min=-1.5, n_max=1.5)
)
actions = ObsTerm(func=mdp.action)
height_scan = ObsTerm( height_scan = ObsTerm(
func=mdp.height_scan, func=mdp.height_scan,
params={"asset_cfg": SceneEntityCfg("robot"), "sensor_cfg": SceneEntityCfg("height_scanner")}, params={"sensor_cfg": SceneEntityCfg("height_scanner")},
noise=Unoise(n_min=-0.1, n_max=0.1), noise=Unoise(n_min=-0.1, n_max=0.1),
) )
...@@ -172,7 +163,6 @@ class RandomizationCfg: ...@@ -172,7 +163,6 @@ class RandomizationCfg:
func=mdp.reset_root_state, func=mdp.reset_root_state,
mode="reset", mode="reset",
params={ params={
"asset_cfg": SceneEntityCfg("robot"),
"pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)},
"velocity_range": { "velocity_range": {
"x": (-0.5, 0.5), "x": (-0.5, 0.5),
...@@ -189,7 +179,6 @@ class RandomizationCfg: ...@@ -189,7 +179,6 @@ class RandomizationCfg:
func=mdp.reset_joints_by_scale, func=mdp.reset_joints_by_scale,
mode="reset", mode="reset",
params={ params={
"asset_cfg": SceneEntityCfg("robot"),
"position_range": (0.5, 1.5), "position_range": (0.5, 1.5),
"velocity_range": (0.0, 0.0), "velocity_range": (0.0, 0.0),
}, },
...@@ -200,10 +189,7 @@ class RandomizationCfg: ...@@ -200,10 +189,7 @@ class RandomizationCfg:
func=mdp.push_by_setting_velocity, func=mdp.push_by_setting_velocity,
mode="interval", mode="interval",
interval_range_s=(10.0, 15.0), interval_range_s=(10.0, 15.0),
params={ params={"velocity_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5)}},
"asset_cfg": SceneEntityCfg("robot"),
"velocity_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5)},
},
) )
...@@ -212,17 +198,13 @@ class RewardsCfg: ...@@ -212,17 +198,13 @@ class RewardsCfg:
"""Reward terms for the MDP.""" """Reward terms for the MDP."""
# -- task # -- task
track_lin_vel_xy_exp = RewTerm( track_lin_vel_xy_exp = RewTerm(func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"std": math.sqrt(0.25)})
func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"asset_cfg": SceneEntityCfg("robot"), "std": math.sqrt(0.25)} track_ang_vel_z_exp = RewTerm(func=mdp.track_ang_vel_z_exp, weight=0.5, params={"std": math.sqrt(0.25)})
)
track_ang_vel_z_exp = RewTerm(
func=mdp.track_ang_vel_z_exp, weight=0.5, params={"asset_cfg": SceneEntityCfg("robot"), "std": math.sqrt(0.25)}
)
# -- penalties # -- penalties
lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0, params={"asset_cfg": SceneEntityCfg("robot")}) lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0)
ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05, params={"asset_cfg": SceneEntityCfg("robot")}) ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05)
dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-1.0e-5, params={"asset_cfg": SceneEntityCfg("robot")}) dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-1.0e-5)
dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-2.5e-7, params={"asset_cfg": SceneEntityCfg("robot")}) dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-2.5e-7)
action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01) action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01)
feet_air_time = RewTerm( feet_air_time = RewTerm(
func=mdp.feet_air_time, func=mdp.feet_air_time,
...@@ -251,7 +233,7 @@ class TerminationsCfg: ...@@ -251,7 +233,7 @@ class TerminationsCfg:
class CurriculumCfg: class CurriculumCfg:
"""Curriculum terms for the MDP.""" """Curriculum terms for the MDP."""
terrain_levels = CurrTerm(func=mdp.terrain_levels_vel, params={"asset_cfg": SceneEntityCfg("robot")}) terrain_levels = CurrTerm(func=mdp.terrain_levels_vel)
## ##
...@@ -293,10 +275,7 @@ class LocomotionEnvRoughCfg(RLEnvCfg): ...@@ -293,10 +275,7 @@ class LocomotionEnvRoughCfg(RLEnvCfg):
# simulation settings # simulation settings
self.sim.dt = 0.005 self.sim.dt = 0.005
self.sim.disable_contact_processing = True self.sim.disable_contact_processing = True
self.sim.physics_material.static_friction = 1.0 self.sim.physics_material = self.scene.terrain.physics_material
self.sim.physics_material.dynamic_friction = 1.0
self.sim.physics_material.friction_combine_mode = "multiply"
self.sim.physics_material.restitution_combine_mode = "multiply"
# update sensor update periods # update sensor update periods
# we tick all the sensors based on the smallest update period (physics update period) # we tick all the sensors based on the smallest update period (physics update period)
self.scene.height_scanner.update_period = self.decimation * self.sim.dt self.scene.height_scanner.update_period = self.decimation * self.sim.dt
......
...@@ -118,7 +118,7 @@ class RslRlVecEnvWrapper(gym.Wrapper): ...@@ -118,7 +118,7 @@ class RslRlVecEnvWrapper(gym.Wrapper):
# return step information # return step information
obs = obs_dict["policy"] obs = obs_dict["policy"]
extras["observations"] = obs_dict extras["observations"] = obs_dict
return obs, rew, dones, extras return obs, rew, dones.to(torch.long), extras
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment