Unverified Commit 5d44141e authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes SB-3 and RL-Games RL wrappers (#242)

# Description

This MR goes over the current implementations of Stable-Baselines3 and
RL-Games wrapper. It fixes the wrapper implementations as well as the
checkpoint loader to work for the logging format of these wrappers.

The changes have been tested against the `Isaac-Cartpole-v0` environment
from MR #241.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 13327b8a
......@@ -28,6 +28,7 @@ extra_standard_library = [
"bpy",
"matplotlib",
"gymnasium",
"gym",
"scipy",
"hid",
"yaml",
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.5.1"
version = "0.5.2"
# Description
title = "ORBIT Environments"
......
Changelog
---------
0.5.2 (2023-11-08)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed the RL wrappers for Stable-Baselines3 and RL-Games. It now works with their most recent versions.
* Fixed the :meth:`get_checkpoint_path` to allow any in-between sub-folders between the run directory and the
checkpoint directory.
0.5.1 (2023-11-04)
~~~~~~~~~~~~~~~~~~
......
......@@ -140,21 +140,26 @@ def parse_env_cfg(task_name: str, use_gpu: bool | None = None, num_envs: int | N
def get_checkpoint_path(
log_path: str, run_dir: str = "*", checkpoint: str = "*", sort_alphabetical: bool = True
log_path: str, run_dir: str = ".*", checkpoint: str = ".*", other_dirs: list[str] = None, sort_alpha: bool = True
) -> str:
"""Get path to the model checkpoint in input directory.
The checkpoint file is resolved as: <log_path>/<run_dir>/<checkpoint>.
If run_dir and checkpoint are regex expressions then the most recent (highest alphabetical order) run and checkpoint are selected.
The checkpoint file is resolved as: <log_path>/<run_dir>/<*other_dirs>/<checkpoint>, where the
:attr:`other_dirs` are intermediate folder names to concatenate. These cannot be regex expressions.
If :attr:`run_dir` and :attr:`checkpoint` are regex expressions then the most recent (highest alphabetical order)
run and checkpoint are selected. To disable this behavior, set the flag :attr:`sort_alpha` to False.
Args:
log_path: The log directory path to find models in.
run_dir: Regex expression for the name of the directory containing the run. Defaults to the most
run_dir: The regex expression for the name of the directory containing the run. Defaults to the most
recent directory created inside :obj:`log_dir`.
checkpoint: The model checkpoint file or directory name. Defaults to the most recent
other_dirs: The intermediate directories between the run directory and the checkpoint file. Defaults to
None, which implies that checkpoint file is directly under the run directory.
checkpoint: The regex expression for the model checkpoint file. Defaults to the most recent
torch-model saved in the :obj:`run_dir` directory.
sort_alphabetical: Whether to sort the runs and checkpoints by alphabetical order. Defaults to True.
If False, the checkpoints are sorted by the last modified time.
sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True.
If False, the folders in :attr:`run_dir` are sorted by the last modified time.
Raises:
ValueError: When no runs are found in the input directory.
......@@ -173,12 +178,15 @@ def get_checkpoint_path(
os.path.join(log_path, run) for run in os.scandir(log_path) if run.is_dir() and re.match(run_dir, run.name)
]
# sort matched runs by alphabetical order (latest run should be last)
if sort_alphabetical:
if sort_alpha:
runs.sort()
else:
runs = sorted(runs, key=os.path.getmtime)
# create last run file path
run_path = runs[-1]
if other_dirs is not None:
run_path = os.path.join(runs[-1], *other_dirs)
else:
run_path = runs[-1]
except IndexError:
raise ValueError(f"No runs present in the directory: '{log_path}' match: '{run_dir}'.")
......
......@@ -33,7 +33,8 @@ for RL-Games :class:`Runner` class:
from __future__ import annotations
import gymnasium as gym
import gym.spaces # needed for rl-games incompatibility: https://github.com/Denys88/rl_games/issues/261
import gymnasium
import torch
from rl_games.common import env_configurations
......@@ -61,13 +62,12 @@ class RlGamesVecEnvWrapper(IVecEnv):
observations. This dictionary contains "obs" and "states" which typically correspond
to the actor and critic observations respectively.
To use asymmetric actor-critic, the environment instance must have the attributes
To use asymmetric actor-critic, the environment observations from :class:`RLTaskEnv`
must have the key or group name "critic". The observation group is used to set the
:attr:`num_states` (int) and :attr:`state_space` (:obj:`gym.spaces.Box`). These are
used by the learning agent to allocate buffers in the trajectory memory. Additionally,
the method :meth:`_get_observations()` should have the key "critic" which corresponds
to the privileged observations. Since this is optional for some environments, the wrapper
checks if these attributes exist. If they don't then the wrapper defaults to zero as number
of privileged observations.
used by the learning agent in RL-Games to allocate buffers in the trajectory memory.
Since this is optional for some environments, the wrapper checks if these attributes exist.
If they don't then the wrapper defaults to zero as number of privileged observations.
.. caution::
......@@ -104,19 +104,11 @@ class RlGamesVecEnvWrapper(IVecEnv):
self._clip_obs = clip_obs
self._clip_actions = clip_actions
self._sim_device = env.unwrapped.device
# information about spaces for the wrapper
# note: rl-games only wants single observation and action spaces
self.rlg_observation_space = self.unwrapped.single_observation_space["policy"]
self.rlg_action_space = self.unwrapped.single_action_space
# information for privileged observations
self.rlg_state_space = self.unwrapped.single_observation_space.get("critic")
if self.rlg_state_space is not None:
if not isinstance(self.rlg_state_space, gym.spaces.Box):
raise ValueError(f"Privileged observations must be of type Box. Type: {type(self.rlg_state_space)}")
self.rlg_num_states = self.rlg_state_space.shape[0]
else:
if self.state_space is None:
self.rlg_num_states = 0
else:
self.rlg_num_states = self.state_space.shape[0]
def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
......@@ -142,14 +134,35 @@ class RlGamesVecEnvWrapper(IVecEnv):
return self.env.render_mode
@property
def observation_space(self) -> gym.Space:
def observation_space(self) -> gym.spaces.Box:
"""Returns the :attr:`Env` :attr:`observation_space`."""
return self.env.observation_space
# note: rl-games only wants single observation space
policy_obs_space = self.unwrapped.single_observation_space["policy"]
if not isinstance(policy_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support observation space: '{type(policy_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)
@property
def action_space(self) -> gym.Space:
"""Returns the :attr:`Env` :attr:`action_space`."""
return self.env.action_space
# note: rl-games only wants single action space
action_space = self.unwrapped.single_action_space
if not isinstance(action_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support action space: '{type(action_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_actions, self._clip_actions, action_space.shape)
@classmethod
def class_name(cls) -> str:
......@@ -168,6 +181,35 @@ class RlGamesVecEnvWrapper(IVecEnv):
Properties
"""
@property
def num_envs(self) -> int:
"""Returns the number of sub-environment instances."""
return self.unwrapped.num_envs
@property
def device(self) -> str:
"""Returns the base environment simulation device."""
return self.unwrapped.device
@property
def state_space(self) -> gym.spaces.Box | None:
"""Returns the :attr:`Env` :attr:`observation_space`."""
# note: rl-games only wants single observation space
critic_obs_space = self.unwrapped.single_observation_space.get("critic")
# check if we even have a critic obs
if critic_obs_space is None:
return None
elif not isinstance(critic_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support state space: '{type(critic_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, critic_obs_space.shape)
def get_number_of_agents(self) -> int:
"""Returns number of actors in the environment."""
return getattr(self, "num_agents", 1)
......@@ -175,9 +217,9 @@ class RlGamesVecEnvWrapper(IVecEnv):
def get_env_info(self) -> dict:
"""Returns the Gym spaces for the environment."""
return {
"observation_space": self.rlg_observation_space,
"action_space": self.rlg_action_space,
"state_space": self.rlg_state_space,
"observation_space": self.observation_space,
"action_space": self.action_space,
"state_space": self.state_space,
}
"""
......
......@@ -17,10 +17,13 @@ The following example shows how to wrap an environment for Stable-Baselines3:
from __future__ import annotations
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn # noqa: F401
from typing import Any
from stable_baselines3.common.utils import constant_fn
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn
from omni.isaac.orbit.envs import RLTaskEnv
......@@ -44,16 +47,28 @@ def process_sb3_cfg(cfg: dict) -> dict:
Reference:
https://github.com/DLR-RM/rl-baselines3-zoo/blob/0e5eb145faefa33e7d79c7f8c179788574b20da5/utils/exp_manager.py#L358
"""
_direct_eval = ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]
def update_dict(d):
for key, value in d.items():
def update_dict(hyperparams: dict[str, Any]) -> dict[str, Any]:
for key, value in hyperparams.items():
if isinstance(value, dict):
update_dict(value)
else:
if key in _direct_eval:
d[key] = eval(value)
return d
if key in ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]:
hyperparams[key] = eval(value)
elif key in ["learning_rate", "clip_range", "clip_range_vf", "delta_std"]:
if isinstance(value, str):
_, initial_value = value.split("_")
initial_value = float(initial_value)
hyperparams[key] = lambda progress_remaining: progress_remaining * initial_value
elif isinstance(value, (float, int)):
# Negative value: ignore (ex: for clipping)
if value < 0:
continue
hyperparams[key] = constant_fn(float(value))
else:
raise ValueError(f"Invalid value for {key}: {hyperparams[key]}")
return hyperparams
# parse agent configuration and convert to classes
return update_dict(cfg)
......@@ -127,9 +142,14 @@ class Sb3VecEnvWrapper(VecEnv):
self.num_envs = self.unwrapped.num_envs
self.sim_device = self.unwrapped.device
self.render_mode = self.unwrapped.render_mode
# initialize vec-env
# obtain gym spaces
# note: stable-baselines3 does not like when we have unbounded action space so
# we set it to some high value here. Maybe this is not general but something to think about.
observation_space = self.unwrapped.single_observation_space["policy"]
action_space = self.unwrapped.single_action_space
if isinstance(action_space, gym.spaces.Box) and action_space.is_bounded() != "both":
action_space = gym.spaces.Box(low=-100, high=100, shape=action_space.shape)
# initialize vec-env
VecEnv.__init__(self, self.num_envs, observation_space, action_space)
# add buffer for logging episodic information
self._ep_rew_buf = torch.zeros(self.num_envs, device=self.sim_device)
......
......@@ -34,8 +34,8 @@ INSTALL_REQUIRES = [
# Extra dependencies for RL agents
EXTRAS_REQUIRE = {
"sb3": ["stable-baselines3>=2.0"],
"skrl": ["skrl>=0.10.0"],
"rl_games": ["rl-games==1.6.1"],
"skrl": ["skrl==0.10.0"],
"rl_games": ["rl-games==1.6.1", "gym"], # rl-games still needs gym :(
"rsl_rl": ["rsl_rl@git+https://github.com/leggedrobotics/rsl_rl.git"],
"robomimic": ["robomimic@git+https://github.com/ARISE-Initiative/robomimic.git"],
}
......
......@@ -82,7 +82,7 @@ class TestRlGamesVecEnvWrapper(unittest.TestCase):
with torch.inference_mode():
for _ in range(100):
# sample actions from -1 to 1
actions = 2 * torch.rand(env.action_space.shape, device=env.device) - 1
actions = 2 * torch.rand(env.num_envs, *env.action_space.shape, device=env.device) - 1
# apply actions
transition = env.step(actions)
# check signals
......
......@@ -83,7 +83,7 @@ class TestStableBaselines3VecEnvWrapper(unittest.TestCase):
with torch.inference_mode():
for _ in range(1000):
# sample actions from -1 to 1
actions = 2 * np.random.rand(env.num_envs, env.action_space.shape) - 1
actions = 2 * np.random.rand(env.num_envs, *env.action_space.shape) - 1
# apply actions
transition = env.step(actions)
# check signals
......
......@@ -84,17 +84,15 @@ def main():
# find checkpoint
if args_cli.checkpoint is None:
# specify directory for logging runs
if "full_experiment_name" not in agent_cfg["params"]["config"]:
run_dir = os.path.join("*", "nn")
else:
run_dir = os.path.join(agent_cfg["params"]["config"]["full_experiment_name"], "nn")
run_dir = agent_cfg["params"]["config"].get("full_experiment_name", ".*")
# specify name of checkpoint
if args_cli.use_last_checkpoint:
checkpoint_file = None
checkpoint_file = ".*"
else:
# this loads the best checkpoint
checkpoint_file = f"{agent_cfg['params']['config']['name']}.pth"
# get path to previous checkpoint
resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file)
resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file, other_dirs=["nn"])
else:
resume_path = os.path.abspath(args_cli.checkpoint)
# load previously trained model
......
......@@ -15,7 +15,7 @@ import argparse
from omni.isaac.orbit.app import AppLauncher
# local imports
import source.standalone.workflows.rsl_rl.cli_args as cli_args # isort: skip
import cli_args # isort: skip
# add argparse arguments
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
......
......@@ -16,7 +16,7 @@ import os
from omni.isaac.orbit.app import AppLauncher
# local imports
import source.standalone.workflows.rsl_rl.cli_args as cli_args # isort: skip
import cli_args # isort: skip
# add argparse arguments
......
......@@ -21,6 +21,11 @@ parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU p
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument(
"--use_last_checkpoint",
action="store_true",
help="When no checkpoint provided, use the last saved model. Otherwise use the best saved model.",
)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
......@@ -34,6 +39,7 @@ simulation_app = app_launcher.app
import gymnasium as gym
import os
import torch
import traceback
......@@ -43,7 +49,7 @@ from stable_baselines3.common.vec_env import VecNormalize
import omni.isaac.contrib_tasks # noqa: F401
import omni.isaac.orbit_tasks # noqa: F401
from omni.isaac.orbit_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.orbit_tasks.utils.parse_cfg import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
from omni.isaac.orbit_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
......@@ -72,12 +78,21 @@ def main():
clip_reward=np.inf,
)
# directory for logging into
log_root_path = os.path.join("logs", "sb3", args_cli.task)
log_root_path = os.path.abspath(log_root_path)
# check checkpoint is valid
if args_cli.checkpoint is None:
raise ValueError("Checkpoint path is not valid.")
if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip"
else:
checkpoint = "model.zip"
checkpoint_path = get_checkpoint_path(log_root_path, ".*", checkpoint)
else:
checkpoint_path = args_cli.checkpoint
# create agent from stable baselines
print(f"Loading checkpoint from: {args_cli.checkpoint}")
agent = PPO.load(args_cli.checkpoint, env, print_system_info=True)
print(f"Loading checkpoint from: {checkpoint_path}")
agent = PPO.load(checkpoint_path, env, print_system_info=True)
# reset environment
obs = env.reset()
......
......@@ -3,7 +3,12 @@
#
# SPDX-License-Identifier: BSD-3-Clause
"""Script to train RL agent with Stable Baselines3."""
"""Script to train RL agent with Stable Baselines3.
Since Stable-Baselines3 does not support buffers living on GPU directly,
we recommend using smaller number of environments. Otherwise,
there will be significant overhead in GPU->CPU transfer.
"""
from __future__ import annotations
......@@ -68,8 +73,6 @@ def main():
# parse configuration
env_cfg = parse_env_cfg(args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs)
agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point")
# post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg)
# override configuration with command line arguments
if args_cli.seed is not None:
......@@ -83,6 +86,8 @@ def main():
dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
# post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg)
# read configurations about the agent-training
policy_arch = agent_cfg.pop("policy")
n_timesteps = agent_cfg.pop("n_timesteps")
......
......@@ -123,7 +123,7 @@ def main():
if args_cli.checkpoint:
resume_path = os.path.abspath(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(log_root_path, os.path.join("*", "checkpoints"), None)
resume_path = get_checkpoint_path(log_root_path, other_dirs=["checkpoints"])
print(f"[INFO] Loading model checkpoint from: {resume_path}")
# initialize agent
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment