Commit f7e4183d authored by Mayank Mittal's avatar Mayank Mittal

renames SkrlTrainer to SkrlSequentialLogTrainer for explicit visbility

parent e9862d42
......@@ -7,8 +7,8 @@ Changelog
Added
^^^^^
* Environment wrapper for the skrl RL library
* Training/evaluation configuration files for the skrl RL library
* Added environment wrapper and sequential trainer for the skrl RL library
* Added training/evaluation configuration files for the skrl RL library
0.1.2 (2023-01-19)
~~~~~~~~~~~~~~~~~~
......
......@@ -18,7 +18,7 @@ Or, equivalently, by directly calling the skrl library API as follows:
"""
import copy
import torch
from typing import List, Optional, Union
......@@ -28,10 +28,11 @@ import tqdm
from skrl.agents.torch import Agent
from skrl.envs.torch.wrappers import Wrapper, wrap_env
from skrl.trainers.torch import Trainer
from skrl.trainers.torch.sequential import SEQUENTIAL_TRAINER_DEFAULT_CONFIG
from omni.isaac.orbit_envs.isaac_env import IsaacEnv
__all__ = ["SkrlVecEnvWrapper"]
__all__ = ["SkrlVecEnvWrapper", "SkrlSequentialLogTrainer"]
"""
......@@ -57,25 +58,72 @@ def SkrlVecEnvWrapper(env: IsaacEnv):
return wrap_env(env, wrapper="isaac-orbit")
class SkrlLogTrainer(Trainer):
"""
Custom trainer for skrl.
"""
class SkrlSequentialLogTrainer(Trainer):
"""Sequential trainer with logging of episode information.
This trainer inherits from the :class:`skrl.trainers.base_class.Trainer` class. It is used to
train agents in a sequential manner (i.e., one after the other in each interaction with the
environment). It is most suitable for on-policy RL agents such as PPO, A2C, etc.
It modifies the :class:`skrl.trainers.torch.sequential.SequentialTrainer` class with the following
differences:
* It also log episode information to the agent's logger.
* It does not close the environment at the end of the training.
Reference:
https://skrl.readthedocs.io/en/latest/modules/skrl.trainers.base_class.html
"""
def __init__(
self,
env: Wrapper,
agents: Union[Agent, List[Agent]],
agents_scope: Optional[List[int]] = None,
cfg: Optional[dict] = None,
) -> None:
"""Customized trainer for tracking episode information
Reference:
https://skrl.readthedocs.io/en/latest/modules/skrl.trainers.base_class.html
):
"""Initializes the trainer.
Args:
env (Wrapper): Environment to train on.
agents (Union[Agent, List[Agent]]): Agents to train.
agents_scope (Optional[List[int]], optional): Number of environments for each agent to
train on. Defaults to None.
cfg (Optional[dict], optional): Configuration dictionary. Defaults to None.
"""
default_cfg = {"timesteps": 1000, "disable_progressbar": False}
default_cfg.update(cfg if cfg is not None else {})
super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=default_cfg)
# update the config
_cfg = copy.deepcopy(SEQUENTIAL_TRAINER_DEFAULT_CONFIG)
_cfg.update(cfg if cfg is not None else {})
# store agents scope
agents_scope = agents_scope if agents_scope is not None else []
# initialize the base class
super().__init__(env=env, agents=agents, agents_scope=agents_scope, cfg=_cfg)
# init agents
if self.num_agents > 1:
for agent in self.agents:
agent.init(trainer_cfg=self.cfg)
else:
self.agents.init(trainer_cfg=self.cfg)
def train(self):
"""Train the agent"""
"""Train the agents sequentially.
This method executes the training loop for the agents. It performs the following steps:
* Pre-interaction: Perform any pre-interaction operations.
* Compute actions: Compute the actions for the agents.
* Step the environments: Step the environments with the computed actions.
* Record the environments' transitions: Record the transitions from the environments.
* Log custom environment data: Log custom environment data.
* Post-interaction: Perform any post-interaction operations.
* Reset the environments: Reset the environments if they are terminated or truncated.
"""
# init agent
self.agents.init(trainer_cfg=self.cfg)
self.agents.set_running_mode("train")
......@@ -90,6 +138,7 @@ class SkrlLogTrainer(Trainer):
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
# note: here we do not call render scene since it is done in the env.step() method
# record the environments' transitions
with torch.no_grad():
self.agents.record_transition(
......@@ -107,8 +156,74 @@ class SkrlLogTrainer(Trainer):
if "episode" in infos:
for k, v in infos["episode"].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())
self.agents.track_data(f"Episode / {k}", v.item())
# post-interaction
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
# reset the environments
# note: here we do not call reset scene since it is done in the env.step() method
# update states
states.copy_(next_states)
def eval(self) -> None:
"""Evaluate the agents sequentially.
This method executes the following steps in loop:
* Compute actions: Compute the actions for the agents.
* Step the environments: Step the environments with the computed actions.
* Record the environments' transitions: Record the transitions from the environments.
* Log custom environment data: Log custom environment data.
"""
# set running mode
if self.num_agents > 1:
for agent in self.agents:
agent.set_running_mode("eval")
else:
self.agents.set_running_mode("eval")
# single agent
if self.num_agents == 1:
self.single_agent_eval()
return
# reset env
states, infos = self.env.reset()
# evaluation loop
for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar):
# compute actions
with torch.no_grad():
actions = torch.vstack(
[
agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0]
for agent, scope in zip(self.agents, self.agents_scope)
]
)
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
with torch.no_grad():
# write data to TensorBoard
for agent, scope in zip(self.agents, self.agents_scope):
# track data
agent.record_transition(
states=states[scope[0] : scope[1]],
actions=actions[scope[0] : scope[1]],
rewards=rewards[scope[0] : scope[1]],
next_states=next_states[scope[0] : scope[1]],
terminated=terminated[scope[0] : scope[1]],
truncated=truncated[scope[0] : scope[1]],
infos=infos,
timestep=timestep,
timesteps=self.timesteps,
)
# log custom environment data
if "episode" in infos:
for k, v in infos["episode"].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
agent.track_data(f"Episode / {k}", v.item())
# perform post-interaction
super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps)
# reset environments
# note: here we do not call reset scene since it is done in the env.step() method
states.copy_(next_states)
......@@ -81,5 +81,7 @@ def convert_skrl_cfg(cfg):
elif key in ["rewards_shaper_scale"]:
d["rewards_shaper"] = reward_shaper_function(value)
# parse agent configuration and convert to classes
update_dict(cfg)
# return the updated configuration
return cfg
......@@ -114,18 +114,19 @@ def main():
else:
resume_path = get_checkpoint_path(log_root_path, os.path.join("*", "checkpoints"), None)
print(f"[INFO] Loading model checkpoint from: {resume_path}")
# initialize agent
agent.init()
agent.load(resume_path)
def get_actions(obs):
return agent.act(obs, timestep=0, timesteps=0)[0]
# set agent to evaluation mode
agent.set_running_mode("eval")
# reset environment
obs, _ = env.reset()
# simulate environment
while simulation_app.is_running():
# agent stepping
actions = get_actions(obs)
actions = agent.act(obs, timestep=0, timesteps=0)[0]
# env stepping
obs, _, _, _, _ = env.step(actions)
# check if simulator is stopped
......
......@@ -46,7 +46,7 @@ from omni.isaac.orbit.utils.io import dump_pickle, dump_yaml
import omni.isaac.contrib_envs # noqa: F401
import omni.isaac.orbit_envs # noqa: F401
from omni.isaac.orbit_envs.utils import parse_env_cfg
from omni.isaac.orbit_envs.utils.wrappers.skrl import SkrlLogTrainer, SkrlVecEnvWrapper
from omni.isaac.orbit_envs.utils.wrappers.skrl import SkrlSequentialLogTrainer, SkrlVecEnvWrapper
from config import convert_skrl_cfg, parse_skrl_cfg
......@@ -146,7 +146,7 @@ def main():
# configure and instantiate a custom RL trainer for logging episode events
# https://skrl.readthedocs.io/en/latest/modules/skrl.trainers.base_class.html
trainer_cfg = experiment_cfg["trainer"]
trainer = SkrlLogTrainer(cfg=trainer_cfg, env=env, agents=agent)
trainer = SkrlSequentialLogTrainer(cfg=trainer_cfg, env=env, agents=agent)
# train the agent
trainer.train()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment