omnigenome 0.3.1a0__py3-none-any.whl → 0.3.4a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +252 -266
- {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/METADATA +9 -9
- omnigenome-0.3.4a0.dist-info/RECORD +7 -0
- omnigenome/auto/__init__.py +0 -3
- omnigenome/auto/auto_bench/__init__.py +0 -11
- omnigenome/auto/auto_bench/auto_bench.py +0 -494
- omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
- omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
- omnigenome/auto/auto_bench/config_check.py +0 -34
- omnigenome/auto/auto_train/__init__.py +0 -12
- omnigenome/auto/auto_train/auto_train.py +0 -429
- omnigenome/auto/auto_train/auto_train_cli.py +0 -222
- omnigenome/auto/bench_hub/__init__.py +0 -11
- omnigenome/auto/bench_hub/bench_hub.py +0 -25
- omnigenome/cli/__init__.py +0 -12
- omnigenome/cli/commands/__init__.py +0 -12
- omnigenome/cli/commands/base.py +0 -83
- omnigenome/cli/commands/bench/__init__.py +0 -12
- omnigenome/cli/commands/bench/bench_cli.py +0 -202
- omnigenome/cli/commands/rna/__init__.py +0 -12
- omnigenome/cli/commands/rna/rna_design.py +0 -177
- omnigenome/cli/omnigenome_cli.py +0 -128
- omnigenome/src/__init__.py +0 -11
- omnigenome/src/abc/__init__.py +0 -11
- omnigenome/src/abc/abstract_dataset.py +0 -641
- omnigenome/src/abc/abstract_metric.py +0 -114
- omnigenome/src/abc/abstract_model.py +0 -690
- omnigenome/src/abc/abstract_tokenizer.py +0 -269
- omnigenome/src/dataset/__init__.py +0 -16
- omnigenome/src/dataset/omni_dataset.py +0 -437
- omnigenome/src/lora/__init__.py +0 -12
- omnigenome/src/lora/lora_model.py +0 -300
- omnigenome/src/metric/__init__.py +0 -15
- omnigenome/src/metric/classification_metric.py +0 -184
- omnigenome/src/metric/metric.py +0 -199
- omnigenome/src/metric/ranking_metric.py +0 -142
- omnigenome/src/metric/regression_metric.py +0 -191
- omnigenome/src/misc/__init__.py +0 -3
- omnigenome/src/misc/utils.py +0 -503
- omnigenome/src/model/__init__.py +0 -19
- omnigenome/src/model/augmentation/__init__.py +0 -11
- omnigenome/src/model/augmentation/model.py +0 -219
- omnigenome/src/model/classification/__init__.py +0 -11
- omnigenome/src/model/classification/model.py +0 -638
- omnigenome/src/model/embedding/__init__.py +0 -11
- omnigenome/src/model/embedding/model.py +0 -263
- omnigenome/src/model/mlm/__init__.py +0 -11
- omnigenome/src/model/mlm/model.py +0 -177
- omnigenome/src/model/module_utils.py +0 -232
- omnigenome/src/model/regression/__init__.py +0 -11
- omnigenome/src/model/regression/model.py +0 -781
- omnigenome/src/model/regression/resnet.py +0 -483
- omnigenome/src/model/rna_design/__init__.py +0 -11
- omnigenome/src/model/rna_design/model.py +0 -476
- omnigenome/src/model/seq2seq/__init__.py +0 -11
- omnigenome/src/model/seq2seq/model.py +0 -44
- omnigenome/src/tokenizer/__init__.py +0 -16
- omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
- omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
- omnigenome/src/trainer/__init__.py +0 -14
- omnigenome/src/trainer/accelerate_trainer.py +0 -747
- omnigenome/src/trainer/hf_trainer.py +0 -75
- omnigenome/src/trainer/trainer.py +0 -591
- omnigenome/utility/__init__.py +0 -3
- omnigenome/utility/dataset_hub/__init__.py +0 -12
- omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
- omnigenome/utility/ensemble.py +0 -324
- omnigenome/utility/hub_utils.py +0 -517
- omnigenome/utility/model_hub/__init__.py +0 -11
- omnigenome/utility/model_hub/model_hub.py +0 -232
- omnigenome/utility/pipeline_hub/__init__.py +0 -11
- omnigenome/utility/pipeline_hub/pipeline.py +0 -483
- omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
- omnigenome-0.3.1a0.dist-info/RECORD +0 -78
- {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/top_level.txt +0 -0
|
@@ -1,476 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: model.py
|
|
3
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
4
|
-
# github: https://github.com/yangheng95
|
|
5
|
-
# huggingface: https://huggingface.co/yangheng
|
|
6
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
7
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
8
|
-
"""
|
|
9
|
-
RNA design model using masked language modeling and evolutionary algorithms.
|
|
10
|
-
|
|
11
|
-
This module provides an RNA design model that combines masked language modeling
|
|
12
|
-
with evolutionary algorithms to design RNA sequences that fold into specific
|
|
13
|
-
target structures. It uses a multi-objective optimization approach to balance
|
|
14
|
-
structure similarity and thermodynamic stability.
|
|
15
|
-
"""
|
|
16
|
-
import random
|
|
17
|
-
import numpy as np
|
|
18
|
-
import torch
|
|
19
|
-
import autocuda
|
|
20
|
-
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
21
|
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
22
|
-
import ViennaRNA
|
|
23
|
-
from scipy.spatial.distance import hamming
|
|
24
|
-
import warnings
|
|
25
|
-
import os
|
|
26
|
-
|
|
27
|
-
from omnigenome.src.misc.utils import fprint
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class OmniModelForRNADesign(torch.nn.Module):
|
|
31
|
-
"""
|
|
32
|
-
RNA design model using masked language modeling and evolutionary algorithms.
|
|
33
|
-
|
|
34
|
-
This model combines a pre-trained masked language model with evolutionary
|
|
35
|
-
algorithms to design RNA sequences that fold into specific target structures.
|
|
36
|
-
It uses a multi-objective optimization approach to balance structure similarity
|
|
37
|
-
and thermodynamic stability.
|
|
38
|
-
|
|
39
|
-
Attributes:
|
|
40
|
-
device: Device to run the model on (CPU or GPU)
|
|
41
|
-
parallel: Whether to use parallel processing for structure prediction
|
|
42
|
-
tokenizer: Tokenizer for processing RNA sequences
|
|
43
|
-
model: Pre-trained masked language model
|
|
44
|
-
"""
|
|
45
|
-
|
|
46
|
-
def __init__(
|
|
47
|
-
self,
|
|
48
|
-
model="yangheng/OmniGenome-186M",
|
|
49
|
-
device=None,
|
|
50
|
-
parallel=False,
|
|
51
|
-
*args,
|
|
52
|
-
**kwargs,
|
|
53
|
-
):
|
|
54
|
-
"""
|
|
55
|
-
Initialize the RNA design model.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
model (str): Model name or path for the pre-trained MLM model
|
|
59
|
-
device: Device to run the model on (default: None, auto-detect)
|
|
60
|
-
parallel (bool): Whether to use parallel processing (default: False)
|
|
61
|
-
*args: Additional positional arguments
|
|
62
|
-
**kwargs: Additional keyword arguments
|
|
63
|
-
"""
|
|
64
|
-
super().__init__(*args, **kwargs)
|
|
65
|
-
self.device = autocuda.auto_cuda() if device is None else device
|
|
66
|
-
self.parallel = parallel
|
|
67
|
-
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
|
68
|
-
self.model = AutoModelForMaskedLM.from_pretrained(model, trust_remote_code=True)
|
|
69
|
-
self.model.to(self.device).to(torch.float16)
|
|
70
|
-
|
|
71
|
-
@staticmethod
|
|
72
|
-
def _random_bp_span(bp_span=None):
|
|
73
|
-
"""
|
|
74
|
-
Generate a random base pair span.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
bp_span (int, optional): Fixed base pair span. If None, generates random.
|
|
78
|
-
|
|
79
|
-
Returns:
|
|
80
|
-
int: Base pair span value
|
|
81
|
-
"""
|
|
82
|
-
if bp_span is None:
|
|
83
|
-
return random.randint(1, 10)
|
|
84
|
-
return bp_span
|
|
85
|
-
|
|
86
|
-
@staticmethod
|
|
87
|
-
def _longest_bp_span(structure):
|
|
88
|
-
"""
|
|
89
|
-
Find the longest base pair span in the structure.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
structure (str): RNA structure in dot-bracket notation
|
|
93
|
-
|
|
94
|
-
Returns:
|
|
95
|
-
int: Length of the longest base pair span
|
|
96
|
-
"""
|
|
97
|
-
max_span = 0
|
|
98
|
-
current_span = 0
|
|
99
|
-
|
|
100
|
-
for char in structure:
|
|
101
|
-
if char == "(":
|
|
102
|
-
current_span += 1
|
|
103
|
-
max_span = max(max_span, current_span)
|
|
104
|
-
elif char == ")":
|
|
105
|
-
current_span = max(0, current_span - 1)
|
|
106
|
-
else:
|
|
107
|
-
current_span = 0
|
|
108
|
-
|
|
109
|
-
return max_span
|
|
110
|
-
|
|
111
|
-
@staticmethod
|
|
112
|
-
def _predict_structure_single(sequence, bp_span=-1):
|
|
113
|
-
"""
|
|
114
|
-
Predict structure for a single sequence (worker function for multiprocessing).
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
sequence (str): RNA sequence to fold
|
|
118
|
-
bp_span (int): Base pair span parameter
|
|
119
|
-
|
|
120
|
-
Returns:
|
|
121
|
-
tuple: (structure, mfe) tuple
|
|
122
|
-
"""
|
|
123
|
-
try:
|
|
124
|
-
return ViennaRNA.fold(sequence)
|
|
125
|
-
except Exception as e:
|
|
126
|
-
warnings.warn(f"Failed to fold sequence {sequence}: {e}")
|
|
127
|
-
return ("." * len(sequence), 0.0)
|
|
128
|
-
|
|
129
|
-
def _predict_structure(self, sequences, bp_span=-1):
|
|
130
|
-
"""
|
|
131
|
-
Predict structures for multiple sequences.
|
|
132
|
-
|
|
133
|
-
Args:
|
|
134
|
-
sequences (list): List of RNA sequences
|
|
135
|
-
bp_span (int): Base pair span parameter
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
list: List of (structure, mfe) tuples
|
|
139
|
-
"""
|
|
140
|
-
if not self.parallel or len(sequences) <= 1:
|
|
141
|
-
# Sequential processing
|
|
142
|
-
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
143
|
-
|
|
144
|
-
# Parallel processing with improved error handling
|
|
145
|
-
try:
|
|
146
|
-
# Determine number of workers
|
|
147
|
-
max_workers = min(os.cpu_count(), len(sequences), 8) # Limit to 8 workers
|
|
148
|
-
|
|
149
|
-
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
150
|
-
# Submit all tasks
|
|
151
|
-
future_to_seq = {
|
|
152
|
-
executor.submit(self._predict_structure_single, seq, bp_span): seq
|
|
153
|
-
for seq in sequences
|
|
154
|
-
}
|
|
155
|
-
|
|
156
|
-
# Collect results
|
|
157
|
-
results = []
|
|
158
|
-
for future in as_completed(future_to_seq):
|
|
159
|
-
try:
|
|
160
|
-
result = future.result()
|
|
161
|
-
results.append(result)
|
|
162
|
-
except Exception as e:
|
|
163
|
-
seq = future_to_seq[future]
|
|
164
|
-
warnings.warn(f"Failed to process sequence {seq}: {e}")
|
|
165
|
-
# Fallback to dot structure
|
|
166
|
-
results.append(("." * len(seq), 0.0))
|
|
167
|
-
|
|
168
|
-
return results
|
|
169
|
-
|
|
170
|
-
except Exception as e:
|
|
171
|
-
warnings.warn(
|
|
172
|
-
f"Parallel processing failed, falling back to sequential: {e}"
|
|
173
|
-
)
|
|
174
|
-
# Fallback to sequential processing
|
|
175
|
-
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
176
|
-
|
|
177
|
-
def _init_population(self, structure, num_population):
|
|
178
|
-
"""
|
|
179
|
-
Initialize the population with random sequences.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
structure (str): Target RNA structure
|
|
183
|
-
num_population (int): Population size
|
|
184
|
-
|
|
185
|
-
Returns:
|
|
186
|
-
list: List of (sequence, bp_span) tuples
|
|
187
|
-
"""
|
|
188
|
-
population = []
|
|
189
|
-
bp_span = self._longest_bp_span(structure)
|
|
190
|
-
|
|
191
|
-
for _ in range(num_population):
|
|
192
|
-
# Generate random sequence
|
|
193
|
-
sequence = "".join(random.choice("ACGU") for _ in range(len(structure)))
|
|
194
|
-
population.append((sequence, bp_span))
|
|
195
|
-
|
|
196
|
-
return population
|
|
197
|
-
|
|
198
|
-
def _mlm_mutate(self, population, structure, mutation_ratio):
|
|
199
|
-
"""
|
|
200
|
-
Mutate population using masked language modeling.
|
|
201
|
-
|
|
202
|
-
Args:
|
|
203
|
-
population (list): Current population
|
|
204
|
-
structure (str): Target RNA structure
|
|
205
|
-
mutation_ratio (float): Ratio of tokens to mutate
|
|
206
|
-
|
|
207
|
-
Returns:
|
|
208
|
-
list: Mutated population
|
|
209
|
-
"""
|
|
210
|
-
|
|
211
|
-
def mutate(sequence, mutation_rate):
|
|
212
|
-
# Create masked sequence
|
|
213
|
-
masked_sequence = list(sequence)
|
|
214
|
-
num_mutations = int(len(sequence) * mutation_rate)
|
|
215
|
-
mutation_positions = random.sample(range(len(sequence)), num_mutations)
|
|
216
|
-
|
|
217
|
-
for pos in mutation_positions:
|
|
218
|
-
masked_sequence[pos] = self.tokenizer.mask_token
|
|
219
|
-
|
|
220
|
-
return "".join(masked_sequence)
|
|
221
|
-
|
|
222
|
-
# Prepare inputs for MLM
|
|
223
|
-
mlm_inputs = []
|
|
224
|
-
for sequence, bp_span in population:
|
|
225
|
-
masked_seq = mutate(sequence, mutation_ratio)
|
|
226
|
-
mlm_inputs.append(masked_seq)
|
|
227
|
-
|
|
228
|
-
# Get predictions from MLM
|
|
229
|
-
predicted_tokens = self._mlm_predict(mlm_inputs, structure)
|
|
230
|
-
|
|
231
|
-
# Convert predictions back to sequences
|
|
232
|
-
mutated_population = []
|
|
233
|
-
for i, (sequence, bp_span) in enumerate(population):
|
|
234
|
-
# Convert token IDs back to nucleotides
|
|
235
|
-
new_sequence = self.tokenizer.decode(
|
|
236
|
-
predicted_tokens[i], skip_special_tokens=True
|
|
237
|
-
)
|
|
238
|
-
# Ensure the sequence has the correct length
|
|
239
|
-
if len(new_sequence) != len(structure):
|
|
240
|
-
new_sequence = new_sequence[: len(structure)].ljust(len(structure), "A")
|
|
241
|
-
mutated_population.append((new_sequence, bp_span))
|
|
242
|
-
|
|
243
|
-
return mutated_population
|
|
244
|
-
|
|
245
|
-
def _crossover(self, population, num_points=3):
|
|
246
|
-
"""
|
|
247
|
-
Perform crossover operation on the population.
|
|
248
|
-
|
|
249
|
-
Args:
|
|
250
|
-
population (list): Current population
|
|
251
|
-
num_points (int): Number of crossover points
|
|
252
|
-
|
|
253
|
-
Returns:
|
|
254
|
-
list: Population after crossover
|
|
255
|
-
"""
|
|
256
|
-
if len(population) < 2:
|
|
257
|
-
return population
|
|
258
|
-
|
|
259
|
-
# Create crossover masks
|
|
260
|
-
num_sequences = len(population)
|
|
261
|
-
masks = np.zeros((num_sequences, len(population[0][0])), dtype=bool)
|
|
262
|
-
|
|
263
|
-
# Generate random crossover points
|
|
264
|
-
crossover_points = np.random.randint(
|
|
265
|
-
0, len(population[0][0]), (num_sequences, num_points)
|
|
266
|
-
)
|
|
267
|
-
|
|
268
|
-
# Create parent indices
|
|
269
|
-
parent_indices = np.random.randint(0, num_sequences, (num_sequences, 2))
|
|
270
|
-
|
|
271
|
-
# Generate crossover masks
|
|
272
|
-
for i in range(num_sequences):
|
|
273
|
-
for j in range(num_points):
|
|
274
|
-
if j == 0:
|
|
275
|
-
masks[i, : crossover_points[i, j]] = True
|
|
276
|
-
else:
|
|
277
|
-
last_point = crossover_points[i, j - 1]
|
|
278
|
-
masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
|
|
279
|
-
|
|
280
|
-
# Handle the last segment
|
|
281
|
-
last_point = crossover_points[i, -1]
|
|
282
|
-
masks[i, last_point:] = num_points % 2 == 0
|
|
283
|
-
|
|
284
|
-
# Perform crossover
|
|
285
|
-
population_array = np.array([list(seq[0]) for seq in population])
|
|
286
|
-
child1_array = np.where(
|
|
287
|
-
masks,
|
|
288
|
-
population_array[parent_indices[:, 0]],
|
|
289
|
-
population_array[parent_indices[:, 1]],
|
|
290
|
-
)
|
|
291
|
-
child2_array = np.where(
|
|
292
|
-
masks,
|
|
293
|
-
population_array[parent_indices[:, 1]],
|
|
294
|
-
population_array[parent_indices[:, 0]],
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
return [
|
|
298
|
-
("".join(child), bp_span)
|
|
299
|
-
for child, (_, bp_span) in zip(child1_array, population)
|
|
300
|
-
] + [
|
|
301
|
-
("".join(child), bp_span)
|
|
302
|
-
for child, (_, bp_span) in zip(child2_array, population)
|
|
303
|
-
]
|
|
304
|
-
|
|
305
|
-
def _evaluate_structure_fitness(self, sequences, structure):
|
|
306
|
-
"""
|
|
307
|
-
Evaluate the fitness of the RNA structure by comparing with the target structure.
|
|
308
|
-
|
|
309
|
-
Args:
|
|
310
|
-
sequences (list): List of (sequence, bp_span) tuples to evaluate
|
|
311
|
-
structure (str): Target RNA structure
|
|
312
|
-
|
|
313
|
-
Returns:
|
|
314
|
-
list: Sorted population with fitness scores and MFE values
|
|
315
|
-
"""
|
|
316
|
-
# Get sequences for structure prediction
|
|
317
|
-
seq_list = [seq for seq, _ in sequences]
|
|
318
|
-
|
|
319
|
-
# Predict structures (with improved multiprocessing)
|
|
320
|
-
structures_mfe = self._predict_structure(seq_list)
|
|
321
|
-
|
|
322
|
-
sorted_population = []
|
|
323
|
-
for (seq, bp_span), (ss, mfe) in zip(sequences, structures_mfe):
|
|
324
|
-
score = hamming(list(structure), list(ss))
|
|
325
|
-
sorted_population.append((seq, bp_span, score, mfe))
|
|
326
|
-
|
|
327
|
-
fronts = self._non_dominated_sorting(
|
|
328
|
-
[x[2] for x in sorted_population], [x[3] for x in sorted_population]
|
|
329
|
-
)
|
|
330
|
-
return self._select_next_generation(sorted_population, fronts)
|
|
331
|
-
|
|
332
|
-
@staticmethod
|
|
333
|
-
def _non_dominated_sorting(scores, mfe_values):
|
|
334
|
-
"""
|
|
335
|
-
Perform non-dominated sorting for multi-objective optimization.
|
|
336
|
-
|
|
337
|
-
Args:
|
|
338
|
-
scores (list): Structure similarity scores
|
|
339
|
-
mfe_values (list): Minimum free energy values
|
|
340
|
-
|
|
341
|
-
Returns:
|
|
342
|
-
list: List of fronts (Pareto fronts)
|
|
343
|
-
"""
|
|
344
|
-
num_solutions = len(scores)
|
|
345
|
-
domination_count = [0] * num_solutions
|
|
346
|
-
dominated_solutions = [[] for _ in range(num_solutions)]
|
|
347
|
-
fronts = [[]]
|
|
348
|
-
|
|
349
|
-
for p in range(num_solutions):
|
|
350
|
-
for q in range(num_solutions):
|
|
351
|
-
if scores[p] < scores[q] and mfe_values[p] < mfe_values[q]:
|
|
352
|
-
dominated_solutions[p].append(q)
|
|
353
|
-
elif scores[q] < scores[p] and mfe_values[q] < mfe_values[p]:
|
|
354
|
-
domination_count[p] += 1
|
|
355
|
-
|
|
356
|
-
if domination_count[p] == 0:
|
|
357
|
-
fronts[0].append(p)
|
|
358
|
-
|
|
359
|
-
i = 0
|
|
360
|
-
while len(fronts[i]) > 0:
|
|
361
|
-
next_front = []
|
|
362
|
-
for p in fronts[i]:
|
|
363
|
-
for q in dominated_solutions[p]:
|
|
364
|
-
domination_count[q] -= 1
|
|
365
|
-
if domination_count[q] == 0:
|
|
366
|
-
next_front.append(q)
|
|
367
|
-
i += 1
|
|
368
|
-
fronts.append(next_front)
|
|
369
|
-
|
|
370
|
-
if not fronts[-1]: # Ensure the last front is not empty before removing
|
|
371
|
-
fronts.pop(-1)
|
|
372
|
-
|
|
373
|
-
return fronts
|
|
374
|
-
|
|
375
|
-
@staticmethod
|
|
376
|
-
def _select_next_generation(next_generation, fronts):
|
|
377
|
-
"""
|
|
378
|
-
Select the next generation based on Pareto fronts.
|
|
379
|
-
|
|
380
|
-
Args:
|
|
381
|
-
next_generation (list): Current population with fitness scores
|
|
382
|
-
fronts (list): Pareto fronts
|
|
383
|
-
|
|
384
|
-
Returns:
|
|
385
|
-
list: Selected population for the next generation
|
|
386
|
-
"""
|
|
387
|
-
sorted_population = []
|
|
388
|
-
for front in fronts:
|
|
389
|
-
front_population = [next_generation[i] for i in front]
|
|
390
|
-
sorted_population.extend(front_population)
|
|
391
|
-
if len(sorted_population) >= len(next_generation):
|
|
392
|
-
break
|
|
393
|
-
|
|
394
|
-
return sorted_population[: len(next_generation)]
|
|
395
|
-
|
|
396
|
-
def _mlm_predict(self, mlm_inputs, structure):
|
|
397
|
-
"""
|
|
398
|
-
Perform masked language model prediction.
|
|
399
|
-
|
|
400
|
-
Args:
|
|
401
|
-
mlm_inputs (list): List of masked input sequences
|
|
402
|
-
structure (str): Target RNA structure
|
|
403
|
-
|
|
404
|
-
Returns:
|
|
405
|
-
list: Predicted token IDs for each input
|
|
406
|
-
"""
|
|
407
|
-
batch_size = 8
|
|
408
|
-
all_outputs = []
|
|
409
|
-
|
|
410
|
-
with torch.no_grad():
|
|
411
|
-
for i in range(0, len(mlm_inputs), batch_size):
|
|
412
|
-
inputs = self.tokenizer(
|
|
413
|
-
mlm_inputs[i : i + batch_size],
|
|
414
|
-
padding=False,
|
|
415
|
-
max_length=1024,
|
|
416
|
-
truncation=True,
|
|
417
|
-
return_tensors="pt",
|
|
418
|
-
)
|
|
419
|
-
inputs = {
|
|
420
|
-
key: value.to(self.model.device) for key, value in inputs.items()
|
|
421
|
-
}
|
|
422
|
-
outputs = self.model(**inputs)[0].argmax(dim=-1)
|
|
423
|
-
all_outputs.append(outputs)
|
|
424
|
-
|
|
425
|
-
return torch.cat(all_outputs, dim=0)[:, 1 : 1 + len(structure)]
|
|
426
|
-
|
|
427
|
-
def design(
|
|
428
|
-
self, structure, mutation_ratio=0.5, num_population=100, num_generation=100
|
|
429
|
-
):
|
|
430
|
-
"""
|
|
431
|
-
Design RNA sequences for a target structure using evolutionary algorithms.
|
|
432
|
-
|
|
433
|
-
Args:
|
|
434
|
-
structure (str): Target RNA structure in dot-bracket notation
|
|
435
|
-
mutation_ratio (float): Ratio of tokens to mutate (default: 0.5)
|
|
436
|
-
num_population (int): Population size (default: 100)
|
|
437
|
-
num_generation (int): Number of generations (default: 100)
|
|
438
|
-
|
|
439
|
-
Returns:
|
|
440
|
-
list: List of designed RNA sequences with their fitness scores
|
|
441
|
-
"""
|
|
442
|
-
population = self._init_population(structure, num_population)
|
|
443
|
-
population = self._mlm_mutate(population, structure, mutation_ratio)
|
|
444
|
-
|
|
445
|
-
for generation_id in range(num_generation):
|
|
446
|
-
next_generation = self._crossover(population)
|
|
447
|
-
next_generation = self._mlm_mutate(
|
|
448
|
-
next_generation, structure, mutation_ratio
|
|
449
|
-
)
|
|
450
|
-
next_generation = self._evaluate_structure_fitness(
|
|
451
|
-
next_generation, structure
|
|
452
|
-
)[:num_population]
|
|
453
|
-
|
|
454
|
-
candidate_sequences = [
|
|
455
|
-
seq for seq, bp_span, score, mfe in next_generation if score == 0
|
|
456
|
-
]
|
|
457
|
-
if candidate_sequences:
|
|
458
|
-
return candidate_sequences
|
|
459
|
-
|
|
460
|
-
population = [
|
|
461
|
-
(seq, bp_span) for seq, bp_span, score, mfe in next_generation
|
|
462
|
-
]
|
|
463
|
-
|
|
464
|
-
return population[0][0]
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
# Example usage
|
|
468
|
-
if __name__ == "__main__":
|
|
469
|
-
model = OmniModelForRNADesign(model="anonymous8/OmniGenome-186M")
|
|
470
|
-
best_sequence = model.design(
|
|
471
|
-
structure="(((....)))",
|
|
472
|
-
mutation_ratio=0.5,
|
|
473
|
-
num_population=100,
|
|
474
|
-
num_generation=100,
|
|
475
|
-
)
|
|
476
|
-
fprint(f"Best RNA sequence: {best_sequence}")
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: __init__.py
|
|
3
|
-
# time: 22:21 08/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
This package contains modules for sequence-to-sequence models.
|
|
11
|
-
"""
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: model.py
|
|
3
|
-
# time: 11:40 14/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
Sequence-to-sequence model for genomic sequences.
|
|
11
|
-
|
|
12
|
-
This module provides a sequence-to-sequence model implementation for genomic
|
|
13
|
-
sequences. It's designed for tasks where the input and output are both
|
|
14
|
-
sequences, such as sequence translation, structure prediction, or sequence
|
|
15
|
-
transformation tasks.
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
from ...abc.abstract_model import OmniModel
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class OmniModelForSeq2Seq(OmniModel):
|
|
22
|
-
"""
|
|
23
|
-
Sequence-to-sequence model for genomic sequences.
|
|
24
|
-
|
|
25
|
-
This model implements a sequence-to-sequence architecture for genomic
|
|
26
|
-
sequences, where the input is one sequence and the output is another
|
|
27
|
-
sequence. It's useful for tasks like sequence translation, structure
|
|
28
|
-
prediction, or sequence transformation.
|
|
29
|
-
|
|
30
|
-
The model can be extended to implement specific seq2seq tasks by
|
|
31
|
-
overriding the forward, predict, and inference methods.
|
|
32
|
-
"""
|
|
33
|
-
|
|
34
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
35
|
-
"""
|
|
36
|
-
Initialize the sequence-to-sequence model.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
config_or_model: Model configuration or pre-trained model
|
|
40
|
-
tokenizer: Tokenizer for processing input sequences
|
|
41
|
-
*args: Additional positional arguments
|
|
42
|
-
**kwargs: Additional keyword arguments
|
|
43
|
-
"""
|
|
44
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: __init__.py
|
|
3
|
-
# time: 18:05 08/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
This package contains tokenizer implementations.
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
from .bpe_tokenizer import OmniBPETokenizer
|
|
15
|
-
from .kmers_tokenizer import OmniKmersTokenizer
|
|
16
|
-
from .single_nucleotide_tokenizer import OmniSingleNucleotideTokenizer
|