Unverified Commit 2dc138a5 authored by Nikita Rudin's avatar Nikita Rudin Committed by GitHub

Adds method to get a specific term from the action manager (#427)

# Description

Added option to get specific term by name from the action manager.

## Type of change

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 61c4fd9d
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.12.0"
version = "0.12.1"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.12.1 (2024-03-09)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added an option to the last actions observation term to get a specific term by name from the action manager.
If None, the behavior remains the same as before (the entire action is returned).
0.12.0 (2024-03-08)
~~~~~~~~~~~~~~~~~~~
......
......@@ -101,6 +101,7 @@ Sensors.
def height_scan(env: BaseEnv, sensor_cfg: SceneEntityCfg, offset: float = 0.5) -> torch.Tensor:
"""Height scan from the given sensor w.r.t. the sensor's frame.
The provided offset (Defaults to 0.5) is subtracted from the returned values.
"""
# extract the used quantities (to enable type-hinting)
......@@ -126,9 +127,16 @@ Actions.
"""
def last_action(env: BaseEnv) -> torch.Tensor:
"""The last input action to the environment."""
def last_action(env: BaseEnv, action_name: str | None = None) -> torch.Tensor:
"""The last input action to the environment.
The name of the action term for which the action is required. If None, the
entire action tensor is returned.
"""
if action_name is None:
return env.action_manager.action
else:
return env.action_manager.get_term(action_name).raw_actions
"""
......
......@@ -134,7 +134,7 @@ class ActionManager(ManagerBase):
table.align["Name"] = "l"
table.align["Dimension"] = "r"
# add info on each term
for index, (name, term) in enumerate(zip(self._term_names, self._terms)):
for index, (name, term) in enumerate(self._terms.items()):
table.add_row([index, name, term.action_dim])
# convert table to string
msg += table.get_string()
......@@ -159,7 +159,7 @@ class ActionManager(ManagerBase):
@property
def action_term_dim(self) -> list[int]:
"""Shape of each action term."""
return [term.action_dim for term in self._terms]
return [term.action_dim for term in self._terms.values()]
@property
def action(self) -> torch.Tensor:
......@@ -192,7 +192,7 @@ class ActionManager(ManagerBase):
self._prev_action[env_ids] = 0.0
self._action[env_ids] = 0.0
# reset all action terms
for term in self._terms:
for term in self._terms.values():
term.reset(env_ids=env_ids)
# nothing to log here
return {}
......@@ -215,7 +215,7 @@ class ActionManager(ManagerBase):
# split the actions and apply to each tensor
idx = 0
for term in self._terms:
for term in self._terms.values():
term_actions = action[:, idx : idx + term.action_dim]
term.process_actions(term_actions)
idx += term.action_dim
......@@ -226,9 +226,20 @@ class ActionManager(ManagerBase):
Note:
This should be called at every simulation step.
"""
for term in self._terms:
for term in self._terms.values():
term.apply_actions()
def get_term(self, name: str) -> ActionTerm:
"""Returns the action term with the specified name.
Args:
name: The name of the action term.
Returns:
The action term with the specified name.
"""
return self._terms[name]
"""
Helper functions.
"""
......@@ -237,7 +248,7 @@ class ActionManager(ManagerBase):
"""Prepares a list of action terms."""
# parse action terms from the config
self._term_names: list[str] = list()
self._terms: list[ActionTerm] = list()
self._terms: dict[str, ActionTerm] = dict()
# check if config is dict already
if isinstance(self.cfg, dict):
......@@ -261,4 +272,4 @@ class ActionManager(ManagerBase):
raise TypeError(f"Returned object for the term '{term_name}' is not of type ActionType.")
# add term name and parameters
self._term_names.append(term_name)
self._terms.append(term)
self._terms[term_name] = term
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment