Unverified Commit 40c8d16d authored by ooctipus's avatar ooctipus Committed by GitHub

Adds PBT algorithm to rl games (#3399)

# Description

This PR introduces the Population Based Training algorithm originally
implemented in

Petrenko, Aleksei, et al. "Dexpbt: Scaling up dexterous manipulation for
hand-arm systems with population based training." arXiv preprint
arXiv:2305.12127 (2023).

Pbt algorithm offers a alternative to scaling when increasing number of
environment has margin effect.
It takes idea in natural selection and stochastic property in
rl-training to always keeps the top performing agent while replace weak
agent with top performance to overcome the catastrophic failure, and
improve the exploration.

Training view, underperformers are rescued by best performers and later
surpasses them and become best performers
<img width="1078" height="509" alt="Screenshot from 2025-09-09 00-55-11"
src="https://github.com/user-attachments/assets/34434bf1-5cb6-4956-a344-49c9969d4861"
/>


Note:
PBT is still at beta phase and has below limitations:

1. in theory It can work with any rl algorithm but current
implementation only works for rl-games
2. The API could be furthur simplified without needing explicitly input
num_policies or policy_idx, which allows for dynamic max_population, but
it is for future work

## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent c7dde1b7
......@@ -116,6 +116,7 @@ Table of Contents
source/features/hydra
source/features/multi_gpu
source/features/population_based_training
Tiled Rendering</source/overview/core-concepts/sensors/camera>
source/features/ray
source/features/reproducibility
......
Population Based Training
=========================
What PBT Does
-------------
* Trains *N* policies in parallel (a "population") on the **same task**.
* Every ``interval_steps``:
#. Save each policy's checkpoint and objective.
#. Score the population and identify **leaders** and **underperformers**.
#. For underperformers, replace weights from a random leader and **mutate** selected hyperparameters.
#. Restart that process with the new weights/params automatically.
Leader / Underperformer Selection
---------------------------------
Let ``o_i`` be each initialized policy's objective, with mean ``μ`` and std ``σ``.
Upper and lower performance cuts are::
upper_cut = max(μ + threshold_std * σ, μ + threshold_abs)
lower_cut = min(μ - threshold_std * σ, μ - threshold_abs)
* **Leaders**: ``o_i > upper_cut``
* **Underperformers**: ``o_i < lower_cut``
The "Natural-Selection" rules:
1. Only underperformers are acted on (mutated or replaced).
2. If leaders exist, replace an underperformer with a random leader; otherwise, self-mutate.
Mutation (Hyperparameters)
--------------------------
* Each param has a mutation function (e.g., ``mutate_float``, ``mutate_discount``, etc.).
* A param is mutated with probability ``mutation_rate``.
* When mutated, its value is perturbed within ``change_range = (min, max)``.
* Only whitelisted keys (from the PBT config) are considered.
Example Config
--------------
.. code-block:: yaml
pbt:
enabled: True
policy_idx: 0
num_policies: 8
directory: .
workspace: "pbt_workspace"
objective: Curriculum/difficulty_level
interval_steps: 50000000
threshold_std: 0.1
threshold_abs: 0.025
mutation_rate: 0.25
change_range: [1.1, 2.0]
mutation:
agent.params.config.learning_rate: "mutate_float"
agent.params.config.grad_norm: "mutate_float"
agent.params.config.entropy_coef: "mutate_float"
agent.params.config.critic_coef: "mutate_float"
agent.params.config.bounds_loss_coef: "mutate_float"
agent.params.config.kl_threshold: "mutate_float"
agent.params.config.gamma: "mutate_discount"
agent.params.config.tau: "mutate_discount"
``objective: Curriculum/difficulty_level`` uses ``infos["episode"]["Curriculum/difficulty_level"]`` as the scalar to
**rank policies** (higher is better). With ``num_policies: 8``, launch eight processes sharing the same ``workspace``
and unique ``policy_idx`` (0-7).
Launching PBT
-------------
You must start **one process per policy** and point them to the **same workspace**. Set a unique
``policy_idx`` for each process and the common ``num_policies``.
Minimal flags you need:
* ``agent.pbt.enabled=True``
* ``agent.pbt.workspace=<path/to/shared_folder>``
* ``agent.pbt.policy_idx=<0..num_policies-1>``
* ``agent.pbt.num_policies=<N>``
.. note::
All processes must use the same ``agent.pbt.workspace`` so they can see each other's checkpoints.
.. caution::
PBT is currently supported **only** with the **rl_games** library. Other RL libraries are not supported yet.
Tips
----
* Keep checkpoints fast: reduce ``interval_steps`` only if you really need tighter PBT cadence.
* It is recommended to run 6+ workers to see benefit of pbt
References
----------
This PBT implementation reimplements and is inspired by *Dexpbt: Scaling up dexterous manipulation for hand-arm systems with population based training* (Petrenko et al., 2023).
.. code-block:: bibtex
@article{petrenko2023dexpbt,
title={Dexpbt: Scaling up dexterous manipulation for hand-arm systems with population based training},
author={Petrenko, Aleksei and Allshire, Arthur and State, Gavriel and Handa, Ankur and Makoviychuk, Viktor},
journal={arXiv preprint arXiv:2305.12127},
year={2023}
}
......@@ -81,7 +81,7 @@ from isaaclab.utils.assets import retrieve_file_path
from isaaclab.utils.dict import print_dict
from isaaclab.utils.io import dump_pickle, dump_yaml
from isaaclab_rl.rl_games import RlGamesGpuEnv, RlGamesVecEnvWrapper
from isaaclab_rl.rl_games import MultiObserver, PbtAlgoObserver, RlGamesGpuEnv, RlGamesVecEnvWrapper
import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils.hydra import hydra_task_config
......@@ -127,7 +127,12 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# specify directory for logging experiments
config_name = agent_cfg["params"]["config"]["name"]
log_root_path = os.path.join("logs", "rl_games", config_name)
if "pbt" in agent_cfg:
if agent_cfg["pbt"]["directory"] == ".":
log_root_path = os.path.abspath(log_root_path)
else:
log_root_path = os.path.join(agent_cfg["pbt"]["directory"], log_root_path)
print(f"[INFO] Logging experiment in directory: {log_root_path}")
# specify directory for logging runs
log_dir = agent_cfg["params"]["config"].get("full_experiment_name", datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
......@@ -192,7 +197,13 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set number of actors into agent config
agent_cfg["params"]["config"]["num_actors"] = env.unwrapped.num_envs
# create runner from rl-games
if "pbt" in agent_cfg and agent_cfg["pbt"]["enabled"]:
observers = MultiObserver([IsaacAlgoObserver(), PbtAlgoObserver(agent_cfg, args_cli)])
runner = Runner(observers)
else:
runner = Runner(IsaacAlgoObserver())
runner.load(agent_cfg)
# reset the agent and env
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.3.0"
version = "0.4.0"
# Description
title = "Isaac Lab RL"
......
Changelog
---------
0.4.0 (2025-09-09)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Introduced PBT to rl-games.
0.3.0 (2025-09-03)
~~~~~~~~~~~~~~~~~~
......
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Wrappers and utilities to configure an environment for rl-games library."""
from .pbt import *
from .rl_games import *
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from .pbt import MultiObserver, PbtAlgoObserver
from .pbt_cfg import PbtCfg
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import random
from collections.abc import Callable
from typing import Any
def mutate_float(x: float, change_min: float = 1.1, change_max: float = 1.5) -> float:
"""Multiply or divide by a random factor in [change_min, change_max]."""
k = random.uniform(change_min, change_max)
return x / k if random.random() < 0.5 else x * k
def mutate_discount(x: float, **kwargs) -> float:
"""Conservative change near 1.0 by mutating (1 - x) in [1.1, 1.2]."""
inv = 1.0 - x
new_inv = mutate_float(inv, change_min=1.1, change_max=1.2)
return 1.0 - new_inv
MUTATION_FUNCS: dict[str, Callable[..., Any]] = {
"mutate_float": mutate_float,
"mutate_discount": mutate_discount,
}
def mutate(
params: dict[str, Any],
mutations: dict[str, str],
mutation_rate: float,
change_range: tuple[float, float],
) -> dict[str, Any]:
cmin, cmax = change_range
out: dict[str, Any] = {}
for name, val in params.items():
fn_name = mutations.get(name)
# skip if no rule or coin flip says "no"
if fn_name is None or random.random() > mutation_rate:
out[name] = val
continue
fn = MUTATION_FUNCS.get(fn_name)
if fn is None:
raise KeyError(f"Unknown mutation function: {fn_name!r}")
out[name] = fn(val, change_min=cmin, change_max=cmax)
return out
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
import os
import random
import sys
import torch
import torch.distributed as dist
from rl_games.common.algo_observer import AlgoObserver
from . import pbt_utils
from .mutation import mutate
from .pbt_cfg import PbtCfg
# i.e. value for target objective when it is not known
_UNINITIALIZED_VALUE = float(-1e9)
class PbtAlgoObserver(AlgoObserver):
"""rl_games observer that implements Population-Based Training for a single policy process."""
def __init__(self, params, args_cli):
"""Initialize observer, print the mutation table, and allocate the restart flag.
Args:
params (dict): Full agent/task params (Hydra style).
args_cli: Parsed CLI args used to reconstruct a restart command.
"""
super().__init__()
self.printer = pbt_utils.PbtTablePrinter()
self.dir = params["pbt"]["directory"]
self.rendering_args = pbt_utils.RenderingArgs(args_cli)
self.wandb_args = pbt_utils.WandbArgs(args_cli)
self.env_args = pbt_utils.EnvArgs(args_cli)
self.distributed_args = pbt_utils.DistributedArgs(args_cli)
self.cfg = PbtCfg(**params["pbt"])
self.pbt_it = -1 # dummy value, stands for "not initialized"
self.score = _UNINITIALIZED_VALUE
self.pbt_params = pbt_utils.filter_params(pbt_utils.flatten_dict({"agent": params}), self.cfg.mutation)
assert len(self.pbt_params) > 0, "[DANGER]: Dictionary that contains params to mutate is empty"
self.printer.print_params_table(self.pbt_params, header="List of params to mutate")
self.device = params["params"]["config"]["device"]
self.restart_flag = torch.tensor([0], device=self.device)
def after_init(self, algo):
"""Capture training directories on rank 0 and create this policy's workspace folder.
Args:
algo: rl_games algorithm object (provides writer, train_dir, frame counter, etc.).
"""
if self.distributed_args.rank != 0:
return
self.algo = algo
self.root_dir = algo.train_dir
self.ws_dir = os.path.join(self.root_dir, self.cfg.workspace)
self.curr_policy_dir = os.path.join(self.ws_dir, f"{self.cfg.policy_idx:03d}")
os.makedirs(self.curr_policy_dir, exist_ok=True)
def process_infos(self, infos, done_indices):
"""Extract the scalar objective from environment infos and store in `self.score`.
Notes:
Expects the objective to be at `infos["episode"][self.cfg.objective]`.
"""
self.score = infos["episode"][self.cfg.objective]
def after_steps(self):
"""Main PBT tick executed every train step.
Flow:
1) Non-zero ranks: exit immediately if `restart_flag == 1`, else return.
2) Rank 0: if `restart_flag == 1`, restart this process with new params.
3) Rank 0: on PBT cadence boundary (`interval_steps`), save checkpoint,
load population checkpoints, compute bands, and if this policy is an
underperformer, select a replacement (random leader or self), mutate
whitelisted params, set `restart_flag`, broadcast (if distributed),
and print a mutation diff table.
"""
if self.distributed_args.rank != 0:
if self.restart_flag.cpu().item() == 1:
os._exit(0)
return
elif self.restart_flag.cpu().item() == 1:
self._restart_with_new_params(self.new_params, self.restart_from_checkpoint)
return
# Non-zero can continue
if self.distributed_args.rank != 0:
return
if self.pbt_it == -1:
self.pbt_it = self.algo.frame // self.cfg.interval_steps
return
if self.algo.frame // self.cfg.interval_steps <= self.pbt_it:
return
self.pbt_it = self.algo.frame // self.cfg.interval_steps
frame_left = (self.pbt_it + 1) * self.cfg.interval_steps - self.algo.frame
print(f"Policy {self.cfg.policy_idx}, frames_left {frame_left}, PBT it {self.pbt_it}")
try:
pbt_utils.save_pbt_checkpoint(self.curr_policy_dir, self.score, self.pbt_it, self.algo, self.pbt_params)
ckpts = pbt_utils.load_pbt_ckpts(self.ws_dir, self.cfg.policy_idx, self.cfg.num_policies, self.pbt_it)
pbt_utils.cleanup(ckpts, self.curr_policy_dir)
except Exception as exc:
print(f"Policy {self.cfg.policy_idx}: Exception {exc} during sanity log!")
return
sumry = {i: None if c is None else {k: v for k, v in c.items() if k != "params"} for i, c in ckpts.items()}
self.printer.print_ckpt_summary(sumry)
policies = list(range(self.cfg.num_policies))
target_objectives = [ckpts[p]["true_objective"] if ckpts[p] else _UNINITIALIZED_VALUE for p in policies]
initialized = [(obj, p) for obj, p in zip(target_objectives, policies) if obj > _UNINITIALIZED_VALUE]
if not initialized:
print("No policies initialized; skipping PBT iteration.")
return
initialized_objectives, initialized_policies = zip(*initialized)
# 1) Stats
mean_obj = float(np.mean(initialized_objectives))
std_obj = float(np.std(initialized_objectives))
upper_cut = max(mean_obj + self.cfg.threshold_std * std_obj, mean_obj + self.cfg.threshold_abs)
lower_cut = min(mean_obj - self.cfg.threshold_std * std_obj, mean_obj - self.cfg.threshold_abs)
leaders = [p for obj, p in zip(initialized_objectives, initialized_policies) if obj > upper_cut]
underperformers = [p for obj, p in zip(initialized_objectives, initialized_policies) if obj < lower_cut]
print(f"mean={mean_obj:.4f}, std={std_obj:.4f}, upper={upper_cut:.4f}, lower={lower_cut:.4f}")
print(f"Leaders: {leaders} Underperformers: {underperformers}")
# 3) Only replace if *this* policy is an underperformer
if self.cfg.policy_idx in underperformers:
# 4) If there are any leaders, pick one at random; else simply mutate with no replacement
replacement_policy_candidate = random.choice(leaders) if leaders else self.cfg.policy_idx
print(f"Replacing policy {self.cfg.policy_idx} with {replacement_policy_candidate}.")
if self.distributed_args.rank == 0:
for param, value in self.pbt_params.items():
self.algo.writer.add_scalar(f"pbt/{param}", value, self.algo.frame)
self.algo.writer.add_scalar("pbt/00_best_objective", max(initialized_objectives), self.algo.frame)
self.algo.writer.flush()
# Decided to replace the policy weights!
cur_params = ckpts[replacement_policy_candidate]["params"]
self.new_params = mutate(cur_params, self.cfg.mutation, self.cfg.mutation_rate, self.cfg.change_range)
self.restart_from_checkpoint = os.path.abspath(ckpts[replacement_policy_candidate]["checkpoint"])
self.restart_flag[0] = 1
if self.distributed_args.distributed:
dist.broadcast(self.restart_flag, src=0)
self.printer.print_mutation_diff(cur_params, self.new_params)
def _restart_with_new_params(self, new_params, restart_from_checkpoint):
"""Re-exec the current process with a filtered/augmented CLI to apply new params.
Notes:
- Filters out existing Hydra-style overrides that will be replaced,
and appends `--checkpoint=<path>` and new param overrides.
- On distributed runs, assigns a fresh master port and forwards
distributed args to the python.sh launcher.
"""
cli_args = sys.argv
print(f"previous command line args: {cli_args}")
SKIP = ["checkpoint"]
is_hydra = lambda arg: ( # noqa: E731
(name := arg.split("=", 1)[0]) not in new_params and not any(k in name for k in SKIP)
)
modified_args = [cli_args[0]] + [arg for arg in cli_args[1:] if "=" not in arg or is_hydra(arg)]
modified_args.append(f"--checkpoint={restart_from_checkpoint}")
modified_args.extend(self.wandb_args.get_args_list())
modified_args.extend(self.rendering_args.get_args_list())
# add all of the new (possibly mutated) parameters
for param, value in new_params.items():
modified_args.append(f"{param}={value}")
self.algo.writer.flush()
self.algo.writer.close()
if self.wandb_args.enabled:
import wandb
wandb.run.finish()
# Get the directory of the current file
thisfile_dir = os.path.dirname(os.path.abspath(__file__))
isaac_sim_path = os.path.abspath(os.path.join(thisfile_dir, "../../../../../_isaac_sim"))
command = [f"{isaac_sim_path}/python.sh"]
if self.distributed_args.distributed:
self.distributed_args.master_port = str(pbt_utils.find_free_port())
command.extend(self.distributed_args.get_args_list())
command += [modified_args[0]]
command.extend(self.env_args.get_args_list())
command += modified_args[1:]
if self.distributed_args.distributed:
command += ["--distributed"]
print("Running command:", command, flush=True)
print("sys.executable = ", sys.executable)
print(f"Policy {self.cfg.policy_idx}: Restarting self with args {modified_args}", flush=True)
if self.distributed_args.rank == 0:
pbt_utils.dump_env_sizes()
# after any sourcing (or before exec’ing python.sh) prevent kept increasing arg_length:
for var in ("PATH", "PYTHONPATH", "LD_LIBRARY_PATH", "OMNI_USD_RESOLVER_MDL_BUILTIN_PATHS"):
val = os.environ.get(var)
if not val or os.pathsep not in val:
continue
seen = set()
new_parts = []
for p in val.split(os.pathsep):
if p and p not in seen:
seen.add(p)
new_parts.append(p)
os.environ[var] = os.pathsep.join(new_parts)
os.execv(f"{isaac_sim_path}/python.sh", command)
class MultiObserver(AlgoObserver):
"""Meta-observer that allows the user to add several observers."""
def __init__(self, observers_):
super().__init__()
self.observers = observers_
def _call_multi(self, method, *args_, **kwargs_):
for o in self.observers:
getattr(o, method)(*args_, **kwargs_)
def before_init(self, base_name, config, experiment_name):
self._call_multi("before_init", base_name, config, experiment_name)
def after_init(self, algo):
self._call_multi("after_init", algo)
def process_infos(self, infos, done_indices):
self._call_multi("process_infos", infos, done_indices)
def after_steps(self):
self._call_multi("after_steps")
def after_clear_stats(self):
self._call_multi("after_clear_stats")
def after_print_stats(self, frame, epoch_num, total_time):
self._call_multi("after_print_stats", frame, epoch_num, total_time)
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from isaaclab.utils import configclass
@configclass
class PbtCfg:
"""
Population-Based Training (PBT) configuration.
leaders are policies with score > max(mean + threshold_std*std, mean + threshold_abs).
underperformers are policies with score < min(mean - threshold_std*std, mean - threshold_abs).
On replacement, selected hyperparameters are mutated multiplicatively in [change_min, change_max].
"""
enabled: bool = False
"""Enable/disable PBT logic."""
policy_idx: int = 0
"""Index of this learner in the population (unique in [0, num_policies-1])."""
num_policies: int = 8
"""Total number of learners participating in PBT."""
directory: str = ""
"""Root directory for PBT artifacts (checkpoints, metadata)."""
workspace: str = "pbt_workspace"
"""Subfolder under the training dir to isolate this PBT run."""
objective: str = "Episode_Reward/success"
"""The key in info returned by env.step that pbt measures to determine leaders and underperformers,
If reward is stationary, using the term that corresponds to task success is usually enough, when reward
are non-stationary, consider uses better objectives.
"""
interval_steps: int = 100_000
"""Environment steps between PBT iterations (save, compare, replace/mutate)."""
threshold_std: float = 0.10
"""Std-based margin k in max(mean ± k·std, mean ± threshold_abs) for leader/underperformer cuts."""
threshold_abs: float = 0.05
"""Absolute margin A in max(mean ± threshold_std·std, mean ± A) for leader/underperformer cuts."""
mutation_rate: float = 0.25
"""Per-parameter probability of mutation when a policy is replaced."""
change_range: tuple[float, float] = (1.1, 2.0)
"""Lower and upper bound of multiplicative change factor (sampled in [change_min, change_max])."""
mutation: dict[str, str] = {}
"""Mutation strings indicating which parameter will be mutated when pbt restart
example:
{
"agent.params.config.learning_rate": "mutate_float"
"agent.params.config.grad_norm": "mutate_float"
"agent.params.config.entropy_coef": "mutate_float"
}
"""
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import datetime
import os
import random
import socket
import yaml
from collections import OrderedDict
from pathlib import Path
from prettytable import PrettyTable
from rl_games.algos_torch.torch_ext import safe_filesystem_op, safe_save
class DistributedArgs:
def __init__(self, args_cli):
self.distributed = args_cli.distributed
self.nproc_per_node = int(os.environ.get("WORLD_SIZE", 1))
self.rank = int(os.environ.get("RANK", 0))
self.nnodes = 1
self.master_port = getattr(args_cli, "master_port", None)
def get_args_list(self) -> list[str]:
args = ["-m", "torch.distributed.run", f"--nnodes={self.nnodes}", f"--nproc_per_node={self.nproc_per_node}"]
if self.master_port:
args.append(f"--master_port={self.master_port}")
return args
class EnvArgs:
def __init__(self, args_cli):
self.task = args_cli.task
self.seed = args_cli.seed if args_cli.seed is not None else -1
self.headless = args_cli.headless
self.num_envs = args_cli.num_envs
def get_args_list(self) -> list[str]:
list = []
list.append(f"--task={self.task}")
list.append(f"--seed={self.seed}")
list.append(f"--num_envs={self.num_envs}")
if self.headless:
list.append("--headless")
return list
class RenderingArgs:
def __init__(self, args_cli):
self.camera_enabled = args_cli.enable_cameras
self.video = args_cli.video
self.video_length = args_cli.video_length
self.video_interval = args_cli.video_interval
def get_args_list(self) -> list[str]:
args = []
if self.camera_enabled:
args.append("--enable_cameras")
if self.video:
args.extend(["--video", f"--video_length={self.video_length}", f"--video_interval={self.video_interval}"])
return args
class WandbArgs:
def __init__(self, args_cli):
self.enabled = args_cli.track
self.project_name = args_cli.wandb_project_name
self.name = args_cli.wandb_name
self.entity = args_cli.wandb_entity
def get_args_list(self) -> list[str]:
args = []
if self.enabled:
args.append("--track")
if self.entity:
args.append(f"--wandb-entity={self.entity}")
else:
raise ValueError("entity must be specified if wandb is enabled")
if self.project_name:
args.append(f"--wandb-project-name={self.project_name}")
if self.name:
args.append(f"--wandb-name={self.name}")
return args
def dump_env_sizes():
"""Print summary of environment variable usage (count, bytes, top-5 largest, SC_ARG_MAX)."""
n = len(os.environ)
# total bytes in "KEY=VAL\0" for all envp entries
total = sum(len(k) + 1 + len(v) + 1 for k, v in os.environ.items())
# find the 5 largest values
biggest = sorted(os.environ.items(), key=lambda kv: len(kv[1]), reverse=True)[:5]
print(f"[ENV MONITOR] vars={n}, total_bytes={total}")
for k, v in biggest:
print(f" {k!r} length={len(v)} → {v[:60]}{'…' if len(v) > 60 else ''}")
try:
argmax = os.sysconf("SC_ARG_MAX")
print(f"[ENV MONITOR] SC_ARG_MAX = {argmax}")
except (ValueError, AttributeError):
pass
def flatten_dict(d, prefix="", separator="."):
"""Flatten nested dictionaries into a flat dict with keys joined by `separator`."""
res = dict()
for key, value in d.items():
if isinstance(value, (dict, OrderedDict)):
res.update(flatten_dict(value, prefix + key + separator, separator))
else:
res[prefix + key] = value
return res
def find_free_port(max_tries: int = 20) -> int:
"""Return an OS-assigned free TCP port, with a few retries; fall back to a random high port."""
for _ in range(max_tries):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
continue
return random.randint(20000, 65000)
def filter_params(params, params_to_mutate):
"""Filter `params` to only those in `params_to_mutate`, converting str floats (e.g. '1e-4') to float."""
def try_float(v):
if isinstance(v, str):
try:
return float(v)
except ValueError:
return v
return v
return {k: try_float(v) for k, v in params.items() if k in params_to_mutate}
def save_pbt_checkpoint(workspace_dir, curr_policy_score, curr_iter, algo, params):
"""Save a PBT checkpoint (.pth and .yaml) with policy state, score, and metadata (rank 0 only)."""
if int(os.environ.get("RANK", "0")) == 0:
checkpoint_file = os.path.join(workspace_dir, f"{curr_iter:06d}.pth")
safe_save(algo.get_full_state_weights(), checkpoint_file)
pbt_checkpoint_file = os.path.join(workspace_dir, f"{curr_iter:06d}.yaml")
pbt_checkpoint = {
"iteration": curr_iter,
"true_objective": curr_policy_score,
"frame": algo.frame,
"params": params,
"checkpoint": os.path.abspath(checkpoint_file),
"pbt_checkpoint": os.path.abspath(pbt_checkpoint_file),
"experiment_name": algo.experiment_name,
}
with open(pbt_checkpoint_file, "w") as fobj:
yaml.dump(pbt_checkpoint, fobj)
def load_pbt_ckpts(workspace_dir, cur_policy_id, num_policies, pbt_iteration) -> dict | None:
"""
Load the latest available PBT checkpoint for each policy (≤ current iteration).
Returns a dict mapping policy_idx → checkpoint dict or None. (rank 0 only)
"""
if int(os.environ.get("RANK", "0")) != 0:
return None
checkpoints = dict()
for policy_idx in range(num_policies):
checkpoints[policy_idx] = None
policy_dir = os.path.join(workspace_dir, f"{policy_idx:03d}")
if not os.path.isdir(policy_dir):
continue
pbt_checkpoint_files = sorted([f for f in os.listdir(policy_dir) if f.endswith(".yaml")], reverse=True)
for pbt_checkpoint_file in pbt_checkpoint_files:
iteration = int(pbt_checkpoint_file.split(".")[0])
# current local time
now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
ctime_ts = os.path.getctime(os.path.join(policy_dir, pbt_checkpoint_file))
created_str = datetime.datetime.fromtimestamp(ctime_ts).strftime("%Y-%m-%d %H:%M:%S")
if iteration <= pbt_iteration:
with open(os.path.join(policy_dir, pbt_checkpoint_file)) as fobj:
now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(
f"Policy {cur_policy_id} [{now_str}]: Loading"
f" policy-{policy_idx} {pbt_checkpoint_file} (created at {created_str})"
)
checkpoints[policy_idx] = safe_filesystem_op(yaml.load, fobj, Loader=yaml.FullLoader)
break
return checkpoints
def cleanup(checkpoints: dict[int, dict], policy_dir, keep_back: int = 20, max_yaml: int = 50) -> None:
"""
Cleanup old checkpoints for the current policy directory (rank 0 only).
- Delete files older than (oldest iteration - keep_back).
- Keep at most `max_yaml` latest YAML iterations.
"""
if int(os.environ.get("RANK", "0")) == 0:
oldest = min((ckpt["iteration"] if ckpt else 0) for ckpt in checkpoints.values())
threshold = max(0, oldest - keep_back)
root = Path(policy_dir)
# group files by numeric iteration (only *.yaml / *.pth)
groups: dict[int, list[Path]] = {}
for p in root.iterdir():
if p.suffix in (".yaml", ".pth") and p.stem.isdigit():
groups.setdefault(int(p.stem), []).append(p)
# 1) drop anything older than threshold
for it in [i for i in groups if i <= threshold]:
for p in groups[it]:
p.unlink(missing_ok=True)
groups.pop(it, None)
# 2) cap total YAML checkpoints: keep newest `max_yaml` iters
yaml_iters = sorted((i for i, ps in groups.items() if any(p.suffix == ".yaml" for p in ps)), reverse=True)
for it in yaml_iters[max_yaml:]:
for p in groups.get(it, []):
p.unlink(missing_ok=True)
groups.pop(it, None)
class PbtTablePrinter:
"""All PrettyTable-related rendering lives here."""
def __init__(self, *, float_digits: int = 6, path_maxlen: int = 52):
self.float_digits = float_digits
self.path_maxlen = path_maxlen
# format helpers
def fmt(self, v):
return f"{v:.{self.float_digits}g}" if isinstance(v, float) else v
def short(self, s: str) -> str:
s = str(s)
L = self.path_maxlen
return s if len(s) <= L else s[: L // 2 - 1] + "…" + s[-L // 2 :]
# tables
def print_params_table(self, params: dict, header: str = "Parameters"):
table = PrettyTable(field_names=["Parameter", "Value"])
table.align["Parameter"] = "l"
table.align["Value"] = "r"
for k in sorted(params):
table.add_row([k, self.fmt(params[k])])
print(header + ":")
print(table.get_string())
def print_ckpt_summary(self, sumry: dict[int, dict | None]):
t = PrettyTable(["Policy", "Status", "Objective", "Iter", "Frame", "Experiment", "Checkpoint", "YAML"])
t.align["Policy"] = "r"
t.align["Status"] = "l"
t.align["Objective"] = "r"
t.align["Iter"] = "r"
t.align["Frame"] = "r"
t.align["Experiment"] = "l"
t.align["Checkpoint"] = "l"
t.align["YAML"] = "l"
for p in sorted(sumry.keys()):
c = sumry[p]
if c is None:
t.add_row([p, "—", "", "", "", "", "", ""])
else:
t.add_row([
p,
"OK",
self.fmt(c.get("true_objective", "")),
c.get("iteration", ""),
c.get("frame", ""),
c.get("experiment_name", ""),
self.short(c.get("checkpoint", "")),
self.short(c.get("pbt_checkpoint", "")),
])
print(t)
def print_mutation_diff(self, before: dict, after: dict, *, header: str = "Mutated params (changed only)"):
t = PrettyTable(["Parameter", "Old", "New"])
for k in sorted(before):
if before[k] != after[k]:
t.add_row([k, self.fmt(before[k]), self.fmt(after[k])])
print(header + ":")
print(t if t._rows else "(no changes)")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment