Unverified Commit a958ac56 authored by shauryadNv's avatar shauryadNv Committed by GitHub

Updates cosmos test files to use pytest (#548)

# Description

Updated Mimic-Cosmos related tests to use pytest.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 24e83d72
...@@ -8,66 +8,65 @@ ...@@ -8,66 +8,65 @@
import json import json
import os import os
import tempfile import tempfile
import unittest
import pytest
from scripts.tools.cosmos.cosmos_prompt_gen import generate_prompt, main from scripts.tools.cosmos.cosmos_prompt_gen import generate_prompt, main
class TestCosmosPromptGen(unittest.TestCase): @pytest.fixture(scope="class")
def temp_templates_file():
"""Create temporary templates file."""
temp_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
# Create test templates
test_templates = {
"lighting": ["with bright lighting", "with dim lighting", "with natural lighting"],
"color": ["in warm colors", "in cool colors", "in vibrant colors"],
"style": ["in a realistic style", "in an artistic style", "in a minimalist style"],
"empty_section": [], # Test empty section
"invalid_section": "not a list", # Test invalid section
}
# Write templates to file
with open(temp_file.name, "w") as f:
json.dump(test_templates, f)
yield temp_file.name
# Cleanup
os.remove(temp_file.name)
@pytest.fixture
def temp_output_file():
"""Create temporary output file."""
temp_file = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
yield temp_file.name
# Cleanup
os.remove(temp_file.name)
class TestCosmosPromptGen:
"""Test cases for Cosmos prompt generation functionality.""" """Test cases for Cosmos prompt generation functionality."""
@classmethod def test_generate_prompt_valid_templates(self, temp_templates_file):
def setUpClass(cls):
"""Set up test fixtures that are shared across all test methods."""
# Create temporary templates file
cls.temp_templates_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
# Create test templates
test_templates = {
"lighting": ["with bright lighting", "with dim lighting", "with natural lighting"],
"color": ["in warm colors", "in cool colors", "in vibrant colors"],
"style": ["in a realistic style", "in an artistic style", "in a minimalist style"],
"empty_section": [], # Test empty section
"invalid_section": "not a list", # Test invalid section
}
# Write templates to file
with open(cls.temp_templates_file.name, "w") as f:
json.dump(test_templates, f)
def setUp(self):
"""Set up test fixtures that are created for each test method."""
self.temp_output_file = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
def tearDown(self):
"""Clean up test fixtures after each test method."""
# Remove the temporary output file
os.remove(self.temp_output_file.name)
@classmethod
def tearDownClass(cls):
"""Clean up test fixtures that are shared across all test methods."""
# Remove the temporary templates file
os.remove(cls.temp_templates_file.name)
def test_generate_prompt_valid_templates(self):
"""Test generating a prompt with valid templates.""" """Test generating a prompt with valid templates."""
prompt = generate_prompt(self.temp_templates_file.name) prompt = generate_prompt(temp_templates_file)
# Check that prompt is a string # Check that prompt is a string
self.assertIsInstance(prompt, str) assert isinstance(prompt, str)
# Check that prompt contains at least one word # Check that prompt contains at least one word
self.assertTrue(len(prompt.split()) > 0) assert len(prompt.split()) > 0
# Check that prompt contains valid sections # Check that prompt contains valid sections
valid_sections = ["lighting", "color", "style"] valid_sections = ["lighting", "color", "style"]
found_sections = [section for section in valid_sections if section in prompt.lower()] found_sections = [section for section in valid_sections if section in prompt.lower()]
self.assertTrue(len(found_sections) > 0) assert len(found_sections) > 0
def test_generate_prompt_invalid_file(self): def test_generate_prompt_invalid_file(self):
"""Test generating a prompt with invalid file path.""" """Test generating a prompt with invalid file path."""
with self.assertRaises(FileNotFoundError): with pytest.raises(FileNotFoundError):
generate_prompt("nonexistent_file.json") generate_prompt("nonexistent_file.json")
def test_generate_prompt_invalid_json(self): def test_generate_prompt_invalid_json(self):
...@@ -78,12 +77,12 @@ class TestCosmosPromptGen(unittest.TestCase): ...@@ -78,12 +77,12 @@ class TestCosmosPromptGen(unittest.TestCase):
temp_file.flush() temp_file.flush()
try: try:
with self.assertRaises(ValueError): with pytest.raises(ValueError):
generate_prompt(temp_file.name) generate_prompt(temp_file.name)
finally: finally:
os.remove(temp_file.name) os.remove(temp_file.name)
def test_main_function_single_prompt(self): def test_main_function_single_prompt(self, temp_templates_file, temp_output_file):
"""Test main function with single prompt generation.""" """Test main function with single prompt generation."""
# Mock command line arguments # Mock command line arguments
import sys import sys
...@@ -92,29 +91,29 @@ class TestCosmosPromptGen(unittest.TestCase): ...@@ -92,29 +91,29 @@ class TestCosmosPromptGen(unittest.TestCase):
sys.argv = [ sys.argv = [
"cosmos_prompt_gen.py", "cosmos_prompt_gen.py",
"--templates_path", "--templates_path",
self.temp_templates_file.name, temp_templates_file,
"--num_prompts", "--num_prompts",
"1", "1",
"--output_path", "--output_path",
self.temp_output_file.name, temp_output_file,
] ]
try: try:
main() main()
# Check if output file was created # Check if output file was created
self.assertTrue(os.path.exists(self.temp_output_file.name)) assert os.path.exists(temp_output_file)
# Check content of output file # Check content of output file
with open(self.temp_output_file.name) as f: with open(temp_output_file) as f:
content = f.read().strip() content = f.read().strip()
self.assertTrue(len(content) > 0) assert len(content) > 0
self.assertEqual(len(content.split("\n")), 1) assert len(content.split("\n")) == 1
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
def test_main_function_multiple_prompts(self): def test_main_function_multiple_prompts(self, temp_templates_file, temp_output_file):
"""Test main function with multiple prompt generation.""" """Test main function with multiple prompt generation."""
# Mock command line arguments # Mock command line arguments
import sys import sys
...@@ -123,52 +122,48 @@ class TestCosmosPromptGen(unittest.TestCase): ...@@ -123,52 +122,48 @@ class TestCosmosPromptGen(unittest.TestCase):
sys.argv = [ sys.argv = [
"cosmos_prompt_gen.py", "cosmos_prompt_gen.py",
"--templates_path", "--templates_path",
self.temp_templates_file.name, temp_templates_file,
"--num_prompts", "--num_prompts",
"3", "3",
"--output_path", "--output_path",
self.temp_output_file.name, temp_output_file,
] ]
try: try:
main() main()
# Check if output file was created # Check if output file was created
self.assertTrue(os.path.exists(self.temp_output_file.name)) assert os.path.exists(temp_output_file)
# Check content of output file # Check content of output file
with open(self.temp_output_file.name) as f: with open(temp_output_file) as f:
content = f.read().strip() content = f.read().strip()
self.assertTrue(len(content) > 0) assert len(content) > 0
self.assertEqual(len(content.split("\n")), 3) assert len(content.split("\n")) == 3
# Check that each line is a valid prompt # Check that each line is a valid prompt
for line in content.split("\n"): for line in content.split("\n"):
self.assertTrue(len(line) > 0) assert len(line) > 0
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
def test_main_function_default_output(self): def test_main_function_default_output(self, temp_templates_file):
"""Test main function with default output path.""" """Test main function with default output path."""
# Mock command line arguments # Mock command line arguments
import sys import sys
original_argv = sys.argv original_argv = sys.argv
sys.argv = ["cosmos_prompt_gen.py", "--templates_path", self.temp_templates_file.name, "--num_prompts", "1"] sys.argv = ["cosmos_prompt_gen.py", "--templates_path", temp_templates_file, "--num_prompts", "1"]
try: try:
main() main()
# Check if default output file was created # Check if default output file was created
self.assertTrue(os.path.exists("prompts.txt")) assert os.path.exists("prompts.txt")
# Clean up default output file # Clean up default output file
os.remove("prompts.txt") os.remove("prompts.txt")
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
if __name__ == "__main__":
unittest.main()
...@@ -9,138 +9,128 @@ import h5py ...@@ -9,138 +9,128 @@ import h5py
import numpy as np import numpy as np
import os import os
import tempfile import tempfile
import unittest
import pytest
from scripts.tools.hdf5_to_mp4 import get_num_demos, main, write_demo_to_mp4 from scripts.tools.hdf5_to_mp4 import get_num_demos, main, write_demo_to_mp4
class TestHDF5ToMP4(unittest.TestCase): @pytest.fixture(scope="class")
def temp_hdf5_file():
"""Create temporary HDF5 file with test data."""
temp_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False)
with h5py.File(temp_file.name, "w") as h5f:
# Create test data structure
for demo_id in range(2): # Create 2 demos
demo_group = h5f.create_group(f"data/demo_{demo_id}/obs")
# Create RGB frames (2 frames per demo)
rgb_data = np.random.randint(0, 255, (2, 704, 1280, 3), dtype=np.uint8)
demo_group.create_dataset("table_cam", data=rgb_data)
# Create segmentation frames
seg_data = np.random.randint(0, 255, (2, 704, 1280, 4), dtype=np.uint8)
demo_group.create_dataset("table_cam_segmentation", data=seg_data)
# Create normal maps
normals_data = np.random.rand(2, 704, 1280, 3).astype(np.float32)
demo_group.create_dataset("table_cam_normals", data=normals_data)
# Create depth maps
depth_data = np.random.rand(2, 704, 1280, 1).astype(np.float32)
demo_group.create_dataset("table_cam_depth", data=depth_data)
yield temp_file.name
# Cleanup
os.remove(temp_file.name)
@pytest.fixture
def temp_output_dir():
"""Create temporary output directory."""
temp_dir = tempfile.mkdtemp()
yield temp_dir
# Cleanup
for file in os.listdir(temp_dir):
os.remove(os.path.join(temp_dir, file))
os.rmdir(temp_dir)
class TestHDF5ToMP4:
"""Test cases for HDF5 to MP4 conversion functionality.""" """Test cases for HDF5 to MP4 conversion functionality."""
@classmethod def test_get_num_demos(self, temp_hdf5_file):
def setUpClass(cls):
"""Set up test fixtures that are shared across all test methods."""
# Create temporary HDF5 file with test data
cls.temp_hdf5_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False)
with h5py.File(cls.temp_hdf5_file.name, "w") as h5f:
# Create test data structure
for demo_id in range(2): # Create 2 demos
demo_group = h5f.create_group(f"data/demo_{demo_id}/obs")
# Create RGB frames (2 frames per demo)
rgb_data = np.random.randint(0, 255, (2, 704, 1280, 3), dtype=np.uint8)
demo_group.create_dataset("table_cam", data=rgb_data)
# Create segmentation frames
seg_data = np.random.randint(0, 255, (2, 704, 1280, 4), dtype=np.uint8)
demo_group.create_dataset("table_cam_segmentation", data=seg_data)
# Create normal maps
normals_data = np.random.rand(2, 704, 1280, 3).astype(np.float32)
demo_group.create_dataset("table_cam_normals", data=normals_data)
# Create depth maps
depth_data = np.random.rand(2, 704, 1280, 1).astype(np.float32)
demo_group.create_dataset("table_cam_depth", data=depth_data)
def setUp(self):
"""Set up test fixtures that are created for each test method."""
self.temp_output_dir = tempfile.mkdtemp()
def tearDown(self):
"""Clean up test fixtures after each test method."""
# Remove all files in the output directory
for file in os.listdir(self.temp_output_dir):
os.remove(os.path.join(self.temp_output_dir, file))
# Remove the output directory
os.rmdir(self.temp_output_dir)
@classmethod
def tearDownClass(cls):
"""Clean up test fixtures that are shared across all test methods."""
# Remove the temporary HDF5 file
os.remove(cls.temp_hdf5_file.name)
def test_get_num_demos(self):
"""Test the get_num_demos function.""" """Test the get_num_demos function."""
num_demos = get_num_demos(self.temp_hdf5_file.name) num_demos = get_num_demos(temp_hdf5_file)
self.assertEqual(num_demos, 2) assert num_demos == 2
def test_write_demo_to_mp4_rgb(self): def test_write_demo_to_mp4_rgb(self, temp_hdf5_file, temp_output_dir):
"""Test writing RGB frames to MP4.""" """Test writing RGB frames to MP4."""
write_demo_to_mp4(self.temp_hdf5_file.name, 0, "data/demo_0/obs", "table_cam", self.temp_output_dir, 704, 1280) write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam", temp_output_dir, 704, 1280)
output_file = os.path.join(self.temp_output_dir, "demo_0_table_cam.mp4") output_file = os.path.join(temp_output_dir, "demo_0_table_cam.mp4")
self.assertTrue(os.path.exists(output_file)) assert os.path.exists(output_file)
self.assertGreater(os.path.getsize(output_file), 0) assert os.path.getsize(output_file) > 0
def test_write_demo_to_mp4_segmentation(self): def test_write_demo_to_mp4_segmentation(self, temp_hdf5_file, temp_output_dir):
"""Test writing segmentation frames to MP4.""" """Test writing segmentation frames to MP4."""
write_demo_to_mp4( write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam_segmentation", temp_output_dir, 704, 1280)
self.temp_hdf5_file.name, 0, "data/demo_0/obs", "table_cam_segmentation", self.temp_output_dir, 704, 1280
)
output_file = os.path.join(self.temp_output_dir, "demo_0_table_cam_segmentation.mp4") output_file = os.path.join(temp_output_dir, "demo_0_table_cam_segmentation.mp4")
self.assertTrue(os.path.exists(output_file)) assert os.path.exists(output_file)
self.assertGreater(os.path.getsize(output_file), 0) assert os.path.getsize(output_file) > 0
def test_write_demo_to_mp4_normals(self): def test_write_demo_to_mp4_normals(self, temp_hdf5_file, temp_output_dir):
"""Test writing normal maps to MP4.""" """Test writing normal maps to MP4."""
write_demo_to_mp4( write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam_normals", temp_output_dir, 704, 1280)
self.temp_hdf5_file.name, 0, "data/demo_0/obs", "table_cam_normals", self.temp_output_dir, 704, 1280
)
output_file = os.path.join(self.temp_output_dir, "demo_0_table_cam_normals.mp4") output_file = os.path.join(temp_output_dir, "demo_0_table_cam_normals.mp4")
self.assertTrue(os.path.exists(output_file)) assert os.path.exists(output_file)
self.assertGreater(os.path.getsize(output_file), 0) assert os.path.getsize(output_file) > 0
def test_write_demo_to_mp4_shaded_segmentation(self): def test_write_demo_to_mp4_shaded_segmentation(self, temp_hdf5_file, temp_output_dir):
"""Test writing shaded_segmentation frames to MP4.""" """Test writing shaded_segmentation frames to MP4."""
write_demo_to_mp4( write_demo_to_mp4(
self.temp_hdf5_file.name, temp_hdf5_file,
0, 0,
"data/demo_0/obs", "data/demo_0/obs",
"table_cam_shaded_segmentation", "table_cam_shaded_segmentation",
self.temp_output_dir, temp_output_dir,
704, 704,
1280, 1280,
) )
output_file = os.path.join(self.temp_output_dir, "demo_0_table_cam_shaded_segmentation.mp4") output_file = os.path.join(temp_output_dir, "demo_0_table_cam_shaded_segmentation.mp4")
self.assertTrue(os.path.exists(output_file)) assert os.path.exists(output_file)
self.assertGreater(os.path.getsize(output_file), 0) assert os.path.getsize(output_file) > 0
def test_write_demo_to_mp4_depth(self): def test_write_demo_to_mp4_depth(self, temp_hdf5_file, temp_output_dir):
"""Test writing depth maps to MP4.""" """Test writing depth maps to MP4."""
write_demo_to_mp4( write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam_depth", temp_output_dir, 704, 1280)
self.temp_hdf5_file.name, 0, "data/demo_0/obs", "table_cam_depth", self.temp_output_dir, 704, 1280
)
output_file = os.path.join(self.temp_output_dir, "demo_0_table_cam_depth.mp4") output_file = os.path.join(temp_output_dir, "demo_0_table_cam_depth.mp4")
self.assertTrue(os.path.exists(output_file)) assert os.path.exists(output_file)
self.assertGreater(os.path.getsize(output_file), 0) assert os.path.getsize(output_file) > 0
def test_write_demo_to_mp4_invalid_demo(self): def test_write_demo_to_mp4_invalid_demo(self, temp_hdf5_file, temp_output_dir):
"""Test writing with invalid demo ID.""" """Test writing with invalid demo ID."""
with self.assertRaises(KeyError): with pytest.raises(KeyError):
write_demo_to_mp4( write_demo_to_mp4(
self.temp_hdf5_file.name, temp_hdf5_file,
999, # Invalid demo ID 999, # Invalid demo ID
"data/demo_999/obs", "data/demo_999/obs",
"table_cam", "table_cam",
self.temp_output_dir, temp_output_dir,
704, 704,
1280, 1280,
) )
def test_write_demo_to_mp4_invalid_key(self): def test_write_demo_to_mp4_invalid_key(self, temp_hdf5_file, temp_output_dir):
"""Test writing with invalid input key.""" """Test writing with invalid input key."""
with self.assertRaises(KeyError): with pytest.raises(KeyError):
write_demo_to_mp4( write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "invalid_key", temp_output_dir, 704, 1280)
self.temp_hdf5_file.name, 0, "data/demo_0/obs", "invalid_key", self.temp_output_dir, 704, 1280
)
def test_main_function(self): def test_main_function(self, temp_hdf5_file, temp_output_dir):
"""Test the main function.""" """Test the main function."""
# Mock command line arguments # Mock command line arguments
import sys import sys
...@@ -149,9 +139,9 @@ class TestHDF5ToMP4(unittest.TestCase): ...@@ -149,9 +139,9 @@ class TestHDF5ToMP4(unittest.TestCase):
sys.argv = [ sys.argv = [
"hdf5_to_mp4.py", "hdf5_to_mp4.py",
"--input_file", "--input_file",
self.temp_hdf5_file.name, temp_hdf5_file,
"--output_dir", "--output_dir",
self.temp_output_dir, temp_output_dir,
"--input_keys", "--input_keys",
"table_cam", "table_cam",
"table_cam_segmentation", "table_cam_segmentation",
...@@ -175,13 +165,9 @@ class TestHDF5ToMP4(unittest.TestCase): ...@@ -175,13 +165,9 @@ class TestHDF5ToMP4(unittest.TestCase):
] ]
for file in expected_files: for file in expected_files:
output_file = os.path.join(self.temp_output_dir, file) output_file = os.path.join(temp_output_dir, file)
self.assertTrue(os.path.exists(output_file)) assert os.path.exists(output_file)
self.assertGreater(os.path.getsize(output_file), 0) assert os.path.getsize(output_file) > 0
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
if __name__ == "__main__":
unittest.main()
...@@ -9,130 +9,137 @@ import h5py ...@@ -9,130 +9,137 @@ import h5py
import numpy as np import numpy as np
import os import os
import tempfile import tempfile
import unittest
import cv2 import cv2
import pytest
from scripts.tools.mp4_to_hdf5 import get_frames_from_mp4, main, process_video_and_demo from scripts.tools.mp4_to_hdf5 import get_frames_from_mp4, main, process_video_and_demo
class TestMP4ToHDF5(unittest.TestCase): @pytest.fixture(scope="class")
def temp_hdf5_file():
"""Create temporary HDF5 file with test data."""
temp_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False)
with h5py.File(temp_file.name, "w") as h5f:
# Create test data structure for 2 demos
for demo_id in range(2):
demo_group = h5f.create_group(f"data/demo_{demo_id}")
obs_group = demo_group.create_group("obs")
# Create actions data
actions_data = np.random.rand(10, 7).astype(np.float32)
demo_group.create_dataset("actions", data=actions_data)
# Create robot state data
eef_pos_data = np.random.rand(10, 3).astype(np.float32)
eef_quat_data = np.random.rand(10, 4).astype(np.float32)
gripper_pos_data = np.random.rand(10, 1).astype(np.float32)
obs_group.create_dataset("eef_pos", data=eef_pos_data)
obs_group.create_dataset("eef_quat", data=eef_quat_data)
obs_group.create_dataset("gripper_pos", data=gripper_pos_data)
# Create camera data
table_cam_data = np.random.randint(0, 255, (10, 704, 1280, 3), dtype=np.uint8)
wrist_cam_data = np.random.randint(0, 255, (10, 704, 1280, 3), dtype=np.uint8)
obs_group.create_dataset("table_cam", data=table_cam_data)
obs_group.create_dataset("wrist_cam", data=wrist_cam_data)
# Set attributes
demo_group.attrs["num_samples"] = 10
yield temp_file.name
# Cleanup
os.remove(temp_file.name)
@pytest.fixture(scope="class")
def temp_videos_dir():
"""Create temporary MP4 files."""
temp_dir = tempfile.mkdtemp()
video_paths = []
for demo_id in range(2):
video_path = os.path.join(temp_dir, f"demo_{demo_id}_table_cam.mp4")
video_paths.append(video_path)
# Create a test video
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video = cv2.VideoWriter(video_path, fourcc, 30, (1280, 704))
# Write some random frames
for _ in range(10):
frame = np.random.randint(0, 255, (704, 1280, 3), dtype=np.uint8)
video.write(frame)
video.release()
yield temp_dir, video_paths
# Cleanup
for video_path in video_paths:
os.remove(video_path)
os.rmdir(temp_dir)
@pytest.fixture
def temp_output_file():
"""Create temporary output file."""
temp_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False)
yield temp_file.name
# Cleanup
os.remove(temp_file.name)
class TestMP4ToHDF5:
"""Test cases for MP4 to HDF5 conversion functionality.""" """Test cases for MP4 to HDF5 conversion functionality."""
@classmethod def test_get_frames_from_mp4(self, temp_videos_dir):
def setUpClass(cls):
"""Set up test fixtures that are shared across all test methods."""
# Create temporary HDF5 file with test data
cls.temp_hdf5_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False)
with h5py.File(cls.temp_hdf5_file.name, "w") as h5f:
# Create test data structure for 2 demos
for demo_id in range(2):
demo_group = h5f.create_group(f"data/demo_{demo_id}")
obs_group = demo_group.create_group("obs")
# Create actions data
actions_data = np.random.rand(10, 7).astype(np.float32)
demo_group.create_dataset("actions", data=actions_data)
# Create robot state data
eef_pos_data = np.random.rand(10, 3).astype(np.float32)
eef_quat_data = np.random.rand(10, 4).astype(np.float32)
gripper_pos_data = np.random.rand(10, 1).astype(np.float32)
obs_group.create_dataset("eef_pos", data=eef_pos_data)
obs_group.create_dataset("eef_quat", data=eef_quat_data)
obs_group.create_dataset("gripper_pos", data=gripper_pos_data)
# Create camera data
table_cam_data = np.random.randint(0, 255, (10, 704, 1280, 3), dtype=np.uint8)
wrist_cam_data = np.random.randint(0, 255, (10, 704, 1280, 3), dtype=np.uint8)
obs_group.create_dataset("table_cam", data=table_cam_data)
obs_group.create_dataset("wrist_cam", data=wrist_cam_data)
# Set attributes
demo_group.attrs["num_samples"] = 10
# Create temporary MP4 files
cls.temp_videos_dir = tempfile.mkdtemp()
cls.video_paths = []
for demo_id in range(2):
video_path = os.path.join(cls.temp_videos_dir, f"demo_{demo_id}_table_cam.mp4")
cls.video_paths.append(video_path)
# Create a test video
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video = cv2.VideoWriter(video_path, fourcc, 30, (1280, 704))
# Write some random frames
for _ in range(10):
frame = np.random.randint(0, 255, (704, 1280, 3), dtype=np.uint8)
video.write(frame)
video.release()
def setUp(self):
"""Set up test fixtures that are created for each test method."""
self.temp_output_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False)
def tearDown(self):
"""Clean up test fixtures after each test method."""
# Remove the temporary output file
os.remove(self.temp_output_file.name)
@classmethod
def tearDownClass(cls):
"""Clean up test fixtures that are shared across all test methods."""
# Remove the temporary HDF5 file
os.remove(cls.temp_hdf5_file.name)
# Remove temporary videos and directory
for video_path in cls.video_paths:
os.remove(video_path)
os.rmdir(cls.temp_videos_dir)
def test_get_frames_from_mp4(self):
"""Test extracting frames from MP4 video.""" """Test extracting frames from MP4 video."""
frames = get_frames_from_mp4(self.video_paths[0]) _, video_paths = temp_videos_dir
frames = get_frames_from_mp4(video_paths[0])
# Check frame properties # Check frame properties
self.assertEqual(frames.shape[0], 10) # Number of frames assert frames.shape[0] == 10 # Number of frames
self.assertEqual(frames.shape[1:], (704, 1280, 3)) # Frame dimensions assert frames.shape[1:] == (704, 1280, 3) # Frame dimensions
self.assertEqual(frames.dtype, np.uint8) # Data type assert frames.dtype == np.uint8 # Data type
def test_get_frames_from_mp4_resize(self): def test_get_frames_from_mp4_resize(self, temp_videos_dir):
"""Test extracting frames with resizing.""" """Test extracting frames with resizing."""
_, video_paths = temp_videos_dir
target_height, target_width = 352, 640 target_height, target_width = 352, 640
frames = get_frames_from_mp4(self.video_paths[0], target_height, target_width) frames = get_frames_from_mp4(video_paths[0], target_height, target_width)
# Check resized frame properties # Check resized frame properties
self.assertEqual(frames.shape[0], 10) # Number of frames assert frames.shape[0] == 10 # Number of frames
self.assertEqual(frames.shape[1:], (target_height, target_width, 3)) # Resized dimensions assert frames.shape[1:] == (target_height, target_width, 3) # Resized dimensions
self.assertEqual(frames.dtype, np.uint8) # Data type assert frames.dtype == np.uint8 # Data type
def test_process_video_and_demo(self): def test_process_video_and_demo(self, temp_hdf5_file, temp_videos_dir, temp_output_file):
"""Test processing a single video and creating a new demo.""" """Test processing a single video and creating a new demo."""
with h5py.File(self.temp_hdf5_file.name, "r") as f_in, h5py.File(self.temp_output_file.name, "w") as f_out: _, video_paths = temp_videos_dir
process_video_and_demo(f_in, f_out, self.video_paths[0], 0, 2) with h5py.File(temp_hdf5_file, "r") as f_in, h5py.File(temp_output_file, "w") as f_out:
process_video_and_demo(f_in, f_out, video_paths[0], 0, 2)
# Check if new demo was created with correct data # Check if new demo was created with correct data
self.assertIn("data/demo_2", f_out) assert "data/demo_2" in f_out
self.assertIn("data/demo_2/actions", f_out) assert "data/demo_2/actions" in f_out
self.assertIn("data/demo_2/obs/eef_pos", f_out) assert "data/demo_2/obs/eef_pos" in f_out
self.assertIn("data/demo_2/obs/eef_quat", f_out) assert "data/demo_2/obs/eef_quat" in f_out
self.assertIn("data/demo_2/obs/gripper_pos", f_out) assert "data/demo_2/obs/gripper_pos" in f_out
self.assertIn("data/demo_2/obs/table_cam", f_out) assert "data/demo_2/obs/table_cam" in f_out
self.assertIn("data/demo_2/obs/wrist_cam", f_out) assert "data/demo_2/obs/wrist_cam" in f_out
# Check data shapes # Check data shapes
self.assertEqual(f_out["data/demo_2/actions"].shape, (10, 7)) assert f_out["data/demo_2/actions"].shape == (10, 7)
self.assertEqual(f_out["data/demo_2/obs/eef_pos"].shape, (10, 3)) assert f_out["data/demo_2/obs/eef_pos"].shape == (10, 3)
self.assertEqual(f_out["data/demo_2/obs/eef_quat"].shape, (10, 4)) assert f_out["data/demo_2/obs/eef_quat"].shape == (10, 4)
self.assertEqual(f_out["data/demo_2/obs/gripper_pos"].shape, (10, 1)) assert f_out["data/demo_2/obs/gripper_pos"].shape == (10, 1)
self.assertEqual(f_out["data/demo_2/obs/table_cam"].shape, (10, 704, 1280, 3)) assert f_out["data/demo_2/obs/table_cam"].shape == (10, 704, 1280, 3)
self.assertEqual(f_out["data/demo_2/obs/wrist_cam"].shape, (10, 704, 1280, 3)) assert f_out["data/demo_2/obs/wrist_cam"].shape == (10, 704, 1280, 3)
# Check attributes # Check attributes
self.assertEqual(f_out["data/demo_2"].attrs["num_samples"], 10) assert f_out["data/demo_2"].attrs["num_samples"] == 10
def test_main_function(self): def test_main_function(self, temp_hdf5_file, temp_videos_dir, temp_output_file):
"""Test the main function.""" """Test the main function."""
# Mock command line arguments # Mock command line arguments
import sys import sys
...@@ -141,38 +148,34 @@ class TestMP4ToHDF5(unittest.TestCase): ...@@ -141,38 +148,34 @@ class TestMP4ToHDF5(unittest.TestCase):
sys.argv = [ sys.argv = [
"mp4_to_hdf5.py", "mp4_to_hdf5.py",
"--input_file", "--input_file",
self.temp_hdf5_file.name, temp_hdf5_file,
"--videos_dir", "--videos_dir",
self.temp_videos_dir, temp_videos_dir[0],
"--output_file", "--output_file",
self.temp_output_file.name, temp_output_file,
] ]
try: try:
main() main()
# Check if output file was created with correct data # Check if output file was created with correct data
with h5py.File(self.temp_output_file.name, "r") as f: with h5py.File(temp_output_file, "r") as f:
# Check if original demos were copied # Check if original demos were copied
self.assertIn("data/demo_0", f) assert "data/demo_0" in f
self.assertIn("data/demo_1", f) assert "data/demo_1" in f
# Check if new demos were created # Check if new demos were created
self.assertIn("data/demo_2", f) assert "data/demo_2" in f
self.assertIn("data/demo_3", f) assert "data/demo_3" in f
# Check data in new demos # Check data in new demos
for demo_id in [2, 3]: for demo_id in [2, 3]:
self.assertIn(f"data/demo_{demo_id}/actions", f) assert f"data/demo_{demo_id}/actions" in f
self.assertIn(f"data/demo_{demo_id}/obs/eef_pos", f) assert f"data/demo_{demo_id}/obs/eef_pos" in f
self.assertIn(f"data/demo_{demo_id}/obs/eef_quat", f) assert f"data/demo_{demo_id}/obs/eef_quat" in f
self.assertIn(f"data/demo_{demo_id}/obs/gripper_pos", f) assert f"data/demo_{demo_id}/obs/gripper_pos" in f
self.assertIn(f"data/demo_{demo_id}/obs/table_cam", f) assert f"data/demo_{demo_id}/obs/table_cam" in f
self.assertIn(f"data/demo_{demo_id}/obs/wrist_cam", f) assert f"data/demo_{demo_id}/obs/wrist_cam" in f
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment