omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: model.py
|
|
3
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
4
|
+
# github: https://github.com/yangheng95
|
|
5
|
+
# huggingface: https://huggingface.co/yangheng
|
|
6
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
7
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
8
|
+
"""
|
|
9
|
+
RNA design model using masked language modeling and evolutionary algorithms.
|
|
10
|
+
|
|
11
|
+
This module provides an RNA design model that combines masked language modeling
|
|
12
|
+
with evolutionary algorithms to design RNA sequences that fold into specific
|
|
13
|
+
target structures. It uses a multi-objective optimization approach to balance
|
|
14
|
+
structure similarity and thermodynamic stability.
|
|
15
|
+
"""
|
|
16
|
+
import random
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
import autocuda
|
|
20
|
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
21
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
22
|
+
import ViennaRNA
|
|
23
|
+
from scipy.spatial.distance import hamming
|
|
24
|
+
|
|
25
|
+
from omnigenome.src.misc.utils import fprint
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class OmniModelForRNADesign(torch.nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
RNA design model using masked language modeling and evolutionary algorithms.
|
|
31
|
+
|
|
32
|
+
This model combines a pre-trained masked language model with evolutionary
|
|
33
|
+
algorithms to design RNA sequences that fold into specific target structures.
|
|
34
|
+
It uses a multi-objective optimization approach to balance structure similarity
|
|
35
|
+
and thermodynamic stability.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
device: Device to run the model on (CPU or GPU)
|
|
39
|
+
parallel: Whether to use parallel processing for structure prediction
|
|
40
|
+
tokenizer: Tokenizer for processing RNA sequences
|
|
41
|
+
model: Pre-trained masked language model
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
model="yangheng/OmniGenome-186M",
|
|
47
|
+
device=None,
|
|
48
|
+
parallel=False,
|
|
49
|
+
*args,
|
|
50
|
+
**kwargs,
|
|
51
|
+
):
|
|
52
|
+
"""
|
|
53
|
+
Initialize the RNA design model.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
model (str): Model name or path for the pre-trained MLM model
|
|
57
|
+
device: Device to run the model on (default: None, auto-detect)
|
|
58
|
+
parallel (bool): Whether to use parallel processing (default: False)
|
|
59
|
+
*args: Additional positional arguments
|
|
60
|
+
**kwargs: Additional keyword arguments
|
|
61
|
+
"""
|
|
62
|
+
super().__init__(*args, **kwargs)
|
|
63
|
+
self.device = autocuda.auto_cuda() if device is None else device
|
|
64
|
+
self.parallel = parallel
|
|
65
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
|
66
|
+
self.model = AutoModelForMaskedLM.from_pretrained(model, trust_remote_code=True)
|
|
67
|
+
self.model.to(self.device).to(torch.float16)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _random_bp_span(bp_span=None):
|
|
71
|
+
"""
|
|
72
|
+
Generate a random base pair span.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
bp_span (int, optional): Base pair span to center around (default: None)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
int: Random base pair span within ±50 of the input span
|
|
79
|
+
"""
|
|
80
|
+
return random.choice(range(max(0, bp_span - 50), min(bp_span + 50, 400)))
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def _longest_bp_span(structure):
|
|
84
|
+
"""
|
|
85
|
+
Compute the longest base-pair span from RNA structure.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
structure (str): RNA structure in dot-bracket notation
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
int: Length of the longest base-pair span
|
|
92
|
+
"""
|
|
93
|
+
stack = []
|
|
94
|
+
max_span = 0
|
|
95
|
+
for i, char in enumerate(structure):
|
|
96
|
+
if char == "(":
|
|
97
|
+
stack.append(i)
|
|
98
|
+
elif char == ")" and stack:
|
|
99
|
+
left_index = stack.pop()
|
|
100
|
+
max_span = max(max_span, i - left_index)
|
|
101
|
+
return max_span
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _predict_structure_single(sequence, bp_span=-1):
|
|
105
|
+
"""
|
|
106
|
+
Predict the RNA structure and minimum free energy (MFE) for a single sequence.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
sequence (str): RNA sequence
|
|
110
|
+
bp_span (int): Maximum base pair span for folding (default: -1, no limit)
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
tuple: (structure, mfe) where structure is in dot-bracket notation
|
|
114
|
+
"""
|
|
115
|
+
md = ViennaRNA.md()
|
|
116
|
+
md.max_bp_span = bp_span
|
|
117
|
+
fc = ViennaRNA.fold_compound(sequence, md)
|
|
118
|
+
return fc.mfe()
|
|
119
|
+
|
|
120
|
+
def _predict_structure(self, sequences, bp_span=-1):
|
|
121
|
+
"""
|
|
122
|
+
Predict RNA structures for multiple sequences.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
sequences (list): List of RNA sequences
|
|
126
|
+
bp_span (int): Maximum base pair span for folding (default: -1, no limit)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
list: List of (structure, mfe) tuples
|
|
130
|
+
"""
|
|
131
|
+
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
132
|
+
|
|
133
|
+
def _init_population(self, structure, num_population):
|
|
134
|
+
"""
|
|
135
|
+
Initialize the population with masked sequences.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
structure (str): Target RNA structure in dot-bracket notation
|
|
139
|
+
num_population (int): Number of individuals in the population
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
list: List of (sequence, bp_span) tuples representing the initial population
|
|
143
|
+
"""
|
|
144
|
+
population = []
|
|
145
|
+
mlm_inputs = []
|
|
146
|
+
for _ in range(num_population):
|
|
147
|
+
masked_sequence = "".join(
|
|
148
|
+
[random.choice(["G", "C", "<mask>"]) for _ in structure]
|
|
149
|
+
)
|
|
150
|
+
mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
|
|
151
|
+
|
|
152
|
+
outputs = self._mlm_predict(mlm_inputs, structure)
|
|
153
|
+
|
|
154
|
+
for i, output in enumerate(outputs):
|
|
155
|
+
sequence = self.tokenizer.convert_ids_to_tokens(output.tolist())
|
|
156
|
+
fixed_sequence = [
|
|
157
|
+
x if x in "AGCT" else random.choice(["A", "T", "G", "C"])
|
|
158
|
+
for x in sequence
|
|
159
|
+
]
|
|
160
|
+
bp_span = self._random_bp_span(len(structure))
|
|
161
|
+
population.append(("".join(fixed_sequence), bp_span))
|
|
162
|
+
|
|
163
|
+
return population
|
|
164
|
+
|
|
165
|
+
def _mlm_mutate(self, population, structure, mutation_ratio):
|
|
166
|
+
"""
|
|
167
|
+
Apply mutation to the population using the masked language model (MLM).
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
population (list): Current population of (sequence, bp_span) tuples
|
|
171
|
+
structure (str): Target RNA structure
|
|
172
|
+
mutation_ratio (float): Ratio of tokens to mutate
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
list: Mutated population of (sequence, bp_span) tuples
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
def mutate(sequence, mutation_rate):
|
|
179
|
+
sequence = np.array(list(sequence))
|
|
180
|
+
masked_indices = np.random.rand(len(sequence)) < mutation_rate
|
|
181
|
+
sequence[masked_indices] = "$"
|
|
182
|
+
return "".join(sequence).replace("$", "<mask>")
|
|
183
|
+
|
|
184
|
+
mlm_inputs = []
|
|
185
|
+
for sequence, bp_span in population:
|
|
186
|
+
masked_sequence = mutate(sequence, mutation_ratio)
|
|
187
|
+
mlm_inputs.append(f"{masked_sequence}<eos>{structure}")
|
|
188
|
+
|
|
189
|
+
outputs = self._mlm_predict(mlm_inputs, structure)
|
|
190
|
+
|
|
191
|
+
mut_population = []
|
|
192
|
+
for i, (seq, bp_span) in enumerate(population):
|
|
193
|
+
sequence = self.tokenizer.convert_ids_to_tokens(outputs[i].tolist())
|
|
194
|
+
fixed_sequence = [
|
|
195
|
+
x if x in "AGCT" else random.choice(["A", "T", "G", "C"])
|
|
196
|
+
for x in sequence
|
|
197
|
+
]
|
|
198
|
+
bp_span = self._random_bp_span(bp_span)
|
|
199
|
+
mut_population.append(("".join(fixed_sequence), bp_span))
|
|
200
|
+
|
|
201
|
+
return mut_population
|
|
202
|
+
|
|
203
|
+
def _crossover(self, population, num_points=3):
|
|
204
|
+
"""
|
|
205
|
+
Perform crossover operation to create offspring.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
population (list): Current population of (sequence, bp_span) tuples
|
|
209
|
+
num_points (int): Number of crossover points (default: 3)
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
list: Offspring population after crossover
|
|
213
|
+
"""
|
|
214
|
+
population_size = len(population)
|
|
215
|
+
sequence_length = len(population[0][0])
|
|
216
|
+
|
|
217
|
+
parent_indices = np.random.choice(population_size // 10, (population_size, 2))
|
|
218
|
+
crossover_points = np.sort(
|
|
219
|
+
np.random.randint(1, sequence_length, size=(population_size, num_points)),
|
|
220
|
+
axis=1,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
masks = np.zeros((population_size, sequence_length), dtype=bool)
|
|
224
|
+
for i in range(population_size):
|
|
225
|
+
last_point = 0
|
|
226
|
+
for j in range(num_points):
|
|
227
|
+
masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
|
|
228
|
+
last_point = crossover_points[i, j]
|
|
229
|
+
masks[i, last_point:] = num_points % 2 == 0
|
|
230
|
+
|
|
231
|
+
population_array = np.array([list(seq[0]) for seq in population])
|
|
232
|
+
child1_array = np.where(
|
|
233
|
+
masks,
|
|
234
|
+
population_array[parent_indices[:, 0]],
|
|
235
|
+
population_array[parent_indices[:, 1]],
|
|
236
|
+
)
|
|
237
|
+
child2_array = np.where(
|
|
238
|
+
masks,
|
|
239
|
+
population_array[parent_indices[:, 1]],
|
|
240
|
+
population_array[parent_indices[:, 0]],
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return [
|
|
244
|
+
("".join(child), bp_span)
|
|
245
|
+
for child, (_, bp_span) in zip(child1_array, population)
|
|
246
|
+
] + [
|
|
247
|
+
("".join(child), bp_span)
|
|
248
|
+
for child, (_, bp_span) in zip(child2_array, population)
|
|
249
|
+
]
|
|
250
|
+
|
|
251
|
+
def _evaluate_structure_fitness(self, sequences, structure):
|
|
252
|
+
"""
|
|
253
|
+
Evaluate the fitness of the RNA structure by comparing with the target structure.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
sequences (list): List of (sequence, bp_span) tuples to evaluate
|
|
257
|
+
structure (str): Target RNA structure
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
list: Sorted population with fitness scores and MFE values
|
|
261
|
+
"""
|
|
262
|
+
if self.parallel:
|
|
263
|
+
with ProcessPoolExecutor() as executor:
|
|
264
|
+
structures_mfe = list(
|
|
265
|
+
executor.map(
|
|
266
|
+
self._predict_structure_single, [seq for seq, _ in sequences]
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
structures_mfe = self._predict_structure([seq for seq, _ in sequences])
|
|
271
|
+
|
|
272
|
+
sorted_population = []
|
|
273
|
+
for (seq, bp_span), (ss, mfe) in zip(sequences, structures_mfe):
|
|
274
|
+
score = hamming(list(structure), list(ss))
|
|
275
|
+
sorted_population.append((seq, bp_span, score, mfe))
|
|
276
|
+
|
|
277
|
+
fronts = self._non_dominated_sorting(
|
|
278
|
+
[x[2] for x in sorted_population], [x[3] for x in sorted_population]
|
|
279
|
+
)
|
|
280
|
+
return self._select_next_generation(sorted_population, fronts)
|
|
281
|
+
|
|
282
|
+
@staticmethod
|
|
283
|
+
def _non_dominated_sorting(scores, mfe_values):
|
|
284
|
+
"""
|
|
285
|
+
Perform non-dominated sorting for multi-objective optimization.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
scores (list): Structure similarity scores
|
|
289
|
+
mfe_values (list): Minimum free energy values
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
list: List of fronts (Pareto fronts)
|
|
293
|
+
"""
|
|
294
|
+
num_solutions = len(scores)
|
|
295
|
+
domination_count = [0] * num_solutions
|
|
296
|
+
dominated_solutions = [[] for _ in range(num_solutions)]
|
|
297
|
+
fronts = [[]]
|
|
298
|
+
|
|
299
|
+
for p in range(num_solutions):
|
|
300
|
+
for q in range(num_solutions):
|
|
301
|
+
if scores[p] < scores[q] and mfe_values[p] < mfe_values[q]:
|
|
302
|
+
dominated_solutions[p].append(q)
|
|
303
|
+
elif scores[q] < scores[p] and mfe_values[q] < mfe_values[p]:
|
|
304
|
+
domination_count[p] += 1
|
|
305
|
+
|
|
306
|
+
if domination_count[p] == 0:
|
|
307
|
+
fronts[0].append(p)
|
|
308
|
+
|
|
309
|
+
i = 0
|
|
310
|
+
while len(fronts[i]) > 0:
|
|
311
|
+
next_front = []
|
|
312
|
+
for p in fronts[i]:
|
|
313
|
+
for q in dominated_solutions[p]:
|
|
314
|
+
domination_count[q] -= 1
|
|
315
|
+
if domination_count[q] == 0:
|
|
316
|
+
next_front.append(q)
|
|
317
|
+
i += 1
|
|
318
|
+
fronts.append(next_front)
|
|
319
|
+
|
|
320
|
+
if not fronts[-1]: # Ensure the last front is not empty before removing
|
|
321
|
+
fronts.pop(-1)
|
|
322
|
+
|
|
323
|
+
return fronts
|
|
324
|
+
|
|
325
|
+
@staticmethod
|
|
326
|
+
def _select_next_generation(next_generation, fronts):
|
|
327
|
+
"""
|
|
328
|
+
Select the next generation based on Pareto fronts.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
next_generation (list): Current population with fitness scores
|
|
332
|
+
fronts (list): Pareto fronts
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
list: Selected population for the next generation
|
|
336
|
+
"""
|
|
337
|
+
sorted_population = []
|
|
338
|
+
for front in fronts:
|
|
339
|
+
front_population = [next_generation[i] for i in front]
|
|
340
|
+
sorted_population.extend(front_population)
|
|
341
|
+
if len(sorted_population) >= len(next_generation):
|
|
342
|
+
break
|
|
343
|
+
|
|
344
|
+
return sorted_population[: len(next_generation)]
|
|
345
|
+
|
|
346
|
+
def _mlm_predict(self, mlm_inputs, structure):
|
|
347
|
+
"""
|
|
348
|
+
Perform masked language model prediction.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
mlm_inputs (list): List of masked input sequences
|
|
352
|
+
structure (str): Target RNA structure
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
list: Predicted token IDs for each input
|
|
356
|
+
"""
|
|
357
|
+
batch_size = 8
|
|
358
|
+
all_outputs = []
|
|
359
|
+
|
|
360
|
+
with torch.no_grad():
|
|
361
|
+
for i in range(0, len(mlm_inputs), batch_size):
|
|
362
|
+
inputs = self.tokenizer(
|
|
363
|
+
mlm_inputs[i: i + batch_size],
|
|
364
|
+
padding=False,
|
|
365
|
+
max_length=1024,
|
|
366
|
+
truncation=True,
|
|
367
|
+
return_tensors="pt",
|
|
368
|
+
)
|
|
369
|
+
inputs = {
|
|
370
|
+
key: value.to(self.model.device) for key, value in inputs.items()
|
|
371
|
+
}
|
|
372
|
+
outputs = self.model(**inputs)[0].argmax(dim=-1)
|
|
373
|
+
all_outputs.append(outputs)
|
|
374
|
+
|
|
375
|
+
return torch.cat(all_outputs, dim=0)[:, 1 : 1 + len(structure)]
|
|
376
|
+
|
|
377
|
+
def design(
|
|
378
|
+
self, structure, mutation_ratio=0.5, num_population=100, num_generation=100
|
|
379
|
+
):
|
|
380
|
+
"""
|
|
381
|
+
Design RNA sequences for a target structure using evolutionary algorithms.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
structure (str): Target RNA structure in dot-bracket notation
|
|
385
|
+
mutation_ratio (float): Ratio of tokens to mutate (default: 0.5)
|
|
386
|
+
num_population (int): Population size (default: 100)
|
|
387
|
+
num_generation (int): Number of generations (default: 100)
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
list: List of designed RNA sequences with their fitness scores
|
|
391
|
+
"""
|
|
392
|
+
population = self._init_population(structure, num_population)
|
|
393
|
+
population = self._mlm_mutate(population, structure, mutation_ratio)
|
|
394
|
+
|
|
395
|
+
for generation_id in range(num_generation):
|
|
396
|
+
next_generation = self._crossover(population)
|
|
397
|
+
next_generation = self._mlm_mutate(
|
|
398
|
+
next_generation, structure, mutation_ratio
|
|
399
|
+
)
|
|
400
|
+
next_generation = self._evaluate_structure_fitness(
|
|
401
|
+
next_generation, structure
|
|
402
|
+
)[:num_population]
|
|
403
|
+
|
|
404
|
+
candidate_sequences = [
|
|
405
|
+
seq for seq, bp_span, score, mfe in next_generation if score == 0
|
|
406
|
+
]
|
|
407
|
+
if candidate_sequences:
|
|
408
|
+
return candidate_sequences
|
|
409
|
+
|
|
410
|
+
population = [
|
|
411
|
+
(seq, bp_span) for seq, bp_span, score, mfe in next_generation
|
|
412
|
+
]
|
|
413
|
+
|
|
414
|
+
return population[0][0]
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
# Example usage
|
|
418
|
+
if __name__ == "__main__":
|
|
419
|
+
model = OmniModelForRNADesign(model="anonymous8/OmniGenome-186M")
|
|
420
|
+
best_sequence = model.design(
|
|
421
|
+
structure="(((....)))",
|
|
422
|
+
mutation_ratio=0.5,
|
|
423
|
+
num_population=100,
|
|
424
|
+
num_generation=100,
|
|
425
|
+
)
|
|
426
|
+
fprint(f"Best RNA sequence: {best_sequence}")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 22:21 08/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains modules for sequence-to-sequence models.
|
|
11
|
+
"""
|
|
12
|
+
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: model.py
|
|
3
|
+
# time: 11:40 14/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
Sequence-to-sequence model for genomic sequences.
|
|
11
|
+
|
|
12
|
+
This module provides a sequence-to-sequence model implementation for genomic
|
|
13
|
+
sequences. It's designed for tasks where the input and output are both
|
|
14
|
+
sequences, such as sequence translation, structure prediction, or sequence
|
|
15
|
+
transformation tasks.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from ...abc.abstract_model import OmniModel
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OmniModelForSeq2Seq(OmniModel):
|
|
22
|
+
"""
|
|
23
|
+
Sequence-to-sequence model for genomic sequences.
|
|
24
|
+
|
|
25
|
+
This model implements a sequence-to-sequence architecture for genomic
|
|
26
|
+
sequences, where the input is one sequence and the output is another
|
|
27
|
+
sequence. It's useful for tasks like sequence translation, structure
|
|
28
|
+
prediction, or sequence transformation.
|
|
29
|
+
|
|
30
|
+
The model can be extended to implement specific seq2seq tasks by
|
|
31
|
+
overriding the forward, predict, and inference methods.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the sequence-to-sequence model.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
config_or_model: Model configuration or pre-trained model
|
|
40
|
+
tokenizer: Tokenizer for processing input sequences
|
|
41
|
+
*args: Additional positional arguments
|
|
42
|
+
**kwargs: Additional keyword arguments
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 18:05 08/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains tokenizer implementations.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
from .bpe_tokenizer import OmniBPETokenizer
|
|
15
|
+
from .kmers_tokenizer import OmniKmersTokenizer
|
|
16
|
+
from .single_nucleotide_tokenizer import OmniSingleNucleotideTokenizer
|