Unverified Commit 6c06a58b authored by Clemens Schwarke's avatar Clemens Schwarke Committed by GitHub

Adds a configuration example for Student-Teacher Distillation (#3100)

# Description

This PR adds a configuration class to distill a walking policy for
ANYmal D as an example. The training is run almost the same way as a
normal PPO training. The only difference is that a policy checkpoint
needs to be passed via the `--load_run` CLI argument, to serve as the
teacher.

Additionally, the `RslRlDistillationRunnerCfg` got moved to the correct
file.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent d7613ce8
...@@ -21,6 +21,7 @@ Guidelines for modifications: ...@@ -21,6 +21,7 @@ Guidelines for modifications:
* Antonio Serrano-Muñoz * Antonio Serrano-Muñoz
* Ben Johnston * Ben Johnston
* Clemens Schwarke
* David Hoeller * David Hoeller
* Farbod Farshidian * Farbod Farshidian
* Hunter Hansen * Hunter Hansen
...@@ -54,7 +55,6 @@ Guidelines for modifications: ...@@ -54,7 +55,6 @@ Guidelines for modifications:
* Calvin Yu * Calvin Yu
* Cheng-Rong Lai * Cheng-Rong Lai
* Chenyu Yang * Chenyu Yang
* Clemens Schwarke
* Connor Smith * Connor Smith
* CY (Chien-Ying) Chen * CY (Chien-Ying) Chen
* David Yang * David Yang
......
...@@ -87,6 +87,40 @@ RSL-RL ...@@ -87,6 +87,40 @@ RSL-RL
:: run script for recording video of a trained agent (requires installing `ffmpeg`) :: run script for recording video of a trained agent (requires installing `ffmpeg`)
isaaclab.bat -p scripts\reinforcement_learning\rsl_rl\play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200 isaaclab.bat -p scripts\reinforcement_learning\rsl_rl\play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
- Training and distilling an agent with
`RSL-RL <https://github.com/leggedrobotics/rsl_rl>`__ on ``Isaac-Velocity-Flat-Anymal-D-v0``:
.. tab-set::
:sync-group: os
.. tab-item:: :icon:`fa-brands fa-linux` Linux
:sync: linux
.. code:: bash
# install python module (for rsl-rl)
./isaaclab.sh -i rsl_rl
# run script for rl training of the teacher agent
./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/train.py --task Isaac-Velocity-Flat-Anymal-D-v0 --headless
# run script for distilling the teacher agent into a student agent
./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/train.py --task Isaac-Velocity-Flat-Anymal-D-v0 --headless --agent rsl_rl_distillation_cfg_entry_point --load_run teacher_run_folder_name
# run script for playing the student with 64 environments
./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/play.py --task Isaac-Velocity-Flat-Anymal-D-v0 --num_envs 64 --agent rsl_rl_distillation_cfg_entry_point
.. tab-item:: :icon:`fa-brands fa-windows` Windows
:sync: windows
.. code:: batch
:: install python module (for rsl-rl)
isaaclab.bat -i rsl_rl
:: run script for rl training of the teacher agent
isaaclab.bat -p scripts\reinforcement_learning\rsl_rl\train.py --task Isaac-Velocity-Flat-Anymal-D-v0 --headless
:: run script for distilling the teacher agent into a student agent
isaaclab.bat -p scripts\reinforcement_learning\rsl_rl\train.py --task Isaac-Velocity-Flat-Anymal-D-v0 --headless --agent rsl_rl_distillation_cfg_entry_point --load_run teacher_run_folder_name
:: run script for playing the student with 64 environments
isaaclab.bat -p scripts\reinforcement_learning\rsl_rl\play.py --task Isaac-Velocity-Flat-Anymal-D-v0 --num_envs 64 --agent rsl_rl_distillation_cfg_entry_point
SKRL SKRL
---- ----
......
...@@ -10,6 +10,8 @@ from typing import Literal ...@@ -10,6 +10,8 @@ from typing import Literal
from isaaclab.utils import configclass from isaaclab.utils import configclass
from .rl_cfg import RslRlBaseRunnerCfg
######################### #########################
# Policy configurations # # Policy configurations #
######################### #########################
...@@ -93,3 +95,22 @@ class RslRlDistillationAlgorithmCfg: ...@@ -93,3 +95,22 @@ class RslRlDistillationAlgorithmCfg:
loss_type: Literal["mse", "huber"] = "mse" loss_type: Literal["mse", "huber"] = "mse"
"""The loss type to use for the student policy.""" """The loss type to use for the student policy."""
#########################
# Runner configurations #
#########################
@configclass
class RslRlDistillationRunnerCfg(RslRlBaseRunnerCfg):
"""Configuration of the runner for distillation algorithms."""
class_name: str = "DistillationRunner"
"""The runner class name. Default is DistillationRunner."""
policy: RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration."""
algorithm: RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""
...@@ -10,7 +10,6 @@ from typing import Literal ...@@ -10,7 +10,6 @@ from typing import Literal
from isaaclab.utils import configclass from isaaclab.utils import configclass
from .distillation_cfg import RslRlDistillationAlgorithmCfg, RslRlDistillationStudentTeacherCfg
from .rnd_cfg import RslRlRndCfg from .rnd_cfg import RslRlRndCfg
from .symmetry_cfg import RslRlSymmetryCfg from .symmetry_cfg import RslRlSymmetryCfg
...@@ -237,17 +236,3 @@ class RslRlOnPolicyRunnerCfg(RslRlBaseRunnerCfg): ...@@ -237,17 +236,3 @@ class RslRlOnPolicyRunnerCfg(RslRlBaseRunnerCfg):
algorithm: RslRlPpoAlgorithmCfg = MISSING algorithm: RslRlPpoAlgorithmCfg = MISSING
"""The algorithm configuration.""" """The algorithm configuration."""
@configclass
class RslRlDistillationRunnerCfg(RslRlBaseRunnerCfg):
"""Configuration of the runner for distillation algorithms."""
class_name: str = "DistillationRunner"
"""The runner class name. Default is DistillationRunner."""
policy: RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration."""
algorithm: RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""
...@@ -18,6 +18,9 @@ gym.register( ...@@ -18,6 +18,9 @@ gym.register(
kwargs={ kwargs={
"env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg", "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg", "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg",
"rsl_rl_distillation_cfg_entry_point": (
f"{agents.__name__}.rsl_rl_distillation_cfg:AnymalDFlatDistillationRunnerCfg"
),
"rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerWithSymmetryCfg", "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerWithSymmetryCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml",
}, },
...@@ -30,6 +33,9 @@ gym.register( ...@@ -30,6 +33,9 @@ gym.register(
kwargs={ kwargs={
"env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg_PLAY", "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg_PLAY",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg", "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg",
"rsl_rl_distillation_cfg_entry_point": (
f"{agents.__name__}.rsl_rl_distillation_cfg:AnymalDFlatDistillationRunnerCfg"
),
"rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerWithSymmetryCfg", "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerWithSymmetryCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml",
}, },
......
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from isaaclab.utils import configclass
from isaaclab_rl.rsl_rl import (
RslRlDistillationAlgorithmCfg,
RslRlDistillationRunnerCfg,
RslRlDistillationStudentTeacherCfg,
)
@configclass
class AnymalDFlatDistillationRunnerCfg(RslRlDistillationRunnerCfg):
num_steps_per_env = 120
max_iterations = 300
save_interval = 50
experiment_name = "anymal_d_flat"
obs_groups = {"policy": ["policy"], "teacher": ["policy"]}
policy = RslRlDistillationStudentTeacherCfg(
init_noise_std=0.1,
noise_std_type="scalar",
student_obs_normalization=False,
teacher_obs_normalization=False,
student_hidden_dims=[128, 128, 128],
teacher_hidden_dims=[128, 128, 128],
activation="elu",
)
algorithm = RslRlDistillationAlgorithmCfg(
num_learning_epochs=2,
learning_rate=1.0e-3,
gradient_length=15,
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment