Unverified Commit f879aa6a authored by Wei Yang's avatar Wei Yang Committed by GitHub

Fixes the checkpoint loading error in RSL-RL training script (#1210)

# Description

An error of `No checkpoints in the directory` will throw when resume
from a previous training with `--video` set. This is because a new log
folder will be created before the check.

This MR fixes this issue by loading the checkpoint before.

Fixes #1209 

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 1b8943c1
...@@ -63,6 +63,7 @@ Guidelines for modifications: ...@@ -63,6 +63,7 @@ Guidelines for modifications:
* Rosario Scalise * Rosario Scalise
* Shafeef Omar * Shafeef Omar
* Vladimir Fokow * Vladimir Fokow
* Wei Yang
* Xavier Nal * Xavier Nal
* Yang Jin * Yang Jin
* Zhengyu Zhang * Zhengyu Zhang
......
...@@ -99,6 +99,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -99,6 +99,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# save resume path before creating a new log_dir
if agent_cfg.resume:
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
...@@ -122,10 +127,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -122,10 +127,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device) runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device)
# write git state to logs # write git state to logs
runner.add_git_repo_to_log(__file__) runner.add_git_repo_to_log(__file__)
# save resume path before creating a new log_dir # load the checkpoint
if agent_cfg.resume: if agent_cfg.resume:
# get path to previous checkpoint
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
print(f"[INFO]: Loading model checkpoint from: {resume_path}") print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model # load previously trained model
runner.load(resume_path) runner.load(resume_path)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment