Unverified Commit 4049aa8d authored by Özhan Özen's avatar Özhan Özen Committed by GitHub

Adds support for custom `ProgressReporter` to Ray integration (#3269)

# Description

This PR adds support for providing a custom `ProgressReporter` while
doing hyperparameter tuning with Ray Integration.

Without the PR, the Ray integration defaults to the standard
`CLIReporter`, which often displays metrics that aren’t particularly
relevant or at the desired frequency. Similar to how we allow users to
specify a --cfg_class (e.g., `CartpoleTheiaJobCfg`), this PR lets them
optionally provide a custom `ProgressReporter` class. If such is not
provided, it falls back to the default.

Moreover, I have added an example inside `vision_cartpole_cfg.py` (i.e.,
`CustomCartpoleProgressReporter`).

One point to highlight is that the new "[context-aware progress
reporting](https://github.com/ray-project/ray/issues/36949)" conflicts
with custom `ProgressReporter`, so if a custom `ProgressReporter` is
provided, the PR disables the context-aware progress reporting.

Fixes #3268.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatargarylvov <67614381+garylvov@users.noreply.github.com>
parent 1de99635
......@@ -65,7 +65,7 @@ The three following files contain the core functionality of the Ray integration.
.. literalinclude:: ../../../scripts/reinforcement_learning/ray/tuner.py
:language: python
:emphasize-lines: 18-54
:emphasize-lines: 18-59
.. dropdown:: scripts/reinforcement_learning/ray/task_runner.py
:icon: code
......
......@@ -13,6 +13,7 @@ sys.path.extend([str(UTIL_DIR), str(CUR_DIR)])
import util
import vision_cfg
from ray import tune
from ray.tune.progress_reporter import CLIReporter
from ray.tune.stopper import Stopper
......@@ -51,6 +52,21 @@ class CartpoleTheiaJobCfg(vision_cfg.TheiaCameraJob):
super().__init__(cfg)
class CustomCartpoleProgressReporter(CLIReporter):
def __init__(self):
super().__init__(
metric_columns={
"training_iteration": "iter",
"time_total_s": "total time (s)",
"Episode/Episode_Reward/alive": "alive",
"Episode/Episode_Reward/cart_vel": "cart velocity",
"rewards/time": "rewards/time",
},
max_report_frequency=5,
sort_by_metric=True,
)
class CartpoleEarlyStopper(Stopper):
def __init__(self):
self._bad_trials = set()
......
......@@ -14,6 +14,7 @@ import ray
import util
from ray import air, tune
from ray.tune import Callback
from ray.tune.progress_reporter import ProgressReporter
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search.repeater import Repeater
from ray.tune.stopper import CombinedStopper
......@@ -48,6 +49,11 @@ Usage:
./isaaclab.sh -p scripts/reinforcement_learning/ray/tuner.py --run_mode local \
--cfg_file scripts/reinforcement_learning/ray/hyperparameter_tuning/vision_cartpole_cfg.py \
--cfg_class CartpoleTheiaJobCfg
# Local with a custom progress reporter
./isaaclab.sh -p scripts/reinforcement_learning/ray/tuner.py \
--cfg_file scripts/reinforcement_learning/ray/hyperparameter_tuning/vision_cartpole_cfg.py \
--cfg_class CartpoleTheiaJobCfg \
--progress_reporter CustomCartpoleProgressReporter
# Remote (run grok cluster or create config file mentioned in :file:`submit_job.py`)
./isaaclab.sh -p scripts/reinforcement_learning/ray/submit_job.py \
--aggregate_jobs tuner.py \
......@@ -229,6 +235,7 @@ class ProcessCleanupCallback(Callback):
def invoke_tuning_run(
cfg: dict,
args: argparse.Namespace,
progress_reporter: ProgressReporter | None = None,
stopper: tune.Stopper | None = None,
) -> None:
"""Invoke an Isaac-Ray tuning run.
......@@ -237,6 +244,7 @@ def invoke_tuning_run(
Args:
cfg: Configuration dictionary extracted from job setup
args: Command-line arguments related to tuning.
progress_reporter: Custom progress reporter. Defaults to CLIReporter or JupyterNotebookReporter if not provided.
stopper: Custom stopper, optional.
"""
# Allow for early exit
......@@ -271,6 +279,17 @@ def invoke_tuning_run(
*([stopper] if stopper is not None else []),
])
if progress_reporter is not None:
os.environ["RAY_AIR_NEW_OUTPUT"] = "0"
if (
getattr(progress_reporter, "_metric", None) is not None
or getattr(progress_reporter, "_mode", None) is not None
):
raise ValueError(
"Do not set <metric> or <mode> directly in the custom progress reporter class, "
"provide them as arguments to tuner.py instead."
)
if args.run_mode == "local": # Standard config, to file
run_config = air.RunConfig(
storage_path="/tmp/ray",
......@@ -282,6 +301,7 @@ def invoke_tuning_run(
checkpoint_at_end=False, # Disable final checkpoint
),
stop=stoppers,
progress_reporter=progress_reporter,
)
elif args.run_mode == "remote": # MLFlow, to MLFlow server
......@@ -298,6 +318,7 @@ def invoke_tuning_run(
callbacks=[ProcessCleanupCallback(), mlflow_callback],
checkpoint_config=ray.train.CheckpointConfig(checkpoint_frequency=0, checkpoint_at_end=False),
stop=stoppers,
progress_reporter=progress_reporter,
)
else:
raise ValueError("Unrecognized run mode.")
......@@ -435,6 +456,16 @@ if __name__ == "__main__":
default=MAX_LOG_EXTRACTION_ERRORS,
help="Max number number of LogExtractionError failures before we abort the whole tuning run.",
)
parser.add_argument(
"--progress_reporter",
type=str,
default=None,
help=(
"Optional: name of a custom reporter class defined in the cfg_file. "
"Must subclass ray.tune.ProgressReporter "
"(e.g., CustomCartpoleProgressReporter)."
),
)
parser.add_argument(
"--stopper",
type=str,
......@@ -508,7 +539,16 @@ if __name__ == "__main__":
else:
raise TypeError(f"[ERROR]: Unsupported stop criteria type: {type(stopper)}")
print(f"[INFO]: Loaded custom stop criteria from '{args.stopper}'")
invoke_tuning_run(cfg, args, stopper=stopper)
# Load optional progress reporter config
progress_reporter = None
if args.progress_reporter and hasattr(module, args.progress_reporter):
progress_reporter = getattr(module, args.progress_reporter)
if isinstance(progress_reporter, type) and issubclass(progress_reporter, tune.ProgressReporter):
progress_reporter = progress_reporter()
else:
raise TypeError(f"[ERROR]: {args.progress_reporter} is not a valid ProgressReporter.")
print(f"[INFO]: Loaded custom progress reporter from '{args.progress_reporter}'")
invoke_tuning_run(cfg, args, progress_reporter=progress_reporter, stopper=stopper)
else:
raise AttributeError(f"[ERROR]:Class '{class_name}' not found in {file_path}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment