Unverified Commit ad14a674 authored by Antonin RAFFIN's avatar Antonin RAFFIN Committed by GitHub

Adds optimizations and additional training configs for SB3 (#2022)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

Please include a summary of the change and which issue is fixed. Please
also include relevant motivation and context.
List any dependencies that are required for this change.

Implement part of https://github.com/isaac-sim/IsaacLab/issues/1769
(optimization)

This is a breaking change because the fast variant is now enabled by
default.

I also improve sb3 training script, fixed loading of normalization and
fixed the humanoid hyperparameters to be similar to rsl-rl, so we can
compare apples to apples in terms of training speed.

I will probably open another PR for the rest of the proposals.

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

With respect to testing, how do you run a single test?
and is there anything I should add?


## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarAntonin RAFFIN <antonin.raffin@ensta.org>
Signed-off-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 9980e665
...@@ -40,6 +40,7 @@ Guidelines for modifications: ...@@ -40,6 +40,7 @@ Guidelines for modifications:
* Amr Mousa * Amr Mousa
* Andrej Orsula * Andrej Orsula
* Anton Bjørndahl Mortensen * Anton Bjørndahl Mortensen
* Antonin Raffin
* Arjun Bhardwaj * Arjun Bhardwaj
* Ashwin Varghese Kuruttukulam * Ashwin Varghese Kuruttukulam
* Bikram Pandit * Bikram Pandit
......
...@@ -884,7 +884,7 @@ Comprehensive List of Environments ...@@ -884,7 +884,7 @@ Comprehensive List of Environments
* - Isaac-Velocity-Flat-Unitree-A1-v0 * - Isaac-Velocity-Flat-Unitree-A1-v0
- Isaac-Velocity-Flat-Unitree-A1-Play-v0 - Isaac-Velocity-Flat-Unitree-A1-Play-v0
- Manager Based - Manager Based
- **rsl_rl** (PPO), **skrl** (PPO) - **rsl_rl** (PPO), **skrl** (PPO), **sb3** (PPO)
* - Isaac-Velocity-Flat-Unitree-Go1-v0 * - Isaac-Velocity-Flat-Unitree-Go1-v0
- Isaac-Velocity-Flat-Unitree-Go1-Play-v0 - Isaac-Velocity-Flat-Unitree-Go1-Play-v0
- Manager Based - Manager Based
...@@ -924,7 +924,7 @@ Comprehensive List of Environments ...@@ -924,7 +924,7 @@ Comprehensive List of Environments
* - Isaac-Velocity-Rough-Unitree-A1-v0 * - Isaac-Velocity-Rough-Unitree-A1-v0
- Isaac-Velocity-Rough-Unitree-A1-Play-v0 - Isaac-Velocity-Rough-Unitree-A1-Play-v0
- Manager Based - Manager Based
- **rsl_rl** (PPO), **skrl** (PPO) - **rsl_rl** (PPO), **skrl** (PPO), **sb3** (PPO)
* - Isaac-Velocity-Rough-Unitree-Go1-v0 * - Isaac-Velocity-Rough-Unitree-Go1-v0
- Isaac-Velocity-Rough-Unitree-Go1-Play-v0 - Isaac-Velocity-Rough-Unitree-Go1-Play-v0
- Manager Based - Manager Based
......
...@@ -187,7 +187,7 @@ Stable-Baselines3 ...@@ -187,7 +187,7 @@ Stable-Baselines3
- Training an agent with - Training an agent with
`Stable-Baselines3 <https://stable-baselines3.readthedocs.io/en/master/index.html>`__ `Stable-Baselines3 <https://stable-baselines3.readthedocs.io/en/master/index.html>`__
on ``Isaac-Cartpole-v0``: on ``Isaac-Velocity-Flat-Unitree-A1-v0``:
.. tab-set:: .. tab-set::
:sync-group: os :sync-group: os
...@@ -200,14 +200,13 @@ Stable-Baselines3 ...@@ -200,14 +200,13 @@ Stable-Baselines3
# install python module (for stable-baselines3) # install python module (for stable-baselines3)
./isaaclab.sh -i sb3 ./isaaclab.sh -i sb3
# run script for training # run script for training
# note: we set the device to cpu since SB3 doesn't optimize for GPU anyway ./isaaclab.sh -p scripts/reinforcement_learning/sb3/train.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --headless
./isaaclab.sh -p scripts/reinforcement_learning/sb3/train.py --task Isaac-Cartpole-v0 --headless --device cpu
# run script for playing with 32 environments # run script for playing with 32 environments
./isaaclab.sh -p scripts/reinforcement_learning/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip ./isaaclab.sh -p scripts/reinforcement_learning/sb3/play.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip
# run script for playing a pre-trained checkpoint with 32 environments # run script for playing a pre-trained checkpoint with 32 environments
./isaaclab.sh -p scripts/reinforcement_learning/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --use_pretrained_checkpoint ./isaaclab.sh -p scripts/reinforcement_learning/sb3/play.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --num_envs 32 --use_pretrained_checkpoint
# run script for recording video of a trained agent (requires installing `ffmpeg`) # run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p scripts/reinforcement_learning/sb3/play.py --task Isaac-Cartpole-v0 --headless --video --video_length 200 ./isaaclab.sh -p scripts/reinforcement_learning/sb3/play.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --headless --video --video_length 200
.. tab-item:: :icon:`fa-brands fa-windows` Windows .. tab-item:: :icon:`fa-brands fa-windows` Windows
:sync: windows :sync: windows
...@@ -217,14 +216,13 @@ Stable-Baselines3 ...@@ -217,14 +216,13 @@ Stable-Baselines3
:: install python module (for stable-baselines3) :: install python module (for stable-baselines3)
isaaclab.bat -i sb3 isaaclab.bat -i sb3
:: run script for training :: run script for training
:: note: we set the device to cpu since SB3 doesn't optimize for GPU anyway isaaclab.bat -p scripts\reinforcement_learning\sb3\train.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --headless
isaaclab.bat -p scripts\reinforcement_learning\sb3\train.py --task Isaac-Cartpole-v0 --headless --device cpu
:: run script for playing with 32 environments :: run script for playing with 32 environments
isaaclab.bat -p scripts\reinforcement_learning\sb3\play.py --task Isaac-Cartpole-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip isaaclab.bat -p scripts\reinforcement_learning\sb3\play.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip
:: run script for playing a pre-trained checkpoint with 32 environments :: run script for playing a pre-trained checkpoint with 32 environments
isaaclab.bat -p scripts\reinforcement_learning\sb3\play.py --task Isaac-Cartpole-v0 --num_envs 32 --use_pretrained_checkpoint isaaclab.bat -p scripts\reinforcement_learning\sb3\play.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --num_envs 32 --use_pretrained_checkpoint
:: run script for recording video of a trained agent (requires installing `ffmpeg`) :: run script for recording video of a trained agent (requires installing `ffmpeg`)
isaaclab.bat -p scripts\reinforcement_learning\sb3\play.py --task Isaac-Cartpole-v0 --headless --video --video_length 200 isaaclab.bat -p scripts\reinforcement_learning\sb3\play.py --task Isaac-Velocity-Flat-Unitree-A1-v0 --headless --video --video_length 200
All the scripts above log the training progress to `Tensorboard`_ in the ``logs`` directory in the root of All the scripts above log the training progress to `Tensorboard`_ in the ``logs`` directory in the root of
the repository. The logs directory follows the pattern ``logs/<library>/<task>/<date-time>``, where ``<library>`` the repository. The logs directory follows the pattern ``logs/<library>/<task>/<date-time>``, where ``<library>``
......
...@@ -71,9 +71,12 @@ Training Performance ...@@ -71,9 +71,12 @@ Training Performance
-------------------- --------------------
We performed training with each RL library on the same ``Isaac-Humanoid-v0`` environment We performed training with each RL library on the same ``Isaac-Humanoid-v0`` environment
with ``--headless`` on a single RTX 4090 GPU with ``--headless`` on a single RTX 4090 GPU using 4096 environments
and logged the total training time for 65.5M steps for each RL library. and logged the total training time for 65.5M steps for each RL library.
..
Note: SB3 need to be re-run (current number comes from a GeForce RTX 3070)
+--------------------+-----------------+ +--------------------+-----------------+
| RL Library | Time in seconds | | RL Library | Time in seconds |
+====================+=================+ +====================+=================+
...@@ -83,5 +86,5 @@ and logged the total training time for 65.5M steps for each RL library. ...@@ -83,5 +86,5 @@ and logged the total training time for 65.5M steps for each RL library.
+--------------------+-----------------+ +--------------------+-----------------+
| RSL RL | 207 | | RSL RL | 207 |
+--------------------+-----------------+ +--------------------+-----------------+
| Stable-Baselines3 | 6320 | | Stable-Baselines3 | 550 |
+--------------------+-----------------+ +--------------------+-----------------+
...@@ -188,7 +188,7 @@ def main(): ...@@ -188,7 +188,7 @@ def main():
s[:, dones, :] = 0.0 s[:, dones, :] = 0.0
if args_cli.video: if args_cli.video:
timestep += 1 timestep += 1
# Exit the play loop after recording one video # exit the play loop after recording one video
if timestep == args_cli.video_length: if timestep == args_cli.video_length:
break break
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
from pathlib import Path
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
...@@ -32,6 +33,12 @@ parser.add_argument( ...@@ -32,6 +33,12 @@ parser.add_argument(
help="When no checkpoint provided, use the last saved model. Otherwise use the best saved model.", help="When no checkpoint provided, use the last saved model. Otherwise use the best saved model.",
) )
parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.") parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.")
parser.add_argument(
"--keep_all_info",
action="store_true",
default=False,
help="Use a slower SB3 wrapper but keep all the extra training info.",
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
...@@ -47,7 +54,6 @@ simulation_app = app_launcher.app ...@@ -47,7 +54,6 @@ simulation_app = app_launcher.app
"""Rest everything follows.""" """Rest everything follows."""
import gymnasium as gym import gymnasium as gym
import numpy as np
import os import os
import time import time
import torch import torch
...@@ -57,12 +63,13 @@ from stable_baselines3.common.vec_env import VecNormalize ...@@ -57,12 +63,13 @@ from stable_baselines3.common.vec_env import VecNormalize
from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent
from isaaclab.utils.dict import print_dict from isaaclab.utils.dict import print_dict
from isaaclab.utils.io import load_yaml
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
from isaaclab_rl.sb3 import Sb3VecEnvWrapper, process_sb3_cfg from isaaclab_rl.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
import isaaclab_tasks # noqa: F401 import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils.parse_cfg import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg from isaaclab_tasks.utils.parse_cfg import get_checkpoint_path, parse_env_cfg
# PLACEHOLDER: Extension template (do not remove this comment) # PLACEHOLDER: Extension template (do not remove this comment)
...@@ -73,7 +80,6 @@ def main(): ...@@ -73,7 +80,6 @@ def main():
env_cfg = parse_env_cfg( env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
) )
agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point")
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
...@@ -87,6 +93,7 @@ def main(): ...@@ -87,6 +93,7 @@ def main():
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return return
elif args_cli.checkpoint is None: elif args_cli.checkpoint is None:
# FIXME: last checkpoint doesn't seem to really use the last one'
if args_cli.use_last_checkpoint: if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip" checkpoint = "model_.*.zip"
else: else:
...@@ -96,12 +103,14 @@ def main(): ...@@ -96,12 +103,14 @@ def main():
checkpoint_path = args_cli.checkpoint checkpoint_path = args_cli.checkpoint
log_dir = os.path.dirname(checkpoint_path) log_dir = os.path.dirname(checkpoint_path)
# post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg)
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# load the exact config used for training (instead of the default config)
agent_cfg = load_yaml(os.path.join(log_dir, "params", "agent.yaml"))
# post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg, env.unwrapped.num_envs)
# convert to single-agent instance if required by the RL algorithm # convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv): if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env) env = multi_agent_to_single_agent(env)
...@@ -118,18 +127,25 @@ def main(): ...@@ -118,18 +127,25 @@ def main():
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for stable baselines # wrap around environment for stable baselines
env = Sb3VecEnvWrapper(env) env = Sb3VecEnvWrapper(env, fast_variant=not args_cli.keep_all_info)
vec_norm_path = checkpoint_path.replace("/model", "/model_vecnormalize").replace(".zip", ".pkl")
vec_norm_path = Path(vec_norm_path)
# normalize environment (if needed) # normalize environment (if needed)
if "normalize_input" in agent_cfg: if vec_norm_path.exists():
print(f"Loading saved normalization: {vec_norm_path}")
env = VecNormalize.load(vec_norm_path, env)
# do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False
elif "normalize_input" in agent_cfg:
env = VecNormalize( env = VecNormalize(
env, env,
training=True, training=True,
norm_obs="normalize_input" in agent_cfg and agent_cfg.pop("normalize_input"), norm_obs="normalize_input" in agent_cfg and agent_cfg.pop("normalize_input"),
norm_reward="normalize_value" in agent_cfg and agent_cfg.pop("normalize_value"),
clip_obs="clip_obs" in agent_cfg and agent_cfg.pop("clip_obs"), clip_obs="clip_obs" in agent_cfg and agent_cfg.pop("clip_obs"),
gamma=agent_cfg["gamma"],
clip_reward=np.inf,
) )
# create agent from stable baselines # create agent from stable baselines
......
...@@ -3,17 +3,16 @@ ...@@ -3,17 +3,16 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
"""Script to train RL agent with Stable Baselines3.
Since Stable-Baselines3 does not support buffers living on GPU directly, """Script to train RL agent with Stable Baselines3."""
we recommend using smaller number of environments. Otherwise,
there will be significant overhead in GPU->CPU transfer.
"""
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import contextlib
import signal
import sys import sys
from pathlib import Path
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
...@@ -25,7 +24,14 @@ parser.add_argument("--video_interval", type=int, default=2000, help="Interval b ...@@ -25,7 +24,14 @@ parser.add_argument("--video_interval", type=int, default=2000, help="Interval b
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument("--log_interval", type=int, default=100_000, help="Log data every n timesteps.")
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.") parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
parser.add_argument(
"--keep_all_info",
action="store_true",
default=False,
help="Use a slower SB3 wrapper but keep all the extra training info.",
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
...@@ -41,6 +47,24 @@ sys.argv = [sys.argv[0]] + hydra_args ...@@ -41,6 +47,24 @@ sys.argv = [sys.argv[0]] + hydra_args
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
def cleanup_pbar(*args):
"""
A small helper to stop training and
cleanup progress bar properly on ctrl+c
"""
import gc
tqdm_objects = [obj for obj in gc.get_objects() if "tqdm" in type(obj).__name__]
for tqdm_object in tqdm_objects:
if "tqdm_rich" in type(tqdm_object).__name__:
tqdm_object.close()
raise KeyboardInterrupt
# disable KeyboardInterrupt override
signal.signal(signal.SIGINT, cleanup_pbar)
"""Rest everything follows.""" """Rest everything follows."""
import gymnasium as gym import gymnasium as gym
...@@ -50,8 +74,7 @@ import random ...@@ -50,8 +74,7 @@ import random
from datetime import datetime from datetime import datetime
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.callbacks import CheckpointCallback, LogEveryNTimesteps
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
from isaaclab.envs import ( from isaaclab.envs import (
...@@ -104,8 +127,12 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -104,8 +127,12 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg) dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg) dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
# save command used to run the script
command = " ".join(sys.orig_argv)
(Path(log_dir) / "command.txt").write_text(command)
# post-process agent configuration # post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg) agent_cfg = process_sb3_cfg(agent_cfg, env_cfg.scene.num_envs)
# read configurations about the agent-training # read configurations about the agent-training
policy_arch = agent_cfg.pop("policy") policy_arch = agent_cfg.pop("policy")
n_timesteps = agent_cfg.pop("n_timesteps") n_timesteps = agent_cfg.pop("n_timesteps")
...@@ -130,31 +157,49 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -130,31 +157,49 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for stable baselines # wrap around environment for stable baselines
env = Sb3VecEnvWrapper(env) env = Sb3VecEnvWrapper(env, fast_variant=not args_cli.keep_all_info)
norm_keys = {"normalize_input", "normalize_value", "clip_obs"}
norm_args = {}
for key in norm_keys:
if key in agent_cfg:
norm_args[key] = agent_cfg.pop(key)
if "normalize_input" in agent_cfg: if norm_args and norm_args.get("normalize_input"):
print(f"Normalizing input, {norm_args=}")
env = VecNormalize( env = VecNormalize(
env, env,
training=True, training=True,
norm_obs="normalize_input" in agent_cfg and agent_cfg.pop("normalize_input"), norm_obs=norm_args["normalize_input"],
norm_reward="normalize_value" in agent_cfg and agent_cfg.pop("normalize_value"), norm_reward=norm_args.get("normalize_value", False),
clip_obs="clip_obs" in agent_cfg and agent_cfg.pop("clip_obs"), clip_obs=norm_args.get("clip_obs", 100.0),
gamma=agent_cfg["gamma"], gamma=agent_cfg["gamma"],
clip_reward=np.inf, clip_reward=np.inf,
) )
# create agent from stable baselines # create agent from stable baselines
agent = PPO(policy_arch, env, verbose=1, **agent_cfg) agent = PPO(policy_arch, env, verbose=1, tensorboard_log=log_dir, **agent_cfg)
# configure the logger
new_logger = configure(log_dir, ["stdout", "tensorboard"])
agent.set_logger(new_logger)
# callbacks for agent # callbacks for agent
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_dir, name_prefix="model", verbose=2) checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_dir, name_prefix="model", verbose=2)
callbacks = [checkpoint_callback, LogEveryNTimesteps(n_steps=args_cli.log_interval)]
# train the agent # train the agent
agent.learn(total_timesteps=n_timesteps, callback=checkpoint_callback) with contextlib.suppress(KeyboardInterrupt):
agent.learn(
total_timesteps=n_timesteps,
callback=callbacks,
progress_bar=True,
log_interval=None,
)
# save the final model # save the final model
agent.save(os.path.join(log_dir, "model")) agent.save(os.path.join(log_dir, "model"))
print("Saving to:")
print(os.path.join(log_dir, "model.zip"))
if isinstance(env, VecNormalize):
print("Saving normalization")
env.save(os.path.join(log_dir, "model_vecnormalize.pkl"))
# close the simulator # close the simulator
env.close() env.close()
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.1.4" version = "0.1.5"
# Description # Description
title = "Isaac Lab RL" title = "Isaac Lab RL"
......
Changelog Changelog
--------- ---------
0.1.5 (2025-04-11)
~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Optimized Stable-Baselines3 wrapper ``Sb3VecEnvWrapper`` (now 4x faster) by using Numpy buffers and only logging episode and truncation information by default.
* Upgraded minimum SB3 version to 2.6.0 and added optional dependencies for progress bar
0.1.4 (2025-04-10) 0.1.4 (2025-04-10)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -22,6 +22,7 @@ import gymnasium as gym ...@@ -22,6 +22,7 @@ import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn # noqa: F401 import torch.nn as nn # noqa: F401
import warnings
from typing import Any from typing import Any
from stable_baselines3.common.utils import constant_fn from stable_baselines3.common.utils import constant_fn
...@@ -29,16 +30,20 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, Vec ...@@ -29,16 +30,20 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, Vec
from isaaclab.envs import DirectRLEnv, ManagerBasedRLEnv from isaaclab.envs import DirectRLEnv, ManagerBasedRLEnv
# remove SB3 warnings because PPO with bigger net actually benefits from GPU
warnings.filterwarnings("ignore", message="You are trying to run PPO on the GPU")
""" """
Configuration Parser. Configuration Parser.
""" """
def process_sb3_cfg(cfg: dict) -> dict: def process_sb3_cfg(cfg: dict, num_envs: int) -> dict:
"""Convert simple YAML types to Stable-Baselines classes/components. """Convert simple YAML types to Stable-Baselines classes/components.
Args: Args:
cfg: A configuration dictionary. cfg: A configuration dictionary.
num_envs: the number of parallel environments (used to compute `batch_size` for a desired number of minibatches)
Returns: Returns:
A dictionary containing the converted configuration. A dictionary containing the converted configuration.
...@@ -54,19 +59,24 @@ def process_sb3_cfg(cfg: dict) -> dict: ...@@ -54,19 +59,24 @@ def process_sb3_cfg(cfg: dict) -> dict:
else: else:
if key in ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]: if key in ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]:
hyperparams[key] = eval(value) hyperparams[key] = eval(value)
elif key in ["learning_rate", "clip_range", "clip_range_vf", "delta_std"]: elif key in ["learning_rate", "clip_range", "clip_range_vf"]:
if isinstance(value, str): if isinstance(value, str):
_, initial_value = value.split("_") _, initial_value = value.split("_")
initial_value = float(initial_value) initial_value = float(initial_value)
hyperparams[key] = lambda progress_remaining: progress_remaining * initial_value hyperparams[key] = lambda progress_remaining: progress_remaining * initial_value
elif isinstance(value, (float, int)): elif isinstance(value, (float, int)):
# Negative value: ignore (ex: for clipping) # negative value: ignore (ex: for clipping)
if value < 0: if value < 0:
continue continue
hyperparams[key] = constant_fn(float(value)) hyperparams[key] = constant_fn(float(value))
else: else:
raise ValueError(f"Invalid value for {key}: {hyperparams[key]}") raise ValueError(f"Invalid value for {key}: {hyperparams[key]}")
# Convert to a desired batch_size (n_steps=2048 by default for SB3 PPO)
if "n_minibatches" in hyperparams:
hyperparams["batch_size"] = (hyperparams.get("n_steps", 2048) * num_envs) // hyperparams["n_minibatches"]
del hyperparams["n_minibatches"]
return hyperparams return hyperparams
# parse agent configuration and convert to classes # parse agent configuration and convert to classes
...@@ -89,8 +99,8 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -89,8 +99,8 @@ class Sb3VecEnvWrapper(VecEnv):
Note: Note:
While Stable-Baselines3 supports Gym 0.26+ API, their vectorized environment While Stable-Baselines3 supports Gym 0.26+ API, their vectorized environment
still uses the old API (i.e. it is closer to Gym 0.21). Thus, we implement uses their own API (i.e. it is closer to Gym 0.21). Thus, we implement
the old API for the vectorized environment. the API for the vectorized environment.
We also add monitoring functionality that computes the un-discounted episode We also add monitoring functionality that computes the un-discounted episode
return and length. This information is added to the info dicts under key `episode`. return and length. This information is added to the info dicts under key `episode`.
...@@ -123,12 +133,13 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -123,12 +133,13 @@ class Sb3VecEnvWrapper(VecEnv):
""" """
def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv): def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, fast_variant: bool = True):
"""Initialize the wrapper. """Initialize the wrapper.
Args: Args:
env: The environment to wrap around. env: The environment to wrap around.
fast_variant: Use fast variant for processing info
(Only episodic reward, lengths and truncation info are included)
Raises: Raises:
ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`. ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
""" """
...@@ -140,6 +151,7 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -140,6 +151,7 @@ class Sb3VecEnvWrapper(VecEnv):
) )
# initialize the wrapper # initialize the wrapper
self.env = env self.env = env
self.fast_variant = fast_variant
# collect common information # collect common information
self.num_envs = self.unwrapped.num_envs self.num_envs = self.unwrapped.num_envs
self.sim_device = self.unwrapped.device self.sim_device = self.unwrapped.device
...@@ -156,8 +168,8 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -156,8 +168,8 @@ class Sb3VecEnvWrapper(VecEnv):
# initialize vec-env # initialize vec-env
VecEnv.__init__(self, self.num_envs, observation_space, action_space) VecEnv.__init__(self, self.num_envs, observation_space, action_space)
# add buffer for logging episodic information # add buffer for logging episodic information
self._ep_rew_buf = torch.zeros(self.num_envs, device=self.sim_device) self._ep_rew_buf = np.zeros(self.num_envs)
self._ep_len_buf = torch.zeros(self.num_envs, device=self.sim_device) self._ep_len_buf = np.zeros(self.num_envs)
def __str__(self): def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string.""" """Returns the wrapper name and the :attr:`env` representation string."""
...@@ -190,11 +202,11 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -190,11 +202,11 @@ class Sb3VecEnvWrapper(VecEnv):
def get_episode_rewards(self) -> list[float]: def get_episode_rewards(self) -> list[float]:
"""Returns the rewards of all the episodes.""" """Returns the rewards of all the episodes."""
return self._ep_rew_buf.cpu().tolist() return self._ep_rew_buf.tolist()
def get_episode_lengths(self) -> list[int]: def get_episode_lengths(self) -> list[int]:
"""Returns the number of time-steps of all the episodes.""" """Returns the number of time-steps of all the episodes."""
return self._ep_len_buf.cpu().tolist() return self._ep_len_buf.tolist()
""" """
Operations - MDP Operations - MDP
...@@ -206,8 +218,8 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -206,8 +218,8 @@ class Sb3VecEnvWrapper(VecEnv):
def reset(self) -> VecEnvObs: # noqa: D102 def reset(self) -> VecEnvObs: # noqa: D102
obs_dict, _ = self.env.reset() obs_dict, _ = self.env.reset()
# reset episodic information buffers # reset episodic information buffers
self._ep_rew_buf.zero_() self._ep_rew_buf = np.zeros(self.num_envs)
self._ep_len_buf.zero_() self._ep_len_buf = np.zeros(self.num_envs)
# convert data types to numpy depending on backend # convert data types to numpy depending on backend
return self._process_obs(obs_dict) return self._process_obs(obs_dict)
...@@ -224,28 +236,30 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -224,28 +236,30 @@ class Sb3VecEnvWrapper(VecEnv):
def step_wait(self) -> VecEnvStepReturn: # noqa: D102 def step_wait(self) -> VecEnvStepReturn: # noqa: D102
# record step information # record step information
obs_dict, rew, terminated, truncated, extras = self.env.step(self._async_actions) obs_dict, rew, terminated, truncated, extras = self.env.step(self._async_actions)
# update episode un-discounted return and length
self._ep_rew_buf += rew
self._ep_len_buf += 1
# compute reset ids # compute reset ids
dones = terminated | truncated dones = terminated | truncated
reset_ids = (dones > 0).nonzero(as_tuple=False)
# convert data types to numpy depending on backend # convert data types to numpy depending on backend
# note: ManagerBasedRLEnv uses torch backend (by default). # note: ManagerBasedRLEnv uses torch backend (by default).
obs = self._process_obs(obs_dict) obs = self._process_obs(obs_dict)
rew = rew.detach().cpu().numpy() rewards = rew.detach().cpu().numpy()
terminated = terminated.detach().cpu().numpy() terminated = terminated.detach().cpu().numpy()
truncated = truncated.detach().cpu().numpy() truncated = truncated.detach().cpu().numpy()
dones = dones.detach().cpu().numpy() dones = dones.detach().cpu().numpy()
reset_ids = dones.nonzero()[0]
# update episode un-discounted return and length
self._ep_rew_buf += rewards
self._ep_len_buf += 1
# convert extra information to list of dicts # convert extra information to list of dicts
infos = self._process_extras(obs, terminated, truncated, extras, reset_ids) infos = self._process_extras(obs, terminated, truncated, extras, reset_ids)
# reset info for terminated environments # reset info for terminated environments
self._ep_rew_buf[reset_ids] = 0 self._ep_rew_buf[reset_ids] = 0.0
self._ep_len_buf[reset_ids] = 0 self._ep_len_buf[reset_ids] = 0
return obs, rew, dones, infos return obs, rewards, dones, infos
def close(self): # noqa: D102 def close(self): # noqa: D102
self.env.close() self.env.close()
...@@ -279,7 +293,8 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -279,7 +293,8 @@ class Sb3VecEnvWrapper(VecEnv):
return env_method(*method_args, indices=indices, **method_kwargs) return env_method(*method_args, indices=indices, **method_kwargs)
def env_is_wrapped(self, wrapper_class, indices=None): # noqa: D102 def env_is_wrapped(self, wrapper_class, indices=None): # noqa: D102
raise NotImplementedError("Checking if environment is wrapped is not supported.") # fake implementation to be able to use `evaluate_policy()` helper
return [False]
def get_images(self): # noqa: D102 def get_images(self): # noqa: D102
raise NotImplementedError("Getting images is not supported.") raise NotImplementedError("Getting images is not supported.")
...@@ -306,6 +321,29 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -306,6 +321,29 @@ class Sb3VecEnvWrapper(VecEnv):
self, obs: np.ndarray, terminated: np.ndarray, truncated: np.ndarray, extras: dict, reset_ids: np.ndarray self, obs: np.ndarray, terminated: np.ndarray, truncated: np.ndarray, extras: dict, reset_ids: np.ndarray
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Convert miscellaneous information into dictionary for each sub-environment.""" """Convert miscellaneous information into dictionary for each sub-environment."""
# faster version: only process env that terminated and add bootstrapping info
if self.fast_variant:
infos = [{} for _ in range(self.num_envs)]
for idx in reset_ids:
# fill-in episode monitoring info
infos[idx]["episode"] = {
"r": self._ep_rew_buf[idx],
"l": self._ep_len_buf[idx],
}
# fill-in bootstrap information
infos[idx]["TimeLimit.truncated"] = truncated[idx] and not terminated[idx]
# add information about terminal observation separately
if isinstance(obs, dict):
terminal_obs = {key: value[idx] for key, value in obs.items()}
else:
terminal_obs = obs[idx]
infos[idx]["terminal_observation"] = terminal_obs
return infos
# create empty list of dictionaries to fill # create empty list of dictionaries to fill
infos: list[dict[str, Any]] = [dict.fromkeys(extras.keys()) for _ in range(self.num_envs)] infos: list[dict[str, Any]] = [dict.fromkeys(extras.keys()) for _ in range(self.num_envs)]
# fill-in information for each sub-environment # fill-in information for each sub-environment
......
...@@ -41,7 +41,7 @@ PYTORCH_INDEX_URL = ["https://download.pytorch.org/whl/cu118"] ...@@ -41,7 +41,7 @@ PYTORCH_INDEX_URL = ["https://download.pytorch.org/whl/cu118"]
# Extra dependencies for RL agents # Extra dependencies for RL agents
EXTRAS_REQUIRE = { EXTRAS_REQUIRE = {
"sb3": ["stable-baselines3>=2.1"], "sb3": ["stable-baselines3>=2.6", "tqdm", "rich"], # tqdm/rich for progress bar
"skrl": ["skrl>=1.4.2"], "skrl": ["skrl>=1.4.2"],
"rl-games": ["rl-games==1.6.1", "gym"], # rl-games still needs gym :( "rl-games": ["rl-games==1.6.1", "gym"], # rl-games still needs gym :(
"rsl-rl": ["rsl-rl-lib==2.3.3"], "rsl-rl": ["rsl-rl-lib==2.3.3"],
......
# Reference: https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml#L245 # Adapted from rsl_rl config
seed: 42 seed: 42
policy: "MlpPolicy"
policy: 'MlpPolicy'
n_timesteps: !!float 5e7 n_timesteps: !!float 5e7
batch_size: 256 # For 4 minibatches with 4096 envs
n_steps: 512 # batch_size = (n_envs * n_steps) / n_minibatches = 32768
n_minibatches: 4
n_steps: 32
gamma: 0.99 gamma: 0.99
learning_rate: !!float 2.5e-4 learning_rate: !!float 5e-4
ent_coef: 0.0 ent_coef: 0.0
clip_range: 0.2 clip_range: 0.2
n_epochs: 10 n_epochs: 5
gae_lambda: 0.95 gae_lambda: 0.95
max_grad_norm: 1.0 max_grad_norm: 1.0
vf_coef: 0.5 vf_coef: 0.5
device: "cuda:0"
policy_kwargs: "dict( policy_kwargs: "dict(
log_std_init=-1, activation_fn=nn.ELU,
ortho_init=False, net_arch=[400, 200, 100],
activation_fn=nn.ReLU, optimizer_kwargs=dict(eps=1e-8),
net_arch=dict(pi=[256, 256], vf=[256, 256]) ortho_init=False,
)" )"
...@@ -19,6 +19,7 @@ gym.register( ...@@ -19,6 +19,7 @@ gym.register(
"env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg", "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg", "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
}, },
) )
...@@ -30,6 +31,7 @@ gym.register( ...@@ -30,6 +31,7 @@ gym.register(
"env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg_PLAY", "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg_PLAY",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg", "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
}, },
) )
...@@ -41,6 +43,7 @@ gym.register( ...@@ -41,6 +43,7 @@ gym.register(
"env_cfg_entry_point": f"{__name__}.rough_env_cfg:UnitreeA1RoughEnvCfg", "env_cfg_entry_point": f"{__name__}.rough_env_cfg:UnitreeA1RoughEnvCfg",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1RoughPPORunnerCfg", "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1RoughPPORunnerCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
}, },
) )
...@@ -52,5 +55,6 @@ gym.register( ...@@ -52,5 +55,6 @@ gym.register(
"env_cfg_entry_point": f"{__name__}.rough_env_cfg:UnitreeA1RoughEnvCfg_PLAY", "env_cfg_entry_point": f"{__name__}.rough_env_cfg:UnitreeA1RoughEnvCfg_PLAY",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1RoughPPORunnerCfg", "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1RoughPPORunnerCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
}, },
) )
# Adapted from rsl_rl config
seed: 42
n_timesteps: !!float 5e7
policy: 'MlpPolicy'
n_steps: 24
n_minibatches: 4 # batch_size=24576 for n_envs=4096 and n_steps=24
gae_lambda: 0.95
gamma: 0.99
n_epochs: 5
ent_coef: 0.005
learning_rate: !!float 1e-3
clip_range: !!float 0.2
policy_kwargs: "dict(
activation_fn=nn.ELU,
net_arch=[512, 256, 128],
optimizer_kwargs=dict(eps=1e-8),
ortho_init=False,
)"
vf_coef: 1.0
max_grad_norm: 1.0
normalize_input: True
normalize_value: False
clip_obs: 10.0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment