Unverified Commit 40c5f4df authored by David Hoeller's avatar David Hoeller Committed by GitHub

Moves command generator into managers sub-package (#276)

# Description

This MR adds command manager terms so that it is possible to apply
multiple types of commands in the same environment. Before, you could
only use one. Now you can add multiple, for example, to generate a base
velocity command and an end effector pose command simultaneously.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarDavid Hoeller <dhoeller@ethz.ch>
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent 6f4cc59d
......@@ -16,7 +16,6 @@ The following modules are available in the ``omni.isaac.orbit`` extension:
app
actuators
assets
command_generators
controllers
devices
envs
......
orbit.command\_generators
=========================
.. automodule:: omni.isaac.orbit.command_generators
.. rubric:: Classes
.. autosummary::
CommandGeneratorBase
CommandGeneratorBaseCfg
NullCommandGenerator
NullCommandGeneratorCfg
UniformVelocityCommandGenerator
UniformVelocityCommandGeneratorCfg
NormalVelocityCommandGenerator
NormalVelocityCommandGeneratorCfg
TerrainBasedPositionCommandGenerator
TerrainBasedPositionCommandGeneratorCfg
Command Generator Base
----------------------
.. autoclass:: CommandGeneratorBase
:members:
.. autoclass:: CommandGeneratorBaseCfg
:members:
:exclude-members: __init__, class_type
Null Command Generator
----------------------
.. autoclass:: NullCommandGenerator
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: NullCommandGeneratorCfg
:members:
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type, resampling_time_range
Uniform SE(2) Velocity Command Generator
----------------------------------------
.. autoclass:: UniformVelocityCommandGenerator
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: UniformVelocityCommandGeneratorCfg
:members:
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type
Normal SE(2) Velocity Command Generator
---------------------------------------
.. autoclass:: NormalVelocityCommandGenerator
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: NormalVelocityCommandGeneratorCfg
:members:
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type
Uniform SE(3) Pose Command Generator
------------------------------------
.. autoclass:: UniformPoseCommandGenerator
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: UniformPoseCommandGeneratorCfg
:members:
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type
Terrain-based SE(2) Position Command Generator
----------------------------------------------
.. note::
This command generator is currently not tested. It may not work as expected.
.. autoclass:: TerrainBasedPositionCommandGenerator
:members:
:inherited-members:
:show-inheritance:
.. autoclass:: TerrainBasedPositionCommandGeneratorCfg
:members:
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type
......@@ -25,6 +25,16 @@ Randomization
.. automodule:: omni.isaac.orbit.envs.mdp.randomizations
:members:
Commands
--------
.. automodule:: omni.isaac.orbit.envs.mdp.commands
.. automodule:: omni.isaac.orbit.envs.mdp.commands.commands_cfg
:members:
:show-inheritance:
:exclude-members: __init__, class_type
Rewards
-------
......
......@@ -19,6 +19,9 @@
ActionTermCfg
RandomizationManager
RandomizationTermCfg
CommandManager
CommandTerm
CommandTermCfg
RewardManager
RewardTermCfg
TerminationManager
......@@ -91,6 +94,21 @@ Randomization Manager
:members:
:exclude-members: __init__
Command Manager
---------------
.. autoclass:: CommandManager
:members:
.. autoclass:: CommandTerm
:members:
:exclude-members: __init__, class_type
.. autoclass:: CommandTermCfg
:members:
:exclude-members: __init__, class_type
Reward Manager
--------------
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.10.0"
version = "0.10.1"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.10.1 (2023-12-06)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added command manager class with terms defined by :class:`omni.isaac.orbit.managers.CommandTerm`. This
allow for multiple types of command generators to be used in the same environment.
0.10.0 (2023-12-04)
~~~~~~~~~~~~~~~~~~~
......
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Sub-package for different command generators implementations.
The command generators are used to generate commands for the agent to execute. The command generators act
as utility classes to make it convenient to switch between different command generation strategies within
the same environment. For instance, in an environment consisting of a quadrupedal robot, the command to it
could be a velocity command or position command. By keeping the command generation logic separate from the
environment, it is easy to switch between different command generation strategies.
The command generators are implemented as classes that inherit from the :class:`CommandGeneratorBase` class.
Each command generator class should also have a corresponding configuration class that inherits from the
:class:`CommandGeneratorBaseCfg` class.
"""
from .command_generator_base import CommandGeneratorBase
from .command_generator_cfg import (
CommandGeneratorBaseCfg,
NormalVelocityCommandGeneratorCfg,
NullCommandGeneratorCfg,
TerrainBasedPositionCommandGeneratorCfg,
UniformPoseCommandGeneratorCfg,
UniformVelocityCommandGeneratorCfg,
)
from .null_command_generator import NullCommandGenerator
from .pose_command_generator import UniformPoseCommandGenerator
from .position_command_generator import TerrainBasedPositionCommandGenerator
from .velocity_command_generator import NormalVelocityCommandGenerator, UniformVelocityCommandGenerator
......@@ -16,6 +16,7 @@ are used to define the environment through their managers.
"""
from .actions import * # noqa: F401, F403
from .commands import * # noqa: F401, F403
from .curriculums import * # noqa: F401, F403
from .observations import * # noqa: F401, F403
from .randomizations import * # noqa: F401, F403
......
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Various command terms that can be used in the environment."""
from .commands_cfg import (
NormalVelocityCommandCfg,
NullCommandCfg,
TerrainBasedPositionCommandCfg,
UniformPoseCommandCfg,
UniformVelocityCommandCfg,
)
from .null_command import NullCommand
from .pose_command import UniformPoseCommand
from .position_command import TerrainBasedPositionCommand
from .velocity_command import NormalVelocityCommand, UniformVelocityCommand
......@@ -8,34 +8,13 @@ from __future__ import annotations
import math
from dataclasses import MISSING
from omni.isaac.orbit.managers import CommandTermCfg
from omni.isaac.orbit.utils import configclass
from .command_generator_base import CommandGeneratorBase
from .null_command_generator import NullCommandGenerator
from .pose_command_generator import UniformPoseCommandGenerator
from .position_command_generator import TerrainBasedPositionCommandGenerator
from .velocity_command_generator import NormalVelocityCommandGenerator, UniformVelocityCommandGenerator
"""
Base command generator.
"""
@configclass
class CommandGeneratorBaseCfg:
"""Configuration for the base command generator."""
class_type: type[CommandGeneratorBase] = MISSING
"""The associated command generator class to use.
The class should inherit from :class:`omni.isaac.orbit.command_generators.command_generator_base.CommandGeneratorBase`.
"""
resampling_time_range: tuple[float, float] = MISSING
"""Time before commands are changed [s]."""
debug_vis: bool = False
"""Whether to visualize debug information. Defaults to False."""
from .null_command import NullCommand
from .pose_command import UniformPoseCommand
from .position_command import TerrainBasedPositionCommand
from .velocity_command import NormalVelocityCommand, UniformVelocityCommand
"""
Null-command generator.
......@@ -43,10 +22,10 @@ Null-command generator.
@configclass
class NullCommandGeneratorCfg(CommandGeneratorBaseCfg):
class NullCommandCfg(CommandTermCfg):
"""Configuration for the null command generator."""
class_type: type = NullCommandGenerator
class_type: type = NullCommand
def __post_init__(self):
"""Post initialization."""
......@@ -60,10 +39,10 @@ Locomotion-specific command generators.
@configclass
class UniformVelocityCommandGeneratorCfg(CommandGeneratorBaseCfg):
class UniformVelocityCommandCfg(CommandTermCfg):
"""Configuration for the uniform velocity command generator."""
class_type: type = UniformVelocityCommandGenerator
class_type: type = UniformVelocityCommand
asset_name: str = MISSING
"""Name of the asset in the environment for which the commands are generated."""
......@@ -94,10 +73,10 @@ class UniformVelocityCommandGeneratorCfg(CommandGeneratorBaseCfg):
@configclass
class NormalVelocityCommandGeneratorCfg(UniformVelocityCommandGeneratorCfg):
class NormalVelocityCommandCfg(UniformVelocityCommandCfg):
"""Configuration for the normal velocity command generator."""
class_type: type = NormalVelocityCommandGenerator
class_type: type = NormalVelocityCommand
heading_command: bool = False # --> we don't use heading command for normal velocity command.
@configclass
......@@ -125,10 +104,10 @@ class NormalVelocityCommandGeneratorCfg(UniformVelocityCommandGeneratorCfg):
@configclass
class UniformPoseCommandGeneratorCfg(CommandGeneratorBaseCfg):
class UniformPoseCommandCfg(CommandTermCfg):
"""Configuration for uniform pose command generator."""
class_type: type = UniformPoseCommandGenerator
class_type: type = UniformPoseCommand
asset_name: str = MISSING
"""Name of the asset in the environment for which the commands are generated."""
......@@ -151,10 +130,10 @@ class UniformPoseCommandGeneratorCfg(CommandGeneratorBaseCfg):
@configclass
class TerrainBasedPositionCommandGeneratorCfg(CommandGeneratorBaseCfg):
class TerrainBasedPositionCommandCfg(CommandTermCfg):
"""Configuration for the terrain-based position command generator."""
class_type: type = TerrainBasedPositionCommandGenerator
class_type: type = TerrainBasedPositionCommand
asset_name: str = MISSING
"""Name of the asset in the environment for which the commands are generated."""
......
......@@ -9,24 +9,24 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
from omni.isaac.orbit.command_generators.command_generator_base import CommandGeneratorBase
from omni.isaac.orbit.managers import CommandTerm
if TYPE_CHECKING:
from .command_generator_cfg import NullCommandGeneratorCfg
from .commands_cfg import NullCommandCfg
class NullCommandGenerator(CommandGeneratorBase):
class NullCommand(CommandTerm):
"""Command generator that does nothing.
This command generator does not generate any commands. It is used for environments that do not
require any commands.
"""
cfg: NullCommandGeneratorCfg
cfg: NullCommandCfg
"""Configuration for the command generator."""
def __str__(self) -> str:
msg = "NullCommandGenerator:\n"
msg = "NullCommand:\n"
msg += "\tCommand dimension: N/A\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}"
return msg
......@@ -42,7 +42,7 @@ class NullCommandGenerator(CommandGeneratorBase):
Raises:
RuntimeError: No command is generated. Always raises this error.
"""
raise RuntimeError("NullCommandGenerator does not generate any commands.")
raise RuntimeError("NullCommandTerm does not generate any commands.")
"""
Operations.
......
......@@ -11,19 +11,18 @@ import torch
from typing import TYPE_CHECKING, Sequence
from omni.isaac.orbit.assets import Articulation
from omni.isaac.orbit.managers import CommandTerm
from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.markers.config import FRAME_MARKER_CFG
from omni.isaac.orbit.utils.math import combine_frame_transforms, compute_pose_error, quat_from_euler_xyz
from .command_generator_base import CommandGeneratorBase
if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv
from .command_generator_cfg import UniformPoseCommandGeneratorCfg
from .commands_cfg import UniformPoseCommandCfg
class UniformPoseCommandGenerator(CommandGeneratorBase):
class UniformPoseCommand(CommandTerm):
"""Command generator for generating pose commands uniformly.
The command generator generates poses by sampling positions uniformly within specified
......@@ -42,10 +41,10 @@ class UniformPoseCommandGenerator(CommandGeneratorBase):
"""
cfg: UniformPoseCommandGeneratorCfg
cfg: UniformPoseCommandCfg
"""Configuration for the command generator."""
def __init__(self, cfg: UniformPoseCommandGeneratorCfg, env: BaseEnv):
def __init__(self, cfg: UniformPoseCommandCfg, env: BaseEnv):
"""Initialize the command generator class.
Args:
......@@ -69,7 +68,7 @@ class UniformPoseCommandGenerator(CommandGeneratorBase):
self.metrics["orientation_error"] = torch.zeros(self.num_envs, device=self.device)
def __str__(self) -> str:
msg = "PoseCommandGenerator:\n"
msg = "UniformPoseCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
return msg
......
......@@ -11,30 +11,29 @@ import torch
from typing import TYPE_CHECKING, Sequence
from omni.isaac.orbit.assets import Articulation
from omni.isaac.orbit.managers import CommandTerm
from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.markers.config import CUBOID_MARKER_CFG
from omni.isaac.orbit.terrains import TerrainImporter
from omni.isaac.orbit.utils.math import quat_rotate_inverse, wrap_to_pi, yaw_quat
from .command_generator_base import CommandGeneratorBase
if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv
from .command_generator_cfg import TerrainBasedPositionCommandGeneratorCfg
from .commands_cfg import TerrainBasedPositionCommandCfg
class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
class TerrainBasedPositionCommand(CommandTerm):
"""Command generator that generates position commands based on the terrain.
The position commands are sampled from the terrain mesh and the heading commands are either set
to point towards the target or are sampled uniformly.
"""
cfg: TerrainBasedPositionCommandGeneratorCfg
cfg: TerrainBasedPositionCommandCfg
"""Configuration for the command generator."""
def __init__(self, cfg: TerrainBasedPositionCommandGeneratorCfg, env: BaseEnv):
def __init__(self, cfg: TerrainBasedPositionCommandCfg, env: BaseEnv):
"""Initialize the command generator class.
Args:
......@@ -61,7 +60,7 @@ class TerrainBasedPositionCommandGenerator(CommandGeneratorBase):
self.metrics["error_heading"] = torch.zeros(self.num_envs, device=self.device)
def __str__(self) -> str:
msg = "TerrainBasedPositionCommandGenerator:\n"
msg = "TerrainBasedPositionCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"\tStanding probability: {self.cfg.rel_standing_envs}"
......
......@@ -12,18 +12,17 @@ from typing import TYPE_CHECKING, Sequence
import omni.isaac.orbit.utils.math as math_utils
from omni.isaac.orbit.assets import Articulation
from omni.isaac.orbit.managers import CommandTerm
from omni.isaac.orbit.markers import VisualizationMarkers
from omni.isaac.orbit.markers.config import BLUE_ARROW_X_MARKER_CFG, GREEN_ARROW_X_MARKER_CFG
from .command_generator_base import CommandGeneratorBase
if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv
from .command_generator_cfg import NormalVelocityCommandGeneratorCfg, UniformVelocityCommandGeneratorCfg
from .commands_cfg import NormalVelocityCommandCfg, UniformVelocityCommandCfg
class UniformVelocityCommandGenerator(CommandGeneratorBase):
class UniformVelocityCommand(CommandTerm):
r"""Command generator that generates a velocity command in SE(2) from uniform distribution.
The command comprises of a linear velocity in x and y direction and an angular velocity around
......@@ -41,10 +40,10 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
"""
cfg: UniformVelocityCommandGeneratorCfg
cfg: UniformVelocityCommandCfg
"""The configuration of the command generator."""
def __init__(self, cfg: UniformVelocityCommandGeneratorCfg, env: BaseEnv):
def __init__(self, cfg: UniformVelocityCommandCfg, env: BaseEnv):
"""Initialize the command generator.
Args:
......@@ -70,7 +69,7 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
def __str__(self) -> str:
"""Return a string representation of the command generator."""
msg = "UniformVelocityCommandGenerator:\n"
msg = "UniformVelocityCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"\tHeading command: {self.cfg.heading_command}\n"
......@@ -199,7 +198,7 @@ class UniformVelocityCommandGenerator(CommandGeneratorBase):
return arrow_scale, arrow_quat
class NormalVelocityCommandGenerator(UniformVelocityCommandGenerator):
class NormalVelocityCommand(UniformVelocityCommand):
"""Command generator that generates a velocity command in SE(2) from a normal distribution.
The command comprises of a linear velocity in x and y direction and an angular velocity around
......@@ -209,10 +208,10 @@ class NormalVelocityCommandGenerator(UniformVelocityCommandGenerator):
the configuration. With equal probability, the sign of the individual components is flipped.
"""
cfg: NormalVelocityCommandGeneratorCfg
cfg: NormalVelocityCommandCfg
"""The command generator configuration."""
def __init__(self, cfg: NormalVelocityCommandGeneratorCfg, env: object):
def __init__(self, cfg: NormalVelocityCommandCfg, env: object):
"""Initializes the command generator.
Args:
......@@ -227,7 +226,7 @@ class NormalVelocityCommandGenerator(UniformVelocityCommandGenerator):
def __str__(self) -> str:
"""Return a string representation of the command generator."""
msg = "NormalVelocityCommandGenerator:\n"
msg = "NormalVelocityCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"\tStanding probability: {self.cfg.rel_standing_envs}"
......
......@@ -11,48 +11,12 @@ the curriculum introduced by the function.
from __future__ import annotations
import torch
from typing import TYPE_CHECKING, Sequence
from omni.isaac.orbit.assets import RigidObject
from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.terrains import TerrainImporter
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
def terrain_levels_vel(
env: RLTaskEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Curriculum based on the distance the robot walked when commanded to move at a desired velocity.
This term is used to increase the difficulty of the terrain when the robot walks far enough and decrease the
difficulty when the robot walks less than half of the distance required by the commanded velocity.
.. note::
It is only possible to use this term with the terrain type ``generator``. For further information
on different terrain types, check the :class:`omni.isaac.orbit.terrains.TerrainImporter` class.
Returns:
The mean terrain level for the given environment ids.
"""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
terrain: TerrainImporter = env.scene.terrain
# compute the distance the robot walked
distance = torch.norm(asset.data.root_pos_w[env_ids, :2] - env.scene.env_origins[env_ids, :2], dim=1)
# robots that walked far enough progress to harder terrains
move_up = distance > terrain.cfg.terrain_generator.size[0] / 2
# robots that walked less than half of their required distance go to simpler terrains
move_down = distance < torch.norm(env.command_manager.command[env_ids, :2], dim=1) * env.max_episode_length_s * 0.5
move_down *= ~move_up
# update terrain levels
terrain.update_env_origins(env_ids, move_up, move_down)
# return the mean terrain level
return torch.mean(terrain.terrain_levels.float())
def modify_reward_weight(env: RLTaskEnv, env_ids: Sequence[int], term_name: str, weight: float, num_steps: int):
"""Curriculum that modifies a reward weight a given number of steps.
......
......@@ -123,6 +123,6 @@ Commands.
"""
def generated_commands(env: RLTaskEnv) -> torch.Tensor:
"""The generated command from the command generator."""
return env.command_manager.command
def generated_commands(env: RLTaskEnv, command_name: str) -> torch.Tensor:
"""The generated command from command term in the command manager with the given name."""
return env.command_manager.get_command(command_name)
......@@ -215,24 +215,25 @@ Velocity-tracking rewards.
def track_lin_vel_xy_exp(
env: RLTaskEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
env: RLTaskEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Reward tracking of linear velocity commands (xy axes) using exponential kernel."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
# compute the error
lin_vel_error = torch.sum(
torch.square(env.command_manager.command[:, :2] - asset.data.root_lin_vel_b[:, :2]), dim=1
torch.square(env.command_manager.get_command(command_name)[:, :2] - asset.data.root_lin_vel_b[:, :2]),
dim=1,
)
return torch.exp(-lin_vel_error / std**2)
def track_ang_vel_z_exp(
env: RLTaskEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
env: RLTaskEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Reward tracking of angular velocity commands (yaw) using exponential kernel."""
# extract the used quantities (to enable type-hinting)
asset: RigidObject = env.scene[asset_cfg.name]
# compute the error
ang_vel_error = torch.square(env.command_manager.command[:, 2] - asset.data.root_ang_vel_b[:, 2])
ang_vel_error = torch.square(env.command_manager.get_command(command_name)[:, 2] - asset.data.root_ang_vel_b[:, 2])
return torch.exp(-ang_vel_error / std**2)
......@@ -13,8 +13,7 @@ from typing import Any, ClassVar, Dict, Sequence, Tuple
from omni.isaac.version import get_version
from omni.isaac.orbit.command_generators import CommandGeneratorBase
from omni.isaac.orbit.managers import CurriculumManager, RewardManager, TerminationManager
from omni.isaac.orbit.managers import CommandManager, CurriculumManager, RewardManager, TerminationManager
from .base_env import BaseEnv, VecEnvObs
from .rl_task_env_cfg import RLTaskEnvCfg
......@@ -121,7 +120,7 @@ class RLTaskEnv(BaseEnv, gym.Env):
def load_managers(self):
# note: this order is important since observation manager needs to know the command and action managers
# -- command manager
self.command_manager: CommandGeneratorBase = self.cfg.commands.class_type(self.cfg.commands, self)
self.command_manager: CommandManager = CommandManager(self.cfg.commands, self)
print("[INFO] Command Manager: ", self.command_manager)
# call the parent class to load the managers for observations and actions.
super().load_managers()
......
......@@ -7,7 +7,6 @@ from __future__ import annotations
from dataclasses import MISSING
from omni.isaac.orbit.command_generators import CommandGeneratorBaseCfg
from omni.isaac.orbit.utils import configclass
from .base_env_cfg import BaseEnvCfg
......@@ -32,5 +31,5 @@ class RLTaskEnvCfg(BaseEnvCfg):
"""Termination settings."""
curriculum: object = MISSING
"""Curriculum settings."""
commands: CommandGeneratorBaseCfg = MISSING
"""Command generator settings."""
commands: object = MISSING
"""Command settings."""
......@@ -11,10 +11,12 @@ designed to be modular and can be easily extended to support new functionality.
"""
from .action_manager import ActionManager, ActionTerm
from .command_manager import CommandManager, CommandTerm
from .curriculum_manager import CurriculumManager
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import (
ActionTermCfg,
CommandTermCfg,
CurriculumTermCfg,
ManagerTermBaseCfg,
ObservationGroupCfg,
......
......@@ -3,52 +3,47 @@
#
# SPDX-License-Identifier: BSD-3-Clause
"""Base class for command generators.
This class defines an interface for command generators that can be used for goal-conditioned
tasks. Each command generator class should inherit from this class and implement the abstract
methods.
"""
"""Command manager for generating and updating commands."""
from __future__ import annotations
import inspect
import torch
import weakref
from abc import ABC, abstractmethod
from abc import abstractmethod
from prettytable import PrettyTable
from typing import TYPE_CHECKING, Sequence
import omni.kit.app
if TYPE_CHECKING:
from omni.isaac.orbit.envs import BaseEnv
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import CommandTermCfg
from .command_generator_cfg import CommandGeneratorBaseCfg
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
class CommandGeneratorBase(ABC):
"""The base class for implementing a command generator.
class CommandTerm(ManagerTermBase):
"""The base class for implementing a command term.
A command generator is used to generate commands for goal-conditioned tasks. For example,
in the case of a goal-conditioned navigation task, the command generator can be used to
A command term is used to generate commands for goal-conditioned tasks. For example,
in the case of a goal-conditioned navigation task, the command term can be used to
generate a target position for the robot to navigate to.
The command generator implements a resampling mechanism that allows the command to be
resampled at a fixed frequency. The resampling frequency can be specified in the
configuration object. Additionally, it is possible to assign a visualization function
to the command generator that can be used to visualize the command in the simulator.
It implements a resampling mechanism that allows the command to be resampled at a fixed
frequency. The resampling frequency can be specified in the configuration object.
Additionally, it is possible to assign a visualization function to the command term
that can be used to visualize the command in the simulator.
"""
def __init__(self, cfg: CommandGeneratorBaseCfg, env: BaseEnv):
def __init__(self, cfg: CommandTermCfg, env: RLTaskEnv):
"""Initialize the command generator class.
Args:
cfg: The configuration parameters for the command generator.
env: The environment object.
"""
# store the inputs
self.cfg = cfg
self._env = env
super().__init__(cfg, env)
# create buffers to store the command
# -- metrics that can be used for logging
......@@ -73,16 +68,6 @@ class CommandGeneratorBase(ABC):
Properties
"""
@property
def num_envs(self) -> int:
"""Number of environments."""
return self._env.num_envs
@property
def device(self) -> str:
"""Device on which to perform computations."""
return self._env.device
@property
@abstractmethod
def command(self) -> torch.Tensor:
......@@ -141,7 +126,7 @@ class CommandGeneratorBase(ABC):
env_ids: The list of environment IDs to reset. Defaults to None.
Returns:
A dictionary containing the information to log under the "Metrics/{name}" key.
A dictionary containing the information to log under the "{name}" key.
"""
# resolve the environment IDs
if env_ids is None:
......@@ -154,7 +139,7 @@ class CommandGeneratorBase(ABC):
extras = {}
for metric_name, metric_value in self.metrics.items():
# compute the mean metric value
extras[f"Metrics/{metric_name}"] = torch.mean(metric_value[env_ids]).item()
extras[metric_name] = torch.mean(metric_value[env_ids]).item()
# reset the metric value
metric_value[env_ids] = 0.0
return extras
......@@ -230,3 +215,169 @@ class CommandGeneratorBase(ABC):
This function calls the visualization objects and sets the data to visualize into them.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
class CommandManager(ManagerBase):
"""Manager for generating commands.
The command manager is used to generate commands for an agent to execute. It makes it convenient to switch
between different command generation strategies within the same environment. For instance, in an environment
consisting of a quadrupedal robot, the command to it could be a velocity command or position command.
By keeping the command generation logic separate from the environment, it is easy to switch between different
command generation strategies.
The command terms are implemented as classes that inherit from the :class:`CommandTerm` class.
Each command generator term should also have a corresponding configuration class that inherits from the
:class:`CommandTermCfg` class.
"""
_env: RLTaskEnv
"""The environment instance."""
def __init__(self, cfg: object, env: RLTaskEnv):
"""Initialize the command manager.
Args:
cfg: The configuration object or dictionary (``dict[str, CommandTermCfg]``).
env: The environment instance.
"""
super().__init__(cfg, env)
# store the commands
self._commands = dict()
self.cfg.debug_vis = False
for term in self._terms.values():
self.cfg.debug_vis |= term.cfg.debug_vis
def __str__(self) -> str:
"""Returns: A string representation for the command manager."""
msg = f"<CommandManager> contains {len(self._terms.values())} active terms.\n"
# create table for term information
table = PrettyTable()
table.title = "Active Command Terms"
table.field_names = ["Index", "Name", "Type"]
# set alignment of table columns
table.align["Name"] = "l"
# add info on each term
for index, (name, term) in enumerate(self._terms.items()):
table.add_row([index, name, term.__class__.__name__])
# convert table to string
msg += table.get_string()
return msg
"""
Properties.
"""
@property
def active_terms(self) -> list[str]:
"""Name of active command terms."""
return list(self._terms.keys())
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the command terms have debug visualization implemented."""
# check if function raises NotImplementedError
has_debug_vis = False
for term in self._terms.values():
has_debug_vis |= term.has_debug_vis_implementation
return has_debug_vis
"""
Operations.
"""
def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the command data.
Args:
debug_vis: Whether to visualize the command data.
Returns:
Whether the debug visualization was successfully set. False if the command
generator does not support debug visualization.
"""
for term in self._terms.values():
term.set_debug_vis(debug_vis)
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]:
"""Reset the command terms and log their metrics.
This function resets the command counter and resamples the command for each term. It should be called
at the beginning of each episode.
Args:
env_ids: The list of environment IDs to reset. Defaults to None.
Returns:
A dictionary containing the information to log under the "Metrics/{term_name}/{metric_name}" key.
"""
# resolve environment ids
if env_ids is None:
env_ids = slice(None)
# store information
extras = {}
for name, term in self._terms.items():
# reset the command term
metrics = term.reset(env_ids=env_ids)
# compute the mean metric value
for metric_name, metric_value in metrics.items():
extras[f"Metrics/{name}/{metric_name}"] = metric_value
# return logged information
return extras
def compute(self, dt: float):
"""Updates the commands.
This function calls each command term managed by the class.
Args:
dt: The time-step interval of the environment.
"""
# iterate over all the command terms
for term in self._terms.values():
# compute term's value
term.compute(dt)
def get_command(self, name: str) -> torch.Tensor:
"""Returns the command for the specified command term.
Args:
name: The name of the command term.
Returns:
The command tensor of the specified command term.
"""
return self._terms[name].command
"""
Helper functions.
"""
def _prepare_terms(self):
"""Prepares a list of command terms."""
# parse command terms from the config
self._terms: dict[str, CommandTerm] = dict()
# check if config is dict already
if isinstance(self.cfg, dict):
cfg_items = self.cfg.items()
else:
cfg_items = self.cfg.__dict__.items()
# iterate over all the terms
for term_name, term_cfg in cfg_items:
# check for non config
if term_cfg is None:
continue
# check for valid config type
if not isinstance(term_cfg, CommandTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type CommandTermCfg."
f" Received: '{type(term_cfg)}'."
)
# create the action term
term = term_cfg.class_type(term_cfg, self._env)
# add class to dict
self._terms[term_name] = term
......@@ -18,6 +18,7 @@ from .scene_entity_cfg import SceneEntityCfg
if TYPE_CHECKING:
from .action_manager import ActionTerm
from .command_manager import CommandTerm
from .manager_base import ManagerTermBase
......@@ -71,6 +72,27 @@ class ActionTermCfg:
"""
##
# Command manager.
##
@configclass
class CommandTermCfg:
"""Configuration for a command generator term."""
class_type: type[CommandTerm] = MISSING
"""The associated command term class to use.
The class should inherit from :class:`omni.isaac.orbit.managers.command_manager.CommandTerm`.
"""
resampling_time_range: tuple[float, float] = MISSING
"""Time before commands are changed [s]."""
debug_vis: bool = False
"""Whether to visualize debug information. Defaults to False."""
##
# Curriculum manager.
##
......
......@@ -18,10 +18,10 @@ simulation_app = SimulationApp(config)
import unittest
from collections import namedtuple
from omni.isaac.orbit.command_generators import NullCommandGeneratorCfg
from omni.isaac.orbit.envs.mdp import NullCommandCfg
class TestNullCommandGeneratorCfg(unittest.TestCase):
class TestNullCommandTerm(unittest.TestCase):
"""Test cases for null command generator."""
def setUp(self) -> None:
......@@ -29,24 +29,24 @@ class TestNullCommandGeneratorCfg(unittest.TestCase):
def test_str(self):
"""Test the string representation of the command manager."""
cfg = NullCommandGeneratorCfg()
command_manager = cfg.class_type(cfg, self.env)
cfg = NullCommandCfg()
command_term = cfg.class_type(cfg, self.env)
# print the expected string
print()
print(command_manager)
print(command_term)
def test_compute(self):
"""Test the compute function. For null command generator, it does nothing."""
cfg = NullCommandGeneratorCfg()
command_manager = cfg.class_type(cfg, self.env)
cfg = NullCommandCfg()
command_term = cfg.class_type(cfg, self.env)
# test the reset function
command_manager.reset()
command_term.reset()
# test the compute function
command_manager.compute(dt=self.env.dt)
command_term.compute(dt=self.env.dt)
# expect error
with self.assertRaises(RuntimeError):
command_manager.command
command_term.command
if __name__ == "__main__":
......
......@@ -8,7 +8,6 @@ from __future__ import annotations
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.actuators import ImplicitActuatorCfg
from omni.isaac.orbit.assets import ArticulationCfg, AssetBaseCfg
from omni.isaac.orbit.command_generators import NullCommandGeneratorCfg
from omni.isaac.orbit.envs import RLTaskEnvCfg
from omni.isaac.orbit.managers import ObservationGroupCfg as ObsGroup
from omni.isaac.orbit.managers import ObservationTermCfg as ObsTerm
......@@ -86,6 +85,19 @@ class MySceneCfg(InteractiveSceneCfg):
)
##
# MDP settings
##
@configclass
class CommandsCfg:
"""Command terms for the MDP."""
# no commands for this MDP
null = mdp.NullCommandCfg()
@configclass
class ActionsCfg:
"""Action specifications for the MDP."""
......@@ -199,7 +211,7 @@ class AntEnvCfg(RLTaskEnvCfg):
# Basic settings
observations: ObservationsCfg = ObservationsCfg()
actions: ActionsCfg = ActionsCfg()
commands: NullCommandGeneratorCfg = NullCommandGeneratorCfg()
commands: CommandsCfg = CommandsCfg()
# MDP settings
rewards: RewardsCfg = RewardsCfg()
......
......@@ -10,7 +10,6 @@ from omni.isaac.orbit_assets import ORBIT_ASSETS_DATA_DIR
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.actuators import ImplicitActuatorCfg
from omni.isaac.orbit.assets import ArticulationCfg, AssetBaseCfg
from omni.isaac.orbit.command_generators.command_generator_cfg import NullCommandGeneratorCfg
from omni.isaac.orbit.envs import RLTaskEnvCfg
from omni.isaac.orbit.managers import ObservationGroupCfg as ObsGroup
from omni.isaac.orbit.managers import ObservationTermCfg as ObsTerm
......@@ -90,6 +89,14 @@ class CartpoleSceneCfg(InteractiveSceneCfg):
##
@configclass
class CommandsCfg:
"""Command terms for the MDP."""
# no commands for this MDP
null = mdp.NullCommandCfg()
@configclass
class ActionsCfg:
"""Action specifications for the MDP."""
......@@ -211,7 +218,7 @@ class CartpoleEnvCfg(RLTaskEnvCfg):
rewards: RewardsCfg = RewardsCfg()
terminations: TerminationsCfg = TerminationsCfg()
# No command generator
commands: NullCommandGeneratorCfg = NullCommandGeneratorCfg()
commands: CommandsCfg = CommandsCfg()
def __post_init__(self) -> None:
"""Post initialization."""
......
......@@ -8,7 +8,6 @@ from __future__ import annotations
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.actuators import ImplicitActuatorCfg
from omni.isaac.orbit.assets import ArticulationCfg, AssetBaseCfg
from omni.isaac.orbit.command_generators import NullCommandGeneratorCfg
from omni.isaac.orbit.envs import RLTaskEnvCfg
from omni.isaac.orbit.managers import ObservationGroupCfg as ObsGroup
from omni.isaac.orbit.managers import ObservationTermCfg as ObsTerm
......@@ -100,6 +99,19 @@ class MySceneCfg(InteractiveSceneCfg):
)
##
# MDP settings
##
@configclass
class CommandsCfg:
"""Command terms for the MDP."""
# no commands for this MDP
null = mdp.NullCommandCfg()
@configclass
class ActionsCfg:
"""Action specifications for the MDP."""
......@@ -254,7 +266,7 @@ class HumanoidEnvCfg(RLTaskEnvCfg):
# Basic settings
observations: ObservationsCfg = ObservationsCfg()
actions: ActionsCfg = ActionsCfg()
commands: NullCommandGeneratorCfg = NullCommandGeneratorCfg()
commands: CommandsCfg = CommandsCfg()
# MDP settings
rewards: RewardsCfg = RewardsCfg()
......
......@@ -7,4 +7,5 @@
from omni.isaac.orbit.envs.mdp import * # noqa: F401, F403
from .curriculums import * # noqa: F401, F403
from .rewards import * # noqa: F401, F403
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Common functions that can be used to create curriculum for the learning environment.
The functions can be passed to the :class:`omni.isaac.orbit.managers.CurriculumTermCfg` object to enable
the curriculum introduced by the function.
"""
from __future__ import annotations
import torch
from typing import TYPE_CHECKING, Sequence
from omni.isaac.orbit.assets import Articulation
from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.terrains import TerrainImporter
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
def terrain_levels_vel(
env: RLTaskEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Curriculum based on the distance the robot walked when commanded to move at a desired velocity.
This term is used to increase the difficulty of the terrain when the robot walks far enough and decrease the
difficulty when the robot walks less than half of the distance required by the commanded velocity.
.. note::
It is only possible to use this term with the terrain type ``generator``. For further information
on different terrain types, check the :class:`omni.isaac.orbit.terrains.TerrainImporter` class.
Returns:
The mean terrain level for the given environment ids.
"""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
terrain: TerrainImporter = env.scene.terrain
command = env.command_manager.get_command("base_velocity")
# compute the distance the robot walked
distance = torch.norm(asset.data.root_pos_w[env_ids, :2] - env.scene.env_origins[env_ids, :2], dim=1)
# robots that walked far enough progress to harder terrains
move_up = distance > terrain.cfg.terrain_generator.size[0] / 2
# robots that walked less than half of their required distance go to simpler terrains
move_down = distance < torch.norm(command[env_ids, :2], dim=1) * env.max_episode_length_s * 0.5
move_down *= ~move_up
# update terrain levels
terrain.update_env_origins(env_ids, move_up, move_down)
# return the mean terrain level
return torch.mean(terrain.terrain_levels.float())
......@@ -15,7 +15,7 @@ if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
def feet_air_time(env: RLTaskEnv, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor:
def feet_air_time(env: RLTaskEnv, command_name: str, sensor_cfg: SceneEntityCfg, threshold: float) -> torch.Tensor:
"""Reward long steps taken by the feet using L2-kernel.
This function rewards the agent for taking steps that are longer than a threshold. This helps ensure
......@@ -31,5 +31,5 @@ def feet_air_time(env: RLTaskEnv, sensor_cfg: SceneEntityCfg, threshold: float)
first_contact = last_air_time > 0.0
reward = torch.sum((last_air_time - threshold) * first_contact, dim=1)
# no reward for zero command
reward *= torch.norm(env.command_manager.command[:, :2], dim=1) > 0.1
reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1
return reward
......@@ -10,7 +10,6 @@ from dataclasses import MISSING
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.assets import ArticulationCfg, AssetBaseCfg
from omni.isaac.orbit.command_generators import UniformVelocityCommandGeneratorCfg
from omni.isaac.orbit.envs import RLTaskEnvCfg
from omni.isaac.orbit.managers import CurriculumTermCfg as CurrTerm
from omni.isaac.orbit.managers import ObservationGroupCfg as ObsGroup
......@@ -89,6 +88,23 @@ class MySceneCfg(InteractiveSceneCfg):
##
@configclass
class CommandsCfg:
"""Command specifications for the MDP."""
base_velocity = mdp.UniformVelocityCommandCfg(
asset_name="robot",
resampling_time_range=(10.0, 10.0),
rel_standing_envs=0.02,
rel_heading_envs=1.0,
heading_command=True,
debug_vis=True,
ranges=mdp.UniformVelocityCommandCfg.Ranges(
lin_vel_x=(-1.0, 1.0), lin_vel_y=(-1.0, 1.0), ang_vel_z=(-1.0, 1.0), heading=(-math.pi, math.pi)
),
)
@configclass
class ActionsCfg:
"""Action specifications for the MDP."""
......@@ -111,7 +127,7 @@ class ObservationsCfg:
func=mdp.projected_gravity,
noise=Unoise(n_min=-0.05, n_max=0.05),
)
velocity_commands = ObsTerm(func=mdp.generated_commands)
velocity_commands = ObsTerm(func=mdp.generated_commands, params={"command_name": "base_velocity"})
joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01))
joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5))
actions = ObsTerm(func=mdp.last_action)
......@@ -203,8 +219,12 @@ class RewardsCfg:
"""Reward terms for the MDP."""
# -- task
track_lin_vel_xy_exp = RewTerm(func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"std": math.sqrt(0.25)})
track_ang_vel_z_exp = RewTerm(func=mdp.track_ang_vel_z_exp, weight=0.5, params={"std": math.sqrt(0.25)})
track_lin_vel_xy_exp = RewTerm(
func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"command_name": "base_velocity", "std": math.sqrt(0.25)}
)
track_ang_vel_z_exp = RewTerm(
func=mdp.track_ang_vel_z_exp, weight=0.5, params={"command_name": "base_velocity", "std": math.sqrt(0.25)}
)
# -- penalties
lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0)
ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05)
......@@ -214,7 +234,11 @@ class RewardsCfg:
feet_air_time = RewTerm(
func=mdp.feet_air_time,
weight=0.5,
params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*SHANK"), "threshold": 0.5},
params={
"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*SHANK"),
"command_name": "base_velocity",
"threshold": 0.5,
},
)
undesired_contacts = RewTerm(
func=mdp.undesired_contacts,
......@@ -258,17 +282,7 @@ class LocomotionVelocityRoughEnvCfg(RLTaskEnvCfg):
# Basic settings
observations: ObservationsCfg = ObservationsCfg()
actions: ActionsCfg = ActionsCfg()
commands: UniformVelocityCommandGeneratorCfg = UniformVelocityCommandGeneratorCfg(
asset_name="robot",
resampling_time_range=(10.0, 10.0),
rel_standing_envs=0.02,
rel_heading_envs=1.0,
heading_command=True,
debug_vis=True,
ranges=UniformVelocityCommandGeneratorCfg.Ranges(
lin_vel_x=(-1.0, 1.0), lin_vel_y=(-1.0, 1.0), ang_vel_z=(-1.0, 1.0), heading=(-math.pi, math.pi)
),
)
commands: CommandsCfg = CommandsCfg()
# MDP settings
rewards: RewardsCfg = RewardsCfg()
terminations: TerminationsCfg = TerminationsCfg()
......
......@@ -42,7 +42,7 @@ class FrankaCubeLiftEnvCfg(LiftEnvCfg):
close_command_expr={"panda_finger_.*": 0.0},
)
# Set the body name for the end effector
self.commands.body_name = "panda_hand"
self.commands.object_pose.body_name = "panda_hand"
# Set Cube as object
self.scene.object = RigidObjectCfg(
......
......@@ -13,7 +13,6 @@ from dataclasses import MISSING
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.assets import ArticulationCfg, AssetBaseCfg, RigidObjectCfg
from omni.isaac.orbit.command_generators.command_generator_cfg import UniformPoseCommandGeneratorCfg
from omni.isaac.orbit.envs import RLTaskEnvCfg
from omni.isaac.orbit.managers import CurriculumTermCfg as CurrTerm
from omni.isaac.orbit.managers import ObservationGroupCfg as ObsGroup
......@@ -75,6 +74,21 @@ class ObjectTableSceneCfg(InteractiveSceneCfg):
##
@configclass
class CommandsCfg:
"""Command terms for the MDP."""
object_pose = mdp.UniformPoseCommandCfg(
asset_name="robot",
body_name=MISSING, # will be set by agent env cfg
resampling_time_range=(5.0, 5.0),
debug_vis=True,
ranges=mdp.UniformPoseCommandCfg.Ranges(
pos_x=(0.4, 0.6), pos_y=(-0.25, 0.25), pos_z=(0.25, 0.5), roll=(0.0, 0.0), pitch=(0.0, 0.0), yaw=(0.0, 0.0)
),
)
@configclass
class ActionsCfg:
"""Action specifications for the MDP."""
......@@ -95,7 +109,7 @@ class ObservationsCfg:
joint_pos = ObsTerm(func=mdp.joint_pos_rel)
joint_vel = ObsTerm(func=mdp.joint_vel_rel)
object_position = ObsTerm(func=mdp.object_position_in_robot_root_frame)
target_object_position = ObsTerm(func=mdp.generated_commands)
target_object_position = ObsTerm(func=mdp.generated_commands, params={"command_name": "object_pose"})
actions = ObsTerm(func=mdp.last_action)
def __post_init__(self):
......@@ -133,13 +147,13 @@ class RewardsCfg:
object_goal_tracking = RewTerm(
func=mdp.object_goal_distance,
params={"std": 0.3, "minimal_height": 0.06},
params={"std": 0.3, "minimal_height": 0.06, "command_name": "object_pose"},
weight=16.0,
)
object_goal_tracking_fine_grained = RewTerm(
func=mdp.object_goal_distance,
params={"std": 0.05, "minimal_height": 0.06},
params={"std": 0.05, "minimal_height": 0.06, "command_name": "object_pose"},
weight=5.0,
)
......@@ -188,22 +202,10 @@ class LiftEnvCfg(RLTaskEnvCfg):
# Scene settings
scene: ObjectTableSceneCfg = ObjectTableSceneCfg(num_envs=4096, env_spacing=2.5, replicate_physics=False)
# Basic settings
observations: ObservationsCfg = ObservationsCfg()
actions: ActionsCfg = ActionsCfg()
commands: UniformPoseCommandGeneratorCfg = UniformPoseCommandGeneratorCfg(
asset_name="robot",
body_name=MISSING, # will be set by agent env cfg
resampling_time_range=(5.0, 5.0),
debug_vis=True,
ranges=UniformPoseCommandGeneratorCfg.Ranges(
pos_x=(0.4, 0.6), pos_y=(-0.25, 0.25), pos_z=(0.25, 0.5), roll=(0.0, 0.0), pitch=(0.0, 0.0), yaw=(0.0, 0.0)
),
)
commands: CommandsCfg = CommandsCfg()
# MDP settings
rewards: RewardsCfg = RewardsCfg()
terminations: TerminationsCfg = TerminationsCfg()
......
......@@ -9,7 +9,6 @@ import torch
from typing import TYPE_CHECKING
from omni.isaac.orbit.assets import RigidObject
from omni.isaac.orbit.command_generators import UniformPoseCommandGenerator
from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.sensors import FrameTransformer
from omni.isaac.orbit.utils.math import combine_frame_transforms
......@@ -33,10 +32,10 @@ def object_ee_distance(
ee_frame_cfg: SceneEntityCfg = SceneEntityCfg("ee_frame"),
) -> torch.Tensor:
"""Reward the agent for reaching the object using tanh-kernel."""
# Target object position: (num_envs, 3)
# extract the used quantities (to enable type-hinting)
object: RigidObject = env.scene[object_cfg.name]
ee_frame: FrameTransformer = env.scene[ee_frame_cfg.name]
# Target object position: (num_envs, 3)
cube_pos_w = object.data.root_pos_w
# End-effector position: (num_envs, 3)
ee_w = ee_frame.data.target_pos_w[..., 0, :]
......@@ -50,15 +49,17 @@ def object_goal_distance(
env: RLTaskEnv,
std: float,
minimal_height: float,
command_name: str,
robot_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
) -> torch.Tensor:
"""Reward the agent for tracking the goal pose using tanh-kernel."""
# extract the used quantities (to enable type-hinting)
robot: RigidObject = env.scene[robot_cfg.name]
object: RigidObject = env.scene[object_cfg.name]
command_manager: UniformPoseCommandGenerator = env.command_manager
des_pos_b = command_manager.command[:, :3]
command = env.command_manager.get_command(command_name)
# compute the desired position in the world frame
des_pos_b = command[:, :3]
des_pos_w, _ = combine_frame_transforms(robot.data.root_state_w[:, :3], robot.data.root_state_w[:, 3:7], des_pos_b)
# distance of the end-effector to the object: (num_envs,)
distance = torch.norm(des_pos_w - object.data.root_pos_w[:, :3], dim=1)
......
......@@ -41,8 +41,8 @@ class FrankaReachEnvCfg(ReachEnvCfg):
)
# override command generator body
# end-effector is along z-direction
self.commands.body_name = "panda_hand"
self.commands.ranges.pitch = (math.pi, math.pi)
self.commands.ee_pose.body_name = "panda_hand"
self.commands.ee_pose.ranges.pitch = (math.pi, math.pi)
@configclass
......
......@@ -43,8 +43,8 @@ class UR10ReachEnvCfg(ReachEnvCfg):
)
# override command generator body
# end-effector is along x-direction
self.commands.body_name = "ee_link"
self.commands.ranges.pitch = (math.pi / 2, math.pi / 2)
self.commands.ee_pose.body_name = "ee_link"
self.commands.ee_pose.ranges.pitch = (math.pi / 2, math.pi / 2)
@configclass
......
......@@ -9,7 +9,6 @@ import torch
from typing import TYPE_CHECKING
from omni.isaac.orbit.assets import RigidObject
from omni.isaac.orbit.command_generators import UniformPoseCommandGenerator
from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.utils.math import combine_frame_transforms, quat_error_magnitude, quat_mul
......@@ -17,7 +16,7 @@ if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
def position_command_error(env: RLTaskEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def position_command_error(env: RLTaskEnv, command_name: str, asset_cfg: SceneEntityCfg) -> torch.Tensor:
"""Penalize tracking of the position error using L2-norm.
The function computes the position error between the desired position (from the command) and the
......@@ -26,15 +25,15 @@ def position_command_error(env: RLTaskEnv, asset_cfg: SceneEntityCfg) -> torch.T
"""
# extract the asset (to enable type hinting)
asset: RigidObject = env.scene[asset_cfg.name]
command_manager: UniformPoseCommandGenerator = env.command_manager
command = env.command_manager.get_command(command_name)
# obtain the desired and current positions
des_pos_b = command_manager.command[:, :3]
des_pos_b = command[:, :3]
des_pos_w, _ = combine_frame_transforms(asset.data.root_state_w[:, :3], asset.data.root_state_w[:, 3:7], des_pos_b)
curr_pos_w = asset.data.body_state_w[:, asset_cfg.body_ids[0], :3] # type: ignore
return torch.norm(curr_pos_w - des_pos_w, dim=1)
def orientation_command_error(env: RLTaskEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
def orientation_command_error(env: RLTaskEnv, command_name: str, asset_cfg: SceneEntityCfg) -> torch.Tensor:
"""Penalize tracking orientation error using shortest path.
The function computes the orientation error between the desired orientation (from the command) and the
......@@ -43,9 +42,9 @@ def orientation_command_error(env: RLTaskEnv, asset_cfg: SceneEntityCfg) -> torc
"""
# extract the asset (to enable type hinting)
asset: RigidObject = env.scene[asset_cfg.name]
command_manager: UniformPoseCommandGenerator = env.command_manager
command = env.command_manager.get_command(command_name)
# obtain the desired and current orientations
des_quat_b = command_manager.command[:, 3:7]
des_quat_b = command[:, 3:7]
des_quat_w = quat_mul(asset.data.root_state_w[:, 3:7], des_quat_b)
curr_quat_w = asset.data.body_state_w[:, asset_cfg.body_ids[0], 3:7] # type: ignore
return quat_error_magnitude(curr_quat_w, des_quat_w)
......@@ -9,7 +9,6 @@ from dataclasses import MISSING
import omni.isaac.orbit.sim as sim_utils
from omni.isaac.orbit.assets import ArticulationCfg, AssetBaseCfg
from omni.isaac.orbit.command_generators import UniformPoseCommandGeneratorCfg
from omni.isaac.orbit.envs import RLTaskEnvCfg
from omni.isaac.orbit.managers import ActionTermCfg as ActionTerm
from omni.isaac.orbit.managers import CurriculumTermCfg as CurrTerm
......@@ -65,6 +64,26 @@ class ReachSceneCfg(InteractiveSceneCfg):
##
@configclass
class CommandsCfg:
"""Command terms for the MDP."""
ee_pose = mdp.UniformPoseCommandCfg(
asset_name="robot",
body_name=MISSING,
resampling_time_range=(4.0, 4.0),
debug_vis=True,
ranges=mdp.UniformPoseCommandCfg.Ranges(
pos_x=(0.35, 0.65),
pos_y=(-0.2, 0.2),
pos_z=(0.15, 0.5),
roll=(0.0, 0.0),
pitch=MISSING, # depends on end-effector axis
yaw=(-3.14, 3.14),
),
)
@configclass
class ActionsCfg:
"""Action specifications for the MDP."""
......@@ -84,7 +103,7 @@ class ObservationsCfg:
# observation terms (order preserved)
joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01))
joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-0.01, n_max=0.01))
pose_command = ObsTerm(func=mdp.generated_commands)
pose_command = ObsTerm(func=mdp.generated_commands, params={"command_name": "ee_pose"})
actions = ObsTerm(func=mdp.last_action)
def __post_init__(self):
......@@ -117,12 +136,12 @@ class RewardsCfg:
end_effector_position_tracking = RewTerm(
func=mdp.position_command_error,
weight=-0.2,
params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING)},
params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "command_name": "ee_pose"},
)
end_effector_orientation_tracking = RewTerm(
func=mdp.orientation_command_error,
weight=-0.05,
params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING)},
params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "command_name": "ee_pose"},
)
# action penalty
......@@ -164,20 +183,7 @@ class ReachEnvCfg(RLTaskEnvCfg):
# Basic settings
observations: ObservationsCfg = ObservationsCfg()
actions: ActionsCfg = ActionsCfg()
commands: UniformPoseCommandGeneratorCfg = UniformPoseCommandGeneratorCfg(
asset_name="robot",
body_name=MISSING,
resampling_time_range=(4.0, 4.0),
debug_vis=True,
ranges=UniformPoseCommandGeneratorCfg.Ranges(
pos_x=(0.35, 0.65),
pos_y=(-0.2, 0.2),
pos_z=(0.15, 0.5),
roll=(0.0, 0.0),
pitch=MISSING, # depends on end-effector axis
yaw=(-3.14, 3.14),
),
)
commands: CommandsCfg = CommandsCfg()
# MDP settings
rewards: RewardsCfg = RewardsCfg()
terminations: TerminationsCfg = TerminationsCfg()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment