Unverified Commit 3d5ea25d authored by yijieg's avatar yijieg Committed by GitHub

Removes wandb logging in AutoMate env (#2912)

# Description

wandb logging function is provided in rl_games script. So we remove the
wandb logging in task level. Also, we edit the task registration style
to help startup perf.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- This change requires a documentation update

## Checklist

- [ x ] I have run the [`pre-commit` checks](https://pre-commit.com/)
with `./isaaclab.sh --format`
- [ x ] I have made corresponding changes to the documentation
- [ x ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ x ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 13965cc3
...@@ -225,7 +225,7 @@ We provide environments for both disassembly and assembly. ...@@ -225,7 +225,7 @@ We provide environments for both disassembly and assembly.
* |disassembly-link|: The plug starts inserted in the socket. A low-level controller lifts the plug out and moves it to a random position. This process is purely scripted and does not involve any learned policy. Therefore, it does not require policy training or evaluation. The resulting trajectories serve as demonstrations for the reverse process, i.e., learning to assemble. To run disassembly for a specific task: ``python source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_disassembly_w_id.py --assembly_id=ASSEMBLY_ID --disassembly_dir=DISASSEMBLY_DIR``. All generated trajectories are saved to a local directory ``DISASSEMBLY_DIR``. * |disassembly-link|: The plug starts inserted in the socket. A low-level controller lifts the plug out and moves it to a random position. This process is purely scripted and does not involve any learned policy. Therefore, it does not require policy training or evaluation. The resulting trajectories serve as demonstrations for the reverse process, i.e., learning to assemble. To run disassembly for a specific task: ``python source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_disassembly_w_id.py --assembly_id=ASSEMBLY_ID --disassembly_dir=DISASSEMBLY_DIR``. All generated trajectories are saved to a local directory ``DISASSEMBLY_DIR``.
* |assembly-link|: The goal is to insert the plug into the socket. You can use this environment to train a policy via reinforcement learning or evaluate a pre-trained checkpoint. * |assembly-link|: The goal is to insert the plug into the socket. You can use this environment to train a policy via reinforcement learning or evaluate a pre-trained checkpoint.
* To train an assembly policy, we run the command ``python source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py --assembly_id=ASSEMBLY_ID --train``. We can customize the training process using the optional flags: ``--headless`` to run without opening the GUI windows, ``--max_iterations=MAX_ITERATIONS`` to set the number of training iterations, ``--num_envs=NUM_ENVS`` to set the number of parallel environments during training, ``--seed=SEED`` to assign the random seed, ``--wandb`` to enable logging to WandB (requires a WandB account). The policy checkpoints will be saved automatically during training in the directory ``logs/rl_games/Assembly/test``. * To train an assembly policy, we run the command ``python source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py --assembly_id=ASSEMBLY_ID --train``. We can customize the training process using the optional flags: ``--headless`` to run without opening the GUI windows, ``--max_iterations=MAX_ITERATIONS`` to set the number of training iterations, ``--num_envs=NUM_ENVS`` to set the number of parallel environments during training, ``--seed=SEED`` to assign the random seed. The policy checkpoints will be saved automatically during training in the directory ``logs/rl_games/Assembly/test``.
* To evaluate an assembly policy, we run the command ``python source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py --assembly_id=ASSEMBLY_ID --checkpoint=CHECKPOINT --log_eval``. The evaluation results are stored in ``evaluation_{ASSEMBLY_ID}.h5``. * To evaluate an assembly policy, we run the command ``python source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py --assembly_id=ASSEMBLY_ID --checkpoint=CHECKPOINT --log_eval``. The evaluation results are stored in ``evaluation_{ASSEMBLY_ID}.h5``.
.. table:: .. table::
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
import gymnasium as gym import gymnasium as gym
from . import agents from . import agents
from .assembly_env import AssemblyEnv, AssemblyEnvCfg
from .disassembly_env import DisassemblyEnv, DisassemblyEnvCfg
## ##
# Register Gym environments. # Register Gym environments.
...@@ -15,10 +13,10 @@ from .disassembly_env import DisassemblyEnv, DisassemblyEnvCfg ...@@ -15,10 +13,10 @@ from .disassembly_env import DisassemblyEnv, DisassemblyEnvCfg
gym.register( gym.register(
id="Isaac-AutoMate-Assembly-Direct-v0", id="Isaac-AutoMate-Assembly-Direct-v0",
entry_point="isaaclab_tasks.direct.automate:AssemblyEnv", entry_point=f"{__name__}.assembly_env:AssemblyEnv",
disable_env_checker=True, disable_env_checker=True,
kwargs={ kwargs={
"env_cfg_entry_point": AssemblyEnvCfg, "env_cfg_entry_point": f"{__name__}.assembly_env:AssemblyEnvCfg",
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
}, },
) )
...@@ -26,10 +24,10 @@ gym.register( ...@@ -26,10 +24,10 @@ gym.register(
gym.register( gym.register(
id="Isaac-AutoMate-Disassembly-Direct-v0", id="Isaac-AutoMate-Disassembly-Direct-v0",
entry_point="isaaclab_tasks.direct.automate:DisassemblyEnv", entry_point=f"{__name__}.disassembly_env:DisassemblyEnv",
disable_env_checker=True, disable_env_checker=True,
kwargs={ kwargs={
"env_cfg_entry_point": DisassemblyEnvCfg, "env_cfg_entry_point": f"{__name__}.disassembly_env:DisassemblyEnvCfg",
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
}, },
) )
...@@ -7,11 +7,9 @@ import json ...@@ -7,11 +7,9 @@ import json
import numpy as np import numpy as np
import os import os
import torch import torch
from datetime import datetime
import carb import carb
import isaacsim.core.utils.torch as torch_utils import isaacsim.core.utils.torch as torch_utils
import wandb
import warp as wp import warp as wp
import isaaclab.sim as sim_utils import isaaclab.sim as sim_utils
...@@ -71,9 +69,6 @@ class AssemblyEnv(DirectRLEnv): ...@@ -71,9 +69,6 @@ class AssemblyEnv(DirectRLEnv):
if self.cfg_task.sample_from != "rand": if self.cfg_task.sample_from != "rand":
self._init_eval_loading() self._init_eval_loading()
if self.cfg_task.wandb:
wandb.init(project="automate", name=self.cfg_task.assembly_id + "_" + datetime.now().strftime("%m/%d/%Y"))
def _init_eval_loading(self): def _init_eval_loading(self):
eval_held_asset_pose, eval_fixed_asset_pose, eval_success = automate_log.load_log_from_hdf5( eval_held_asset_pose, eval_fixed_asset_pose, eval_success = automate_log.load_log_from_hdf5(
self.cfg_task.eval_filename self.cfg_task.eval_filename
...@@ -554,9 +549,6 @@ class AssemblyEnv(DirectRLEnv): ...@@ -554,9 +549,6 @@ class AssemblyEnv(DirectRLEnv):
rew_buf = self._update_rew_buf(curr_successes) rew_buf = self._update_rew_buf(curr_successes)
self.ep_succeeded = torch.logical_or(self.ep_succeeded, curr_successes) self.ep_succeeded = torch.logical_or(self.ep_succeeded, curr_successes)
if self.cfg_task.wandb:
wandb.log(self.extras)
# Only log episode success rates at the end of an episode. # Only log episode success rates at the end of an episode.
if torch.any(self.reset_buf): if torch.any(self.reset_buf):
self.extras["successes"] = torch.count_nonzero(self.ep_succeeded) / self.num_envs self.extras["successes"] = torch.count_nonzero(self.ep_succeeded) / self.num_envs
...@@ -579,12 +571,6 @@ class AssemblyEnv(DirectRLEnv): ...@@ -579,12 +571,6 @@ class AssemblyEnv(DirectRLEnv):
) )
self.extras["curr_max_disp"] = self.curr_max_disp self.extras["curr_max_disp"] = self.curr_max_disp
if self.cfg_task.wandb:
wandb.log({
"success": torch.mean(self.ep_succeeded.float()),
"reward": torch.mean(rew_buf),
"sbc_rwd_scale": sbc_rwd_scale,
})
if self.cfg_task.if_logging_eval: if self.cfg_task.if_logging_eval:
self.success_log = torch.cat([self.success_log, self.ep_succeeded.reshape((self.num_envs, 1))], dim=0) self.success_log = torch.cat([self.success_log, self.ep_succeeded.reshape((self.num_envs, 1))], dim=0)
......
...@@ -138,7 +138,6 @@ class AssemblyTask: ...@@ -138,7 +138,6 @@ class AssemblyTask:
if_logging_eval: bool = False if_logging_eval: bool = False
num_eval_trials: int = 100 num_eval_trials: int = 100
eval_filename: str = "evaluation_00015.h5" eval_filename: str = "evaluation_00015.h5"
wandb: bool = False
# Fine-tuning # Fine-tuning
sample_from: str = "rand" # gp, gmm, idv, rand sample_from: str = "rand" # gp, gmm, idv, rand
......
...@@ -9,7 +9,7 @@ import subprocess ...@@ -9,7 +9,7 @@ import subprocess
import sys import sys
def update_task_param(task_cfg, assembly_id, if_sbc, if_log_eval, if_wandb): def update_task_param(task_cfg, assembly_id, if_sbc, if_log_eval):
# Read the file lines. # Read the file lines.
with open(task_cfg) as f: with open(task_cfg) as f:
lines = f.readlines() lines = f.readlines()
...@@ -21,7 +21,6 @@ def update_task_param(task_cfg, assembly_id, if_sbc, if_log_eval, if_wandb): ...@@ -21,7 +21,6 @@ def update_task_param(task_cfg, assembly_id, if_sbc, if_log_eval, if_wandb):
if_sbc_pattern = re.compile(r"^(.*if_sbc\s*:\s*bool\s*=\s*).*$") if_sbc_pattern = re.compile(r"^(.*if_sbc\s*:\s*bool\s*=\s*).*$")
if_log_eval_pattern = re.compile(r"^(.*if_logging_eval\s*:\s*bool\s*=\s*).*$") if_log_eval_pattern = re.compile(r"^(.*if_logging_eval\s*:\s*bool\s*=\s*).*$")
eval_file_pattern = re.compile(r"^(.*eval_filename\s*:\s*str\s*=\s*).*$") eval_file_pattern = re.compile(r"^(.*eval_filename\s*:\s*str\s*=\s*).*$")
if_wandb_pattern = re.compile(r"^(.*wandb\s*:\s*bool\s*=\s*).*$")
for line in lines: for line in lines:
if "assembly_id =" in line: if "assembly_id =" in line:
...@@ -32,8 +31,6 @@ def update_task_param(task_cfg, assembly_id, if_sbc, if_log_eval, if_wandb): ...@@ -32,8 +31,6 @@ def update_task_param(task_cfg, assembly_id, if_sbc, if_log_eval, if_wandb):
line = if_log_eval_pattern.sub(rf"\1{str(if_log_eval)}", line) line = if_log_eval_pattern.sub(rf"\1{str(if_log_eval)}", line)
elif "eval_filename: str = " in line: elif "eval_filename: str = " in line:
line = eval_file_pattern.sub(r"\1'{}'".format(f"evaluation_{assembly_id}.h5"), line) line = eval_file_pattern.sub(r"\1'{}'".format(f"evaluation_{assembly_id}.h5"), line)
elif "wandb: bool =" in line:
line = if_wandb_pattern.sub(rf"\1{str(if_wandb)}", line)
updated_lines.append(line) updated_lines.append(line)
...@@ -51,7 +48,6 @@ def main(): ...@@ -51,7 +48,6 @@ def main():
default="source/isaaclab_tasks/isaaclab_tasks/direct/automate/assembly_tasks_cfg.py", default="source/isaaclab_tasks/isaaclab_tasks/direct/automate/assembly_tasks_cfg.py",
) )
parser.add_argument("--assembly_id", type=str, help="New assembly ID to set.") parser.add_argument("--assembly_id", type=str, help="New assembly ID to set.")
parser.add_argument("--wandb", action="store_true", help="Use wandb to record learning curves")
parser.add_argument("--checkpoint", type=str, help="Checkpoint path.") parser.add_argument("--checkpoint", type=str, help="Checkpoint path.")
parser.add_argument("--num_envs", type=int, default=128, help="Number of parallel environment.") parser.add_argument("--num_envs", type=int, default=128, help="Number of parallel environment.")
parser.add_argument("--seed", type=int, default=-1, help="Random seed.") parser.add_argument("--seed", type=int, default=-1, help="Random seed.")
...@@ -61,7 +57,7 @@ def main(): ...@@ -61,7 +57,7 @@ def main():
parser.add_argument("--max_iterations", type=int, default=1500, help="Number of iteration for policy learning.") parser.add_argument("--max_iterations", type=int, default=1500, help="Number of iteration for policy learning.")
args = parser.parse_args() args = parser.parse_args()
update_task_param(args.cfg_path, args.assembly_id, args.train, args.log_eval, args.wandb) update_task_param(args.cfg_path, args.assembly_id, args.train, args.log_eval)
bash_command = None bash_command = None
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment