omnigenome 0.3.0a0__py3-none-any.whl → 0.3.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +14 -37
- omnigenome/src/misc/utils.py +199 -139
- omnigenome/src/model/rna_design/model.py +139 -96
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/METADATA +3 -3
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/RECORD +9 -16
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/top_level.txt +0 -1
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/licenses/LICENSE +0 -0
tests/test_rna_functions.py
DELETED
|
@@ -1,255 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Test RNA-specific functionality based on examples.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
import tempfile
|
|
6
|
-
import os
|
|
7
|
-
from unittest.mock import patch, MagicMock
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class TestRNAFunctions:
|
|
11
|
-
"""Test RNA functionality based on examples."""
|
|
12
|
-
|
|
13
|
-
def test_rna_sequence_validity_checker(self):
|
|
14
|
-
"""Test ss_validity_loss function from Secondary_Structure_Prediction.py."""
|
|
15
|
-
# Recreate the function from the example
|
|
16
|
-
def ss_validity_loss(rna_strct: str) -> float:
|
|
17
|
-
left = right = 0
|
|
18
|
-
dots = rna_strct.count('.')
|
|
19
|
-
for c in rna_strct:
|
|
20
|
-
if c == '(':
|
|
21
|
-
left += 1
|
|
22
|
-
elif c == ')':
|
|
23
|
-
if left:
|
|
24
|
-
left -= 1
|
|
25
|
-
else:
|
|
26
|
-
right += 1
|
|
27
|
-
elif c != '.':
|
|
28
|
-
raise ValueError(f"Invalid char {c}")
|
|
29
|
-
return (left + right) / (len(rna_strct) - dots + 1e-8)
|
|
30
|
-
|
|
31
|
-
# Test valid structures
|
|
32
|
-
assert ss_validity_loss("(())") == 0.0
|
|
33
|
-
assert ss_validity_loss("((..))") == 0.0
|
|
34
|
-
assert ss_validity_loss("....") == 0.0
|
|
35
|
-
|
|
36
|
-
# Test invalid structures
|
|
37
|
-
assert ss_validity_loss("(((") > 0.0 # Unmatched left
|
|
38
|
-
assert ss_validity_loss(")))") > 0.0 # Unmatched right
|
|
39
|
-
assert ss_validity_loss("())(") > 0.0 # Mixed unmatched
|
|
40
|
-
|
|
41
|
-
# Test error case
|
|
42
|
-
with pytest.raises(ValueError, match="Invalid char"):
|
|
43
|
-
ss_validity_loss("((X))")
|
|
44
|
-
|
|
45
|
-
def test_find_invalid_positions(self):
|
|
46
|
-
"""Test find_invalid_positions function from Secondary_Structure_Prediction.py."""
|
|
47
|
-
# Recreate the function from the example
|
|
48
|
-
def find_invalid_positions(struct: str) -> list:
|
|
49
|
-
stack, invalid = [], []
|
|
50
|
-
for i, c in enumerate(struct):
|
|
51
|
-
if c == '(':
|
|
52
|
-
stack.append(i)
|
|
53
|
-
elif c == ')':
|
|
54
|
-
if stack:
|
|
55
|
-
stack.pop()
|
|
56
|
-
else:
|
|
57
|
-
invalid.append(i)
|
|
58
|
-
invalid.extend(stack)
|
|
59
|
-
return invalid
|
|
60
|
-
|
|
61
|
-
# Test valid structures
|
|
62
|
-
assert find_invalid_positions("(())") == []
|
|
63
|
-
assert find_invalid_positions("((..))") == []
|
|
64
|
-
assert find_invalid_positions("....") == []
|
|
65
|
-
|
|
66
|
-
# Test invalid structures
|
|
67
|
-
assert find_invalid_positions("(((") == [0, 1, 2] # All unmatched left
|
|
68
|
-
assert find_invalid_positions(")))") == [0, 1, 2] # All unmatched right
|
|
69
|
-
assert find_invalid_positions("())(") == [2, 3] # One unmatched right, one left
|
|
70
|
-
|
|
71
|
-
def test_rna_structure_formats(self):
|
|
72
|
-
"""Test RNA structure format validation."""
|
|
73
|
-
valid_structures = [
|
|
74
|
-
"(())",
|
|
75
|
-
"((()))",
|
|
76
|
-
".((.))",
|
|
77
|
-
"....",
|
|
78
|
-
"",
|
|
79
|
-
"((..))",
|
|
80
|
-
]
|
|
81
|
-
|
|
82
|
-
invalid_structures = [
|
|
83
|
-
"((X))", # Invalid character
|
|
84
|
-
"(()", # Unmatched
|
|
85
|
-
"())", # Unmatched
|
|
86
|
-
")(", # Wrong order
|
|
87
|
-
]
|
|
88
|
-
|
|
89
|
-
def is_valid_structure_format(struct: str) -> bool:
|
|
90
|
-
"""Check if structure contains only valid characters."""
|
|
91
|
-
return all(c in "()." for c in struct)
|
|
92
|
-
|
|
93
|
-
for struct in valid_structures:
|
|
94
|
-
assert is_valid_structure_format(struct), f"Should be valid: {struct}"
|
|
95
|
-
|
|
96
|
-
for struct in invalid_structures:
|
|
97
|
-
if any(c not in "()." for c in struct):
|
|
98
|
-
assert not is_valid_structure_format(struct), f"Should be invalid: {struct}"
|
|
99
|
-
|
|
100
|
-
def test_sequence_replacement_patterns(self):
|
|
101
|
-
"""Test U/T replacement patterns from examples."""
|
|
102
|
-
# Pattern from web_rna_design.py
|
|
103
|
-
def rna_to_dna_pattern(sequence):
|
|
104
|
-
return sequence.replace("U", "T")
|
|
105
|
-
|
|
106
|
-
def dna_to_rna_pattern(sequence):
|
|
107
|
-
return sequence.replace("T", "U")
|
|
108
|
-
|
|
109
|
-
# Test RNA to DNA
|
|
110
|
-
assert rna_to_dna_pattern("AUCG") == "ATCG"
|
|
111
|
-
assert rna_to_dna_pattern("UUUU") == "TTTT"
|
|
112
|
-
assert rna_to_dna_pattern("ACGU") == "ACGT"
|
|
113
|
-
|
|
114
|
-
# Test DNA to RNA
|
|
115
|
-
assert dna_to_rna_pattern("ATCG") == "AUCG"
|
|
116
|
-
assert dna_to_rna_pattern("TTTT") == "UUUU"
|
|
117
|
-
assert dna_to_rna_pattern("ACGT") == "ACGU"
|
|
118
|
-
|
|
119
|
-
def test_random_base_generation_patterns(self):
|
|
120
|
-
"""Test random base generation patterns from RNA design examples."""
|
|
121
|
-
import random
|
|
122
|
-
|
|
123
|
-
def generate_random_rna_base():
|
|
124
|
-
"""Pattern from easy_rna_design_emoo.py."""
|
|
125
|
-
return random.choice(["A", "U", "G", "C"])
|
|
126
|
-
|
|
127
|
-
def generate_random_dna_base():
|
|
128
|
-
"""Pattern from easy_rna_design_emoo.py."""
|
|
129
|
-
return random.choice(["A", "T", "G", "C"])
|
|
130
|
-
|
|
131
|
-
# Test multiple generations to ensure valid bases
|
|
132
|
-
for _ in range(10):
|
|
133
|
-
rna_base = generate_random_rna_base()
|
|
134
|
-
assert rna_base in ["A", "U", "G", "C"]
|
|
135
|
-
|
|
136
|
-
dna_base = generate_random_dna_base()
|
|
137
|
-
assert dna_base in ["A", "T", "G", "C"]
|
|
138
|
-
|
|
139
|
-
def test_sequence_mutation_pattern(self):
|
|
140
|
-
"""Test sequence mutation pattern from mlm_mutate function."""
|
|
141
|
-
try:
|
|
142
|
-
import numpy as np
|
|
143
|
-
except ImportError:
|
|
144
|
-
pytest.skip("numpy not available")
|
|
145
|
-
|
|
146
|
-
def mutate_sequence_pattern(sequence, mutation_rate=0.1):
|
|
147
|
-
"""Simplified version of mutation pattern from examples."""
|
|
148
|
-
sequence_array = np.array(list(sequence), dtype=np.str_)
|
|
149
|
-
probability_matrix = np.full(sequence_array.shape, mutation_rate)
|
|
150
|
-
masked_indices = np.random.rand(*sequence_array.shape) < probability_matrix
|
|
151
|
-
sequence_array[masked_indices] = "$" # Mask token
|
|
152
|
-
return "".join(sequence_array.tolist())
|
|
153
|
-
|
|
154
|
-
# Test mutation with 0% rate
|
|
155
|
-
original = "AUCG"
|
|
156
|
-
mutated_zero = mutate_sequence_pattern(original, 0.0)
|
|
157
|
-
assert mutated_zero == original
|
|
158
|
-
|
|
159
|
-
# Test mutation with 100% rate
|
|
160
|
-
mutated_full = mutate_sequence_pattern(original, 1.0)
|
|
161
|
-
assert mutated_full == "$$$$"
|
|
162
|
-
|
|
163
|
-
# Test with moderate rate - should have some masks
|
|
164
|
-
np.random.seed(42) # For reproducible test
|
|
165
|
-
mutated_partial = mutate_sequence_pattern("AUCGAUCGAUCG", 0.5)
|
|
166
|
-
assert "$" in mutated_partial
|
|
167
|
-
|
|
168
|
-
@patch('tempfile.mkdtemp')
|
|
169
|
-
def test_temp_directory_pattern(self, mock_mkdtemp):
|
|
170
|
-
"""Test temp directory usage pattern from Secondary_Structure_Prediction.py."""
|
|
171
|
-
from pathlib import Path
|
|
172
|
-
|
|
173
|
-
mock_mkdtemp.return_value = "/tmp/test_dir"
|
|
174
|
-
|
|
175
|
-
# Pattern from Secondary_Structure_Prediction.py
|
|
176
|
-
TEMP_DIR = Path(tempfile.mkdtemp())
|
|
177
|
-
|
|
178
|
-
mock_mkdtemp.assert_called_once()
|
|
179
|
-
assert isinstance(TEMP_DIR, Path)
|
|
180
|
-
|
|
181
|
-
def test_rna_embedding_sequence_validation(self):
|
|
182
|
-
"""Test RNA sequence validation for embedding examples."""
|
|
183
|
-
# RNA sequences from RNA_Embedding_Tutorial.ipynb
|
|
184
|
-
rna_sequences = [
|
|
185
|
-
"AUGGCUACG",
|
|
186
|
-
"CGGAUACGGC",
|
|
187
|
-
"UGGCCAAGUC",
|
|
188
|
-
"AUGCUGCUAUGCUA"
|
|
189
|
-
]
|
|
190
|
-
|
|
191
|
-
def validate_rna_sequence(seq):
|
|
192
|
-
"""Validate RNA sequence format."""
|
|
193
|
-
return all(base in "AUCG" for base in seq) and len(seq) > 0
|
|
194
|
-
|
|
195
|
-
for seq in rna_sequences:
|
|
196
|
-
assert validate_rna_sequence(seq), f"Invalid RNA sequence: {seq}"
|
|
197
|
-
|
|
198
|
-
def test_structure_prediction_mock_pattern(self):
|
|
199
|
-
"""Test structure prediction pattern without ViennaRNA dependency."""
|
|
200
|
-
def mock_predict_structure_single(sequence):
|
|
201
|
-
"""Mock version of predict_structure_single from examples."""
|
|
202
|
-
# Return a mock structure and energy
|
|
203
|
-
return "." * len(sequence), -10.0
|
|
204
|
-
|
|
205
|
-
# Test the pattern
|
|
206
|
-
seq = "AUCG"
|
|
207
|
-
struct, energy = mock_predict_structure_single(seq)
|
|
208
|
-
|
|
209
|
-
assert len(struct) == len(seq)
|
|
210
|
-
assert isinstance(energy, float)
|
|
211
|
-
assert struct == "...."
|
|
212
|
-
|
|
213
|
-
def test_base64_encoding_pattern(self):
|
|
214
|
-
"""Test base64 encoding pattern from SVG generation."""
|
|
215
|
-
import base64
|
|
216
|
-
|
|
217
|
-
def create_mock_svg_datauri(content="test"):
|
|
218
|
-
"""Mock version of SVG data URI creation."""
|
|
219
|
-
svg_content = f'<svg>{content}</svg>'
|
|
220
|
-
b64 = base64.b64encode(svg_content.encode()).decode('utf-8')
|
|
221
|
-
return f"data:image/svg+xml;base64,{b64}"
|
|
222
|
-
|
|
223
|
-
uri = create_mock_svg_datauri("test")
|
|
224
|
-
assert uri.startswith("data:image/svg+xml;base64,")
|
|
225
|
-
|
|
226
|
-
# Decode and verify
|
|
227
|
-
_, b64_part = uri.split(",", 1)
|
|
228
|
-
decoded = base64.b64decode(b64_part).decode('utf-8')
|
|
229
|
-
assert decoded == "<svg>test</svg>"
|
|
230
|
-
|
|
231
|
-
def test_longest_bp_span_function(self):
|
|
232
|
-
"""Test longest_bp_span function from easy_rna_design_emoo.py."""
|
|
233
|
-
def longest_bp_span(structure):
|
|
234
|
-
"""Function from easy_rna_design_emoo.py."""
|
|
235
|
-
stack = []
|
|
236
|
-
max_span = 0
|
|
237
|
-
|
|
238
|
-
for i, char in enumerate(structure):
|
|
239
|
-
if char == '(':
|
|
240
|
-
stack.append(i)
|
|
241
|
-
elif char == ')':
|
|
242
|
-
if stack:
|
|
243
|
-
left_index = stack.pop()
|
|
244
|
-
current_span = i - left_index
|
|
245
|
-
max_span = max(max_span, current_span)
|
|
246
|
-
|
|
247
|
-
return max_span
|
|
248
|
-
|
|
249
|
-
# Test cases
|
|
250
|
-
assert longest_bp_span("(())") == 3 # Outer pair spans 3 positions
|
|
251
|
-
assert longest_bp_span("((()))") == 5 # Outer pair spans 5 positions
|
|
252
|
-
assert longest_bp_span("()()") == 1 # Each pair spans 1 position
|
|
253
|
-
assert longest_bp_span("....") == 0 # No pairs
|
|
254
|
-
assert longest_bp_span("") == 0 # Empty structure
|
|
255
|
-
assert longest_bp_span("((.))") == 4 # Outer pair spans 4 positions
|
tests/test_training_patterns.py
DELETED
|
@@ -1,302 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Test training patterns and configurations based on examples.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
from unittest.mock import patch, MagicMock
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class TestTrainingPatterns:
|
|
9
|
-
"""Test training patterns from examples."""
|
|
10
|
-
|
|
11
|
-
def test_trainer_imports(self):
|
|
12
|
-
"""Test trainer imports as shown in quick_start.md."""
|
|
13
|
-
try:
|
|
14
|
-
from omnigenome import Trainer
|
|
15
|
-
assert True
|
|
16
|
-
except ImportError:
|
|
17
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
18
|
-
|
|
19
|
-
def test_autobench_imports(self):
|
|
20
|
-
"""Test AutoBench imports from examples."""
|
|
21
|
-
try:
|
|
22
|
-
from omnigenome import AutoBench
|
|
23
|
-
assert True
|
|
24
|
-
except ImportError:
|
|
25
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
26
|
-
|
|
27
|
-
def test_autocuda_import_pattern(self):
|
|
28
|
-
"""Test autocuda import pattern from examples."""
|
|
29
|
-
try:
|
|
30
|
-
import autocuda
|
|
31
|
-
# Pattern from examples
|
|
32
|
-
device = autocuda.auto_cuda()
|
|
33
|
-
# Just verify the function exists and returns something
|
|
34
|
-
assert device is not None
|
|
35
|
-
except ImportError:
|
|
36
|
-
# Skip if autocuda not available
|
|
37
|
-
pytest.skip("autocuda not available")
|
|
38
|
-
|
|
39
|
-
@patch('omnigenome.AutoBench')
|
|
40
|
-
def test_autobench_initialization_pattern(self, mock_autobench):
|
|
41
|
-
"""Test AutoBench initialization pattern from quick_start.md."""
|
|
42
|
-
mock_instance = MagicMock()
|
|
43
|
-
mock_autobench.return_value = mock_instance
|
|
44
|
-
|
|
45
|
-
from omnigenome import AutoBench
|
|
46
|
-
|
|
47
|
-
# Pattern from quick_start.md
|
|
48
|
-
auto_bench = AutoBench(
|
|
49
|
-
benchmark="RGB",
|
|
50
|
-
model_name_or_path="yangheng/OmniGenome-186M",
|
|
51
|
-
device="cuda:0",
|
|
52
|
-
overwrite=True
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
mock_autobench.assert_called_once_with(
|
|
56
|
-
benchmark="RGB",
|
|
57
|
-
model_name_or_path="yangheng/OmniGenome-186M",
|
|
58
|
-
device="cuda:0",
|
|
59
|
-
overwrite=True
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
def test_benchmark_names(self):
|
|
63
|
-
"""Test benchmark names from examples."""
|
|
64
|
-
# Benchmarks from quick_start.md
|
|
65
|
-
benchmarks = ["RGB", "GB", "PGB", "GUE", "BEACON"]
|
|
66
|
-
|
|
67
|
-
for benchmark in benchmarks:
|
|
68
|
-
assert isinstance(benchmark, str)
|
|
69
|
-
assert len(benchmark) > 0
|
|
70
|
-
assert benchmark.isupper()
|
|
71
|
-
|
|
72
|
-
def test_trainer_names(self):
|
|
73
|
-
"""Test trainer types from examples."""
|
|
74
|
-
# Trainers from autobench examples
|
|
75
|
-
trainers = ["accelerate", "huggingface"]
|
|
76
|
-
|
|
77
|
-
for trainer in trainers:
|
|
78
|
-
assert isinstance(trainer, str)
|
|
79
|
-
assert trainer in ["accelerate", "huggingface"]
|
|
80
|
-
|
|
81
|
-
@patch('omnigenome.Trainer')
|
|
82
|
-
def test_trainer_initialization_pattern(self, mock_trainer):
|
|
83
|
-
"""Test Trainer initialization pattern from quick_start.md."""
|
|
84
|
-
mock_trainer.return_value = MagicMock()
|
|
85
|
-
|
|
86
|
-
from omnigenome import Trainer
|
|
87
|
-
|
|
88
|
-
# Mock training arguments
|
|
89
|
-
mock_args = MagicMock()
|
|
90
|
-
mock_args.output_dir = "./results"
|
|
91
|
-
mock_args.num_train_epochs = 3
|
|
92
|
-
mock_args.per_device_train_batch_size = 8
|
|
93
|
-
mock_args.learning_rate = 2e-5
|
|
94
|
-
|
|
95
|
-
# Pattern from quick_start.md
|
|
96
|
-
trainer = Trainer(
|
|
97
|
-
model=MagicMock(),
|
|
98
|
-
train_dataset=MagicMock(),
|
|
99
|
-
eval_dataset=MagicMock(),
|
|
100
|
-
args=mock_args
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
mock_trainer.assert_called_once()
|
|
104
|
-
|
|
105
|
-
def test_training_arguments_pattern(self):
|
|
106
|
-
"""Test training arguments patterns from examples."""
|
|
107
|
-
# Common training parameters from examples
|
|
108
|
-
training_configs = {
|
|
109
|
-
"output_dir": "./results",
|
|
110
|
-
"num_train_epochs": 3,
|
|
111
|
-
"per_device_train_batch_size": 8,
|
|
112
|
-
"learning_rate": 2e-5,
|
|
113
|
-
"epochs": 3,
|
|
114
|
-
"batch_size": 8,
|
|
115
|
-
"seeds": [42, 43, 44]
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
# Verify types and ranges
|
|
119
|
-
assert isinstance(training_configs["output_dir"], str)
|
|
120
|
-
assert isinstance(training_configs["num_train_epochs"], int)
|
|
121
|
-
assert training_configs["num_train_epochs"] > 0
|
|
122
|
-
assert isinstance(training_configs["per_device_train_batch_size"], int)
|
|
123
|
-
assert training_configs["per_device_train_batch_size"] > 0
|
|
124
|
-
assert isinstance(training_configs["learning_rate"], float)
|
|
125
|
-
assert training_configs["learning_rate"] > 0
|
|
126
|
-
assert isinstance(training_configs["seeds"], list)
|
|
127
|
-
assert all(isinstance(seed, int) for seed in training_configs["seeds"])
|
|
128
|
-
|
|
129
|
-
def test_genetic_algorithm_parameters(self):
|
|
130
|
-
"""Test genetic algorithm parameters from RNA design examples."""
|
|
131
|
-
# Parameters from easy_rna_design_emoo.py
|
|
132
|
-
ga_params = {
|
|
133
|
-
"mutation_ratio": 0.1,
|
|
134
|
-
"num_population": 100,
|
|
135
|
-
"num_generation": 50,
|
|
136
|
-
"model": "anonymous8/OmniGenome-186M"
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
# Verify parameter types and ranges
|
|
140
|
-
assert isinstance(ga_params["mutation_ratio"], float)
|
|
141
|
-
assert 0.0 <= ga_params["mutation_ratio"] <= 1.0
|
|
142
|
-
assert isinstance(ga_params["num_population"], int)
|
|
143
|
-
assert ga_params["num_population"] > 0
|
|
144
|
-
assert isinstance(ga_params["num_generation"], int)
|
|
145
|
-
assert ga_params["num_generation"] > 0
|
|
146
|
-
assert isinstance(ga_params["model"], str)
|
|
147
|
-
|
|
148
|
-
def test_web_rna_design_parameters(self):
|
|
149
|
-
"""Test web RNA design parameters from web_rna_design.py."""
|
|
150
|
-
# Parameters from web_rna_design.py
|
|
151
|
-
web_params = {
|
|
152
|
-
"mutation_ratio": 0.5,
|
|
153
|
-
"num_population": 500,
|
|
154
|
-
"num_generation": 10,
|
|
155
|
-
"puzzle_id": 0
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
# Verify parameter types
|
|
159
|
-
assert isinstance(web_params["mutation_ratio"], float)
|
|
160
|
-
assert 0.0 <= web_params["mutation_ratio"] <= 1.0
|
|
161
|
-
assert isinstance(web_params["num_population"], int)
|
|
162
|
-
assert web_params["num_population"] > 0
|
|
163
|
-
assert isinstance(web_params["num_generation"], int)
|
|
164
|
-
assert web_params["num_generation"] > 0
|
|
165
|
-
assert isinstance(web_params["puzzle_id"], int)
|
|
166
|
-
assert web_params["puzzle_id"] >= 0
|
|
167
|
-
|
|
168
|
-
def test_model_optimization_patterns(self):
|
|
169
|
-
"""Test model optimization patterns from examples."""
|
|
170
|
-
# Patterns from examples for model optimization
|
|
171
|
-
optimization_configs = {
|
|
172
|
-
"torch_dtype": "float16",
|
|
173
|
-
"device_map": "auto",
|
|
174
|
-
"trust_remote_code": True,
|
|
175
|
-
"gradient_checkpointing": True,
|
|
176
|
-
"fp16": True
|
|
177
|
-
}
|
|
178
|
-
|
|
179
|
-
for key, value in optimization_configs.items():
|
|
180
|
-
assert isinstance(key, str)
|
|
181
|
-
# Value types vary, just ensure they exist
|
|
182
|
-
assert value is not None
|
|
183
|
-
|
|
184
|
-
@patch('torch.cuda.empty_cache')
|
|
185
|
-
def test_memory_management_pattern(self, mock_empty_cache):
|
|
186
|
-
"""Test memory management patterns from web_rna_design.py."""
|
|
187
|
-
try:
|
|
188
|
-
import torch
|
|
189
|
-
except ImportError:
|
|
190
|
-
pytest.skip("torch not available")
|
|
191
|
-
|
|
192
|
-
# Pattern from web_rna_design.py
|
|
193
|
-
def cleanup_model_pattern():
|
|
194
|
-
"""Memory cleanup pattern from examples."""
|
|
195
|
-
# del model, tokenizer # Would normally delete objects
|
|
196
|
-
torch.cuda.empty_cache()
|
|
197
|
-
|
|
198
|
-
cleanup_model_pattern()
|
|
199
|
-
mock_empty_cache.assert_called_once()
|
|
200
|
-
|
|
201
|
-
def test_random_seed_patterns(self):
|
|
202
|
-
"""Test random seed patterns from examples."""
|
|
203
|
-
import random
|
|
204
|
-
|
|
205
|
-
# Pattern from examples
|
|
206
|
-
def set_random_seed_pattern():
|
|
207
|
-
"""Random seed pattern from easy_rna_design_emoo.py."""
|
|
208
|
-
return random.randint(0, 99999999)
|
|
209
|
-
|
|
210
|
-
# Test seed generation
|
|
211
|
-
seed1 = set_random_seed_pattern()
|
|
212
|
-
seed2 = set_random_seed_pattern()
|
|
213
|
-
|
|
214
|
-
assert isinstance(seed1, int)
|
|
215
|
-
assert isinstance(seed2, int)
|
|
216
|
-
assert 0 <= seed1 <= 99999999
|
|
217
|
-
assert 0 <= seed2 <= 99999999
|
|
218
|
-
|
|
219
|
-
def test_evaluation_metrics_patterns(self):
|
|
220
|
-
"""Test evaluation metrics patterns from examples."""
|
|
221
|
-
# Common metrics mentioned in examples
|
|
222
|
-
metrics = [
|
|
223
|
-
"accuracy",
|
|
224
|
-
"f1_score",
|
|
225
|
-
"precision",
|
|
226
|
-
"recall",
|
|
227
|
-
"mse",
|
|
228
|
-
"mae",
|
|
229
|
-
"r2_score"
|
|
230
|
-
]
|
|
231
|
-
|
|
232
|
-
for metric in metrics:
|
|
233
|
-
assert isinstance(metric, str)
|
|
234
|
-
assert len(metric) > 0
|
|
235
|
-
|
|
236
|
-
def test_device_selection_patterns(self):
|
|
237
|
-
"""Test device selection patterns from examples."""
|
|
238
|
-
# Patterns from examples
|
|
239
|
-
device_patterns = [
|
|
240
|
-
"cuda:0",
|
|
241
|
-
"cuda",
|
|
242
|
-
"cpu",
|
|
243
|
-
"auto"
|
|
244
|
-
]
|
|
245
|
-
|
|
246
|
-
for device in device_patterns:
|
|
247
|
-
assert isinstance(device, str)
|
|
248
|
-
assert len(device) > 0
|
|
249
|
-
|
|
250
|
-
def test_batch_size_patterns(self):
|
|
251
|
-
"""Test batch size patterns from examples."""
|
|
252
|
-
# Common batch sizes from examples
|
|
253
|
-
batch_sizes = [4, 8, 16, 32, 64]
|
|
254
|
-
|
|
255
|
-
for batch_size in batch_sizes:
|
|
256
|
-
assert isinstance(batch_size, int)
|
|
257
|
-
assert batch_size > 0
|
|
258
|
-
assert batch_size <= 1024 # Reasonable upper limit
|
|
259
|
-
|
|
260
|
-
def test_learning_rate_patterns(self):
|
|
261
|
-
"""Test learning rate patterns from examples."""
|
|
262
|
-
# Common learning rates from examples
|
|
263
|
-
learning_rates = [1e-5, 2e-5, 5e-5, 1e-4, 2e-4]
|
|
264
|
-
|
|
265
|
-
for lr in learning_rates:
|
|
266
|
-
assert isinstance(lr, float)
|
|
267
|
-
assert lr > 0
|
|
268
|
-
assert lr < 1.0 # Should be small
|
|
269
|
-
|
|
270
|
-
def test_epoch_patterns(self):
|
|
271
|
-
"""Test epoch patterns from examples."""
|
|
272
|
-
# Common epoch counts from examples
|
|
273
|
-
epoch_counts = [1, 3, 5, 10, 20]
|
|
274
|
-
|
|
275
|
-
for epochs in epoch_counts:
|
|
276
|
-
assert isinstance(epochs, int)
|
|
277
|
-
assert epochs > 0
|
|
278
|
-
assert epochs <= 100 # Reasonable upper limit
|
|
279
|
-
|
|
280
|
-
def test_output_directory_patterns(self):
|
|
281
|
-
"""Test output directory patterns from examples."""
|
|
282
|
-
# Common output directory patterns
|
|
283
|
-
output_dirs = [
|
|
284
|
-
"./results",
|
|
285
|
-
"./output",
|
|
286
|
-
"./checkpoints",
|
|
287
|
-
"./models"
|
|
288
|
-
]
|
|
289
|
-
|
|
290
|
-
for output_dir in output_dirs:
|
|
291
|
-
assert isinstance(output_dir, str)
|
|
292
|
-
assert output_dir.startswith("./") or output_dir.startswith("/")
|
|
293
|
-
|
|
294
|
-
def test_model_saving_patterns(self):
|
|
295
|
-
"""Test model saving patterns from examples."""
|
|
296
|
-
# File extensions for saved models
|
|
297
|
-
model_extensions = [".pt", ".pth", ".bin", ".safetensors"]
|
|
298
|
-
|
|
299
|
-
for ext in model_extensions:
|
|
300
|
-
assert isinstance(ext, str)
|
|
301
|
-
assert ext.startswith(".")
|
|
302
|
-
assert len(ext) > 1
|
|
File without changes
|
|
File without changes
|
|
File without changes
|