Unverified Commit 5d44141e authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes SB-3 and RL-Games RL wrappers (#242)

# Description

This MR goes over the current implementations of Stable-Baselines3 and
RL-Games wrapper. It fixes the wrapper implementations as well as the
checkpoint loader to work for the logging format of these wrappers.

The changes have been tested against the `Isaac-Cartpole-v0` environment
from MR #241.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 13327b8a
...@@ -28,6 +28,7 @@ extra_standard_library = [ ...@@ -28,6 +28,7 @@ extra_standard_library = [
"bpy", "bpy",
"matplotlib", "matplotlib",
"gymnasium", "gymnasium",
"gym",
"scipy", "scipy",
"hid", "hid",
"yaml", "yaml",
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.5.1" version = "0.5.2"
# Description # Description
title = "ORBIT Environments" title = "ORBIT Environments"
......
Changelog Changelog
--------- ---------
0.5.2 (2023-11-08)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed the RL wrappers for Stable-Baselines3 and RL-Games. It now works with their most recent versions.
* Fixed the :meth:`get_checkpoint_path` to allow any in-between sub-folders between the run directory and the
checkpoint directory.
0.5.1 (2023-11-04) 0.5.1 (2023-11-04)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -140,21 +140,26 @@ def parse_env_cfg(task_name: str, use_gpu: bool | None = None, num_envs: int | N ...@@ -140,21 +140,26 @@ def parse_env_cfg(task_name: str, use_gpu: bool | None = None, num_envs: int | N
def get_checkpoint_path( def get_checkpoint_path(
log_path: str, run_dir: str = "*", checkpoint: str = "*", sort_alphabetical: bool = True log_path: str, run_dir: str = ".*", checkpoint: str = ".*", other_dirs: list[str] = None, sort_alpha: bool = True
) -> str: ) -> str:
"""Get path to the model checkpoint in input directory. """Get path to the model checkpoint in input directory.
The checkpoint file is resolved as: <log_path>/<run_dir>/<checkpoint>. The checkpoint file is resolved as: <log_path>/<run_dir>/<*other_dirs>/<checkpoint>, where the
If run_dir and checkpoint are regex expressions then the most recent (highest alphabetical order) run and checkpoint are selected. :attr:`other_dirs` are intermediate folder names to concatenate. These cannot be regex expressions.
If :attr:`run_dir` and :attr:`checkpoint` are regex expressions then the most recent (highest alphabetical order)
run and checkpoint are selected. To disable this behavior, set the flag :attr:`sort_alpha` to False.
Args: Args:
log_path: The log directory path to find models in. log_path: The log directory path to find models in.
run_dir: Regex expression for the name of the directory containing the run. Defaults to the most run_dir: The regex expression for the name of the directory containing the run. Defaults to the most
recent directory created inside :obj:`log_dir`. recent directory created inside :obj:`log_dir`.
checkpoint: The model checkpoint file or directory name. Defaults to the most recent other_dirs: The intermediate directories between the run directory and the checkpoint file. Defaults to
None, which implies that checkpoint file is directly under the run directory.
checkpoint: The regex expression for the model checkpoint file. Defaults to the most recent
torch-model saved in the :obj:`run_dir` directory. torch-model saved in the :obj:`run_dir` directory.
sort_alphabetical: Whether to sort the runs and checkpoints by alphabetical order. Defaults to True. sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True.
If False, the checkpoints are sorted by the last modified time. If False, the folders in :attr:`run_dir` are sorted by the last modified time.
Raises: Raises:
ValueError: When no runs are found in the input directory. ValueError: When no runs are found in the input directory.
...@@ -173,11 +178,14 @@ def get_checkpoint_path( ...@@ -173,11 +178,14 @@ def get_checkpoint_path(
os.path.join(log_path, run) for run in os.scandir(log_path) if run.is_dir() and re.match(run_dir, run.name) os.path.join(log_path, run) for run in os.scandir(log_path) if run.is_dir() and re.match(run_dir, run.name)
] ]
# sort matched runs by alphabetical order (latest run should be last) # sort matched runs by alphabetical order (latest run should be last)
if sort_alphabetical: if sort_alpha:
runs.sort() runs.sort()
else: else:
runs = sorted(runs, key=os.path.getmtime) runs = sorted(runs, key=os.path.getmtime)
# create last run file path # create last run file path
if other_dirs is not None:
run_path = os.path.join(runs[-1], *other_dirs)
else:
run_path = runs[-1] run_path = runs[-1]
except IndexError: except IndexError:
raise ValueError(f"No runs present in the directory: '{log_path}' match: '{run_dir}'.") raise ValueError(f"No runs present in the directory: '{log_path}' match: '{run_dir}'.")
......
...@@ -33,7 +33,8 @@ for RL-Games :class:`Runner` class: ...@@ -33,7 +33,8 @@ for RL-Games :class:`Runner` class:
from __future__ import annotations from __future__ import annotations
import gymnasium as gym import gym.spaces # needed for rl-games incompatibility: https://github.com/Denys88/rl_games/issues/261
import gymnasium
import torch import torch
from rl_games.common import env_configurations from rl_games.common import env_configurations
...@@ -61,13 +62,12 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -61,13 +62,12 @@ class RlGamesVecEnvWrapper(IVecEnv):
observations. This dictionary contains "obs" and "states" which typically correspond observations. This dictionary contains "obs" and "states" which typically correspond
to the actor and critic observations respectively. to the actor and critic observations respectively.
To use asymmetric actor-critic, the environment instance must have the attributes To use asymmetric actor-critic, the environment observations from :class:`RLTaskEnv`
must have the key or group name "critic". The observation group is used to set the
:attr:`num_states` (int) and :attr:`state_space` (:obj:`gym.spaces.Box`). These are :attr:`num_states` (int) and :attr:`state_space` (:obj:`gym.spaces.Box`). These are
used by the learning agent to allocate buffers in the trajectory memory. Additionally, used by the learning agent in RL-Games to allocate buffers in the trajectory memory.
the method :meth:`_get_observations()` should have the key "critic" which corresponds Since this is optional for some environments, the wrapper checks if these attributes exist.
to the privileged observations. Since this is optional for some environments, the wrapper If they don't then the wrapper defaults to zero as number of privileged observations.
checks if these attributes exist. If they don't then the wrapper defaults to zero as number
of privileged observations.
.. caution:: .. caution::
...@@ -104,19 +104,11 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -104,19 +104,11 @@ class RlGamesVecEnvWrapper(IVecEnv):
self._clip_obs = clip_obs self._clip_obs = clip_obs
self._clip_actions = clip_actions self._clip_actions = clip_actions
self._sim_device = env.unwrapped.device self._sim_device = env.unwrapped.device
# information about spaces for the wrapper
# note: rl-games only wants single observation and action spaces
self.rlg_observation_space = self.unwrapped.single_observation_space["policy"]
self.rlg_action_space = self.unwrapped.single_action_space
# information for privileged observations # information for privileged observations
self.rlg_state_space = self.unwrapped.single_observation_space.get("critic") if self.state_space is None:
if self.rlg_state_space is not None:
if not isinstance(self.rlg_state_space, gym.spaces.Box):
raise ValueError(f"Privileged observations must be of type Box. Type: {type(self.rlg_state_space)}")
self.rlg_num_states = self.rlg_state_space.shape[0]
else:
self.rlg_num_states = 0 self.rlg_num_states = 0
else:
self.rlg_num_states = self.state_space.shape[0]
def __str__(self): def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string.""" """Returns the wrapper name and the :attr:`env` representation string."""
...@@ -142,14 +134,35 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -142,14 +134,35 @@ class RlGamesVecEnvWrapper(IVecEnv):
return self.env.render_mode return self.env.render_mode
@property @property
def observation_space(self) -> gym.Space: def observation_space(self) -> gym.spaces.Box:
"""Returns the :attr:`Env` :attr:`observation_space`.""" """Returns the :attr:`Env` :attr:`observation_space`."""
return self.env.observation_space # note: rl-games only wants single observation space
policy_obs_space = self.unwrapped.single_observation_space["policy"]
if not isinstance(policy_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support observation space: '{type(policy_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)
@property @property
def action_space(self) -> gym.Space: def action_space(self) -> gym.Space:
"""Returns the :attr:`Env` :attr:`action_space`.""" """Returns the :attr:`Env` :attr:`action_space`."""
return self.env.action_space # note: rl-games only wants single action space
action_space = self.unwrapped.single_action_space
if not isinstance(action_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support action space: '{type(action_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_actions, self._clip_actions, action_space.shape)
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
...@@ -168,6 +181,35 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -168,6 +181,35 @@ class RlGamesVecEnvWrapper(IVecEnv):
Properties Properties
""" """
@property
def num_envs(self) -> int:
"""Returns the number of sub-environment instances."""
return self.unwrapped.num_envs
@property
def device(self) -> str:
"""Returns the base environment simulation device."""
return self.unwrapped.device
@property
def state_space(self) -> gym.spaces.Box | None:
"""Returns the :attr:`Env` :attr:`observation_space`."""
# note: rl-games only wants single observation space
critic_obs_space = self.unwrapped.single_observation_space.get("critic")
# check if we even have a critic obs
if critic_obs_space is None:
return None
elif not isinstance(critic_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support state space: '{type(critic_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, critic_obs_space.shape)
def get_number_of_agents(self) -> int: def get_number_of_agents(self) -> int:
"""Returns number of actors in the environment.""" """Returns number of actors in the environment."""
return getattr(self, "num_agents", 1) return getattr(self, "num_agents", 1)
...@@ -175,9 +217,9 @@ class RlGamesVecEnvWrapper(IVecEnv): ...@@ -175,9 +217,9 @@ class RlGamesVecEnvWrapper(IVecEnv):
def get_env_info(self) -> dict: def get_env_info(self) -> dict:
"""Returns the Gym spaces for the environment.""" """Returns the Gym spaces for the environment."""
return { return {
"observation_space": self.rlg_observation_space, "observation_space": self.observation_space,
"action_space": self.rlg_action_space, "action_space": self.action_space,
"state_space": self.rlg_state_space, "state_space": self.state_space,
} }
""" """
......
...@@ -17,10 +17,13 @@ The following example shows how to wrap an environment for Stable-Baselines3: ...@@ -17,10 +17,13 @@ The following example shows how to wrap an environment for Stable-Baselines3:
from __future__ import annotations from __future__ import annotations
import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn # noqa: F401
from typing import Any from typing import Any
from stable_baselines3.common.utils import constant_fn
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn
from omni.isaac.orbit.envs import RLTaskEnv from omni.isaac.orbit.envs import RLTaskEnv
...@@ -44,16 +47,28 @@ def process_sb3_cfg(cfg: dict) -> dict: ...@@ -44,16 +47,28 @@ def process_sb3_cfg(cfg: dict) -> dict:
Reference: Reference:
https://github.com/DLR-RM/rl-baselines3-zoo/blob/0e5eb145faefa33e7d79c7f8c179788574b20da5/utils/exp_manager.py#L358 https://github.com/DLR-RM/rl-baselines3-zoo/blob/0e5eb145faefa33e7d79c7f8c179788574b20da5/utils/exp_manager.py#L358
""" """
_direct_eval = ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]
def update_dict(d): def update_dict(hyperparams: dict[str, Any]) -> dict[str, Any]:
for key, value in d.items(): for key, value in hyperparams.items():
if isinstance(value, dict): if isinstance(value, dict):
update_dict(value) update_dict(value)
else: else:
if key in _direct_eval: if key in ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]:
d[key] = eval(value) hyperparams[key] = eval(value)
return d elif key in ["learning_rate", "clip_range", "clip_range_vf", "delta_std"]:
if isinstance(value, str):
_, initial_value = value.split("_")
initial_value = float(initial_value)
hyperparams[key] = lambda progress_remaining: progress_remaining * initial_value
elif isinstance(value, (float, int)):
# Negative value: ignore (ex: for clipping)
if value < 0:
continue
hyperparams[key] = constant_fn(float(value))
else:
raise ValueError(f"Invalid value for {key}: {hyperparams[key]}")
return hyperparams
# parse agent configuration and convert to classes # parse agent configuration and convert to classes
return update_dict(cfg) return update_dict(cfg)
...@@ -127,9 +142,14 @@ class Sb3VecEnvWrapper(VecEnv): ...@@ -127,9 +142,14 @@ class Sb3VecEnvWrapper(VecEnv):
self.num_envs = self.unwrapped.num_envs self.num_envs = self.unwrapped.num_envs
self.sim_device = self.unwrapped.device self.sim_device = self.unwrapped.device
self.render_mode = self.unwrapped.render_mode self.render_mode = self.unwrapped.render_mode
# initialize vec-env # obtain gym spaces
# note: stable-baselines3 does not like when we have unbounded action space so
# we set it to some high value here. Maybe this is not general but something to think about.
observation_space = self.unwrapped.single_observation_space["policy"] observation_space = self.unwrapped.single_observation_space["policy"]
action_space = self.unwrapped.single_action_space action_space = self.unwrapped.single_action_space
if isinstance(action_space, gym.spaces.Box) and action_space.is_bounded() != "both":
action_space = gym.spaces.Box(low=-100, high=100, shape=action_space.shape)
# initialize vec-env
VecEnv.__init__(self, self.num_envs, observation_space, action_space) VecEnv.__init__(self, self.num_envs, observation_space, action_space)
# add buffer for logging episodic information # add buffer for logging episodic information
self._ep_rew_buf = torch.zeros(self.num_envs, device=self.sim_device) self._ep_rew_buf = torch.zeros(self.num_envs, device=self.sim_device)
......
...@@ -34,8 +34,8 @@ INSTALL_REQUIRES = [ ...@@ -34,8 +34,8 @@ INSTALL_REQUIRES = [
# Extra dependencies for RL agents # Extra dependencies for RL agents
EXTRAS_REQUIRE = { EXTRAS_REQUIRE = {
"sb3": ["stable-baselines3>=2.0"], "sb3": ["stable-baselines3>=2.0"],
"skrl": ["skrl>=0.10.0"], "skrl": ["skrl==0.10.0"],
"rl_games": ["rl-games==1.6.1"], "rl_games": ["rl-games==1.6.1", "gym"], # rl-games still needs gym :(
"rsl_rl": ["rsl_rl@git+https://github.com/leggedrobotics/rsl_rl.git"], "rsl_rl": ["rsl_rl@git+https://github.com/leggedrobotics/rsl_rl.git"],
"robomimic": ["robomimic@git+https://github.com/ARISE-Initiative/robomimic.git"], "robomimic": ["robomimic@git+https://github.com/ARISE-Initiative/robomimic.git"],
} }
......
...@@ -82,7 +82,7 @@ class TestRlGamesVecEnvWrapper(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestRlGamesVecEnvWrapper(unittest.TestCase):
with torch.inference_mode(): with torch.inference_mode():
for _ in range(100): for _ in range(100):
# sample actions from -1 to 1 # sample actions from -1 to 1
actions = 2 * torch.rand(env.action_space.shape, device=env.device) - 1 actions = 2 * torch.rand(env.num_envs, *env.action_space.shape, device=env.device) - 1
# apply actions # apply actions
transition = env.step(actions) transition = env.step(actions)
# check signals # check signals
......
...@@ -83,7 +83,7 @@ class TestStableBaselines3VecEnvWrapper(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestStableBaselines3VecEnvWrapper(unittest.TestCase):
with torch.inference_mode(): with torch.inference_mode():
for _ in range(1000): for _ in range(1000):
# sample actions from -1 to 1 # sample actions from -1 to 1
actions = 2 * np.random.rand(env.num_envs, env.action_space.shape) - 1 actions = 2 * np.random.rand(env.num_envs, *env.action_space.shape) - 1
# apply actions # apply actions
transition = env.step(actions) transition = env.step(actions)
# check signals # check signals
......
...@@ -84,17 +84,15 @@ def main(): ...@@ -84,17 +84,15 @@ def main():
# find checkpoint # find checkpoint
if args_cli.checkpoint is None: if args_cli.checkpoint is None:
# specify directory for logging runs # specify directory for logging runs
if "full_experiment_name" not in agent_cfg["params"]["config"]: run_dir = agent_cfg["params"]["config"].get("full_experiment_name", ".*")
run_dir = os.path.join("*", "nn")
else:
run_dir = os.path.join(agent_cfg["params"]["config"]["full_experiment_name"], "nn")
# specify name of checkpoint # specify name of checkpoint
if args_cli.use_last_checkpoint: if args_cli.use_last_checkpoint:
checkpoint_file = None checkpoint_file = ".*"
else: else:
# this loads the best checkpoint
checkpoint_file = f"{agent_cfg['params']['config']['name']}.pth" checkpoint_file = f"{agent_cfg['params']['config']['name']}.pth"
# get path to previous checkpoint # get path to previous checkpoint
resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file) resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file, other_dirs=["nn"])
else: else:
resume_path = os.path.abspath(args_cli.checkpoint) resume_path = os.path.abspath(args_cli.checkpoint)
# load previously trained model # load previously trained model
......
...@@ -15,7 +15,7 @@ import argparse ...@@ -15,7 +15,7 @@ import argparse
from omni.isaac.orbit.app import AppLauncher from omni.isaac.orbit.app import AppLauncher
# local imports # local imports
import source.standalone.workflows.rsl_rl.cli_args as cli_args # isort: skip import cli_args # isort: skip
# add argparse arguments # add argparse arguments
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.") parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
from omni.isaac.orbit.app import AppLauncher from omni.isaac.orbit.app import AppLauncher
# local imports # local imports
import source.standalone.workflows.rsl_rl.cli_args as cli_args # isort: skip import cli_args # isort: skip
# add argparse arguments # add argparse arguments
......
...@@ -21,6 +21,11 @@ parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU p ...@@ -21,6 +21,11 @@ parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU p
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.") parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument(
"--use_last_checkpoint",
action="store_true",
help="When no checkpoint provided, use the last saved model. Otherwise use the best saved model.",
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
...@@ -34,6 +39,7 @@ simulation_app = app_launcher.app ...@@ -34,6 +39,7 @@ simulation_app = app_launcher.app
import gymnasium as gym import gymnasium as gym
import os
import torch import torch
import traceback import traceback
...@@ -43,7 +49,7 @@ from stable_baselines3.common.vec_env import VecNormalize ...@@ -43,7 +49,7 @@ from stable_baselines3.common.vec_env import VecNormalize
import omni.isaac.contrib_tasks # noqa: F401 import omni.isaac.contrib_tasks # noqa: F401
import omni.isaac.orbit_tasks # noqa: F401 import omni.isaac.orbit_tasks # noqa: F401
from omni.isaac.orbit_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg from omni.isaac.orbit_tasks.utils.parse_cfg import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
from omni.isaac.orbit_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg from omni.isaac.orbit_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
...@@ -72,12 +78,21 @@ def main(): ...@@ -72,12 +78,21 @@ def main():
clip_reward=np.inf, clip_reward=np.inf,
) )
# directory for logging into
log_root_path = os.path.join("logs", "sb3", args_cli.task)
log_root_path = os.path.abspath(log_root_path)
# check checkpoint is valid # check checkpoint is valid
if args_cli.checkpoint is None: if args_cli.checkpoint is None:
raise ValueError("Checkpoint path is not valid.") if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip"
else:
checkpoint = "model.zip"
checkpoint_path = get_checkpoint_path(log_root_path, ".*", checkpoint)
else:
checkpoint_path = args_cli.checkpoint
# create agent from stable baselines # create agent from stable baselines
print(f"Loading checkpoint from: {args_cli.checkpoint}") print(f"Loading checkpoint from: {checkpoint_path}")
agent = PPO.load(args_cli.checkpoint, env, print_system_info=True) agent = PPO.load(checkpoint_path, env, print_system_info=True)
# reset environment # reset environment
obs = env.reset() obs = env.reset()
......
...@@ -3,7 +3,12 @@ ...@@ -3,7 +3,12 @@
# #
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
"""Script to train RL agent with Stable Baselines3.""" """Script to train RL agent with Stable Baselines3.
Since Stable-Baselines3 does not support buffers living on GPU directly,
we recommend using smaller number of environments. Otherwise,
there will be significant overhead in GPU->CPU transfer.
"""
from __future__ import annotations from __future__ import annotations
...@@ -68,8 +73,6 @@ def main(): ...@@ -68,8 +73,6 @@ def main():
# parse configuration # parse configuration
env_cfg = parse_env_cfg(args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs) env_cfg = parse_env_cfg(args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs)
agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point") agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point")
# post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg)
# override configuration with command line arguments # override configuration with command line arguments
if args_cli.seed is not None: if args_cli.seed is not None:
...@@ -83,6 +86,8 @@ def main(): ...@@ -83,6 +86,8 @@ def main():
dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg) dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg) dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
# post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg)
# read configurations about the agent-training # read configurations about the agent-training
policy_arch = agent_cfg.pop("policy") policy_arch = agent_cfg.pop("policy")
n_timesteps = agent_cfg.pop("n_timesteps") n_timesteps = agent_cfg.pop("n_timesteps")
......
...@@ -123,7 +123,7 @@ def main(): ...@@ -123,7 +123,7 @@ def main():
if args_cli.checkpoint: if args_cli.checkpoint:
resume_path = os.path.abspath(args_cli.checkpoint) resume_path = os.path.abspath(args_cli.checkpoint)
else: else:
resume_path = get_checkpoint_path(log_root_path, os.path.join("*", "checkpoints"), None) resume_path = get_checkpoint_path(log_root_path, other_dirs=["checkpoints"])
print(f"[INFO] Loading model checkpoint from: {resume_path}") print(f"[INFO] Loading model checkpoint from: {resume_path}")
# initialize agent # initialize agent
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment