omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +16 -8
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +40 -36
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +65 -58
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
|
@@ -23,14 +23,14 @@ from typing import Type, Callable, Union, List, Optional
|
|
|
23
23
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
24
24
|
"""
|
|
25
25
|
3x3 convolution with padding.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
Args:
|
|
28
28
|
in_planes (int): Number of input channels
|
|
29
29
|
out_planes (int): Number of output channels
|
|
30
30
|
stride (int): Stride for the convolution (default: 1)
|
|
31
31
|
groups (int): Number of groups for grouped convolution (default: 1)
|
|
32
32
|
dilation (int): Dilation factor for the convolution (default: 1)
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
Returns:
|
|
35
35
|
nn.Conv2d: 3x3 convolution layer
|
|
36
36
|
"""
|
|
@@ -49,12 +49,12 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
|
49
49
|
def conv1x1(in_planes, out_planes, stride=1):
|
|
50
50
|
"""
|
|
51
51
|
1x1 convolution.
|
|
52
|
-
|
|
52
|
+
|
|
53
53
|
Args:
|
|
54
54
|
in_planes (int): Number of input channels
|
|
55
55
|
out_planes (int): Number of output channels
|
|
56
56
|
stride (int): Stride for the convolution (default: 1)
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
Returns:
|
|
59
59
|
nn.Conv2d: 1x1 convolution layer
|
|
60
60
|
"""
|
|
@@ -64,14 +64,14 @@ def conv1x1(in_planes, out_planes, stride=1):
|
|
|
64
64
|
def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
65
65
|
"""
|
|
66
66
|
5x5 convolution with padding.
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
Args:
|
|
69
69
|
in_planes (int): Number of input channels
|
|
70
70
|
out_planes (int): Number of output channels
|
|
71
71
|
stride (int): Stride for the convolution (default: 1)
|
|
72
72
|
groups (int): Number of groups for grouped convolution (default: 1)
|
|
73
73
|
dilation (int): Dilation factor for the convolution (default: 1)
|
|
74
|
-
|
|
74
|
+
|
|
75
75
|
Returns:
|
|
76
76
|
nn.Conv2d: 5x5 convolution layer
|
|
77
77
|
"""
|
|
@@ -90,10 +90,10 @@ def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
|
90
90
|
class BasicBlock(nn.Module):
|
|
91
91
|
"""
|
|
92
92
|
Basic ResNet block for genomic sequence processing.
|
|
93
|
-
|
|
93
|
+
|
|
94
94
|
This block implements a basic residual connection with two convolutions
|
|
95
95
|
and is optimized for processing genomic sequence data with layer normalization.
|
|
96
|
-
|
|
96
|
+
|
|
97
97
|
Attributes:
|
|
98
98
|
expansion (int): Expansion factor for the block (default: 1)
|
|
99
99
|
conv1: First 3x3 convolution layer
|
|
@@ -105,7 +105,7 @@ class BasicBlock(nn.Module):
|
|
|
105
105
|
downsample: Downsampling layer for residual connection
|
|
106
106
|
stride: Stride for the convolutions
|
|
107
107
|
"""
|
|
108
|
-
|
|
108
|
+
|
|
109
109
|
expansion: int = 1
|
|
110
110
|
|
|
111
111
|
def __init__(
|
|
@@ -121,7 +121,7 @@ class BasicBlock(nn.Module):
|
|
|
121
121
|
) -> None:
|
|
122
122
|
"""
|
|
123
123
|
Initialize the BasicBlock.
|
|
124
|
-
|
|
124
|
+
|
|
125
125
|
Args:
|
|
126
126
|
inplanes (int): Number of input channels
|
|
127
127
|
planes (int): Number of output channels
|
|
@@ -130,7 +130,7 @@ class BasicBlock(nn.Module):
|
|
|
130
130
|
groups (int): Number of groups for grouped convolution (default: 1)
|
|
131
131
|
dilation (int): Dilation factor for convolutions (default: 1)
|
|
132
132
|
norm_layer: Normalization layer type (default: None, uses LayerNorm)
|
|
133
|
-
|
|
133
|
+
|
|
134
134
|
Raises:
|
|
135
135
|
NotImplementedError: If dilation > 1 is specified
|
|
136
136
|
"""
|
|
@@ -154,10 +154,10 @@ class BasicBlock(nn.Module):
|
|
|
154
154
|
def forward(self, x: Tensor) -> Tensor:
|
|
155
155
|
"""
|
|
156
156
|
Forward pass through the BasicBlock.
|
|
157
|
-
|
|
157
|
+
|
|
158
158
|
Args:
|
|
159
159
|
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
160
|
-
|
|
160
|
+
|
|
161
161
|
Returns:
|
|
162
162
|
Tensor: Output tensor with same shape as input
|
|
163
163
|
"""
|
|
@@ -188,11 +188,11 @@ class BasicBlock(nn.Module):
|
|
|
188
188
|
class Bottleneck(nn.Module):
|
|
189
189
|
"""
|
|
190
190
|
Bottleneck ResNet block for genomic sequence processing.
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
This block implements a bottleneck residual connection with three convolutions
|
|
193
193
|
(1x1, 3x3, 1x1) and is designed for deeper networks. It's adapted from
|
|
194
194
|
the original ResNet V1.5 implementation.
|
|
195
|
-
|
|
195
|
+
|
|
196
196
|
Attributes:
|
|
197
197
|
expansion (int): Expansion factor for the block (default: 4)
|
|
198
198
|
conv1: First 1x1 convolution layer
|
|
@@ -205,7 +205,7 @@ class Bottleneck(nn.Module):
|
|
|
205
205
|
downsample: Downsampling layer for residual connection
|
|
206
206
|
stride: Stride for the convolutions
|
|
207
207
|
"""
|
|
208
|
-
|
|
208
|
+
|
|
209
209
|
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
|
210
210
|
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
|
211
211
|
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
|
@@ -227,7 +227,7 @@ class Bottleneck(nn.Module):
|
|
|
227
227
|
) -> None:
|
|
228
228
|
"""
|
|
229
229
|
Initialize the Bottleneck block.
|
|
230
|
-
|
|
230
|
+
|
|
231
231
|
Args:
|
|
232
232
|
inplanes (int): Number of input channels
|
|
233
233
|
planes (int): Number of output channels
|
|
@@ -256,10 +256,10 @@ class Bottleneck(nn.Module):
|
|
|
256
256
|
def forward(self, x: Tensor) -> Tensor:
|
|
257
257
|
"""
|
|
258
258
|
Forward pass through the Bottleneck block.
|
|
259
|
-
|
|
259
|
+
|
|
260
260
|
Args:
|
|
261
261
|
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
262
|
-
|
|
262
|
+
|
|
263
263
|
Returns:
|
|
264
264
|
Tensor: Output tensor with same shape as input
|
|
265
265
|
"""
|
|
@@ -288,11 +288,11 @@ class Bottleneck(nn.Module):
|
|
|
288
288
|
class ResNet(nn.Module):
|
|
289
289
|
"""
|
|
290
290
|
ResNet architecture adapted for genomic sequence analysis.
|
|
291
|
-
|
|
291
|
+
|
|
292
292
|
This ResNet implementation is specifically designed for processing genomic
|
|
293
293
|
sequences and their structural representations. It uses layer normalization
|
|
294
294
|
instead of batch normalization and is optimized for genomic data characteristics.
|
|
295
|
-
|
|
295
|
+
|
|
296
296
|
Attributes:
|
|
297
297
|
_norm_layer: Normalization layer type
|
|
298
298
|
inplanes: Number of input channels for the first layer
|
|
@@ -319,7 +319,7 @@ class ResNet(nn.Module):
|
|
|
319
319
|
) -> None:
|
|
320
320
|
"""
|
|
321
321
|
Initialize the ResNet architecture.
|
|
322
|
-
|
|
322
|
+
|
|
323
323
|
Args:
|
|
324
324
|
channels (int): Number of input channels
|
|
325
325
|
block: Type of ResNet block (BasicBlock or Bottleneck)
|
|
@@ -329,7 +329,7 @@ class ResNet(nn.Module):
|
|
|
329
329
|
width_per_group (int): Width per group for bottleneck blocks (default: 1)
|
|
330
330
|
replace_stride_with_dilation: Whether to replace stride with dilation (default: None)
|
|
331
331
|
norm_layer: Normalization layer type (default: None, uses LayerNorm)
|
|
332
|
-
|
|
332
|
+
|
|
333
333
|
Raises:
|
|
334
334
|
ValueError: If replace_stride_with_dilation is not None or a 3-element tuple
|
|
335
335
|
"""
|
|
@@ -379,14 +379,14 @@ class ResNet(nn.Module):
|
|
|
379
379
|
) -> nn.Sequential:
|
|
380
380
|
"""
|
|
381
381
|
Create a layer of ResNet blocks.
|
|
382
|
-
|
|
382
|
+
|
|
383
383
|
Args:
|
|
384
384
|
block: Type of ResNet block to use
|
|
385
385
|
planes (int): Number of output channels for the layer
|
|
386
386
|
blocks (int): Number of blocks in the layer
|
|
387
387
|
stride (int): Stride for the first block (default: 1)
|
|
388
388
|
dilate (bool): Whether to use dilation (default: False)
|
|
389
|
-
|
|
389
|
+
|
|
390
390
|
Returns:
|
|
391
391
|
nn.Sequential: Sequential container of ResNet blocks
|
|
392
392
|
"""
|
|
@@ -433,10 +433,10 @@ class ResNet(nn.Module):
|
|
|
433
433
|
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
434
434
|
"""
|
|
435
435
|
Forward pass implementation.
|
|
436
|
-
|
|
436
|
+
|
|
437
437
|
Args:
|
|
438
438
|
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
439
|
-
|
|
439
|
+
|
|
440
440
|
Returns:
|
|
441
441
|
Tensor: Output tensor after processing through ResNet
|
|
442
442
|
"""
|
|
@@ -456,10 +456,10 @@ class ResNet(nn.Module):
|
|
|
456
456
|
def forward(self, x: Tensor) -> Tensor:
|
|
457
457
|
"""
|
|
458
458
|
Forward pass through the ResNet.
|
|
459
|
-
|
|
459
|
+
|
|
460
460
|
Args:
|
|
461
461
|
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
462
|
-
|
|
462
|
+
|
|
463
463
|
Returns:
|
|
464
464
|
Tensor: Output tensor after processing through ResNet
|
|
465
465
|
"""
|
|
@@ -469,14 +469,14 @@ class ResNet(nn.Module):
|
|
|
469
469
|
def resnet_b16(channels=128, bbn=16):
|
|
470
470
|
"""
|
|
471
471
|
Create a ResNet-B16 model for genomic sequence analysis.
|
|
472
|
-
|
|
472
|
+
|
|
473
473
|
This function creates a ResNet model with 16 basic blocks, optimized
|
|
474
474
|
for processing genomic sequences and their structural representations.
|
|
475
|
-
|
|
475
|
+
|
|
476
476
|
Args:
|
|
477
477
|
channels (int): Number of input channels (default: 128)
|
|
478
478
|
bbn (int): Number of basic blocks (default: 16)
|
|
479
|
-
|
|
479
|
+
|
|
480
480
|
Returns:
|
|
481
481
|
ResNet: Configured ResNet model
|
|
482
482
|
"""
|
|
@@ -30,19 +30,19 @@ from omnigenome.src.misc.utils import fprint
|
|
|
30
30
|
class OmniModelForRNADesign(torch.nn.Module):
|
|
31
31
|
"""
|
|
32
32
|
RNA design model using masked language modeling and evolutionary algorithms.
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
This model combines a pre-trained masked language model with evolutionary
|
|
35
35
|
algorithms to design RNA sequences that fold into specific target structures.
|
|
36
36
|
It uses a multi-objective optimization approach to balance structure similarity
|
|
37
37
|
and thermodynamic stability.
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
Attributes:
|
|
40
40
|
device: Device to run the model on (CPU or GPU)
|
|
41
41
|
parallel: Whether to use parallel processing for structure prediction
|
|
42
42
|
tokenizer: Tokenizer for processing RNA sequences
|
|
43
43
|
model: Pre-trained masked language model
|
|
44
44
|
"""
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
def __init__(
|
|
47
47
|
self,
|
|
48
48
|
model="yangheng/OmniGenome-186M",
|
|
@@ -53,7 +53,7 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
53
53
|
):
|
|
54
54
|
"""
|
|
55
55
|
Initialize the RNA design model.
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
Args:
|
|
58
58
|
model (str): Model name or path for the pre-trained MLM model
|
|
59
59
|
device: Device to run the model on (default: None, auto-detect)
|
|
@@ -72,10 +72,10 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
72
72
|
def _random_bp_span(bp_span=None):
|
|
73
73
|
"""
|
|
74
74
|
Generate a random base pair span.
|
|
75
|
-
|
|
75
|
+
|
|
76
76
|
Args:
|
|
77
77
|
bp_span (int, optional): Fixed base pair span. If None, generates random.
|
|
78
|
-
|
|
78
|
+
|
|
79
79
|
Returns:
|
|
80
80
|
int: Base pair span value
|
|
81
81
|
"""
|
|
@@ -87,16 +87,16 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
87
87
|
def _longest_bp_span(structure):
|
|
88
88
|
"""
|
|
89
89
|
Find the longest base pair span in the structure.
|
|
90
|
-
|
|
90
|
+
|
|
91
91
|
Args:
|
|
92
92
|
structure (str): RNA structure in dot-bracket notation
|
|
93
|
-
|
|
93
|
+
|
|
94
94
|
Returns:
|
|
95
95
|
int: Length of the longest base pair span
|
|
96
96
|
"""
|
|
97
97
|
max_span = 0
|
|
98
98
|
current_span = 0
|
|
99
|
-
|
|
99
|
+
|
|
100
100
|
for char in structure:
|
|
101
101
|
if char == "(":
|
|
102
102
|
current_span += 1
|
|
@@ -105,18 +105,18 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
105
105
|
current_span = max(0, current_span - 1)
|
|
106
106
|
else:
|
|
107
107
|
current_span = 0
|
|
108
|
-
|
|
108
|
+
|
|
109
109
|
return max_span
|
|
110
110
|
|
|
111
111
|
@staticmethod
|
|
112
112
|
def _predict_structure_single(sequence, bp_span=-1):
|
|
113
113
|
"""
|
|
114
114
|
Predict structure for a single sequence (worker function for multiprocessing).
|
|
115
|
-
|
|
115
|
+
|
|
116
116
|
Args:
|
|
117
117
|
sequence (str): RNA sequence to fold
|
|
118
118
|
bp_span (int): Base pair span parameter
|
|
119
|
-
|
|
119
|
+
|
|
120
120
|
Returns:
|
|
121
121
|
tuple: (structure, mfe) tuple
|
|
122
122
|
"""
|
|
@@ -129,30 +129,30 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
129
129
|
def _predict_structure(self, sequences, bp_span=-1):
|
|
130
130
|
"""
|
|
131
131
|
Predict structures for multiple sequences.
|
|
132
|
-
|
|
132
|
+
|
|
133
133
|
Args:
|
|
134
134
|
sequences (list): List of RNA sequences
|
|
135
135
|
bp_span (int): Base pair span parameter
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
Returns:
|
|
138
138
|
list: List of (structure, mfe) tuples
|
|
139
139
|
"""
|
|
140
140
|
if not self.parallel or len(sequences) <= 1:
|
|
141
141
|
# Sequential processing
|
|
142
142
|
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
143
|
-
|
|
143
|
+
|
|
144
144
|
# Parallel processing with improved error handling
|
|
145
145
|
try:
|
|
146
146
|
# Determine number of workers
|
|
147
147
|
max_workers = min(os.cpu_count(), len(sequences), 8) # Limit to 8 workers
|
|
148
|
-
|
|
148
|
+
|
|
149
149
|
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
150
150
|
# Submit all tasks
|
|
151
151
|
future_to_seq = {
|
|
152
|
-
executor.submit(self._predict_structure_single, seq, bp_span): seq
|
|
152
|
+
executor.submit(self._predict_structure_single, seq, bp_span): seq
|
|
153
153
|
for seq in sequences
|
|
154
154
|
}
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
# Collect results
|
|
157
157
|
results = []
|
|
158
158
|
for future in as_completed(future_to_seq):
|
|
@@ -164,112 +164,119 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
164
164
|
warnings.warn(f"Failed to process sequence {seq}: {e}")
|
|
165
165
|
# Fallback to dot structure
|
|
166
166
|
results.append(("." * len(seq), 0.0))
|
|
167
|
-
|
|
167
|
+
|
|
168
168
|
return results
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
except Exception as e:
|
|
171
|
-
warnings.warn(
|
|
171
|
+
warnings.warn(
|
|
172
|
+
f"Parallel processing failed, falling back to sequential: {e}"
|
|
173
|
+
)
|
|
172
174
|
# Fallback to sequential processing
|
|
173
175
|
return [self._predict_structure_single(seq, bp_span) for seq in sequences]
|
|
174
176
|
|
|
175
177
|
def _init_population(self, structure, num_population):
|
|
176
178
|
"""
|
|
177
179
|
Initialize the population with random sequences.
|
|
178
|
-
|
|
180
|
+
|
|
179
181
|
Args:
|
|
180
182
|
structure (str): Target RNA structure
|
|
181
183
|
num_population (int): Population size
|
|
182
|
-
|
|
184
|
+
|
|
183
185
|
Returns:
|
|
184
186
|
list: List of (sequence, bp_span) tuples
|
|
185
187
|
"""
|
|
186
188
|
population = []
|
|
187
189
|
bp_span = self._longest_bp_span(structure)
|
|
188
|
-
|
|
190
|
+
|
|
189
191
|
for _ in range(num_population):
|
|
190
192
|
# Generate random sequence
|
|
191
193
|
sequence = "".join(random.choice("ACGU") for _ in range(len(structure)))
|
|
192
194
|
population.append((sequence, bp_span))
|
|
193
|
-
|
|
195
|
+
|
|
194
196
|
return population
|
|
195
197
|
|
|
196
198
|
def _mlm_mutate(self, population, structure, mutation_ratio):
|
|
197
199
|
"""
|
|
198
200
|
Mutate population using masked language modeling.
|
|
199
|
-
|
|
201
|
+
|
|
200
202
|
Args:
|
|
201
203
|
population (list): Current population
|
|
202
204
|
structure (str): Target RNA structure
|
|
203
205
|
mutation_ratio (float): Ratio of tokens to mutate
|
|
204
|
-
|
|
206
|
+
|
|
205
207
|
Returns:
|
|
206
208
|
list: Mutated population
|
|
207
209
|
"""
|
|
210
|
+
|
|
208
211
|
def mutate(sequence, mutation_rate):
|
|
209
212
|
# Create masked sequence
|
|
210
213
|
masked_sequence = list(sequence)
|
|
211
214
|
num_mutations = int(len(sequence) * mutation_rate)
|
|
212
215
|
mutation_positions = random.sample(range(len(sequence)), num_mutations)
|
|
213
|
-
|
|
216
|
+
|
|
214
217
|
for pos in mutation_positions:
|
|
215
218
|
masked_sequence[pos] = self.tokenizer.mask_token
|
|
216
|
-
|
|
219
|
+
|
|
217
220
|
return "".join(masked_sequence)
|
|
218
|
-
|
|
221
|
+
|
|
219
222
|
# Prepare inputs for MLM
|
|
220
223
|
mlm_inputs = []
|
|
221
224
|
for sequence, bp_span in population:
|
|
222
225
|
masked_seq = mutate(sequence, mutation_ratio)
|
|
223
226
|
mlm_inputs.append(masked_seq)
|
|
224
|
-
|
|
227
|
+
|
|
225
228
|
# Get predictions from MLM
|
|
226
229
|
predicted_tokens = self._mlm_predict(mlm_inputs, structure)
|
|
227
|
-
|
|
230
|
+
|
|
228
231
|
# Convert predictions back to sequences
|
|
229
232
|
mutated_population = []
|
|
230
233
|
for i, (sequence, bp_span) in enumerate(population):
|
|
231
234
|
# Convert token IDs back to nucleotides
|
|
232
|
-
new_sequence = self.tokenizer.decode(
|
|
235
|
+
new_sequence = self.tokenizer.decode(
|
|
236
|
+
predicted_tokens[i], skip_special_tokens=True
|
|
237
|
+
)
|
|
233
238
|
# Ensure the sequence has the correct length
|
|
234
239
|
if len(new_sequence) != len(structure):
|
|
235
|
-
new_sequence = new_sequence[:len(structure)].ljust(len(structure), "A")
|
|
240
|
+
new_sequence = new_sequence[: len(structure)].ljust(len(structure), "A")
|
|
236
241
|
mutated_population.append((new_sequence, bp_span))
|
|
237
|
-
|
|
242
|
+
|
|
238
243
|
return mutated_population
|
|
239
244
|
|
|
240
245
|
def _crossover(self, population, num_points=3):
|
|
241
246
|
"""
|
|
242
247
|
Perform crossover operation on the population.
|
|
243
|
-
|
|
248
|
+
|
|
244
249
|
Args:
|
|
245
250
|
population (list): Current population
|
|
246
251
|
num_points (int): Number of crossover points
|
|
247
|
-
|
|
252
|
+
|
|
248
253
|
Returns:
|
|
249
254
|
list: Population after crossover
|
|
250
255
|
"""
|
|
251
256
|
if len(population) < 2:
|
|
252
257
|
return population
|
|
253
|
-
|
|
258
|
+
|
|
254
259
|
# Create crossover masks
|
|
255
260
|
num_sequences = len(population)
|
|
256
261
|
masks = np.zeros((num_sequences, len(population[0][0])), dtype=bool)
|
|
257
|
-
|
|
262
|
+
|
|
258
263
|
# Generate random crossover points
|
|
259
|
-
crossover_points = np.random.randint(
|
|
260
|
-
|
|
264
|
+
crossover_points = np.random.randint(
|
|
265
|
+
0, len(population[0][0]), (num_sequences, num_points)
|
|
266
|
+
)
|
|
267
|
+
|
|
261
268
|
# Create parent indices
|
|
262
269
|
parent_indices = np.random.randint(0, num_sequences, (num_sequences, 2))
|
|
263
|
-
|
|
270
|
+
|
|
264
271
|
# Generate crossover masks
|
|
265
272
|
for i in range(num_sequences):
|
|
266
273
|
for j in range(num_points):
|
|
267
274
|
if j == 0:
|
|
268
|
-
masks[i, :crossover_points[i, j]] = True
|
|
275
|
+
masks[i, : crossover_points[i, j]] = True
|
|
269
276
|
else:
|
|
270
|
-
last_point = crossover_points[i, j-1]
|
|
271
|
-
masks[i, last_point:crossover_points[i, j]] = j % 2 == 0
|
|
272
|
-
|
|
277
|
+
last_point = crossover_points[i, j - 1]
|
|
278
|
+
masks[i, last_point : crossover_points[i, j]] = j % 2 == 0
|
|
279
|
+
|
|
273
280
|
# Handle the last segment
|
|
274
281
|
last_point = crossover_points[i, -1]
|
|
275
282
|
masks[i, last_point:] = num_points % 2 == 0
|
|
@@ -298,17 +305,17 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
298
305
|
def _evaluate_structure_fitness(self, sequences, structure):
|
|
299
306
|
"""
|
|
300
307
|
Evaluate the fitness of the RNA structure by comparing with the target structure.
|
|
301
|
-
|
|
308
|
+
|
|
302
309
|
Args:
|
|
303
310
|
sequences (list): List of (sequence, bp_span) tuples to evaluate
|
|
304
311
|
structure (str): Target RNA structure
|
|
305
|
-
|
|
312
|
+
|
|
306
313
|
Returns:
|
|
307
314
|
list: Sorted population with fitness scores and MFE values
|
|
308
315
|
"""
|
|
309
316
|
# Get sequences for structure prediction
|
|
310
317
|
seq_list = [seq for seq, _ in sequences]
|
|
311
|
-
|
|
318
|
+
|
|
312
319
|
# Predict structures (with improved multiprocessing)
|
|
313
320
|
structures_mfe = self._predict_structure(seq_list)
|
|
314
321
|
|
|
@@ -326,11 +333,11 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
326
333
|
def _non_dominated_sorting(scores, mfe_values):
|
|
327
334
|
"""
|
|
328
335
|
Perform non-dominated sorting for multi-objective optimization.
|
|
329
|
-
|
|
336
|
+
|
|
330
337
|
Args:
|
|
331
338
|
scores (list): Structure similarity scores
|
|
332
339
|
mfe_values (list): Minimum free energy values
|
|
333
|
-
|
|
340
|
+
|
|
334
341
|
Returns:
|
|
335
342
|
list: List of fronts (Pareto fronts)
|
|
336
343
|
"""
|
|
@@ -369,11 +376,11 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
369
376
|
def _select_next_generation(next_generation, fronts):
|
|
370
377
|
"""
|
|
371
378
|
Select the next generation based on Pareto fronts.
|
|
372
|
-
|
|
379
|
+
|
|
373
380
|
Args:
|
|
374
381
|
next_generation (list): Current population with fitness scores
|
|
375
382
|
fronts (list): Pareto fronts
|
|
376
|
-
|
|
383
|
+
|
|
377
384
|
Returns:
|
|
378
385
|
list: Selected population for the next generation
|
|
379
386
|
"""
|
|
@@ -389,11 +396,11 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
389
396
|
def _mlm_predict(self, mlm_inputs, structure):
|
|
390
397
|
"""
|
|
391
398
|
Perform masked language model prediction.
|
|
392
|
-
|
|
399
|
+
|
|
393
400
|
Args:
|
|
394
401
|
mlm_inputs (list): List of masked input sequences
|
|
395
402
|
structure (str): Target RNA structure
|
|
396
|
-
|
|
403
|
+
|
|
397
404
|
Returns:
|
|
398
405
|
list: Predicted token IDs for each input
|
|
399
406
|
"""
|
|
@@ -403,7 +410,7 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
403
410
|
with torch.no_grad():
|
|
404
411
|
for i in range(0, len(mlm_inputs), batch_size):
|
|
405
412
|
inputs = self.tokenizer(
|
|
406
|
-
mlm_inputs[i: i + batch_size],
|
|
413
|
+
mlm_inputs[i : i + batch_size],
|
|
407
414
|
padding=False,
|
|
408
415
|
max_length=1024,
|
|
409
416
|
truncation=True,
|
|
@@ -422,13 +429,13 @@ class OmniModelForRNADesign(torch.nn.Module):
|
|
|
422
429
|
):
|
|
423
430
|
"""
|
|
424
431
|
Design RNA sequences for a target structure using evolutionary algorithms.
|
|
425
|
-
|
|
432
|
+
|
|
426
433
|
Args:
|
|
427
434
|
structure (str): Target RNA structure in dot-bracket notation
|
|
428
435
|
mutation_ratio (float): Ratio of tokens to mutate (default: 0.5)
|
|
429
436
|
num_population (int): Population size (default: 100)
|
|
430
437
|
num_generation (int): Number of generations (default: 100)
|
|
431
|
-
|
|
438
|
+
|
|
432
439
|
Returns:
|
|
433
440
|
list: List of designed RNA sequences with their fitness scores
|
|
434
441
|
"""
|
|
@@ -21,20 +21,20 @@ from ...abc.abstract_model import OmniModel
|
|
|
21
21
|
class OmniModelForSeq2Seq(OmniModel):
|
|
22
22
|
"""
|
|
23
23
|
Sequence-to-sequence model for genomic sequences.
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
This model implements a sequence-to-sequence architecture for genomic
|
|
26
26
|
sequences, where the input is one sequence and the output is another
|
|
27
27
|
sequence. It's useful for tasks like sequence translation, structure
|
|
28
28
|
prediction, or sequence transformation.
|
|
29
|
-
|
|
29
|
+
|
|
30
30
|
The model can be extended to implement specific seq2seq tasks by
|
|
31
31
|
overriding the forward, predict, and inference methods.
|
|
32
32
|
"""
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
35
35
|
"""
|
|
36
36
|
Initialize the sequence-to-sequence model.
|
|
37
|
-
|
|
37
|
+
|
|
38
38
|
Args:
|
|
39
39
|
config_or_model: Model configuration or pre-trained model
|
|
40
40
|
tokenizer: Tokenizer for processing input sequences
|