omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,9 +18,11 @@ import numpy as np
|
|
|
18
18
|
import torch
|
|
19
19
|
import autocuda
|
|
20
20
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
21
|
-
from concurrent.futures import ProcessPoolExecutor
|
|
21
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
22
22
|
import ViennaRNA
|
|
23
23
|
from scipy.spatial.distance import hamming
|
|
24
|
+
import warnings
|
|
25
|
+
import os
|
|
24
26
|
|
|
25
27
|
from omnigenome.src.misc.utils import fprint
|
|
26
28
|
|
|
@@ -28,19 +30,19 @@ from omnigenome.src.misc.utils import fprint
|
|
|
28
30
|
class OmniModelForRNADesign(torch.nn.Module):
|
|
29
31
|
"""
|
|
30
32
|
RNA design model using masked language modeling and evolutionary algorithms.
|
|
31
|
-
|
|
33
|
+
|
|
32
34
|
This model combines a pre-trained masked language model with evolutionary
|
|
33
35
|
algorithms to design RNA sequences that fold into specific target structures.
|
|
34
36
|
It uses a multi-objective optimization approach to balance structure similarity
|
|
35
37
|
and thermodynamic stability.
|
|
36
|
-
|
|
38
|
+
|
|
37
39
|
Attributes:
|
|
38
40
|
device: Device to run the model on (CPU or GPU)
|
|
39
41
|
parallel: Whether to use parallel processing for structure prediction
|
|
40
42
|
tokenizer: Tokenizer for processing RNA sequences
|
|
41
43
|
model: Pre-trained masked language model
|
|
42
44
|
"""
|
|
43
|
-
|
|
45
|
+
|
|
44
46
|
def __init__(
|
|
45
47
|
self,
|
|
46
48
|
model="yangheng/OmniGenome-186M",
|
|
@@ -51,7 +53,7 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
51
53
|
):
|
|
52
54
|
"""
|
|
53
55
|
Initialize the RNA design model.
|
|
54
|
-
|
|
56
|
+
|
|
55
57
|
Args:
|
|
56
58
|
model (str): Model name or path for the pre-trained MLM model
|
|
57
59
|
device: Device to run the model on (default: None, auto-detect)
|
|
@@ -70,164 +72,216 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
70
72
|
def _random_bp_span(bp_span=None):
|
|
71
73
|
"""
|
|
72
74
|
Generate a random base pair span.
|
|
73
|
-
|
|
75
|
+
|
|
74
76
|
Args:
|
|
75
|
-
bp_span (int, optional):
|
|
76
|
-
|
|
77
|
+
bp_span (int, optional): Fixed base pair span. If None, generates random.
|
|
78
|
+
|
|
77
79
|
Returns:
|
|
78
|
-
int:
|
|
80
|
+
int: Base pair span value
|
|
79
81
|
"""
|
|
80
|
-
|
|
82
|
+
if bp_span is None:
|
|
83
|
+
return random.randint(1, 10)
|
|
84
|
+
return bp_span
|
|
81
85
|
|
|
82
86
|
@staticmethod
|
|
83
87
|
def _longest_bp_span(structure):
|
|
84
88
|
"""
|
|
85
|
-
|
|
86
|
-
|
|
89
|
+
Find the longest base pair span in the structure.
|
|
90
|
+
|
|
87
91
|
Args:
|
|
88
92
|
structure (str): RNA structure in dot-bracket notation
|
|
89
|
-
|
|
93
|
+
|
|
90
94
|
Returns:
|
|
91
|
-
int: Length of the longest base
|
|
95
|
+
int: Length of the longest base pair span
|
|
92
96
|
"""
|
|
93
|
-
stack = []
|
|
94
97
|
max_span = 0
|
|
95
|
-
|
|
98
|
+
current_span = 0
|
|
99
|
+
|
|
100
|
+
for char in structure:
|
|
96
101
|
if char == "(":
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
current_span += 1
|
|
103
|
+
max_span = max(max_span, current_span)
|
|
104
|
+
elif char == ")":
|
|
105
|
+
current_span = max(0, current_span - 1)
|
|
106
|
+
else:
|
|
107
|
+
current_span = 0
|
|
108
|
+
|
|
101
109
|
return max_span
|
|
102
110
|
|
|
103
111
|
@staticmethod
|
|
104
112
|
def _predict_structure_single(sequence, bp_span=-1):
|
|
105
113
|
"""
|
|
106
|
-
Predict
|
|
107
|
-
|
|
114
|
+
Predict structure for a single sequence (worker function for multiprocessing).
|
|
115
|
+
|
|
108
116
|
Args:
|
|
109
|
-
sequence (str): RNA sequence
|
|
110
|
-
bp_span (int):
|
|
111
|
-
|
|
117
|
+
sequence (str): RNA sequence to fold
|
|
118
|
+
bp_span (int): Base pair span parameter
|
|
119
|
+
|
|
112
120
|
Returns:
|
|
113
|
-
tuple: (structure, mfe)
|
|
121
|
+
tuple: (structure, mfe) tuple
|
|
114
122
|
"""
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
123
|
+
try:
|
|
124
|
+
return ViennaRNA.fold(sequence)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
warnings.warn(f"Failed to fold sequence {sequence}: {e}")
|
|
127
|
+
return ("." * len(sequence), 0.0)
|
|
119
128
|
|
|
120
129
|
def _predict_structure(self, sequences, bp_span=-1):
|
|
121
130
|
"""
|
|
122
|
-
Predict
|
|
123
|
-
|
|
131
|
+
Predict structures for multiple sequences.
|
|
132
|
+
|
|
124
133
|
Args:
|
|
125
134
|
sequences (list): List of RNA sequences
|
|
126
|
-
bp_span (int):
|
|
127
|
-
|
|
135
|
+
bp_span (int): Base pair span parameter
|
|
136
|
+
|
|
128
137
|
Returns:
|
|
129
138
|
list: List of (structure, mfe) tuples
|
|
130
139
|
"""
|
|
131
|
-
|
|
140
|
+
if not self.parallel or len(sequences) <= 1:
|
|
141
|
+
# Sequential processing
|
|
142
|
+
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
143
|
+
|
|
144
|
+
# Parallel processing with improved error handling
|
|
145
|
+
try:
|
|
146
|
+
# Determine number of workers
|
|
147
|
+
max_workers = min(os.cpu_count(), len(sequences), 8) # Limit to 8 workers
|
|
148
|
+
|
|
149
|
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
150
|
+
# Submit all tasks
|
|
151
|
+
future_to_seq = {
|
|
152
|
+
executor.submit(self._predict_structure_single, seq, bp_span): seq
|
|
153
|
+
for seq in sequences
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
# Collect results
|
|
157
|
+
results = []
|
|
158
|
+
for future in as_completed(future_to_seq):
|
|
159
|
+
try:
|
|
160
|
+
result = future.result()
|
|
161
|
+
results.append(result)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
seq = future_to_seq[future]
|
|
164
|
+
warnings.warn(f"Failed to process sequence {seq}: {e}")
|
|
165
|
+
# Fallback to dot structure
|
|
166
|
+
results.append(("." * len(seq), 0.0))
|
|
167
|
+
|
|
168
|
+
return results
|
|
169
|
+
|
|
170
|
+
except Exception as e:
|
|
171
|
+
warnings.warn(
|
|
172
|
+
f"Parallel processing failed, falling back to sequential: {e}"
|
|
173
|
+
)
|
|
174
|
+
# Fallback to sequential processing
|
|
175
|
+
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
132
176
|
|
|
133
177
|
def _init_population(self, structure, num_population):
|
|
134
178
|
"""
|
|
135
|
-
Initialize the population with
|
|
136
|
-
|
|
179
|
+
Initialize the population with random sequences.
|
|
180
|
+
|
|
137
181
|
Args:
|
|
138
|
-
structure (str): Target RNA structure
|
|
139
|
-
num_population (int):
|
|
140
|
-
|
|
182
|
+
structure (str): Target RNA structure
|
|
183
|
+
num_population (int): Population size
|
|
184
|
+
|
|
141
185
|
Returns:
|
|
142
|
-
list: List of (sequence, bp_span) tuples
|
|
186
|
+
list: List of (sequence, bp_span) tuples
|
|
143
187
|
"""
|
|
144
188
|
population = []
|
|
145
|
-
|
|
146
|
-
for _ in range(num_population):
|
|
147
|
-
masked_sequence = "".join(
|
|
148
|
-
[random.choice(["G", "C", "<mask>"]) for _ in structure]
|
|
149
|
-
)
|
|
150
|
-
mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
|
|
151
|
-
|
|
152
|
-
outputs = self._mlm_predict(mlm_inputs, structure)
|
|
189
|
+
bp_span = self._longest_bp_span(structure)
|
|
153
190
|
|
|
154
|
-
for
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
for x in sequence
|
|
159
|
-
]
|
|
160
|
-
bp_span = self._random_bp_span(len(structure))
|
|
161
|
-
population.append(("".join(fixed_sequence), bp_span))
|
|
191
|
+
for _ in range(num_population):
|
|
192
|
+
# Generate random sequence
|
|
193
|
+
sequence = "".join(random.choice("ACGU") for _ in range(len(structure)))
|
|
194
|
+
population.append((sequence, bp_span))
|
|
162
195
|
|
|
163
196
|
return population
|
|
164
197
|
|
|
165
198
|
def _mlm_mutate(self, population, structure, mutation_ratio):
|
|
166
199
|
"""
|
|
167
|
-
|
|
168
|
-
|
|
200
|
+
Mutate population using masked language modeling.
|
|
201
|
+
|
|
169
202
|
Args:
|
|
170
|
-
population (list): Current population
|
|
203
|
+
population (list): Current population
|
|
171
204
|
structure (str): Target RNA structure
|
|
172
205
|
mutation_ratio (float): Ratio of tokens to mutate
|
|
173
|
-
|
|
206
|
+
|
|
174
207
|
Returns:
|
|
175
|
-
list: Mutated population
|
|
208
|
+
list: Mutated population
|
|
176
209
|
"""
|
|
177
210
|
|
|
178
211
|
def mutate(sequence, mutation_rate):
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
212
|
+
# Create masked sequence
|
|
213
|
+
masked_sequence = list(sequence)
|
|
214
|
+
num_mutations = int(len(sequence) * mutation_rate)
|
|
215
|
+
mutation_positions = random.sample(range(len(sequence)), num_mutations)
|
|
183
216
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
masked_sequence = mutate(sequence, mutation_ratio)
|
|
187
|
-
mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
|
|
217
|
+
for pos in mutation_positions:
|
|
218
|
+
masked_sequence[pos] = self.tokenizer.mask_token
|
|
188
219
|
|
|
189
|
-
|
|
220
|
+
return "".join(masked_sequence)
|
|
190
221
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
222
|
+
# Prepare inputs for MLM
|
|
223
|
+
mlm_inputs = []
|
|
224
|
+
for sequence, bp_span in population:
|
|
225
|
+
masked_seq = mutate(sequence, mutation_ratio)
|
|
226
|
+
mlm_inputs.append(masked_seq)
|
|
227
|
+
|
|
228
|
+
# Get predictions from MLM
|
|
229
|
+
predicted_tokens = self._mlm_predict(mlm_inputs, structure)
|
|
230
|
+
|
|
231
|
+
# Convert predictions back to sequences
|
|
232
|
+
mutated_population = []
|
|
233
|
+
for i, (sequence, bp_span) in enumerate(population):
|
|
234
|
+
# Convert token IDs back to nucleotides
|
|
235
|
+
new_sequence = self.tokenizer.decode(
|
|
236
|
+
predicted_tokens[i], skip_special_tokens=True
|
|
237
|
+
)
|
|
238
|
+
# Ensure the sequence has the correct length
|
|
239
|
+
if len(new_sequence) != len(structure):
|
|
240
|
+
new_sequence = new_sequence[: len(structure)].ljust(len(structure), "A")
|
|
241
|
+
mutated_population.append((new_sequence, bp_span))
|
|
200
242
|
|
|
201
|
-
return
|
|
243
|
+
return mutated_population
|
|
202
244
|
|
|
203
245
|
def _crossover(self, population, num_points=3):
|
|
204
246
|
"""
|
|
205
|
-
Perform crossover operation
|
|
206
|
-
|
|
247
|
+
Perform crossover operation on the population.
|
|
248
|
+
|
|
207
249
|
Args:
|
|
208
|
-
population (list): Current population
|
|
209
|
-
num_points (int): Number of crossover points
|
|
210
|
-
|
|
250
|
+
population (list): Current population
|
|
251
|
+
num_points (int): Number of crossover points
|
|
252
|
+
|
|
211
253
|
Returns:
|
|
212
|
-
list:
|
|
254
|
+
list: Population after crossover
|
|
213
255
|
"""
|
|
214
|
-
|
|
215
|
-
|
|
256
|
+
if len(population) < 2:
|
|
257
|
+
return population
|
|
258
|
+
|
|
259
|
+
# Create crossover masks
|
|
260
|
+
num_sequences = len(population)
|
|
261
|
+
masks = np.zeros((num_sequences, len(population[0][0])), dtype=bool)
|
|
216
262
|
|
|
217
|
-
|
|
218
|
-
crossover_points = np.
|
|
219
|
-
|
|
220
|
-
axis=1,
|
|
263
|
+
# Generate random crossover points
|
|
264
|
+
crossover_points = np.random.randint(
|
|
265
|
+
0, len(population[0][0]), (num_sequences, num_points)
|
|
221
266
|
)
|
|
222
267
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
268
|
+
# Create parent indices
|
|
269
|
+
parent_indices = np.random.randint(0, num_sequences, (num_sequences, 2))
|
|
270
|
+
|
|
271
|
+
# Generate crossover masks
|
|
272
|
+
for i in range(num_sequences):
|
|
226
273
|
for j in range(num_points):
|
|
227
|
-
|
|
228
|
-
|
|
274
|
+
if j == 0:
|
|
275
|
+
masks[i, : crossover_points[i, j]] = True
|
|
276
|
+
else:
|
|
277
|
+
last_point = crossover_points[i, j - 1]
|
|
278
|
+
masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
|
|
279
|
+
|
|
280
|
+
# Handle the last segment
|
|
281
|
+
last_point = crossover_points[i, -1]
|
|
229
282
|
masks[i, last_point:] = num_points % 2 == 0
|
|
230
283
|
|
|
284
|
+
# Perform crossover
|
|
231
285
|
population_array = np.array([list(seq[0]) for seq in population])
|
|
232
286
|
child1_array = np.where(
|
|
233
287
|
masks,
|
|
@@ -251,23 +305,19 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
251
305
|
def _evaluate_structure_fitness(self, sequences, structure):
|
|
252
306
|
"""
|
|
253
307
|
Evaluate the fitness of the RNA structure by comparing with the target structure.
|
|
254
|
-
|
|
308
|
+
|
|
255
309
|
Args:
|
|
256
310
|
sequences (list): List of (sequence, bp_span) tuples to evaluate
|
|
257
311
|
structure (str): Target RNA structure
|
|
258
|
-
|
|
312
|
+
|
|
259
313
|
Returns:
|
|
260
314
|
list: Sorted population with fitness scores and MFE values
|
|
261
315
|
"""
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
)
|
|
268
|
-
)
|
|
269
|
-
else:
|
|
270
|
-
structures_mfe = self._predict_structure([seq for seq, _ in sequences])
|
|
316
|
+
# Get sequences for structure prediction
|
|
317
|
+
seq_list = [seq for seq, _ in sequences]
|
|
318
|
+
|
|
319
|
+
# Predict structures (with improved multiprocessing)
|
|
320
|
+
structures_mfe = self._predict_structure(seq_list)
|
|
271
321
|
|
|
272
322
|
sorted_population = []
|
|
273
323
|
for (seq, bp_span), (ss, mfe) in zip(sequences, structures_mfe):
|
|
@@ -283,11 +333,11 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
283
333
|
def _non_dominated_sorting(scores, mfe_values):
|
|
284
334
|
"""
|
|
285
335
|
Perform non-dominated sorting for multi-objective optimization.
|
|
286
|
-
|
|
336
|
+
|
|
287
337
|
Args:
|
|
288
338
|
scores (list): Structure similarity scores
|
|
289
339
|
mfe_values (list): Minimum free energy values
|
|
290
|
-
|
|
340
|
+
|
|
291
341
|
Returns:
|
|
292
342
|
list: List of fronts (Pareto fronts)
|
|
293
343
|
"""
|
|
@@ -326,11 +376,11 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
326
376
|
def _select_next_generation(next_generation, fronts):
|
|
327
377
|
"""
|
|
328
378
|
Select the next generation based on Pareto fronts.
|
|
329
|
-
|
|
379
|
+
|
|
330
380
|
Args:
|
|
331
381
|
next_generation (list): Current population with fitness scores
|
|
332
382
|
fronts (list): Pareto fronts
|
|
333
|
-
|
|
383
|
+
|
|
334
384
|
Returns:
|
|
335
385
|
list: Selected population for the next generation
|
|
336
386
|
"""
|
|
@@ -346,11 +396,11 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
346
396
|
def _mlm_predict(self, mlm_inputs, structure):
|
|
347
397
|
"""
|
|
348
398
|
Perform masked language model prediction.
|
|
349
|
-
|
|
399
|
+
|
|
350
400
|
Args:
|
|
351
401
|
mlm_inputs (list): List of masked input sequences
|
|
352
402
|
structure (str): Target RNA structure
|
|
353
|
-
|
|
403
|
+
|
|
354
404
|
Returns:
|
|
355
405
|
list: Predicted token IDs for each input
|
|
356
406
|
"""
|
|
@@ -360,7 +410,7 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
360
410
|
with torch.no_grad():
|
|
361
411
|
for i in range(0, len(mlm_inputs), batch_size):
|
|
362
412
|
inputs = self.tokenizer(
|
|
363
|
-
mlm_inputs[i: i + batch_size],
|
|
413
|
+
mlm_inputs[i : i + batch_size],
|
|
364
414
|
padding=False,
|
|
365
415
|
max_length=1024,
|
|
366
416
|
truncation=True,
|
|
@@ -379,13 +429,13 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
379
429
|
):
|
|
380
430
|
"""
|
|
381
431
|
Design RNA sequences for a target structure using evolutionary algorithms.
|
|
382
|
-
|
|
432
|
+
|
|
383
433
|
Args:
|
|
384
434
|
structure (str): Target RNA structure in dot-bracket notation
|
|
385
435
|
mutation_ratio (float): Ratio of tokens to mutate (default: 0.5)
|
|
386
436
|
num_population (int): Population size (default: 100)
|
|
387
437
|
num_generation (int): Number of generations (default: 100)
|
|
388
|
-
|
|
438
|
+
|
|
389
439
|
Returns:
|
|
390
440
|
list: List of designed RNA sequences with their fitness scores
|
|
391
441
|
"""
|
|
@@ -21,20 +21,20 @@ from ...abc.abstract_model import OmniModel
|
|
|
21
21
|
class OmniModelForSeq2Seq(OmniModel):
|
|
22
22
|
"""
|
|
23
23
|
Sequence-to-sequence model for genomic sequences.
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
This model implements a sequence-to-sequence architecture for genomic
|
|
26
26
|
sequences, where the input is one sequence and the output is another
|
|
27
27
|
sequence. It's useful for tasks like sequence translation, structure
|
|
28
28
|
prediction, or sequence transformation.
|
|
29
|
-
|
|
29
|
+
|
|
30
30
|
The model can be extended to implement specific seq2seq tasks by
|
|
31
31
|
overriding the forward, predict, and inference methods.
|
|
32
32
|
"""
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
35
35
|
"""
|
|
36
36
|
Initialize the sequence-to-sequence model.
|
|
37
|
-
|
|
37
|
+
|
|
38
38
|
Args:
|
|
39
39
|
config_or_model: Model configuration or pre-trained model
|
|
40
40
|
tokenizer: Tokenizer for processing input sequences
|
|
@@ -17,17 +17,17 @@ warnings.filterwarnings("once")
|
|
|
17
17
|
def is_bpe_tokenization(tokens, threshold=0.1):
|
|
18
18
|
"""
|
|
19
19
|
Check if the tokenization is BPE-based by analyzing token characteristics.
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
This function examines the tokens to determine if they follow BPE tokenization
|
|
22
22
|
patterns by analyzing token length distributions and special token patterns.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
Args:
|
|
25
25
|
tokens (list): List of tokens to analyze
|
|
26
26
|
threshold (float, optional): Threshold for determining BPE tokenization. Defaults to 0.1
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
Returns:
|
|
29
29
|
bool: True if tokens appear to be BPE-based, False otherwise
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Example:
|
|
32
32
|
>>> tokens = ["▁hello", "▁world", "▁how", "▁are", "▁you"]
|
|
33
33
|
>>> is_bpe = is_bpe_tokenization(tokens)
|
|
@@ -52,15 +52,15 @@ def is_bpe_tokenization(tokens, threshold=0.1):
|
|
|
52
52
|
class OmniBPETokenizer(OmniTokenizer):
|
|
53
53
|
"""
|
|
54
54
|
A Byte Pair Encoding (BPE) tokenizer for genomic sequences.
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
This tokenizer uses BPE tokenization for genomic sequences and provides
|
|
57
57
|
validation to ensure the base tokenizer is BPE-based. It supports sequence
|
|
58
58
|
preprocessing and handles various input formats.
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
Attributes:
|
|
61
61
|
base_tokenizer: The underlying BPE tokenizer
|
|
62
62
|
metadata: Dictionary containing tokenizer metadata
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
Example:
|
|
65
65
|
>>> from omnigenome.src.tokenizer import OmniBPETokenizer
|
|
66
66
|
>>> from transformers import AutoTokenizer
|
|
@@ -75,7 +75,7 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
75
75
|
def __init__(self, base_tokenizer=None, **kwargs):
|
|
76
76
|
"""
|
|
77
77
|
Initialize the OmniBPETokenizer.
|
|
78
|
-
|
|
78
|
+
|
|
79
79
|
Args:
|
|
80
80
|
base_tokenizer: The base BPE tokenizer
|
|
81
81
|
**kwargs: Additional keyword arguments passed to parent class
|
|
@@ -86,21 +86,21 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
86
86
|
def __call__(self, sequence, **kwargs):
|
|
87
87
|
"""
|
|
88
88
|
Tokenize a sequence using BPE tokenization.
|
|
89
|
-
|
|
89
|
+
|
|
90
90
|
This method processes the input sequence using BPE tokenization,
|
|
91
91
|
handles sequence preprocessing (U/T conversion, whitespace addition),
|
|
92
92
|
and validates that the tokenization is BPE-based.
|
|
93
|
-
|
|
93
|
+
|
|
94
94
|
Args:
|
|
95
95
|
sequence (str): Input sequence to tokenize
|
|
96
96
|
**kwargs: Additional keyword arguments including max_length
|
|
97
|
-
|
|
97
|
+
|
|
98
98
|
Returns:
|
|
99
99
|
dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
|
|
100
|
-
|
|
100
|
+
|
|
101
101
|
Raises:
|
|
102
102
|
ValueError: If the tokenizer is not BPE-based
|
|
103
|
-
|
|
103
|
+
|
|
104
104
|
Example:
|
|
105
105
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
106
106
|
>>> tokenized = tokenizer(sequence)
|
|
@@ -136,14 +136,14 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
136
136
|
def from_pretrained(model_name_or_path, **kwargs):
|
|
137
137
|
"""
|
|
138
138
|
Create a BPE tokenizer from a pre-trained model.
|
|
139
|
-
|
|
139
|
+
|
|
140
140
|
Args:
|
|
141
141
|
model_name_or_path (str): Name or path of the pre-trained model
|
|
142
142
|
**kwargs: Additional keyword arguments
|
|
143
|
-
|
|
143
|
+
|
|
144
144
|
Returns:
|
|
145
145
|
OmniBPETokenizer: Initialized BPE tokenizer
|
|
146
|
-
|
|
146
|
+
|
|
147
147
|
Example:
|
|
148
148
|
>>> tokenizer = OmniBPETokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
149
149
|
>>> print(type(tokenizer))
|
|
@@ -159,14 +159,14 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
159
159
|
def tokenize(self, sequence, **kwargs):
|
|
160
160
|
"""
|
|
161
161
|
Tokenize a sequence using the base BPE tokenizer.
|
|
162
|
-
|
|
162
|
+
|
|
163
163
|
Args:
|
|
164
164
|
sequence (str): Input sequence to tokenize
|
|
165
165
|
**kwargs: Additional keyword arguments
|
|
166
|
-
|
|
166
|
+
|
|
167
167
|
Returns:
|
|
168
168
|
list: List of tokens
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
Example:
|
|
171
171
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
172
172
|
>>> tokens = tokenizer.tokenize(sequence)
|
|
@@ -178,17 +178,17 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
178
178
|
def encode(self, sequence, **kwargs):
|
|
179
179
|
"""
|
|
180
180
|
Encode a sequence using the base BPE tokenizer.
|
|
181
|
-
|
|
181
|
+
|
|
182
182
|
Args:
|
|
183
183
|
sequence (str): Input sequence to encode
|
|
184
184
|
**kwargs: Additional keyword arguments
|
|
185
|
-
|
|
185
|
+
|
|
186
186
|
Returns:
|
|
187
187
|
list: List of token IDs
|
|
188
|
-
|
|
188
|
+
|
|
189
189
|
Raises:
|
|
190
190
|
AssertionError: If the base tokenizer is not BPE-based
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
Example:
|
|
193
193
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
194
194
|
>>> token_ids = tokenizer.encode(sequence)
|
|
@@ -203,17 +203,17 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
203
203
|
def decode(self, sequence, **kwargs):
|
|
204
204
|
"""
|
|
205
205
|
Decode a sequence using the base BPE tokenizer.
|
|
206
|
-
|
|
206
|
+
|
|
207
207
|
Args:
|
|
208
208
|
sequence: Input sequence to decode (can be token IDs or tokens)
|
|
209
209
|
**kwargs: Additional keyword arguments
|
|
210
|
-
|
|
210
|
+
|
|
211
211
|
Returns:
|
|
212
212
|
str: Decoded sequence
|
|
213
|
-
|
|
213
|
+
|
|
214
214
|
Raises:
|
|
215
215
|
AssertionError: If the base tokenizer is not BPE-based
|
|
216
|
-
|
|
216
|
+
|
|
217
217
|
Example:
|
|
218
218
|
>>> token_ids = [1, 2, 3, 4, 5]
|
|
219
219
|
>>> sequence = tokenizer.decode(token_ids)
|