Unverified Commit 6f8ec452 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Adds more details about state in InteractiveScene (#2119)

# Description

It was hard to understand what the scene's state means unless you check
the code. This MR adds more docstrings to make this simpler to
understand.

## Type of change

- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent f8cf6bad
...@@ -319,16 +319,23 @@ class ManagerBasedEnv: ...@@ -319,16 +319,23 @@ class ManagerBasedEnv:
env_ids: Sequence[int] | None, env_ids: Sequence[int] | None,
seed: int | None = None, seed: int | None = None,
is_relative: bool = False, is_relative: bool = False,
) -> None: ):
"""Resets specified environments to known states. """Resets specified environments to provided states.
Note that this is different from reset() function as it resets the environments to specific states This function resets the environments to the provided states. The state is a dictionary
containing the state of the scene entities. Please refer to :meth:`InteractiveScene.get_state`
for the format.
The function is different from the :meth:`reset` function as it resets the environments to specific states,
instead of using the randomization events for resetting the environments.
Args: Args:
state: The state to reset the specified environments to. state: The state to reset the specified environments to. Please refer to
:meth:`InteractiveScene.get_state` for the format.
env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset. env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset.
seed: The seed to use for randomization. Defaults to None, in which case the seed is not set. seed: The seed to use for randomization. Defaults to None, in which case the seed is not set.
is_relative: If set to True, the state is considered relative to the environment origins. Defaults to False. is_relative: If set to True, the state is considered relative to the environment origins.
Defaults to False.
""" """
# reset all envs in the scene if env_ids is None # reset all envs in the scene if env_ids is None
if env_ids is None: if env_ids is None:
......
...@@ -357,54 +357,12 @@ class InteractiveScene: ...@@ -357,54 +357,12 @@ class InteractiveScene:
@property @property
def state(self) -> dict[str, dict[str, dict[str, torch.Tensor]]]: def state(self) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
"""Returns the state of the scene entities. """A dictionary of the state of the scene entities in the simulation world frame.
Returns: Please refer to :meth:`get_state` for the format.
A dictionary of the state of the scene entities.
""" """
return self.get_state(is_relative=False) return self.get_state(is_relative=False)
def get_state(self, is_relative: bool = False) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
"""Returns the state of the scene entities.
Args:
is_relative: If set to True, the state is considered relative to the environment origins.
Returns:
A dictionary of the state of the scene entities.
"""
state = dict()
# articulations
state["articulation"] = dict()
for asset_name, articulation in self._articulations.items():
asset_state = dict()
asset_state["root_pose"] = articulation.data.root_state_w[:, :7].clone()
if is_relative:
asset_state["root_pose"][:, :3] -= self.env_origins
asset_state["root_velocity"] = articulation.data.root_vel_w.clone()
asset_state["joint_position"] = articulation.data.joint_pos.clone()
asset_state["joint_velocity"] = articulation.data.joint_vel.clone()
state["articulation"][asset_name] = asset_state
# deformable objects
state["deformable_object"] = dict()
for asset_name, deformable_object in self._deformable_objects.items():
asset_state = dict()
asset_state["nodal_position"] = deformable_object.data.nodal_pos_w.clone()
if is_relative:
asset_state["nodal_position"][:, :3] -= self.env_origins
asset_state["nodal_velocity"] = deformable_object.data.nodal_vel_w.clone()
state["deformable_object"][asset_name] = asset_state
# rigid objects
state["rigid_object"] = dict()
for asset_name, rigid_object in self._rigid_objects.items():
asset_state = dict()
asset_state["root_pose"] = rigid_object.data.root_state_w[:, :7].clone()
if is_relative:
asset_state["root_pose"][:, :3] -= self.env_origins
asset_state["root_velocity"] = rigid_object.data.root_vel_w.clone()
state["rigid_object"][asset_name] = asset_state
return state
""" """
Operations. Operations.
""" """
...@@ -429,20 +387,57 @@ class InteractiveScene: ...@@ -429,20 +387,57 @@ class InteractiveScene:
for sensor in self._sensors.values(): for sensor in self._sensors.values():
sensor.reset(env_ids) sensor.reset(env_ids)
def write_data_to_sim(self):
"""Writes the data of the scene entities to the simulation."""
# -- assets
for articulation in self._articulations.values():
articulation.write_data_to_sim()
for deformable_object in self._deformable_objects.values():
deformable_object.write_data_to_sim()
for rigid_object in self._rigid_objects.values():
rigid_object.write_data_to_sim()
for rigid_object_collection in self._rigid_object_collections.values():
rigid_object_collection.write_data_to_sim()
def update(self, dt: float) -> None:
"""Update the scene entities.
Args:
dt: The amount of time passed from last :meth:`update` call.
"""
# -- assets
for articulation in self._articulations.values():
articulation.update(dt)
for deformable_object in self._deformable_objects.values():
deformable_object.update(dt)
for rigid_object in self._rigid_objects.values():
rigid_object.update(dt)
for rigid_object_collection in self._rigid_object_collections.values():
rigid_object_collection.update(dt)
# -- sensors
for sensor in self._sensors.values():
sensor.update(dt, force_recompute=not self.cfg.lazy_sensor_update)
"""
Operations: Scene State.
"""
def reset_to( def reset_to(
self, self,
state: dict[str, dict[str, dict[str, torch.Tensor]]], state: dict[str, dict[str, dict[str, torch.Tensor]]],
env_ids: Sequence[int] | None = None, env_ids: Sequence[int] | None = None,
is_relative: bool = False, is_relative: bool = False,
): ):
"""Resets the scene entities to the given state. """Resets the entities in the scene to the provided state.
Args: Args:
state: The state to reset the scene entities to. state: The state to reset the scene entities to. Please refer to :meth:`get_state` for the format.
env_ids: The indices of the environments to reset. env_ids: The indices of the environments to reset. Defaults to None, in which case
Defaults to None (all instances). all environment instances are reset.
is_relative: If set to True, the state is considered relative to the environment origins. is_relative: If set to True, the state is considered relative to the environment origins.
Defaults to False.
""" """
# resolve env_ids
if env_ids is None: if env_ids is None:
env_ids = slice(None) env_ids = slice(None)
# articulations # articulations
...@@ -459,6 +454,8 @@ class InteractiveScene: ...@@ -459,6 +454,8 @@ class InteractiveScene:
joint_position = asset_state["joint_position"].clone() joint_position = asset_state["joint_position"].clone()
joint_velocity = asset_state["joint_velocity"].clone() joint_velocity = asset_state["joint_velocity"].clone()
articulation.write_joint_state_to_sim(joint_position, joint_velocity, env_ids=env_ids) articulation.write_joint_state_to_sim(joint_position, joint_velocity, env_ids=env_ids)
# FIXME: This is not generic as it assumes PD control over the joints.
# This assumption does not hold for effort controlled joints.
articulation.set_joint_position_target(joint_position, env_ids=env_ids) articulation.set_joint_position_target(joint_position, env_ids=env_ids)
articulation.set_joint_velocity_target(joint_velocity, env_ids=env_ids) articulation.set_joint_velocity_target(joint_velocity, env_ids=env_ids)
# deformable objects # deformable objects
...@@ -479,38 +476,93 @@ class InteractiveScene: ...@@ -479,38 +476,93 @@ class InteractiveScene:
root_velocity = asset_state["root_velocity"].clone() root_velocity = asset_state["root_velocity"].clone()
rigid_object.write_root_pose_to_sim(root_pose, env_ids=env_ids) rigid_object.write_root_pose_to_sim(root_pose, env_ids=env_ids)
rigid_object.write_root_velocity_to_sim(root_velocity, env_ids=env_ids) rigid_object.write_root_velocity_to_sim(root_velocity, env_ids=env_ids)
# write data to simulation to make sure initial state is set
# this propagates the joint targets to the simulation
self.write_data_to_sim() self.write_data_to_sim()
def write_data_to_sim(self): def get_state(self, is_relative: bool = False) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
"""Writes the data of the scene entities to the simulation.""" """Returns the state of the scene entities.
# -- assets
for articulation in self._articulations.values():
articulation.write_data_to_sim()
for deformable_object in self._deformable_objects.values():
deformable_object.write_data_to_sim()
for rigid_object in self._rigid_objects.values():
rigid_object.write_data_to_sim()
for rigid_object_collection in self._rigid_object_collections.values():
rigid_object_collection.write_data_to_sim()
def update(self, dt: float) -> None: Based on the type of the entity, the state comprises of different components.
"""Update the scene entities.
* For an articulation, the state comprises of the root pose, root velocity, and joint position and velocity.
* For a deformable object, the state comprises of the nodal position and velocity.
* For a rigid object, the state comprises of the root pose and root velocity.
The returned state is a dictionary with the following format:
.. code-block:: python
{
"articulation": {
"entity_1_name": {
"root_pose": torch.Tensor,
"root_velocity": torch.Tensor,
"joint_position": torch.Tensor,
"joint_velocity": torch.Tensor,
},
"entity_2_name": {
"root_pose": torch.Tensor,
"root_velocity": torch.Tensor,
"joint_position": torch.Tensor,
"joint_velocity": torch.Tensor,
},
},
"deformable_object": {
"entity_3_name": {
"nodal_position": torch.Tensor,
"nodal_velocity": torch.Tensor,
}
},
"rigid_object": {
"entity_4_name": {
"root_pose": torch.Tensor,
"root_velocity": torch.Tensor,
}
},
}
where ``entity_N_name`` is the name of the entity registered in the scene.
Args: Args:
dt: The amount of time passed from last :meth:`update` call. is_relative: If set to True, the state is considered relative to the environment origins.
Defaults to False.
Returns:
A dictionary of the state of the scene entities.
""" """
# -- assets state = dict()
for articulation in self._articulations.values(): # articulations
articulation.update(dt) state["articulation"] = dict()
for deformable_object in self._deformable_objects.values(): for asset_name, articulation in self._articulations.items():
deformable_object.update(dt) asset_state = dict()
for rigid_object in self._rigid_objects.values(): asset_state["root_pose"] = articulation.data.root_state_w[:, :7].clone()
rigid_object.update(dt) if is_relative:
for rigid_object_collection in self._rigid_object_collections.values(): asset_state["root_pose"][:, :3] -= self.env_origins
rigid_object_collection.update(dt) asset_state["root_velocity"] = articulation.data.root_vel_w.clone()
# -- sensors asset_state["joint_position"] = articulation.data.joint_pos.clone()
for sensor in self._sensors.values(): asset_state["joint_velocity"] = articulation.data.joint_vel.clone()
sensor.update(dt, force_recompute=not self.cfg.lazy_sensor_update) state["articulation"][asset_name] = asset_state
# deformable objects
state["deformable_object"] = dict()
for asset_name, deformable_object in self._deformable_objects.items():
asset_state = dict()
asset_state["nodal_position"] = deformable_object.data.nodal_pos_w.clone()
if is_relative:
asset_state["nodal_position"][:, :3] -= self.env_origins
asset_state["nodal_velocity"] = deformable_object.data.nodal_vel_w.clone()
state["deformable_object"][asset_name] = asset_state
# rigid objects
state["rigid_object"] = dict()
for asset_name, rigid_object in self._rigid_objects.items():
asset_state = dict()
asset_state["root_pose"] = rigid_object.data.root_state_w[:, :7].clone()
if is_relative:
asset_state["root_pose"][:, :3] -= self.env_origins
asset_state["root_velocity"] = rigid_object.data.root_vel_w.clone()
state["rigid_object"][asset_name] = asset_state
return state
""" """
Operations: Iteration. Operations: Iteration.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment