Unverified Commit 5cb47282 authored by Özhan Özen's avatar Özhan Özen Committed by GitHub

Adds early stopping support for Ray integration (#3276)

# Description

This PR introduces support for early stopping in Ray integration through
the `Stopper` class. It enables trials to end sooner when they are
unlikely to yield useful results, reducing wasted compute time and
speeding up experimentation.

Previously, when running hyperparameter tuning with Ray integration, all
trials would continue until the training configuration’s maximum
iterations were reached, even if a trial was clearly underperforming.
This wasn’t always efficient, since poor-performing trials could often
be identified early on. With this PR, an optional early stopping
mechanism is introduced, allowing Ray to terminate unpromising trials
sooner and improve the overall efficiency of hyperparameter tuning.

The PR also includes a `CartpoleEarlyStopper` example in
`vision_cartpole_cfg.py`. This serves as a reference implementation that
halts a trial if the `out_of_bounds` metric doesn’t reduce after a set
number of iterations. It’s meant as a usage example: users are
encouraged to create their own custom stoppers tailored to their
specific use cases.

Fixes #3270.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatargarylvov <67614381+garylvov@users.noreply.github.com>
Co-authored-by: 's avatargarylvov <gary.lvov@gmail.com>
Co-authored-by: 's avatarsbtc-sipbb <sbtc@sipbb.ch>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent c42fc738
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
import pathlib import pathlib
import sys import sys
from typing import Any
# Allow for import of items from the ray workflow. # Allow for import of items from the ray workflow.
CUR_DIR = pathlib.Path(__file__).parent CUR_DIR = pathlib.Path(__file__).parent
...@@ -12,6 +13,7 @@ sys.path.extend([str(UTIL_DIR), str(CUR_DIR)]) ...@@ -12,6 +13,7 @@ sys.path.extend([str(UTIL_DIR), str(CUR_DIR)])
import util import util
import vision_cfg import vision_cfg
from ray import tune from ray import tune
from ray.tune.stopper import Stopper
class CartpoleRGBNoTuneJobCfg(vision_cfg.CameraJobCfg): class CartpoleRGBNoTuneJobCfg(vision_cfg.CameraJobCfg):
...@@ -47,3 +49,21 @@ class CartpoleTheiaJobCfg(vision_cfg.TheiaCameraJob): ...@@ -47,3 +49,21 @@ class CartpoleTheiaJobCfg(vision_cfg.TheiaCameraJob):
cfg = util.populate_isaac_ray_cfg_args(cfg) cfg = util.populate_isaac_ray_cfg_args(cfg)
cfg["runner_args"]["--task"] = tune.choice(["Isaac-Cartpole-RGB-TheiaTiny-v0"]) cfg["runner_args"]["--task"] = tune.choice(["Isaac-Cartpole-RGB-TheiaTiny-v0"])
super().__init__(cfg) super().__init__(cfg)
class CartpoleEarlyStopper(Stopper):
def __init__(self):
self._bad_trials = set()
def __call__(self, trial_id: str, result: dict[str, Any]) -> bool:
iter = result.get("training_iteration", 0)
out_of_bounds = result.get("Episode/Episode_Termination/cart_out_of_bounds")
# Mark the trial for stopping if conditions are met
if 20 <= iter and out_of_bounds is not None and out_of_bounds > 0.85:
self._bad_trials.add(trial_id)
return trial_id in self._bad_trials
def stop_all(self) -> bool:
return False # only stop individual trials
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import argparse import argparse
import importlib.util import importlib.util
import os import os
import random
import subprocess import subprocess
import sys import sys
from time import sleep, time from time import sleep, time
...@@ -12,8 +13,10 @@ from time import sleep, time ...@@ -12,8 +13,10 @@ from time import sleep, time
import ray import ray
import util import util
from ray import air, tune from ray import air, tune
from ray.tune import Callback
from ray.tune.search.optuna import OptunaSearch from ray.tune.search.optuna import OptunaSearch
from ray.tune.search.repeater import Repeater from ray.tune.search.repeater import Repeater
from ray.tune.stopper import CombinedStopper
""" """
This script breaks down an aggregate tuning job, as defined by a hyperparameter sweep configuration, This script breaks down an aggregate tuning job, as defined by a hyperparameter sweep configuration,
...@@ -60,7 +63,7 @@ WORKFLOW = "scripts/reinforcement_learning/rl_games/train.py" ...@@ -60,7 +63,7 @@ WORKFLOW = "scripts/reinforcement_learning/rl_games/train.py"
NUM_WORKERS_PER_NODE = 1 # needed for local parallelism NUM_WORKERS_PER_NODE = 1 # needed for local parallelism
PROCESS_RESPONSE_TIMEOUT = 200.0 # seconds to wait before killing the process when it stops responding PROCESS_RESPONSE_TIMEOUT = 200.0 # seconds to wait before killing the process when it stops responding
MAX_LINES_TO_SEARCH_EXPERIMENT_LOGS = 1000 # maximum number of lines to read from the training process logs MAX_LINES_TO_SEARCH_EXPERIMENT_LOGS = 1000 # maximum number of lines to read from the training process logs
MAX_LOG_EXTRACTION_ERRORS = 2 # maximum allowed LogExtractionErrors before we abort the whole training MAX_LOG_EXTRACTION_ERRORS = 10 # maximum allowed LogExtractionErrors before we abort the whole training
class IsaacLabTuneTrainable(tune.Trainable): class IsaacLabTuneTrainable(tune.Trainable):
...@@ -203,13 +206,38 @@ class LogExtractionErrorStopper(tune.Stopper): ...@@ -203,13 +206,38 @@ class LogExtractionErrorStopper(tune.Stopper):
return False return False
def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None: class ProcessCleanupCallback(Callback):
"""Callback to clean up processes when trials are stopped."""
def on_trial_error(self, iteration, trials, trial, error, **info):
"""Called when a trial encounters an error."""
self._cleanup_trial(trial)
def on_trial_complete(self, iteration, trials, trial, **info):
"""Called when a trial completes."""
self._cleanup_trial(trial)
def _cleanup_trial(self, trial):
"""Clean up processes for a trial using SIGKILL."""
try:
subprocess.run(["pkill", "-9", "-f", f"rid {trial.config['runner_args']['-rid']}"], check=False)
sleep(5)
except Exception as e:
print(f"[ERROR]: Failed to cleanup trial {trial.trial_id}: {e}")
def invoke_tuning_run(
cfg: dict,
args: argparse.Namespace,
stopper: tune.Stopper | None = None,
) -> None:
"""Invoke an Isaac-Ray tuning run. """Invoke an Isaac-Ray tuning run.
Log either to a local directory or to MLFlow. Log either to a local directory or to MLFlow.
Args: Args:
cfg: Configuration dictionary extracted from job setup cfg: Configuration dictionary extracted from job setup
args: Command-line arguments related to tuning. args: Command-line arguments related to tuning.
stopper: Custom stopper, optional.
""" """
# Allow for early exit # Allow for early exit
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1" os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1"
...@@ -237,16 +265,23 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None: ...@@ -237,16 +265,23 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
) )
repeat_search = Repeater(searcher, repeat=args.repeat_run_count) repeat_search = Repeater(searcher, repeat=args.repeat_run_count)
# Configure the stoppers
stoppers: CombinedStopper = CombinedStopper(*[
LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS),
*([stopper] if stopper is not None else []),
])
if args.run_mode == "local": # Standard config, to file if args.run_mode == "local": # Standard config, to file
run_config = air.RunConfig( run_config = air.RunConfig(
storage_path="/tmp/ray", storage_path="/tmp/ray",
name=f"IsaacRay-{args.cfg_class}-tune", name=f"IsaacRay-{args.cfg_class}-tune",
callbacks=[ProcessCleanupCallback()],
verbose=1, verbose=1,
checkpoint_config=air.CheckpointConfig( checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=0, # Disable periodic checkpointing checkpoint_frequency=0, # Disable periodic checkpointing
checkpoint_at_end=False, # Disable final checkpoint checkpoint_at_end=False, # Disable final checkpoint
), ),
stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS), stop=stoppers,
) )
elif args.run_mode == "remote": # MLFlow, to MLFlow server elif args.run_mode == "remote": # MLFlow, to MLFlow server
...@@ -260,13 +295,14 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None: ...@@ -260,13 +295,14 @@ def invoke_tuning_run(cfg: dict, args: argparse.Namespace) -> None:
run_config = ray.train.RunConfig( run_config = ray.train.RunConfig(
name="mlflow", name="mlflow",
storage_path="/tmp/ray", storage_path="/tmp/ray",
callbacks=[mlflow_callback], callbacks=[ProcessCleanupCallback(), mlflow_callback],
checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=0, checkpoint_at_end=False), checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=0, checkpoint_at_end=False),
stop=LogExtractionErrorStopper(max_errors=MAX_LOG_EXTRACTION_ERRORS), stop=stoppers,
) )
else: else:
raise ValueError("Unrecognized run mode.") raise ValueError("Unrecognized run mode.")
# RID isn't optimized as it is sampled from, but useful for cleanup later
cfg["runner_args"]["-rid"] = tune.sample_from(lambda _: str(random.randint(int(1e9), int(1e10) - 1)))
# Configure the tuning job # Configure the tuning job
tuner = tune.Tuner( tuner = tune.Tuner(
IsaacLabTuneTrainable, IsaacLabTuneTrainable,
...@@ -399,6 +435,12 @@ if __name__ == "__main__": ...@@ -399,6 +435,12 @@ if __name__ == "__main__":
default=MAX_LOG_EXTRACTION_ERRORS, default=MAX_LOG_EXTRACTION_ERRORS,
help="Max number number of LogExtractionError failures before we abort the whole tuning run.", help="Max number number of LogExtractionError failures before we abort the whole tuning run.",
) )
parser.add_argument(
"--stopper",
type=str,
default=None,
help="A stop criteria in the cfg_file, must be a tune.Stopper instance.",
)
args = parser.parse_args() args = parser.parse_args()
PROCESS_RESPONSE_TIMEOUT = args.process_response_timeout PROCESS_RESPONSE_TIMEOUT = args.process_response_timeout
...@@ -457,7 +499,16 @@ if __name__ == "__main__": ...@@ -457,7 +499,16 @@ if __name__ == "__main__":
print(f"[INFO]: Successfully instantiated class '{class_name}' from {file_path}") print(f"[INFO]: Successfully instantiated class '{class_name}' from {file_path}")
cfg = instance.cfg cfg = instance.cfg
print(f"[INFO]: Grabbed the following hyperparameter sweep config: \n {cfg}") print(f"[INFO]: Grabbed the following hyperparameter sweep config: \n {cfg}")
invoke_tuning_run(cfg, args) # Load optional stopper config
stopper = None
if args.stopper and hasattr(module, args.stopper):
stopper = getattr(module, args.stopper)
if isinstance(stopper, type) and issubclass(stopper, tune.Stopper):
stopper = stopper()
else:
raise TypeError(f"[ERROR]: Unsupported stop criteria type: {type(stopper)}")
print(f"[INFO]: Loaded custom stop criteria from '{args.stopper}'")
invoke_tuning_run(cfg, args, stopper=stopper)
else: else:
raise AttributeError(f"[ERROR]:Class '{class_name}' not found in {file_path}") raise AttributeError(f"[ERROR]:Class '{class_name}' not found in {file_path}")
...@@ -71,7 +71,7 @@ def get_invocation_command_from_cfg( ...@@ -71,7 +71,7 @@ def get_invocation_command_from_cfg(
if not is_hydra: if not is_hydra:
if key.endswith("_singleton"): if key.endswith("_singleton"):
target_list.append(value) target_list.append(value)
elif key.startswith("--"): elif key.startswith("--") or key.startswith("-"):
target_list.append(f"{key} {value}") # Space instead of = for runner args target_list.append(f"{key} {value}") # Space instead of = for runner args
else: else:
target_list.append(f"{value}") target_list.append(f"{value}")
......
...@@ -42,6 +42,9 @@ parser.add_argument( ...@@ -42,6 +42,9 @@ parser.add_argument(
help="if toggled, this experiment will be tracked with Weights and Biases", help="if toggled, this experiment will be tracked with Weights and Biases",
) )
parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.") parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.")
parser.add_argument(
"--ray-proc-id", "-rid", type=int, default=None, help="Automatically configured by Ray integration, otherwise None."
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
......
...@@ -31,6 +31,9 @@ parser.add_argument( ...@@ -31,6 +31,9 @@ parser.add_argument(
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes." "--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
) )
parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.") parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.")
parser.add_argument(
"--ray-proc-id", "-rid", type=int, default=None, help="Automatically configured by Ray integration, otherwise None."
)
# append RSL-RL cli arguments # append RSL-RL cli arguments
cli_args.add_rsl_rl_args(parser) cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args # append AppLauncher cli args
......
...@@ -37,6 +37,9 @@ parser.add_argument( ...@@ -37,6 +37,9 @@ parser.add_argument(
default=False, default=False,
help="Use a slower SB3 wrapper but keep all the extra training info.", help="Use a slower SB3 wrapper but keep all the extra training info.",
) )
parser.add_argument(
"--ray-proc-id", "-rid", type=int, default=None, help="Automatically configured by Ray integration, otherwise None."
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
......
...@@ -54,7 +54,9 @@ parser.add_argument( ...@@ -54,7 +54,9 @@ parser.add_argument(
choices=["AMP", "PPO", "IPPO", "MAPPO"], choices=["AMP", "PPO", "IPPO", "MAPPO"],
help="The RL algorithm used for training the skrl agent.", help="The RL algorithm used for training the skrl agent.",
) )
parser.add_argument(
"--ray-proc-id", "-rid", type=int, default=None, help="Automatically configured by Ray integration, otherwise None."
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment