Unverified Commit 8ddc4830 authored by Kelly Guo's avatar Kelly Guo Committed by GitHub

Adds position threshold check for state transitions (#1544)

# Description

Adds a position threshold check, resolving 3 TODO error comments, to
ensure the robot's end effector is within a specified distance from the
target position before transitioning between states in the pick and lift
state machine. Improves the precision of state transitions and helps
prevent premature actions during object manipulation. I.e, the threshold
ensures the robot is "close enough" to the target position before
proceeding, reducing the likelihood of failed grasps or incorrect
movements.

PR adapted from https://github.com/isaac-sim/IsaacLab/pull/1273/ by
@DorsaRoh.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarDorsaRoh <dorsa.rohani@gmail.com>
parent f01c6f9e
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.27.27" version = "0.27.28"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.27.28 (2024-12-14)
~~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Added check for error below threshold in state machines to ensure the state has been reached.
0.27.27 (2024-12-13) 0.27.27 (2024-12-13)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -81,6 +81,11 @@ class PickSmWaitTime: ...@@ -81,6 +81,11 @@ class PickSmWaitTime:
LIFT_OBJECT = wp.constant(1.0) LIFT_OBJECT = wp.constant(1.0)
@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold
@wp.kernel @wp.kernel
def infer_state_machine( def infer_state_machine(
dt: wp.array(dtype=float), dt: wp.array(dtype=float),
...@@ -92,6 +97,7 @@ def infer_state_machine( ...@@ -92,6 +97,7 @@ def infer_state_machine(
des_ee_pose: wp.array(dtype=wp.transform), des_ee_pose: wp.array(dtype=wp.transform),
gripper_state: wp.array(dtype=float), gripper_state: wp.array(dtype=float),
offset: wp.array(dtype=wp.transform), offset: wp.array(dtype=wp.transform),
position_threshold: float,
): ):
# retrieve thread id # retrieve thread id
tid = wp.tid() tid = wp.tid()
...@@ -109,21 +115,28 @@ def infer_state_machine( ...@@ -109,21 +115,28 @@ def infer_state_machine(
elif state == PickSmState.APPROACH_ABOVE_OBJECT: elif state == PickSmState.APPROACH_ABOVE_OBJECT:
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid]) des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
gripper_state[tid] = GripperState.OPEN gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = PickSmState.APPROACH_OBJECT ):
sm_wait_time[tid] = 0.0 # wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.APPROACH_OBJECT: elif state == PickSmState.APPROACH_OBJECT:
des_ee_pose[tid] = object_pose[tid] des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = PickSmState.GRASP_OBJECT ):
sm_wait_time[tid] = 0.0 if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.GRASP_OBJECT: elif state == PickSmState.GRASP_OBJECT:
des_ee_pose[tid] = object_pose[tid] des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.CLOSE gripper_state[tid] = GripperState.CLOSE
...@@ -135,12 +148,16 @@ def infer_state_machine( ...@@ -135,12 +148,16 @@ def infer_state_machine(
elif state == PickSmState.LIFT_OBJECT: elif state == PickSmState.LIFT_OBJECT:
des_ee_pose[tid] = des_object_pose[tid] des_ee_pose[tid] = des_object_pose[tid]
gripper_state[tid] = GripperState.CLOSE gripper_state[tid] = GripperState.CLOSE
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = PickSmState.LIFT_OBJECT ):
sm_wait_time[tid] = 0.0 # wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.LIFT_OBJECT
sm_wait_time[tid] = 0.0
# increment wait time # increment wait time
sm_wait_time[tid] = sm_wait_time[tid] + dt[tid] sm_wait_time[tid] = sm_wait_time[tid] + dt[tid]
...@@ -160,7 +177,7 @@ class PickAndLiftSm: ...@@ -160,7 +177,7 @@ class PickAndLiftSm:
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state. 5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
""" """
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"): def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine. """Initialize the state machine.
Args: Args:
...@@ -172,6 +189,7 @@ class PickAndLiftSm: ...@@ -172,6 +189,7 @@ class PickAndLiftSm:
self.dt = float(dt) self.dt = float(dt)
self.num_envs = num_envs self.num_envs = num_envs
self.device = device self.device = device
self.position_threshold = position_threshold
# initialize state machine # initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device) self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device) self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
...@@ -201,7 +219,7 @@ class PickAndLiftSm: ...@@ -201,7 +219,7 @@ class PickAndLiftSm:
self.sm_state[env_ids] = 0 self.sm_state[env_ids] = 0
self.sm_wait_time[env_ids] = 0.0 self.sm_wait_time[env_ids] = 0.0
def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor): def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor) -> torch.Tensor:
"""Compute the desired state of the robot's end-effector and the gripper.""" """Compute the desired state of the robot's end-effector and the gripper."""
# convert all transformations from (w, x, y, z) to (x, y, z, w) # convert all transformations from (w, x, y, z) to (x, y, z, w)
ee_pose = ee_pose[:, [0, 1, 2, 4, 5, 6, 3]] ee_pose = ee_pose[:, [0, 1, 2, 4, 5, 6, 3]]
...@@ -227,6 +245,7 @@ class PickAndLiftSm: ...@@ -227,6 +245,7 @@ class PickAndLiftSm:
self.des_ee_pose_wp, self.des_ee_pose_wp,
self.des_gripper_state_wp, self.des_gripper_state_wp,
self.offset_wp, self.offset_wp,
self.position_threshold,
], ],
device=self.device, device=self.device,
) )
...@@ -257,7 +276,9 @@ def main(): ...@@ -257,7 +276,9 @@ def main():
desired_orientation = torch.zeros((env.unwrapped.num_envs, 4), device=env.unwrapped.device) desired_orientation = torch.zeros((env.unwrapped.num_envs, 4), device=env.unwrapped.device)
desired_orientation[:, 1] = 1.0 desired_orientation[:, 1] = 1.0
# create state machine # create state machine
pick_sm = PickAndLiftSm(env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device) pick_sm = PickAndLiftSm(
env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device, position_threshold=0.01
)
while simulation_app.is_running(): while simulation_app.is_running():
# run everything in inference mode # run everything in inference mode
......
...@@ -80,6 +80,11 @@ class PickSmWaitTime: ...@@ -80,6 +80,11 @@ class PickSmWaitTime:
OPEN_GRIPPER = wp.constant(0.0) OPEN_GRIPPER = wp.constant(0.0)
@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold
@wp.kernel @wp.kernel
def infer_state_machine( def infer_state_machine(
dt: wp.array(dtype=float), dt: wp.array(dtype=float),
...@@ -91,6 +96,7 @@ def infer_state_machine( ...@@ -91,6 +96,7 @@ def infer_state_machine(
des_ee_pose: wp.array(dtype=wp.transform), des_ee_pose: wp.array(dtype=wp.transform),
gripper_state: wp.array(dtype=float), gripper_state: wp.array(dtype=float),
offset: wp.array(dtype=wp.transform), offset: wp.array(dtype=wp.transform),
position_threshold: float,
): ):
# retrieve thread id # retrieve thread id
tid = wp.tid() tid = wp.tid()
...@@ -108,21 +114,29 @@ def infer_state_machine( ...@@ -108,21 +114,29 @@ def infer_state_machine(
elif state == PickSmState.APPROACH_ABOVE_OBJECT: elif state == PickSmState.APPROACH_ABOVE_OBJECT:
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid]) des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
gripper_state[tid] = GripperState.OPEN gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = PickSmState.APPROACH_OBJECT ):
sm_wait_time[tid] = 0.0 # wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.APPROACH_OBJECT: elif state == PickSmState.APPROACH_OBJECT:
des_ee_pose[tid] = object_pose[tid] des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = PickSmState.GRASP_OBJECT ):
sm_wait_time[tid] = 0.0 # wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.GRASP_OBJECT: elif state == PickSmState.GRASP_OBJECT:
des_ee_pose[tid] = object_pose[tid] des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.CLOSE gripper_state[tid] = GripperState.CLOSE
...@@ -134,12 +148,16 @@ def infer_state_machine( ...@@ -134,12 +148,16 @@ def infer_state_machine(
elif state == PickSmState.LIFT_OBJECT: elif state == PickSmState.LIFT_OBJECT:
des_ee_pose[tid] = des_object_pose[tid] des_ee_pose[tid] = des_object_pose[tid]
gripper_state[tid] = GripperState.CLOSE gripper_state[tid] = GripperState.CLOSE
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = PickSmState.OPEN_GRIPPER ):
sm_wait_time[tid] = 0.0 # wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.OPEN_GRIPPER
sm_wait_time[tid] = 0.0
elif state == PickSmState.OPEN_GRIPPER: elif state == PickSmState.OPEN_GRIPPER:
# des_ee_pose[tid] = object_pose[tid] # des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN gripper_state[tid] = GripperState.OPEN
...@@ -167,7 +185,7 @@ class PickAndLiftSm: ...@@ -167,7 +185,7 @@ class PickAndLiftSm:
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state. 5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
""" """
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"): def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine. """Initialize the state machine.
Args: Args:
...@@ -179,6 +197,7 @@ class PickAndLiftSm: ...@@ -179,6 +197,7 @@ class PickAndLiftSm:
self.dt = float(dt) self.dt = float(dt)
self.num_envs = num_envs self.num_envs = num_envs
self.device = device self.device = device
self.position_threshold = position_threshold
# initialize state machine # initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device) self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device) self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
...@@ -234,6 +253,7 @@ class PickAndLiftSm: ...@@ -234,6 +253,7 @@ class PickAndLiftSm:
self.des_ee_pose_wp, self.des_ee_pose_wp,
self.des_gripper_state_wp, self.des_gripper_state_wp,
self.offset_wp, self.offset_wp,
self.position_threshold,
], ],
device=self.device, device=self.device,
) )
......
...@@ -83,6 +83,11 @@ class OpenDrawerSmWaitTime: ...@@ -83,6 +83,11 @@ class OpenDrawerSmWaitTime:
RELEASE_HANDLE = wp.constant(0.2) RELEASE_HANDLE = wp.constant(0.2)
@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold
@wp.kernel @wp.kernel
def infer_state_machine( def infer_state_machine(
dt: wp.array(dtype=float), dt: wp.array(dtype=float),
...@@ -95,6 +100,7 @@ def infer_state_machine( ...@@ -95,6 +100,7 @@ def infer_state_machine(
handle_approach_offset: wp.array(dtype=wp.transform), handle_approach_offset: wp.array(dtype=wp.transform),
handle_grasp_offset: wp.array(dtype=wp.transform), handle_grasp_offset: wp.array(dtype=wp.transform),
drawer_opening_rate: wp.array(dtype=wp.transform), drawer_opening_rate: wp.array(dtype=wp.transform),
position_threshold: float,
): ):
# retrieve thread id # retrieve thread id
tid = wp.tid() tid = wp.tid()
...@@ -112,21 +118,29 @@ def infer_state_machine( ...@@ -112,21 +118,29 @@ def infer_state_machine(
elif state == OpenDrawerSmState.APPROACH_INFRONT_HANDLE: elif state == OpenDrawerSmState.APPROACH_INFRONT_HANDLE:
des_ee_pose[tid] = wp.transform_multiply(handle_approach_offset[tid], handle_pose[tid]) des_ee_pose[tid] = wp.transform_multiply(handle_approach_offset[tid], handle_pose[tid])
gripper_state[tid] = GripperState.OPEN gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE ):
sm_wait_time[tid] = 0.0 # wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE
sm_wait_time[tid] = 0.0
elif state == OpenDrawerSmState.APPROACH_HANDLE: elif state == OpenDrawerSmState.APPROACH_HANDLE:
des_ee_pose[tid] = handle_pose[tid] des_ee_pose[tid] = handle_pose[tid]
gripper_state[tid] = GripperState.OPEN gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold if distance_below_threshold(
# wait for a while wp.transform_get_translation(ee_pose[tid]),
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE: wp.transform_get_translation(des_ee_pose[tid]),
# move to next state and reset wait time position_threshold,
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE ):
sm_wait_time[tid] = 0.0 # wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE
sm_wait_time[tid] = 0.0
elif state == OpenDrawerSmState.GRASP_HANDLE: elif state == OpenDrawerSmState.GRASP_HANDLE:
des_ee_pose[tid] = wp.transform_multiply(handle_grasp_offset[tid], handle_pose[tid]) des_ee_pose[tid] = wp.transform_multiply(handle_grasp_offset[tid], handle_pose[tid])
gripper_state[tid] = GripperState.CLOSE gripper_state[tid] = GripperState.CLOSE
...@@ -170,7 +184,7 @@ class OpenDrawerSm: ...@@ -170,7 +184,7 @@ class OpenDrawerSm:
5. RELEASE_HANDLE: The robot releases the handle of the drawer. This is the final state. 5. RELEASE_HANDLE: The robot releases the handle of the drawer. This is the final state.
""" """
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"): def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine. """Initialize the state machine.
Args: Args:
...@@ -182,6 +196,7 @@ class OpenDrawerSm: ...@@ -182,6 +196,7 @@ class OpenDrawerSm:
self.dt = float(dt) self.dt = float(dt)
self.num_envs = num_envs self.num_envs = num_envs
self.device = device self.device = device
self.position_threshold = position_threshold
# initialize state machine # initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device) self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device) self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
...@@ -248,6 +263,7 @@ class OpenDrawerSm: ...@@ -248,6 +263,7 @@ class OpenDrawerSm:
self.handle_approach_offset_wp, self.handle_approach_offset_wp,
self.handle_grasp_offset_wp, self.handle_grasp_offset_wp,
self.drawer_opening_rate_wp, self.drawer_opening_rate_wp,
self.position_threshold,
], ],
device=self.device, device=self.device,
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment