Unverified Commit 1103a0f3 authored by yijieg's avatar yijieg Committed by GitHub

Fixes cuda version as float for AutoMate to correctly convert patch versions (#3795)

# Description

To convert cuda version from a string to a float, I update the function
to handle cases with multiple points, e.g. string '12.8.9' will be
converted to float 12.89. Before, float('12.8.9') will return None for
failure conversion.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have read and understood the [contribution
guidelines](https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html)
- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent b70bd42a
...@@ -61,7 +61,7 @@ class AssemblyEnv(DirectRLEnv): ...@@ -61,7 +61,7 @@ class AssemblyEnv(DirectRLEnv):
# Create criterion for dynamic time warping (later used for imitation reward) # Create criterion for dynamic time warping (later used for imitation reward)
cuda_version = automate_algo.get_cuda_version() cuda_version = automate_algo.get_cuda_version()
if (cuda_version is not None) and (cuda_version < 13.0): if (cuda_version is not None) and (cuda_version < (13, 0, 0)):
self.soft_dtw_criterion = SoftDTW(use_cuda=True, device=self.device, gamma=self.cfg_task.soft_dtw_gamma) self.soft_dtw_criterion = SoftDTW(use_cuda=True, device=self.device, gamma=self.cfg_task.soft_dtw_gamma)
else: else:
self.soft_dtw_criterion = SoftDTW(use_cuda=False, device=self.device, gamma=self.cfg_task.soft_dtw_gamma) self.soft_dtw_criterion = SoftDTW(use_cuda=False, device=self.device, gamma=self.cfg_task.soft_dtw_gamma)
......
...@@ -25,6 +25,28 @@ Util Functions ...@@ -25,6 +25,28 @@ Util Functions
""" """
def parse_cuda_version(version_string):
"""
Parse CUDA version string into comparable tuple of (major, minor, patch).
Args:
version_string: Version string like "12.8.9" or "11.2"
Returns:
Tuple of (major, minor, patch) as integers, where patch defaults to 0 iff
not present.
Example:
"12.8.9" -> (12, 8, 9)
"11.2" -> (11, 2, 0)
"""
parts = version_string.split(".")
major = int(parts[0])
minor = int(parts[1]) if len(parts) > 1 else 0
patch = int(parts[2]) if len(parts) > 2 else 0
return (major, minor, patch)
def get_cuda_version(): def get_cuda_version():
try: try:
# Execute nvcc --version command # Execute nvcc --version command
...@@ -34,7 +56,7 @@ def get_cuda_version(): ...@@ -34,7 +56,7 @@ def get_cuda_version():
# Use regex to find the CUDA version (e.g., V11.2.67) # Use regex to find the CUDA version (e.g., V11.2.67)
match = re.search(r"V(\d+\.\d+(\.\d+)?)", output) match = re.search(r"V(\d+\.\d+(\.\d+)?)", output)
if match: if match:
return float(match.group(1)) return parse_cuda_version(match.group(1))
else: else:
print("CUDA version not found in output.") print("CUDA version not found in output.")
return None return None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment