Commit 1a2d7048 authored by nv-cupright's avatar nv-cupright Committed by Kelly Guo

Adds pre-trained checkpoints and tools for generating and uploading checkpoints (#151)

Created a standalone script that can train all our environments and
publish their checkpoints to a Nucleus server. The play.py scripts were
modified to add a --use_pretrained_checkpoint flag. This downloads and
caches the pre-trained checkpoint to a .pretrained_checkpoints
directory.

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarpeterd-NV <peterd@nvidia.com>
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Signed-off-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
Co-authored-by: 's avatarpeterd-NV <peterd@nvidia.com>
Co-authored-by: 's avatarCY Chen <cyc@nvidia.com>
Co-authored-by: 's avataroahmednv <oahmed@Nvidia.com>
Co-authored-by: 's avatarToni-SM <aserranomuno@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 981ee227
......@@ -57,3 +57,6 @@ _build
**/runs/*
**/logs/*
**/recordings/*
# Pre-Trained Checkpoints
/.pretrained_checkpoints/
......@@ -39,6 +39,7 @@ Guidelines for modifications:
* Anton Bjørndahl Mortensen
* Arjun Bhardwaj
* Brayden Zhang
* Cameron Upright
* Calvin Yu
* Chenyu Yang
* CY (Chien-Ying) Chen
......
......@@ -25,6 +25,8 @@ RL-Games
./isaaclab.sh -p source/standalone/workflows/rl_games/train.py --task Isaac-Ant-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --num_envs 32 --checkpoint /PATH/TO/model.pth
# run script for playing a pre-trained checkpoint with 32 environments
./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --num_envs 32 --use_pretrained_checkpoint
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --headless --video --video_length 200
......@@ -39,6 +41,8 @@ RL-Games
isaaclab.bat -p source\standalone\workflows\rl_games\train.py --task Isaac-Ant-v0 --headless
:: run script for playing with 32 environments
isaaclab.bat -p source\standalone\workflows\rl_games\play.py --task Isaac-Ant-v0 --num_envs 32 --checkpoint /PATH/TO/model.pth
:: run script for playing a pre-trained checkpoint with 32 environments
isaaclab.bat -p source\standalone\workflows\rl_games\play.py --task Isaac-Ant-v0 --num_envs 32 --use_pretrained_checkpoint
:: run script for recording video of a trained agent (requires installing `ffmpeg`)
isaaclab.bat -p source\standalone\workflows\rl_games\play.py --task Isaac-Ant-v0 --headless --video --video_length 200
......@@ -62,6 +66,8 @@ RSL-RL
./isaaclab.sh -p source/standalone/workflows/rsl_rl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --load_run run_folder_name --checkpoint model.pt
# run script for playing a pre-trained checkpoint with 32 environments
./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --use_pretrained_checkpoint
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
......@@ -76,6 +82,8 @@ RSL-RL
isaaclab.bat -p source\standalone\workflows\rsl_rl\train.py --task Isaac-Reach-Franka-v0 --headless
:: run script for playing with 32 environments
isaaclab.bat -p source\standalone\workflows\rsl_rl\play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --load_run run_folder_name --checkpoint model.pt
:: run script for playing a pre-trained checkpoint with 32 environments
isaaclab.bat -p source\standalone\workflows\rsl_rl\play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --use_pretrained_checkpoint
:: run script for recording video of a trained agent (requires installing `ffmpeg`)
isaaclab.bat -p source\standalone\workflows\rsl_rl\play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
......@@ -103,6 +111,8 @@ SKRL
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
# run script for playing a pre-trained checkpoint with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --use_pretrained_checkpoint
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
......@@ -117,6 +127,8 @@ SKRL
isaaclab.bat -p source\standalone\workflows\skrl\train.py --task Isaac-Reach-Franka-v0 --headless
:: run script for playing with 32 environments
isaaclab.bat -p source\standalone\workflows\skrl\play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
:: run script for playing a pre-trained checkpoint with 32 environments
isaaclab.bat -p source\standalone\workflows\skrl\play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --use_pretrained_checkpoint
:: run script for recording video of a trained agent (requires installing `ffmpeg`)
isaaclab.bat -p source\standalone\workflows\skrl\play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
......@@ -191,6 +203,8 @@ Stable-Baselines3
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --headless --device cpu
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip
# run script for playing a pre-trained checkpoint with 32 environments
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --use_pretrained_checkpoint
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --headless --video --video_length 200
......@@ -206,6 +220,8 @@ Stable-Baselines3
isaaclab.bat -p source\standalone\workflows\sb3\train.py --task Isaac-Cartpole-v0 --headless --device cpu
:: run script for playing with 32 environments
isaaclab.bat -p source\standalone\workflows\sb3\play.py --task Isaac-Cartpole-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip
:: run script for playing a pre-trained checkpoint with 32 environments
isaaclab.bat -p source\standalone\workflows\sb3\play.py --task Isaac-Cartpole-v0 --num_envs 32 --use_pretrained_checkpoint
:: run script for recording video of a trained agent (requires installing `ffmpeg`)
isaaclab.bat -p source\standalone\workflows\sb3\play.py --task Isaac-Cartpole-v0 --headless --video --video_length 200
......
......@@ -163,7 +163,7 @@ Added
* Added full buffer property to :class:`omni.isaac.lab.utils.buffers.circular_buffer.CircularBuffer`
0.27.32 (2024-12-15)
0.27.33 (2024-12-15)
~~~~~~~~~~~~~~~~~~~~
Added
......@@ -172,7 +172,7 @@ Added
* Added action clip to all :class:`omni.isaac.lab.envs.mdp.actions`.
0.27.31 (2024-12-14)
0.27.32 (2024-12-14)
~~~~~~~~~~~~~~~~~~~~
Changed
......@@ -181,7 +181,7 @@ Changed
* Added check for error below threshold in state machines to ensure the state has been reached.
0.27.30 (2024-12-13)
0.27.31 (2024-12-13)
~~~~~~~~~~~~~~~~~~~~
Fixed
......@@ -190,7 +190,7 @@ Fixed
* Fixed the shape of ``quat_w`` in the ``apply_actions`` method of :attr:`~omni.isaac.lab.env.mdp.NonHolonomicAction` (previously (N,B,4), now (N,4) since the number of root bodies B is required to be 1). Previously ``apply_actions`` errored because ``euler_xyz_from_quat`` requires inputs of shape (N,4).
0.27.29 (2024-12-11)
0.27.30 (2024-12-11)
~~~~~~~~~~~~~~~~~~~~
Changed
......@@ -201,7 +201,7 @@ Changed
* Improved documentation to clarify the usage of the :meth:`~omni.isaac.lab.envs.mdp.rewards.base_height_l2` function in both flat and rough terrain settings.
0.27.28 (2024-12-11)
0.27.29 (2024-12-11)
~~~~~~~~~~~~~~~~~~~~
Fixed
......@@ -211,7 +211,7 @@ Fixed
Jacobian computed w.r.t. to the root frame of the robot. This helps ensure that root pose does not affect the tracking.
0.27.27 (2024-12-09)
0.27.28 (2024-12-09)
~~~~~~~~~~~~~~~~~~~~
Fixed
......@@ -221,7 +221,7 @@ Fixed
return only the states of the specified environment IDs.
0.27.26 (2024-12-06)
0.27.27 (2024-12-06)
~~~~~~~~~~~~~~~~~~~~
Fixed
......@@ -231,7 +231,7 @@ Fixed
:attr:`~omni.isaac.lab.assets.Articulation.root_physx_view` level.
0.27.25 (2024-12-06)
0.27.26 (2024-12-06)
~~~~~~~~~~~~~~~~~~~~
Changed
......@@ -242,7 +242,7 @@ Changed
disabled. Using an articulation root for rigid bodies is not needed and decreases overall performance.
0.27.24 (2024-12-06)
0.27.25 (2024-12-06)
~~~~~~~~~~~~~~~~~~~~
Fixed
......@@ -252,7 +252,7 @@ Fixed
Earlier, the projection names used snakecase instead of camelcase.
0.27.23 (2024-12-06)
0.27.24 (2024-12-06)
~~~~~~~~~~~~~~~~~~~~
Added
......@@ -270,7 +270,7 @@ Changed
:class:`~omni.isaac.lab.sensors.Camera` did not clip them and had a different behavior for both types.
0.27.22 (2024-12-05)
0.27.23 (2024-12-05)
~~~~~~~~~~~~~~~~~~~~
Fixed
......@@ -279,7 +279,7 @@ Fixed
* Fixed the condition in ``isaaclab.sh`` that checks whether ``pre-commit`` is installed before attempting installation.
0.27.21 (2024-12-04)
0.27.22 (2024-12-04)
~~~~~~~~~~~~~~~~~~~~
Fixed
......@@ -288,7 +288,7 @@ Fixed
* Fixed the order of the incoming parameters in :class:`omni.isaac.lab.envs.DirectMARLEnv` to correctly use ``NoiseModel`` in marl-envs.
0.27.20 (2024-12-04)
0.27.21 (2024-12-04)
~~~~~~~~~~~~~~~~~~~~
Added
......@@ -303,7 +303,7 @@ Added
* Added ``replay_demos.py`` script to replay demos loaded from an HDF5 file.
0.27.19 (2024-12-02)
0.27.20 (2024-12-02)
~~~~~~~~~~~~~~~~~~~~
Changed
......@@ -312,6 +312,16 @@ Changed
* Changed :class:`omni.isaac.lab.envs.DirectMARLEnv` to inherit from ``Gymnasium.Env`` due to requirement from Gymnasium v1.0.0 requiring all environments to be a subclass of ``Gymnasium.Env`` when using the ``make`` interface.
0.27.19 (2024-12-02)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added ``omni.isaac.lab.utils.pretrained_checkpoints`` containing constants and utility functions used to manipulate
paths and load checkpoints from Nucleus.
0.27.18 (2024-11-28)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -49,7 +49,8 @@ def check_file_path(path: str) -> Literal[0, 1, 2]:
"""
if os.path.isfile(path):
return 1
elif omni.client.stat(path)[0] == omni.client.Result.OK:
# we need to convert backslash to forward slash on Windows for omni.client API
elif omni.client.stat(path.replace(os.sep, "/"))[0] == omni.client.Result.OK:
return 2
else:
return 0
......@@ -91,12 +92,12 @@ def retrieve_file_path(path: str, download_dir: str | None = None, force_downloa
if not os.path.exists(download_dir):
os.makedirs(download_dir)
# download file in temp directory using os
file_name = os.path.basename(omni.client.break_url(path).path)
file_name = os.path.basename(omni.client.break_url(path.replace(os.sep, "/")).path)
target_path = os.path.join(download_dir, file_name)
# check if file already exists locally
if not os.path.isfile(target_path) or force_download:
# copy file to local machine
result = omni.client.copy(path, target_path)
result = omni.client.copy(path.replace(os.sep, "/"), target_path)
if result != omni.client.Result.OK and force_download:
raise RuntimeError(f"Unable to copy file: '{path}'. Is the Nucleus Server running?")
return os.path.abspath(target_path)
......@@ -122,7 +123,7 @@ def read_file(path: str) -> io.BytesIO:
with open(path, "rb") as f:
return io.BytesIO(f.read())
elif file_status == 2:
file_content = omni.client.read_file(path)[2]
file_content = omni.client.read_file(path.replace(os.sep, "/"))[2]
return io.BytesIO(memoryview(file_content).tobytes())
else:
raise FileNotFoundError(f"Unable to find the file: {path}")
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Sub-module for handling various pre-trained checkpoint tasks"""
import glob
import json
import os
import carb.settings
from omni.isaac.lab.utils.assets import ISAACLAB_NUCLEUS_DIR
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry # noqa: F401
from .assets import retrieve_file_path
PRETRAINED_CHECKPOINTS_ASSET_ROOT_DIR = carb.settings.get_settings().get(
"/persistent/isaaclab/asset_root/pretrained_checkpoints"
)
"""Path to the root directory on the Nucleus Server."""
WORKFLOWS = ["rl_games", "rsl_rl", "sb3", "skrl"]
"""The supported workflows for pre-trained checkpoints"""
WORKFLOW_TRAINER = {w: f"source/standalone/workflows/{w}/train.py" for w in WORKFLOWS}
"""A dict mapping workflow to their training program path"""
WORKFLOW_PLAYER = {w: f"source/standalone/workflows/{w}/play.py" for w in WORKFLOWS}
"""A dict mapping workflow to their play program path"""
PRETRAINED_CHECKPOINT_PATH = str(PRETRAINED_CHECKPOINTS_ASSET_ROOT_DIR) + "/Isaac/IsaacLab/PretrainedCheckpoints"
"""URL for where we store all the pre-trained checkpoints"""
"""The filename for checkpoints used by the different workflows"""
WORKFLOW_PRETRAINED_CHECKPOINT_FILENAMES = {
"rl_games": "checkpoint.pth",
"rsl_rl": "checkpoint.pt",
"sb3": "checkpoint.zip",
"skrl": "checkpoint.pt",
}
"""Maps workflow to the agent variable name that determines the logging directory logs/{workflow}/{variable}"""
WORKFLOW_EXPERIMENT_NAME_VARIABLE = {
"rl_games": "agent.params.config.name",
"rsl_rl": "agent.experiment_name",
"sb3": None,
"skrl": "agent.agent.experiment.directory",
}
def has_pretrained_checkpoints_asset_root_dir() -> bool:
"""Returns True if and only if /persistent/isaaclab/asset_root/pretrained_checkpoints exists"""
return PRETRAINED_CHECKPOINTS_ASSET_ROOT_DIR is not None
def get_log_root_path(workflow: str, task_name: str) -> str:
"""Returns the absolute path where the logs are written for a specific workflow and task_name"""
return os.path.abspath(os.path.join("logs", workflow, task_name))
def get_latest_job_run_path(workflow: str, task_name: str) -> str:
"""The local logs path of the most recent run of this workflow and task name"""
log_root_path = get_log_root_path(workflow, task_name)
return _get_latest_file_or_directory(log_root_path)
def get_pretrained_checkpoint_path(workflow: str, task_name: str) -> str:
"""The local logs path where we get the pre-trained checkpoints from"""
path = get_latest_job_run_path(workflow, task_name)
if not path:
return None
if workflow == "rl_games":
return os.path.join(path, "nn", f"{task_name}.pth")
elif workflow == "rsl_rl":
return _get_latest_file_or_directory(path, "*.pt")
elif workflow == "sb3":
return os.path.join(path, "model.zip")
elif workflow == "skrl":
return os.path.join(path, "checkpoints", "best_agent.pt")
else:
raise Exception(f"Unsupported workflow ({workflow})")
def get_pretrained_checkpoint_publish_path(workflow: str, task_name: str) -> str:
"""The path where pre-trained checkpoints are published to"""
return os.path.join(
PRETRAINED_CHECKPOINT_PATH, workflow, task_name, WORKFLOW_PRETRAINED_CHECKPOINT_FILENAMES[workflow]
)
def get_published_pretrained_checkpoint_path(workflow: str, task_name: str) -> str:
"""The path where pre-trained checkpoints are fetched from"""
return os.path.join(
ISAACLAB_NUCLEUS_DIR,
"PretrainedCheckpoints",
workflow,
task_name,
WORKFLOW_PRETRAINED_CHECKPOINT_FILENAMES[workflow],
)
def get_published_pretrained_checkpoint(workflow: str, task_name: str) -> str | None:
"""Gets the path for the pre-trained checkpoint.
If the checkpoint is not cached locally then the file is downloaded.
The cached path is then returned.
Args:
workflow: The workflow.
task_name: The task name.
Returns:
The path.
"""
ov_path = get_published_pretrained_checkpoint_path(workflow, task_name)
download_dir = os.path.join(".pretrained_checkpoints", workflow, task_name)
resume_path = os.path.join(download_dir, WORKFLOW_PRETRAINED_CHECKPOINT_FILENAMES[workflow])
if not os.path.exists(resume_path):
print(f"Fetching pre-trained checkpoint : {ov_path}")
try:
resume_path = retrieve_file_path(ov_path, download_dir)
except Exception:
print("A pre-trained checkpoint is currently unavailable for this task.")
return None
else:
print("Using pre-fetched pre-trained checkpoint")
return resume_path
def has_pretrained_checkpoint_job_run(workflow: str, task_name: str) -> bool:
"""Returns true if an experiment exists in the logs for the workflow and task"""
return os.path.exists(get_log_root_path(workflow, task_name))
def has_pretrained_checkpoint_job_finished(workflow: str, task_name: str) -> bool:
"""Returns true if an experiment has results which may or may not be final depending on workflow"""
local_path = get_pretrained_checkpoint_path(workflow, task_name)
return local_path is not None and os.path.exists(local_path)
def get_pretrained_checkpoint_review_path(workflow: str, task_name: str) -> str | None:
"""The path of the review JSON file for a workflow and task"""
run_path = get_latest_job_run_path(workflow, task_name)
if not run_path:
return None
return os.path.join(run_path, "pretrained_checkpoint_review.json")
def get_pretrained_checkpoint_review(workflow: str, task_name: str) -> dict | None:
"""Returns the review JSON file as a dict if it exists"""
review_path = get_pretrained_checkpoint_review_path(workflow, task_name)
if not review_path:
return None
if os.path.exists(review_path):
with open(review_path) as f:
return json.load(f)
return None
def _get_latest_file_or_directory(path: str, pattern: str = "*"):
"""Returns the path to the most recently modified file or directory at a path matching an optional pattern"""
g = glob.glob(f"{path}/{pattern}")
if len(g):
return max(g, key=os.path.getmtime)
return None
This diff is collapsed.
......@@ -21,6 +21,11 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument(
"--use_pretrained_checkpoint",
action="store_true",
help="Use the pre-trained checkpoint from Nucleus.",
)
parser.add_argument(
"--use_last_checkpoint",
action="store_true",
......@@ -53,6 +58,7 @@ from rl_games.torch_runner import Runner
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
......@@ -72,7 +78,12 @@ def main():
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
# find checkpoint
if args_cli.checkpoint is None:
if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("rl_games", args_cli.task)
if not resume_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return
elif args_cli.checkpoint is None:
# specify directory for logging runs
run_dir = agent_cfg["params"]["config"].get("full_experiment_name", ".*")
# specify name of checkpoint
......
......@@ -23,6 +23,11 @@ parser.add_argument(
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument(
"--use_pretrained_checkpoint",
action="store_true",
help="Use the pre-trained checkpoint from Nucleus.",
)
# append RSL-RL cli arguments
cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args
......@@ -45,7 +50,9 @@ import torch
from rsl_rl.runners import OnPolicyRunner
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg
......@@ -69,7 +76,16 @@ def main():
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("rsl_rl", args_cli.task)
if not resume_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return
elif args_cli.checkpoint:
resume_path = retrieve_file_path(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
log_dir = os.path.dirname(resume_path)
# create isaac environment
......
......@@ -21,6 +21,11 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument(
"--use_pretrained_checkpoint",
action="store_true",
help="Use the pre-trained checkpoint from Nucleus.",
)
parser.add_argument(
"--use_last_checkpoint",
action="store_true",
......@@ -50,6 +55,7 @@ from stable_baselines3.common.vec_env import VecNormalize
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
......@@ -67,8 +73,13 @@ def main():
# directory for logging into
log_root_path = os.path.join("logs", "sb3", args_cli.task)
log_root_path = os.path.abspath(log_root_path)
# check checkpoint is valid
if args_cli.checkpoint is None:
# checkpoint and log_dir stuff
if args_cli.use_pretrained_checkpoint:
checkpoint_path = get_published_pretrained_checkpoint("sb3", args_cli.task)
if not checkpoint_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return
elif args_cli.checkpoint is None:
if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip"
else:
......
......@@ -26,6 +26,11 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument(
"--use_pretrained_checkpoint",
action="store_true",
help="Use the pre-trained checkpoint from Nucleus.",
)
parser.add_argument(
"--ml_framework",
type=str,
......@@ -77,6 +82,7 @@ elif args_cli.ml_framework.startswith("jax"):
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
......@@ -106,7 +112,12 @@ def main():
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
# get checkpoint path
if args_cli.checkpoint:
if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("skrl", args_cli.task)
if not resume_path:
print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.")
return
elif args_cli.checkpoint:
resume_path = os.path.abspath(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment