Commit db31bb3c authored by ooctipus's avatar ooctipus Committed by Kelly Guo

Enables hydra for all play.py scripts (#2995)

# Description
This PR enables hydra override for all play.py scripts
I have mannually tested all rl_frameworks and worked.

I remember there is a issue related, but couldn't find it, feel free to
add to it if you found it.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->
- New feature (non-breaking change which adds functionality)

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent bfd6fb0d
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
...@@ -35,11 +36,13 @@ parser.add_argument("--real-time", action="store_true", default=False, help="Run ...@@ -35,11 +36,13 @@ parser.add_argument("--real-time", action="store_true", default=False, help="Run
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -57,7 +60,13 @@ from rl_games.common import env_configurations, vecenv ...@@ -57,7 +60,13 @@ from rl_games.common import env_configurations, vecenv
from rl_games.common.player import BasePlayer from rl_games.common.player import BasePlayer
from rl_games.torch_runner import Runner from rl_games.torch_runner import Runner
from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent from isaaclab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from isaaclab.utils.assets import retrieve_file_path from isaaclab.utils.assets import retrieve_file_path
from isaaclab.utils.dict import print_dict from isaaclab.utils.dict import print_dict
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
...@@ -65,19 +74,19 @@ from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkp ...@@ -65,19 +74,19 @@ from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkp
from isaaclab_rl.rl_games import RlGamesGpuEnv, RlGamesVecEnvWrapper from isaaclab_rl.rl_games import RlGamesGpuEnv, RlGamesVecEnvWrapper
import isaaclab_tasks # noqa: F401 import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg from isaaclab_tasks.utils import get_checkpoint_path
from isaaclab_tasks.utils.hydra import hydra_task_config
# PLACEHOLDER: Extension template (do not remove this comment) # PLACEHOLDER: Extension template (do not remove this comment)
def main(): @hydra_task_config(args_cli.task, "rl_games_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Play with RL-Games agent.""" """Play with RL-Games agent."""
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
# parse env configuration # override configurations with non-hydra CLI arguments
env_cfg = parse_env_cfg( env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
)
agent_cfg = load_cfg_from_registry(args_cli.task, "rl_games_cfg_entry_point")
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"]) log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"])
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
...@@ -33,11 +34,15 @@ parser.add_argument("--real-time", action="store_true", default=False, help="Run ...@@ -33,11 +34,15 @@ parser.add_argument("--real-time", action="store_true", default=False, help="Run
cli_args.add_rsl_rl_args(parser) cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args() # parse the arguments
args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -51,7 +56,13 @@ import torch ...@@ -51,7 +56,13 @@ import torch
from rsl_rl.runners import OnPolicyRunner from rsl_rl.runners import OnPolicyRunner
from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent from isaaclab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from isaaclab.utils.assets import retrieve_file_path from isaaclab.utils.assets import retrieve_file_path
from isaaclab.utils.dict import print_dict from isaaclab.utils.dict import print_dict
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
...@@ -59,19 +70,20 @@ from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkp ...@@ -59,19 +70,20 @@ from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkp
from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx
import isaaclab_tasks # noqa: F401 import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils import get_checkpoint_path, parse_env_cfg from isaaclab_tasks.utils import get_checkpoint_path
from isaaclab_tasks.utils.hydra import hydra_task_config
# PLACEHOLDER: Extension template (do not remove this comment) # PLACEHOLDER: Extension template (do not remove this comment)
def main(): @hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
"""Play with RSL-RL agent.""" """Play with RSL-RL agent."""
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
# parse configuration # override configurations with non-hydra CLI arguments
env_cfg = parse_env_cfg( agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
) env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(task_name, args_cli)
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys
from pathlib import Path from pathlib import Path
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
...@@ -42,11 +43,14 @@ parser.add_argument( ...@@ -42,11 +43,14 @@ parser.add_argument(
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -61,25 +65,31 @@ import torch ...@@ -61,25 +65,31 @@ import torch
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent from isaaclab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from isaaclab.utils.dict import print_dict from isaaclab.utils.dict import print_dict
from isaaclab.utils.io import load_yaml
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
from isaaclab_rl.sb3 import Sb3VecEnvWrapper, process_sb3_cfg from isaaclab_rl.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
import isaaclab_tasks # noqa: F401 import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils.parse_cfg import get_checkpoint_path, parse_env_cfg from isaaclab_tasks.utils.hydra import hydra_task_config
from isaaclab_tasks.utils.parse_cfg import get_checkpoint_path
# PLACEHOLDER: Extension template (do not remove this comment) # PLACEHOLDER: Extension template (do not remove this comment)
def main(): @hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Play with stable-baselines agent.""" """Play with stable-baselines agent."""
# parse configuration # override configurations with non-hydra CLI arguments
env_cfg = parse_env_cfg( env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
)
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "") train_task_name = task_name.replace("-Play", "")
...@@ -107,8 +117,6 @@ def main(): ...@@ -107,8 +117,6 @@ def main():
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# load the exact config used for training (instead of the default config)
agent_cfg = load_yaml(os.path.join(log_dir, "params", "agent.yaml"))
# post-process agent configuration # post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg, env.unwrapped.num_envs) agent_cfg = process_sb3_cfg(agent_cfg, env.unwrapped.num_envs)
......
...@@ -13,6 +13,7 @@ a more user-friendly way. ...@@ -13,6 +13,7 @@ a more user-friendly way.
"""Launch Isaac Sim Simulator first.""" """Launch Isaac Sim Simulator first."""
import argparse import argparse
import sys
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
...@@ -49,11 +50,14 @@ parser.add_argument("--real-time", action="store_true", default=False, help="Run ...@@ -49,11 +50,14 @@ parser.add_argument("--real-time", action="store_true", default=False, help="Run
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args() # parse the arguments
args_cli, hydra_args = parser.parse_known_args()
# always enable cameras to record video # always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app simulation_app = app_launcher.app
...@@ -82,38 +86,42 @@ if args_cli.ml_framework.startswith("torch"): ...@@ -82,38 +86,42 @@ if args_cli.ml_framework.startswith("torch"):
elif args_cli.ml_framework.startswith("jax"): elif args_cli.ml_framework.startswith("jax"):
from skrl.utils.runner.jax import Runner from skrl.utils.runner.jax import Runner
from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent from isaaclab.envs import (
DirectMARLEnv,
DirectMARLEnvCfg,
DirectRLEnvCfg,
ManagerBasedRLEnvCfg,
multi_agent_to_single_agent,
)
from isaaclab.utils.dict import print_dict from isaaclab.utils.dict import print_dict
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
from isaaclab_rl.skrl import SkrlVecEnvWrapper from isaaclab_rl.skrl import SkrlVecEnvWrapper
import isaaclab_tasks # noqa: F401 import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg from isaaclab_tasks.utils import get_checkpoint_path
from isaaclab_tasks.utils.hydra import hydra_task_config
# PLACEHOLDER: Extension template (do not remove this comment) # PLACEHOLDER: Extension template (do not remove this comment)
# config shortcuts # config shortcuts
algorithm = args_cli.algorithm.lower() algorithm = args_cli.algorithm.lower()
agent_cfg_entry_point = "skrl_cfg_entry_point" if algorithm in ["ppo"] else f"skrl_{algorithm}_cfg_entry_point"
def main(): @hydra_task_config(args_cli.task, agent_cfg_entry_point)
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, experiment_cfg: dict):
"""Play with skrl agent.""" """Play with skrl agent."""
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# configure the ML framework into the global skrl variable # configure the ML framework into the global skrl variable
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
# parse configuration
env_cfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
try:
experiment_cfg = load_cfg_from_registry(task_name, f"skrl_{algorithm}_cfg_entry_point")
except ValueError:
experiment_cfg = load_cfg_from_registry(task_name, "skrl_cfg_entry_point")
# specify directory for logging experiments (load checkpoint) # specify directory for logging experiments (load checkpoint)
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"]) log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment