Unverified Commit a4118d7f authored by Toni-SM's avatar Toni-SM Committed by GitHub

Support other gymnasium spaces in Direct workflow (#1117)

# Description

This PR add supports for different Gymnasium spaces (`Box`, `Discrete`,
`MultiDiscrete`, `Tuple` and `Dict`) to define observation, action and
state spaces in the direct workflow.

See
https://github.com/isaac-sim/IsaacLab/issues/864#issuecomment-2351819930

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 15a2c508
...@@ -115,7 +115,7 @@ For example, for the configuration of the Cartpole camera depth environment: ...@@ -115,7 +115,7 @@ For example, for the configuration of the Cartpole camera depth environment:
:emphasize-lines: 16 :emphasize-lines: 16
If the user were to modify the width of the camera, i.e. ``env.tiled_camera.width=128``, then the parameter If the user were to modify the width of the camera, i.e. ``env.tiled_camera.width=128``, then the parameter
``env.num_observations=10240`` (1*80*128) must be updated and given as input as well. ``env.observation_space=[80,128,1]`` must be updated and given as input as well.
Similarly, the ``__post_init__`` method is not updated with the command line inputs. In the ``LocomotionVelocityRoughEnvCfg``, for example, Similarly, the ``__post_init__`` method is not updated with the command line inputs. In the ``LocomotionVelocityRoughEnvCfg``, for example,
the post init update is as follows: the post init update is as follows:
......
...@@ -45,9 +45,9 @@ Below is an example skeleton of a task config class: ...@@ -45,9 +45,9 @@ Below is an example skeleton of a task config class:
# env # env
decimation = 2 decimation = 2
episode_length_s = 5.0 episode_length_s = 5.0
num_actions = 1 action_space = 1
num_observations = 4 observation_space = 4
num_states = 0 state_space = 0
# task-specific parameters # task-specific parameters
... ...
...@@ -135,9 +135,9 @@ The following parameters must be set for each environment config: ...@@ -135,9 +135,9 @@ The following parameters must be set for each environment config:
decimation = 2 decimation = 2
episode_length_s = 5.0 episode_length_s = 5.0
num_actions = 1 action_space = 1
num_observations = 4 observation_space = 4
num_states = 0 state_space = 0
Note that the maximum episode length parameter (now ``episode_length_s``) is in seconds instead of steps as it was Note that the maximum episode length parameter (now ``episode_length_s``) is in seconds instead of steps as it was
in IsaacGymEnvs. To convert between step count to seconds, use the equation: in IsaacGymEnvs. To convert between step count to seconds, use the equation:
...@@ -569,9 +569,9 @@ Task Config ...@@ -569,9 +569,9 @@ Task Config
| | decimation = 2 | | | decimation = 2 |
| asset: | episode_length_s = 5.0 | | asset: | episode_length_s = 5.0 |
| assetRoot: "../../assets" | action_scale = 100.0 # [N] | | assetRoot: "../../assets" | action_scale = 100.0 # [N] |
| assetFileName: "urdf/cartpole.urdf" | num_actions = 1 | | assetFileName: "urdf/cartpole.urdf" | action_space = 1 |
| | num_observations = 4 | | | observation_space = 4 |
| enableCameraSensors: False | num_states = 0 | | enableCameraSensors: False | state_space = 0 |
| | # reset | | | # reset |
| sim: | max_cart_pos = 3.0 | | sim: | max_cart_pos = 3.0 |
| dt: 0.0166 # 1/60 s | initial_pole_angle_range = [-0.25, 0.25] | | dt: 0.0166 # 1/60 s | initial_pole_angle_range = [-0.25, 0.25] |
......
...@@ -46,9 +46,9 @@ Below is an example skeleton of a task config class: ...@@ -46,9 +46,9 @@ Below is an example skeleton of a task config class:
# env # env
decimation = 2 decimation = 2
episode_length_s = 5.0 episode_length_s = 5.0
num_actions = 1 action_space = 1
num_observations = 4 observation_space = 4
num_states = 0 state_space = 0
# task-specific parameters # task-specific parameters
... ...
...@@ -158,9 +158,9 @@ The following parameters must be set for each environment config: ...@@ -158,9 +158,9 @@ The following parameters must be set for each environment config:
decimation = 2 decimation = 2
episode_length_s = 5.0 episode_length_s = 5.0
num_actions = 1 action_space = 1
num_observations = 4 observation_space = 4
num_states = 0 state_space = 0
RL Config Setup RL Config Setup
...@@ -501,9 +501,9 @@ Task config in Isaac Lab can be split into the main task configuration class and ...@@ -501,9 +501,9 @@ Task config in Isaac Lab can be split into the main task configuration class and
| clipObservations: 5.0 | decimation = 2 | | clipObservations: 5.0 | decimation = 2 |
| clipActions: 1.0 | episode_length_s = 5.0 | | clipActions: 1.0 | episode_length_s = 5.0 |
| controlFrequencyInv: 2 # 60 Hz | action_scale = 100.0 # [N] | | controlFrequencyInv: 2 # 60 Hz | action_scale = 100.0 # [N] |
| | num_actions = 1 | | | action_space = 1 |
| sim: | num_observations = 4 | | sim: | observation_space = 4 |
| | num_states = 0 | | | state_space = 0 |
| dt: 0.0083 # 1/120 s | # reset | | dt: 0.0083 # 1/120 s | # reset |
| use_gpu_pipeline: ${eq:${...pipeline},"gpu"} | max_cart_pos = 3.0 | | use_gpu_pipeline: ${eq:${...pipeline},"gpu"} | max_cart_pos = 3.0 |
| gravity: [0.0, 0.0, -9.81] | initial_pole_angle_range = [-0.25, 0.25] | | gravity: [0.0, 0.0, -9.81] | initial_pole_angle_range = [-0.25, 0.25] |
......
...@@ -28,8 +28,8 @@ from omni.isaac.lab_assets import H1_CFG ...@@ -28,8 +28,8 @@ from omni.isaac.lab_assets import H1_CFG
# [end-h1_env-import] # [end-h1_env-import]
# [start-h1_env-spaces] # [start-h1_env-spaces]
num_actions = 19 action_space = 19
num_observations = 69 observation_space = 69
# [end-h1_env-spaces] # [end-h1_env-spaces]
# [start-h1_env-robot] # [start-h1_env-robot]
......
...@@ -48,9 +48,9 @@ config should define the number of actions and observations for the environment. ...@@ -48,9 +48,9 @@ config should define the number of actions and observations for the environment.
@configclass @configclass
class CartpoleEnvCfg(DirectRLEnvCfg): class CartpoleEnvCfg(DirectRLEnvCfg):
... ...
num_actions = 1 action_space = 1
num_observations = 4 observation_space = 4
num_states = 0 state_space = 0
The config class can also be used to define task-specific attributes, such as scaling for reward terms The config class can also be used to define task-specific attributes, such as scaling for reward terms
and thresholds for reset conditions. and thresholds for reset conditions.
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.25.1" version = "0.25.2"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.25.2 (2024-10-16)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added support for different Gymnasium spaces (``Box``, ``Discrete``, ``MultiDiscrete``, ``Tuple`` and ``Dict``)
to define observation, action and state spaces in the direct workflow.
* Added :meth:`sample_space` to environment utils to sample supported spaces where data containers are torch tensors.
Changed
^^^^^^^
* Mark the :attr:`num_observations`, :attr:`num_actions` and :attr:`num_states` in :class:`DirectRLEnvCfg` as deprecated
in favor of :attr:`observation_space`, :attr:`action_space` and :attr:`state_space` respectively.
* Mark the :attr:`num_observations`, :attr:`num_actions` and :attr:`num_states` in :class:`DirectMARLEnvCfg` as deprecated
in favor of :attr:`observation_spaces`, :attr:`action_spaces` and :attr:`state_space` respectively.
0.25.1 (2024-10-10) 0.25.1 (2024-10-10)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -52,4 +52,4 @@ from .manager_based_env import ManagerBasedEnv ...@@ -52,4 +52,4 @@ from .manager_based_env import ManagerBasedEnv
from .manager_based_env_cfg import ManagerBasedEnvCfg from .manager_based_env_cfg import ManagerBasedEnvCfg
from .manager_based_rl_env import ManagerBasedRLEnv from .manager_based_rl_env import ManagerBasedRLEnv
from .manager_based_rl_env_cfg import ManagerBasedRLEnvCfg from .manager_based_rl_env_cfg import ManagerBasedRLEnvCfg
from .utils import multi_agent_to_single_agent, multi_agent_with_one_agent from .utils.marl import multi_agent_to_single_agent, multi_agent_with_one_agent
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import gymnasium as gym
import torch import torch
from typing import Dict, Literal, TypeVar from typing import Dict, Literal, TypeVar
...@@ -62,6 +63,9 @@ class ViewerCfg: ...@@ -62,6 +63,9 @@ class ViewerCfg:
# Types. # Types.
## ##
SpaceType = TypeVar("SpaceType", gym.spaces.Space, int, set, tuple, list, dict)
"""A sentinel object to indicate a valid space type to specify states, observations and actions."""
VecEnvObs = Dict[str, torch.Tensor | Dict[str, torch.Tensor]] VecEnvObs = Dict[str, torch.Tensor | Dict[str, torch.Tensor]]
"""Observation returned by the environment. """Observation returned by the environment.
......
...@@ -14,6 +14,7 @@ import torch ...@@ -14,6 +14,7 @@ import torch
import weakref import weakref
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import MISSING
from typing import Any, ClassVar from typing import Any, ClassVar
import omni.isaac.core.utils.torch as torch_utils import omni.isaac.core.utils.torch as torch_utils
...@@ -30,6 +31,7 @@ from omni.isaac.lab.utils.timer import Timer ...@@ -30,6 +31,7 @@ from omni.isaac.lab.utils.timer import Timer
from .common import ActionType, AgentID, EnvStepReturn, ObsType, StateType from .common import ActionType, AgentID, EnvStepReturn, ObsType, StateType
from .direct_marl_env_cfg import DirectMARLEnvCfg from .direct_marl_env_cfg import DirectMARLEnvCfg
from .ui import ViewportCameraController from .ui import ViewportCameraController
from .utils.spaces import sample_space, spec_to_gym_space
class DirectMARLEnv: class DirectMARLEnv:
...@@ -164,10 +166,6 @@ class DirectMARLEnv: ...@@ -164,10 +166,6 @@ class DirectMARLEnv:
# -- init buffers # -- init buffers
self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long) self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
self.reset_buf = torch.zeros(self.num_envs, dtype=torch.bool, device=self.sim.device) self.reset_buf = torch.zeros(self.num_envs, dtype=torch.bool, device=self.sim.device)
self.actions = {
agent: torch.zeros(self.num_envs, self.cfg.num_actions[agent], device=self.sim.device)
for agent in self.cfg.possible_agents
}
# setup the observation, state and action spaces # setup the observation, state and action spaces
self._configure_env_spaces() self._configure_env_spaces()
...@@ -406,16 +404,19 @@ class DirectMARLEnv: ...@@ -406,16 +404,19 @@ class DirectMARLEnv:
"""Returns the state for the environment. """Returns the state for the environment.
The state-space is used for centralized training or asymmetric actor-critic architectures. It is configured The state-space is used for centralized training or asymmetric actor-critic architectures. It is configured
using the :attr:`DirectMARLEnvCfg.num_states` parameter. using the :attr:`DirectMARLEnvCfg.state_space` parameter.
Returns: Returns:
The states for the environment, or None if :attr:`DirectMARLEnvCfg.num_states` parameter is zero. The states for the environment, or None if :attr:`DirectMARLEnvCfg.state_space` parameter is zero.
""" """
if not self.cfg.num_states: if not self.cfg.state_space:
return None return None
# concatenate and return the observations as state # concatenate and return the observations as state
if self.cfg.num_states < 0: # FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
self.state_buf = torch.cat([self.obs_dict[agent] for agent in self.cfg.possible_agents], dim=-1) if isinstance(self.cfg.state_space, int) and self.cfg.state_space < 0:
self.state_buf = torch.cat(
[self.obs_dict[agent].reshape(self.num_envs, -1) for agent in self.cfg.possible_agents], dim=-1
)
# compute and return custom environment state # compute and return custom environment state
else: else:
self.state_buf = self._get_states() self.state_buf = self._get_states()
...@@ -568,25 +569,45 @@ class DirectMARLEnv: ...@@ -568,25 +569,45 @@ class DirectMARLEnv:
self.agents = self.cfg.possible_agents self.agents = self.cfg.possible_agents
self.possible_agents = self.cfg.possible_agents self.possible_agents = self.cfg.possible_agents
# show deprecation message and overwrite configuration
if self.cfg.num_actions is not None:
omni.log.warn("DirectMARLEnvCfg.num_actions is deprecated. Use DirectMARLEnvCfg.action_spaces instead.")
if isinstance(self.cfg.action_spaces, type(MISSING)):
self.cfg.action_spaces = self.cfg.num_actions
if self.cfg.num_observations is not None:
omni.log.warn(
"DirectMARLEnvCfg.num_observations is deprecated. Use DirectMARLEnvCfg.observation_spaces instead."
)
if isinstance(self.cfg.observation_spaces, type(MISSING)):
self.cfg.observation_spaces = self.cfg.num_observations
if self.cfg.num_states is not None:
omni.log.warn("DirectMARLEnvCfg.num_states is deprecated. Use DirectMARLEnvCfg.state_space instead.")
if isinstance(self.cfg.state_space, type(MISSING)):
self.cfg.state_space = self.cfg.num_states
# set up observation and action spaces # set up observation and action spaces
self.observation_spaces = { self.observation_spaces = {
agent: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.cfg.num_observations[agent],)) agent: spec_to_gym_space(self.cfg.observation_spaces[agent]) for agent in self.cfg.possible_agents
for agent in self.cfg.possible_agents
} }
self.action_spaces = { self.action_spaces = {
agent: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.cfg.num_actions[agent],)) agent: spec_to_gym_space(self.cfg.action_spaces[agent]) for agent in self.cfg.possible_agents
for agent in self.cfg.possible_agents
} }
# set up state space # set up state space
if not self.cfg.num_states: if not self.cfg.state_space:
self.state_space = None self.state_space = None
if self.cfg.num_states < 0: if isinstance(self.cfg.state_space, int) and self.cfg.state_space < 0:
self.state_space = gym.spaces.Box( self.state_space = gym.spaces.flatten_space(
low=-np.inf, high=np.inf, shape=(sum(self.cfg.num_observations.values()),) gym.spaces.Tuple([self.observation_spaces[agent] for agent in self.cfg.possible_agents])
) )
else: else:
self.state_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.cfg.num_states,)) self.state_space = spec_to_gym_space(self.cfg.state_space)
# instantiate actions (needed for tasks for which the observations computation is dependent on the actions)
self.actions = {
agent: sample_space(self.action_spaces[agent], self.sim.device, batch_size=self.num_envs, fill_value=0)
for agent in self.cfg.possible_agents
}
def _reset_idx(self, env_ids: Sequence[int]): def _reset_idx(self, env_ids: Sequence[int]):
"""Reset environments based on specified indices. """Reset environments based on specified indices.
...@@ -664,8 +685,8 @@ class DirectMARLEnv: ...@@ -664,8 +685,8 @@ class DirectMARLEnv:
def _get_states(self) -> StateType: def _get_states(self) -> StateType:
"""Compute and return the states for the environment. """Compute and return the states for the environment.
This method is only called (and therefore has to be implemented) when the :attr:`DirectMARLEnvCfg.num_states` This method is only called (and therefore has to be implemented) when the :attr:`DirectMARLEnvCfg.state_space`
parameter is greater than zero. parameter is not a number less than or equal to zero.
Returns: Returns:
The states for the environment. The states for the environment.
......
...@@ -10,7 +10,7 @@ from omni.isaac.lab.sim import SimulationCfg ...@@ -10,7 +10,7 @@ from omni.isaac.lab.sim import SimulationCfg
from omni.isaac.lab.utils import configclass from omni.isaac.lab.utils import configclass
from omni.isaac.lab.utils.noise import NoiseModelCfg from omni.isaac.lab.utils.noise import NoiseModelCfg
from .common import AgentID, ViewerCfg from .common import AgentID, SpaceType, ViewerCfg
from .ui import BaseEnvWindow from .ui import BaseEnvWindow
...@@ -104,11 +104,39 @@ class DirectMARLEnvCfg: ...@@ -104,11 +104,39 @@ class DirectMARLEnvCfg:
Please refer to the :class:`omni.isaac.lab.managers.EventManager` class for more details. Please refer to the :class:`omni.isaac.lab.managers.EventManager` class for more details.
""" """
num_observations: dict[AgentID, int] = MISSING observation_spaces: dict[AgentID, SpaceType] = MISSING
"""The dimension of the observation space from each agent.""" """Observation space definition for each agent.
The space can be defined either using Gymnasium :py:mod:`~gymnasium.spaces` (when a more detailed
specification of the space is desired) or basic Python data types (for simplicity).
.. list-table::
:header-rows: 1
* - Gymnasium space
- Python data type
* - :class:`~gymnasium.spaces.Box`
- Integer or list of integers (e.g.: ``7``, ``[64, 64, 3]``)
* - :class:`~gymnasium.spaces.Discrete`
- Single-element set (e.g.: ``{2}``)
* - :class:`~gymnasium.spaces.MultiDiscrete`
- List of single-element sets (e.g.: ``[{2}, {5}]``)
* - :class:`~gymnasium.spaces.Dict`
- Dictionary (e.g.: ``{"joints": 7, "rgb": [64, 64, 3], "gripper": {2}}``)
* - :class:`~gymnasium.spaces.Tuple`
- Tuple (e.g.: ``(7, [64, 64, 3], {2})``)
"""
num_states: int = MISSING num_observations: dict[AgentID, int] | None = None
"""The dimension of the state space from each environment instance. """The dimension of the observation space for each agent.
.. warning::
This attribute is deprecated. Use :attr:`~omni.isaac.lab.envs.DirectMARLEnvCfg.observation_spaces` instead.
"""
state_space: SpaceType = MISSING
"""State space definition.
The following values are supported: The following values are supported:
...@@ -116,6 +144,33 @@ class DirectMARLEnvCfg: ...@@ -116,6 +144,33 @@ class DirectMARLEnvCfg:
* 0: No state-space will be constructed (`state_space` is None). * 0: No state-space will be constructed (`state_space` is None).
This is useful to save computational resources when the algorithm to be trained does not need it. This is useful to save computational resources when the algorithm to be trained does not need it.
* greater than 0: Custom state-space dimension to be provided by the task implementation. * greater than 0: Custom state-space dimension to be provided by the task implementation.
The space can be defined either using Gymnasium :py:mod:`~gymnasium.spaces` (when a more detailed
specification of the space is desired) or basic Python data types (for simplicity).
.. list-table::
:header-rows: 1
* - Gymnasium space
- Python data type
* - :class:`~gymnasium.spaces.Box`
- Integer or list of integers (e.g.: ``7``, ``[64, 64, 3]``)
* - :class:`~gymnasium.spaces.Discrete`
- Single-element set (e.g.: ``{2}``)
* - :class:`~gymnasium.spaces.MultiDiscrete`
- List of single-element sets (e.g.: ``[{2}, {5}]``)
* - :class:`~gymnasium.spaces.Dict`
- Dictionary (e.g.: ``{"joints": 7, "rgb": [64, 64, 3], "gripper": {2}}``)
* - :class:`~gymnasium.spaces.Tuple`
- Tuple (e.g.: ``(7, [64, 64, 3], {2})``)
"""
num_states: int | None = None
"""The dimension of the state space from each environment instance.
.. warning::
This attribute is deprecated. Use :attr:`~omni.isaac.lab.envs.DirectMARLEnvCfg.state_space` instead.
""" """
observation_noise_model: dict[AgentID, NoiseModelCfg | None] | None = None observation_noise_model: dict[AgentID, NoiseModelCfg | None] | None = None
...@@ -124,8 +179,36 @@ class DirectMARLEnvCfg: ...@@ -124,8 +179,36 @@ class DirectMARLEnvCfg:
Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details. Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details.
""" """
num_actions: dict[AgentID, int] = MISSING action_spaces: dict[AgentID, SpaceType] = MISSING
"""The dimension of the action space for each agent.""" """Action space definition for each agent.
The space can be defined either using Gymnasium :py:mod:`~gymnasium.spaces` (when a more detailed
specification of the space is desired) or basic Python data types (for simplicity).
.. list-table::
:header-rows: 1
* - Gymnasium space
- Python data type
* - :class:`~gymnasium.spaces.Box`
- Integer or list of integers (e.g.: ``7``, ``[64, 64, 3]``)
* - :class:`~gymnasium.spaces.Discrete`
- Single-element set (e.g.: ``{2}``)
* - :class:`~gymnasium.spaces.MultiDiscrete`
- List of single-element sets (e.g.: ``[{2}, {5}]``)
* - :class:`~gymnasium.spaces.Dict`
- Dictionary (e.g.: ``{"joints": 7, "rgb": [64, 64, 3], "gripper": {2}}``)
* - :class:`~gymnasium.spaces.Tuple`
- Tuple (e.g.: ``(7, [64, 64, 3], {2})``)
"""
num_actions: dict[AgentID, int] | None = None
"""The dimension of the action space for each agent.
.. warning::
This attribute is deprecated. Use :attr:`~omni.isaac.lab.envs.DirectMARLEnvCfg.action_spaces` instead.
"""
action_noise_model: dict[AgentID, NoiseModelCfg | None] | None = None action_noise_model: dict[AgentID, NoiseModelCfg | None] | None = None
"""The noise model applied to the actions provided to the environment. Default is None, which means no noise is added. """The noise model applied to the actions provided to the environment. Default is None, which means no noise is added.
......
...@@ -14,6 +14,7 @@ import torch ...@@ -14,6 +14,7 @@ import torch
import weakref import weakref
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import MISSING
from typing import Any, ClassVar from typing import Any, ClassVar
import omni.isaac.core.utils.torch as torch_utils import omni.isaac.core.utils.torch as torch_utils
...@@ -30,6 +31,7 @@ from omni.isaac.lab.utils.timer import Timer ...@@ -30,6 +31,7 @@ from omni.isaac.lab.utils.timer import Timer
from .common import VecEnvObs, VecEnvStepReturn from .common import VecEnvObs, VecEnvStepReturn
from .direct_rl_env_cfg import DirectRLEnvCfg from .direct_rl_env_cfg import DirectRLEnvCfg
from .ui import ViewportCameraController from .ui import ViewportCameraController
from .utils.spaces import sample_space, spec_to_gym_space
class DirectRLEnv(gym.Env): class DirectRLEnv(gym.Env):
...@@ -171,7 +173,6 @@ class DirectRLEnv(gym.Env): ...@@ -171,7 +173,6 @@ class DirectRLEnv(gym.Env):
self.reset_terminated = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) self.reset_terminated = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self.reset_time_outs = torch.zeros_like(self.reset_terminated) self.reset_time_outs = torch.zeros_like(self.reset_terminated)
self.reset_buf = torch.zeros(self.num_envs, dtype=torch.bool, device=self.sim.device) self.reset_buf = torch.zeros(self.num_envs, dtype=torch.bool, device=self.sim.device)
self.actions = torch.zeros(self.num_envs, self.cfg.num_actions, device=self.sim.device)
# setup the action and observation spaces for Gym # setup the action and observation spaces for Gym
self._configure_gym_env_spaces() self._configure_gym_env_spaces()
...@@ -507,27 +508,40 @@ class DirectRLEnv(gym.Env): ...@@ -507,27 +508,40 @@ class DirectRLEnv(gym.Env):
def _configure_gym_env_spaces(self): def _configure_gym_env_spaces(self):
"""Configure the action and observation spaces for the Gym environment.""" """Configure the action and observation spaces for the Gym environment."""
# observation space (unbounded since we don't impose any limits) # show deprecation message and overwrite configuration
self.num_actions = self.cfg.num_actions if self.cfg.num_actions is not None:
self.num_observations = self.cfg.num_observations omni.log.warn("DirectRLEnvCfg.num_actions is deprecated. Use DirectRLEnvCfg.action_space instead.")
self.num_states = self.cfg.num_states if isinstance(self.cfg.action_space, type(MISSING)):
self.cfg.action_space = self.cfg.num_actions
if self.cfg.num_observations is not None:
omni.log.warn(
"DirectRLEnvCfg.num_observations is deprecated. Use DirectRLEnvCfg.observation_space instead."
)
if isinstance(self.cfg.observation_space, type(MISSING)):
self.cfg.observation_space = self.cfg.num_observations
if self.cfg.num_states is not None:
omni.log.warn("DirectRLEnvCfg.num_states is deprecated. Use DirectRLEnvCfg.state_space instead.")
if isinstance(self.cfg.state_space, type(MISSING)):
self.cfg.state_space = self.cfg.num_states
# set up spaces # set up spaces
self.single_observation_space = gym.spaces.Dict() self.single_observation_space = gym.spaces.Dict()
self.single_observation_space["policy"] = gym.spaces.Box( self.single_observation_space["policy"] = spec_to_gym_space(self.cfg.observation_space)
low=-np.inf, high=np.inf, shape=(self.num_observations,) self.single_action_space = spec_to_gym_space(self.cfg.action_space)
)
self.single_action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_actions,))
# batch the spaces for vectorized environments # batch the spaces for vectorized environments
self.observation_space = gym.vector.utils.batch_space(self.single_observation_space["policy"], self.num_envs) self.observation_space = gym.vector.utils.batch_space(self.single_observation_space["policy"], self.num_envs)
self.action_space = gym.vector.utils.batch_space(self.single_action_space, self.num_envs) self.action_space = gym.vector.utils.batch_space(self.single_action_space, self.num_envs)
# optional state space for asymmetric actor-critic architectures # optional state space for asymmetric actor-critic architectures
if self.num_states > 0: self.state_space = None
self.single_observation_space["critic"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_states,)) if self.cfg.state_space > 0:
self.single_observation_space["critic"] = spec_to_gym_space(self.cfg.state_space)
self.state_space = gym.vector.utils.batch_space(self.single_observation_space["critic"], self.num_envs) self.state_space = gym.vector.utils.batch_space(self.single_observation_space["critic"], self.num_envs)
# instantiate actions (needed for tasks for which the observations computation is dependent on the actions)
self.actions = sample_space(self.single_action_space, self.sim.device, batch_size=self.num_envs, fill_value=0)
def _reset_idx(self, env_ids: Sequence[int]): def _reset_idx(self, env_ids: Sequence[int]):
"""Reset environments based on specified indices. """Reset environments based on specified indices.
...@@ -601,7 +615,7 @@ class DirectRLEnv(gym.Env): ...@@ -601,7 +615,7 @@ class DirectRLEnv(gym.Env):
"""Compute and return the states for the environment. """Compute and return the states for the environment.
The state-space is used for asymmetric actor-critic architectures. It is configured The state-space is used for asymmetric actor-critic architectures. It is configured
using the :attr:`DirectRLEnvCfg.num_states` parameter. using the :attr:`DirectRLEnvCfg.state_space` parameter.
Returns: Returns:
The states for the environment. If the environment does not have a state-space, the function The states for the environment. If the environment does not have a state-space, the function
......
...@@ -10,7 +10,7 @@ from omni.isaac.lab.sim import SimulationCfg ...@@ -10,7 +10,7 @@ from omni.isaac.lab.sim import SimulationCfg
from omni.isaac.lab.utils import configclass from omni.isaac.lab.utils import configclass
from omni.isaac.lab.utils.noise import NoiseModelCfg from omni.isaac.lab.utils.noise import NoiseModelCfg
from .common import ViewerCfg from .common import SpaceType, ViewerCfg
from .ui import BaseEnvWindow from .ui import BaseEnvWindow
...@@ -104,13 +104,68 @@ class DirectRLEnvCfg: ...@@ -104,13 +104,68 @@ class DirectRLEnvCfg:
Please refer to the :class:`omni.isaac.lab.managers.EventManager` class for more details. Please refer to the :class:`omni.isaac.lab.managers.EventManager` class for more details.
""" """
num_observations: int = MISSING observation_space: SpaceType = MISSING
"""The dimension of the observation space from each environment instance.""" """Observation space definition.
The space can be defined either using Gymnasium :py:mod:`~gymnasium.spaces` (when a more detailed
specification of the space is desired) or basic Python data types (for simplicity).
.. list-table::
:header-rows: 1
* - Gymnasium space
- Python data type
* - :class:`~gymnasium.spaces.Box`
- Integer or list of integers (e.g.: ``7``, ``[64, 64, 3]``)
* - :class:`~gymnasium.spaces.Discrete`
- Single-element set (e.g.: ``{2}``)
* - :class:`~gymnasium.spaces.MultiDiscrete`
- List of single-element sets (e.g.: ``[{2}, {5}]``)
* - :class:`~gymnasium.spaces.Dict`
- Dictionary (e.g.: ``{"joints": 7, "rgb": [64, 64, 3], "gripper": {2}}``)
* - :class:`~gymnasium.spaces.Tuple`
- Tuple (e.g.: ``(7, [64, 64, 3], {2})``)
"""
num_observations: int | None = None
"""The dimension of the observation space from each environment instance.
.. warning::
This attribute is deprecated. Use :attr:`~omni.isaac.lab.envs.DirectRLEnvCfg.observation_space` instead.
"""
num_states: int = 0 state_space: SpaceType = MISSING
"""The dimension of the state-space from each environment instance. Default is 0, which means no state-space is defined. """State space definition.
This is useful for asymmetric actor-critic and defines the observation space for the critic. This is useful for asymmetric actor-critic and defines the observation space for the critic.
The space can be defined either using Gymnasium :py:mod:`~gymnasium.spaces` (when a more detailed
specification of the space is desired) or basic Python data types (for simplicity).
.. list-table::
:header-rows: 1
* - Gymnasium space
- Python data type
* - :class:`~gymnasium.spaces.Box`
- Integer or list of integers (e.g.: ``7``, ``[64, 64, 3]``)
* - :class:`~gymnasium.spaces.Discrete`
- Single-element set (e.g.: ``{2}``)
* - :class:`~gymnasium.spaces.MultiDiscrete`
- List of single-element sets (e.g.: ``[{2}, {5}]``)
* - :class:`~gymnasium.spaces.Dict`
- Dictionary (e.g.: ``{"joints": 7, "rgb": [64, 64, 3], "gripper": {2}}``)
* - :class:`~gymnasium.spaces.Tuple`
- Tuple (e.g.: ``(7, [64, 64, 3], {2})``)
"""
num_states: int | None = None
"""The dimension of the state-space from each environment instance.
.. warning::
This attribute is deprecated. Use :attr:`~omni.isaac.lab.envs.DirectRLEnvCfg.state_space` instead.
""" """
observation_noise_model: NoiseModelCfg | None = None observation_noise_model: NoiseModelCfg | None = None
...@@ -119,8 +174,36 @@ class DirectRLEnvCfg: ...@@ -119,8 +174,36 @@ class DirectRLEnvCfg:
Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details. Please refer to the :class:`omni.isaac.lab.utils.noise.NoiseModel` class for more details.
""" """
num_actions: int = MISSING action_space: SpaceType = MISSING
"""The dimension of the action space for each environment.""" """Action space definition.
The space can be defined either using Gymnasium :py:mod:`~gymnasium.spaces` (when a more detailed
specification of the space is desired) or basic Python data types (for simplicity).
.. list-table::
:header-rows: 1
* - Gymnasium space
- Python data type
* - :class:`~gymnasium.spaces.Box`
- Integer or list of integers (e.g.: ``7``, ``[64, 64, 3]``)
* - :class:`~gymnasium.spaces.Discrete`
- Single-element set (e.g.: ``{2}``)
* - :class:`~gymnasium.spaces.MultiDiscrete`
- List of single-element sets (e.g.: ``[{2}, {5}]``)
* - :class:`~gymnasium.spaces.Dict`
- Dictionary (e.g.: ``{"joints": 7, "rgb": [64, 64, 3], "gripper": {2}}``)
* - :class:`~gymnasium.spaces.Tuple`
- Tuple (e.g.: ``(7, [64, 64, 3], {2})``)
"""
num_actions: int | None = None
"""The dimension of the action space for each environment.
.. warning::
This attribute is deprecated. Use :attr:`~omni.isaac.lab.envs.DirectRLEnvCfg.action_space` instead.
"""
action_noise_model: NoiseModelCfg | None = None action_noise_model: NoiseModelCfg | None = None
"""The noise model applied to the actions provided to the environment. Default is None, which means no noise is added. """The noise model applied to the actions provided to the environment. Default is None, which means no noise is added.
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Sub-package for environment utils."""
...@@ -9,9 +9,9 @@ import numpy as np ...@@ -9,9 +9,9 @@ import numpy as np
import torch import torch
from typing import Any from typing import Any
from .common import ActionType, AgentID, EnvStepReturn, ObsType, StateType, VecEnvObs, VecEnvStepReturn from ..common import ActionType, AgentID, EnvStepReturn, ObsType, StateType, VecEnvObs, VecEnvStepReturn
from .direct_marl_env import DirectMARLEnv from ..direct_marl_env import DirectMARLEnv
from .direct_rl_env import DirectRLEnv from ..direct_rl_env import DirectRLEnv
def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = False) -> DirectRLEnv: def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = False) -> DirectRLEnv:
...@@ -39,7 +39,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -39,7 +39,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool =
Raises: Raises:
AssertionError: If the environment state cannot be used as observation since it was explicitly defined AssertionError: If the environment state cannot be used as observation since it was explicitly defined
as unconstructed (:attr:`DirectMARLEnvCfg.num_states`). as unconstructed (:attr:`DirectMARLEnvCfg.state_space`).
""" """
class Env(DirectRLEnv): class Env(DirectRLEnv):
...@@ -49,7 +49,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -49,7 +49,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool =
# check if it is possible to use the multi-agent environment state as single-agent observation # check if it is possible to use the multi-agent environment state as single-agent observation
self._state_as_observation = state_as_observation self._state_as_observation = state_as_observation
if self._state_as_observation: if self._state_as_observation:
assert self.env.cfg.num_states != 0, ( assert self.env.cfg.state_space != 0, (
"The environment state cannot be used as observation since it was explicitly defined as" "The environment state cannot be used as observation since it was explicitly defined as"
" unconstructed" " unconstructed"
) )
...@@ -58,18 +58,17 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -58,18 +58,17 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool =
self.cfg = self.env.cfg self.cfg = self.env.cfg
self.sim = self.env.sim self.sim = self.env.sim
self.scene = self.env.scene self.scene = self.env.scene
self.num_actions = sum(self.env.cfg.num_actions.values())
self.num_observations = sum(self.env.cfg.num_observations.values())
self.num_states = self.env.cfg.num_states
self.single_observation_space = gym.spaces.Dict() self.single_observation_space = gym.spaces.Dict()
if self._state_as_observation: if self._state_as_observation:
self.single_observation_space["policy"] = self.env.state_space self.single_observation_space["policy"] = self.env.state_space
else: else:
self.single_observation_space["policy"] = gym.spaces.Box( self.single_observation_space["policy"] = gym.spaces.flatten_space(
low=-np.inf, high=np.inf, shape=(self.num_observations,) gym.spaces.Tuple([self.env.observation_spaces[agent] for agent in self.env.possible_agents])
)
self.single_action_space = gym.spaces.flatten_space(
gym.spaces.Tuple([self.env.action_spaces[agent] for agent in self.env.possible_agents])
) )
self.single_action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_actions,))
# batch the spaces for vectorized environments # batch the spaces for vectorized environments
self.observation_space = gym.vector.utils.batch_space( self.observation_space = gym.vector.utils.batch_space(
...@@ -84,18 +83,25 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -84,18 +83,25 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool =
if self._state_as_observation: if self._state_as_observation:
obs = {"policy": self.env.state()} obs = {"policy": self.env.state()}
# concatenate agents' observations # concatenate agents' observations
# FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
else: else:
obs = {"policy": torch.cat([obs[agent] for agent in self.env.possible_agents], dim=-1)} obs = {
"policy": torch.cat(
[obs[agent].reshape(self.num_envs, -1) for agent in self.env.possible_agents], dim=-1
)
}
return obs, extras return obs, extras
def step(self, action: torch.Tensor) -> VecEnvStepReturn: def step(self, action: torch.Tensor) -> VecEnvStepReturn:
# split single-agent actions to build the multi-agent ones # split single-agent actions to build the multi-agent ones
# FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
index = 0 index = 0
_actions = {} _actions = {}
for agent in self.env.possible_agents: for agent in self.env.possible_agents:
_actions[agent] = action[:, index : index + self.env.cfg.num_actions[agent]] delta = gym.spaces.flatdim(self.env.action_spaces[agent])
index += self.env.cfg.num_actions[agent] _actions[agent] = action[:, index : index + delta]
index += delta
# step the environment # step the environment
obs, rewards, terminated, time_outs, extras = self.env.step(_actions) obs, rewards, terminated, time_outs, extras = self.env.step(_actions)
...@@ -104,8 +110,13 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -104,8 +110,13 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool =
if self._state_as_observation: if self._state_as_observation:
obs = {"policy": self.env.state()} obs = {"policy": self.env.state()}
# concatenate agents' observations # concatenate agents' observations
# FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
else: else:
obs = {"policy": torch.cat([obs[agent] for agent in self.env.possible_agents], dim=-1)} obs = {
"policy": torch.cat(
[obs[agent].reshape(self.num_envs, -1) for agent in self.env.possible_agents], dim=-1
)
}
# process environment outputs to return single-agent data # process environment outputs to return single-agent data
rewards = sum(rewards.values()) rewards = sum(rewards.values())
...@@ -147,7 +158,7 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -147,7 +158,7 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool =
Raises: Raises:
AssertionError: If the environment state cannot be used as observation since it was explicitly defined AssertionError: If the environment state cannot be used as observation since it was explicitly defined
as unconstructed (:attr:`DirectMARLEnvCfg.num_states`). as unconstructed (:attr:`DirectMARLEnvCfg.state_space`).
""" """
class Env(DirectMARLEnv): class Env(DirectMARLEnv):
...@@ -157,7 +168,7 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -157,7 +168,7 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool =
# check if it is possible to use the multi-agent environment state as agent observation # check if it is possible to use the multi-agent environment state as agent observation
self._state_as_observation = state_as_observation self._state_as_observation = state_as_observation
if self._state_as_observation: if self._state_as_observation:
assert self.env.cfg.num_states != 0, ( assert self.env.cfg.state_space != 0, (
"The environment state cannot be used as observation since it was explicitly defined as" "The environment state cannot be used as observation since it was explicitly defined as"
" unconstructed" " unconstructed"
) )
...@@ -170,13 +181,13 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -170,13 +181,13 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool =
self._exported_observation_spaces = {self._agent_id: self.env.state_space} self._exported_observation_spaces = {self._agent_id: self.env.state_space}
else: else:
self._exported_observation_spaces = { self._exported_observation_spaces = {
self._agent_id: gym.spaces.Box( self._agent_id: gym.spaces.flatten_space(
low=-np.inf, high=np.inf, shape=(sum(self.env.cfg.num_observations.values()),) gym.spaces.Tuple([self.env.observation_spaces[agent] for agent in self.env.possible_agents])
) )
} }
self._exported_action_spaces = { self._exported_action_spaces = {
self._agent_id: gym.spaces.Box( self._agent_id: gym.spaces.flatten_space(
low=-np.inf, high=np.inf, shape=(sum(self.env.cfg.num_actions.values()),) gym.spaces.Tuple([self.env.action_spaces[agent] for agent in self.env.possible_agents])
) )
} }
...@@ -208,18 +219,25 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -208,18 +219,25 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool =
if self._state_as_observation: if self._state_as_observation:
obs = {self._agent_id: self.env.state()} obs = {self._agent_id: self.env.state()}
# concatenate agents' observations # concatenate agents' observations
# FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
else: else:
obs = {self._agent_id: torch.cat([obs[agent] for agent in self.env.possible_agents], dim=-1)} obs = {
self._agent_id: torch.cat(
[obs[agent].reshape(self.num_envs, -1) for agent in self.env.possible_agents], dim=-1
)
}
return obs, extras return obs, extras
def step(self, actions: dict[AgentID, ActionType]) -> EnvStepReturn: def step(self, actions: dict[AgentID, ActionType]) -> EnvStepReturn:
# split agent actions to build the multi-agent ones # split agent actions to build the multi-agent ones
# FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
index = 0 index = 0
_actions = {} _actions = {}
for agent in self.env.possible_agents: for agent in self.env.possible_agents:
_actions[agent] = actions[self._agent_id][:, index : index + self.env.cfg.num_actions[agent]] delta = gym.spaces.flatdim(self.env.action_spaces[agent])
index += self.env.cfg.num_actions[agent] _actions[agent] = actions[self._agent_id][:, index : index + delta]
index += delta
# step the environment # step the environment
obs, rewards, terminated, time_outs, extras = self.env.step(_actions) obs, rewards, terminated, time_outs, extras = self.env.step(_actions)
...@@ -228,8 +246,13 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -228,8 +246,13 @@ def multi_agent_with_one_agent(env: DirectMARLEnv, state_as_observation: bool =
if self._state_as_observation: if self._state_as_observation:
obs = {self._agent_id: self.env.state()} obs = {self._agent_id: self.env.state()}
# concatenate agents' observations # concatenate agents' observations
# FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
else: else:
obs = {self._agent_id: torch.cat([obs[agent] for agent in self.env.possible_agents], dim=-1)} obs = {
self._agent_id: torch.cat(
[obs[agent].reshape(self.num_envs, -1) for agent in self.env.possible_agents], dim=-1
)
}
# process environment outputs to return agent data # process environment outputs to return agent data
rewards = {self._agent_id: sum(rewards.values())} rewards = {self._agent_id: sum(rewards.values())}
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import gymnasium as gym
import numpy as np
import torch
from typing import Any
from ..common import SpaceType
def spec_to_gym_space(spec: SpaceType) -> gym.spaces.Space:
"""Generate an appropriate Gymnasium space according to the given space specification.
Args:
spec: Space specification.
Returns:
Gymnasium space.
Raises:
ValueError: If the given space specification is not valid/supported.
"""
if isinstance(spec, gym.spaces.Space):
return spec
# fundamental spaces
# Box
elif isinstance(spec, int):
return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(spec,))
elif isinstance(spec, list) and all(isinstance(x, int) for x in spec):
return gym.spaces.Box(low=-np.inf, high=np.inf, shape=spec)
# Discrete
elif isinstance(spec, set) and len(spec) == 1:
return gym.spaces.Discrete(n=next(iter(spec)))
# MultiDiscrete
elif isinstance(spec, list) and all(isinstance(x, set) and len(x) == 1 for x in spec):
return gym.spaces.MultiDiscrete(nvec=[next(iter(x)) for x in spec])
# composite spaces
# Tuple
elif isinstance(spec, tuple):
return gym.spaces.Tuple([spec_to_gym_space(x) for x in spec])
# Dict
elif isinstance(spec, dict):
return gym.spaces.Dict({k: spec_to_gym_space(v) for k, v in spec.items()})
raise ValueError(f"Unsupported space specification: {spec}")
def sample_space(space: gym.spaces.Space, device: str, batch_size: int = -1, fill_value: float | None = None) -> Any:
"""Sample a Gymnasium space where the data container are PyTorch tensors.
Args:
space: Gymnasium space.
device: The device where the tensor should be created.
batch_size: Batch size. If the specified value is greater than zero, a batched space will be created and sampled from it.
fill_value: The value to fill the created tensors with. If None (default value), tensors will keep their random values.
Returns:
Tensorized sampled space.
"""
def tensorize(s, x):
if isinstance(s, gym.spaces.Box):
tensor = torch.tensor(x, device=device, dtype=torch.float32).reshape(batch_size, *s.shape)
if fill_value is not None:
tensor.fill_(fill_value)
return tensor
elif isinstance(s, gym.spaces.Discrete):
if isinstance(x, np.ndarray):
tensor = torch.tensor(x, device=device, dtype=torch.int64).reshape(batch_size, 1)
if fill_value is not None:
tensor.fill_(int(fill_value))
return tensor
elif isinstance(x, np.number) or type(x) in [int, float]:
tensor = torch.tensor([x], device=device, dtype=torch.int64).reshape(batch_size, 1)
if fill_value is not None:
tensor.fill_(int(fill_value))
return tensor
elif isinstance(s, gym.spaces.MultiDiscrete):
if isinstance(x, np.ndarray):
tensor = torch.tensor(x, device=device, dtype=torch.int64).reshape(batch_size, *s.shape)
if fill_value is not None:
tensor.fill_(int(fill_value))
return tensor
elif isinstance(s, gym.spaces.Dict):
return {k: tensorize(_s, x[k]) for k, _s in s.items()}
elif isinstance(s, gym.spaces.Tuple):
return tuple([tensorize(_s, v) for _s, v in zip(s, x)])
sample = (gym.vector.utils.batch_space(space, batch_size) if batch_size > 0 else space).sample()
return tensorize(space, sample)
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
# ignore private usage of variables warning
# pyright: reportPrivateUsage=none
from __future__ import annotations
"""Launch Isaac Sim Simulator first."""
from omni.isaac.lab.app import AppLauncher, run_tests
# Can set this to False to see the GUI for debugging
HEADLESS = True
# launch omniverse app
app_launcher = AppLauncher(headless=HEADLESS)
simulation_app = app_launcher.app
"""Rest everything follows."""
import numpy as np
import torch
import unittest
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
from omni.isaac.lab.envs.utils.spaces import sample_space, spec_to_gym_space
class TestSpacesUtils(unittest.TestCase):
"""Test for spaces utils' functions"""
"""
Tests
"""
def test_spec_to_gym_space(self):
# fundamental spaces
# Box
space = spec_to_gym_space(1)
self.assertIsInstance(space, Box)
self.assertEqual(space.shape, (1,))
space = spec_to_gym_space([1, 2, 3, 4, 5])
self.assertIsInstance(space, Box)
self.assertEqual(space.shape, (1, 2, 3, 4, 5))
space = spec_to_gym_space(Box(low=-1.0, high=1.0, shape=(1, 2)))
self.assertIsInstance(space, Box)
# Discrete
space = spec_to_gym_space({2})
self.assertIsInstance(space, Discrete)
self.assertEqual(space.n, 2)
space = spec_to_gym_space(Discrete(2))
self.assertIsInstance(space, Discrete)
# MultiDiscrete
space = spec_to_gym_space([{1}, {2}, {3}])
self.assertIsInstance(space, MultiDiscrete)
self.assertEqual(space.nvec.shape, (3,))
space = spec_to_gym_space(MultiDiscrete(np.array([1, 2, 3])))
self.assertIsInstance(space, MultiDiscrete)
# composite spaces
# Tuple
space = spec_to_gym_space(([1, 2, 3, 4, 5], {2}, [{1}, {2}, {3}]))
self.assertIsInstance(space, Tuple)
self.assertEqual(len(space), 3)
self.assertIsInstance(space[0], Box)
self.assertIsInstance(space[1], Discrete)
self.assertIsInstance(space[2], MultiDiscrete)
space = spec_to_gym_space(Tuple((Box(-1, 1, shape=(1,)), Discrete(2))))
self.assertIsInstance(space, Tuple)
# Dict
space = spec_to_gym_space({"box": [1, 2, 3, 4, 5], "discrete": {2}, "multi_discrete": [{1}, {2}, {3}]})
self.assertIsInstance(space, Dict)
self.assertEqual(len(space), 3)
self.assertIsInstance(space["box"], Box)
self.assertIsInstance(space["discrete"], Discrete)
self.assertIsInstance(space["multi_discrete"], MultiDiscrete)
space = spec_to_gym_space(Dict({"box": Box(-1, 1, shape=(1,)), "discrete": Discrete(2)}))
self.assertIsInstance(space, Dict)
def test_sample_space(self):
device = "cpu"
# fundamental spaces
# Box
sample = sample_space(Box(low=-1.0, high=1.0, shape=(1, 2)), device, batch_size=1)
self.assertIsInstance(sample, torch.Tensor)
self._check_tensorized(sample, batch_size=1)
# Discrete
sample = sample_space(Discrete(2), device, batch_size=2)
self.assertIsInstance(sample, torch.Tensor)
self._check_tensorized(sample, batch_size=2)
# MultiDiscrete
sample = sample_space(MultiDiscrete(np.array([1, 2, 3])), device, batch_size=3)
self.assertIsInstance(sample, torch.Tensor)
self._check_tensorized(sample, batch_size=3)
# composite spaces
# Tuple
sample = sample_space(Tuple((Box(-1, 1, shape=(1,)), Discrete(2))), device, batch_size=4)
self.assertIsInstance(sample, (tuple, list))
self._check_tensorized(sample, batch_size=4)
# Dict
sample = sample_space(Dict({"box": Box(-1, 1, shape=(1,)), "discrete": Discrete(2)}), device, batch_size=5)
self.assertIsInstance(sample, dict)
self._check_tensorized(sample, batch_size=5)
"""
Helper functions.
"""
def _check_tensorized(self, sample, batch_size):
if isinstance(sample, (tuple, list)):
list(map(self._check_tensorized, sample, [batch_size] * len(sample)))
elif isinstance(sample, dict):
list(map(self._check_tensorized, sample.values(), [batch_size] * len(sample)))
else:
self.assertIsInstance(sample, torch.Tensor)
self.assertEqual(sample.shape[0], batch_size)
if __name__ == "__main__":
run_tests()
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.10.5" version = "0.10.7"
# Description # Description
title = "Isaac Lab Environments" title = "Isaac Lab Environments"
......
Changelog Changelog
--------- ---------
0.10.7 (2024-10-02)
~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Replace deprecated :attr:`num_observations`, :attr:`num_actions` and :attr:`num_states` in single-agent direct tasks
by :attr:`observation_space`, :attr:`action_space` and :attr:`state_space` respectively.
* Replace deprecated :attr:`num_observations`, :attr:`num_actions` and :attr:`num_states` in multi-agent direct tasks
by :attr:`observation_spaces`, :attr:`action_spaces` and :attr:`state_space` respectively.
0.10.6 (2024-09-25) 0.10.6 (2024-09-25)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
Added Added
^^^^^ ^^^^^
* Added ``Isaac-Cartpole-RGB-Camera-v0`` and ``Isaac-Cartpole-Depth-Camera-v0`` * Added ``Isaac-Cartpole-RGB-Camera-v0`` and ``Isaac-Cartpole-Depth-Camera-v0``
manager based camera cartpole environments. manager based camera cartpole environments.
......
...@@ -22,9 +22,9 @@ class AllegroHandEnvCfg(DirectRLEnvCfg): ...@@ -22,9 +22,9 @@ class AllegroHandEnvCfg(DirectRLEnvCfg):
# env # env
decimation = 4 decimation = 4
episode_length_s = 10.0 episode_length_s = 10.0
num_actions = 16 action_space = 16
num_observations = 124 # (full) observation_space = 124 # (full)
num_states = 0 state_space = 0
asymmetric_obs = False asymmetric_obs = False
obs_type = "full" obs_type = "full"
# simulation # simulation
......
...@@ -24,9 +24,9 @@ class AntEnvCfg(DirectRLEnvCfg): ...@@ -24,9 +24,9 @@ class AntEnvCfg(DirectRLEnvCfg):
episode_length_s = 15.0 episode_length_s = 15.0
decimation = 2 decimation = 2
action_scale = 0.5 action_scale = 0.5
num_actions = 8 action_space = 8
num_observations = 36 observation_space = 36
num_states = 0 state_space = 0
# simulation # simulation
sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation) sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import gymnasium as gym
import torch import torch
import omni.isaac.lab.envs.mdp as mdp import omni.isaac.lab.envs.mdp as mdp
...@@ -59,9 +60,9 @@ class AnymalCFlatEnvCfg(DirectRLEnvCfg): ...@@ -59,9 +60,9 @@ class AnymalCFlatEnvCfg(DirectRLEnvCfg):
episode_length_s = 20.0 episode_length_s = 20.0
decimation = 4 decimation = 4
action_scale = 0.5 action_scale = 0.5
num_actions = 12 action_space = 12
num_observations = 48 observation_space = 48
num_states = 0 state_space = 0
# simulation # simulation
sim: SimulationCfg = SimulationCfg( sim: SimulationCfg = SimulationCfg(
...@@ -118,7 +119,7 @@ class AnymalCFlatEnvCfg(DirectRLEnvCfg): ...@@ -118,7 +119,7 @@ class AnymalCFlatEnvCfg(DirectRLEnvCfg):
@configclass @configclass
class AnymalCRoughEnvCfg(AnymalCFlatEnvCfg): class AnymalCRoughEnvCfg(AnymalCFlatEnvCfg):
# env # env
num_observations = 235 observation_space = 235
terrain = TerrainImporterCfg( terrain = TerrainImporterCfg(
prim_path="/World/ground", prim_path="/World/ground",
...@@ -160,8 +161,10 @@ class AnymalCEnv(DirectRLEnv): ...@@ -160,8 +161,10 @@ class AnymalCEnv(DirectRLEnv):
super().__init__(cfg, render_mode, **kwargs) super().__init__(cfg, render_mode, **kwargs)
# Joint position command (deviation from default joint positions) # Joint position command (deviation from default joint positions)
self._actions = torch.zeros(self.num_envs, self.cfg.num_actions, device=self.device) self._actions = torch.zeros(self.num_envs, gym.spaces.flatdim(self.single_action_space), device=self.device)
self._previous_actions = torch.zeros(self.num_envs, self.cfg.num_actions, device=self.device) self._previous_actions = torch.zeros(
self.num_envs, gym.spaces.flatdim(self.single_action_space), device=self.device
)
# X/Y linear velocity and yaw angular velocity commands # X/Y linear velocity and yaw angular velocity commands
self._commands = torch.zeros(self.num_envs, 3, device=self.device) self._commands = torch.zeros(self.num_envs, 3, device=self.device)
......
...@@ -27,9 +27,9 @@ class CartDoublePendulumEnvCfg(DirectMARLEnvCfg): ...@@ -27,9 +27,9 @@ class CartDoublePendulumEnvCfg(DirectMARLEnvCfg):
decimation = 2 decimation = 2
episode_length_s = 5.0 episode_length_s = 5.0
possible_agents = ["cart", "pendulum"] possible_agents = ["cart", "pendulum"]
num_actions = {"cart": 1, "pendulum": 1} action_spaces = {"cart": 1, "pendulum": 1}
num_observations = {"cart": 4, "pendulum": 3} observation_spaces = {"cart": 4, "pendulum": 3}
num_states = -1 state_space = -1
# simulation # simulation
sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation) sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation)
......
...@@ -5,9 +5,7 @@ ...@@ -5,9 +5,7 @@
from __future__ import annotations from __future__ import annotations
import gymnasium as gym
import math import math
import numpy as np
import torch import torch
from collections.abc import Sequence from collections.abc import Sequence
...@@ -29,9 +27,6 @@ class CartpoleRGBCameraEnvCfg(DirectRLEnvCfg): ...@@ -29,9 +27,6 @@ class CartpoleRGBCameraEnvCfg(DirectRLEnvCfg):
decimation = 2 decimation = 2
episode_length_s = 5.0 episode_length_s = 5.0
action_scale = 100.0 # [N] action_scale = 100.0 # [N]
num_actions = 1
num_channels = 3
num_states = 0
# simulation # simulation
sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation) sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation)
...@@ -52,9 +47,13 @@ class CartpoleRGBCameraEnvCfg(DirectRLEnvCfg): ...@@ -52,9 +47,13 @@ class CartpoleRGBCameraEnvCfg(DirectRLEnvCfg):
width=80, width=80,
height=80, height=80,
) )
num_observations = num_channels * tiled_camera.height * tiled_camera.width
write_image_to_file = False write_image_to_file = False
# spaces
action_space = 1
state_space = 0
observation_space = [tiled_camera.height, tiled_camera.width, 3]
# change viewer settings # change viewer settings
viewer = ViewerCfg(eye=(20.0, 20.0, 20.0)) viewer = ViewerCfg(eye=(20.0, 20.0, 20.0))
...@@ -87,9 +86,8 @@ class CartpoleDepthCameraEnvCfg(CartpoleRGBCameraEnvCfg): ...@@ -87,9 +86,8 @@ class CartpoleDepthCameraEnvCfg(CartpoleRGBCameraEnvCfg):
height=80, height=80,
) )
# env # spaces
num_channels = 1 observation_space = [tiled_camera.height, tiled_camera.width, 1]
num_observations = num_channels * tiled_camera.height * tiled_camera.width
class CartpoleCameraEnv(DirectRLEnv): class CartpoleCameraEnv(DirectRLEnv):
...@@ -118,35 +116,6 @@ class CartpoleCameraEnv(DirectRLEnv): ...@@ -118,35 +116,6 @@ class CartpoleCameraEnv(DirectRLEnv):
"""Cleanup for the environment.""" """Cleanup for the environment."""
super().close() super().close()
def _configure_gym_env_spaces(self):
"""Configure the action and observation spaces for the Gym environment."""
# observation space (unbounded since we don't impose any limits)
self.num_actions = self.cfg.num_actions
self.num_observations = self.cfg.num_observations
self.num_states = self.cfg.num_states
# set up spaces
self.single_observation_space = gym.spaces.Dict()
self.single_observation_space["policy"] = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.cfg.tiled_camera.height, self.cfg.tiled_camera.width, self.cfg.num_channels),
)
if self.num_states > 0:
self.single_observation_space["critic"] = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.cfg.tiled_camera.height, self.cfg.tiled_camera.width, self.cfg.num_channels),
)
self.single_action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_actions,))
# batch the spaces for vectorized environments
self.observation_space = gym.vector.utils.batch_space(self.single_observation_space, self.num_envs)
self.action_space = gym.vector.utils.batch_space(self.single_action_space, self.num_envs)
# RL specifics
self.actions = torch.zeros(self.num_envs, self.num_actions, device=self.sim.device)
def _setup_scene(self): def _setup_scene(self):
"""Setup the scene with the cartpole and camera.""" """Setup the scene with the cartpole and camera."""
self._cartpole = Articulation(self.cfg.robot_cfg) self._cartpole = Articulation(self.cfg.robot_cfg)
......
...@@ -27,9 +27,9 @@ class CartpoleEnvCfg(DirectRLEnvCfg): ...@@ -27,9 +27,9 @@ class CartpoleEnvCfg(DirectRLEnvCfg):
decimation = 2 decimation = 2
episode_length_s = 5.0 episode_length_s = 5.0
action_scale = 100.0 # [N] action_scale = 100.0 # [N]
num_actions = 1 action_space = 1
num_observations = 4 observation_space = 4
num_states = 0 state_space = 0
# simulation # simulation
sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation) sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation)
......
...@@ -28,9 +28,9 @@ class FrankaCabinetEnvCfg(DirectRLEnvCfg): ...@@ -28,9 +28,9 @@ class FrankaCabinetEnvCfg(DirectRLEnvCfg):
# env # env
episode_length_s = 8.3333 # 500 timesteps episode_length_s = 8.3333 # 500 timesteps
decimation = 2 decimation = 2
num_actions = 9 action_space = 9
num_observations = 23 observation_space = 23
num_states = 0 state_space = 0
# simulation # simulation
sim: SimulationCfg = SimulationCfg( sim: SimulationCfg = SimulationCfg(
......
...@@ -24,9 +24,9 @@ class HumanoidEnvCfg(DirectRLEnvCfg): ...@@ -24,9 +24,9 @@ class HumanoidEnvCfg(DirectRLEnvCfg):
episode_length_s = 15.0 episode_length_s = 15.0
decimation = 2 decimation = 2
action_scale = 1.0 action_scale = 1.0
num_actions = 21 action_space = 21
num_observations = 75 observation_space = 75
num_states = 0 state_space = 0
# simulation # simulation
sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation) sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import gymnasium as gym
import torch import torch
import omni.isaac.lab.sim as sim_utils import omni.isaac.lab.sim as sim_utils
...@@ -50,9 +51,9 @@ class QuadcopterEnvCfg(DirectRLEnvCfg): ...@@ -50,9 +51,9 @@ class QuadcopterEnvCfg(DirectRLEnvCfg):
# env # env
episode_length_s = 10.0 episode_length_s = 10.0
decimation = 2 decimation = 2
num_actions = 4 action_space = 4
num_observations = 12 observation_space = 12
num_states = 0 state_space = 0
debug_vis = True debug_vis = True
ui_window_class_type = QuadcopterEnvWindow ui_window_class_type = QuadcopterEnvWindow
...@@ -105,7 +106,7 @@ class QuadcopterEnv(DirectRLEnv): ...@@ -105,7 +106,7 @@ class QuadcopterEnv(DirectRLEnv):
super().__init__(cfg, render_mode, **kwargs) super().__init__(cfg, render_mode, **kwargs)
# Total thrust and moment applied to the base of the quadcopter # Total thrust and moment applied to the base of the quadcopter
self._actions = torch.zeros(self.num_envs, self.cfg.num_actions, device=self.device) self._actions = torch.zeros(self.num_envs, gym.spaces.flatdim(self.single_action_space), device=self.device)
self._thrust = torch.zeros(self.num_envs, 1, 3, device=self.device) self._thrust = torch.zeros(self.num_envs, 1, 3, device=self.device)
self._moment = torch.zeros(self.num_envs, 1, 3, device=self.device) self._moment = torch.zeros(self.num_envs, 1, 3, device=self.device)
# Goal position # Goal position
......
...@@ -119,9 +119,9 @@ class ShadowHandEnvCfg(DirectRLEnvCfg): ...@@ -119,9 +119,9 @@ class ShadowHandEnvCfg(DirectRLEnvCfg):
# env # env
decimation = 2 decimation = 2
episode_length_s = 10.0 episode_length_s = 10.0
num_actions = 20 action_space = 20
num_observations = 157 # (full) observation_space = 157 # (full)
num_states = 0 state_space = 0
asymmetric_obs = False asymmetric_obs = False
obs_type = "full" obs_type = "full"
...@@ -232,9 +232,9 @@ class ShadowHandOpenAIEnvCfg(ShadowHandEnvCfg): ...@@ -232,9 +232,9 @@ class ShadowHandOpenAIEnvCfg(ShadowHandEnvCfg):
# env # env
decimation = 3 decimation = 3
episode_length_s = 8.0 episode_length_s = 8.0
num_actions = 20 action_space = 20
num_observations = 42 observation_space = 42
num_states = 187 state_space = 187
asymmetric_obs = True asymmetric_obs = True
obs_type = "openai" obs_type = "openai"
# simulation # simulation
......
...@@ -48,8 +48,8 @@ class ShadowHandVisionEnvCfg(ShadowHandEnvCfg): ...@@ -48,8 +48,8 @@ class ShadowHandVisionEnvCfg(ShadowHandEnvCfg):
feature_extractor = FeatureExtractorCfg() feature_extractor = FeatureExtractorCfg()
# env # env
num_observations = 164 + 27 # state observation + vision CNN embedding observation_space = 164 + 27 # state observation + vision CNN embedding
num_states = 187 + 27 # asymettric states + vision CNN embedding state_space = 187 + 27 # asymettric states + vision CNN embedding
@configclass @configclass
......
...@@ -118,9 +118,9 @@ class ShadowHandOverEnvCfg(DirectMARLEnvCfg): ...@@ -118,9 +118,9 @@ class ShadowHandOverEnvCfg(DirectMARLEnvCfg):
decimation = 2 decimation = 2
episode_length_s = 7.5 episode_length_s = 7.5
possible_agents = ["right_hand", "left_hand"] possible_agents = ["right_hand", "left_hand"]
num_actions = {"right_hand": 20, "left_hand": 20} action_spaces = {"right_hand": 20, "left_hand": 20}
num_observations = {"right_hand": 157, "left_hand": 157} observation_spaces = {"right_hand": 157, "left_hand": 157}
num_states = 290 state_space = 290
# simulation # simulation
sim: SimulationCfg = SimulationCfg( sim: SimulationCfg = SimulationCfg(
......
...@@ -70,19 +70,19 @@ class RslRlVecEnvWrapper(VecEnv): ...@@ -70,19 +70,19 @@ class RslRlVecEnvWrapper(VecEnv):
if hasattr(self.unwrapped, "action_manager"): if hasattr(self.unwrapped, "action_manager"):
self.num_actions = self.unwrapped.action_manager.total_action_dim self.num_actions = self.unwrapped.action_manager.total_action_dim
else: else:
self.num_actions = self.unwrapped.num_actions self.num_actions = gym.spaces.flatdim(self.unwrapped.single_action_space)
if hasattr(self.unwrapped, "observation_manager"): if hasattr(self.unwrapped, "observation_manager"):
self.num_obs = self.unwrapped.observation_manager.group_obs_dim["policy"][0] self.num_obs = self.unwrapped.observation_manager.group_obs_dim["policy"][0]
else: else:
self.num_obs = self.unwrapped.num_observations self.num_obs = gym.spaces.flatdim(self.unwrapped.single_observation_space["policy"])
# -- privileged observations # -- privileged observations
if ( if (
hasattr(self.unwrapped, "observation_manager") hasattr(self.unwrapped, "observation_manager")
and "critic" in self.unwrapped.observation_manager.group_obs_dim and "critic" in self.unwrapped.observation_manager.group_obs_dim
): ):
self.num_privileged_obs = self.unwrapped.observation_manager.group_obs_dim["critic"][0] self.num_privileged_obs = self.unwrapped.observation_manager.group_obs_dim["critic"][0]
elif hasattr(self.unwrapped, "num_states"): elif hasattr(self.unwrapped, "num_states") and "critic" in self.unwrapped.single_observation_space:
self.num_privileged_obs = self.unwrapped.num_states self.num_privileged_obs = gym.spaces.flatdim(self.unwrapped.single_observation_space["critic"])
else: else:
self.num_privileged_obs = 0 self.num_privileged_obs = 0
# reset at the start since the RSL-RL runner does not call reset # reset at the start since the RSL-RL runner does not call reset
......
...@@ -22,6 +22,7 @@ import carb ...@@ -22,6 +22,7 @@ import carb
import omni.usd import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
from omni.isaac.lab.envs.utils import sample_space
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg
...@@ -108,12 +109,12 @@ class TestEnvironments(unittest.TestCase): ...@@ -108,12 +109,12 @@ class TestEnvironments(unittest.TestCase):
# simulate environment for num_steps steps # simulate environment for num_steps steps
with torch.inference_mode(): with torch.inference_mode():
for _ in range(num_steps): for _ in range(num_steps):
# sample actions from -1 to 1 # sample actions according to the defined space
actions = 2 * torch.rand(env.action_space.shape, device=env.unwrapped.device) - 1 actions = sample_space(env.single_action_space, device=env.unwrapped.device, batch_size=num_envs)
# apply actions # apply actions
transition = env.step(actions) transition = env.step(actions)
# check signals # check signals
for data in transition: for data in transition[:-1]: # exclude info
self.assertTrue(self._check_valid_tensor(data), msg=f"Invalid data: {data}") self.assertTrue(self._check_valid_tensor(data), msg=f"Invalid data: {data}")
# close the environment # close the environment
...@@ -131,14 +132,10 @@ class TestEnvironments(unittest.TestCase): ...@@ -131,14 +132,10 @@ class TestEnvironments(unittest.TestCase):
""" """
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return not torch.any(torch.isnan(data)) return not torch.any(torch.isnan(data))
elif isinstance(data, (tuple, list)):
return all(TestEnvironments._check_valid_tensor(value) for value in data)
elif isinstance(data, dict): elif isinstance(data, dict):
valid_tensor = True return all(TestEnvironments._check_valid_tensor(value) for value in data.values())
for value in data.values():
if isinstance(value, dict):
valid_tensor &= TestEnvironments._check_valid_tensor(value)
elif isinstance(value, torch.Tensor):
valid_tensor &= not torch.any(torch.isnan(value))
return valid_tensor
else: else:
raise ValueError(f"Input data of invalid type: {type(data)}.") raise ValueError(f"Input data of invalid type: {type(data)}.")
......
...@@ -21,6 +21,7 @@ import unittest ...@@ -21,6 +21,7 @@ import unittest
import omni.usd import omni.usd
from omni.isaac.lab.envs import DirectMARLEnv, DirectMARLEnvCfg from omni.isaac.lab.envs import DirectMARLEnv, DirectMARLEnvCfg
from omni.isaac.lab.envs.utils import sample_space
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg
...@@ -104,9 +105,9 @@ class TestEnvironments(unittest.TestCase): ...@@ -104,9 +105,9 @@ class TestEnvironments(unittest.TestCase):
# simulate environment for num_steps steps # simulate environment for num_steps steps
with torch.inference_mode(): with torch.inference_mode():
for _ in range(num_steps): for _ in range(num_steps):
# sample actions from -1 to 1 # sample actions according to the defined space
actions = { actions = {
agent: 2 * torch.rand(env.action_space(agent).shape, device=env.unwrapped.device) - 1 agent: sample_space(env.action_spaces[agent], device=env.unwrapped.device)
for agent in env.unwrapped.possible_agents for agent in env.unwrapped.possible_agents
} }
# apply actions # apply actions
...@@ -131,14 +132,10 @@ class TestEnvironments(unittest.TestCase): ...@@ -131,14 +132,10 @@ class TestEnvironments(unittest.TestCase):
""" """
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return not torch.any(torch.isnan(data)) return not torch.any(torch.isnan(data))
elif isinstance(data, (tuple, list)):
return all(TestEnvironments._check_valid_tensor(value) for value in data)
elif isinstance(data, dict): elif isinstance(data, dict):
valid_tensor = True return all(TestEnvironments._check_valid_tensor(value) for value in data.values())
for value in data.values():
if isinstance(value, dict):
valid_tensor &= TestEnvironments._check_valid_tensor(value)
elif isinstance(value, torch.Tensor):
valid_tensor &= not torch.any(torch.isnan(value))
return valid_tensor
else: else:
raise ValueError(f"Input data of invalid type: {type(data)}.") raise ValueError(f"Input data of invalid type: {type(data)}.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment