Commit 11474763 authored by Toni-SM's avatar Toni-SM Committed by Kelly Guo

Adds checkpoint CLI argument to skrl's train script to resume training (#244)

# Description

Add `checkpoint` CLI argument to skrl's train script to resume training.
It solves https://github.com/isaac-sim/IsaacLab/issues/1635

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent cec2918c
......@@ -28,6 +28,7 @@ parser.add_argument("--seed", type=int, default=None, help="Seed used for the en
parser.add_argument(
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
)
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint to resume training.")
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
parser.add_argument(
"--ml_framework",
......@@ -48,7 +49,7 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
......@@ -92,6 +93,7 @@ from isaaclab.envs import (
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from isaaclab.utils.assets import retrieve_file_path
from isaaclab.utils.dict import print_dict
from isaaclab.utils.io import dump_pickle, dump_yaml
......@@ -151,6 +153,9 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
# get checkpoint path (to resume training)
resume_path = retrieve_file_path(args_cli.checkpoint) if args_cli.checkpoint else None
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
......@@ -177,6 +182,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# https://skrl.readthedocs.io/en/latest/api/utils/runner.html
runner = Runner(env, agent_cfg)
# load checkpoint (if specified)
if resume_path:
print(f"[INFO] Loading model checkpoint from: {resume_path}")
runner.agent.load(resume_path)
# run training
runner.run()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment