Unverified Commit 1b8e2c0e authored by sizsJEon's avatar sizsJEon Committed by GitHub

Adds the ability to resume training from a checkpoint with rl_games (#797)

# Description

The rl_games train.py script did not include the checkpoint and sigma.
(it was none)
I have added these features and verified their operation.

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 6451d235
...@@ -39,6 +39,7 @@ Guidelines for modifications: ...@@ -39,6 +39,7 @@ Guidelines for modifications:
* Brayden Zhang * Brayden Zhang
* Calvin Yu * Calvin Yu
* Chenyu Yang * Chenyu Yang
* HoJin Jeon
* Jia Lin Yuan * Jia Lin Yuan
* Jingzhou Liu * Jingzhou Liu
* Johnson Sun * Johnson Sun
......
...@@ -22,9 +22,13 @@ parser.add_argument( ...@@ -22,9 +22,13 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument( parser.add_argument(
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes." "--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
) )
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument("--sigma", type=str, default=None, help="The policy's initial standard deviation.")
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.") parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
# append AppLauncher cli args # append AppLauncher cli args
...@@ -50,6 +54,7 @@ from rl_games.common import env_configurations, vecenv ...@@ -50,6 +54,7 @@ from rl_games.common import env_configurations, vecenv
from rl_games.common.algo_observer import IsaacAlgoObserver from rl_games.common.algo_observer import IsaacAlgoObserver
from rl_games.torch_runner import Runner from rl_games.torch_runner import Runner
from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
...@@ -130,6 +135,17 @@ def main(): ...@@ -130,6 +135,17 @@ def main():
) )
env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env}) env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
if args_cli.checkpoint is not None:
resume_path = retrieve_file_path(args_cli.checkpoint)
agent_cfg["params"]["load_checkpoint"] = True
agent_cfg["params"]["load_path"] = resume_path
print(f"[INFO]: Loading model checkpoint from: {agent_cfg['params']['load_path']}")
if args_cli.sigma is not None:
train_sigma = float(args_cli.sigma)
else:
train_sigma = None
# set number of actors into agent config # set number of actors into agent config
agent_cfg["params"]["config"]["num_actors"] = env.unwrapped.num_envs agent_cfg["params"]["config"]["num_actors"] = env.unwrapped.num_envs
# create runner from rl-games # create runner from rl-games
...@@ -141,7 +157,10 @@ def main(): ...@@ -141,7 +157,10 @@ def main():
# reset the agent and env # reset the agent and env
runner.reset() runner.reset()
# train the agent # train the agent
runner.run({"train": True, "play": False, "sigma": None}) if args_cli.checkpoint is not None:
runner.run({"train": True, "play": False, "sigma": train_sigma, "checkpoint": resume_path})
else:
runner.run({"train": True, "play": False, "sigma": train_sigma})
# close the simulator # close the simulator
env.close() env.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment