Unverified Commit 2dc138a5 authored by Nikita Rudin's avatar Nikita Rudin Committed by GitHub

Adds method to get a specific term from the action manager (#427)

# Description

Added option to get specific term by name from the action manager.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 61c4fd9d
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.12.0" version = "0.12.1"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.12.1 (2024-03-09)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added an option to the last actions observation term to get a specific term by name from the action manager.
If None, the behavior remains the same as before (the entire action is returned).
0.12.0 (2024-03-08) 0.12.0 (2024-03-08)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -101,6 +101,7 @@ Sensors. ...@@ -101,6 +101,7 @@ Sensors.
def height_scan(env: BaseEnv, sensor_cfg: SceneEntityCfg, offset: float = 0.5) -> torch.Tensor: def height_scan(env: BaseEnv, sensor_cfg: SceneEntityCfg, offset: float = 0.5) -> torch.Tensor:
"""Height scan from the given sensor w.r.t. the sensor's frame. """Height scan from the given sensor w.r.t. the sensor's frame.
The provided offset (Defaults to 0.5) is subtracted from the returned values. The provided offset (Defaults to 0.5) is subtracted from the returned values.
""" """
# extract the used quantities (to enable type-hinting) # extract the used quantities (to enable type-hinting)
...@@ -126,9 +127,16 @@ Actions. ...@@ -126,9 +127,16 @@ Actions.
""" """
def last_action(env: BaseEnv) -> torch.Tensor: def last_action(env: BaseEnv, action_name: str | None = None) -> torch.Tensor:
"""The last input action to the environment.""" """The last input action to the environment.
return env.action_manager.action
The name of the action term for which the action is required. If None, the
entire action tensor is returned.
"""
if action_name is None:
return env.action_manager.action
else:
return env.action_manager.get_term(action_name).raw_actions
""" """
......
...@@ -134,7 +134,7 @@ class ActionManager(ManagerBase): ...@@ -134,7 +134,7 @@ class ActionManager(ManagerBase):
table.align["Name"] = "l" table.align["Name"] = "l"
table.align["Dimension"] = "r" table.align["Dimension"] = "r"
# add info on each term # add info on each term
for index, (name, term) in enumerate(zip(self._term_names, self._terms)): for index, (name, term) in enumerate(self._terms.items()):
table.add_row([index, name, term.action_dim]) table.add_row([index, name, term.action_dim])
# convert table to string # convert table to string
msg += table.get_string() msg += table.get_string()
...@@ -159,7 +159,7 @@ class ActionManager(ManagerBase): ...@@ -159,7 +159,7 @@ class ActionManager(ManagerBase):
@property @property
def action_term_dim(self) -> list[int]: def action_term_dim(self) -> list[int]:
"""Shape of each action term.""" """Shape of each action term."""
return [term.action_dim for term in self._terms] return [term.action_dim for term in self._terms.values()]
@property @property
def action(self) -> torch.Tensor: def action(self) -> torch.Tensor:
...@@ -192,7 +192,7 @@ class ActionManager(ManagerBase): ...@@ -192,7 +192,7 @@ class ActionManager(ManagerBase):
self._prev_action[env_ids] = 0.0 self._prev_action[env_ids] = 0.0
self._action[env_ids] = 0.0 self._action[env_ids] = 0.0
# reset all action terms # reset all action terms
for term in self._terms: for term in self._terms.values():
term.reset(env_ids=env_ids) term.reset(env_ids=env_ids)
# nothing to log here # nothing to log here
return {} return {}
...@@ -215,7 +215,7 @@ class ActionManager(ManagerBase): ...@@ -215,7 +215,7 @@ class ActionManager(ManagerBase):
# split the actions and apply to each tensor # split the actions and apply to each tensor
idx = 0 idx = 0
for term in self._terms: for term in self._terms.values():
term_actions = action[:, idx : idx + term.action_dim] term_actions = action[:, idx : idx + term.action_dim]
term.process_actions(term_actions) term.process_actions(term_actions)
idx += term.action_dim idx += term.action_dim
...@@ -226,9 +226,20 @@ class ActionManager(ManagerBase): ...@@ -226,9 +226,20 @@ class ActionManager(ManagerBase):
Note: Note:
This should be called at every simulation step. This should be called at every simulation step.
""" """
for term in self._terms: for term in self._terms.values():
term.apply_actions() term.apply_actions()
def get_term(self, name: str) -> ActionTerm:
"""Returns the action term with the specified name.
Args:
name: The name of the action term.
Returns:
The action term with the specified name.
"""
return self._terms[name]
""" """
Helper functions. Helper functions.
""" """
...@@ -237,7 +248,7 @@ class ActionManager(ManagerBase): ...@@ -237,7 +248,7 @@ class ActionManager(ManagerBase):
"""Prepares a list of action terms.""" """Prepares a list of action terms."""
# parse action terms from the config # parse action terms from the config
self._term_names: list[str] = list() self._term_names: list[str] = list()
self._terms: list[ActionTerm] = list() self._terms: dict[str, ActionTerm] = dict()
# check if config is dict already # check if config is dict already
if isinstance(self.cfg, dict): if isinstance(self.cfg, dict):
...@@ -261,4 +272,4 @@ class ActionManager(ManagerBase): ...@@ -261,4 +272,4 @@ class ActionManager(ManagerBase):
raise TypeError(f"Returned object for the term '{term_name}' is not of type ActionType.") raise TypeError(f"Returned object for the term '{term_name}' is not of type ActionType.")
# add term name and parameters # add term name and parameters
self._term_names.append(term_name) self._term_names.append(term_name)
self._terms.append(term) self._terms[term_name] = term
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment