omnigenome 0.3.0a0__py3-none-any.whl → 0.3.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +14 -37
- omnigenome/src/misc/utils.py +199 -139
- omnigenome/src/model/rna_design/model.py +139 -96
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/METADATA +3 -3
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/RECORD +9 -16
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/top_level.txt +0 -1
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,9 +18,11 @@ import numpy as np
|
|
|
18
18
|
import torch
|
|
19
19
|
import autocuda
|
|
20
20
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
21
|
-
from concurrent.futures import ProcessPoolExecutor
|
|
21
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
22
22
|
import ViennaRNA
|
|
23
23
|
from scipy.spatial.distance import hamming
|
|
24
|
+
import warnings
|
|
25
|
+
import os
|
|
24
26
|
|
|
25
27
|
from omnigenome.src.misc.utils import fprint
|
|
26
28
|
|
|
@@ -72,162 +74,207 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
72
74
|
Generate a random base pair span.
|
|
73
75
|
|
|
74
76
|
Args:
|
|
75
|
-
bp_span (int, optional):
|
|
77
|
+
bp_span (int, optional): Fixed base pair span. If None, generates random.
|
|
76
78
|
|
|
77
79
|
Returns:
|
|
78
|
-
int:
|
|
80
|
+
int: Base pair span value
|
|
79
81
|
"""
|
|
80
|
-
|
|
82
|
+
if bp_span is None:
|
|
83
|
+
return random.randint(1, 10)
|
|
84
|
+
return bp_span
|
|
81
85
|
|
|
82
86
|
@staticmethod
|
|
83
87
|
def _longest_bp_span(structure):
|
|
84
88
|
"""
|
|
85
|
-
|
|
89
|
+
Find the longest base pair span in the structure.
|
|
86
90
|
|
|
87
91
|
Args:
|
|
88
92
|
structure (str): RNA structure in dot-bracket notation
|
|
89
93
|
|
|
90
94
|
Returns:
|
|
91
|
-
int: Length of the longest base
|
|
95
|
+
int: Length of the longest base pair span
|
|
92
96
|
"""
|
|
93
|
-
stack = []
|
|
94
97
|
max_span = 0
|
|
95
|
-
|
|
98
|
+
current_span = 0
|
|
99
|
+
|
|
100
|
+
for char in structure:
|
|
96
101
|
if char == "(":
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
current_span += 1
|
|
103
|
+
max_span = max(max_span, current_span)
|
|
104
|
+
elif char == ")":
|
|
105
|
+
current_span = max(0, current_span - 1)
|
|
106
|
+
else:
|
|
107
|
+
current_span = 0
|
|
108
|
+
|
|
101
109
|
return max_span
|
|
102
110
|
|
|
103
111
|
@staticmethod
|
|
104
112
|
def _predict_structure_single(sequence, bp_span=-1):
|
|
105
113
|
"""
|
|
106
|
-
Predict
|
|
114
|
+
Predict structure for a single sequence (worker function for multiprocessing).
|
|
107
115
|
|
|
108
116
|
Args:
|
|
109
|
-
sequence (str): RNA sequence
|
|
110
|
-
bp_span (int):
|
|
117
|
+
sequence (str): RNA sequence to fold
|
|
118
|
+
bp_span (int): Base pair span parameter
|
|
111
119
|
|
|
112
120
|
Returns:
|
|
113
|
-
tuple: (structure, mfe)
|
|
121
|
+
tuple: (structure, mfe) tuple
|
|
114
122
|
"""
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
123
|
+
try:
|
|
124
|
+
return ViennaRNA.fold(sequence)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
warnings.warn(f"Failed to fold sequence {sequence}: {e}")
|
|
127
|
+
return ("." * len(sequence), 0.0)
|
|
119
128
|
|
|
120
129
|
def _predict_structure(self, sequences, bp_span=-1):
|
|
121
130
|
"""
|
|
122
|
-
Predict
|
|
131
|
+
Predict structures for multiple sequences.
|
|
123
132
|
|
|
124
133
|
Args:
|
|
125
134
|
sequences (list): List of RNA sequences
|
|
126
|
-
bp_span (int):
|
|
135
|
+
bp_span (int): Base pair span parameter
|
|
127
136
|
|
|
128
137
|
Returns:
|
|
129
138
|
list: List of (structure, mfe) tuples
|
|
130
139
|
"""
|
|
131
|
-
|
|
140
|
+
if not self.parallel or len(sequences) <= 1:
|
|
141
|
+
# Sequential processing
|
|
142
|
+
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
143
|
+
|
|
144
|
+
# Parallel processing with improved error handling
|
|
145
|
+
try:
|
|
146
|
+
# Determine number of workers
|
|
147
|
+
max_workers = min(os.cpu_count(), len(sequences), 8) # Limit to 8 workers
|
|
148
|
+
|
|
149
|
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
150
|
+
# Submit all tasks
|
|
151
|
+
future_to_seq = {
|
|
152
|
+
executor.submit(self._predict_structure_single, seq, bp_span): seq
|
|
153
|
+
for seq in sequences
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
# Collect results
|
|
157
|
+
results = []
|
|
158
|
+
for future in as_completed(future_to_seq):
|
|
159
|
+
try:
|
|
160
|
+
result = future.result()
|
|
161
|
+
results.append(result)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
seq = future_to_seq[future]
|
|
164
|
+
warnings.warn(f"Failed to process sequence {seq}: {e}")
|
|
165
|
+
# Fallback to dot structure
|
|
166
|
+
results.append(("." * len(seq), 0.0))
|
|
167
|
+
|
|
168
|
+
return results
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
warnings.warn(f"Parallel processing failed, falling back to sequential: {e}")
|
|
172
|
+
# Fallback to sequential processing
|
|
173
|
+
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
132
174
|
|
|
133
175
|
def _init_population(self, structure, num_population):
|
|
134
176
|
"""
|
|
135
|
-
Initialize the population with
|
|
177
|
+
Initialize the population with random sequences.
|
|
136
178
|
|
|
137
179
|
Args:
|
|
138
|
-
structure (str): Target RNA structure
|
|
139
|
-
num_population (int):
|
|
180
|
+
structure (str): Target RNA structure
|
|
181
|
+
num_population (int): Population size
|
|
140
182
|
|
|
141
183
|
Returns:
|
|
142
|
-
list: List of (sequence, bp_span) tuples
|
|
184
|
+
list: List of (sequence, bp_span) tuples
|
|
143
185
|
"""
|
|
144
186
|
population = []
|
|
145
|
-
|
|
187
|
+
bp_span = self._longest_bp_span(structure)
|
|
188
|
+
|
|
146
189
|
for _ in range(num_population):
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
outputs = self._mlm_predict(mlm_inputs, structure)
|
|
153
|
-
|
|
154
|
-
for i, output in enumerate(outputs):
|
|
155
|
-
sequence = self.tokenizer.convert_ids_to_tokens(output.tolist())
|
|
156
|
-
fixed_sequence = [
|
|
157
|
-
x if x in "AGCT" else random.choice(["A", "T", "G", "C"])
|
|
158
|
-
for x in sequence
|
|
159
|
-
]
|
|
160
|
-
bp_span = self._random_bp_span(len(structure))
|
|
161
|
-
population.append(("".join(fixed_sequence), bp_span))
|
|
162
|
-
|
|
190
|
+
# Generate random sequence
|
|
191
|
+
sequence = "".join(random.choice("ACGU") for _ in range(len(structure)))
|
|
192
|
+
population.append((sequence, bp_span))
|
|
193
|
+
|
|
163
194
|
return population
|
|
164
195
|
|
|
165
196
|
def _mlm_mutate(self, population, structure, mutation_ratio):
|
|
166
197
|
"""
|
|
167
|
-
|
|
198
|
+
Mutate population using masked language modeling.
|
|
168
199
|
|
|
169
200
|
Args:
|
|
170
|
-
population (list): Current population
|
|
201
|
+
population (list): Current population
|
|
171
202
|
structure (str): Target RNA structure
|
|
172
203
|
mutation_ratio (float): Ratio of tokens to mutate
|
|
173
204
|
|
|
174
205
|
Returns:
|
|
175
|
-
list: Mutated population
|
|
206
|
+
list: Mutated population
|
|
176
207
|
"""
|
|
177
|
-
|
|
178
208
|
def mutate(sequence, mutation_rate):
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
209
|
+
# Create masked sequence
|
|
210
|
+
masked_sequence = list(sequence)
|
|
211
|
+
num_mutations = int(len(sequence) * mutation_rate)
|
|
212
|
+
mutation_positions = random.sample(range(len(sequence)), num_mutations)
|
|
213
|
+
|
|
214
|
+
for pos in mutation_positions:
|
|
215
|
+
masked_sequence[pos] = self.tokenizer.mask_token
|
|
216
|
+
|
|
217
|
+
return "".join(masked_sequence)
|
|
218
|
+
|
|
219
|
+
# Prepare inputs for MLM
|
|
184
220
|
mlm_inputs = []
|
|
185
221
|
for sequence, bp_span in population:
|
|
186
|
-
|
|
187
|
-
mlm_inputs.append(
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
222
|
+
masked_seq = mutate(sequence, mutation_ratio)
|
|
223
|
+
mlm_inputs.append(masked_seq)
|
|
224
|
+
|
|
225
|
+
# Get predictions from MLM
|
|
226
|
+
predicted_tokens = self._mlm_predict(mlm_inputs, structure)
|
|
227
|
+
|
|
228
|
+
# Convert predictions back to sequences
|
|
229
|
+
mutated_population = []
|
|
230
|
+
for i, (sequence, bp_span) in enumerate(population):
|
|
231
|
+
# Convert token IDs back to nucleotides
|
|
232
|
+
new_sequence = self.tokenizer.decode(predicted_tokens[i], skip_special_tokens=True)
|
|
233
|
+
# Ensure the sequence has the correct length
|
|
234
|
+
if len(new_sequence) != len(structure):
|
|
235
|
+
new_sequence = new_sequence[:len(structure)].ljust(len(structure), "A")
|
|
236
|
+
mutated_population.append((new_sequence, bp_span))
|
|
237
|
+
|
|
238
|
+
return mutated_population
|
|
202
239
|
|
|
203
240
|
def _crossover(self, population, num_points=3):
|
|
204
241
|
"""
|
|
205
|
-
Perform crossover operation
|
|
242
|
+
Perform crossover operation on the population.
|
|
206
243
|
|
|
207
244
|
Args:
|
|
208
|
-
population (list): Current population
|
|
209
|
-
num_points (int): Number of crossover points
|
|
245
|
+
population (list): Current population
|
|
246
|
+
num_points (int): Number of crossover points
|
|
210
247
|
|
|
211
248
|
Returns:
|
|
212
|
-
list:
|
|
249
|
+
list: Population after crossover
|
|
213
250
|
"""
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
251
|
+
if len(population) < 2:
|
|
252
|
+
return population
|
|
253
|
+
|
|
254
|
+
# Create crossover masks
|
|
255
|
+
num_sequences = len(population)
|
|
256
|
+
masks = np.zeros((num_sequences, len(population[0][0])), dtype=bool)
|
|
257
|
+
|
|
258
|
+
# Generate random crossover points
|
|
259
|
+
crossover_points = np.random.randint(0, len(population[0][0]), (num_sequences, num_points))
|
|
260
|
+
|
|
261
|
+
# Create parent indices
|
|
262
|
+
parent_indices = np.random.randint(0, num_sequences, (num_sequences, 2))
|
|
263
|
+
|
|
264
|
+
# Generate crossover masks
|
|
265
|
+
for i in range(num_sequences):
|
|
226
266
|
for j in range(num_points):
|
|
227
|
-
|
|
228
|
-
|
|
267
|
+
if j == 0:
|
|
268
|
+
masks[i, :crossover_points[i, j]] = True
|
|
269
|
+
else:
|
|
270
|
+
last_point = crossover_points[i, j-1]
|
|
271
|
+
masks[i, last_point:crossover_points[i, j]] = j % 2 == 0
|
|
272
|
+
|
|
273
|
+
# Handle the last segment
|
|
274
|
+
last_point = crossover_points[i, -1]
|
|
229
275
|
masks[i, last_point:] = num_points % 2 == 0
|
|
230
276
|
|
|
277
|
+
# Perform crossover
|
|
231
278
|
population_array = np.array([list(seq[0]) for seq in population])
|
|
232
279
|
child1_array = np.where(
|
|
233
280
|
masks,
|
|
@@ -259,15 +306,11 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
259
306
|
Returns:
|
|
260
307
|
list: Sorted population with fitness scores and MFE values
|
|
261
308
|
"""
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
)
|
|
268
|
-
)
|
|
269
|
-
else:
|
|
270
|
-
structures_mfe = self._predict_structure([seq for seq, _ in sequences])
|
|
309
|
+
# Get sequences for structure prediction
|
|
310
|
+
seq_list = [seq for seq, _ in sequences]
|
|
311
|
+
|
|
312
|
+
# Predict structures (with improved multiprocessing)
|
|
313
|
+
structures_mfe = self._predict_structure(seq_list)
|
|
271
314
|
|
|
272
315
|
sorted_population = []
|
|
273
316
|
for (seq, bp_span), (ss, mfe) in zip(sequences, structures_mfe):
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: omnigenome
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.0a1
|
|
4
4
|
Summary: OmniGenome: A comprehensive toolkit for genome analysis.
|
|
5
|
-
Home-page: https://github.com/yangheng95/
|
|
5
|
+
Home-page: https://github.com/yangheng95/OmniGenomeBench
|
|
6
6
|
Author: Yang, Heng
|
|
7
7
|
Author-email: hy345@exeter.ac.uk
|
|
8
|
-
License:
|
|
8
|
+
License: Apache-2.0
|
|
9
9
|
Platform: Windows
|
|
10
10
|
Platform: Linux
|
|
11
11
|
Platform: Mac OS-X
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
omnigenome/__init__.py,sha256=
|
|
1
|
+
omnigenome/__init__.py,sha256=ueMMkmyP6EjSvPUwNGLupoWT0W673sRbMXULhjbPjnU,9863
|
|
2
2
|
omnigenome/auto/__init__.py,sha256=UhcuYy43WsR7IowjajlcGwNVFFFDaufl8KqtNDmVqz0,97
|
|
3
3
|
omnigenome/auto/auto_bench/__init__.py,sha256=o0sPxaZM_KP5lRgidFUySr12OWguqB6PlL9ZhvWV1DM,411
|
|
4
4
|
omnigenome/auto/auto_bench/auto_bench.py,sha256=nprUgDGLLh4OIG9Qys6Aing1j8n_aw3ndSmx4PzAYN4,20781
|
|
@@ -34,7 +34,7 @@ omnigenome/src/metric/metric.py,sha256=mDd-8huMv9PiyWSaVWiIqNIaXQC5yI-zc_5WOTXWA
|
|
|
34
34
|
omnigenome/src/metric/ranking_metric.py,sha256=DTyNyhleDPDPEyg5HlDjlUpLS5uYne17SdDUejpXmCs,5826
|
|
35
35
|
omnigenome/src/metric/regression_metric.py,sha256=J_XOZ1jXSdqzkOgw4adHA-YLA4A_QcGlW8g0lgIm9xs,7753
|
|
36
36
|
omnigenome/src/misc/__init__.py,sha256=Dpa-uCQdwKVKkprqy26Np71mRobcWglCjgtITjU6yw0,63
|
|
37
|
-
omnigenome/src/misc/utils.py,sha256=
|
|
37
|
+
omnigenome/src/misc/utils.py,sha256=U8wk7-F2YhODKfSWhzkP8aJuoWIm49H5pAt3jHoJmVE,17241
|
|
38
38
|
omnigenome/src/model/__init__.py,sha256=vu1vJVYp8FR9BgF7X2msKkwMfa6jbzsfAsUHduTB21w,621
|
|
39
39
|
omnigenome/src/model/module_utils.py,sha256=rPJJfAcA4C8KumxSBJRCrCRxUSrwiRvLdbilIYIPS5U,9286
|
|
40
40
|
omnigenome/src/model/augmentation/__init__.py,sha256=JEZ1rszRUq7NBzwyu02eyNb_TTph2K3lXnXOCbHTtJc,396
|
|
@@ -49,7 +49,7 @@ omnigenome/src/model/regression/__init__.py,sha256=Qdd4ctbc6jqTJDxHLe5MzSA3eDvW4
|
|
|
49
49
|
omnigenome/src/model/regression/model.py,sha256=sgFqZ00J_gmeP9eRt1JYlbNN_KZhWLP1m4bEKKzV1Z8,28177
|
|
50
50
|
omnigenome/src/model/regression/resnet.py,sha256=YgzUAhGdXG_pAmvjQOpEjjzwxtm7sOb-a4et0CPJ09Y,17093
|
|
51
51
|
omnigenome/src/model/rna_design/__init__.py,sha256=jHAhyxuJScz1h1HY1UfZ3_fSVmwJOwsSACQkTItAl38,396
|
|
52
|
-
omnigenome/src/model/rna_design/model.py,sha256=
|
|
52
|
+
omnigenome/src/model/rna_design/model.py,sha256=HW5KcJiN-SWCvLalYS3w5ZprDK3GXR1sGr_15OybRlM,17343
|
|
53
53
|
omnigenome/src/model/seq2seq/__init__.py,sha256=OAi4RVSwCbFOIvEwQZCDTImBOFrLkHs1JXwipL_4fqs,406
|
|
54
54
|
omnigenome/src/model/seq2seq/model.py,sha256=-dGUjg7uRmnbR4rPH_lF8SgpR-U5lCoVJm4oNqzCOGg,1715
|
|
55
55
|
omnigenome/src/tokenizer/__init__.py,sha256=zYUgX-FJ-fw0GNJuuW8ovo9kflDmGDd8Z0F3AMDFXF4,556
|
|
@@ -70,16 +70,9 @@ omnigenome/utility/model_hub/model_hub.py,sha256=kgyjrU9qUb_pflIKqOQOUrk3zlF5pM8
|
|
|
70
70
|
omnigenome/utility/pipeline_hub/__init__.py,sha256=rm7k6GDXyrYGQyLO3ZFpYLnjAYf6s8xmJuOPypDNQ-g,395
|
|
71
71
|
omnigenome/utility/pipeline_hub/pipeline.py,sha256=F_pDC_JKJF3b8OZtqzKzl99Q1FLMRQdBaGURi8CjZzg,20121
|
|
72
72
|
omnigenome/utility/pipeline_hub/pipeline_hub.py,sha256=9HB5xZTr8HZtsuC6MrWWNbR4cg_5BW0CVXKQk2AwcWA,5384
|
|
73
|
-
omnigenome-0.3.
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
tests/test_rna_functions.py,sha256=f5RsT0n1dWv8YCuHkAaXzUjrn3nLqNoe3CIyGfMDYNY,10066
|
|
80
|
-
tests/test_training_patterns.py,sha256=ouAP-tDlAbUR2EmHjqDcsMnfOyp3Y4s7rfftzxZPF0I,10979
|
|
81
|
-
omnigenome-0.3.0a0.dist-info/METADATA,sha256=gQmzq0zgIiL7Lbl8qvMqraVDPqRu74C_WTDF9LODX0M,10306
|
|
82
|
-
omnigenome-0.3.0a0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
|
83
|
-
omnigenome-0.3.0a0.dist-info/entry_points.txt,sha256=uu40UgMPxY65ASdRbrhkwH94r7CIYgyG_iDBmqFQbD8,84
|
|
84
|
-
omnigenome-0.3.0a0.dist-info/top_level.txt,sha256=m8gQveMmM9nKDt36SOZTsagU7jEtZq7seCOwmDws-Lw,17
|
|
85
|
-
omnigenome-0.3.0a0.dist-info/RECORD,,
|
|
73
|
+
omnigenome-0.3.0a1.dist-info/licenses/LICENSE,sha256=oQoefBV6siHctF0ET-OO3EaSZgtqGtf-wdIAmokS8iY,11560
|
|
74
|
+
omnigenome-0.3.0a1.dist-info/METADATA,sha256=yT37KTD8T7iMB8nrqAasko3IxhpVR5L3QIkRdT6Qf3o,10318
|
|
75
|
+
omnigenome-0.3.0a1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
|
76
|
+
omnigenome-0.3.0a1.dist-info/entry_points.txt,sha256=uu40UgMPxY65ASdRbrhkwH94r7CIYgyG_iDBmqFQbD8,84
|
|
77
|
+
omnigenome-0.3.0a1.dist-info/top_level.txt,sha256=LVFxm_WPaxjj9KnAqdW94W4D4lbOk30gdsaKlJiSzTo,11
|
|
78
|
+
omnigenome-0.3.0a1.dist-info/RECORD,,
|
tests/__init__.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
OmniGenBench test suite.
|
|
3
|
-
|
|
4
|
-
This test suite validates functionality based on examples in the examples/ directory.
|
|
5
|
-
Tests are designed to be fast and avoid heavy dependencies while ensuring
|
|
6
|
-
code patterns and interfaces work correctly.
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
__version__ = "0.1.0"
|
tests/conftest.py
DELETED
|
@@ -1,160 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Pytest configuration and shared fixtures for OmniGenBench tests.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
import sys
|
|
6
|
-
import os
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
|
|
9
|
-
# Add the project root to Python path
|
|
10
|
-
ROOT_DIR = Path(__file__).parent.parent
|
|
11
|
-
sys.path.insert(0, str(ROOT_DIR))
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def pytest_configure(config):
|
|
15
|
-
"""Configure pytest with custom markers."""
|
|
16
|
-
config.addinivalue_line(
|
|
17
|
-
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
|
18
|
-
)
|
|
19
|
-
config.addinivalue_line(
|
|
20
|
-
"markers", "gpu: marks tests that require GPU (deselect with '-m \"not gpu\"')"
|
|
21
|
-
)
|
|
22
|
-
config.addinivalue_line(
|
|
23
|
-
"markers", "integration: marks tests as integration tests"
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def pytest_collection_modifyitems(config, items):
|
|
28
|
-
"""Auto-mark slow tests and skip GPU tests if CUDA not available."""
|
|
29
|
-
try:
|
|
30
|
-
import torch
|
|
31
|
-
cuda_available = torch.cuda.is_available()
|
|
32
|
-
except ImportError:
|
|
33
|
-
cuda_available = False
|
|
34
|
-
|
|
35
|
-
for item in items:
|
|
36
|
-
# Auto-mark slow tests
|
|
37
|
-
if "slow" in item.nodeid or "model_loading" in item.nodeid:
|
|
38
|
-
item.add_marker(pytest.mark.slow)
|
|
39
|
-
|
|
40
|
-
# Skip GPU tests if CUDA not available
|
|
41
|
-
if item.get_closest_marker("gpu") and not cuda_available:
|
|
42
|
-
item.add_marker(pytest.mark.skip(reason="CUDA not available"))
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@pytest.fixture
|
|
46
|
-
def sample_rna_sequences():
|
|
47
|
-
"""Sample RNA sequences for testing."""
|
|
48
|
-
return [
|
|
49
|
-
"AUGGCUACG",
|
|
50
|
-
"CGGAUACGGC",
|
|
51
|
-
"UGGCCAAGUC",
|
|
52
|
-
"AUGCUGCUAUGCUA"
|
|
53
|
-
]
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
@pytest.fixture
|
|
57
|
-
def sample_rna_structures():
|
|
58
|
-
"""Sample RNA secondary structures for testing."""
|
|
59
|
-
return [
|
|
60
|
-
"(((())))",
|
|
61
|
-
"(((...)))",
|
|
62
|
-
"........",
|
|
63
|
-
"((..))"
|
|
64
|
-
]
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@pytest.fixture
|
|
68
|
-
def sample_dataset_entries():
|
|
69
|
-
"""Sample dataset entries in the format used by examples."""
|
|
70
|
-
return [
|
|
71
|
-
{"seq": "AUCG", "label": "(..)"},
|
|
72
|
-
{"seq": "AUGC", "label": "().."},
|
|
73
|
-
{"seq": "CGAU", "label": "(())"},
|
|
74
|
-
{"seq": "GAUC", "label": "...."}
|
|
75
|
-
]
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@pytest.fixture
|
|
79
|
-
def mock_model_config():
|
|
80
|
-
"""Mock model configuration for testing."""
|
|
81
|
-
from unittest.mock import MagicMock
|
|
82
|
-
config = MagicMock()
|
|
83
|
-
config.hidden_size = 768
|
|
84
|
-
config.num_labels = 2
|
|
85
|
-
config.vocab_size = 32
|
|
86
|
-
config.max_position_embeddings = 512
|
|
87
|
-
return config
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
@pytest.fixture
|
|
91
|
-
def mock_tokenizer():
|
|
92
|
-
"""Mock tokenizer for testing."""
|
|
93
|
-
from unittest.mock import MagicMock
|
|
94
|
-
tokenizer = MagicMock()
|
|
95
|
-
tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
|
96
|
-
tokenizer.decode.return_value = "AUGC"
|
|
97
|
-
tokenizer.convert_ids_to_tokens.return_value = ["A", "U", "G", "C"]
|
|
98
|
-
tokenizer.vocab_size = 32
|
|
99
|
-
tokenizer.pad_token_id = 0
|
|
100
|
-
tokenizer.eos_token_id = 2
|
|
101
|
-
return tokenizer
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
@pytest.fixture
|
|
105
|
-
def temp_data_dir(tmp_path):
|
|
106
|
-
"""Create temporary directory with sample data files."""
|
|
107
|
-
data_dir = tmp_path / "data"
|
|
108
|
-
data_dir.mkdir()
|
|
109
|
-
|
|
110
|
-
# Create sample train.json
|
|
111
|
-
train_file = data_dir / "train.json"
|
|
112
|
-
train_data = [
|
|
113
|
-
'{"seq": "AUCG", "label": "(..)"}',
|
|
114
|
-
'{"seq": "AUGC", "label": "().."}',
|
|
115
|
-
'{"seq": "CGAU", "label": "(())"}'
|
|
116
|
-
]
|
|
117
|
-
train_file.write_text("\n".join(train_data))
|
|
118
|
-
|
|
119
|
-
# Create sample test.json
|
|
120
|
-
test_file = data_dir / "test.json"
|
|
121
|
-
test_data = [
|
|
122
|
-
'{"seq": "GAUC", "label": "...."}',
|
|
123
|
-
'{"seq": "UCGA", "label": "(.)"}'
|
|
124
|
-
]
|
|
125
|
-
test_file.write_text("\n".join(test_data))
|
|
126
|
-
|
|
127
|
-
# Create sample config.py
|
|
128
|
-
config_file = data_dir / "config.py"
|
|
129
|
-
config_content = '''
|
|
130
|
-
# Dataset configuration
|
|
131
|
-
max_length = 512
|
|
132
|
-
num_labels = 4
|
|
133
|
-
task_type = "classification"
|
|
134
|
-
'''
|
|
135
|
-
config_file.write_text(config_content)
|
|
136
|
-
|
|
137
|
-
return data_dir
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
@pytest.fixture(scope="session")
|
|
141
|
-
def examples_dir():
|
|
142
|
-
"""Path to examples directory."""
|
|
143
|
-
return ROOT_DIR / "examples"
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
@pytest.fixture
|
|
147
|
-
def skip_if_no_omnigenome():
|
|
148
|
-
"""Skip test if omnigenome package is not available."""
|
|
149
|
-
try:
|
|
150
|
-
import omnigenome
|
|
151
|
-
return False
|
|
152
|
-
except ImportError:
|
|
153
|
-
pytest.skip("omnigenome package not available")
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
# Custom pytest markers
|
|
157
|
-
pytestmark = [
|
|
158
|
-
pytest.mark.filterwarnings("ignore:.*:DeprecationWarning"),
|
|
159
|
-
pytest.mark.filterwarnings("ignore:.*:UserWarning"),
|
|
160
|
-
]
|