Unverified Commit a958ac56 authored by shauryadNv's avatar shauryadNv Committed by GitHub

Updates cosmos test files to use pytest (#548)

# Description

Updated Mimic-Cosmos related tests to use pytest.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarKelly Guo <kellyguo123@hotmail.com>
parent 24e83d72
...@@ -8,66 +8,65 @@ ...@@ -8,66 +8,65 @@
import json import json
import os import os
import tempfile import tempfile
import unittest
import pytest
from scripts.tools.cosmos.cosmos_prompt_gen import generate_prompt, main from scripts.tools.cosmos.cosmos_prompt_gen import generate_prompt, main
class TestCosmosPromptGen(unittest.TestCase): @pytest.fixture(scope="class")
def temp_templates_file():
"""Create temporary templates file."""
temp_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
# Create test templates
test_templates = {
"lighting": ["with bright lighting", "with dim lighting", "with natural lighting"],
"color": ["in warm colors", "in cool colors", "in vibrant colors"],
"style": ["in a realistic style", "in an artistic style", "in a minimalist style"],
"empty_section": [], # Test empty section
"invalid_section": "not a list", # Test invalid section
}
# Write templates to file
with open(temp_file.name, "w") as f:
json.dump(test_templates, f)
yield temp_file.name
# Cleanup
os.remove(temp_file.name)
@pytest.fixture
def temp_output_file():
"""Create temporary output file."""
temp_file = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
yield temp_file.name
# Cleanup
os.remove(temp_file.name)
class TestCosmosPromptGen:
"""Test cases for Cosmos prompt generation functionality.""" """Test cases for Cosmos prompt generation functionality."""
@classmethod def test_generate_prompt_valid_templates(self, temp_templates_file):
def setUpClass(cls):
"""Set up test fixtures that are shared across all test methods."""
# Create temporary templates file
cls.temp_templates_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
# Create test templates
test_templates = {
"lighting": ["with bright lighting", "with dim lighting", "with natural lighting"],
"color": ["in warm colors", "in cool colors", "in vibrant colors"],
"style": ["in a realistic style", "in an artistic style", "in a minimalist style"],
"empty_section": [], # Test empty section
"invalid_section": "not a list", # Test invalid section
}
# Write templates to file
with open(cls.temp_templates_file.name, "w") as f:
json.dump(test_templates, f)
def setUp(self):
"""Set up test fixtures that are created for each test method."""
self.temp_output_file = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
def tearDown(self):
"""Clean up test fixtures after each test method."""
# Remove the temporary output file
os.remove(self.temp_output_file.name)
@classmethod
def tearDownClass(cls):
"""Clean up test fixtures that are shared across all test methods."""
# Remove the temporary templates file
os.remove(cls.temp_templates_file.name)
def test_generate_prompt_valid_templates(self):
"""Test generating a prompt with valid templates.""" """Test generating a prompt with valid templates."""
prompt = generate_prompt(self.temp_templates_file.name) prompt = generate_prompt(temp_templates_file)
# Check that prompt is a string # Check that prompt is a string
self.assertIsInstance(prompt, str) assert isinstance(prompt, str)
# Check that prompt contains at least one word # Check that prompt contains at least one word
self.assertTrue(len(prompt.split()) > 0) assert len(prompt.split()) > 0
# Check that prompt contains valid sections # Check that prompt contains valid sections
valid_sections = ["lighting", "color", "style"] valid_sections = ["lighting", "color", "style"]
found_sections = [section for section in valid_sections if section in prompt.lower()] found_sections = [section for section in valid_sections if section in prompt.lower()]
self.assertTrue(len(found_sections) > 0) assert len(found_sections) > 0
def test_generate_prompt_invalid_file(self): def test_generate_prompt_invalid_file(self):
"""Test generating a prompt with invalid file path.""" """Test generating a prompt with invalid file path."""
with self.assertRaises(FileNotFoundError): with pytest.raises(FileNotFoundError):
generate_prompt("nonexistent_file.json") generate_prompt("nonexistent_file.json")
def test_generate_prompt_invalid_json(self): def test_generate_prompt_invalid_json(self):
...@@ -78,12 +77,12 @@ class TestCosmosPromptGen(unittest.TestCase): ...@@ -78,12 +77,12 @@ class TestCosmosPromptGen(unittest.TestCase):
temp_file.flush() temp_file.flush()
try: try:
with self.assertRaises(ValueError): with pytest.raises(ValueError):
generate_prompt(temp_file.name) generate_prompt(temp_file.name)
finally: finally:
os.remove(temp_file.name) os.remove(temp_file.name)
def test_main_function_single_prompt(self): def test_main_function_single_prompt(self, temp_templates_file, temp_output_file):
"""Test main function with single prompt generation.""" """Test main function with single prompt generation."""
# Mock command line arguments # Mock command line arguments
import sys import sys
...@@ -92,29 +91,29 @@ class TestCosmosPromptGen(unittest.TestCase): ...@@ -92,29 +91,29 @@ class TestCosmosPromptGen(unittest.TestCase):
sys.argv = [ sys.argv = [
"cosmos_prompt_gen.py", "cosmos_prompt_gen.py",
"--templates_path", "--templates_path",
self.temp_templates_file.name, temp_templates_file,
"--num_prompts", "--num_prompts",
"1", "1",
"--output_path", "--output_path",
self.temp_output_file.name, temp_output_file,
] ]
try: try:
main() main()
# Check if output file was created # Check if output file was created
self.assertTrue(os.path.exists(self.temp_output_file.name)) assert os.path.exists(temp_output_file)
# Check content of output file # Check content of output file
with open(self.temp_output_file.name) as f: with open(temp_output_file) as f:
content = f.read().strip() content = f.read().strip()
self.assertTrue(len(content) > 0) assert len(content) > 0
self.assertEqual(len(content.split("\n")), 1) assert len(content.split("\n")) == 1
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
def test_main_function_multiple_prompts(self): def test_main_function_multiple_prompts(self, temp_templates_file, temp_output_file):
"""Test main function with multiple prompt generation.""" """Test main function with multiple prompt generation."""
# Mock command line arguments # Mock command line arguments
import sys import sys
...@@ -123,52 +122,48 @@ class TestCosmosPromptGen(unittest.TestCase): ...@@ -123,52 +122,48 @@ class TestCosmosPromptGen(unittest.TestCase):
sys.argv = [ sys.argv = [
"cosmos_prompt_gen.py", "cosmos_prompt_gen.py",
"--templates_path", "--templates_path",
self.temp_templates_file.name, temp_templates_file,
"--num_prompts", "--num_prompts",
"3", "3",
"--output_path", "--output_path",
self.temp_output_file.name, temp_output_file,
] ]
try: try:
main() main()
# Check if output file was created # Check if output file was created
self.assertTrue(os.path.exists(self.temp_output_file.name)) assert os.path.exists(temp_output_file)
# Check content of output file # Check content of output file
with open(self.temp_output_file.name) as f: with open(temp_output_file) as f:
content = f.read().strip() content = f.read().strip()
self.assertTrue(len(content) > 0) assert len(content) > 0
self.assertEqual(len(content.split("\n")), 3) assert len(content.split("\n")) == 3
# Check that each line is a valid prompt # Check that each line is a valid prompt
for line in content.split("\n"): for line in content.split("\n"):
self.assertTrue(len(line) > 0) assert len(line) > 0
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
def test_main_function_default_output(self): def test_main_function_default_output(self, temp_templates_file):
"""Test main function with default output path.""" """Test main function with default output path."""
# Mock command line arguments # Mock command line arguments
import sys import sys
original_argv = sys.argv original_argv = sys.argv
sys.argv = ["cosmos_prompt_gen.py", "--templates_path", self.temp_templates_file.name, "--num_prompts", "1"] sys.argv = ["cosmos_prompt_gen.py", "--templates_path", temp_templates_file, "--num_prompts", "1"]
try: try:
main() main()
# Check if default output file was created # Check if default output file was created
self.assertTrue(os.path.exists("prompts.txt")) assert os.path.exists("prompts.txt")
# Clean up default output file # Clean up default output file
os.remove("prompts.txt") os.remove("prompts.txt")
finally: finally:
# Restore original argv # Restore original argv
sys.argv = original_argv sys.argv = original_argv
if __name__ == "__main__":
unittest.main()
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment