Commit 390d28a7 authored by ooctipus's avatar ooctipus Committed by Kelly Guo

Enables sb3 to load checkpoint to continue training (#2954)

# Description

This PR extend `script/reinforcement_learning/sb3/train.py` with feature
to continue learning by loading the checkpoint.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 442a7763
......@@ -25,6 +25,7 @@ parser.add_argument("--num_envs", type=int, default=None, help="Number of enviro
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument("--log_interval", type=int, default=100_000, help="Log data every n timesteps.")
parser.add_argument("--checkpoint", type=str, default=None, help="Continue the training from checkpoint.")
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
parser.add_argument(
"--keep_all_info",
......@@ -179,6 +180,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# create agent from stable baselines
agent = PPO(policy_arch, env, verbose=1, tensorboard_log=log_dir, **agent_cfg)
if args_cli.checkpoint is not None:
agent = agent.load(args_cli.checkpoint, env, print_system_info=True)
# callbacks for agent
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_dir, name_prefix="model", verbose=2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment