Unverified Commit ca4043cc authored by ooctipus's avatar ooctipus Committed by GitHub

Adds wandb native support in rl_games (#2650)

# Description

This PR creates support wandb logging in rl_games training. rl_games has
been supporting wandb logging, and the examples of how to configure it
can be seen from
[rl_games-wandb_support](https://github.com/Denys88/rl_games/blob/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/runner.py
) we could follow the same style and enable current rl_games pipeline to
use wandb as well.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 7a489ad1
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import argparse import argparse
import sys import sys
from distutils.util import strtobool
from isaaclab.app import AppLauncher from isaaclab.app import AppLauncher
...@@ -26,7 +27,17 @@ parser.add_argument( ...@@ -26,7 +27,17 @@ parser.add_argument(
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.") parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument("--sigma", type=str, default=None, help="The policy's initial standard deviation.") parser.add_argument("--sigma", type=str, default=None, help="The policy's initial standard deviation.")
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.") parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
parser.add_argument("--wandb-project-name", type=str, default=None, help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None, help="the entity (team) of wandb's project")
parser.add_argument("--wandb-name", type=str, default=None, help="the name of wandb's run")
parser.add_argument(
"--track",
type=lambda x: bool(strtobool(x)),
default=False,
nargs="?",
const=True,
help="if toggled, this experiment will be tracked with Weights and Biases",
)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
...@@ -109,7 +120,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -109,7 +120,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env_cfg.seed = agent_cfg["params"]["seed"] env_cfg.seed = agent_cfg["params"]["seed"]
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"]) config_name = agent_cfg["params"]["config"]["name"]
log_root_path = os.path.join("logs", "rl_games", config_name)
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Logging experiment in directory: {log_root_path}") print(f"[INFO] Logging experiment in directory: {log_root_path}")
# specify directory for logging runs # specify directory for logging runs
...@@ -118,6 +130,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -118,6 +130,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# logging directory path: <train_dir>/<full_experiment_name> # logging directory path: <train_dir>/<full_experiment_name>
agent_cfg["params"]["config"]["train_dir"] = log_root_path agent_cfg["params"]["config"]["train_dir"] = log_root_path
agent_cfg["params"]["config"]["full_experiment_name"] = log_dir agent_cfg["params"]["config"]["full_experiment_name"] = log_dir
wandb_project = config_name if args_cli.wandb_project_name is None else args_cli.wandb_project_name
experiment_name = log_dir if args_cli.wandb_name is None else args_cli.wandb_name
# dump the configuration into log-directory # dump the configuration into log-directory
dump_yaml(os.path.join(log_root_path, log_dir, "params", "env.yaml"), env_cfg) dump_yaml(os.path.join(log_root_path, log_dir, "params", "env.yaml"), env_cfg)
...@@ -168,6 +182,23 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -168,6 +182,23 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# reset the agent and env # reset the agent and env
runner.reset() runner.reset()
# train the agent # train the agent
global_rank = int(os.getenv("RANK", "0"))
if args_cli.track and global_rank == 0:
if args_cli.wandb_entity is None:
raise ValueError("Weights and Biases entity must be specified for tracking.")
import wandb
wandb.init(
project=wandb_project,
entity=args_cli.wandb_entity,
name=experiment_name,
sync_tensorboard=True,
config=agent_cfg,
monitor_gym=True,
save_code=True,
)
if args_cli.checkpoint is not None: if args_cli.checkpoint is not None:
runner.run({"train": True, "play": False, "sigma": train_sigma, "checkpoint": resume_path}) runner.run({"train": True, "play": False, "sigma": train_sigma, "checkpoint": resume_path})
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment