Unverified Commit 90dda53f authored by ooctipus's avatar ooctipus Committed by GitHub

Enhances Pbt usage experience through small improvements (#3449)

# Description

This PR is added with feedback from PBT user, and made below improvments

1. added resume logic to allow wandb to continue on the same run_id
2. corrected broadcasting order in distributed setup
3. made score query general by using dotted keys to access dictionary of
arbitrary depth

Fixes # (issue)

- Bug fix (non-breaking change which fixes an issue)


## Screenshots

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

## Checklist

- [x] I have read and understood the [contribution
guidelines](https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html)
- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 187f9a58
...@@ -49,7 +49,7 @@ Example Config ...@@ -49,7 +49,7 @@ Example Config
num_policies: 8 num_policies: 8
directory: . directory: .
workspace: "pbt_workspace" workspace: "pbt_workspace"
objective: Curriculum/difficulty_level objective: episode.Curriculum/difficulty_level
interval_steps: 50000000 interval_steps: 50000000
threshold_std: 0.1 threshold_std: 0.1
threshold_abs: 0.025 threshold_abs: 0.025
...@@ -66,9 +66,9 @@ Example Config ...@@ -66,9 +66,9 @@ Example Config
agent.params.config.tau: "mutate_discount" agent.params.config.tau: "mutate_discount"
``objective: Curriculum/difficulty_level`` uses ``infos["episode"]["Curriculum/difficulty_level"]`` as the scalar to ``objective: episode.Curriculum/difficulty_level`` is the dotted expression that uses
**rank policies** (higher is better). With ``num_policies: 8``, launch eight processes sharing the same ``workspace`` ``infos["episode"]["Curriculum/difficulty_level"]`` as the scalar to **rank policies** (higher is better).
and unique ``policy_idx`` (0-7). With ``num_policies: 8``, launch eight processes sharing the same ``workspace`` and unique ``policy_idx`` (0-7).
Launching PBT Launching PBT
......
...@@ -226,6 +226,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -226,6 +226,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
monitor_gym=True, monitor_gym=True,
save_code=True, save_code=True,
) )
if not wandb.run.resumed:
wandb.config.update({"env_cfg": env_cfg.to_dict()}) wandb.config.update({"env_cfg": env_cfg.to_dict()})
wandb.config.update({"agent_cfg": agent_cfg}) wandb.config.update({"agent_cfg": agent_cfg})
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.4.0" version = "0.4.1"
# Description # Description
title = "Isaac Lab RL" title = "Isaac Lab RL"
......
Changelog Changelog
--------- ---------
0.4.1 (2025-09-09)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Made PBT a bit nicer by
* 1. added resume logic to allow wandb to continue on the same run_id
* 2. corrected broadcasting order in distributed setup
* 3. made score query general by using dotted keys to access dictionary of arbitrary depth
0.4.0 (2025-09-09) 0.4.0 (2025-09-09)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -68,9 +68,12 @@ class PbtAlgoObserver(AlgoObserver): ...@@ -68,9 +68,12 @@ class PbtAlgoObserver(AlgoObserver):
"""Extract the scalar objective from environment infos and store in `self.score`. """Extract the scalar objective from environment infos and store in `self.score`.
Notes: Notes:
Expects the objective to be at `infos["episode"][self.cfg.objective]`. Expects the objective to be at `infos[self.cfg.objective]` where self.cfg.objective is dotted address.
""" """
self.score = infos["episode"][self.cfg.objective] score = infos
for part in self.cfg.objective.split("."):
score = score[part]
self.score = score
def after_steps(self): def after_steps(self):
"""Main PBT tick executed every train step. """Main PBT tick executed every train step.
...@@ -84,6 +87,9 @@ class PbtAlgoObserver(AlgoObserver): ...@@ -84,6 +87,9 @@ class PbtAlgoObserver(AlgoObserver):
whitelisted params, set `restart_flag`, broadcast (if distributed), whitelisted params, set `restart_flag`, broadcast (if distributed),
and print a mutation diff table. and print a mutation diff table.
""" """
if self.distributed_args.distributed:
dist.broadcast(self.restart_flag, src=0)
if self.distributed_args.rank != 0: if self.distributed_args.rank != 0:
if self.restart_flag.cpu().item() == 1: if self.restart_flag.cpu().item() == 1:
os._exit(0) os._exit(0)
...@@ -154,9 +160,6 @@ class PbtAlgoObserver(AlgoObserver): ...@@ -154,9 +160,6 @@ class PbtAlgoObserver(AlgoObserver):
self.new_params = mutate(cur_params, self.cfg.mutation, self.cfg.mutation_rate, self.cfg.change_range) self.new_params = mutate(cur_params, self.cfg.mutation, self.cfg.mutation_rate, self.cfg.change_range)
self.restart_from_checkpoint = os.path.abspath(ckpts[replacement_policy_candidate]["checkpoint"]) self.restart_from_checkpoint = os.path.abspath(ckpts[replacement_policy_candidate]["checkpoint"])
self.restart_flag[0] = 1 self.restart_flag[0] = 1
if self.distributed_args.distributed:
dist.broadcast(self.restart_flag, src=0)
self.printer.print_mutation_diff(cur_params, self.new_params) self.printer.print_mutation_diff(cur_params, self.new_params)
def _restart_with_new_params(self, new_params, restart_from_checkpoint): def _restart_with_new_params(self, new_params, restart_from_checkpoint):
...@@ -191,6 +194,11 @@ class PbtAlgoObserver(AlgoObserver): ...@@ -191,6 +194,11 @@ class PbtAlgoObserver(AlgoObserver):
if self.wandb_args.enabled: if self.wandb_args.enabled:
import wandb import wandb
# note setdefault will only affect child process, that mean don't have to worry it env variable
# propagate beyond restarted child process
os.environ.setdefault("WANDB_RUN_ID", wandb.run.id) # continue with the same run id
os.environ.setdefault("WANDB_RESUME", "allow") # allow wandb to resume
os.environ.setdefault("WANDB_INIT_TIMEOUT", "300") # give wandb init more time to be fault tolerant
wandb.run.finish() wandb.run.finish()
# Get the directory of the current file # Get the directory of the current file
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment