Unverified Commit 477b6a92 authored by Clemens Schwarke's avatar Clemens Schwarke Committed by GitHub

Add configs and adapt exporter for RSL-RL distillation (#2182)

# Description

This PR adds configuration classes for Student-Teacher Distillation and
adapts the policy exporters to be able to export student policies.

## Type of change

- Non-breaking change

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent 09590912
......@@ -46,6 +46,7 @@ Guidelines for modifications:
* Calvin Yu
* Cheng-Rong Lai
* Chenyu Yang
* Clemens Schwarke
* CY (Chien-Ying) Chen
* David Yang
* Dorsa Rohani
......
......@@ -5,21 +5,6 @@
"""Script to play a checkpoint if an RL agent from RSL-RL."""
import platform
from importlib.metadata import version
if version("rsl-rl-lib") != "2.3.0":
if platform.system() == "Windows":
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
else:
cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
print(
f"Please install the correct version of RSL-RL.\nExisting version is: '{version('rsl-rl-lib')}'"
" and required version is: '2.3.0'.\nTo install the correct version, run:"
f"\n\n\t{' '.join(cmd)}\n"
)
exit(1)
"""Launch Isaac Sim Simulator first."""
import argparse
......@@ -133,11 +118,20 @@ def main():
# obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
# extract the neural network module
# we do this in a try-except to maintain backwards compatibility.
try:
# version 2.3 onwards
policy_nn = ppo_runner.alg.policy
except AttributeError:
# version 2.2 and below
policy_nn = ppo_runner.alg.actor_critic
# export policy to onnx/jit
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_jit(ppo_runner.alg.policy, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_jit(policy_nn, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(
ppo_runner.alg.policy, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
policy_nn, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
)
dt = env.unwrapped.step_dt
......
......@@ -5,21 +5,6 @@
"""Script to train RL agent with RSL-RL."""
import platform
from importlib.metadata import version
if version("rsl-rl-lib") != "2.3.0":
if platform.system() == "Windows":
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
else:
cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", "rsl-rl-lib==2.3.0"]
print(
f"Please install the correct version of RSL-RL.\nExisting version is: '{version('rsl-rl-lib')}'"
" and required version is: '2.3.0'.\nTo install the correct version, run:"
f"\n\n\t{' '.join(cmd)}\n"
)
exit(1)
"""Launch Isaac Sim Simulator first."""
import argparse
......@@ -60,6 +45,28 @@ sys.argv = [sys.argv[0]] + hydra_args
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
"""Check for minimum supported RSL-RL version."""
import importlib.metadata as metadata
import platform
from packaging import version
# for distributed training, check minimum supported rsl-rl version
RSL_RL_VERSION = "2.3.1"
installed_version = metadata.version("rsl-rl-lib")
if args_cli.distributed and version.parse(installed_version) < version.parse(RSL_RL_VERSION):
if platform.system() == "Windows":
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"]
else:
cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"]
print(
f"Please install the correct version of RSL-RL.\nExisting version is: '{installed_version}'"
f" and required version is: '{RSL_RL_VERSION}'.\nTo install the correct version, run:"
f"\n\n\t{' '.join(cmd)}\n"
)
exit(1)
"""Rest everything follows."""
import gymnasium as gym
......@@ -138,7 +145,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = multi_agent_to_single_agent(env)
# save resume path before creating a new log_dir
if agent_cfg.resume:
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
# wrap for video recording
......@@ -161,7 +168,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# write git state to logs
runner.add_git_repo_to_log(__file__)
# load the checkpoint
if agent_cfg.resume:
if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation":
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model
runner.load(resume_path)
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.1.3"
version = "0.1.4"
# Description
title = "Isaac Lab RL"
......
Changelog
---------
0.1.4 (2025-04-10)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added configurations for distillation implementation in RSL-RL.
* Added configuration for recurrent actor-critic in RSL-RL.
0.1.3 (2025-03-31)
~~~~~~~~~~~~~~~~~~
......
......@@ -15,8 +15,9 @@ The following example shows how to wrap an environment for RSL-RL:
"""
from .distillation_cfg import *
from .exporter import export_policy_as_jit, export_policy_as_onnx
from .rl_cfg import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, RslRlPpoAlgorithmCfg
from .rl_cfg import *
from .rnd_cfg import RslRlRndCfg
from .symmetry_cfg import RslRlSymmetryCfg
from .vecenv_wrapper import RslRlVecEnvWrapper
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
from dataclasses import MISSING
from typing import Literal
from isaaclab.utils import configclass
#########################
# Policy configurations #
#########################
@configclass
class RslRlDistillationStudentTeacherCfg:
"""Configuration for the distillation student-teacher networks."""
class_name: str = "StudentTeacher"
"""The policy class name. Default is StudentTeacher."""
init_noise_std: float = MISSING
"""The initial noise standard deviation for the student policy."""
noise_std_type: Literal["scalar", "log"] = "scalar"
"""The type of noise standard deviation for the policy. Default is scalar."""
student_hidden_dims: list[int] = MISSING
"""The hidden dimensions of the student network."""
teacher_hidden_dims: list[int] = MISSING
"""The hidden dimensions of the teacher network."""
activation: str = MISSING
"""The activation function for the student and teacher networks."""
@configclass
class RslRlDistillationStudentTeacherRecurrentCfg(RslRlDistillationStudentTeacherCfg):
"""Configuration for the distillation student-teacher recurrent networks."""
class_name: str = "StudentTeacherRecurrent"
"""The policy class name. Default is StudentTeacherRecurrent."""
rnn_type: str = MISSING
"""The type of the RNN network. Either "lstm" or "gru"."""
rnn_hidden_dim: int = MISSING
"""The hidden dimension of the RNN network."""
rnn_num_layers: int = MISSING
"""The number of layers of the RNN network."""
teacher_recurrent: bool = MISSING
"""Whether the teacher network is recurrent too."""
############################
# Algorithm configurations #
############################
@configclass
class RslRlDistillationAlgorithmCfg:
"""Configuration for the distillation algorithm."""
class_name: str = "Distillation"
"""The algorithm class name. Default is Distillation."""
num_learning_epochs: int = MISSING
"""The number of updates performed with each sample."""
learning_rate: float = MISSING
"""The learning rate for the student policy."""
gradient_length: int = MISSING
"""The number of environment steps the gradient flows back."""
......@@ -8,26 +8,26 @@ import os
import torch
def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
def export_policy_as_jit(policy: object, normalizer: object | None, path: str, filename="policy.pt"):
"""Export policy into a Torch JIT file.
Args:
actor_critic: The actor-critic torch module.
policy: The policy torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt".
"""
policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
policy_exporter = _TorchPolicyExporter(policy, normalizer)
policy_exporter.export(path, filename)
def export_policy_as_onnx(
actor_critic: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
policy: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False
):
"""Export policy into a Torch ONNX file.
Args:
actor_critic: The actor-critic torch module.
policy: The policy torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported ONNX file. Defaults to "policy.onnx".
......@@ -35,7 +35,7 @@ def export_policy_as_onnx(
"""
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
policy_exporter = _OnnxPolicyExporter(policy, normalizer, verbose)
policy_exporter.export(path, filename)
......@@ -47,12 +47,22 @@ Helper Classes - Private.
class _TorchPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into JIT file."""
def __init__(self, actor_critic, normalizer=None):
def __init__(self, policy, normalizer=None):
super().__init__()
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
self.is_recurrent = policy.is_recurrent
# copy policy parameters
if hasattr(policy, "actor"):
self.actor = copy.deepcopy(policy.actor)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_a.rnn)
elif hasattr(policy, "student"):
self.actor = copy.deepcopy(policy.student)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_s.rnn)
else:
raise ValueError("Policy does not have an actor/student module.")
# set up recurrent network
if self.is_recurrent:
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
......@@ -94,13 +104,23 @@ class _TorchPolicyExporter(torch.nn.Module):
class _OnnxPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into ONNX file."""
def __init__(self, actor_critic, normalizer=None, verbose=False):
def __init__(self, policy, normalizer=None, verbose=False):
super().__init__()
self.verbose = verbose
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
self.is_recurrent = policy.is_recurrent
# copy policy parameters
if hasattr(policy, "actor"):
self.actor = copy.deepcopy(policy.actor)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_a.rnn)
elif hasattr(policy, "student"):
self.actor = copy.deepcopy(policy.student)
if self.is_recurrent:
self.rnn = copy.deepcopy(policy.memory_s.rnn)
else:
raise ValueError("Policy does not have an actor/student module.")
# set up recurrent network
if self.is_recurrent:
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.forward = self.forward_lstm
# copy normalizer if exists
......
......@@ -3,14 +3,21 @@
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
from dataclasses import MISSING
from typing import Literal
from isaaclab.utils import configclass
from .distillation_cfg import RslRlDistillationAlgorithmCfg, RslRlDistillationStudentTeacherCfg
from .rnd_cfg import RslRlRndCfg
from .symmetry_cfg import RslRlSymmetryCfg
#########################
# Policy configurations #
#########################
@configclass
class RslRlPpoActorCriticCfg:
......@@ -36,23 +43,33 @@ class RslRlPpoActorCriticCfg:
@configclass
class RslRlPpoAlgorithmCfg:
"""Configuration for the PPO algorithm."""
class RslRlPpoActorCriticRecurrentCfg(RslRlPpoActorCriticCfg):
"""Configuration for the PPO actor-critic networks with recurrent layers."""
class_name: str = "PPO"
"""The algorithm class name. Default is PPO."""
class_name: str = "ActorCriticRecurrent"
"""The policy class name. Default is ActorCriticRecurrent."""
value_loss_coef: float = MISSING
"""The coefficient for the value loss."""
rnn_type: str = MISSING
"""The type of RNN to use. Either "lstm" or "gru"."""
use_clipped_value_loss: bool = MISSING
"""Whether to use clipped value loss."""
rnn_hidden_dim: int = MISSING
"""The dimension of the RNN layers."""
clip_param: float = MISSING
"""The clipping parameter for the policy."""
rnn_num_layers: int = MISSING
"""The number of RNN layers."""
entropy_coef: float = MISSING
"""The coefficient for the entropy loss."""
############################
# Algorithm configurations #
############################
@configclass
class RslRlPpoAlgorithmCfg:
"""Configuration for the PPO algorithm."""
class_name: str = "PPO"
"""The algorithm class name. Default is PPO."""
num_learning_epochs: int = MISSING
"""The number of learning epochs per update."""
......@@ -72,12 +89,24 @@ class RslRlPpoAlgorithmCfg:
lam: float = MISSING
"""The lambda parameter for Generalized Advantage Estimation (GAE)."""
entropy_coef: float = MISSING
"""The coefficient for the entropy loss."""
desired_kl: float = MISSING
"""The desired KL divergence."""
max_grad_norm: float = MISSING
"""The maximum gradient norm."""
value_loss_coef: float = MISSING
"""The coefficient for the value loss."""
use_clipped_value_loss: bool = MISSING
"""Whether to use clipped value loss."""
clip_param: float = MISSING
"""The clipping parameter for the policy."""
normalize_advantage_per_mini_batch: bool = False
"""Whether to normalize the advantage per mini-batch. Default is False.
......@@ -94,6 +123,11 @@ class RslRlPpoAlgorithmCfg:
"""
#########################
# Runner configurations #
#########################
@configclass
class RslRlOnPolicyRunnerCfg:
"""Configuration of the runner for on-policy algorithms."""
......@@ -113,10 +147,10 @@ class RslRlOnPolicyRunnerCfg:
empirical_normalization: bool = MISSING
"""Whether to use empirical normalization."""
policy: RslRlPpoActorCriticCfg = MISSING
policy: RslRlPpoActorCriticCfg | RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration."""
algorithm: RslRlPpoAlgorithmCfg = MISSING
algorithm: RslRlPpoAlgorithmCfg | RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""
clip_actions: float | None = None
......@@ -126,10 +160,6 @@ class RslRlOnPolicyRunnerCfg:
This clipping is performed inside the :class:`RslRlVecEnvWrapper` wrapper.
"""
##
# Checkpointing parameters
##
save_interval: int = MISSING
"""The number of iterations between saves."""
......@@ -144,10 +174,6 @@ class RslRlOnPolicyRunnerCfg:
``{time-stamp}_{run_name}``.
"""
##
# Logging parameters
##
logger: Literal["tensorboard", "neptune", "wandb"] = "tensorboard"
"""The logger to use. Default is tensorboard."""
......@@ -157,10 +183,6 @@ class RslRlOnPolicyRunnerCfg:
wandb_project: str = "isaaclab"
"""The wandb project name. Default is "isaaclab"."""
##
# Loading parameters
##
resume: bool = False
"""Whether to resume. Default is False."""
......
......@@ -44,7 +44,7 @@ EXTRAS_REQUIRE = {
"sb3": ["stable-baselines3>=2.1"],
"skrl": ["skrl>=1.4.2"],
"rl-games": ["rl-games==1.6.1", "gym"], # rl-games still needs gym :(
"rsl-rl": ["rsl-rl-lib==2.3.0"],
"rsl-rl": ["rsl-rl-lib==2.3.1"],
}
# Add the names with hyphens as aliases for convenience
EXTRAS_REQUIRE["rl_games"] = EXTRAS_REQUIRE["rl-games"]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment