Unverified Commit 4264e4f5 authored by Toni-SM's avatar Toni-SM Committed by GitHub

Updates skrl integration to support training/evaluation using JAX (#592)

# Description

This PR updates the skrl integration to support training/evaluation
using JAX ML framework

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have run all the tests with `./isaaclab.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 1c0a4abe
......@@ -184,14 +184,31 @@ from the environments into the respective libraries function argument and return
- Training an agent with
`SKRL <https://skrl.readthedocs.io>`__ on ``Isaac-Reach-Franka-v0``:
.. code:: bash
.. tab-set::
# install python module (for skrl)
./isaaclab.sh -i skrl
# run script for training
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
.. tab-item:: PyTorch
.. code:: bash
# install python module (for skrl)
./isaaclab.sh -i skrl
# run script for training
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
.. tab-item:: JAX
.. code:: bash
# install python module (for skrl)
./isaaclab.sh -i skrl
# install skrl dependencies for JAX. Visit https://skrl.readthedocs.io/en/latest/intro/installation.html for more details
./isaaclab.sh -p -m pip install skrl["jax"]
# run script for training
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless --ml_framework jax
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --ml_framework jax --checkpoint /PATH/TO/model.pt
- Training an agent with
`RL-Games <https://github.com/Denys88/rl_games>`__ on ``Isaac-Ant-v0``:
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.7.9"
version = "0.7.10"
# Description
title = "Isaac Lab Environments"
......
Changelog
---------
0.7.10 (2024-07-02)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Extended skrl wrapper to support training/evaluation using JAX
0.7.9 (2024-07-01)
~~~~~~~~~~~~~~~~~~
......
......@@ -11,13 +11,14 @@ The following example shows how to wrap an environment for skrl:
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper
env = SkrlVecEnvWrapper(env)
env = SkrlVecEnvWrapper(env, ml_framework="torch") # or ml_framework="jax"
Or, equivalently, by directly calling the skrl library API as follows:
.. code-block:: python
from skrl.envs.torch.wrappers import wrap_env
from skrl.envs.torch.wrappers import wrap_env # for PyTorch, or...
from skrl.envs.jax.wrappers import wrap_env # for JAX
env = wrap_env(env, wrapper="isaaclab")
......@@ -26,10 +27,7 @@ Or, equivalently, by directly calling the skrl library API as follows:
# needed to import for type hinting: Agent | list[Agent]
from __future__ import annotations
from skrl.envs.wrappers.torch import wrap_env
from skrl.resources.preprocessors.torch import RunningStandardScaler # noqa: F401
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa: F401
from skrl.utils.model_instantiators.torch import Shape # noqa: F401
from typing import Literal
from omni.isaac.lab.envs import DirectRLEnv, ManagerBasedRLEnv
......@@ -38,14 +36,18 @@ Configuration Parser.
"""
def process_skrl_cfg(cfg: dict) -> dict:
def process_skrl_cfg(cfg: dict, ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch") -> dict:
"""Convert simple YAML types to skrl classes/components.
Args:
cfg: A configuration dictionary.
ml_framework: The ML framework to use for the wrapper. Defaults to "torch".
Returns:
A dictionary containing the converted configuration.
Raises:
ValueError: If the specified ML framework is not valid.
"""
_direct_eval = [
"learning_rate_scheduler",
......@@ -62,6 +64,20 @@ def process_skrl_cfg(cfg: dict) -> dict:
return reward_shaper
def update_dict(d):
# import statements according to the ML framework
if ml_framework.startswith("torch"):
from skrl.resources.preprocessors.torch import RunningStandardScaler # noqa: F401
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa: F401
from skrl.utils.model_instantiators.torch import Shape # noqa: F401
elif ml_framework.startswith("jax"):
from skrl.resources.preprocessors.jax import RunningStandardScaler # noqa: F401
from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa: F401
from skrl.utils.model_instantiators.jax import Shape # noqa: F401
else:
ValueError(
f"Invalid ML framework for skrl: {ml_framework}. Available options are: 'torch', 'jax' or 'jax-numpy'"
)
for key, value in d.items():
if isinstance(value, dict):
update_dict(value)
......@@ -84,7 +100,7 @@ Vectorized environment wrapper.
"""
def SkrlVecEnvWrapper(env: ManagerBasedRLEnv):
def SkrlVecEnvWrapper(env: ManagerBasedRLEnv, ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch"):
"""Wraps around Isaac Lab environment for skrl.
This function wraps around the Isaac Lab environment. Since the :class:`ManagerBasedRLEnv` environment
......@@ -94,9 +110,11 @@ def SkrlVecEnvWrapper(env: ManagerBasedRLEnv):
Args:
env: The environment to wrap around.
ml_framework: The ML framework to use for the wrapper. Defaults to "torch".
Raises:
ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv`.
ValueError: If the specified ML framework is not valid.
Reference:
https://skrl.readthedocs.io/en/latest/api/envs/wrapping.html
......@@ -106,5 +124,16 @@ def SkrlVecEnvWrapper(env: ManagerBasedRLEnv):
raise ValueError(
f"The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type: {type(env)}"
)
# import statements according to the ML framework
if ml_framework.startswith("torch"):
from skrl.envs.wrappers.torch import wrap_env
elif ml_framework.startswith("jax"):
from skrl.envs.wrappers.jax import wrap_env
else:
ValueError(
f"Invalid ML framework for skrl: {ml_framework}. Available options are: 'torch', 'jax' or 'jax-numpy'"
)
# wrap and return the environment
return wrap_env(env, wrapper="isaaclab")
......@@ -26,6 +26,14 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument(
"--ml_framework",
type=str,
default="torch",
choices=["torch", "jax", "jax-numpy"],
help="The ML framework used for training the skrl agent.",
)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
......@@ -41,8 +49,14 @@ import gymnasium as gym
import os
import torch
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
import skrl
if args_cli.ml_framework.startswith("torch"):
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
elif args_cli.ml_framework.startswith("jax"):
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
......@@ -51,7 +65,10 @@ from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_
def main():
"""Play with skrl agent."""
# parse env configuration
# configure the ML framework into the global skrl variable
if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
......@@ -60,24 +77,26 @@ def main():
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg)
# wrap around environment for skrl
env = SkrlVecEnvWrapper(env) # same as: `wrap_env(env, wrapper="isaaclab")`
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
# instantiate models using skrl model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models = {}
if args_cli.ml_framework.startswith("jax"):
experiment_cfg["models"]["separate"] = True # shared model is not supported in JAX
# non-shared models
if experiment_cfg["models"]["separate"]:
models["policy"] = gaussian_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["policy"]),
**process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
)
models["value"] = deterministic_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["value"]),
**process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
)
# shared models
else:
......@@ -88,17 +107,21 @@ def main():
structure=None,
roles=["policy", "value"],
parameters=[
process_skrl_cfg(experiment_cfg["models"]["policy"]),
process_skrl_cfg(experiment_cfg["models"]["value"]),
process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
],
)
models["value"] = models["policy"]
# instantiate models' state dict
if args_cli.ml_framework.startswith("jax"):
for role, model in models.items():
model.init_state_dict(role)
# configure and instantiate PPO agent
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent_cfg = PPO_DEFAULT_CONFIG.copy()
experiment_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"]))
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"], ml_framework=args_cli.ml_framework))
agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})
......
......@@ -33,6 +33,13 @@ parser.add_argument(
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
)
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
parser.add_argument(
"--ml_framework",
type=str,
default="torch",
choices=["torch", "jax", "jax-numpy"],
help="The ML framework used for training the skrl agent.",
)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
......@@ -52,11 +59,19 @@ import gymnasium as gym
import os
from datetime import datetime
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer
import skrl
from skrl.utils import set_seed
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
if args_cli.ml_framework.startswith("torch"):
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
elif args_cli.ml_framework.startswith("jax"):
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.memories.jax import RandomMemory
from skrl.trainers.jax import SequentialTrainer
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
......@@ -68,6 +83,10 @@ from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_
def main():
"""Train with skrl agent."""
# configure the ML framework into the global skrl variable
if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# read the seed from command line
args_cli_seed = args_cli.seed
......@@ -93,6 +112,8 @@ def main():
# multi-gpu training config
if args_cli.distributed:
if args_cli.ml_framework.startswith("jax"):
raise ValueError("Multi-GPU distributed training not yet supported in JAX")
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
......@@ -120,7 +141,7 @@ def main():
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for skrl
env = SkrlVecEnvWrapper(env) # same as: `wrap_env(env, wrapper="isaaclab")`
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
# set seed for the experiment (override from command line)
set_seed(args_cli_seed if args_cli_seed is not None else experiment_cfg["seed"])
......@@ -128,19 +149,21 @@ def main():
# instantiate models using skrl model instantiator utility
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
models = {}
if args_cli.ml_framework.startswith("jax"):
experiment_cfg["models"]["separate"] = True # shared model is not supported in JAX
# non-shared models
if experiment_cfg["models"]["separate"]:
models["policy"] = gaussian_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["policy"]),
**process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
)
models["value"] = deterministic_model(
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device,
**process_skrl_cfg(experiment_cfg["models"]["value"]),
**process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
)
# shared models
else:
......@@ -151,11 +174,15 @@ def main():
structure=None,
roles=["policy", "value"],
parameters=[
process_skrl_cfg(experiment_cfg["models"]["policy"]),
process_skrl_cfg(experiment_cfg["models"]["value"]),
process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
],
)
models["value"] = models["policy"]
# instantiate models' state dict
if args_cli.ml_framework.startswith("jax"):
for role, model in models.items():
model.init_state_dict(role)
# instantiate a RandomMemory as rollout buffer (any memory can be used for this)
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
......@@ -166,7 +193,7 @@ def main():
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
agent_cfg = PPO_DEFAULT_CONFIG.copy()
experiment_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"]))
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"], ml_framework=args_cli.ml_framework))
agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment