omnigenome 0.3.0a0__py3-none-any.whl → 0.3.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

@@ -18,9 +18,11 @@ import numpy as np
18
18
  import torch
19
19
  import autocuda
20
20
  from transformers import AutoModelForMaskedLM, AutoTokenizer
21
- from concurrent.futures import ProcessPoolExecutor
21
+ from concurrent.futures import ProcessPoolExecutor, as_completed
22
22
  import ViennaRNA
23
23
  from scipy.spatial.distance import hamming
24
+ import warnings
25
+ import os
24
26
 
25
27
  from omnigenome.src.misc.utils import fprint
26
28
 
@@ -72,162 +74,207 @@ class OmniModelForRNADesign(torch.nn.Module):
72
74
  Generate a random base pair span.
73
75
 
74
76
  Args:
75
- bp_span (int, optional): Base pair span to center around (default: None)
77
+ bp_span (int, optional): Fixed base pair span. If None, generates random.
76
78
 
77
79
  Returns:
78
- int: Random base pair span within ±50 of the input span
80
+ int: Base pair span value
79
81
  """
80
- return random.choice(range(max(0, bp_span - 50), min(bp_span + 50, 400)))
82
+ if bp_span is None:
83
+ return random.randint(1, 10)
84
+ return bp_span
81
85
 
82
86
  @staticmethod
83
87
  def _longest_bp_span(structure):
84
88
  """
85
- Compute the longest base-pair span from RNA structure.
89
+ Find the longest base pair span in the structure.
86
90
 
87
91
  Args:
88
92
  structure (str): RNA structure in dot-bracket notation
89
93
 
90
94
  Returns:
91
- int: Length of the longest base-pair span
95
+ int: Length of the longest base pair span
92
96
  """
93
- stack = []
94
97
  max_span = 0
95
- for i, char in enumerate(structure):
98
+ current_span = 0
99
+
100
+ for char in structure:
96
101
  if char == "(":
97
- stack.append(i)
98
- elif char == ")" and stack:
99
- left_index = stack.pop()
100
- max_span = max(max_span, i - left_index)
102
+ current_span += 1
103
+ max_span = max(max_span, current_span)
104
+ elif char == ")":
105
+ current_span = max(0, current_span - 1)
106
+ else:
107
+ current_span = 0
108
+
101
109
  return max_span
102
110
 
103
111
  @staticmethod
104
112
  def _predict_structure_single(sequence, bp_span=-1):
105
113
  """
106
- Predict the RNA structure and minimum free energy (MFE) for a single sequence.
114
+ Predict structure for a single sequence (worker function for multiprocessing).
107
115
 
108
116
  Args:
109
- sequence (str): RNA sequence
110
- bp_span (int): Maximum base pair span for folding (default: -1, no limit)
117
+ sequence (str): RNA sequence to fold
118
+ bp_span (int): Base pair span parameter
111
119
 
112
120
  Returns:
113
- tuple: (structure, mfe) where structure is in dot-bracket notation
121
+ tuple: (structure, mfe) tuple
114
122
  """
115
- md = ViennaRNA.md()
116
- md.max_bp_span = bp_span
117
- fc = ViennaRNA.fold_compound(sequence, md)
118
- return fc.mfe()
123
+ try:
124
+ return ViennaRNA.fold(sequence)
125
+ except Exception as e:
126
+ warnings.warn(f"Failed to fold sequence {sequence}: {e}")
127
+ return ("." * len(sequence), 0.0)
119
128
 
120
129
  def _predict_structure(self, sequences, bp_span=-1):
121
130
  """
122
- Predict RNA structures for multiple sequences.
131
+ Predict structures for multiple sequences.
123
132
 
124
133
  Args:
125
134
  sequences (list): List of RNA sequences
126
- bp_span (int): Maximum base pair span for folding (default: -1, no limit)
135
+ bp_span (int): Base pair span parameter
127
136
 
128
137
  Returns:
129
138
  list: List of (structure, mfe) tuples
130
139
  """
131
- return [self._predict_structure_single(seq, bp_span) for seq in sequences]
140
+ if not self.parallel or len(sequences) <= 1:
141
+ # Sequential processing
142
+ return [self._predict_structure_single(seq, bp_span) for seq in sequences]
143
+
144
+ # Parallel processing with improved error handling
145
+ try:
146
+ # Determine number of workers
147
+ max_workers = min(os.cpu_count(), len(sequences), 8) # Limit to 8 workers
148
+
149
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
150
+ # Submit all tasks
151
+ future_to_seq = {
152
+ executor.submit(self._predict_structure_single, seq, bp_span): seq
153
+ for seq in sequences
154
+ }
155
+
156
+ # Collect results
157
+ results = []
158
+ for future in as_completed(future_to_seq):
159
+ try:
160
+ result = future.result()
161
+ results.append(result)
162
+ except Exception as e:
163
+ seq = future_to_seq[future]
164
+ warnings.warn(f"Failed to process sequence {seq}: {e}")
165
+ # Fallback to dot structure
166
+ results.append(("." * len(seq), 0.0))
167
+
168
+ return results
169
+
170
+ except Exception as e:
171
+ warnings.warn(f"Parallel processing failed, falling back to sequential: {e}")
172
+ # Fallback to sequential processing
173
+ return [self._predict_structure_single(seq, bp_span) for seq in sequences]
132
174
 
133
175
  def _init_population(self, structure, num_population):
134
176
  """
135
- Initialize the population with masked sequences.
177
+ Initialize the population with random sequences.
136
178
 
137
179
  Args:
138
- structure (str): Target RNA structure in dot-bracket notation
139
- num_population (int): Number of individuals in the population
180
+ structure (str): Target RNA structure
181
+ num_population (int): Population size
140
182
 
141
183
  Returns:
142
- list: List of (sequence, bp_span) tuples representing the initial population
184
+ list: List of (sequence, bp_span) tuples
143
185
  """
144
186
  population = []
145
- mlm_inputs = []
187
+ bp_span = self._longest_bp_span(structure)
188
+
146
189
  for _ in range(num_population):
147
- masked_sequence = "".join(
148
- [random.choice(["G", "C", "<mask>"]) for _ in structure]
149
- )
150
- mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
151
-
152
- outputs = self._mlm_predict(mlm_inputs, structure)
153
-
154
- for i, output in enumerate(outputs):
155
- sequence = self.tokenizer.convert_ids_to_tokens(output.tolist())
156
- fixed_sequence = [
157
- x if x in "AGCT" else random.choice(["A", "T", "G", "C"])
158
- for x in sequence
159
- ]
160
- bp_span = self._random_bp_span(len(structure))
161
- population.append(("".join(fixed_sequence), bp_span))
162
-
190
+ # Generate random sequence
191
+ sequence = "".join(random.choice("ACGU") for _ in range(len(structure)))
192
+ population.append((sequence, bp_span))
193
+
163
194
  return population
164
195
 
165
196
  def _mlm_mutate(self, population, structure, mutation_ratio):
166
197
  """
167
- Apply mutation to the population using the masked language model (MLM).
198
+ Mutate population using masked language modeling.
168
199
 
169
200
  Args:
170
- population (list): Current population of (sequence, bp_span) tuples
201
+ population (list): Current population
171
202
  structure (str): Target RNA structure
172
203
  mutation_ratio (float): Ratio of tokens to mutate
173
204
 
174
205
  Returns:
175
- list: Mutated population of (sequence, bp_span) tuples
206
+ list: Mutated population
176
207
  """
177
-
178
208
  def mutate(sequence, mutation_rate):
179
- sequence = np.array(list(sequence))
180
- masked_indices = np.random.rand(len(sequence)) < mutation_rate
181
- sequence[masked_indices] = "$"
182
- return "".join(sequence).replace("$", "<mask>")
183
-
209
+ # Create masked sequence
210
+ masked_sequence = list(sequence)
211
+ num_mutations = int(len(sequence) * mutation_rate)
212
+ mutation_positions = random.sample(range(len(sequence)), num_mutations)
213
+
214
+ for pos in mutation_positions:
215
+ masked_sequence[pos] = self.tokenizer.mask_token
216
+
217
+ return "".join(masked_sequence)
218
+
219
+ # Prepare inputs for MLM
184
220
  mlm_inputs = []
185
221
  for sequence, bp_span in population:
186
- masked_sequence = mutate(sequence, mutation_ratio)
187
- mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
188
-
189
- outputs = self._mlm_predict(mlm_inputs, structure)
190
-
191
- mut_population = []
192
- for i, (seq, bp_span) in enumerate(population):
193
- sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
194
- fixed_sequence = [
195
- x if x in "AGCT" else random.choice(["A", "T", "G", "C"])
196
- for x in sequence
197
- ]
198
- bp_span = self._random_bp_span(bp_span)
199
- mut_population.append(("".join(fixed_sequence), bp_span))
200
-
201
- return mut_population
222
+ masked_seq = mutate(sequence, mutation_ratio)
223
+ mlm_inputs.append(masked_seq)
224
+
225
+ # Get predictions from MLM
226
+ predicted_tokens = self._mlm_predict(mlm_inputs, structure)
227
+
228
+ # Convert predictions back to sequences
229
+ mutated_population = []
230
+ for i, (sequence, bp_span) in enumerate(population):
231
+ # Convert token IDs back to nucleotides
232
+ new_sequence = self.tokenizer.decode(predicted_tokens[i], skip_special_tokens=True)
233
+ # Ensure the sequence has the correct length
234
+ if len(new_sequence) != len(structure):
235
+ new_sequence = new_sequence[:len(structure)].ljust(len(structure), "A")
236
+ mutated_population.append((new_sequence, bp_span))
237
+
238
+ return mutated_population
202
239
 
203
240
  def _crossover(self, population, num_points=3):
204
241
  """
205
- Perform crossover operation to create offspring.
242
+ Perform crossover operation on the population.
206
243
 
207
244
  Args:
208
- population (list): Current population of (sequence, bp_span) tuples
209
- num_points (int): Number of crossover points (default: 3)
245
+ population (list): Current population
246
+ num_points (int): Number of crossover points
210
247
 
211
248
  Returns:
212
- list: Offspring population after crossover
249
+ list: Population after crossover
213
250
  """
214
- population_size = len(population)
215
- sequence_length = len(population[0][0])
216
-
217
- parent_indices = np.random.choice(population_size // 10, (population_size, 2))
218
- crossover_points = np.sort(
219
- np.random.randint(1, sequence_length, size=(population_size, num_points)),
220
- axis=1,
221
- )
222
-
223
- masks = np.zeros((population_size, sequence_length), dtype=bool)
224
- for i in range(population_size):
225
- last_point = 0
251
+ if len(population) < 2:
252
+ return population
253
+
254
+ # Create crossover masks
255
+ num_sequences = len(population)
256
+ masks = np.zeros((num_sequences, len(population[0][0])), dtype=bool)
257
+
258
+ # Generate random crossover points
259
+ crossover_points = np.random.randint(0, len(population[0][0]), (num_sequences, num_points))
260
+
261
+ # Create parent indices
262
+ parent_indices = np.random.randint(0, num_sequences, (num_sequences, 2))
263
+
264
+ # Generate crossover masks
265
+ for i in range(num_sequences):
226
266
  for j in range(num_points):
227
- masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
228
- last_point = crossover_points[i, j]
267
+ if j == 0:
268
+ masks[i, :crossover_points[i, j]] = True
269
+ else:
270
+ last_point = crossover_points[i, j-1]
271
+ masks[i, last_point:crossover_points[i, j]] = j % 2 == 0
272
+
273
+ # Handle the last segment
274
+ last_point = crossover_points[i, -1]
229
275
  masks[i, last_point:] = num_points % 2 == 0
230
276
 
277
+ # Perform crossover
231
278
  population_array = np.array([list(seq[0]) for seq in population])
232
279
  child1_array = np.where(
233
280
  masks,
@@ -259,15 +306,11 @@ class OmniModelForRNADesign(torch.nn.Module):
259
306
  Returns:
260
307
  list: Sorted population with fitness scores and MFE values
261
308
  """
262
- if self.parallel:
263
- with ProcessPoolExecutor() as executor:
264
- structures_mfe = list(
265
- executor.map(
266
- self._predict_structure_single, [seq for seq, _ in sequences]
267
- )
268
- )
269
- else:
270
- structures_mfe = self._predict_structure([seq for seq, _ in sequences])
309
+ # Get sequences for structure prediction
310
+ seq_list = [seq for seq, _ in sequences]
311
+
312
+ # Predict structures (with improved multiprocessing)
313
+ structures_mfe = self._predict_structure(seq_list)
271
314
 
272
315
  sorted_population = []
273
316
  for (seq, bp_span), (ss, mfe) in zip(sequences, structures_mfe):
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: omnigenome
3
- Version: 0.3.0a0
3
+ Version: 0.3.0a1
4
4
  Summary: OmniGenome: A comprehensive toolkit for genome analysis.
5
- Home-page: https://github.com/yangheng95/omnigenome
5
+ Home-page: https://github.com/yangheng95/OmniGenomeBench
6
6
  Author: Yang, Heng
7
7
  Author-email: hy345@exeter.ac.uk
8
- License: MIT
8
+ License: Apache-2.0
9
9
  Platform: Windows
10
10
  Platform: Linux
11
11
  Platform: Mac OS-X
@@ -1,4 +1,4 @@
1
- omnigenome/__init__.py,sha256=rG1SRLhIMfh9IbKkpDjoaa99jx67AItyNmW9hmQ6WF0,10719
1
+ omnigenome/__init__.py,sha256=ueMMkmyP6EjSvPUwNGLupoWT0W673sRbMXULhjbPjnU,9863
2
2
  omnigenome/auto/__init__.py,sha256=UhcuYy43WsR7IowjajlcGwNVFFFDaufl8KqtNDmVqz0,97
3
3
  omnigenome/auto/auto_bench/__init__.py,sha256=o0sPxaZM_KP5lRgidFUySr12OWguqB6PlL9ZhvWV1DM,411
4
4
  omnigenome/auto/auto_bench/auto_bench.py,sha256=nprUgDGLLh4OIG9Qys6Aing1j8n_aw3ndSmx4PzAYN4,20781
@@ -34,7 +34,7 @@ omnigenome/src/metric/metric.py,sha256=mDd-8huMv9PiyWSaVWiIqNIaXQC5yI-zc_5WOTXWA
34
34
  omnigenome/src/metric/ranking_metric.py,sha256=DTyNyhleDPDPEyg5HlDjlUpLS5uYne17SdDUejpXmCs,5826
35
35
  omnigenome/src/metric/regression_metric.py,sha256=J_XOZ1jXSdqzkOgw4adHA-YLA4A_QcGlW8g0lgIm9xs,7753
36
36
  omnigenome/src/misc/__init__.py,sha256=Dpa-uCQdwKVKkprqy26Np71mRobcWglCjgtITjU6yw0,63
37
- omnigenome/src/misc/utils.py,sha256=8b7FHp0OlyIbmbINOEgHa9nlhKz5qZ92x1tfAy7S0ko,15296
37
+ omnigenome/src/misc/utils.py,sha256=U8wk7-F2YhODKfSWhzkP8aJuoWIm49H5pAt3jHoJmVE,17241
38
38
  omnigenome/src/model/__init__.py,sha256=vu1vJVYp8FR9BgF7X2msKkwMfa6jbzsfAsUHduTB21w,621
39
39
  omnigenome/src/model/module_utils.py,sha256=rPJJfAcA4C8KumxSBJRCrCRxUSrwiRvLdbilIYIPS5U,9286
40
40
  omnigenome/src/model/augmentation/__init__.py,sha256=JEZ1rszRUq7NBzwyu02eyNb_TTph2K3lXnXOCbHTtJc,396
@@ -49,7 +49,7 @@ omnigenome/src/model/regression/__init__.py,sha256=Qdd4ctbc6jqTJDxHLe5MzSA3eDvW4
49
49
  omnigenome/src/model/regression/model.py,sha256=sgFqZ00J_gmeP9eRt1JYlbNN_KZhWLP1m4bEKKzV1Z8,28177
50
50
  omnigenome/src/model/regression/resnet.py,sha256=YgzUAhGdXG_pAmvjQOpEjjzwxtm7sOb-a4et0CPJ09Y,17093
51
51
  omnigenome/src/model/rna_design/__init__.py,sha256=jHAhyxuJScz1h1HY1UfZ3_fSVmwJOwsSACQkTItAl38,396
52
- omnigenome/src/model/rna_design/model.py,sha256=_4RQtlmLPCpMCDXWweV_FOiWPNN-ZrjceWcnw9Gphsc,15826
52
+ omnigenome/src/model/rna_design/model.py,sha256=HW5KcJiN-SWCvLalYS3w5ZprDK3GXR1sGr_15OybRlM,17343
53
53
  omnigenome/src/model/seq2seq/__init__.py,sha256=OAi4RVSwCbFOIvEwQZCDTImBOFrLkHs1JXwipL_4fqs,406
54
54
  omnigenome/src/model/seq2seq/model.py,sha256=-dGUjg7uRmnbR4rPH_lF8SgpR-U5lCoVJm4oNqzCOGg,1715
55
55
  omnigenome/src/tokenizer/__init__.py,sha256=zYUgX-FJ-fw0GNJuuW8ovo9kflDmGDd8Z0F3AMDFXF4,556
@@ -70,16 +70,9 @@ omnigenome/utility/model_hub/model_hub.py,sha256=kgyjrU9qUb_pflIKqOQOUrk3zlF5pM8
70
70
  omnigenome/utility/pipeline_hub/__init__.py,sha256=rm7k6GDXyrYGQyLO3ZFpYLnjAYf6s8xmJuOPypDNQ-g,395
71
71
  omnigenome/utility/pipeline_hub/pipeline.py,sha256=F_pDC_JKJF3b8OZtqzKzl99Q1FLMRQdBaGURi8CjZzg,20121
72
72
  omnigenome/utility/pipeline_hub/pipeline_hub.py,sha256=9HB5xZTr8HZtsuC6MrWWNbR4cg_5BW0CVXKQk2AwcWA,5384
73
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE,sha256=oQoefBV6siHctF0ET-OO3EaSZgtqGtf-wdIAmokS8iY,11560
74
- tests/__init__.py,sha256=MsAPLRxLTpyXAhwM2gnJ4ibJT6h5-SvyFd7gglSfZ2c,270
75
- tests/conftest.py,sha256=YNK66YqdtjofE65R59JJ2aiq24a3ltQ1ISSdf4Uqvlg,4344
76
- tests/test_dataset_patterns.py,sha256=x0pv09jOircm2fzbZ1xseCitZCSEftoOvVKv-3O_BJ4,11020
77
- tests/test_examples_syntax.py,sha256=0ERqLxOoi05zGZqkKKaAoHkhWggxyXGd7h2HVVd2Wtc,3277
78
- tests/test_model_loading.py,sha256=H5Ug1jNns74_CzL_j5fzqm_eFke4VlQF9HEmAV733eY,7145
79
- tests/test_rna_functions.py,sha256=f5RsT0n1dWv8YCuHkAaXzUjrn3nLqNoe3CIyGfMDYNY,10066
80
- tests/test_training_patterns.py,sha256=ouAP-tDlAbUR2EmHjqDcsMnfOyp3Y4s7rfftzxZPF0I,10979
81
- omnigenome-0.3.0a0.dist-info/METADATA,sha256=gQmzq0zgIiL7Lbl8qvMqraVDPqRu74C_WTDF9LODX0M,10306
82
- omnigenome-0.3.0a0.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
83
- omnigenome-0.3.0a0.dist-info/entry_points.txt,sha256=uu40UgMPxY65ASdRbrhkwH94r7CIYgyG_iDBmqFQbD8,84
84
- omnigenome-0.3.0a0.dist-info/top_level.txt,sha256=m8gQveMmM9nKDt36SOZTsagU7jEtZq7seCOwmDws-Lw,17
85
- omnigenome-0.3.0a0.dist-info/RECORD,,
73
+ omnigenome-0.3.0a1.dist-info/licenses/LICENSE,sha256=oQoefBV6siHctF0ET-OO3EaSZgtqGtf-wdIAmokS8iY,11560
74
+ omnigenome-0.3.0a1.dist-info/METADATA,sha256=yT37KTD8T7iMB8nrqAasko3IxhpVR5L3QIkRdT6Qf3o,10318
75
+ omnigenome-0.3.0a1.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
76
+ omnigenome-0.3.0a1.dist-info/entry_points.txt,sha256=uu40UgMPxY65ASdRbrhkwH94r7CIYgyG_iDBmqFQbD8,84
77
+ omnigenome-0.3.0a1.dist-info/top_level.txt,sha256=LVFxm_WPaxjj9KnAqdW94W4D4lbOk30gdsaKlJiSzTo,11
78
+ omnigenome-0.3.0a1.dist-info/RECORD,,
tests/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- """
2
- OmniGenBench test suite.
3
-
4
- This test suite validates functionality based on examples in the examples/ directory.
5
- Tests are designed to be fast and avoid heavy dependencies while ensuring
6
- code patterns and interfaces work correctly.
7
- """
8
-
9
- __version__ = "0.1.0"
tests/conftest.py DELETED
@@ -1,160 +0,0 @@
1
- """
2
- Pytest configuration and shared fixtures for OmniGenBench tests.
3
- """
4
- import pytest
5
- import sys
6
- import os
7
- from pathlib import Path
8
-
9
- # Add the project root to Python path
10
- ROOT_DIR = Path(__file__).parent.parent
11
- sys.path.insert(0, str(ROOT_DIR))
12
-
13
-
14
- def pytest_configure(config):
15
- """Configure pytest with custom markers."""
16
- config.addinivalue_line(
17
- "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
18
- )
19
- config.addinivalue_line(
20
- "markers", "gpu: marks tests that require GPU (deselect with '-m \"not gpu\"')"
21
- )
22
- config.addinivalue_line(
23
- "markers", "integration: marks tests as integration tests"
24
- )
25
-
26
-
27
- def pytest_collection_modifyitems(config, items):
28
- """Auto-mark slow tests and skip GPU tests if CUDA not available."""
29
- try:
30
- import torch
31
- cuda_available = torch.cuda.is_available()
32
- except ImportError:
33
- cuda_available = False
34
-
35
- for item in items:
36
- # Auto-mark slow tests
37
- if "slow" in item.nodeid or "model_loading" in item.nodeid:
38
- item.add_marker(pytest.mark.slow)
39
-
40
- # Skip GPU tests if CUDA not available
41
- if item.get_closest_marker("gpu") and not cuda_available:
42
- item.add_marker(pytest.mark.skip(reason="CUDA not available"))
43
-
44
-
45
- @pytest.fixture
46
- def sample_rna_sequences():
47
- """Sample RNA sequences for testing."""
48
- return [
49
- "AUGGCUACG",
50
- "CGGAUACGGC",
51
- "UGGCCAAGUC",
52
- "AUGCUGCUAUGCUA"
53
- ]
54
-
55
-
56
- @pytest.fixture
57
- def sample_rna_structures():
58
- """Sample RNA secondary structures for testing."""
59
- return [
60
- "(((())))",
61
- "(((...)))",
62
- "........",
63
- "((..))"
64
- ]
65
-
66
-
67
- @pytest.fixture
68
- def sample_dataset_entries():
69
- """Sample dataset entries in the format used by examples."""
70
- return [
71
- {"seq": "AUCG", "label": "(..)"},
72
- {"seq": "AUGC", "label": "().."},
73
- {"seq": "CGAU", "label": "(())"},
74
- {"seq": "GAUC", "label": "...."}
75
- ]
76
-
77
-
78
- @pytest.fixture
79
- def mock_model_config():
80
- """Mock model configuration for testing."""
81
- from unittest.mock import MagicMock
82
- config = MagicMock()
83
- config.hidden_size = 768
84
- config.num_labels = 2
85
- config.vocab_size = 32
86
- config.max_position_embeddings = 512
87
- return config
88
-
89
-
90
- @pytest.fixture
91
- def mock_tokenizer():
92
- """Mock tokenizer for testing."""
93
- from unittest.mock import MagicMock
94
- tokenizer = MagicMock()
95
- tokenizer.encode.return_value = [1, 2, 3, 4, 5]
96
- tokenizer.decode.return_value = "AUGC"
97
- tokenizer.convert_ids_to_tokens.return_value = ["A", "U", "G", "C"]
98
- tokenizer.vocab_size = 32
99
- tokenizer.pad_token_id = 0
100
- tokenizer.eos_token_id = 2
101
- return tokenizer
102
-
103
-
104
- @pytest.fixture
105
- def temp_data_dir(tmp_path):
106
- """Create temporary directory with sample data files."""
107
- data_dir = tmp_path / "data"
108
- data_dir.mkdir()
109
-
110
- # Create sample train.json
111
- train_file = data_dir / "train.json"
112
- train_data = [
113
- '{"seq": "AUCG", "label": "(..)"}',
114
- '{"seq": "AUGC", "label": "().."}',
115
- '{"seq": "CGAU", "label": "(())"}'
116
- ]
117
- train_file.write_text("\n".join(train_data))
118
-
119
- # Create sample test.json
120
- test_file = data_dir / "test.json"
121
- test_data = [
122
- '{"seq": "GAUC", "label": "...."}',
123
- '{"seq": "UCGA", "label": "(.)"}'
124
- ]
125
- test_file.write_text("\n".join(test_data))
126
-
127
- # Create sample config.py
128
- config_file = data_dir / "config.py"
129
- config_content = '''
130
- # Dataset configuration
131
- max_length = 512
132
- num_labels = 4
133
- task_type = "classification"
134
- '''
135
- config_file.write_text(config_content)
136
-
137
- return data_dir
138
-
139
-
140
- @pytest.fixture(scope="session")
141
- def examples_dir():
142
- """Path to examples directory."""
143
- return ROOT_DIR / "examples"
144
-
145
-
146
- @pytest.fixture
147
- def skip_if_no_omnigenome():
148
- """Skip test if omnigenome package is not available."""
149
- try:
150
- import omnigenome
151
- return False
152
- except ImportError:
153
- pytest.skip("omnigenome package not available")
154
-
155
-
156
- # Custom pytest markers
157
- pytestmark = [
158
- pytest.mark.filterwarnings("ignore:.*:DeprecationWarning"),
159
- pytest.mark.filterwarnings("ignore:.*:UserWarning"),
160
- ]