Unverified Commit 648e1568 authored by matthewtrepte's avatar matthewtrepte Committed by GitHub

Fixes parsing for play envs (#582)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

Fix checkpoint path parsing when a -Play env is provided to play scripts
using the --task argument,.

Fixes # (issue)

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- This change requires a documentation update

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [ ] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent a396e450
...@@ -83,7 +83,10 @@ from isaaclab_tasks.utils.hydra import hydra_task_config ...@@ -83,7 +83,10 @@ from isaaclab_tasks.utils.hydra import hydra_task_config
@hydra_task_config(args_cli.task, "rl_games_cfg_entry_point") @hydra_task_config(args_cli.task, "rl_games_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Play with RL-Games agent.""" """Play with RL-Games agent."""
# grab task name for checkpoint path
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
...@@ -94,7 +97,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -94,7 +97,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print(f"[INFO] Loading experiment from directory: {log_root_path}") print(f"[INFO] Loading experiment from directory: {log_root_path}")
# find checkpoint # find checkpoint
if args_cli.use_pretrained_checkpoint: if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("rl_games", task_name) resume_path = get_published_pretrained_checkpoint("rl_games", train_task_name)
if not resume_path: if not resume_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return return
......
...@@ -79,7 +79,10 @@ from isaaclab_tasks.utils.hydra import hydra_task_config ...@@ -79,7 +79,10 @@ from isaaclab_tasks.utils.hydra import hydra_task_config
@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point") @hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
"""Play with RSL-RL agent.""" """Play with RSL-RL agent."""
# grab task name for checkpoint path
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
...@@ -90,7 +93,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -90,7 +93,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}") print(f"[INFO] Loading experiment from directory: {log_root_path}")
if args_cli.use_pretrained_checkpoint: if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("rsl_rl", task_name) resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name)
if not resume_path: if not resume_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return return
......
...@@ -87,13 +87,14 @@ from isaaclab_tasks.utils.parse_cfg import get_checkpoint_path ...@@ -87,13 +87,14 @@ from isaaclab_tasks.utils.parse_cfg import get_checkpoint_path
@hydra_task_config(args_cli.task, "sb3_cfg_entry_point") @hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Play with stable-baselines agent.""" """Play with stable-baselines agent."""
# grab task name for checkpoint path
task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")
# directory for logging into # directory for logging into
log_root_path = os.path.join("logs", "sb3", train_task_name) log_root_path = os.path.join("logs", "sb3", train_task_name)
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
......
...@@ -112,6 +112,10 @@ agent_cfg_entry_point = "skrl_cfg_entry_point" if algorithm in ["ppo"] else f"sk ...@@ -112,6 +112,10 @@ agent_cfg_entry_point = "skrl_cfg_entry_point" if algorithm in ["ppo"] else f"sk
@hydra_task_config(args_cli.task, agent_cfg_entry_point) @hydra_task_config(args_cli.task, agent_cfg_entry_point)
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, experiment_cfg: dict): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, experiment_cfg: dict):
"""Play with skrl agent.""" """Play with skrl agent."""
# grab task name for checkpoint path
task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
...@@ -120,15 +124,13 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, expe ...@@ -120,15 +124,13 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, expe
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
task_name = args_cli.task.split(":")[-1]
# specify directory for logging experiments (load checkpoint) # specify directory for logging experiments (load checkpoint)
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"]) log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}") print(f"[INFO] Loading experiment from directory: {log_root_path}")
# get checkpoint path # get checkpoint path
if args_cli.use_pretrained_checkpoint: if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("skrl", task_name) resume_path = get_published_pretrained_checkpoint("skrl", train_task_name)
if not resume_path: if not resume_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment