omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: resnet.py
|
|
3
|
+
# time: 14:43 29/01/2025
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# Homepage: https://yangheng95.github.io
|
|
6
|
+
# github: https://github.com/yangheng95
|
|
7
|
+
# huggingface: https://huggingface.co/yangheng
|
|
8
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
9
|
+
# Copyright (C) 2019-2025. All Rights Reserved.
|
|
10
|
+
# Adapted from: https://github.com/terry-r123/RNABenchmark/blob/main/downstream/structure/resnet.py
|
|
11
|
+
"""
|
|
12
|
+
ResNet implementation for genomic sequence analysis.
|
|
13
|
+
|
|
14
|
+
This module provides a ResNet architecture adapted for processing genomic sequences
|
|
15
|
+
and their structural representations. It includes basic blocks, bottleneck blocks,
|
|
16
|
+
and a complete ResNet implementation optimized for genomic data.
|
|
17
|
+
"""
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
from typing import Type, Callable, Union, List, Optional
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
24
|
+
"""
|
|
25
|
+
3x3 convolution with padding.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
in_planes (int): Number of input channels
|
|
29
|
+
out_planes (int): Number of output channels
|
|
30
|
+
stride (int): Stride for the convolution (default: 1)
|
|
31
|
+
groups (int): Number of groups for grouped convolution (default: 1)
|
|
32
|
+
dilation (int): Dilation factor for the convolution (default: 1)
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
nn.Conv2d: 3x3 convolution layer
|
|
36
|
+
"""
|
|
37
|
+
return nn.Conv2d(
|
|
38
|
+
in_planes,
|
|
39
|
+
out_planes,
|
|
40
|
+
kernel_size=3,
|
|
41
|
+
stride=stride,
|
|
42
|
+
padding=dilation,
|
|
43
|
+
groups=groups,
|
|
44
|
+
bias=False,
|
|
45
|
+
dilation=dilation,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def conv1x1(in_planes, out_planes, stride=1):
|
|
50
|
+
"""
|
|
51
|
+
1x1 convolution.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
in_planes (int): Number of input channels
|
|
55
|
+
out_planes (int): Number of output channels
|
|
56
|
+
stride (int): Stride for the convolution (default: 1)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
nn.Conv2d: 1x1 convolution layer
|
|
60
|
+
"""
|
|
61
|
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
65
|
+
"""
|
|
66
|
+
5x5 convolution with padding.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
in_planes (int): Number of input channels
|
|
70
|
+
out_planes (int): Number of output channels
|
|
71
|
+
stride (int): Stride for the convolution (default: 1)
|
|
72
|
+
groups (int): Number of groups for grouped convolution (default: 1)
|
|
73
|
+
dilation (int): Dilation factor for the convolution (default: 1)
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
nn.Conv2d: 5x5 convolution layer
|
|
77
|
+
"""
|
|
78
|
+
return nn.Conv2d(
|
|
79
|
+
in_planes,
|
|
80
|
+
out_planes,
|
|
81
|
+
kernel_size=5,
|
|
82
|
+
stride=stride,
|
|
83
|
+
padding=2,
|
|
84
|
+
groups=groups,
|
|
85
|
+
bias=False,
|
|
86
|
+
dilation=dilation,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class BasicBlock(nn.Module):
|
|
91
|
+
"""
|
|
92
|
+
Basic ResNet block for genomic sequence processing.
|
|
93
|
+
|
|
94
|
+
This block implements a basic residual connection with two convolutions
|
|
95
|
+
and is optimized for processing genomic sequence data with layer normalization.
|
|
96
|
+
|
|
97
|
+
Attributes:
|
|
98
|
+
expansion (int): Expansion factor for the block (default: 1)
|
|
99
|
+
conv1: First 3x3 convolution layer
|
|
100
|
+
bn1: First layer normalization
|
|
101
|
+
conv2: Second 5x5 convolution layer
|
|
102
|
+
bn2: Second layer normalization
|
|
103
|
+
relu: ReLU activation function
|
|
104
|
+
drop: Dropout layer
|
|
105
|
+
downsample: Downsampling layer for residual connection
|
|
106
|
+
stride: Stride for the convolutions
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
expansion: int = 1
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
inplanes: int,
|
|
114
|
+
planes: int,
|
|
115
|
+
stride: int = 1,
|
|
116
|
+
downsample=None,
|
|
117
|
+
groups: int = 1,
|
|
118
|
+
# base_width: int = 64,
|
|
119
|
+
dilation: int = 1,
|
|
120
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
121
|
+
) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Initialize the BasicBlock.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
inplanes (int): Number of input channels
|
|
127
|
+
planes (int): Number of output channels
|
|
128
|
+
stride (int): Stride for the convolutions (default: 1)
|
|
129
|
+
downsample: Downsampling layer for residual connection (default: None)
|
|
130
|
+
groups (int): Number of groups for grouped convolution (default: 1)
|
|
131
|
+
dilation (int): Dilation factor for convolutions (default: 1)
|
|
132
|
+
norm_layer: Normalization layer type (default: None, uses LayerNorm)
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
NotImplementedError: If dilation > 1 is specified
|
|
136
|
+
"""
|
|
137
|
+
super(BasicBlock, self).__init__()
|
|
138
|
+
if norm_layer is None:
|
|
139
|
+
norm_layer = nn.LayerNorm
|
|
140
|
+
# if groups != 1 or base_width != 64:
|
|
141
|
+
# raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
|
142
|
+
if dilation > 1:
|
|
143
|
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
|
144
|
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
|
145
|
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
146
|
+
self.bn1 = norm_layer(planes)
|
|
147
|
+
self.relu = nn.ReLU(inplace=False)
|
|
148
|
+
self.drop = nn.Dropout(0.25, inplace=False)
|
|
149
|
+
self.conv2 = conv5x5(planes, planes)
|
|
150
|
+
self.bn2 = norm_layer(planes)
|
|
151
|
+
self.downsample = downsample
|
|
152
|
+
self.stride = stride
|
|
153
|
+
|
|
154
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
155
|
+
"""
|
|
156
|
+
Forward pass through the BasicBlock.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Tensor: Output tensor with same shape as input
|
|
163
|
+
"""
|
|
164
|
+
identity = x
|
|
165
|
+
|
|
166
|
+
x = x.permute(0, 2, 3, 1)
|
|
167
|
+
out = self.bn1(x)
|
|
168
|
+
out = out.permute(0, 3, 1, 2)
|
|
169
|
+
out = self.relu(out)
|
|
170
|
+
out = self.drop(out)
|
|
171
|
+
out = self.conv1(out)
|
|
172
|
+
|
|
173
|
+
out = out.permute(0, 2, 3, 1)
|
|
174
|
+
out = self.bn2(out)
|
|
175
|
+
out = out.permute(0, 3, 1, 2)
|
|
176
|
+
out = self.relu(out)
|
|
177
|
+
out = self.drop(out)
|
|
178
|
+
out = self.conv2(out)
|
|
179
|
+
|
|
180
|
+
if self.downsample is not None:
|
|
181
|
+
identity = self.downsample(x)
|
|
182
|
+
|
|
183
|
+
out = out + identity
|
|
184
|
+
|
|
185
|
+
return out
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class Bottleneck(nn.Module):
|
|
189
|
+
"""
|
|
190
|
+
Bottleneck ResNet block for genomic sequence processing.
|
|
191
|
+
|
|
192
|
+
This block implements a bottleneck residual connection with three convolutions
|
|
193
|
+
(1x1, 3x3, 1x1) and is designed for deeper networks. It's adapted from
|
|
194
|
+
the original ResNet V1.5 implementation.
|
|
195
|
+
|
|
196
|
+
Attributes:
|
|
197
|
+
expansion (int): Expansion factor for the block (default: 4)
|
|
198
|
+
conv1: First 1x1 convolution layer
|
|
199
|
+
bn1: First batch normalization
|
|
200
|
+
conv2: Second 3x3 convolution layer
|
|
201
|
+
bn2: Second batch normalization
|
|
202
|
+
conv3: Third 1x1 convolution layer
|
|
203
|
+
bn3: Third batch normalization
|
|
204
|
+
relu: ReLU activation function
|
|
205
|
+
downsample: Downsampling layer for residual connection
|
|
206
|
+
stride: Stride for the convolutions
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
|
210
|
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
|
211
|
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
|
212
|
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
|
213
|
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
|
214
|
+
|
|
215
|
+
expansion: int = 4
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
inplanes: int,
|
|
220
|
+
planes: int,
|
|
221
|
+
stride: int = 1,
|
|
222
|
+
downsample: Optional[nn.Module] = None,
|
|
223
|
+
groups: int = 1,
|
|
224
|
+
base_width: int = 64,
|
|
225
|
+
dilation: int = 1,
|
|
226
|
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
|
227
|
+
) -> None:
|
|
228
|
+
"""
|
|
229
|
+
Initialize the Bottleneck block.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
inplanes (int): Number of input channels
|
|
233
|
+
planes (int): Number of output channels
|
|
234
|
+
stride (int): Stride for the convolutions (default: 1)
|
|
235
|
+
downsample: Downsampling layer for residual connection (default: None)
|
|
236
|
+
groups (int): Number of groups for grouped convolution (default: 1)
|
|
237
|
+
base_width (int): Base width for the bottleneck (default: 64)
|
|
238
|
+
dilation (int): Dilation factor for convolutions (default: 1)
|
|
239
|
+
norm_layer: Normalization layer type (default: None, uses BatchNorm2d)
|
|
240
|
+
"""
|
|
241
|
+
super(Bottleneck, self).__init__()
|
|
242
|
+
if norm_layer is None:
|
|
243
|
+
norm_layer = nn.BatchNorm2d
|
|
244
|
+
width = int(planes * (base_width / 64.0)) * groups
|
|
245
|
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
246
|
+
self.conv1 = conv1x1(inplanes, width)
|
|
247
|
+
self.bn1 = norm_layer(width)
|
|
248
|
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
|
249
|
+
self.bn2 = norm_layer(width)
|
|
250
|
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
|
251
|
+
self.bn3 = norm_layer(planes * self.expansion)
|
|
252
|
+
self.relu = nn.ReLU(inplace=False)
|
|
253
|
+
self.downsample = downsample
|
|
254
|
+
self.stride = stride
|
|
255
|
+
|
|
256
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
257
|
+
"""
|
|
258
|
+
Forward pass through the Bottleneck block.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Tensor: Output tensor with same shape as input
|
|
265
|
+
"""
|
|
266
|
+
identity = x
|
|
267
|
+
|
|
268
|
+
out = self.conv1(x)
|
|
269
|
+
out = self.bn1(out)
|
|
270
|
+
out = self.relu(out)
|
|
271
|
+
|
|
272
|
+
out = self.conv2(out)
|
|
273
|
+
out = self.bn2(out)
|
|
274
|
+
out = self.relu(out)
|
|
275
|
+
|
|
276
|
+
out = self.conv3(out)
|
|
277
|
+
out = self.bn3(out)
|
|
278
|
+
|
|
279
|
+
if self.downsample is not None:
|
|
280
|
+
identity = self.downsample(x)
|
|
281
|
+
|
|
282
|
+
out = out + identity
|
|
283
|
+
out = self.relu(out)
|
|
284
|
+
|
|
285
|
+
return out
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class ResNet(nn.Module):
|
|
289
|
+
"""
|
|
290
|
+
ResNet architecture adapted for genomic sequence analysis.
|
|
291
|
+
|
|
292
|
+
This ResNet implementation is specifically designed for processing genomic
|
|
293
|
+
sequences and their structural representations. It uses layer normalization
|
|
294
|
+
instead of batch normalization and is optimized for genomic data characteristics.
|
|
295
|
+
|
|
296
|
+
Attributes:
|
|
297
|
+
_norm_layer: Normalization layer type
|
|
298
|
+
inplanes: Number of input channels for the first layer
|
|
299
|
+
dilation: Dilation factor for convolutions
|
|
300
|
+
groups: Number of groups for grouped convolutions
|
|
301
|
+
base_width: Base width for bottleneck blocks
|
|
302
|
+
conv1: Initial convolution layer
|
|
303
|
+
bn1: Initial normalization layer
|
|
304
|
+
relu: ReLU activation function
|
|
305
|
+
layer1: First layer of ResNet blocks
|
|
306
|
+
fc1: Final fully connected layer
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(
|
|
310
|
+
self,
|
|
311
|
+
channels,
|
|
312
|
+
block: Type[Union[BasicBlock, Bottleneck]],
|
|
313
|
+
layers: List[int],
|
|
314
|
+
zero_init_residual: bool = False,
|
|
315
|
+
groups: int = 1,
|
|
316
|
+
width_per_group: int = 1,
|
|
317
|
+
replace_stride_with_dilation=None,
|
|
318
|
+
norm_layer=None,
|
|
319
|
+
) -> None:
|
|
320
|
+
"""
|
|
321
|
+
Initialize the ResNet architecture.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
channels (int): Number of input channels
|
|
325
|
+
block: Type of ResNet block (BasicBlock or Bottleneck)
|
|
326
|
+
layers (List[int]): List specifying the number of blocks in each layer
|
|
327
|
+
zero_init_residual (bool): Whether to zero-initialize residual connections (default: False)
|
|
328
|
+
groups (int): Number of groups for grouped convolutions (default: 1)
|
|
329
|
+
width_per_group (int): Width per group for bottleneck blocks (default: 1)
|
|
330
|
+
replace_stride_with_dilation: Whether to replace stride with dilation (default: None)
|
|
331
|
+
norm_layer: Normalization layer type (default: None, uses LayerNorm)
|
|
332
|
+
|
|
333
|
+
Raises:
|
|
334
|
+
ValueError: If replace_stride_with_dilation is not None or a 3-element tuple
|
|
335
|
+
"""
|
|
336
|
+
super(ResNet, self).__init__()
|
|
337
|
+
if norm_layer is None:
|
|
338
|
+
norm_layer = nn.LayerNorm
|
|
339
|
+
self._norm_layer = norm_layer
|
|
340
|
+
|
|
341
|
+
self.inplanes = 48
|
|
342
|
+
self.dilation = 1
|
|
343
|
+
if replace_stride_with_dilation is None:
|
|
344
|
+
# each element in the tuple indicates if we should replace
|
|
345
|
+
# the 2x2 stride with a dilated convolution instead
|
|
346
|
+
replace_stride_with_dilation = [False, False, False]
|
|
347
|
+
if len(replace_stride_with_dilation) != 3:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
"replace_stride_with_dilation should be None "
|
|
350
|
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
|
|
351
|
+
)
|
|
352
|
+
self.groups = groups
|
|
353
|
+
self.base_width = width_per_group
|
|
354
|
+
self.conv1 = nn.Conv2d(
|
|
355
|
+
channels, self.inplanes, kernel_size=3, stride=1, padding=1
|
|
356
|
+
)
|
|
357
|
+
self.bn1 = norm_layer(self.inplanes)
|
|
358
|
+
self.relu = nn.ReLU(inplace=False)
|
|
359
|
+
self.layer1 = self._make_layer(block, 48, layers[0])
|
|
360
|
+
self.fc1 = nn.Linear(48, 1)
|
|
361
|
+
|
|
362
|
+
# Zero-initialize the last BN in each residual branch,
|
|
363
|
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
|
364
|
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
|
365
|
+
if zero_init_residual:
|
|
366
|
+
for m in self.modules():
|
|
367
|
+
if isinstance(m, Bottleneck):
|
|
368
|
+
nn.init.constant_(m.bn3.weight, 0)
|
|
369
|
+
elif isinstance(m, BasicBlock):
|
|
370
|
+
nn.init.constant_(m.bn2.weight, 0)
|
|
371
|
+
|
|
372
|
+
def _make_layer(
|
|
373
|
+
self,
|
|
374
|
+
block: Type[Union[BasicBlock, Bottleneck]],
|
|
375
|
+
planes: int,
|
|
376
|
+
blocks: int,
|
|
377
|
+
stride: int = 1,
|
|
378
|
+
dilate: bool = False,
|
|
379
|
+
) -> nn.Sequential:
|
|
380
|
+
"""
|
|
381
|
+
Create a layer of ResNet blocks.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
block: Type of ResNet block to use
|
|
385
|
+
planes (int): Number of output channels for the layer
|
|
386
|
+
blocks (int): Number of blocks in the layer
|
|
387
|
+
stride (int): Stride for the first block (default: 1)
|
|
388
|
+
dilate (bool): Whether to use dilation (default: False)
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
nn.Sequential: Sequential container of ResNet blocks
|
|
392
|
+
"""
|
|
393
|
+
norm_layer = self._norm_layer
|
|
394
|
+
downsample = None
|
|
395
|
+
previous_dilation = self.dilation
|
|
396
|
+
if dilate:
|
|
397
|
+
self.dilation *= stride
|
|
398
|
+
stride = 1
|
|
399
|
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
400
|
+
downsample = nn.Sequential(
|
|
401
|
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
|
402
|
+
norm_layer(planes * block.expansion),
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
layers = []
|
|
406
|
+
layers.append(
|
|
407
|
+
block(
|
|
408
|
+
self.inplanes,
|
|
409
|
+
planes,
|
|
410
|
+
stride,
|
|
411
|
+
downsample,
|
|
412
|
+
self.groups,
|
|
413
|
+
self.base_width,
|
|
414
|
+
previous_dilation,
|
|
415
|
+
norm_layer,
|
|
416
|
+
)
|
|
417
|
+
)
|
|
418
|
+
self.inplanes = planes * block.expansion
|
|
419
|
+
for _ in range(1, blocks):
|
|
420
|
+
layers.append(
|
|
421
|
+
block(
|
|
422
|
+
self.inplanes,
|
|
423
|
+
planes,
|
|
424
|
+
groups=self.groups,
|
|
425
|
+
base_width=self.base_width,
|
|
426
|
+
dilation=self.dilation,
|
|
427
|
+
norm_layer=norm_layer,
|
|
428
|
+
)
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
return nn.Sequential(*layers)
|
|
432
|
+
|
|
433
|
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
|
434
|
+
"""
|
|
435
|
+
Forward pass implementation.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
Tensor: Output tensor after processing through ResNet
|
|
442
|
+
"""
|
|
443
|
+
# [bz,hd,len,len]
|
|
444
|
+
x = self.conv1(x)
|
|
445
|
+
x = x.permute(0, 2, 3, 1)
|
|
446
|
+
x = self.bn1(x)
|
|
447
|
+
x = x.permute(0, 3, 1, 2)
|
|
448
|
+
x = self.relu(x)
|
|
449
|
+
|
|
450
|
+
x = self.layer1(x)
|
|
451
|
+
x = x.mean(dim=[2, 3])
|
|
452
|
+
x = self.fc1(x)
|
|
453
|
+
|
|
454
|
+
return x
|
|
455
|
+
|
|
456
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
457
|
+
"""
|
|
458
|
+
Forward pass through the ResNet.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
x (Tensor): Input tensor [batch_size, channels, height, width]
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Tensor: Output tensor after processing through ResNet
|
|
465
|
+
"""
|
|
466
|
+
return self._forward_impl(x)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def resnet_b16(channels=128, bbn=16):
|
|
470
|
+
"""
|
|
471
|
+
Create a ResNet-B16 model for genomic sequence analysis.
|
|
472
|
+
|
|
473
|
+
This function creates a ResNet model with 16 basic blocks, optimized
|
|
474
|
+
for processing genomic sequences and their structural representations.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
channels (int): Number of input channels (default: 128)
|
|
478
|
+
bbn (int): Number of basic blocks (default: 16)
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
ResNet: Configured ResNet model
|
|
482
|
+
"""
|
|
483
|
+
return ResNet(channels, BasicBlock, [bbn])
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 18:25 22/09/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains modules for RNA design models.
|
|
11
|
+
"""
|
|
12
|
+
|