Commit e851048f authored by Toni-SM's avatar Toni-SM Committed by David Hoeller

Add the direct workflow cart double pendulum multi-agent environment (#94)

This PR adds the cart double pendulum multi-agent direct-workflow task
(`Isaac-Cart-Double-Pendulum-Direct-v0`)

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarToni-SM <toni.semu@gmail.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarAlexander <143108850+nv-apoddubny@users.noreply.github.com>
Co-authored-by: 's avatarAlexander Poddubny <apoddubny@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 38f72c0f
...@@ -12,12 +12,14 @@ running the following command: ...@@ -12,12 +12,14 @@ running the following command:
We are actively working on adding more environments to the list. If you have any environments that We are actively working on adding more environments to the list. If you have any environments that
you would like to add to Isaac Lab, please feel free to open a pull request! you would like to add to Isaac Lab, please feel free to open a pull request!
Single-agent
------------
Classic Classic
------- ~~~~~~~
Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-style environments. Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-style environments.
.. table:: .. table::
:widths: 33 37 30 :widths: 33 37 30
...@@ -52,7 +54,7 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty ...@@ -52,7 +54,7 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
Manipulation Manipulation
------------ ~~~~~~~~~~~~
Environments based on fixed-arm manipulation tasks. Environments based on fixed-arm manipulation tasks.
...@@ -108,7 +110,7 @@ for the reach environment: ...@@ -108,7 +110,7 @@ for the reach environment:
.. |cube-shadow-lstm-link| replace:: `Isaac-Repose-Cube-Shadow-OpenAI-LSTM-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/shadow_hand/shadow_hand_env_cfg.py>`__ .. |cube-shadow-lstm-link| replace:: `Isaac-Repose-Cube-Shadow-OpenAI-LSTM-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/shadow_hand/shadow_hand_env_cfg.py>`__
Locomotion Locomotion
---------- ~~~~~~~~~~
Environments based on legged locomotion tasks. Environments based on legged locomotion tasks.
...@@ -204,7 +206,7 @@ Environments based on legged locomotion tasks. ...@@ -204,7 +206,7 @@ Environments based on legged locomotion tasks.
.. |velocity-rough-g1| image:: ../_static/tasks/locomotion/g1_rough.jpg .. |velocity-rough-g1| image:: ../_static/tasks/locomotion/g1_rough.jpg
Navigation Navigation
---------- ~~~~~~~~~~
.. table:: .. table::
:widths: 33 37 30 :widths: 33 37 30
...@@ -221,7 +223,7 @@ Navigation ...@@ -221,7 +223,7 @@ Navigation
Others Others
------ ~~~~~~
.. table:: .. table::
:widths: 33 37 30 :widths: 33 37 30
...@@ -238,6 +240,27 @@ Others ...@@ -238,6 +240,27 @@ Others
.. |quadcopter| image:: ../_static/tasks/others/quadcopter.jpg .. |quadcopter| image:: ../_static/tasks/others/quadcopter.jpg
Multi-agent
------------
Classic
~~~~~~~
.. table::
:widths: 33 37 30
+------------------------+------------------------------------+-----------------------------------------------------------------------------------------------------------------------+
| World | Environment ID | Description |
+========================+====================================+=======================================================================================================================+
| |cart-double-pendulum| | |cart-double-pendulum-direct-link| | Move the cart and the pendulum to keep the last one upwards in the classic inverted double pendulum on a cart control |
+------------------------+------------------------------------+-----------------------------------------------------------------------------------------------------------------------+
.. |cart-double-pendulum| image:: ../_static/tasks/classic/cart_double_pendulum.jpg
.. |cart-double-pendulum-direct-link| replace:: `Isaac-Cart-Double-Pendulum-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/cart_double_pendulum/cart_double_pendulum_env.py>`__
Comprehensive List of Environments Comprehensive List of Environments
================================== ==================================
...@@ -255,6 +278,8 @@ Comprehensive List of Environments ...@@ -255,6 +278,8 @@ Comprehensive List of Environments
+------------------------------------------------+--------------------------------------------+---------------+-----------------------------+ +------------------------------------------------+--------------------------------------------+---------------+-----------------------------+
| Isaac-Velocity-Rough-Anymal-C-Direct-v0 | | Direct | rsl_rl, rl_games, skrl | | Isaac-Velocity-Rough-Anymal-C-Direct-v0 | | Direct | rsl_rl, rl_games, skrl |
+------------------------------------------------+--------------------------------------------+---------------+-----------------------------+ +------------------------------------------------+--------------------------------------------+---------------+-----------------------------+
| Isaac-Cart-Double-Pendulum-Direct-v0 | | Direct | rsl_rl, rl_games, skrl |
+------------------------------------------------+--------------------------------------------+---------------+-----------------------------+
| Isaac-Cartpole-Direct-v0 | | Direct | rsl_rl, rl_games, skrl, sb3 | | Isaac-Cartpole-Direct-v0 | | Direct | rsl_rl, rl_games, skrl, sb3 |
+------------------------------------------------+--------------------------------------------+---------------+-----------------------------+ +------------------------------------------------+--------------------------------------------+---------------+-----------------------------+
| Isaac-Cartpole-RGB-Camera-Direct-v0 | | Direct | rl_games | | Isaac-Cartpole-RGB-Camera-Direct-v0 | | Direct | rl_games |
......
[package] [package]
# Semantic Versioning is used: https://semver.org/ # Semantic Versioning is used: https://semver.org/
version = "0.1.3" version = "0.1.4"
# Description # Description
title = "Isaac Lab Assets" title = "Isaac Lab Assets"
......
Changelog Changelog
--------- ---------
0.1.4 (2024-08-21)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added configuration for the Inverted Double Pendulum on a Cart robot.
0.1.2 (2024-04-03) 0.1.2 (2024-04-03)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -29,6 +29,7 @@ __version__ = ISAACLAB_ASSETS_METADATA["package"]["version"] ...@@ -29,6 +29,7 @@ __version__ = ISAACLAB_ASSETS_METADATA["package"]["version"]
from .allegro import * from .allegro import *
from .ant import * from .ant import *
from .anymal import * from .anymal import *
from .cart_double_pendulum import *
from .cartpole import * from .cartpole import *
from .franka import * from .franka import *
from .humanoid import * from .humanoid import *
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Configuration for a simple inverted Double Pendulum on a Cart robot."""
import omni.isaac.lab.sim as sim_utils
from omni.isaac.lab.actuators import ImplicitActuatorCfg
from omni.isaac.lab.assets import ArticulationCfg
from omni.isaac.lab.utils.assets import ISAACLAB_NUCLEUS_DIR
##
# Configuration
##
CART_DOUBLE_PENDULUM_CFG = ArticulationCfg(
spawn=sim_utils.UsdFileCfg(
usd_path=f"{ISAACLAB_NUCLEUS_DIR}/Robots/Classic/CartDoublePendulum/cart_double_pendulum.usd",
rigid_props=sim_utils.RigidBodyPropertiesCfg(
rigid_body_enabled=True,
max_linear_velocity=1000.0,
max_angular_velocity=1000.0,
max_depenetration_velocity=100.0,
enable_gyroscopic_forces=True,
),
articulation_props=sim_utils.ArticulationRootPropertiesCfg(
enabled_self_collisions=False,
solver_position_iteration_count=4,
solver_velocity_iteration_count=0,
sleep_threshold=0.005,
stabilization_threshold=0.001,
),
),
init_state=ArticulationCfg.InitialStateCfg(
pos=(0.0, 0.0, 2.0), joint_pos={"slider_to_cart": 0.0, "cart_to_pole": 0.0, "pole_to_pendulum": 0.0}
),
actuators={
"cart_actuator": ImplicitActuatorCfg(
joint_names_expr=["slider_to_cart"],
effort_limit=400.0,
velocity_limit=100.0,
stiffness=0.0,
damping=10.0,
),
"pole_actuator": ImplicitActuatorCfg(
joint_names_expr=["cart_to_pole"], effort_limit=400.0, velocity_limit=100.0, stiffness=0.0, damping=0.0
),
"pendulum_actuator": ImplicitActuatorCfg(
joint_names_expr=["pole_to_pendulum"], effort_limit=400.0, velocity_limit=100.0, stiffness=0.0, damping=0.0
),
},
)
"""Configuration for a simple inverted Double Pendulum on a Cart robot."""
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.10.0" version = "0.10.1"
# Description # Description
title = "Isaac Lab Environments" title = "Isaac Lab Environments"
......
Changelog Changelog
--------- ---------
0.10.1 (2024-08-21)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added ``Isaac-Cart-Double-Pendulum-Direct-v0`` multi-agent environment
Changed
^^^^^^^
* Update skrl wrapper to support multi-agent environments.
0.10.0 (2024-08-14) 0.10.0 (2024-08-14)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""
Inverted Double Pendulum on a Cart balancing environment.
"""
import gymnasium as gym
from . import agents
from .cart_double_pendulum_env import CartDoublePendulumEnv, CartDoublePendulumEnvCfg
##
# Register Gym environments.
##
gym.register(
id="Isaac-Cart-Double-Pendulum-Direct-v0",
entry_point="omni.isaac.lab_tasks.direct.cart_double_pendulum:CartDoublePendulumEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": CartDoublePendulumEnvCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CartDoublePendulumPPORunnerCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml",
"skrl_ippo_cfg_entry_point": f"{agents.__name__}:skrl_ippo_cfg.yaml",
"skrl_mappo_cfg_entry_point": f"{agents.__name__}:skrl_mappo_cfg.yaml",
},
)
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
params:
seed: 42
# environment wrapper clipping
env:
# added to the wrapper
clip_observations: 5.0
# can make custom wrapper?
clip_actions: 1.0
algo:
name: a2c_continuous
model:
name: continuous_a2c_logstd
# doesn't have this fine grained control but made it close
network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: True
mlp:
units: [32, 32]
activation: elu
d2rl: False
initializer:
name: default
regularizer:
name: None
load_checkpoint: False # flag which sets whether to load the checkpoint
load_path: '' # path to the checkpoint to load
config:
name: cart_double_pendulum_direct
env_name: rlgpu
device: 'cuda:0'
device_name: 'cuda:0'
multi_gpu: False
ppo: True
mixed_precision: False
normalize_input: True
normalize_value: True
num_actors: -1 # configured from the script (based on num_envs)
reward_shaper:
scale_value: 0.1
normalize_advantage: True
gamma: 0.99
tau : 0.95
learning_rate: 5e-4
lr_schedule: adaptive
kl_threshold: 0.008
score_to_win: 20000
max_epochs: 150
save_best_after: 50
save_frequency: 25
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 32
minibatch_size: 16384
mini_epochs: 8
critic_coef: 4
clip_value: True
seq_length: 4
bounds_loss_coef: 0.0001
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from omni.isaac.lab.utils import configclass
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import (
RslRlOnPolicyRunnerCfg,
RslRlPpoActorCriticCfg,
RslRlPpoAlgorithmCfg,
)
@configclass
class CartDoublePendulumPPORunnerCfg(RslRlOnPolicyRunnerCfg):
num_steps_per_env = 16
max_iterations = 150
save_interval = 50
experiment_name = "cart_double_pendulum_direct"
empirical_normalization = False
policy = RslRlPpoActorCriticCfg(
init_noise_std=1.0,
actor_hidden_dims=[32, 32],
critic_hidden_dims=[32, 32],
activation="elu",
)
algorithm = RslRlPpoAlgorithmCfg(
value_loss_coef=1.0,
use_clipped_value_loss=True,
clip_param=0.2,
entropy_coef=0.005,
num_learning_epochs=5,
num_mini_batches=4,
learning_rate=1.0e-3,
schedule="adaptive",
gamma=0.99,
lam=0.95,
desired_kl=0.01,
max_grad_norm=1.0,
)
seed: 42
# Models are instantiated using skrl's model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models:
separate: False
policy: # see gaussian_model parameters
class: "GaussianMixin"
clip_actions: False
clip_log_std: True
initial_log_std: 0
min_log_std: -20.0
max_log_std: 2.0
input_shape: "Shape.STATES"
hiddens: [32, 32]
hidden_activation: ["elu", "elu"]
output_shape: "Shape.ACTIONS"
output_activation: "tanh"
output_scale: 1.0
value: # see deterministic_model parameters
class: "DeterministicMixin"
clip_actions: False
input_shape: "Shape.STATES"
hiddens: [32, 32]
hidden_activation: ["elu", "elu"]
output_shape: "Shape.ONE"
output_activation: ""
output_scale: 1.0
# Memory
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory:
class: "RandomMemory"
memory_size: -1 # automatically determined (same as agent:rollouts)
# IPPO agent configuration (field names are from IPPO_DEFAULT_CONFIG)
# https://skrl.readthedocs.io/en/latest/api/multi_agents/ippo.html
agent:
class: "IPPO"
rollouts: 16
learning_epochs: 8
mini_batches: 1
discount_factor: 0.99
lambda: 0.95
learning_rate: 3.e-4
learning_rate_scheduler: "KLAdaptiveLR"
learning_rate_scheduler_kwargs:
kl_threshold: 0.008
state_preprocessor: "RunningStandardScaler"
state_preprocessor_kwargs: null
value_preprocessor: "RunningStandardScaler"
value_preprocessor_kwargs: null
random_timesteps: 0
learning_starts: 0
grad_norm_clip: 1.0
ratio_clip: 0.2
value_clip: 0.2
clip_predicted_values: True
entropy_loss_scale: 0.0
value_loss_scale: 2.0
kl_threshold: 0
time_limit_bootstrap: True
rewards_shaper_scale: 1.0
# logging and checkpoint
experiment:
directory: "cart_double_pendulum_direct"
experiment_name: ""
write_interval: 16
checkpoint_interval: 80
# Sequential trainer
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
class: "SequentialTrainer"
timesteps: 1600
environment_info: "log"
seed: 42
# Models are instantiated using skrl's model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models:
separate: True
policy: # see gaussian_model parameters
class: "GaussianMixin"
clip_actions: False
clip_log_std: True
initial_log_std: 0
min_log_std: -20.0
max_log_std: 2.0
input_shape: "Shape.STATES"
hiddens: [32, 32]
hidden_activation: ["elu", "elu"]
output_shape: "Shape.ACTIONS"
output_activation: "tanh"
output_scale: 1.0
value: # see deterministic_model parameters
class: "DeterministicMixin"
clip_actions: False
input_shape: "Shape.STATES"
hiddens: [32, 32]
hidden_activation: ["elu", "elu"]
output_shape: "Shape.ONE"
output_activation: ""
output_scale: 1.0
# Memory
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory:
class: "RandomMemory"
memory_size: -1 # automatically determined (same as agent:rollouts)
# MAPPO agent configuration (field names are from MAPPO_DEFAULT_CONFIG)
# https://skrl.readthedocs.io/en/latest/api/multi_agents/mappo.html
agent:
class: "MAPPO"
rollouts: 16
learning_epochs: 8
mini_batches: 1
discount_factor: 0.99
lambda: 0.95
learning_rate: 3.e-4
learning_rate_scheduler: "KLAdaptiveLR"
learning_rate_scheduler_kwargs:
kl_threshold: 0.008
state_preprocessor: "RunningStandardScaler"
state_preprocessor_kwargs: null
shared_state_preprocessor: "RunningStandardScaler"
shared_state_preprocessor_kwargs: null
value_preprocessor: "RunningStandardScaler"
value_preprocessor_kwargs: null
random_timesteps: 0
learning_starts: 0
grad_norm_clip: 1.0
ratio_clip: 0.2
value_clip: 0.2
clip_predicted_values: True
entropy_loss_scale: 0.0
value_loss_scale: 2.0
kl_threshold: 0
time_limit_bootstrap: True
rewards_shaper_scale: 1.0
# logging and checkpoint
experiment:
directory: "cart_double_pendulum_direct"
experiment_name: ""
write_interval: 16
checkpoint_interval: 80
# Sequential trainer
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
class: "SequentialTrainer"
timesteps: 1600
environment_info: "log"
seed: 42
# Models are instantiated using skrl's model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models:
separate: False
policy: # see gaussian_model parameters
class: "GaussianMixin"
clip_actions: False
clip_log_std: True
initial_log_std: 0
min_log_std: -20.0
max_log_std: 2.0
input_shape: "Shape.STATES"
hiddens: [32, 32]
hidden_activation: ["elu", "elu"]
output_shape: "Shape.ACTIONS"
output_activation: "tanh"
output_scale: 1.0
value: # see deterministic_model parameters
class: "DeterministicMixin"
clip_actions: False
input_shape: "Shape.STATES"
hiddens: [32, 32]
hidden_activation: ["elu", "elu"]
output_shape: "Shape.ONE"
output_activation: ""
output_scale: 1.0
# Memory
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory:
class: "RandomMemory"
memory_size: -1 # automatically determined (same as agent:rollouts)
# PPO agent configuration (field names are from PPO_DEFAULT_CONFIG)
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent:
class: "PPO"
rollouts: 16
learning_epochs: 8
mini_batches: 1
discount_factor: 0.99
lambda: 0.95
learning_rate: 3.e-4
learning_rate_scheduler: "KLAdaptiveLR"
learning_rate_scheduler_kwargs:
kl_threshold: 0.008
state_preprocessor: "RunningStandardScaler"
state_preprocessor_kwargs: null
value_preprocessor: "RunningStandardScaler"
value_preprocessor_kwargs: null
random_timesteps: 0
learning_starts: 0
grad_norm_clip: 1.0
ratio_clip: 0.2
value_clip: 0.2
clip_predicted_values: True
entropy_loss_scale: 0.0
value_loss_scale: 2.0
kl_threshold: 0
time_limit_bootstrap: True
rewards_shaper_scale: 1.0
# logging and checkpoint
experiment:
directory: "cart_double_pendulum_direct"
experiment_name: ""
write_interval: 16
checkpoint_interval: 80
# Sequential trainer
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
class: "SequentialTrainer"
timesteps: 1600
environment_info: "log"
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import math
import torch
from collections.abc import Sequence
from omni.isaac.lab_assets.cart_double_pendulum import CART_DOUBLE_PENDULUM_CFG
import omni.isaac.lab.sim as sim_utils
from omni.isaac.lab.assets import Articulation, ArticulationCfg
from omni.isaac.lab.envs import DirectMARLEnv, DirectMARLEnvCfg
from omni.isaac.lab.scene import InteractiveSceneCfg
from omni.isaac.lab.sim import SimulationCfg
from omni.isaac.lab.sim.spawners.from_files import GroundPlaneCfg, spawn_ground_plane
from omni.isaac.lab.utils import configclass
from omni.isaac.lab.utils.math import sample_uniform
@configclass
class CartDoublePendulumEnvCfg(DirectMARLEnvCfg):
# env
decimation = 2
episode_length_s = 5.0
possible_agents = ["cart", "pendulum"]
num_actions = {"cart": 1, "pendulum": 1}
num_observations = {"cart": 4, "pendulum": 3}
num_states = -1
# simulation
sim: SimulationCfg = SimulationCfg(dt=1 / 120, render_interval=decimation)
# robot
robot_cfg: ArticulationCfg = CART_DOUBLE_PENDULUM_CFG.replace(prim_path="/World/envs/env_.*/Robot")
cart_dof_name = "slider_to_cart"
pole_dof_name = "cart_to_pole"
pendulum_dof_name = "pole_to_pendulum"
# scene
scene: InteractiveSceneCfg = InteractiveSceneCfg(num_envs=4096, env_spacing=4.0, replicate_physics=True)
# reset
max_cart_pos = 3.0 # the cart is reset if it exceeds that position [m]
initial_pole_angle_range = [-0.25, 0.25] # the range in which the pole angle is sampled from on reset [rad]
initial_pendulum_angle_range = [-0.25, 0.25] # the range in which the pendulum angle is sampled from on reset [rad]
# action scales
cart_action_scale = 100.0 # [N]
pendulum_action_scale = 50.0 # [Nm]
# reward scales
rew_scale_alive = 1.0
rew_scale_terminated = -2.0
rew_scale_cart_pos = 0
rew_scale_cart_vel = -0.01
rew_scale_pole_pos = -1.0
rew_scale_pole_vel = -0.01
rew_scale_pendulum_pos = -1.0
rew_scale_pendulum_vel = -0.01
class CartDoublePendulumEnv(DirectMARLEnv):
cfg: CartDoublePendulumEnvCfg
def __init__(self, cfg: CartDoublePendulumEnvCfg, render_mode: str | None = None, **kwargs):
super().__init__(cfg, render_mode, **kwargs)
self._cart_dof_idx, _ = self.robot.find_joints(self.cfg.cart_dof_name)
self._pole_dof_idx, _ = self.robot.find_joints(self.cfg.pole_dof_name)
self._pendulum_dof_idx, _ = self.robot.find_joints(self.cfg.pendulum_dof_name)
self.joint_pos = self.robot.data.joint_pos
self.joint_vel = self.robot.data.joint_vel
def _setup_scene(self):
self.robot = Articulation(self.cfg.robot_cfg)
# add ground plane
spawn_ground_plane(prim_path="/World/ground", cfg=GroundPlaneCfg())
# clone, filter, and replicate
self.scene.clone_environments(copy_from_source=False)
self.scene.filter_collisions(global_prim_paths=[])
# add articulation to scene
self.scene.articulations["robot"] = self.robot
# add lights
light_cfg = sim_utils.DomeLightCfg(intensity=2000.0, color=(0.75, 0.75, 0.75))
light_cfg.func("/World/Light", light_cfg)
def _pre_physics_step(self, actions: dict[str, torch.Tensor]) -> None:
self.actions = actions
def _apply_action(self) -> None:
self.robot.set_joint_effort_target(
self.actions["cart"] * self.cfg.cart_action_scale, joint_ids=self._cart_dof_idx
)
self.robot.set_joint_effort_target(
self.actions["pendulum"] * self.cfg.pendulum_action_scale, joint_ids=self._pendulum_dof_idx
)
def _get_observations(self) -> dict[str, torch.Tensor]:
pole_joint_pos = normalize_angle(self.joint_pos[:, self._pole_dof_idx[0]].unsqueeze(dim=1))
pendulum_joint_pos = normalize_angle(self.joint_pos[:, self._pendulum_dof_idx[0]].unsqueeze(dim=1))
observations = {
"cart": torch.cat(
(
self.joint_pos[:, self._cart_dof_idx[0]].unsqueeze(dim=1),
self.joint_vel[:, self._cart_dof_idx[0]].unsqueeze(dim=1),
pole_joint_pos,
self.joint_vel[:, self._pole_dof_idx[0]].unsqueeze(dim=1),
),
dim=-1,
),
"pendulum": torch.cat(
(
pole_joint_pos + pendulum_joint_pos,
pendulum_joint_pos,
self.joint_vel[:, self._pendulum_dof_idx[0]].unsqueeze(dim=1),
),
dim=-1,
),
}
return observations
def _get_rewards(self) -> dict[str, torch.Tensor]:
total_reward = compute_rewards(
self.cfg.rew_scale_alive,
self.cfg.rew_scale_terminated,
self.cfg.rew_scale_cart_pos,
self.cfg.rew_scale_cart_vel,
self.cfg.rew_scale_pole_pos,
self.cfg.rew_scale_pole_vel,
self.cfg.rew_scale_pendulum_pos,
self.cfg.rew_scale_pendulum_vel,
self.joint_pos[:, self._cart_dof_idx[0]],
self.joint_vel[:, self._cart_dof_idx[0]],
normalize_angle(self.joint_pos[:, self._pole_dof_idx[0]]),
self.joint_vel[:, self._pole_dof_idx[0]],
normalize_angle(self.joint_pos[:, self._pendulum_dof_idx[0]]),
self.joint_vel[:, self._pendulum_dof_idx[0]],
math.prod(self.terminated_dict.values()),
)
return total_reward
def _get_dones(self) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
self.joint_pos = self.robot.data.joint_pos
self.joint_vel = self.robot.data.joint_vel
time_out = self.episode_length_buf >= self.max_episode_length - 1
out_of_bounds = torch.any(torch.abs(self.joint_pos[:, self._cart_dof_idx]) > self.cfg.max_cart_pos, dim=1)
out_of_bounds = out_of_bounds | torch.any(torch.abs(self.joint_pos[:, self._pole_dof_idx]) > math.pi / 2, dim=1)
terminated = {agent: out_of_bounds for agent in self.cfg.possible_agents}
time_outs = {agent: time_out for agent in self.cfg.possible_agents}
return terminated, time_outs
def _reset_idx(self, env_ids: Sequence[int] | None):
if env_ids is None:
env_ids = self.robot._ALL_INDICES
super()._reset_idx(env_ids)
joint_pos = self.robot.data.default_joint_pos[env_ids]
joint_pos[:, self._pole_dof_idx] += sample_uniform(
self.cfg.initial_pole_angle_range[0] * math.pi,
self.cfg.initial_pole_angle_range[1] * math.pi,
joint_pos[:, self._pole_dof_idx].shape,
joint_pos.device,
)
joint_pos[:, self._pendulum_dof_idx] += sample_uniform(
self.cfg.initial_pendulum_angle_range[0] * math.pi,
self.cfg.initial_pendulum_angle_range[1] * math.pi,
joint_pos[:, self._pendulum_dof_idx].shape,
joint_pos.device,
)
joint_vel = self.robot.data.default_joint_vel[env_ids]
default_root_state = self.robot.data.default_root_state[env_ids]
default_root_state[:, :3] += self.scene.env_origins[env_ids]
self.joint_pos[env_ids] = joint_pos
self.joint_vel[env_ids] = joint_vel
self.robot.write_root_pose_to_sim(default_root_state[:, :7], env_ids)
self.robot.write_root_velocity_to_sim(default_root_state[:, 7:], env_ids)
self.robot.write_joint_state_to_sim(joint_pos, joint_vel, None, env_ids)
@torch.jit.script
def normalize_angle(angle):
return (angle + math.pi) % (2 * math.pi) - math.pi
@torch.jit.script
def compute_rewards(
rew_scale_alive: float,
rew_scale_terminated: float,
rew_scale_cart_pos: float,
rew_scale_cart_vel: float,
rew_scale_pole_pos: float,
rew_scale_pole_vel: float,
rew_scale_pendulum_pos: float,
rew_scale_pendulum_vel: float,
cart_pos: torch.Tensor,
cart_vel: torch.Tensor,
pole_pos: torch.Tensor,
pole_vel: torch.Tensor,
pendulum_pos: torch.Tensor,
pendulum_vel: torch.Tensor,
reset_terminated: torch.Tensor,
):
rew_alive = rew_scale_alive * (1.0 - reset_terminated.float())
rew_termination = rew_scale_terminated * reset_terminated.float()
rew_pole_pos = rew_scale_pole_pos * torch.sum(torch.square(pole_pos).unsqueeze(dim=1), dim=-1)
rew_pendulum_pos = rew_scale_pendulum_pos * torch.sum(
torch.square(pole_pos + pendulum_pos).unsqueeze(dim=1), dim=-1
)
rew_cart_vel = rew_scale_cart_vel * torch.sum(torch.abs(cart_vel).unsqueeze(dim=1), dim=-1)
rew_pole_vel = rew_scale_pole_vel * torch.sum(torch.abs(pole_vel).unsqueeze(dim=1), dim=-1)
rew_pendulum_vel = rew_scale_pendulum_vel * torch.sum(torch.abs(pendulum_vel).unsqueeze(dim=1), dim=-1)
total_reward = {
"cart": rew_alive + rew_termination + rew_pole_pos + rew_cart_vel + rew_pole_vel,
"pendulum": rew_alive + rew_termination + rew_pendulum_pos + rew_pendulum_vel,
}
return total_reward
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
"""Wrapper to configure a :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv` instance to skrl environment. """Wrapper to configure an Isaac Lab environment instance to skrl environment.
The following example shows how to wrap an environment for skrl: The following example shows how to wrap an environment for skrl:
...@@ -29,71 +29,7 @@ from __future__ import annotations ...@@ -29,71 +29,7 @@ from __future__ import annotations
from typing import Literal from typing import Literal
from omni.isaac.lab.envs import DirectRLEnv, ManagerBasedRLEnv from omni.isaac.lab.envs import DirectMARLEnv, DirectRLEnv, ManagerBasedRLEnv
"""
Configuration Parser.
"""
def process_skrl_cfg(cfg: dict, ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch") -> dict:
"""Convert simple YAML types to skrl classes/components.
Args:
cfg: A configuration dictionary.
ml_framework: The ML framework to use for the wrapper. Defaults to "torch".
Returns:
A dictionary containing the converted configuration.
Raises:
ValueError: If the specified ML framework is not valid.
"""
_direct_eval = [
"learning_rate_scheduler",
"state_preprocessor",
"value_preprocessor",
"input_shape",
"output_shape",
]
def reward_shaper_function(scale):
def reward_shaper(rewards, timestep, timesteps):
return rewards * scale
return reward_shaper
def update_dict(d):
# import statements according to the ML framework
if ml_framework.startswith("torch"):
from skrl.resources.preprocessors.torch import RunningStandardScaler # noqa: F401
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa: F401
from skrl.utils.model_instantiators.torch import Shape # noqa: F401
elif ml_framework.startswith("jax"):
from skrl.resources.preprocessors.jax import RunningStandardScaler # noqa: F401
from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa: F401
from skrl.utils.model_instantiators.jax import Shape # noqa: F401
else:
ValueError(
f"Invalid ML framework for skrl: {ml_framework}. Available options are: 'torch', 'jax' or 'jax-numpy'"
)
for key, value in d.items():
if isinstance(value, dict):
update_dict(value)
else:
if key in _direct_eval:
d[key] = eval(value)
elif key.endswith("_kwargs"):
d[key] = value if value is not None else {}
elif key in ["rewards_shaper_scale"]:
d["rewards_shaper"] = reward_shaper_function(value)
return d
# parse agent configuration and convert to classes
return update_dict(cfg)
""" """
Vectorized environment wrapper. Vectorized environment wrapper.
...@@ -101,30 +37,39 @@ Vectorized environment wrapper. ...@@ -101,30 +37,39 @@ Vectorized environment wrapper.
def SkrlVecEnvWrapper( def SkrlVecEnvWrapper(
env: ManagerBasedRLEnv | DirectRLEnv, ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch" env: ManagerBasedRLEnv | DirectRLEnv | DirectMARLEnv,
ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch",
wrapper: Literal["auto", "isaaclab", "isaaclab-single-agent", "isaaclab-multi-agent"] = "auto",
): ):
"""Wraps around Isaac Lab environment for skrl. """Wraps around Isaac Lab environment for skrl.
This function wraps around the Isaac Lab environment. Since the :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv` environment This function wraps around the Isaac Lab environment. Since the wrapping
wrapping functionality is defined within the skrl library itself, this implementation functionality is defined within the skrl library itself, this implementation
is maintained for compatibility with the structure of the extension that contains it. is maintained for compatibility with the structure of the extension that contains it.
Internally it calls the :func:`wrap_env` from the skrl library API. Internally it calls the :func:`wrap_env` from the skrl library API.
Args: Args:
env: The environment to wrap around. env: The environment to wrap around.
ml_framework: The ML framework to use for the wrapper. Defaults to "torch". ml_framework: The ML framework to use for the wrapper. Defaults to "torch".
wrapper: The wrapper to use. Defaults to "auto": leave it to skrl to determine if the environment
will be wrapped as single-agent or multi-agent.
Raises: Raises:
ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`. ValueError: When the environment is not an instance of any Isaac Lab environment interface.
ValueError: If the specified ML framework is not valid. ValueError: If the specified ML framework is not valid.
Reference: Reference:
https://skrl.readthedocs.io/en/latest/api/envs/wrapping.html https://skrl.readthedocs.io/en/latest/api/envs/wrapping.html
""" """
# check that input is valid # check that input is valid
if not isinstance(env.unwrapped, ManagerBasedRLEnv) and not isinstance(env.unwrapped, DirectRLEnv): if (
not isinstance(env.unwrapped, ManagerBasedRLEnv)
and not isinstance(env.unwrapped, DirectRLEnv)
and not isinstance(env.unwrapped, DirectMARLEnv)
):
raise ValueError( raise ValueError(
f"The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type: {type(env)}" "The environment must be inherited from ManagerBasedRLEnv, DirectRLEnv or DirectMARLEnv. Environment type:"
f" {type(env)}"
) )
# import statements according to the ML framework # import statements according to the ML framework
...@@ -138,4 +83,4 @@ def SkrlVecEnvWrapper( ...@@ -138,4 +83,4 @@ def SkrlVecEnvWrapper(
) )
# wrap and return the environment # wrap and return the environment
return wrap_env(env, wrapper="isaaclab") return wrap_env(env, wrapper)
...@@ -86,7 +86,13 @@ class TestEnvironments(unittest.TestCase): ...@@ -86,7 +86,13 @@ class TestEnvironments(unittest.TestCase):
# create a new stage # create a new stage
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg = parse_env_cfg(task_name, device=device, num_envs=num_envs) env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=device, num_envs=num_envs)
# skip test if the environment is a multi-agent task
if hasattr(env_cfg, "possible_agents"):
print(f"[INFO]: Skipping {task_name} as it is a multi-agent task")
return
# create environment # create environment
env = gym.make(task_name, cfg=env_cfg) env = gym.make(task_name, cfg=env_cfg)
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Launch Isaac Sim Simulator first."""
from omni.isaac.lab.app import AppLauncher, run_tests
# launch the simulator
app_launcher = AppLauncher(headless=True, enable_cameras=True)
simulation_app = app_launcher.app
"""Rest everything follows."""
import gymnasium as gym
import torch
import unittest
import omni.usd
from omni.isaac.lab.envs import DirectMARLEnv, DirectMARLEnvCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg
class TestEnvironments(unittest.TestCase):
"""Test cases for all registered multi-agent environments."""
@classmethod
def setUpClass(cls):
# acquire all Isaac environments names
cls.registered_tasks = list()
for task_spec in gym.registry.values():
if "Isaac" in task_spec.id and not task_spec.id.endswith("Play-v0"):
cls.registered_tasks.append(task_spec.id)
# sort environments by name
cls.registered_tasks.sort()
# print all existing task names
print(">>> All registered environments:", cls.registered_tasks)
"""
Test fixtures.
"""
def test_multiple_instances_gpu(self):
"""Run all environments with multiple instances and check environments return valid signals."""
# common parameters
num_envs = 32
device = "cuda"
# iterate over all registered environments
for task_name in self.registered_tasks:
with self.subTest(task_name=task_name):
print(f">>> Running test for environment: {task_name}")
# check environment
self._check_random_actions(task_name, device, num_envs, num_steps=100)
# close the environment
print(f">>> Closing environment: {task_name}")
print("-" * 80)
def test_single_instance_gpu(self):
"""Run all environments with single instance and check environments return valid signals."""
# common parameters
num_envs = 1
device = "cuda"
# iterate over all registered environments
for task_name in self.registered_tasks:
with self.subTest(task_name=task_name):
print(f">>> Running test for environment: {task_name}")
# check environment
self._check_random_actions(task_name, device, num_envs, num_steps=100)
# close the environment
print(f">>> Closing environment: {task_name}")
print("-" * 80)
"""
Helper functions.
"""
def _check_random_actions(self, task_name: str, device: str, num_envs: int, num_steps: int = 1000):
"""Run random actions and check environments return valid signals."""
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: DirectMARLEnvCfg = parse_env_cfg(task_name, device=device, num_envs=num_envs)
# skip test if the environment is not a multi-agent task
if not hasattr(env_cfg, "possible_agents"):
print(f"[INFO]: Skipping {task_name} as it is not a multi-agent task")
return
# create environment
env: DirectMARLEnv = gym.make(task_name, cfg=env_cfg)
# this flag is necessary to prevent a bug where the simulation gets stuck randomly when running the
# test on many environments.
env.sim.set_setting("/physics/cooking/ujitsoCollisionCooking", False)
# reset environment
obs, _ = env.reset()
# check signal
self.assertTrue(self._check_valid_tensor(obs))
# simulate environment for num_steps steps
with torch.inference_mode():
for _ in range(num_steps):
# sample actions from -1 to 1
actions = {
agent: 2 * torch.rand(env.action_space(agent).shape, device=env.unwrapped.device) - 1
for agent in env.unwrapped.possible_agents
}
# apply actions
transition = env.step(actions)
# check signals
for item in transition[:-1]: # exclude info
for agent, data in item.items():
self.assertTrue(self._check_valid_tensor(data), msg=f"Invalid data ('{agent}'): {data}")
# close the environment
env.close()
@staticmethod
def _check_valid_tensor(data: torch.Tensor | dict) -> bool:
"""Checks if given data does not have corrupted values.
Args:
data: Data buffer.
Returns:
True if the data is valid.
"""
if isinstance(data, torch.Tensor):
return not torch.any(torch.isnan(data))
elif isinstance(data, dict):
valid_tensor = True
for value in data.values():
if isinstance(value, dict):
valid_tensor &= TestEnvironments._check_valid_tensor(value)
elif isinstance(value, torch.Tensor):
valid_tensor &= not torch.any(torch.isnan(value))
return valid_tensor
else:
raise ValueError(f"Input data of invalid type: {type(data)}.")
if __name__ == "__main__":
run_tests()
...@@ -50,6 +50,7 @@ from rl_games.common import env_configurations, vecenv ...@@ -50,6 +50,7 @@ from rl_games.common import env_configurations, vecenv
from rl_games.common.player import BasePlayer from rl_games.common.player import BasePlayer
from rl_games.torch_runner import Runner from rl_games.torch_runner import Runner
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.assets import retrieve_file_path from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
...@@ -104,6 +105,11 @@ def main(): ...@@ -104,6 +105,11 @@ def main():
print("[INFO] Recording videos during training.") print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rl-games # wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions) env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
......
...@@ -53,7 +53,13 @@ from rl_games.common import env_configurations, vecenv ...@@ -53,7 +53,13 @@ from rl_games.common import env_configurations, vecenv
from rl_games.common.algo_observer import IsaacAlgoObserver from rl_games.common.algo_observer import IsaacAlgoObserver
from rl_games.torch_runner import Runner from rl_games.torch_runner import Runner
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg from omni.isaac.lab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from omni.isaac.lab.utils.assets import retrieve_file_path from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
...@@ -64,7 +70,7 @@ from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesGpuEnv, RlGamesV ...@@ -64,7 +70,7 @@ from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesGpuEnv, RlGamesV
@hydra_task_config(args_cli.task, "rl_games_cfg_entry_point") @hydra_task_config(args_cli.task, "rl_games_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Train with RL-Games agent.""" """Train with RL-Games agent."""
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
...@@ -127,6 +133,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict): ...@@ -127,6 +133,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
print("[INFO] Recording videos during training.") print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rl-games # wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions) env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
......
...@@ -44,6 +44,7 @@ import torch ...@@ -44,6 +44,7 @@ import torch
from rsl_rl.runners import OnPolicyRunner from rsl_rl.runners import OnPolicyRunner
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
...@@ -84,6 +85,11 @@ def main(): ...@@ -84,6 +85,11 @@ def main():
print("[INFO] Recording videos during training.") print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rsl-rl # wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env) env = RslRlVecEnvWrapper(env)
......
...@@ -51,7 +51,13 @@ from datetime import datetime ...@@ -51,7 +51,13 @@ from datetime import datetime
from rsl_rl.runners import OnPolicyRunner from rsl_rl.runners import OnPolicyRunner
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg from omni.isaac.lab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
...@@ -67,7 +73,7 @@ torch.backends.cudnn.benchmark = False ...@@ -67,7 +73,7 @@ torch.backends.cudnn.benchmark = False
@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point") @hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
"""Train with RSL-RL agent.""" """Train with RSL-RL agent."""
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
...@@ -103,6 +109,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: RslRlOnPolic ...@@ -103,6 +109,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: RslRlOnPolic
print("[INFO] Recording videos during training.") print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rsl-rl # wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env) env = RslRlVecEnvWrapper(env)
......
...@@ -53,7 +53,13 @@ from stable_baselines3.common.callbacks import CheckpointCallback ...@@ -53,7 +53,13 @@ from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.logger import configure from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg from omni.isaac.lab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
...@@ -63,7 +69,7 @@ from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb ...@@ -63,7 +69,7 @@ from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb
@hydra_task_config(args_cli.task, "sb3_cfg_entry_point") @hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Train with stable-baselines agent.""" """Train with stable-baselines agent."""
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
...@@ -103,6 +109,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict): ...@@ -103,6 +109,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
print("[INFO] Recording videos during training.") print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for stable baselines # wrap around environment for stable baselines
env = Sb3VecEnvWrapper(env) env = Sb3VecEnvWrapper(env)
......
...@@ -12,7 +12,6 @@ a more user-friendly way. ...@@ -12,7 +12,6 @@ a more user-friendly way.
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
from omni.isaac.lab.app import AppLauncher from omni.isaac.lab.app import AppLauncher
...@@ -34,10 +33,16 @@ parser.add_argument( ...@@ -34,10 +33,16 @@ parser.add_argument(
choices=["torch", "jax", "jax-numpy"], choices=["torch", "jax", "jax-numpy"],
help="The ML framework used for training the skrl agent.", help="The ML framework used for training the skrl agent.",
) )
parser.add_argument(
"--algorithm",
type=str,
default="PPO",
choices=["PPO", "IPPO", "MAPPO"],
help="The RL algorithm used for training the skrl agent.",
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args() args_cli = parser.parse_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
...@@ -54,19 +59,31 @@ import os ...@@ -54,19 +59,31 @@ import os
import torch import torch
import skrl import skrl
from packaging import version
# check for minimum supported skrl version
SKRL_VERSION = "1.3.0"
if version.parse(skrl.__version__) < version.parse(SKRL_VERSION):
skrl.logger.error(
f"Unsupported skrl version: {skrl.__version__}. "
f"Install supported version using 'pip install skrl>={SKRL_VERSION}'"
)
exit()
if args_cli.ml_framework.startswith("torch"): if args_cli.ml_framework.startswith("torch"):
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG from skrl.utils.runner.torch import Runner
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
elif args_cli.ml_framework.startswith("jax"): elif args_cli.ml_framework.startswith("jax"):
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG from skrl.utils.runner.jax import Runner
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper
# config shortcuts
algorithm = args_cli.algorithm.lower()
def main(): def main():
...@@ -74,10 +91,14 @@ def main(): ...@@ -74,10 +91,14 @@ def main():
# configure the ML framework into the global skrl variable # configure the ML framework into the global skrl variable
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# parse configuration # parse configuration
env_cfg = parse_env_cfg( env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
) )
try:
experiment_cfg = load_cfg_from_registry(args_cli.task, f"skrl_{algorithm}_cfg_entry_point")
except ValueError:
experiment_cfg = load_cfg_from_registry(args_cli.task, "skrl_cfg_entry_point") experiment_cfg = load_cfg_from_registry(args_cli.task, "skrl_cfg_entry_point")
# specify directory for logging experiments (load checkpoint) # specify directory for logging experiments (load checkpoint)
...@@ -104,73 +125,25 @@ def main(): ...@@ -104,73 +125,25 @@ def main():
print("[INFO] Recording videos during training.") print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)
# wrap around environment for skrl # wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")` env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")`
# instantiate models using skrl model instantiator utility # configure and instantiate the skrl runner
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html # https://skrl.readthedocs.io/en/latest/api/utils/runner.html
models = {} experiment_cfg["trainer"]["close_environment_at_exit"] = False
if args_cli.ml_framework.startswith("jax"): experiment_cfg["agent"]["experiment"]["write_interval"] = 0 # don't log to TensorBoard
experiment_cfg["models"]["separate"] = True # shared model is not supported in JAX experiment_cfg["agent"]["experiment"]["checkpoint_interval"] = 0 # don't generate checkpoints
# non-shared models runner = Runner(env, experiment_cfg)
if experiment_cfg["models"]["separate"]:
models["policy"] = gaussian_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
)
models["value"] = deterministic_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
)
# shared models
else:
models["policy"] = shared_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
structure=None,
roles=["policy", "value"],
parameters=[
process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
],
)
models["value"] = models["policy"]
# instantiate models' state dict
if args_cli.ml_framework.startswith("jax"):
for role, model in models.items():
model.init_state_dict(role)
# configure and instantiate PPO agent
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent_cfg = PPO_DEFAULT_CONFIG.copy()
experiment_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"], ml_framework=args_cli.ml_framework))
agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})
agent_cfg["experiment"]["write_interval"] = 0 # don't log to Tensorboard
agent_cfg["experiment"]["checkpoint_interval"] = 0 # don't generate checkpoints
agent = PPO(
models=models,
memory=None, # memory is optional during evaluation
cfg=agent_cfg,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
)
# initialize agent
agent.init()
print(f"[INFO] Loading model checkpoint from: {resume_path}") print(f"[INFO] Loading model checkpoint from: {resume_path}")
agent.load(resume_path) runner.agent.load(resume_path)
# set agent to evaluation mode # set agent to evaluation mode
agent.set_running_mode("eval") runner.agent.set_running_mode("eval")
# reset environment # reset environment
obs, _ = env.reset() obs, _ = env.reset()
...@@ -180,7 +153,7 @@ def main(): ...@@ -180,7 +153,7 @@ def main():
# run everything in inference mode # run everything in inference mode
with torch.inference_mode(): with torch.inference_mode():
# agent stepping # agent stepping
actions = agent.act(obs, timestep=0, timesteps=0)[0] actions = runner.agent.act(obs, timestep=0, timesteps=0)[0]
# env stepping # env stepping
obs, _, _, _, _ = env.step(actions) obs, _, _, _, _ = env.step(actions)
if args_cli.video: if args_cli.video:
......
...@@ -12,7 +12,6 @@ a more user-friendly way. ...@@ -12,7 +12,6 @@ a more user-friendly way.
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys import sys
...@@ -37,6 +36,13 @@ parser.add_argument( ...@@ -37,6 +36,13 @@ parser.add_argument(
choices=["torch", "jax", "jax-numpy"], choices=["torch", "jax", "jax-numpy"],
help="The ML framework used for training the skrl agent.", help="The ML framework used for training the skrl agent.",
) )
parser.add_argument(
"--algorithm",
type=str,
default="PPO",
choices=["PPO", "IPPO", "MAPPO"],
help="The RL algorithm used for training the skrl agent.",
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
...@@ -60,44 +66,53 @@ import os ...@@ -60,44 +66,53 @@ import os
from datetime import datetime from datetime import datetime
import skrl import skrl
from skrl.utils import set_seed from packaging import version
# check for minimum supported skrl version
SKRL_VERSION = "1.3.0"
if version.parse(skrl.__version__) < version.parse(SKRL_VERSION):
skrl.logger.error(
f"Unsupported skrl version: {skrl.__version__}. "
f"Install supported version using 'pip install skrl>={SKRL_VERSION}'"
)
exit()
if args_cli.ml_framework.startswith("torch"): if args_cli.ml_framework.startswith("torch"):
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG from skrl.utils.runner.torch import Runner
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
elif args_cli.ml_framework.startswith("jax"): elif args_cli.ml_framework.startswith("jax"):
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG from skrl.utils.runner.jax import Runner
from skrl.memories.jax import RandomMemory
from skrl.trainers.jax import SequentialTrainer from omni.isaac.lab.envs import (
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model DirectMARLEnv,
DirectMARLEnvCfg,
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper
# config shortcuts
algorithm = args_cli.algorithm.lower()
agent_cfg_entry_point = "skrl_cfg_entry_point" if algorithm in ["ppo"] else f"skrl_{algorithm}_cfg_entry_point"
@hydra_task_config(args_cli.task, "skrl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict): @hydra_task_config(args_cli.task, agent_cfg_entry_point)
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Train with skrl agent.""" """Train with skrl agent."""
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
set_seed(args_cli.seed if args_cli.seed is not None else agent_cfg["seed"])
# multi-gpu training config # multi-gpu training config
if args_cli.distributed: if args_cli.distributed:
if args_cli.ml_framework.startswith("jax"):
raise ValueError("Multi-GPU distributed training not yet supported in JAX")
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}" env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# max iterations for training # max iterations for training
if args_cli.max_iterations: if args_cli.max_iterations:
agent_cfg["trainer"]["timesteps"] = args_cli.max_iterations * agent_cfg["agent"]["rollouts"] agent_cfg["trainer"]["timesteps"] = args_cli.max_iterations * agent_cfg["agent"]["rollouts"]
agent_cfg["trainer"]["close_environment_at_exit"] = False
# configure the ML framework into the global skrl variable # configure the ML framework into the global skrl variable
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
...@@ -120,6 +135,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict): ...@@ -120,6 +135,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
# update log_dir # update log_dir
log_dir = os.path.join(log_root_path, log_dir) log_dir = os.path.join(log_root_path, log_dir)
# multi-gpu training config
if args_cli.distributed:
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# dump the configuration into log-directory # dump the configuration into log-directory
dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg) dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg) dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
...@@ -139,78 +159,20 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict): ...@@ -139,78 +159,20 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: dict):
print("[INFO] Recording videos during training.") print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
# instantiate models using skrl model instantiator utility # convert to single-agent instance if required by the RL algorithm
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
models = {} env = multi_agent_to_single_agent(env)
if args_cli.ml_framework.startswith("jax"):
agent_cfg["models"]["separate"] = True # shared model is not supported in JAX # wrap around environment for skrl
# non-shared models env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")`
if agent_cfg["models"]["separate"]:
models["policy"] = gaussian_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(agent_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
)
models["value"] = deterministic_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(agent_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
)
# shared models
else:
models["policy"] = shared_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
structure=None,
roles=["policy", "value"],
parameters=[
process_skrl_cfg(agent_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(agent_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
],
)
models["value"] = models["policy"]
# instantiate models' state dict
if args_cli.ml_framework.startswith("jax"):
for role, model in models.items():
model.init_state_dict(role)
# instantiate a RandomMemory as rollout buffer (any memory can be used for this)
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
memory_size = agent_cfg["agent"]["rollouts"] # memory_size is the agent's number of rollouts
memory = RandomMemory(memory_size=memory_size, num_envs=env.num_envs, device=env.device)
# configure and instantiate PPO agent
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
default_agent_cfg = PPO_DEFAULT_CONFIG.copy()
agent_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
default_agent_cfg.update(process_skrl_cfg(agent_cfg["agent"], ml_framework=args_cli.ml_framework))
default_agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
default_agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})
agent = PPO(
models=models,
memory=memory,
cfg=default_agent_cfg,
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
)
# configure and instantiate a custom RL trainer for logging episode events # configure and instantiate the skrl runner
# https://skrl.readthedocs.io/en/latest/api/trainers.html # https://skrl.readthedocs.io/en/latest/api/utils/runner.html
trainer_cfg = agent_cfg["trainer"] runner = Runner(env, agent_cfg)
trainer_cfg["close_environment_at_exit"] = False
trainer = SequentialTrainer(cfg=trainer_cfg, env=env, agents=agent)
# train the agent # run training
trainer.train() runner.run()
# close the simulator # close the simulator
env.close() env.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment