omnigenome 0.3.0a0__py3-none-any.whl → 0.3.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

@@ -1,255 +0,0 @@
1
- """
2
- Test RNA-specific functionality based on examples.
3
- """
4
- import pytest
5
- import tempfile
6
- import os
7
- from unittest.mock import patch, MagicMock
8
-
9
-
10
- class TestRNAFunctions:
11
- """Test RNA functionality based on examples."""
12
-
13
- def test_rna_sequence_validity_checker(self):
14
- """Test ss_validity_loss function from Secondary_Structure_Prediction.py."""
15
- # Recreate the function from the example
16
- def ss_validity_loss(rna_strct: str) -> float:
17
- left = right = 0
18
- dots = rna_strct.count('.')
19
- for c in rna_strct:
20
- if c == '(':
21
- left += 1
22
- elif c == ')':
23
- if left:
24
- left -= 1
25
- else:
26
- right += 1
27
- elif c != '.':
28
- raise ValueError(f"Invalid char {c}")
29
- return (left + right) / (len(rna_strct) - dots + 1e-8)
30
-
31
- # Test valid structures
32
- assert ss_validity_loss("(())") == 0.0
33
- assert ss_validity_loss("((..))") == 0.0
34
- assert ss_validity_loss("....") == 0.0
35
-
36
- # Test invalid structures
37
- assert ss_validity_loss("(((") > 0.0 # Unmatched left
38
- assert ss_validity_loss(")))") > 0.0 # Unmatched right
39
- assert ss_validity_loss("())(") > 0.0 # Mixed unmatched
40
-
41
- # Test error case
42
- with pytest.raises(ValueError, match="Invalid char"):
43
- ss_validity_loss("((X))")
44
-
45
- def test_find_invalid_positions(self):
46
- """Test find_invalid_positions function from Secondary_Structure_Prediction.py."""
47
- # Recreate the function from the example
48
- def find_invalid_positions(struct: str) -> list:
49
- stack, invalid = [], []
50
- for i, c in enumerate(struct):
51
- if c == '(':
52
- stack.append(i)
53
- elif c == ')':
54
- if stack:
55
- stack.pop()
56
- else:
57
- invalid.append(i)
58
- invalid.extend(stack)
59
- return invalid
60
-
61
- # Test valid structures
62
- assert find_invalid_positions("(())") == []
63
- assert find_invalid_positions("((..))") == []
64
- assert find_invalid_positions("....") == []
65
-
66
- # Test invalid structures
67
- assert find_invalid_positions("(((") == [0, 1, 2] # All unmatched left
68
- assert find_invalid_positions(")))") == [0, 1, 2] # All unmatched right
69
- assert find_invalid_positions("())(") == [2, 3] # One unmatched right, one left
70
-
71
- def test_rna_structure_formats(self):
72
- """Test RNA structure format validation."""
73
- valid_structures = [
74
- "(())",
75
- "((()))",
76
- ".((.))",
77
- "....",
78
- "",
79
- "((..))",
80
- ]
81
-
82
- invalid_structures = [
83
- "((X))", # Invalid character
84
- "(()", # Unmatched
85
- "())", # Unmatched
86
- ")(", # Wrong order
87
- ]
88
-
89
- def is_valid_structure_format(struct: str) -> bool:
90
- """Check if structure contains only valid characters."""
91
- return all(c in "()." for c in struct)
92
-
93
- for struct in valid_structures:
94
- assert is_valid_structure_format(struct), f"Should be valid: {struct}"
95
-
96
- for struct in invalid_structures:
97
- if any(c not in "()." for c in struct):
98
- assert not is_valid_structure_format(struct), f"Should be invalid: {struct}"
99
-
100
- def test_sequence_replacement_patterns(self):
101
- """Test U/T replacement patterns from examples."""
102
- # Pattern from web_rna_design.py
103
- def rna_to_dna_pattern(sequence):
104
- return sequence.replace("U", "T")
105
-
106
- def dna_to_rna_pattern(sequence):
107
- return sequence.replace("T", "U")
108
-
109
- # Test RNA to DNA
110
- assert rna_to_dna_pattern("AUCG") == "ATCG"
111
- assert rna_to_dna_pattern("UUUU") == "TTTT"
112
- assert rna_to_dna_pattern("ACGU") == "ACGT"
113
-
114
- # Test DNA to RNA
115
- assert dna_to_rna_pattern("ATCG") == "AUCG"
116
- assert dna_to_rna_pattern("TTTT") == "UUUU"
117
- assert dna_to_rna_pattern("ACGT") == "ACGU"
118
-
119
- def test_random_base_generation_patterns(self):
120
- """Test random base generation patterns from RNA design examples."""
121
- import random
122
-
123
- def generate_random_rna_base():
124
- """Pattern from easy_rna_design_emoo.py."""
125
- return random.choice(["A", "U", "G", "C"])
126
-
127
- def generate_random_dna_base():
128
- """Pattern from easy_rna_design_emoo.py."""
129
- return random.choice(["A", "T", "G", "C"])
130
-
131
- # Test multiple generations to ensure valid bases
132
- for _ in range(10):
133
- rna_base = generate_random_rna_base()
134
- assert rna_base in ["A", "U", "G", "C"]
135
-
136
- dna_base = generate_random_dna_base()
137
- assert dna_base in ["A", "T", "G", "C"]
138
-
139
- def test_sequence_mutation_pattern(self):
140
- """Test sequence mutation pattern from mlm_mutate function."""
141
- try:
142
- import numpy as np
143
- except ImportError:
144
- pytest.skip("numpy not available")
145
-
146
- def mutate_sequence_pattern(sequence, mutation_rate=0.1):
147
- """Simplified version of mutation pattern from examples."""
148
- sequence_array = np.array(list(sequence), dtype=np.str_)
149
- probability_matrix = np.full(sequence_array.shape, mutation_rate)
150
- masked_indices = np.random.rand(*sequence_array.shape) < probability_matrix
151
- sequence_array[masked_indices] = "$" # Mask token
152
- return "".join(sequence_array.tolist())
153
-
154
- # Test mutation with 0% rate
155
- original = "AUCG"
156
- mutated_zero = mutate_sequence_pattern(original, 0.0)
157
- assert mutated_zero == original
158
-
159
- # Test mutation with 100% rate
160
- mutated_full = mutate_sequence_pattern(original, 1.0)
161
- assert mutated_full == "$$$$"
162
-
163
- # Test with moderate rate - should have some masks
164
- np.random.seed(42) # For reproducible test
165
- mutated_partial = mutate_sequence_pattern("AUCGAUCGAUCG", 0.5)
166
- assert "$" in mutated_partial
167
-
168
- @patch('tempfile.mkdtemp')
169
- def test_temp_directory_pattern(self, mock_mkdtemp):
170
- """Test temp directory usage pattern from Secondary_Structure_Prediction.py."""
171
- from pathlib import Path
172
-
173
- mock_mkdtemp.return_value = "/tmp/test_dir"
174
-
175
- # Pattern from Secondary_Structure_Prediction.py
176
- TEMP_DIR = Path(tempfile.mkdtemp())
177
-
178
- mock_mkdtemp.assert_called_once()
179
- assert isinstance(TEMP_DIR, Path)
180
-
181
- def test_rna_embedding_sequence_validation(self):
182
- """Test RNA sequence validation for embedding examples."""
183
- # RNA sequences from RNA_Embedding_Tutorial.ipynb
184
- rna_sequences = [
185
- "AUGGCUACG",
186
- "CGGAUACGGC",
187
- "UGGCCAAGUC",
188
- "AUGCUGCUAUGCUA"
189
- ]
190
-
191
- def validate_rna_sequence(seq):
192
- """Validate RNA sequence format."""
193
- return all(base in "AUCG" for base in seq) and len(seq) > 0
194
-
195
- for seq in rna_sequences:
196
- assert validate_rna_sequence(seq), f"Invalid RNA sequence: {seq}"
197
-
198
- def test_structure_prediction_mock_pattern(self):
199
- """Test structure prediction pattern without ViennaRNA dependency."""
200
- def mock_predict_structure_single(sequence):
201
- """Mock version of predict_structure_single from examples."""
202
- # Return a mock structure and energy
203
- return "." * len(sequence), -10.0
204
-
205
- # Test the pattern
206
- seq = "AUCG"
207
- struct, energy = mock_predict_structure_single(seq)
208
-
209
- assert len(struct) == len(seq)
210
- assert isinstance(energy, float)
211
- assert struct == "...."
212
-
213
- def test_base64_encoding_pattern(self):
214
- """Test base64 encoding pattern from SVG generation."""
215
- import base64
216
-
217
- def create_mock_svg_datauri(content="test"):
218
- """Mock version of SVG data URI creation."""
219
- svg_content = f'<svg>{content}</svg>'
220
- b64 = base64.b64encode(svg_content.encode()).decode('utf-8')
221
- return f"data:image/svg+xml;base64,{b64}"
222
-
223
- uri = create_mock_svg_datauri("test")
224
- assert uri.startswith("data:image/svg+xml;base64,")
225
-
226
- # Decode and verify
227
- _, b64_part = uri.split(",", 1)
228
- decoded = base64.b64decode(b64_part).decode('utf-8')
229
- assert decoded == "<svg>test</svg>"
230
-
231
- def test_longest_bp_span_function(self):
232
- """Test longest_bp_span function from easy_rna_design_emoo.py."""
233
- def longest_bp_span(structure):
234
- """Function from easy_rna_design_emoo.py."""
235
- stack = []
236
- max_span = 0
237
-
238
- for i, char in enumerate(structure):
239
- if char == '(':
240
- stack.append(i)
241
- elif char == ')':
242
- if stack:
243
- left_index = stack.pop()
244
- current_span = i - left_index
245
- max_span = max(max_span, current_span)
246
-
247
- return max_span
248
-
249
- # Test cases
250
- assert longest_bp_span("(())") == 3 # Outer pair spans 3 positions
251
- assert longest_bp_span("((()))") == 5 # Outer pair spans 5 positions
252
- assert longest_bp_span("()()") == 1 # Each pair spans 1 position
253
- assert longest_bp_span("....") == 0 # No pairs
254
- assert longest_bp_span("") == 0 # Empty structure
255
- assert longest_bp_span("((.))") == 4 # Outer pair spans 4 positions
@@ -1,302 +0,0 @@
1
- """
2
- Test training patterns and configurations based on examples.
3
- """
4
- import pytest
5
- from unittest.mock import patch, MagicMock
6
-
7
-
8
- class TestTrainingPatterns:
9
- """Test training patterns from examples."""
10
-
11
- def test_trainer_imports(self):
12
- """Test trainer imports as shown in quick_start.md."""
13
- try:
14
- from omnigenome import Trainer
15
- assert True
16
- except ImportError:
17
- pytest.skip("omnigenome not available or missing dependencies")
18
-
19
- def test_autobench_imports(self):
20
- """Test AutoBench imports from examples."""
21
- try:
22
- from omnigenome import AutoBench
23
- assert True
24
- except ImportError:
25
- pytest.skip("omnigenome not available or missing dependencies")
26
-
27
- def test_autocuda_import_pattern(self):
28
- """Test autocuda import pattern from examples."""
29
- try:
30
- import autocuda
31
- # Pattern from examples
32
- device = autocuda.auto_cuda()
33
- # Just verify the function exists and returns something
34
- assert device is not None
35
- except ImportError:
36
- # Skip if autocuda not available
37
- pytest.skip("autocuda not available")
38
-
39
- @patch('omnigenome.AutoBench')
40
- def test_autobench_initialization_pattern(self, mock_autobench):
41
- """Test AutoBench initialization pattern from quick_start.md."""
42
- mock_instance = MagicMock()
43
- mock_autobench.return_value = mock_instance
44
-
45
- from omnigenome import AutoBench
46
-
47
- # Pattern from quick_start.md
48
- auto_bench = AutoBench(
49
- benchmark="RGB",
50
- model_name_or_path="yangheng/OmniGenome-186M",
51
- device="cuda:0",
52
- overwrite=True
53
- )
54
-
55
- mock_autobench.assert_called_once_with(
56
- benchmark="RGB",
57
- model_name_or_path="yangheng/OmniGenome-186M",
58
- device="cuda:0",
59
- overwrite=True
60
- )
61
-
62
- def test_benchmark_names(self):
63
- """Test benchmark names from examples."""
64
- # Benchmarks from quick_start.md
65
- benchmarks = ["RGB", "GB", "PGB", "GUE", "BEACON"]
66
-
67
- for benchmark in benchmarks:
68
- assert isinstance(benchmark, str)
69
- assert len(benchmark) > 0
70
- assert benchmark.isupper()
71
-
72
- def test_trainer_names(self):
73
- """Test trainer types from examples."""
74
- # Trainers from autobench examples
75
- trainers = ["accelerate", "huggingface"]
76
-
77
- for trainer in trainers:
78
- assert isinstance(trainer, str)
79
- assert trainer in ["accelerate", "huggingface"]
80
-
81
- @patch('omnigenome.Trainer')
82
- def test_trainer_initialization_pattern(self, mock_trainer):
83
- """Test Trainer initialization pattern from quick_start.md."""
84
- mock_trainer.return_value = MagicMock()
85
-
86
- from omnigenome import Trainer
87
-
88
- # Mock training arguments
89
- mock_args = MagicMock()
90
- mock_args.output_dir = "./results"
91
- mock_args.num_train_epochs = 3
92
- mock_args.per_device_train_batch_size = 8
93
- mock_args.learning_rate = 2e-5
94
-
95
- # Pattern from quick_start.md
96
- trainer = Trainer(
97
- model=MagicMock(),
98
- train_dataset=MagicMock(),
99
- eval_dataset=MagicMock(),
100
- args=mock_args
101
- )
102
-
103
- mock_trainer.assert_called_once()
104
-
105
- def test_training_arguments_pattern(self):
106
- """Test training arguments patterns from examples."""
107
- # Common training parameters from examples
108
- training_configs = {
109
- "output_dir": "./results",
110
- "num_train_epochs": 3,
111
- "per_device_train_batch_size": 8,
112
- "learning_rate": 2e-5,
113
- "epochs": 3,
114
- "batch_size": 8,
115
- "seeds": [42, 43, 44]
116
- }
117
-
118
- # Verify types and ranges
119
- assert isinstance(training_configs["output_dir"], str)
120
- assert isinstance(training_configs["num_train_epochs"], int)
121
- assert training_configs["num_train_epochs"] > 0
122
- assert isinstance(training_configs["per_device_train_batch_size"], int)
123
- assert training_configs["per_device_train_batch_size"] > 0
124
- assert isinstance(training_configs["learning_rate"], float)
125
- assert training_configs["learning_rate"] > 0
126
- assert isinstance(training_configs["seeds"], list)
127
- assert all(isinstance(seed, int) for seed in training_configs["seeds"])
128
-
129
- def test_genetic_algorithm_parameters(self):
130
- """Test genetic algorithm parameters from RNA design examples."""
131
- # Parameters from easy_rna_design_emoo.py
132
- ga_params = {
133
- "mutation_ratio": 0.1,
134
- "num_population": 100,
135
- "num_generation": 50,
136
- "model": "anonymous8/OmniGenome-186M"
137
- }
138
-
139
- # Verify parameter types and ranges
140
- assert isinstance(ga_params["mutation_ratio"], float)
141
- assert 0.0 <= ga_params["mutation_ratio"] <= 1.0
142
- assert isinstance(ga_params["num_population"], int)
143
- assert ga_params["num_population"] > 0
144
- assert isinstance(ga_params["num_generation"], int)
145
- assert ga_params["num_generation"] > 0
146
- assert isinstance(ga_params["model"], str)
147
-
148
- def test_web_rna_design_parameters(self):
149
- """Test web RNA design parameters from web_rna_design.py."""
150
- # Parameters from web_rna_design.py
151
- web_params = {
152
- "mutation_ratio": 0.5,
153
- "num_population": 500,
154
- "num_generation": 10,
155
- "puzzle_id": 0
156
- }
157
-
158
- # Verify parameter types
159
- assert isinstance(web_params["mutation_ratio"], float)
160
- assert 0.0 <= web_params["mutation_ratio"] <= 1.0
161
- assert isinstance(web_params["num_population"], int)
162
- assert web_params["num_population"] > 0
163
- assert isinstance(web_params["num_generation"], int)
164
- assert web_params["num_generation"] > 0
165
- assert isinstance(web_params["puzzle_id"], int)
166
- assert web_params["puzzle_id"] >= 0
167
-
168
- def test_model_optimization_patterns(self):
169
- """Test model optimization patterns from examples."""
170
- # Patterns from examples for model optimization
171
- optimization_configs = {
172
- "torch_dtype": "float16",
173
- "device_map": "auto",
174
- "trust_remote_code": True,
175
- "gradient_checkpointing": True,
176
- "fp16": True
177
- }
178
-
179
- for key, value in optimization_configs.items():
180
- assert isinstance(key, str)
181
- # Value types vary, just ensure they exist
182
- assert value is not None
183
-
184
- @patch('torch.cuda.empty_cache')
185
- def test_memory_management_pattern(self, mock_empty_cache):
186
- """Test memory management patterns from web_rna_design.py."""
187
- try:
188
- import torch
189
- except ImportError:
190
- pytest.skip("torch not available")
191
-
192
- # Pattern from web_rna_design.py
193
- def cleanup_model_pattern():
194
- """Memory cleanup pattern from examples."""
195
- # del model, tokenizer # Would normally delete objects
196
- torch.cuda.empty_cache()
197
-
198
- cleanup_model_pattern()
199
- mock_empty_cache.assert_called_once()
200
-
201
- def test_random_seed_patterns(self):
202
- """Test random seed patterns from examples."""
203
- import random
204
-
205
- # Pattern from examples
206
- def set_random_seed_pattern():
207
- """Random seed pattern from easy_rna_design_emoo.py."""
208
- return random.randint(0, 99999999)
209
-
210
- # Test seed generation
211
- seed1 = set_random_seed_pattern()
212
- seed2 = set_random_seed_pattern()
213
-
214
- assert isinstance(seed1, int)
215
- assert isinstance(seed2, int)
216
- assert 0 <= seed1 <= 99999999
217
- assert 0 <= seed2 <= 99999999
218
-
219
- def test_evaluation_metrics_patterns(self):
220
- """Test evaluation metrics patterns from examples."""
221
- # Common metrics mentioned in examples
222
- metrics = [
223
- "accuracy",
224
- "f1_score",
225
- "precision",
226
- "recall",
227
- "mse",
228
- "mae",
229
- "r2_score"
230
- ]
231
-
232
- for metric in metrics:
233
- assert isinstance(metric, str)
234
- assert len(metric) > 0
235
-
236
- def test_device_selection_patterns(self):
237
- """Test device selection patterns from examples."""
238
- # Patterns from examples
239
- device_patterns = [
240
- "cuda:0",
241
- "cuda",
242
- "cpu",
243
- "auto"
244
- ]
245
-
246
- for device in device_patterns:
247
- assert isinstance(device, str)
248
- assert len(device) > 0
249
-
250
- def test_batch_size_patterns(self):
251
- """Test batch size patterns from examples."""
252
- # Common batch sizes from examples
253
- batch_sizes = [4, 8, 16, 32, 64]
254
-
255
- for batch_size in batch_sizes:
256
- assert isinstance(batch_size, int)
257
- assert batch_size > 0
258
- assert batch_size <= 1024 # Reasonable upper limit
259
-
260
- def test_learning_rate_patterns(self):
261
- """Test learning rate patterns from examples."""
262
- # Common learning rates from examples
263
- learning_rates = [1e-5, 2e-5, 5e-5, 1e-4, 2e-4]
264
-
265
- for lr in learning_rates:
266
- assert isinstance(lr, float)
267
- assert lr > 0
268
- assert lr < 1.0 # Should be small
269
-
270
- def test_epoch_patterns(self):
271
- """Test epoch patterns from examples."""
272
- # Common epoch counts from examples
273
- epoch_counts = [1, 3, 5, 10, 20]
274
-
275
- for epochs in epoch_counts:
276
- assert isinstance(epochs, int)
277
- assert epochs > 0
278
- assert epochs <= 100 # Reasonable upper limit
279
-
280
- def test_output_directory_patterns(self):
281
- """Test output directory patterns from examples."""
282
- # Common output directory patterns
283
- output_dirs = [
284
- "./results",
285
- "./output",
286
- "./checkpoints",
287
- "./models"
288
- ]
289
-
290
- for output_dir in output_dirs:
291
- assert isinstance(output_dir, str)
292
- assert output_dir.startswith("./") or output_dir.startswith("/")
293
-
294
- def test_model_saving_patterns(self):
295
- """Test model saving patterns from examples."""
296
- # File extensions for saved models
297
- model_extensions = [".pt", ".pth", ".bin", ".safetensors"]
298
-
299
- for ext in model_extensions:
300
- assert isinstance(ext, str)
301
- assert ext.startswith(".")
302
- assert len(ext) > 1