Commit 28762da0 authored by peterd-NV's avatar peterd-NV Committed by Kelly Guo

Adds a CLI argument to set epochs for Robomimic training script (#449)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

1. Adds an optional CLI argument to the robomimic training script that
can be used to set the number of training epochs. If set, the epochs
defined by the JSON training config is overwritten.

2. Save the last training epoch regardless if it does not satisfy the
training interval defined in the JSON training config. This ensures that
a model will always be saved even if the user specifies an arbitrary
epoch number that does not divide evenly by the save interval defined in
the JSON.



## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)


## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 9c8ea7d6
......@@ -285,7 +285,8 @@ def train(config: Config, device: str, log_dir: str, ckpt_dir: str, video_dir: s
and (epoch % config.experiment.save.every_n_epochs == 0)
)
epoch_list_check = epoch in config.experiment.save.epochs
should_save_ckpt = time_check or epoch_check or epoch_list_check
last_epoch_check = epoch == config.train.num_epochs
should_save_ckpt = time_check or epoch_check or epoch_list_check or last_epoch_check
ckpt_reason = None
if should_save_ckpt:
last_ckpt_time = time.time()
......@@ -383,6 +384,9 @@ def main(args: argparse.Namespace):
if args.name is not None:
config.experiment.name = args.name
if args.epochs is not None:
config.train.num_epochs = args.epochs
# change location of experiment directory
config.train.output_dir = os.path.abspath(os.path.join("./logs", args.log_dir, args.task))
......@@ -428,6 +432,15 @@ if __name__ == "__main__":
parser.add_argument("--algo", type=str, default=None, help="Name of the algorithm.")
parser.add_argument("--log_dir", type=str, default="robomimic", help="Path to log directory")
parser.add_argument("--normalize_training_actions", action="store_true", default=False, help="Normalize actions")
parser.add_argument(
"--epochs",
type=int,
default=None,
help=(
"Optional: Number of training epochs. If specified, overrides the number of epochs from the JSON training"
" config."
),
)
args = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment