Commit 553f1f5d authored by Michael Noseworthy's avatar Michael Noseworthy Committed by Kelly Guo

Adds FORGE tasks for contact-rich manipulation with force sensing to IsaacLab (#2968)

This MR adds new tasks which extend the `Factory` tasks to include:
1. Force sensing: Add observations for force experienced by the
end-effector.
2. Excessive force penalty: Add an option to penalize the agent for
excessive contact forces.
3. Dynamics randomization: Randomize controller gains, asset properties
(friction, mass), and dead-zone.
4. Success prediction: Add an extra action that predicts task success.

The new tasks are: `Isaac-Forge-PegInsert-Direct-v0`,
`Isaac-Forge-GearMesh-Direct-v0`, and `Isaac-Forge-NutThread-Direct-v0`

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
Co-authored-by: 's avatarOcti Zhang <zhengyuz@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 1c780a02
...@@ -360,5 +360,10 @@ ...@@ -360,5 +360,10 @@
"package": "referencing", "package": "referencing",
"license": "UNKNOWN", "license": "UNKNOWN",
"comment": "MIT" "comment": "MIT"
},
{
"package": "regex",
"license": "UNKNOWN",
"comment": "Apache 2.0"
} }
] ]
...@@ -245,6 +245,44 @@ We provide environments for both disassembly and assembly. ...@@ -245,6 +245,44 @@ We provide environments for both disassembly and assembly.
.. |assembly-link| replace:: `Isaac-AutoMate-Assembly-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab_tasks/isaaclab_tasks/direct/automate/assembly_env_cfg.py>`__ .. |assembly-link| replace:: `Isaac-AutoMate-Assembly-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab_tasks/isaaclab_tasks/direct/automate/assembly_env_cfg.py>`__
.. |disassembly-link| replace:: `Isaac-AutoMate-Disassembly-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab_tasks/isaaclab_tasks/direct/automate/disassembly_env_cfg.py>`__ .. |disassembly-link| replace:: `Isaac-AutoMate-Disassembly-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab_tasks/isaaclab_tasks/direct/automate/disassembly_env_cfg.py>`__
FORGE
~~~~~~~~
FORGE environments extend Factory environments with:
* Force sensing: Add observations for force experienced by the end-effector.
* Excessive force penalty: Add an option to penalize the agent for excessive contact forces.
* Dynamics randomization: Randomize controller gains, asset properties (friction, mass), and dead-zone.
* Success prediction: Add an extra action that predicts task success.
These tasks share the same task configurations and control options. You can switch between them by specifying the task name.
* |forge-peg-link|: Peg insertion with the Franka arm
* |forge-gear-link|: Gear meshing with the Franka arm
* |forge-nut-link|: Nut-Bolt fastening with the Franka arm
.. table::
:widths: 33 37 30
+--------------------+-------------------------+-----------------------------------------------------------------------------+
| World | Environment ID | Description |
+====================+=========================+=============================================================================+
| |forge-peg| | |forge-peg-link| | Insert peg into the socket with the Franka robot |
+--------------------+-------------------------+-----------------------------------------------------------------------------+
| |forge-gear| | |forge-gear-link| | Insert and mesh gear into the base with other gears, using the Franka robot |
+--------------------+-------------------------+-----------------------------------------------------------------------------+
| |forge-nut| | |forge-nut-link| | Thread the nut onto the first 2 threads of the bolt, using the Franka robot |
+--------------------+-------------------------+-----------------------------------------------------------------------------+
.. |forge-peg| image:: ../_static/tasks/factory/peg_insert.jpg
.. |forge-gear| image:: ../_static/tasks/factory/gear_mesh.jpg
.. |forge-nut| image:: ../_static/tasks/factory/nut_thread.jpg
.. |forge-peg-link| replace:: `Isaac-Forge-PegInsert-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab_tasks/isaaclab_tasks/direct/forge/forge_env_cfg.py>`__
.. |forge-gear-link| replace:: `Isaac-Forge-GearMesh-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab_tasks/isaaclab_tasks/direct/forge/forge_env_cfg.py>`__
.. |forge-nut-link| replace:: `Isaac-Forge-NutThread-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab_tasks/isaaclab_tasks/direct/forge/forge_env_cfg.py>`__
Locomotion Locomotion
~~~~~~~~~~ ~~~~~~~~~~
...@@ -743,6 +781,18 @@ inferencing, including reading from an already trained checkpoint and disabling ...@@ -743,6 +781,18 @@ inferencing, including reading from an already trained checkpoint and disabling
- -
- Direct - Direct
- -
* - Isaac-Forge-GearMesh-Direct-v0
-
- Direct
- **rl_games** (PPO)
* - Isaac-Forge-NutThread-Direct-v0
-
- Direct
- **rl_games** (PPO)
* - Isaac-Forge-PegInsert-Direct-v0
-
- Direct
- **rl_games** (PPO)
* - Isaac-Franka-Cabinet-Direct-v0 * - Isaac-Franka-Cabinet-Direct-v0
- -
- Direct - Direct
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.10.43" version = "0.10.44"
# Description # Description
title = "Isaac Lab Environments" title = "Isaac Lab Environments"
......
Changelog Changelog
--------- ---------
0.10.44 (2025-07-16)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added ``Isaac-Forge-PegInsert-Direct-v0``, ``Isaac-Forge-GearMesh-Direct-v0``,
and ``Isaac-Forge-NutThread-Direct-v0`` environments as direct RL envs. These
environments extend ``Isaac-Factory-*-v0`` with force sensing, an excessive force
penalty, dynamics randomization, and success prediction.
0.10.43 (2025-07-24) 0.10.43 (2025-07-24)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -31,6 +31,7 @@ def compute_dof_torque( ...@@ -31,6 +31,7 @@ def compute_dof_torque(
task_prop_gains, task_prop_gains,
task_deriv_gains, task_deriv_gains,
device, device,
dead_zone_thresholds=None,
): ):
"""Compute Franka DOF torque to move fingertips towards target pose.""" """Compute Franka DOF torque to move fingertips towards target pose."""
# References: # References:
...@@ -61,6 +62,15 @@ def compute_dof_torque( ...@@ -61,6 +62,15 @@ def compute_dof_torque(
) )
task_wrench += task_wrench_motion task_wrench += task_wrench_motion
# Offset task_wrench motion by random amount to simulate unreliability at low forces.
# Check if absolute value is less than specified amount. If so, 0 out, otherwise, subtract.
if dead_zone_thresholds is not None:
task_wrench = torch.where(
task_wrench.abs() < dead_zone_thresholds,
torch.zeros_like(task_wrench),
task_wrench.sign() * (task_wrench.abs() - dead_zone_thresholds),
)
# Set tau = J^T * tau, i.e., map tau into joint space as desired # Set tau = J^T * tau, i.e., map tau into joint space as desired
jacobian_T = torch.transpose(jacobian, dim0=1, dim1=2) jacobian_T = torch.transpose(jacobian, dim0=1, dim1=2)
dof_torque[:, 0:7] = (jacobian_T @ task_wrench.unsqueeze(-1)).squeeze(-1) dof_torque[:, 0:7] = (jacobian_T @ task_wrench.unsqueeze(-1)).squeeze(-1)
...@@ -135,7 +145,7 @@ def get_pose_error( ...@@ -135,7 +145,7 @@ def get_pose_error(
return pos_error, axis_angle_error return pos_error, axis_angle_error
def _get_delta_dof_pos(delta_pose, ik_method, jacobian, device): def get_delta_dof_pos(delta_pose, ik_method, jacobian, device):
"""Get delta Franka DOF position from delta pose using specified IK method.""" """Get delta Franka DOF position from delta pose using specified IK method."""
# References: # References:
# 1) https://www.cs.cmu.edu/~15464-s13/lectures/lecture6/iksurvey.pdf # 1) https://www.cs.cmu.edu/~15464-s13/lectures/lecture6/iksurvey.pdf
......
...@@ -16,7 +16,7 @@ from isaaclab.sim.spawners.from_files import GroundPlaneCfg, spawn_ground_plane ...@@ -16,7 +16,7 @@ from isaaclab.sim.spawners.from_files import GroundPlaneCfg, spawn_ground_plane
from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR
from isaaclab.utils.math import axis_angle_from_quat from isaaclab.utils.math import axis_angle_from_quat
from . import factory_control as fc from . import factory_control, factory_utils
from .factory_env_cfg import OBS_DIM_CFG, STATE_DIM_CFG, FactoryEnvCfg from .factory_env_cfg import OBS_DIM_CFG, STATE_DIM_CFG, FactoryEnvCfg
...@@ -33,18 +33,9 @@ class FactoryEnv(DirectRLEnv): ...@@ -33,18 +33,9 @@ class FactoryEnv(DirectRLEnv):
super().__init__(cfg, render_mode, **kwargs) super().__init__(cfg, render_mode, **kwargs)
self._set_body_inertias() factory_utils.set_body_inertias(self._robot, self.scene.num_envs)
self._init_tensors() self._init_tensors()
self._set_default_dynamics_parameters() self._set_default_dynamics_parameters()
self._compute_intermediate_values(dt=self.physics_dt)
def _set_body_inertias(self):
"""Note: this is to account for the asset_options.armature parameter in IGE."""
inertias = self._robot.root_physx_view.get_inertias()
offset = torch.zeros_like(inertias)
offset[:, :, [0, 4, 8]] += 0.01
new_inertias = inertias + offset
self._robot.root_physx_view.set_inertias(new_inertias, torch.arange(self.num_envs))
def _set_default_dynamics_parameters(self): def _set_default_dynamics_parameters(self):
"""Set parameters defining dynamic interactions.""" """Set parameters defining dynamic interactions."""
...@@ -60,55 +51,21 @@ class FactoryEnv(DirectRLEnv): ...@@ -60,55 +51,21 @@ class FactoryEnv(DirectRLEnv):
) )
# Set masses and frictions. # Set masses and frictions.
self._set_friction(self._held_asset, self.cfg_task.held_asset_cfg.friction) factory_utils.set_friction(self._held_asset, self.cfg_task.held_asset_cfg.friction, self.scene.num_envs)
self._set_friction(self._fixed_asset, self.cfg_task.fixed_asset_cfg.friction) factory_utils.set_friction(self._fixed_asset, self.cfg_task.fixed_asset_cfg.friction, self.scene.num_envs)
self._set_friction(self._robot, self.cfg_task.robot_cfg.friction) factory_utils.set_friction(self._robot, self.cfg_task.robot_cfg.friction, self.scene.num_envs)
def _set_friction(self, asset, value):
"""Update material properties for a given asset."""
materials = asset.root_physx_view.get_material_properties()
materials[..., 0] = value # Static friction.
materials[..., 1] = value # Dynamic friction.
env_ids = torch.arange(self.scene.num_envs, device="cpu")
asset.root_physx_view.set_material_properties(materials, env_ids)
def _init_tensors(self): def _init_tensors(self):
"""Initialize tensors once.""" """Initialize tensors once."""
self.identity_quat = (
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1)
)
# Control targets. # Control targets.
self.ctrl_target_joint_pos = torch.zeros((self.num_envs, self._robot.num_joints), device=self.device) self.ctrl_target_joint_pos = torch.zeros((self.num_envs, self._robot.num_joints), device=self.device)
self.ctrl_target_fingertip_midpoint_pos = torch.zeros((self.num_envs, 3), device=self.device) self.ema_factor = self.cfg.ctrl.ema_factor
self.ctrl_target_fingertip_midpoint_quat = torch.zeros((self.num_envs, 4), device=self.device) self.dead_zone_thresholds = None
# Fixed asset. # Fixed asset.
self.fixed_pos_action_frame = torch.zeros((self.num_envs, 3), device=self.device)
self.fixed_pos_obs_frame = torch.zeros((self.num_envs, 3), device=self.device) self.fixed_pos_obs_frame = torch.zeros((self.num_envs, 3), device=self.device)
self.init_fixed_pos_obs_noise = torch.zeros((self.num_envs, 3), device=self.device) self.init_fixed_pos_obs_noise = torch.zeros((self.num_envs, 3), device=self.device)
# Held asset
held_base_x_offset = 0.0
if self.cfg_task.name == "peg_insert":
held_base_z_offset = 0.0
elif self.cfg_task.name == "gear_mesh":
gear_base_offset = self._get_target_gear_base_offset()
held_base_x_offset = gear_base_offset[0]
held_base_z_offset = gear_base_offset[2]
elif self.cfg_task.name == "nut_thread":
held_base_z_offset = self.cfg_task.fixed_asset_cfg.base_height
else:
raise NotImplementedError("Task not implemented")
self.held_base_pos_local = torch.tensor([0.0, 0.0, 0.0], device=self.device).repeat((self.num_envs, 1))
self.held_base_pos_local[:, 0] = held_base_x_offset
self.held_base_pos_local[:, 2] = held_base_z_offset
self.held_base_quat_local = self.identity_quat.clone().detach()
self.held_base_pos = torch.zeros_like(self.held_base_pos_local)
self.held_base_quat = self.identity_quat.clone().detach()
# Computer body indices. # Computer body indices.
self.left_finger_body_idx = self._robot.body_names.index("panda_leftfinger") self.left_finger_body_idx = self._robot.body_names.index("panda_leftfinger")
self.right_finger_body_idx = self._robot.body_names.index("panda_rightfinger") self.right_finger_body_idx = self._robot.body_names.index("panda_rightfinger")
...@@ -117,44 +74,14 @@ class FactoryEnv(DirectRLEnv): ...@@ -117,44 +74,14 @@ class FactoryEnv(DirectRLEnv):
# Tensors for finite-differencing. # Tensors for finite-differencing.
self.last_update_timestamp = 0.0 # Note: This is for finite differencing body velocities. self.last_update_timestamp = 0.0 # Note: This is for finite differencing body velocities.
self.prev_fingertip_pos = torch.zeros((self.num_envs, 3), device=self.device) self.prev_fingertip_pos = torch.zeros((self.num_envs, 3), device=self.device)
self.prev_fingertip_quat = self.identity_quat.clone() self.prev_fingertip_quat = (
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1)
)
self.prev_joint_pos = torch.zeros((self.num_envs, 7), device=self.device) self.prev_joint_pos = torch.zeros((self.num_envs, 7), device=self.device)
# Keypoint tensors.
self.target_held_base_pos = torch.zeros((self.num_envs, 3), device=self.device)
self.target_held_base_quat = self.identity_quat.clone().detach()
offsets = self._get_keypoint_offsets(self.cfg_task.num_keypoints)
self.keypoint_offsets = offsets * self.cfg_task.keypoint_scale
self.keypoints_held = torch.zeros((self.num_envs, self.cfg_task.num_keypoints, 3), device=self.device)
self.keypoints_fixed = torch.zeros_like(self.keypoints_held, device=self.device)
# Used to compute target poses.
self.fixed_success_pos_local = torch.zeros((self.num_envs, 3), device=self.device)
if self.cfg_task.name == "peg_insert":
self.fixed_success_pos_local[:, 2] = 0.0
elif self.cfg_task.name == "gear_mesh":
gear_base_offset = self._get_target_gear_base_offset()
self.fixed_success_pos_local[:, 0] = gear_base_offset[0]
self.fixed_success_pos_local[:, 2] = gear_base_offset[2]
elif self.cfg_task.name == "nut_thread":
head_height = self.cfg_task.fixed_asset_cfg.base_height
shank_length = self.cfg_task.fixed_asset_cfg.height
thread_pitch = self.cfg_task.fixed_asset_cfg.thread_pitch
self.fixed_success_pos_local[:, 2] = head_height + shank_length - thread_pitch * 1.5
else:
raise NotImplementedError("Task not implemented")
self.ep_succeeded = torch.zeros((self.num_envs,), dtype=torch.long, device=self.device) self.ep_succeeded = torch.zeros((self.num_envs,), dtype=torch.long, device=self.device)
self.ep_success_times = torch.zeros((self.num_envs,), dtype=torch.long, device=self.device) self.ep_success_times = torch.zeros((self.num_envs,), dtype=torch.long, device=self.device)
def _get_keypoint_offsets(self, num_keypoints):
"""Get uniformly-spaced keypoints along a line of unit length, centered at 0."""
keypoint_offsets = torch.zeros((num_keypoints, 3), device=self.device)
keypoint_offsets[:, -1] = torch.linspace(0.0, 1.0, num_keypoints, device=self.device) - 0.5
return keypoint_offsets
def _setup_scene(self): def _setup_scene(self):
"""Initialize simulation scene.""" """Initialize simulation scene."""
spawn_ground_plane(prim_path="/World/ground", cfg=GroundPlaneCfg(), translation=(0.0, 0.0, -1.05)) spawn_ground_plane(prim_path="/World/ground", cfg=GroundPlaneCfg(), translation=(0.0, 0.0, -1.05))
...@@ -228,31 +155,10 @@ class FactoryEnv(DirectRLEnv): ...@@ -228,31 +155,10 @@ class FactoryEnv(DirectRLEnv):
self.joint_vel_fd = joint_diff / dt self.joint_vel_fd = joint_diff / dt
self.prev_joint_pos = self.joint_pos[:, 0:7].clone() self.prev_joint_pos = self.joint_pos[:, 0:7].clone()
# Keypoint tensors.
self.held_base_quat[:], self.held_base_pos[:] = torch_utils.tf_combine(
self.held_quat, self.held_pos, self.held_base_quat_local, self.held_base_pos_local
)
self.target_held_base_quat[:], self.target_held_base_pos[:] = torch_utils.tf_combine(
self.fixed_quat, self.fixed_pos, self.identity_quat, self.fixed_success_pos_local
)
# Compute pos of keypoints on held asset, and fixed asset in world frame
for idx, keypoint_offset in enumerate(self.keypoint_offsets):
self.keypoints_held[:, idx] = torch_utils.tf_combine(
self.held_base_quat, self.held_base_pos, self.identity_quat, keypoint_offset.repeat(self.num_envs, 1)
)[1]
self.keypoints_fixed[:, idx] = torch_utils.tf_combine(
self.target_held_base_quat,
self.target_held_base_pos,
self.identity_quat,
keypoint_offset.repeat(self.num_envs, 1),
)[1]
self.keypoint_dist = torch.norm(self.keypoints_held - self.keypoints_fixed, p=2, dim=-1).mean(-1)
self.last_update_timestamp = self._robot._data._sim_timestamp self.last_update_timestamp = self._robot._data._sim_timestamp
def _get_observations(self): def _get_factory_obs_state_dict(self):
"""Get actor/critic inputs using asymmetric critic.""" """Populate dictionaries for the policy and critic."""
noisy_fixed_pos = self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise noisy_fixed_pos = self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise
prev_actions = self.actions.clone() prev_actions = self.actions.clone()
...@@ -283,15 +189,20 @@ class FactoryEnv(DirectRLEnv): ...@@ -283,15 +189,20 @@ class FactoryEnv(DirectRLEnv):
"rot_threshold": self.rot_threshold, "rot_threshold": self.rot_threshold,
"prev_actions": prev_actions, "prev_actions": prev_actions,
} }
obs_tensors = [obs_dict[obs_name] for obs_name in self.cfg.obs_order + ["prev_actions"]] return obs_dict, state_dict
obs_tensors = torch.cat(obs_tensors, dim=-1)
state_tensors = [state_dict[state_name] for state_name in self.cfg.state_order + ["prev_actions"]] def _get_observations(self):
state_tensors = torch.cat(state_tensors, dim=-1) """Get actor/critic inputs using asymmetric critic."""
obs_dict, state_dict = self._get_factory_obs_state_dict()
obs_tensors = factory_utils.collapse_obs_dict(obs_dict, self.cfg.obs_order + ["prev_actions"])
state_tensors = factory_utils.collapse_obs_dict(state_dict, self.cfg.state_order + ["prev_actions"])
return {"policy": obs_tensors, "critic": state_tensors} return {"policy": obs_tensors, "critic": state_tensors}
def _reset_buffers(self, env_ids): def _reset_buffers(self, env_ids):
"""Reset buffers.""" """Reset buffers."""
self.ep_succeeded[env_ids] = 0 self.ep_succeeded[env_ids] = 0
self.ep_success_times[env_ids] = 0
def _pre_physics_step(self, action): def _pre_physics_step(self, action):
"""Apply policy actions with smoothing.""" """Apply policy actions with smoothing."""
...@@ -299,18 +210,15 @@ class FactoryEnv(DirectRLEnv): ...@@ -299,18 +210,15 @@ class FactoryEnv(DirectRLEnv):
if len(env_ids) > 0: if len(env_ids) > 0:
self._reset_buffers(env_ids) self._reset_buffers(env_ids)
self.actions = ( self.actions = self.ema_factor * action.clone().to(self.device) + (1 - self.ema_factor) * self.actions
self.cfg.ctrl.ema_factor * action.clone().to(self.device) + (1 - self.cfg.ctrl.ema_factor) * self.actions
)
def close_gripper_in_place(self): def close_gripper_in_place(self):
"""Keep gripper in current position as gripper closes.""" """Keep gripper in current position as gripper closes."""
actions = torch.zeros((self.num_envs, 6), device=self.device) actions = torch.zeros((self.num_envs, 6), device=self.device)
ctrl_target_gripper_dof_pos = 0.0
# Interpret actions as target pos displacements and set pos target # Interpret actions as target pos displacements and set pos target
pos_actions = actions[:, 0:3] * self.pos_threshold pos_actions = actions[:, 0:3] * self.pos_threshold
self.ctrl_target_fingertip_midpoint_pos = self.fingertip_midpoint_pos + pos_actions ctrl_target_fingertip_midpoint_pos = self.fingertip_midpoint_pos + pos_actions
# Interpret actions as target rot (axis-angle) displacements # Interpret actions as target rot (axis-angle) displacements
rot_actions = actions[:, 3:6] rot_actions = actions[:, 3:6]
...@@ -326,25 +234,24 @@ class FactoryEnv(DirectRLEnv): ...@@ -326,25 +234,24 @@ class FactoryEnv(DirectRLEnv):
rot_actions_quat, rot_actions_quat,
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(self.num_envs, 1), torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(self.num_envs, 1),
) )
self.ctrl_target_fingertip_midpoint_quat = torch_utils.quat_mul(rot_actions_quat, self.fingertip_midpoint_quat) ctrl_target_fingertip_midpoint_quat = torch_utils.quat_mul(rot_actions_quat, self.fingertip_midpoint_quat)
target_euler_xyz = torch.stack(torch_utils.get_euler_xyz(self.ctrl_target_fingertip_midpoint_quat), dim=1) target_euler_xyz = torch.stack(torch_utils.get_euler_xyz(ctrl_target_fingertip_midpoint_quat), dim=1)
target_euler_xyz[:, 0] = 3.14159 target_euler_xyz[:, 0] = 3.14159
target_euler_xyz[:, 1] = 0.0 target_euler_xyz[:, 1] = 0.0
self.ctrl_target_fingertip_midpoint_quat = torch_utils.quat_from_euler_xyz( ctrl_target_fingertip_midpoint_quat = torch_utils.quat_from_euler_xyz(
roll=target_euler_xyz[:, 0], pitch=target_euler_xyz[:, 1], yaw=target_euler_xyz[:, 2] roll=target_euler_xyz[:, 0], pitch=target_euler_xyz[:, 1], yaw=target_euler_xyz[:, 2]
) )
self.ctrl_target_gripper_dof_pos = ctrl_target_gripper_dof_pos self.generate_ctrl_signals(
self.generate_ctrl_signals() ctrl_target_fingertip_midpoint_pos=ctrl_target_fingertip_midpoint_pos,
ctrl_target_fingertip_midpoint_quat=ctrl_target_fingertip_midpoint_quat,
ctrl_target_gripper_dof_pos=0.0,
)
def _apply_action(self): def _apply_action(self):
"""Apply actions for policy as delta targets from current position.""" """Apply actions for policy as delta targets from current position."""
# Get current yaw for success checking.
_, _, curr_yaw = torch_utils.get_euler_xyz(self.fingertip_midpoint_quat)
self.curr_yaw = torch.where(curr_yaw > np.deg2rad(235), curr_yaw - 2 * np.pi, curr_yaw)
# Note: We use finite-differenced velocities for control and observations. # Note: We use finite-differenced velocities for control and observations.
# Check if we need to re-compute velocities within the decimation loop. # Check if we need to re-compute velocities within the decimation loop.
if self.last_update_timestamp < self._robot._data._sim_timestamp: if self.last_update_timestamp < self._robot._data._sim_timestamp:
...@@ -359,13 +266,14 @@ class FactoryEnv(DirectRLEnv): ...@@ -359,13 +266,14 @@ class FactoryEnv(DirectRLEnv):
rot_actions[:, 2] = -(rot_actions[:, 2] + 1.0) * 0.5 # [-1, 0] rot_actions[:, 2] = -(rot_actions[:, 2] + 1.0) * 0.5 # [-1, 0]
rot_actions = rot_actions * self.rot_threshold rot_actions = rot_actions * self.rot_threshold
self.ctrl_target_fingertip_midpoint_pos = self.fingertip_midpoint_pos + pos_actions ctrl_target_fingertip_midpoint_pos = self.fingertip_midpoint_pos + pos_actions
# To speed up learning, never allow the policy to move more than 5cm away from the base. # To speed up learning, never allow the policy to move more than 5cm away from the base.
delta_pos = self.ctrl_target_fingertip_midpoint_pos - self.fixed_pos_action_frame fixed_pos_action_frame = self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise
delta_pos = ctrl_target_fingertip_midpoint_pos - fixed_pos_action_frame
pos_error_clipped = torch.clip( pos_error_clipped = torch.clip(
delta_pos, -self.cfg.ctrl.pos_action_bounds[0], self.cfg.ctrl.pos_action_bounds[1] delta_pos, -self.cfg.ctrl.pos_action_bounds[0], self.cfg.ctrl.pos_action_bounds[1]
) )
self.ctrl_target_fingertip_midpoint_pos = self.fixed_pos_action_frame + pos_error_clipped ctrl_target_fingertip_midpoint_pos = fixed_pos_action_frame + pos_error_clipped
# Convert to quat and set rot target # Convert to quat and set rot target
angle = torch.norm(rot_actions, p=2, dim=-1) angle = torch.norm(rot_actions, p=2, dim=-1)
...@@ -377,53 +285,57 @@ class FactoryEnv(DirectRLEnv): ...@@ -377,53 +285,57 @@ class FactoryEnv(DirectRLEnv):
rot_actions_quat, rot_actions_quat,
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(self.num_envs, 1), torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).repeat(self.num_envs, 1),
) )
self.ctrl_target_fingertip_midpoint_quat = torch_utils.quat_mul(rot_actions_quat, self.fingertip_midpoint_quat) ctrl_target_fingertip_midpoint_quat = torch_utils.quat_mul(rot_actions_quat, self.fingertip_midpoint_quat)
target_euler_xyz = torch.stack(torch_utils.get_euler_xyz(self.ctrl_target_fingertip_midpoint_quat), dim=1) target_euler_xyz = torch.stack(torch_utils.get_euler_xyz(ctrl_target_fingertip_midpoint_quat), dim=1)
target_euler_xyz[:, 0] = 3.14159 # Restrict actions to be upright. target_euler_xyz[:, 0] = 3.14159 # Restrict actions to be upright.
target_euler_xyz[:, 1] = 0.0 target_euler_xyz[:, 1] = 0.0
self.ctrl_target_fingertip_midpoint_quat = torch_utils.quat_from_euler_xyz( ctrl_target_fingertip_midpoint_quat = torch_utils.quat_from_euler_xyz(
roll=target_euler_xyz[:, 0], pitch=target_euler_xyz[:, 1], yaw=target_euler_xyz[:, 2] roll=target_euler_xyz[:, 0], pitch=target_euler_xyz[:, 1], yaw=target_euler_xyz[:, 2]
) )
self.ctrl_target_gripper_dof_pos = 0.0 self.generate_ctrl_signals(
self.generate_ctrl_signals() ctrl_target_fingertip_midpoint_pos=ctrl_target_fingertip_midpoint_pos,
ctrl_target_fingertip_midpoint_quat=ctrl_target_fingertip_midpoint_quat,
def _set_gains(self, prop_gains, rot_deriv_scale=1.0): ctrl_target_gripper_dof_pos=0.0,
"""Set robot gains using critical damping.""" )
self.task_prop_gains = prop_gains
self.task_deriv_gains = 2 * torch.sqrt(prop_gains)
self.task_deriv_gains[:, 3:6] /= rot_deriv_scale
def generate_ctrl_signals(self): def generate_ctrl_signals(
self, ctrl_target_fingertip_midpoint_pos, ctrl_target_fingertip_midpoint_quat, ctrl_target_gripper_dof_pos
):
"""Get Jacobian. Set Franka DOF position targets (fingers) or DOF torques (arm).""" """Get Jacobian. Set Franka DOF position targets (fingers) or DOF torques (arm)."""
self.joint_torque, self.applied_wrench = fc.compute_dof_torque( self.joint_torque, self.applied_wrench = factory_control.compute_dof_torque(
cfg=self.cfg, cfg=self.cfg,
dof_pos=self.joint_pos, dof_pos=self.joint_pos,
dof_vel=self.joint_vel, # _fd, dof_vel=self.joint_vel,
fingertip_midpoint_pos=self.fingertip_midpoint_pos, fingertip_midpoint_pos=self.fingertip_midpoint_pos,
fingertip_midpoint_quat=self.fingertip_midpoint_quat, fingertip_midpoint_quat=self.fingertip_midpoint_quat,
fingertip_midpoint_linvel=self.ee_linvel_fd, fingertip_midpoint_linvel=self.fingertip_midpoint_linvel,
fingertip_midpoint_angvel=self.ee_angvel_fd, fingertip_midpoint_angvel=self.fingertip_midpoint_angvel,
jacobian=self.fingertip_midpoint_jacobian, jacobian=self.fingertip_midpoint_jacobian,
arm_mass_matrix=self.arm_mass_matrix, arm_mass_matrix=self.arm_mass_matrix,
ctrl_target_fingertip_midpoint_pos=self.ctrl_target_fingertip_midpoint_pos, ctrl_target_fingertip_midpoint_pos=ctrl_target_fingertip_midpoint_pos,
ctrl_target_fingertip_midpoint_quat=self.ctrl_target_fingertip_midpoint_quat, ctrl_target_fingertip_midpoint_quat=ctrl_target_fingertip_midpoint_quat,
task_prop_gains=self.task_prop_gains, task_prop_gains=self.task_prop_gains,
task_deriv_gains=self.task_deriv_gains, task_deriv_gains=self.task_deriv_gains,
device=self.device, device=self.device,
dead_zone_thresholds=self.dead_zone_thresholds,
) )
# set target for gripper joints to use physx's PD controller # set target for gripper joints to use physx's PD controller
self.ctrl_target_joint_pos[:, 7:9] = self.ctrl_target_gripper_dof_pos self.ctrl_target_joint_pos[:, 7:9] = ctrl_target_gripper_dof_pos
self.joint_torque[:, 7:9] = 0.0 self.joint_torque[:, 7:9] = 0.0
self._robot.set_joint_position_target(self.ctrl_target_joint_pos) self._robot.set_joint_position_target(self.ctrl_target_joint_pos)
self._robot.set_joint_effort_target(self.joint_torque) self._robot.set_joint_effort_target(self.joint_torque)
def _get_dones(self): def _get_dones(self):
"""Update intermediate values used for rewards and observations.""" """Check which environments are terminated.
For Factory reset logic, it is important that all environments
stay in sync (i.e., _get_dones should return all true or all false).
"""
self._compute_intermediate_values(dt=self.physics_dt) self._compute_intermediate_values(dt=self.physics_dt)
time_out = self.episode_length_buf >= self.max_episode_length - 1 time_out = self.episode_length_buf >= self.max_episode_length - 1
return time_out, time_out return time_out, time_out
...@@ -432,8 +344,20 @@ class FactoryEnv(DirectRLEnv): ...@@ -432,8 +344,20 @@ class FactoryEnv(DirectRLEnv):
"""Get success mask at current timestep.""" """Get success mask at current timestep."""
curr_successes = torch.zeros((self.num_envs,), dtype=torch.bool, device=self.device) curr_successes = torch.zeros((self.num_envs,), dtype=torch.bool, device=self.device)
xy_dist = torch.linalg.vector_norm(self.target_held_base_pos[:, 0:2] - self.held_base_pos[:, 0:2], dim=1) held_base_pos, held_base_quat = factory_utils.get_held_base_pose(
z_disp = self.held_base_pos[:, 2] - self.target_held_base_pos[:, 2] self.held_pos, self.held_quat, self.cfg_task.name, self.cfg_task.fixed_asset_cfg, self.num_envs, self.device
)
target_held_base_pos, target_held_base_quat = factory_utils.get_target_held_base_pose(
self.fixed_pos,
self.fixed_quat,
self.cfg_task.name,
self.cfg_task.fixed_asset_cfg,
self.num_envs,
self.device,
)
xy_dist = torch.linalg.vector_norm(target_held_base_pos[:, 0:2] - held_base_pos[:, 0:2], dim=1)
z_disp = held_base_pos[:, 2] - target_held_base_pos[:, 2]
is_centered = torch.where(xy_dist < 0.0025, torch.ones_like(curr_successes), torch.zeros_like(curr_successes)) is_centered = torch.where(xy_dist < 0.0025, torch.ones_like(curr_successes), torch.zeros_like(curr_successes))
# Height threshold to target # Height threshold to target
...@@ -450,21 +374,15 @@ class FactoryEnv(DirectRLEnv): ...@@ -450,21 +374,15 @@ class FactoryEnv(DirectRLEnv):
curr_successes = torch.logical_and(is_centered, is_close_or_below) curr_successes = torch.logical_and(is_centered, is_close_or_below)
if check_rot: if check_rot:
is_rotated = self.curr_yaw < self.cfg_task.ee_success_yaw _, _, curr_yaw = torch_utils.get_euler_xyz(self.fingertip_midpoint_quat)
curr_yaw = factory_utils.wrap_yaw(curr_yaw)
is_rotated = curr_yaw < self.cfg_task.ee_success_yaw
curr_successes = torch.logical_and(curr_successes, is_rotated) curr_successes = torch.logical_and(curr_successes, is_rotated)
return curr_successes return curr_successes
def _get_rewards(self): def _log_factory_metrics(self, rew_dict, curr_successes):
"""Update rewards and compute success statistics.""" """Keep track of episode statistics and log rewards."""
# Get successful and failed envs at current timestep
check_rot = self.cfg_task.name == "nut_thread"
curr_successes = self._get_curr_successes(
success_threshold=self.cfg_task.success_threshold, check_rot=check_rot
)
rew_buf = self._update_rew_buf(curr_successes)
# Only log episode success rates at the end of an episode. # Only log episode success rates at the end of an episode.
if torch.any(self.reset_buf): if torch.any(self.reset_buf):
self.extras["successes"] = torch.count_nonzero(curr_successes) / self.num_envs self.extras["successes"] = torch.count_nonzero(curr_successes) / self.num_envs
...@@ -481,53 +399,94 @@ class FactoryEnv(DirectRLEnv): ...@@ -481,53 +399,94 @@ class FactoryEnv(DirectRLEnv):
success_times = self.ep_success_times[nonzero_success_ids].sum() / len(nonzero_success_ids) success_times = self.ep_success_times[nonzero_success_ids].sum() / len(nonzero_success_ids)
self.extras["success_times"] = success_times self.extras["success_times"] = success_times
for rew_name, rew in rew_dict.items():
self.extras[f"logs_rew_{rew_name}"] = rew.mean()
def _get_rewards(self):
"""Update rewards and compute success statistics."""
# Get successful and failed envs at current timestep
check_rot = self.cfg_task.name == "nut_thread"
curr_successes = self._get_curr_successes(
success_threshold=self.cfg_task.success_threshold, check_rot=check_rot
)
rew_dict, rew_scales = self._get_factory_rew_dict(curr_successes)
rew_buf = torch.zeros_like(rew_dict["kp_coarse"])
for rew_name, rew in rew_dict.items():
rew_buf += rew_dict[rew_name] * rew_scales[rew_name]
self.prev_actions = self.actions.clone() self.prev_actions = self.actions.clone()
self._log_factory_metrics(rew_dict, curr_successes)
return rew_buf return rew_buf
def _update_rew_buf(self, curr_successes): def _get_factory_rew_dict(self, curr_successes):
"""Compute reward at current timestep.""" """Compute reward terms at current timestep."""
rew_dict = {} rew_dict, rew_scales = {}, {}
# Keypoint rewards. # Compute pos of keypoints on held asset, and fixed asset in world frame
def squashing_fn(x, a, b): held_base_pos, held_base_quat = factory_utils.get_held_base_pose(
return 1 / (torch.exp(a * x) + b + torch.exp(-a * x)) self.held_pos, self.held_quat, self.cfg_task.name, self.cfg_task.fixed_asset_cfg, self.num_envs, self.device
)
target_held_base_pos, target_held_base_quat = factory_utils.get_target_held_base_pose(
self.fixed_pos,
self.fixed_quat,
self.cfg_task.name,
self.cfg_task.fixed_asset_cfg,
self.num_envs,
self.device,
)
keypoints_held = torch.zeros((self.num_envs, self.cfg_task.num_keypoints, 3), device=self.device)
keypoints_fixed = torch.zeros((self.num_envs, self.cfg_task.num_keypoints, 3), device=self.device)
offsets = factory_utils.get_keypoint_offsets(self.cfg_task.num_keypoints, self.device)
keypoint_offsets = offsets * self.cfg_task.keypoint_scale
for idx, keypoint_offset in enumerate(keypoint_offsets):
keypoints_held[:, idx] = torch_utils.tf_combine(
held_base_quat,
held_base_pos,
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1),
keypoint_offset.repeat(self.num_envs, 1),
)[1]
keypoints_fixed[:, idx] = torch_utils.tf_combine(
target_held_base_quat,
target_held_base_pos,
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1),
keypoint_offset.repeat(self.num_envs, 1),
)[1]
keypoint_dist = torch.norm(keypoints_held - keypoints_fixed, p=2, dim=-1).mean(-1)
a0, b0 = self.cfg_task.keypoint_coef_baseline a0, b0 = self.cfg_task.keypoint_coef_baseline
rew_dict["kp_baseline"] = squashing_fn(self.keypoint_dist, a0, b0)
# a1, b1 = 25, 2
a1, b1 = self.cfg_task.keypoint_coef_coarse a1, b1 = self.cfg_task.keypoint_coef_coarse
rew_dict["kp_coarse"] = squashing_fn(self.keypoint_dist, a1, b1)
a2, b2 = self.cfg_task.keypoint_coef_fine a2, b2 = self.cfg_task.keypoint_coef_fine
# a2, b2 = 300, 0
rew_dict["kp_fine"] = squashing_fn(self.keypoint_dist, a2, b2)
# Action penalties. # Action penalties.
rew_dict["action_penalty"] = torch.norm(self.actions, p=2) action_penalty_ee = torch.norm(self.actions, p=2)
rew_dict["action_grad_penalty"] = torch.norm(self.actions - self.prev_actions, p=2, dim=-1) action_grad_penalty = torch.norm(self.actions - self.prev_actions, p=2, dim=-1)
rew_dict["curr_engaged"] = ( curr_engaged = self._get_curr_successes(success_threshold=self.cfg_task.engage_threshold, check_rot=False)
self._get_curr_successes(success_threshold=self.cfg_task.engage_threshold, check_rot=False).clone().float()
) rew_dict = {
rew_dict["curr_successes"] = curr_successes.clone().float() "kp_baseline": factory_utils.squashing_fn(keypoint_dist, a0, b0),
"kp_coarse": factory_utils.squashing_fn(keypoint_dist, a1, b1),
rew_buf = ( "kp_fine": factory_utils.squashing_fn(keypoint_dist, a2, b2),
rew_dict["kp_coarse"] "action_penalty_ee": action_penalty_ee,
+ rew_dict["kp_baseline"] "action_grad_penalty": action_grad_penalty,
+ rew_dict["kp_fine"] "curr_engaged": curr_engaged.float(),
- rew_dict["action_penalty"] * self.cfg_task.action_penalty_scale "curr_success": curr_successes.float(),
- rew_dict["action_grad_penalty"] * self.cfg_task.action_grad_penalty_scale }
+ rew_dict["curr_engaged"] rew_scales = {
+ rew_dict["curr_successes"] "kp_baseline": 1.0,
) "kp_coarse": 1.0,
"kp_fine": 1.0,
for rew_name, rew in rew_dict.items(): "action_penalty_ee": -self.cfg_task.action_penalty_ee_scale,
self.extras[f"logs_rew_{rew_name}"] = rew.mean() "action_grad_penalty": -self.cfg_task.action_grad_penalty_scale,
"curr_engaged": 1.0,
return rew_buf "curr_success": 1.0,
}
return rew_dict, rew_scales
def _reset_idx(self, env_ids): def _reset_idx(self, env_ids):
""" """We assume all envs will always be reset at the same time."""
We assume all envs will always be reset at the same time.
"""
super()._reset_idx(env_ids) super()._reset_idx(env_ids)
self._set_assets_to_default_pose(env_ids) self._set_assets_to_default_pose(env_ids)
...@@ -536,19 +495,6 @@ class FactoryEnv(DirectRLEnv): ...@@ -536,19 +495,6 @@ class FactoryEnv(DirectRLEnv):
self.randomize_initial_state(env_ids) self.randomize_initial_state(env_ids)
def _get_target_gear_base_offset(self):
"""Get offset of target gear from the gear base asset."""
target_gear = self.cfg_task.target_gear
if target_gear == "gear_large":
gear_base_offset = self.cfg_task.fixed_asset_cfg.large_gear_base_offset
elif target_gear == "gear_medium":
gear_base_offset = self.cfg_task.fixed_asset_cfg.medium_gear_base_offset
elif target_gear == "gear_small":
gear_base_offset = self.cfg_task.fixed_asset_cfg.small_gear_base_offset
else:
raise ValueError(f"{target_gear} not valid in this context!")
return gear_base_offset
def _set_assets_to_default_pose(self, env_ids): def _set_assets_to_default_pose(self, env_ids):
"""Move assets to default pose before randomization.""" """Move assets to default pose before randomization."""
held_state = self._held_asset.data.default_root_state.clone()[env_ids] held_state = self._held_asset.data.default_root_state.clone()[env_ids]
...@@ -565,16 +511,18 @@ class FactoryEnv(DirectRLEnv): ...@@ -565,16 +511,18 @@ class FactoryEnv(DirectRLEnv):
self._fixed_asset.write_root_velocity_to_sim(fixed_state[:, 7:], env_ids=env_ids) self._fixed_asset.write_root_velocity_to_sim(fixed_state[:, 7:], env_ids=env_ids)
self._fixed_asset.reset() self._fixed_asset.reset()
def set_pos_inverse_kinematics(self, env_ids): def set_pos_inverse_kinematics(
self, ctrl_target_fingertip_midpoint_pos, ctrl_target_fingertip_midpoint_quat, env_ids
):
"""Set robot joint position using DLS IK.""" """Set robot joint position using DLS IK."""
ik_time = 0.0 ik_time = 0.0
while ik_time < 0.25: while ik_time < 0.25:
# Compute error to target. # Compute error to target.
pos_error, axis_angle_error = fc.get_pose_error( pos_error, axis_angle_error = factory_control.get_pose_error(
fingertip_midpoint_pos=self.fingertip_midpoint_pos[env_ids], fingertip_midpoint_pos=self.fingertip_midpoint_pos[env_ids],
fingertip_midpoint_quat=self.fingertip_midpoint_quat[env_ids], fingertip_midpoint_quat=self.fingertip_midpoint_quat[env_ids],
ctrl_target_fingertip_midpoint_pos=self.ctrl_target_fingertip_midpoint_pos[env_ids], ctrl_target_fingertip_midpoint_pos=ctrl_target_fingertip_midpoint_pos[env_ids],
ctrl_target_fingertip_midpoint_quat=self.ctrl_target_fingertip_midpoint_quat[env_ids], ctrl_target_fingertip_midpoint_quat=ctrl_target_fingertip_midpoint_quat[env_ids],
jacobian_type="geometric", jacobian_type="geometric",
rot_error_type="axis_angle", rot_error_type="axis_angle",
) )
...@@ -582,7 +530,7 @@ class FactoryEnv(DirectRLEnv): ...@@ -582,7 +530,7 @@ class FactoryEnv(DirectRLEnv):
delta_hand_pose = torch.cat((pos_error, axis_angle_error), dim=-1) delta_hand_pose = torch.cat((pos_error, axis_angle_error), dim=-1)
# Solve DLS problem. # Solve DLS problem.
delta_dof_pos = fc._get_delta_dof_pos( delta_dof_pos = factory_control.get_delta_dof_pos(
delta_pose=delta_hand_pose, delta_pose=delta_hand_pose,
ik_method="dls", ik_method="dls",
jacobian=self.fingertip_midpoint_jacobian[env_ids], jacobian=self.fingertip_midpoint_jacobian[env_ids],
...@@ -605,21 +553,25 @@ class FactoryEnv(DirectRLEnv): ...@@ -605,21 +553,25 @@ class FactoryEnv(DirectRLEnv):
def get_handheld_asset_relative_pose(self): def get_handheld_asset_relative_pose(self):
"""Get default relative pose between help asset and fingertip.""" """Get default relative pose between help asset and fingertip."""
if self.cfg_task.name == "peg_insert": if self.cfg_task.name == "peg_insert":
held_asset_relative_pos = torch.zeros_like(self.held_base_pos_local) held_asset_relative_pos = torch.zeros((self.num_envs, 3), device=self.device)
held_asset_relative_pos[:, 2] = self.cfg_task.held_asset_cfg.height held_asset_relative_pos[:, 2] = self.cfg_task.held_asset_cfg.height
held_asset_relative_pos[:, 2] -= self.cfg_task.robot_cfg.franka_fingerpad_length held_asset_relative_pos[:, 2] -= self.cfg_task.robot_cfg.franka_fingerpad_length
elif self.cfg_task.name == "gear_mesh": elif self.cfg_task.name == "gear_mesh":
held_asset_relative_pos = torch.zeros_like(self.held_base_pos_local) held_asset_relative_pos = torch.zeros((self.num_envs, 3), device=self.device)
gear_base_offset = self._get_target_gear_base_offset() gear_base_offset = self.cfg_task.fixed_asset_cfg.medium_gear_base_offset
held_asset_relative_pos[:, 0] += gear_base_offset[0] held_asset_relative_pos[:, 0] += gear_base_offset[0]
held_asset_relative_pos[:, 2] += gear_base_offset[2] held_asset_relative_pos[:, 2] += gear_base_offset[2]
held_asset_relative_pos[:, 2] += self.cfg_task.held_asset_cfg.height / 2.0 * 1.1 held_asset_relative_pos[:, 2] += self.cfg_task.held_asset_cfg.height / 2.0 * 1.1
elif self.cfg_task.name == "nut_thread": elif self.cfg_task.name == "nut_thread":
held_asset_relative_pos = self.held_base_pos_local held_asset_relative_pos = factory_utils.get_held_base_pos_local(
self.cfg_task.name, self.cfg_task.fixed_asset_cfg, self.num_envs, self.device
)
else: else:
raise NotImplementedError("Task not implemented") raise NotImplementedError("Task not implemented")
held_asset_relative_quat = self.identity_quat held_asset_relative_quat = (
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1)
)
if self.cfg_task.name == "nut_thread": if self.cfg_task.name == "nut_thread":
# Rotate along z-axis of frame for default position. # Rotate along z-axis of frame for default position.
initial_rot_deg = self.cfg_task.held_asset_rot_init initial_rot_deg = self.cfg_task.held_asset_rot_init
...@@ -649,7 +601,11 @@ class FactoryEnv(DirectRLEnv): ...@@ -649,7 +601,11 @@ class FactoryEnv(DirectRLEnv):
self.step_sim_no_action() self.step_sim_no_action()
def step_sim_no_action(self): def step_sim_no_action(self):
"""Step the simulation without an action. Used for resets.""" """Step the simulation without an action. Used for resets only.
This method should only be called during resets when all environments
reset at the same time.
"""
self.scene.write_data_to_sim() self.scene.write_data_to_sim()
self.sim.step(render=False) self.sim.step(render=False)
self.scene.update(dt=self.physics_dt) self.scene.update(dt=self.physics_dt)
...@@ -698,14 +654,17 @@ class FactoryEnv(DirectRLEnv): ...@@ -698,14 +654,17 @@ class FactoryEnv(DirectRLEnv):
# Compute the frame on the bolt that would be used as observation: fixed_pos_obs_frame # Compute the frame on the bolt that would be used as observation: fixed_pos_obs_frame
# For example, the tip of the bolt can be used as the observation frame # For example, the tip of the bolt can be used as the observation frame
fixed_tip_pos_local = torch.zeros_like(self.fixed_pos) fixed_tip_pos_local = torch.zeros((self.num_envs, 3), device=self.device)
fixed_tip_pos_local[:, 2] += self.cfg_task.fixed_asset_cfg.height fixed_tip_pos_local[:, 2] += self.cfg_task.fixed_asset_cfg.height
fixed_tip_pos_local[:, 2] += self.cfg_task.fixed_asset_cfg.base_height fixed_tip_pos_local[:, 2] += self.cfg_task.fixed_asset_cfg.base_height
if self.cfg_task.name == "gear_mesh": if self.cfg_task.name == "gear_mesh":
fixed_tip_pos_local[:, 0] = self._get_target_gear_base_offset()[0] fixed_tip_pos_local[:, 0] = self.cfg_task.fixed_asset_cfg.medium_gear_base_offset[0]
_, fixed_tip_pos = torch_utils.tf_combine( _, fixed_tip_pos = torch_utils.tf_combine(
self.fixed_quat, self.fixed_pos, self.identity_quat, fixed_tip_pos_local self.fixed_quat,
self.fixed_pos,
torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1),
fixed_tip_pos_local,
) )
self.fixed_pos_obs_frame[:] = fixed_tip_pos self.fixed_pos_obs_frame[:] = fixed_tip_pos
...@@ -715,7 +674,6 @@ class FactoryEnv(DirectRLEnv): ...@@ -715,7 +674,6 @@ class FactoryEnv(DirectRLEnv):
ik_attempt = 0 ik_attempt = 0
hand_down_quat = torch.zeros((self.num_envs, 4), dtype=torch.float32, device=self.device) hand_down_quat = torch.zeros((self.num_envs, 4), dtype=torch.float32, device=self.device)
self.hand_down_euler = torch.zeros((self.num_envs, 3), dtype=torch.float32, device=self.device)
while True: while True:
n_bad = bad_envs.shape[0] n_bad = bad_envs.shape[0]
...@@ -738,16 +696,16 @@ class FactoryEnv(DirectRLEnv): ...@@ -738,16 +696,16 @@ class FactoryEnv(DirectRLEnv):
hand_init_orn_rand = torch.tensor(self.cfg_task.hand_init_orn_noise, device=self.device) hand_init_orn_rand = torch.tensor(self.cfg_task.hand_init_orn_noise, device=self.device)
above_fixed_orn_noise = above_fixed_orn_noise @ torch.diag(hand_init_orn_rand) above_fixed_orn_noise = above_fixed_orn_noise @ torch.diag(hand_init_orn_rand)
hand_down_euler += above_fixed_orn_noise hand_down_euler += above_fixed_orn_noise
self.hand_down_euler[bad_envs, ...] = hand_down_euler
hand_down_quat[bad_envs, :] = torch_utils.quat_from_euler_xyz( hand_down_quat[bad_envs, :] = torch_utils.quat_from_euler_xyz(
roll=hand_down_euler[:, 0], pitch=hand_down_euler[:, 1], yaw=hand_down_euler[:, 2] roll=hand_down_euler[:, 0], pitch=hand_down_euler[:, 1], yaw=hand_down_euler[:, 2]
) )
# (c) iterative IK Method # (c) iterative IK Method
self.ctrl_target_fingertip_midpoint_pos[bad_envs, ...] = above_fixed_pos[bad_envs, ...] pos_error, aa_error = self.set_pos_inverse_kinematics(
self.ctrl_target_fingertip_midpoint_quat[bad_envs, ...] = hand_down_quat[bad_envs, :] ctrl_target_fingertip_midpoint_pos=above_fixed_pos,
ctrl_target_fingertip_midpoint_quat=hand_down_quat,
pos_error, aa_error = self.set_pos_inverse_kinematics(env_ids=bad_envs) env_ids=bad_envs,
)
pos_error = torch.linalg.norm(pos_error, dim=1) > 1e-3 pos_error = torch.linalg.norm(pos_error, dim=1) > 1e-3
angle_error = torch.norm(aa_error, dim=1) > 1e-3 angle_error = torch.norm(aa_error, dim=1) > 1e-3
any_error = torch.logical_or(pos_error, angle_error) any_error = torch.logical_or(pos_error, angle_error)
...@@ -788,7 +746,7 @@ class FactoryEnv(DirectRLEnv): ...@@ -788,7 +746,7 @@ class FactoryEnv(DirectRLEnv):
q1=self.fingertip_midpoint_quat, q1=self.fingertip_midpoint_quat,
t1=self.fingertip_midpoint_pos, t1=self.fingertip_midpoint_pos,
q2=flip_z_quat, q2=flip_z_quat,
t2=torch.zeros_like(self.fingertip_midpoint_pos), t2=torch.zeros((self.num_envs, 3), device=self.device),
) )
# get default gripper in asset transform # get default gripper in asset transform
...@@ -803,17 +761,17 @@ class FactoryEnv(DirectRLEnv): ...@@ -803,17 +761,17 @@ class FactoryEnv(DirectRLEnv):
# Add asset in hand randomization # Add asset in hand randomization
rand_sample = torch.rand((self.num_envs, 3), dtype=torch.float32, device=self.device) rand_sample = torch.rand((self.num_envs, 3), dtype=torch.float32, device=self.device)
self.held_asset_pos_noise = 2 * (rand_sample - 0.5) # [-1, 1] held_asset_pos_noise = 2 * (rand_sample - 0.5) # [-1, 1]
if self.cfg_task.name == "gear_mesh": if self.cfg_task.name == "gear_mesh":
self.held_asset_pos_noise[:, 2] = -rand_sample[:, 2] # [-1, 0] held_asset_pos_noise[:, 2] = -rand_sample[:, 2] # [-1, 0]
held_asset_pos_noise = torch.tensor(self.cfg_task.held_asset_pos_noise, device=self.device) held_asset_pos_noise_level = torch.tensor(self.cfg_task.held_asset_pos_noise, device=self.device)
self.held_asset_pos_noise = self.held_asset_pos_noise @ torch.diag(held_asset_pos_noise) held_asset_pos_noise = held_asset_pos_noise @ torch.diag(held_asset_pos_noise_level)
translated_held_asset_quat, translated_held_asset_pos = torch_utils.tf_combine( translated_held_asset_quat, translated_held_asset_pos = torch_utils.tf_combine(
q1=translated_held_asset_quat, q1=translated_held_asset_quat,
t1=translated_held_asset_pos, t1=translated_held_asset_pos,
q2=self.identity_quat, q2=torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1),
t2=self.held_asset_pos_noise, t2=held_asset_pos_noise,
) )
held_state = self._held_asset.data.default_root_state.clone() held_state = self._held_asset.data.default_root_state.clone()
...@@ -829,15 +787,16 @@ class FactoryEnv(DirectRLEnv): ...@@ -829,15 +787,16 @@ class FactoryEnv(DirectRLEnv):
reset_task_prop_gains = torch.tensor(self.cfg.ctrl.reset_task_prop_gains, device=self.device).repeat( reset_task_prop_gains = torch.tensor(self.cfg.ctrl.reset_task_prop_gains, device=self.device).repeat(
(self.num_envs, 1) (self.num_envs, 1)
) )
reset_rot_deriv_scale = self.cfg.ctrl.reset_rot_deriv_scale self.task_prop_gains = reset_task_prop_gains
self._set_gains(reset_task_prop_gains, reset_rot_deriv_scale) self.task_deriv_gains = factory_utils.get_deriv_gains(
reset_task_prop_gains, self.cfg.ctrl.reset_rot_deriv_scale
)
self.step_sim_no_action() self.step_sim_no_action()
grasp_time = 0.0 grasp_time = 0.0
while grasp_time < 0.25: while grasp_time < 0.25:
self.ctrl_target_joint_pos[env_ids, 7:] = 0.0 # Close gripper. self.ctrl_target_joint_pos[env_ids, 7:] = 0.0 # Close gripper.
self.ctrl_target_gripper_dof_pos = 0.0
self.close_gripper_in_place() self.close_gripper_in_place()
self.step_sim_no_action() self.step_sim_no_action()
grasp_time += self.sim.get_physics_dt() grasp_time += self.sim.get_physics_dt()
...@@ -849,38 +808,13 @@ class FactoryEnv(DirectRLEnv): ...@@ -849,38 +808,13 @@ class FactoryEnv(DirectRLEnv):
# Set initial actions to involve no-movement. Needed for EMA/correct penalties. # Set initial actions to involve no-movement. Needed for EMA/correct penalties.
self.actions = torch.zeros_like(self.actions) self.actions = torch.zeros_like(self.actions)
self.prev_actions = torch.zeros_like(self.actions) self.prev_actions = torch.zeros_like(self.actions)
# Back out what actions should be for initial state.
# Relative position to bolt tip.
self.fixed_pos_action_frame[:] = self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise
pos_actions = self.fingertip_midpoint_pos - self.fixed_pos_action_frame
pos_action_bounds = torch.tensor(self.cfg.ctrl.pos_action_bounds, device=self.device)
pos_actions = pos_actions @ torch.diag(1.0 / pos_action_bounds)
self.actions[:, 0:3] = self.prev_actions[:, 0:3] = pos_actions
# Relative yaw to bolt.
unrot_180_euler = torch.tensor([-np.pi, 0.0, 0.0], device=self.device).repeat(self.num_envs, 1)
unrot_quat = torch_utils.quat_from_euler_xyz(
roll=unrot_180_euler[:, 0], pitch=unrot_180_euler[:, 1], yaw=unrot_180_euler[:, 2]
)
fingertip_quat_rel_bolt = torch_utils.quat_mul(unrot_quat, self.fingertip_midpoint_quat)
fingertip_yaw_bolt = torch_utils.get_euler_xyz(fingertip_quat_rel_bolt)[-1]
fingertip_yaw_bolt = torch.where(
fingertip_yaw_bolt > torch.pi / 2, fingertip_yaw_bolt - 2 * torch.pi, fingertip_yaw_bolt
)
fingertip_yaw_bolt = torch.where(
fingertip_yaw_bolt < -torch.pi, fingertip_yaw_bolt + 2 * torch.pi, fingertip_yaw_bolt
)
yaw_action = (fingertip_yaw_bolt + np.deg2rad(180.0)) / np.deg2rad(270.0) * 2.0 - 1.0
self.actions[:, 5] = self.prev_actions[:, 5] = yaw_action
# Zero initial velocity. # Zero initial velocity.
self.ee_angvel_fd[:, :] = 0.0 self.ee_angvel_fd[:, :] = 0.0
self.ee_linvel_fd[:, :] = 0.0 self.ee_linvel_fd[:, :] = 0.0
# Set initial gains for the episode. # Set initial gains for the episode.
self._set_gains(self.default_gains) self.task_prop_gains = self.default_gains
self.task_deriv_gains = factory_utils.get_deriv_gains(self.default_gains)
physics_sim_view.set_gravity(carb.Float3(*self.cfg.sim.gravity)) physics_sim_view.set_gravity(carb.Float3(*self.cfg.sim.gravity))
...@@ -164,6 +164,7 @@ class FactoryEnvCfg(DirectRLEnvCfg): ...@@ -164,6 +164,7 @@ class FactoryEnvCfg(DirectRLEnvCfg):
friction=0.0, friction=0.0,
armature=0.0, armature=0.0,
effort_limit_sim=87, effort_limit_sim=87,
velocity_limit_sim=124.6,
), ),
"panda_arm2": ImplicitActuatorCfg( "panda_arm2": ImplicitActuatorCfg(
joint_names_expr=["panda_joint[5-7]"], joint_names_expr=["panda_joint[5-7]"],
...@@ -172,10 +173,12 @@ class FactoryEnvCfg(DirectRLEnvCfg): ...@@ -172,10 +173,12 @@ class FactoryEnvCfg(DirectRLEnvCfg):
friction=0.0, friction=0.0,
armature=0.0, armature=0.0,
effort_limit_sim=12, effort_limit_sim=12,
velocity_limit_sim=149.5,
), ),
"panda_hand": ImplicitActuatorCfg( "panda_hand": ImplicitActuatorCfg(
joint_names_expr=["panda_finger_joint[1-2]"], joint_names_expr=["panda_finger_joint[1-2]"],
effort_limit_sim=40.0, effort_limit_sim=40.0,
velocity_limit_sim=0.04,
stiffness=7500.0, stiffness=7500.0,
damping=173.0, damping=173.0,
friction=0.1, friction=0.1,
......
...@@ -67,7 +67,7 @@ class FactoryTask: ...@@ -67,7 +67,7 @@ class FactoryTask:
# Reward # Reward
ee_success_yaw: float = 0.0 # nut_thread task only. ee_success_yaw: float = 0.0 # nut_thread task only.
action_penalty_scale: float = 0.0 action_penalty_ee_scale: float = 0.0
action_grad_penalty_scale: float = 0.0 action_grad_penalty_scale: float = 0.0
# Reward function details can be found in Appendix B of https://arxiv.org/pdf/2408.04587. # Reward function details can be found in Appendix B of https://arxiv.org/pdf/2408.04587.
# Multi-scale keypoints are used to capture different phases of the task. # Multi-scale keypoints are used to capture different phases of the task.
...@@ -206,7 +206,6 @@ class GearMesh(FactoryTask): ...@@ -206,7 +206,6 @@ class GearMesh(FactoryTask):
name = "gear_mesh" name = "gear_mesh"
fixed_asset_cfg = GearBase() fixed_asset_cfg = GearBase()
held_asset_cfg = MediumGear() held_asset_cfg = MediumGear()
target_gear = "gear_medium"
duration_s = 20.0 duration_s = 20.0
small_gear_usd = f"{ASSET_DIR}/factory_gear_small.usd" small_gear_usd = f"{ASSET_DIR}/factory_gear_small.usd"
......
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
import torch
import isaacsim.core.utils.torch as torch_utils
def get_keypoint_offsets(num_keypoints, device):
"""Get uniformly-spaced keypoints along a line of unit length, centered at 0."""
keypoint_offsets = torch.zeros((num_keypoints, 3), device=device)
keypoint_offsets[:, -1] = torch.linspace(0.0, 1.0, num_keypoints, device=device) - 0.5
return keypoint_offsets
def get_deriv_gains(prop_gains, rot_deriv_scale=1.0):
"""Set robot gains using critical damping."""
deriv_gains = 2 * torch.sqrt(prop_gains)
deriv_gains[:, 3:6] /= rot_deriv_scale
return deriv_gains
def wrap_yaw(angle):
"""Ensure yaw stays within range."""
return torch.where(angle > np.deg2rad(235), angle - 2 * np.pi, angle)
def set_friction(asset, value, num_envs):
"""Update material properties for a given asset."""
materials = asset.root_physx_view.get_material_properties()
materials[..., 0] = value # Static friction.
materials[..., 1] = value # Dynamic friction.
env_ids = torch.arange(num_envs, device="cpu")
asset.root_physx_view.set_material_properties(materials, env_ids)
def set_body_inertias(robot, num_envs):
"""Note: this is to account for the asset_options.armature parameter in IGE."""
inertias = robot.root_physx_view.get_inertias()
offset = torch.zeros_like(inertias)
offset[:, :, [0, 4, 8]] += 0.01
new_inertias = inertias + offset
robot.root_physx_view.set_inertias(new_inertias, torch.arange(num_envs))
def get_held_base_pos_local(task_name, fixed_asset_cfg, num_envs, device):
"""Get transform between asset default frame and geometric base frame."""
held_base_x_offset = 0.0
if task_name == "peg_insert":
held_base_z_offset = 0.0
elif task_name == "gear_mesh":
gear_base_offset = fixed_asset_cfg.medium_gear_base_offset
held_base_x_offset = gear_base_offset[0]
held_base_z_offset = gear_base_offset[2]
elif task_name == "nut_thread":
held_base_z_offset = fixed_asset_cfg.base_height
else:
raise NotImplementedError("Task not implemented")
held_base_pos_local = torch.tensor([0.0, 0.0, 0.0], device=device).repeat((num_envs, 1))
held_base_pos_local[:, 0] = held_base_x_offset
held_base_pos_local[:, 2] = held_base_z_offset
return held_base_pos_local
def get_held_base_pose(held_pos, held_quat, task_name, fixed_asset_cfg, num_envs, device):
"""Get current poses for keypoint and success computation."""
held_base_pos_local = get_held_base_pos_local(task_name, fixed_asset_cfg, num_envs, device)
held_base_quat_local = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device).unsqueeze(0).repeat(num_envs, 1)
held_base_quat, held_base_pos = torch_utils.tf_combine(
held_quat, held_pos, held_base_quat_local, held_base_pos_local
)
return held_base_pos, held_base_quat
def get_target_held_base_pose(fixed_pos, fixed_quat, task_name, fixed_asset_cfg, num_envs, device):
"""Get target poses for keypoint and success computation."""
fixed_success_pos_local = torch.zeros((num_envs, 3), device=device)
if task_name == "peg_insert":
fixed_success_pos_local[:, 2] = 0.0
elif task_name == "gear_mesh":
gear_base_offset = fixed_asset_cfg.medium_gear_base_offset
fixed_success_pos_local[:, 0] = gear_base_offset[0]
fixed_success_pos_local[:, 2] = gear_base_offset[2]
elif task_name == "nut_thread":
head_height = fixed_asset_cfg.base_height
shank_length = fixed_asset_cfg.height
thread_pitch = fixed_asset_cfg.thread_pitch
fixed_success_pos_local[:, 2] = head_height + shank_length - thread_pitch * 1.5
else:
raise NotImplementedError("Task not implemented")
fixed_success_quat_local = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device).unsqueeze(0).repeat(num_envs, 1)
target_held_base_quat, target_held_base_pos = torch_utils.tf_combine(
fixed_quat, fixed_pos, fixed_success_quat_local, fixed_success_pos_local
)
return target_held_base_pos, target_held_base_quat
def squashing_fn(x, a, b):
"""Compute bounded reward function."""
return 1 / (torch.exp(a * x) + b + torch.exp(-a * x))
def collapse_obs_dict(obs_dict, obs_order):
"""Stack observations in given order."""
obs_tensors = [obs_dict[obs_name] for obs_name in obs_order]
obs_tensors = torch.cat(obs_tensors, dim=-1)
return obs_tensors
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import gymnasium as gym
from . import agents
from .forge_env import ForgeEnv
from .forge_env_cfg import ForgeTaskGearMeshCfg, ForgeTaskNutThreadCfg, ForgeTaskPegInsertCfg
##
# Register Gym environments.
##
gym.register(
id="Isaac-Forge-PegInsert-Direct-v0",
entry_point="isaaclab_tasks.direct.forge:ForgeEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": ForgeTaskPegInsertCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
},
)
gym.register(
id="Isaac-Forge-GearMesh-Direct-v0",
entry_point="isaaclab_tasks.direct.forge:ForgeEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": ForgeTaskGearMeshCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
},
)
gym.register(
id="Isaac-Forge-NutThread-Direct-v0",
entry_point="isaaclab_tasks.direct.forge:ForgeEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": ForgeTaskNutThreadCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg_nut_thread.yaml",
},
)
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
params:
seed: 0
algo:
name: a2c_continuous
env:
clip_actions: 1.0
model:
name: continuous_a2c_logstd
network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: False
mlp:
units: [512, 128, 64]
activation: elu
d2rl: False
initializer:
name: default
regularizer:
name: None
rnn:
name: lstm
units: 1024
layers: 2
before_mlp: True
concat_input: True
layer_norm: True
load_checkpoint: False
load_path: ""
config:
name: Forge
device: cuda:0
full_experiment_name: test
env_name: rlgpu
multi_gpu: False
ppo: True
mixed_precision: True
normalize_input: True
normalize_value: True
value_bootstrap: True
num_actors: 128
reward_shaper:
scale_value: 1.0
normalize_advantage: True
gamma: 0.995
tau: 0.95
learning_rate: 1.0e-4
lr_schedule: adaptive
schedule_type: standard
kl_threshold: 0.008
score_to_win: 20000
max_epochs: 200
save_best_after: 10
save_frequency: 100
print_stats: True
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 128
minibatch_size: 512 # batch size = num_envs * horizon_length; minibatch_size = batch_size / num_minibatches
mini_epochs: 4
critic_coef: 2
clip_value: True
seq_length: 128
bounds_loss_coef: 0.0001
central_value_config:
minibatch_size: 512
mini_epochs: 4
learning_rate: 1e-4
lr_schedule: adaptive
kl_threshold: 0.008
clip_value: True
normalize_input: True
truncate_grads: True
network:
name: actor_critic
central_value: True
mlp:
units: [512, 128, 64]
activation: elu
d2rl: False
initializer:
name: default
regularizer:
name: None
rnn:
name: lstm
units: 1024
layers: 2
before_mlp: True
concat_input: True
layer_norm: True
player:
deterministic: False
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
params:
seed: 0
algo:
name: a2c_continuous
env:
clip_actions: 1.0
model:
name: continuous_a2c_logstd
network:
name: actor_critic
separate: False
space:
continuous:
mu_activation: None
sigma_activation: None
mu_init:
name: default
sigma_init:
name: const_initializer
val: 0
fixed_sigma: False
mlp:
units: [512, 128, 64]
activation: elu
d2rl: False
initializer:
name: default
regularizer:
name: None
rnn:
name: lstm
units: 1024
layers: 2
before_mlp: True
concat_input: True
layer_norm: True
load_checkpoint: False
load_path: ""
config:
name: Forge
device: cuda:0
full_experiment_name: test
env_name: rlgpu
multi_gpu: False
ppo: True
mixed_precision: True
normalize_input: True
normalize_value: True
value_bootstrap: True
num_actors: 128
reward_shaper:
scale_value: 1.0
normalize_advantage: True
gamma: 0.995
tau: 0.95
learning_rate: 1.0e-4
lr_schedule: adaptive
schedule_type: standard
kl_threshold: 0.008
score_to_win: 20000
max_epochs: 200
save_best_after: 10
save_frequency: 100
print_stats: True
grad_norm: 1.0
entropy_coef: 0.0
truncate_grads: True
e_clip: 0.2
horizon_length: 256
minibatch_size: 512 # batch size = num_envs * horizon_length; minibatch_size = batch_size / num_minibatches
mini_epochs: 4
critic_coef: 2
clip_value: True
seq_length: 128
bounds_loss_coef: 0.0001
central_value_config:
minibatch_size: 512
mini_epochs: 4
learning_rate: 1e-4
lr_schedule: adaptive
kl_threshold: 0.008
clip_value: True
normalize_input: True
truncate_grads: True
network:
name: actor_critic
central_value: True
mlp:
units: [512, 128, 64]
activation: elu
d2rl: False
initializer:
name: default
regularizer:
name: None
rnn:
name: lstm
units: 1024
layers: 2
before_mlp: True
concat_input: True
layer_norm: True
player:
deterministic: False
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
import torch
import isaacsim.core.utils.torch as torch_utils
from isaaclab.utils.math import axis_angle_from_quat
from isaaclab_tasks.direct.factory import factory_utils
from isaaclab_tasks.direct.factory.factory_env import FactoryEnv
from . import forge_utils
from .forge_env_cfg import ForgeEnvCfg
class ForgeEnv(FactoryEnv):
cfg: ForgeEnvCfg
def __init__(self, cfg: ForgeEnvCfg, render_mode: str | None = None, **kwargs):
"""Initialize additional randomization and logging tensors."""
super().__init__(cfg, render_mode, **kwargs)
# Success prediction.
self.success_pred_scale = 0.0
self.first_pred_success_tx = {}
for thresh in [0.5, 0.6, 0.7, 0.8, 0.9]:
self.first_pred_success_tx[thresh] = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
# Flip quaternions.
self.flip_quats = torch.ones((self.num_envs,), dtype=torch.float32, device=self.device)
# Force sensor information.
self.force_sensor_body_idx = self._robot.body_names.index("force_sensor")
self.force_sensor_smooth = torch.zeros((self.num_envs, 6), device=self.device)
self.force_sensor_world_smooth = torch.zeros((self.num_envs, 6), device=self.device)
# Set nominal dynamics parameters for randomization.
self.default_gains = torch.tensor(self.cfg.ctrl.default_task_prop_gains, device=self.device).repeat(
(self.num_envs, 1)
)
self.default_pos_threshold = torch.tensor(self.cfg.ctrl.pos_action_threshold, device=self.device).repeat(
(self.num_envs, 1)
)
self.default_rot_threshold = torch.tensor(self.cfg.ctrl.rot_action_threshold, device=self.device).repeat(
(self.num_envs, 1)
)
self.default_dead_zone = torch.tensor(self.cfg.ctrl.default_dead_zone, device=self.device).repeat(
(self.num_envs, 1)
)
self.pos_threshold = self.default_pos_threshold.clone()
self.rot_threshold = self.default_rot_threshold.clone()
def _compute_intermediate_values(self, dt):
"""Add noise to observations for force sensing."""
super()._compute_intermediate_values(dt)
# Add noise to fingertip position.
pos_noise_level, rot_noise_level_deg = self.cfg.obs_rand.fingertip_pos, self.cfg.obs_rand.fingertip_rot_deg
fingertip_pos_noise = torch.randn((self.num_envs, 3), dtype=torch.float32, device=self.device)
fingertip_pos_noise = fingertip_pos_noise @ torch.diag(
torch.tensor([pos_noise_level, pos_noise_level, pos_noise_level], dtype=torch.float32, device=self.device)
)
self.noisy_fingertip_pos = self.fingertip_midpoint_pos + fingertip_pos_noise
rot_noise_axis = torch.randn((self.num_envs, 3), dtype=torch.float32, device=self.device)
rot_noise_axis /= torch.linalg.norm(rot_noise_axis, dim=1, keepdim=True)
rot_noise_angle = torch.randn((self.num_envs,), dtype=torch.float32, device=self.device) * np.deg2rad(
rot_noise_level_deg
)
self.noisy_fingertip_quat = torch_utils.quat_mul(
self.fingertip_midpoint_quat, torch_utils.quat_from_angle_axis(rot_noise_angle, rot_noise_axis)
)
self.noisy_fingertip_quat[:, [0, 3]] = 0.0
self.noisy_fingertip_quat = self.noisy_fingertip_quat * self.flip_quats.unsqueeze(-1)
# Repeat finite differencing with noisy fingertip positions.
self.ee_linvel_fd = (self.noisy_fingertip_pos - self.prev_fingertip_pos) / dt
self.prev_fingertip_pos = self.noisy_fingertip_pos.clone()
# Add state differences if velocity isn't being added.
rot_diff_quat = torch_utils.quat_mul(
self.noisy_fingertip_quat, torch_utils.quat_conjugate(self.prev_fingertip_quat)
)
rot_diff_quat *= torch.sign(rot_diff_quat[:, 0]).unsqueeze(-1)
rot_diff_aa = axis_angle_from_quat(rot_diff_quat)
self.ee_angvel_fd = rot_diff_aa / dt
self.ee_angvel_fd[:, 0:2] = 0.0
self.prev_fingertip_quat = self.noisy_fingertip_quat.clone()
# Update and smooth force values.
self.force_sensor_world = self._robot.root_physx_view.get_link_incoming_joint_force()[
:, self.force_sensor_body_idx
]
alpha = self.cfg.ft_smoothing_factor
self.force_sensor_world_smooth = alpha * self.force_sensor_world + (1 - alpha) * self.force_sensor_world_smooth
self.force_sensor_smooth = torch.zeros_like(self.force_sensor_world)
identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device).unsqueeze(0).repeat(self.num_envs, 1)
self.force_sensor_smooth[:, :3], self.force_sensor_smooth[:, 3:6] = forge_utils.change_FT_frame(
self.force_sensor_world_smooth[:, 0:3],
self.force_sensor_world_smooth[:, 3:6],
(identity_quat, torch.zeros((self.num_envs, 3), device=self.device)),
(identity_quat, self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise),
)
# Compute noisy force values.
force_noise = torch.randn((self.num_envs, 3), dtype=torch.float32, device=self.device)
force_noise *= self.cfg.obs_rand.ft_force
self.noisy_force = self.force_sensor_smooth[:, 0:3] + force_noise
def _get_observations(self):
"""Add additional FORGE observations."""
obs_dict, state_dict = self._get_factory_obs_state_dict()
noisy_fixed_pos = self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise
prev_actions = self.actions.clone()
prev_actions[:, 3:5] = 0.0
obs_dict.update({
"fingertip_pos": self.noisy_fingertip_pos,
"fingertip_pos_rel_fixed": self.noisy_fingertip_pos - noisy_fixed_pos,
"fingertip_quat": self.noisy_fingertip_quat,
"force_threshold": self.contact_penalty_thresholds[:, None],
"ft_force": self.noisy_force,
"prev_actions": prev_actions,
})
state_dict.update({
"ema_factor": self.ema_factor,
"ft_force": self.force_sensor_smooth[:, 0:3],
"force_threshold": self.contact_penalty_thresholds[:, None],
"prev_actions": prev_actions,
})
obs_tensors = factory_utils.collapse_obs_dict(obs_dict, self.cfg.obs_order + ["prev_actions"])
state_tensors = factory_utils.collapse_obs_dict(state_dict, self.cfg.state_order + ["prev_actions"])
return {"policy": obs_tensors, "critic": state_tensors}
def _apply_action(self):
"""FORGE actions are defined as targets relative to the fixed asset."""
if self.last_update_timestamp < self._robot._data._sim_timestamp:
self._compute_intermediate_values(dt=self.physics_dt)
# Step (0): Scale actions to allowed range.
pos_actions = self.actions[:, 0:3]
pos_actions = pos_actions @ torch.diag(torch.tensor(self.cfg.ctrl.pos_action_bounds, device=self.device))
rot_actions = self.actions[:, 3:6]
rot_actions = rot_actions @ torch.diag(torch.tensor(self.cfg.ctrl.rot_action_bounds, device=self.device))
# Step (1): Compute desired pose targets in EE frame.
# (1.a) Position. Action frame is assumed to be the top of the bolt (noisy estimate).
fixed_pos_action_frame = self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise
ctrl_target_fingertip_preclipped_pos = fixed_pos_action_frame + pos_actions
# (1.b) Enforce rotation action constraints.
rot_actions[:, 0:2] = 0.0
# Assumes joint limit is in (+x, -y)-quadrant of world frame.
rot_actions[:, 2] = np.deg2rad(-180.0) + np.deg2rad(270.0) * (rot_actions[:, 2] + 1.0) / 2.0 # Joint limit.
# (1.c) Get desired orientation target.
bolt_frame_quat = torch_utils.quat_from_euler_xyz(
roll=rot_actions[:, 0], pitch=rot_actions[:, 1], yaw=rot_actions[:, 2]
)
rot_180_euler = torch.tensor([np.pi, 0.0, 0.0], device=self.device).repeat(self.num_envs, 1)
quat_bolt_to_ee = torch_utils.quat_from_euler_xyz(
roll=rot_180_euler[:, 0], pitch=rot_180_euler[:, 1], yaw=rot_180_euler[:, 2]
)
ctrl_target_fingertip_preclipped_quat = torch_utils.quat_mul(quat_bolt_to_ee, bolt_frame_quat)
# Step (2): Clip targets if they are too far from current EE pose.
# (2.a): Clip position targets.
self.delta_pos = ctrl_target_fingertip_preclipped_pos - self.fingertip_midpoint_pos # Used for action_penalty.
pos_error_clipped = torch.clip(self.delta_pos, -self.pos_threshold, self.pos_threshold)
ctrl_target_fingertip_midpoint_pos = self.fingertip_midpoint_pos + pos_error_clipped
# (2.b) Clip orientation targets. Use Euler angles. We assume we are near upright, so
# clipping yaw will effectively cause slow motions. When we clip, we also need to make
# sure we avoid the joint limit.
# (2.b.i) Get current and desired Euler angles.
curr_roll, curr_pitch, curr_yaw = torch_utils.get_euler_xyz(self.fingertip_midpoint_quat)
desired_roll, desired_pitch, desired_yaw = torch_utils.get_euler_xyz(ctrl_target_fingertip_preclipped_quat)
desired_xyz = torch.stack([desired_roll, desired_pitch, desired_yaw], dim=1)
# (2.b.ii) Correct the direction of motion to avoid joint limit.
# Map yaws between [-125, 235] degrees (so that angles appear on a continuous span uninterrupted by the joint limit).
curr_yaw = factory_utils.wrap_yaw(curr_yaw)
desired_yaw = factory_utils.wrap_yaw(desired_yaw)
# (2.b.iii) Clip motion in the correct direction.
self.delta_yaw = desired_yaw - curr_yaw # Used later for action_penalty.
clipped_yaw = torch.clip(self.delta_yaw, -self.rot_threshold[:, 2], self.rot_threshold[:, 2])
desired_xyz[:, 2] = curr_yaw + clipped_yaw
# (2.b.iv) Clip roll and pitch.
desired_roll = torch.where(desired_roll < 0.0, desired_roll + 2 * torch.pi, desired_roll)
desired_pitch = torch.where(desired_pitch < 0.0, desired_pitch + 2 * torch.pi, desired_pitch)
delta_roll = desired_roll - curr_roll
clipped_roll = torch.clip(delta_roll, -self.rot_threshold[:, 0], self.rot_threshold[:, 0])
desired_xyz[:, 0] = curr_roll + clipped_roll
curr_pitch = torch.where(curr_pitch > torch.pi, curr_pitch - 2 * torch.pi, curr_pitch)
desired_pitch = torch.where(desired_pitch > torch.pi, desired_pitch - 2 * torch.pi, desired_pitch)
delta_pitch = desired_pitch - curr_pitch
clipped_pitch = torch.clip(delta_pitch, -self.rot_threshold[:, 1], self.rot_threshold[:, 1])
desired_xyz[:, 1] = curr_pitch + clipped_pitch
ctrl_target_fingertip_midpoint_quat = torch_utils.quat_from_euler_xyz(
roll=desired_xyz[:, 0], pitch=desired_xyz[:, 1], yaw=desired_xyz[:, 2]
)
self.generate_ctrl_signals(
ctrl_target_fingertip_midpoint_pos=ctrl_target_fingertip_midpoint_pos,
ctrl_target_fingertip_midpoint_quat=ctrl_target_fingertip_midpoint_quat,
ctrl_target_gripper_dof_pos=0.0,
)
def _get_rewards(self):
"""FORGE reward includes a contact penalty and success prediction error."""
# Use same base rewards as Factory.
rew_buf = super()._get_rewards()
rew_dict, rew_scales = {}, {}
# Calculate action penalty for the asset-relative action space.
pos_error = torch.norm(self.delta_pos, p=2, dim=-1) / self.cfg.ctrl.pos_action_threshold[0]
rot_error = torch.abs(self.delta_yaw) / self.cfg.ctrl.rot_action_threshold[0]
# Contact penalty.
contact_force = torch.norm(self.force_sensor_smooth[:, 0:3], p=2, dim=-1, keepdim=False)
contact_penalty = torch.nn.functional.relu(contact_force - self.contact_penalty_thresholds)
# Add success prediction rewards.
check_rot = self.cfg_task.name == "nut_thread"
true_successes = self._get_curr_successes(
success_threshold=self.cfg_task.success_threshold, check_rot=check_rot
)
policy_success_pred = (self.actions[:, 6] + 1) / 2 # rescale from [-1, 1] to [0, 1]
success_pred_error = (true_successes.float() - policy_success_pred).abs()
# Delay success prediction penalty until some successes have occurred.
if true_successes.float().mean() >= self.cfg_task.delay_until_ratio:
self.success_pred_scale = 1.0
# Add new FORGE reward terms.
rew_dict = {
"action_penalty_asset": pos_error + rot_error,
"contact_penalty": contact_penalty,
"success_pred_error": success_pred_error,
}
rew_scales = {
"action_penalty_asset": -self.cfg_task.action_penalty_asset_scale,
"contact_penalty": -self.cfg_task.contact_penalty_scale,
"success_pred_error": -self.success_pred_scale,
}
for rew_name, rew in rew_dict.items():
rew_buf += rew_dict[rew_name] * rew_scales[rew_name]
self._log_forge_metrics(rew_dict, policy_success_pred)
return rew_buf
def _reset_idx(self, env_ids):
"""Perform additional randomizations."""
super()._reset_idx(env_ids)
# Compute initial action for correct EMA computation.
fixed_pos_action_frame = self.fixed_pos_obs_frame + self.init_fixed_pos_obs_noise
pos_actions = self.fingertip_midpoint_pos - fixed_pos_action_frame
pos_action_bounds = torch.tensor(self.cfg.ctrl.pos_action_bounds, device=self.device)
pos_actions = pos_actions @ torch.diag(1.0 / pos_action_bounds)
self.actions[:, 0:3] = self.prev_actions[:, 0:3] = pos_actions
# Relative yaw to bolt.
unrot_180_euler = torch.tensor([-np.pi, 0.0, 0.0], device=self.device).repeat(self.num_envs, 1)
unrot_quat = torch_utils.quat_from_euler_xyz(
roll=unrot_180_euler[:, 0], pitch=unrot_180_euler[:, 1], yaw=unrot_180_euler[:, 2]
)
fingertip_quat_rel_bolt = torch_utils.quat_mul(unrot_quat, self.fingertip_midpoint_quat)
fingertip_yaw_bolt = torch_utils.get_euler_xyz(fingertip_quat_rel_bolt)[-1]
fingertip_yaw_bolt = torch.where(
fingertip_yaw_bolt > torch.pi / 2, fingertip_yaw_bolt - 2 * torch.pi, fingertip_yaw_bolt
)
fingertip_yaw_bolt = torch.where(
fingertip_yaw_bolt < -torch.pi, fingertip_yaw_bolt + 2 * torch.pi, fingertip_yaw_bolt
)
yaw_action = (fingertip_yaw_bolt + np.deg2rad(180.0)) / np.deg2rad(270.0) * 2.0 - 1.0
self.actions[:, 5] = self.prev_actions[:, 5] = yaw_action
self.actions[:, 6] = self.prev_actions[:, 6] = -1.0
# EMA randomization.
ema_rand = torch.rand((self.num_envs, 1), dtype=torch.float32, device=self.device)
ema_lower, ema_upper = self.cfg.ctrl.ema_factor_range
self.ema_factor = ema_lower + ema_rand * (ema_upper - ema_lower)
# Set initial gains for the episode.
prop_gains = self.default_gains.clone()
self.pos_threshold = self.default_pos_threshold.clone()
self.rot_threshold = self.default_rot_threshold.clone()
prop_gains = forge_utils.get_random_prop_gains(
prop_gains, self.cfg.ctrl.task_prop_gains_noise_level, self.num_envs, self.device
)
self.pos_threshold = forge_utils.get_random_prop_gains(
self.pos_threshold, self.cfg.ctrl.pos_threshold_noise_level, self.num_envs, self.device
)
self.rot_threshold = forge_utils.get_random_prop_gains(
self.rot_threshold, self.cfg.ctrl.rot_threshold_noise_level, self.num_envs, self.device
)
self.task_prop_gains = prop_gains
self.task_deriv_gains = factory_utils.get_deriv_gains(prop_gains)
contact_rand = torch.rand((self.num_envs,), dtype=torch.float32, device=self.device)
contact_lower, contact_upper = self.cfg.task.contact_penalty_threshold_range
self.contact_penalty_thresholds = contact_lower + contact_rand * (contact_upper - contact_lower)
self.dead_zone_thresholds = (
torch.rand((self.num_envs, 6), dtype=torch.float32, device=self.device) * self.default_dead_zone
)
self.force_sensor_world_smooth[:, :] = 0.0
self.flip_quats = torch.ones((self.num_envs,), dtype=torch.float32, device=self.device)
rand_flips = torch.rand(self.num_envs) > 0.5
self.flip_quats[rand_flips] = -1.0
def _reset_buffers(self, env_ids):
"""Reset additional logging metrics."""
super()._reset_buffers(env_ids)
# Reset success pred metrics.
for thresh in [0.5, 0.6, 0.7, 0.8, 0.9]:
self.first_pred_success_tx[thresh][env_ids] = 0
def _log_forge_metrics(self, rew_dict, policy_success_pred):
"""Log metrics to evaluate success prediction performance."""
for rew_name, rew in rew_dict.items():
self.extras[f"logs_rew_{rew_name}"] = rew.mean()
for thresh, first_success_tx in self.first_pred_success_tx.items():
curr_predicted_success = policy_success_pred > thresh
first_success_idxs = torch.logical_and(curr_predicted_success, first_success_tx == 0)
first_success_tx[:] = torch.where(first_success_idxs, self.episode_length_buf, first_success_tx)
# Only log at the end.
if torch.any(self.reset_buf):
# Log prediction delay.
delay_ids = torch.logical_and(self.ep_success_times != 0, first_success_tx != 0)
delay_times = (first_success_tx[delay_ids] - self.ep_success_times[delay_ids]).sum() / delay_ids.sum()
if delay_ids.sum().item() > 0:
self.extras[f"early_term_delay_all/{thresh}"] = delay_times
correct_delay_ids = torch.logical_and(delay_ids, first_success_tx > self.ep_success_times)
correct_delay_times = (
first_success_tx[correct_delay_ids] - self.ep_success_times[correct_delay_ids]
).sum() / correct_delay_ids.sum()
if correct_delay_ids.sum().item() > 0:
self.extras[f"early_term_delay_correct/{thresh}"] = correct_delay_times.item()
# Log early-term success rate (for all episodes we have "stopped", did we succeed?).
pred_success_idxs = first_success_tx != 0 # Episodes which we have predicted success.
true_success_preds = torch.logical_and(
self.ep_success_times[pred_success_idxs] > 0, # Success has actually occurred.
self.ep_success_times[pred_success_idxs]
< first_success_tx[pred_success_idxs], # Success occurred before we predicted it.
)
num_pred_success = pred_success_idxs.sum().item()
et_prec = true_success_preds.sum() / num_pred_success
if num_pred_success > 0:
self.extras[f"early_term_precision/{thresh}"] = et_prec
true_success_idxs = self.ep_success_times > 0
num_true_success = true_success_idxs.sum().item()
et_recall = true_success_preds.sum() / num_true_success
if num_true_success > 0:
self.extras[f"early_term_recall/{thresh}"] = et_recall
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import isaaclab.envs.mdp as mdp
from isaaclab.managers import EventTermCfg as EventTerm
from isaaclab.managers import SceneEntityCfg
from isaaclab.utils import configclass
from isaaclab_tasks.direct.factory.factory_env_cfg import OBS_DIM_CFG, STATE_DIM_CFG, CtrlCfg, FactoryEnvCfg, ObsRandCfg
from .forge_events import randomize_dead_zone
from .forge_tasks_cfg import ForgeGearMesh, ForgeNutThread, ForgePegInsert, ForgeTask
OBS_DIM_CFG.update({"force_threshold": 1, "ft_force": 3})
STATE_DIM_CFG.update({"force_threshold": 1, "ft_force": 3})
@configclass
class ForgeCtrlCfg(CtrlCfg):
ema_factor_range = [0.025, 0.1]
default_task_prop_gains = [565.0, 565.0, 565.0, 28.0, 28.0, 28.0]
task_prop_gains_noise_level = [0.41, 0.41, 0.41, 0.41, 0.41, 0.41]
pos_threshold_noise_level = [0.25, 0.25, 0.25]
rot_threshold_noise_level = [0.29, 0.29, 0.29]
default_dead_zone = [5.0, 5.0, 5.0, 1.0, 1.0, 1.0]
@configclass
class ForgeObsRandCfg(ObsRandCfg):
fingertip_pos = 0.00025
fingertip_rot_deg = 0.1
ft_force = 1.0
@configclass
class EventCfg:
object_scale_mass = EventTerm(
func=mdp.randomize_rigid_body_mass,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("held_asset"),
"mass_distribution_params": (-0.005, 0.005),
"operation": "add",
"distribution": "uniform",
},
)
held_physics_material = EventTerm(
func=mdp.randomize_rigid_body_material,
mode="startup",
params={
"asset_cfg": SceneEntityCfg("held_asset"),
"static_friction_range": (0.75, 0.75),
"dynamic_friction_range": (0.75, 0.75),
"restitution_range": (0.0, 0.0),
"num_buckets": 1,
},
)
fixed_physics_material = EventTerm(
func=mdp.randomize_rigid_body_material,
mode="startup",
params={
"asset_cfg": SceneEntityCfg("fixed_asset"),
"static_friction_range": (0.25, 1.25), # TODO: Set these values based on asset type.
"dynamic_friction_range": (0.25, 0.25),
"restitution_range": (0.0, 0.0),
"num_buckets": 128,
},
)
robot_physics_material = EventTerm(
func=mdp.randomize_rigid_body_material,
mode="startup",
params={
"asset_cfg": SceneEntityCfg("robot", body_names=".*"),
"static_friction_range": (0.75, 0.75),
"dynamic_friction_range": (0.75, 0.75),
"restitution_range": (0.0, 0.0),
"num_buckets": 1,
},
)
dead_zone_thresholds = EventTerm(
func=randomize_dead_zone, mode="interval", interval_range_s=(2.0, 2.0) # (0.25, 0.25)
)
@configclass
class ForgeEnvCfg(FactoryEnvCfg):
action_space: int = 7
obs_rand: ForgeObsRandCfg = ForgeObsRandCfg()
ctrl: ForgeCtrlCfg = ForgeCtrlCfg()
task: ForgeTask = ForgeTask()
events: EventCfg = EventCfg()
ft_smoothing_factor: float = 0.25
obs_order: list = [
"fingertip_pos_rel_fixed",
"fingertip_quat",
"ee_linvel",
"ee_angvel",
"ft_force",
"force_threshold",
]
state_order: list = [
"fingertip_pos",
"fingertip_quat",
"ee_linvel",
"ee_angvel",
"joint_pos",
"held_pos",
"held_pos_rel_fixed",
"held_quat",
"fixed_pos",
"fixed_quat",
"task_prop_gains",
"ema_factor",
"ft_force",
"pos_threshold",
"rot_threshold",
"force_threshold",
]
@configclass
class ForgeTaskPegInsertCfg(ForgeEnvCfg):
task_name = "peg_insert"
task = ForgePegInsert()
episode_length_s = 10.0
@configclass
class ForgeTaskGearMeshCfg(ForgeEnvCfg):
task_name = "gear_mesh"
task = ForgeGearMesh()
episode_length_s = 20.0
@configclass
class ForgeTaskNutThreadCfg(ForgeEnvCfg):
task_name = "nut_thread"
task = ForgeNutThread()
episode_length_s = 30.0
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import torch
from isaaclab.envs import DirectRLEnv
def randomize_dead_zone(env: DirectRLEnv, env_ids: torch.Tensor | None):
env.dead_zone_thresholds = (
torch.rand((env.num_envs, 6), dtype=torch.float32, device=env.device) * env.default_dead_zone
)
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from isaaclab.utils import configclass
from isaaclab_tasks.direct.factory.factory_tasks_cfg import FactoryTask, GearMesh, NutThread, PegInsert
@configclass
class ForgeTask(FactoryTask):
action_penalty_ee_scale: float = 0.0
action_penalty_asset_scale: float = 0.001
action_grad_penalty_scale: float = 0.1
contact_penalty_scale: float = 0.05
delay_until_ratio: float = 0.25
contact_penalty_threshold_range = [5.0, 10.0]
@configclass
class ForgePegInsert(PegInsert, ForgeTask):
contact_penalty_scale: float = 0.2
@configclass
class ForgeGearMesh(GearMesh, ForgeTask):
contact_penalty_scale: float = 0.05
@configclass
class ForgeNutThread(NutThread, ForgeTask):
contact_penalty_scale: float = 0.05
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
import torch
import isaacsim.core.utils.torch as torch_utils
def get_random_prop_gains(default_values, noise_levels, num_envs, device):
"""Helper function to randomize controller gains."""
c_param_noise = torch.rand((num_envs, default_values.shape[1]), dtype=torch.float32, device=device)
c_param_noise = c_param_noise @ torch.diag(torch.tensor(noise_levels, dtype=torch.float32, device=device))
c_param_multiplier = 1.0 + c_param_noise
decrease_param_flag = torch.rand((num_envs, default_values.shape[1]), dtype=torch.float32, device=device) > 0.5
c_param_multiplier = torch.where(decrease_param_flag, 1.0 / c_param_multiplier, c_param_multiplier)
prop_gains = default_values * c_param_multiplier
return prop_gains
def change_FT_frame(source_F, source_T, source_frame, target_frame):
"""Convert force/torque reading from source to target frame."""
# Modern Robotics eq. 3.95
source_frame_inv = torch_utils.tf_inverse(source_frame[0], source_frame[1])
target_T_source_quat, target_T_source_pos = torch_utils.tf_combine(
source_frame_inv[0], source_frame_inv[1], target_frame[0], target_frame[1]
)
target_F = torch_utils.quat_apply(target_T_source_quat, source_F)
target_T = torch_utils.quat_apply(
target_T_source_quat, (source_T + torch.cross(target_T_source_pos, source_F, dim=-1))
)
return target_F, target_T
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment