braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from math import floor, log2
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SPARCNet(EEGModuleMixin, nn.Module):
|
|
14
|
+
r"""Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al (2023) [jing2023]_.
|
|
15
|
+
|
|
16
|
+
:bdg-success:`Convolution`
|
|
17
|
+
|
|
18
|
+
This is a temporal CNN model for biosignal classification based on the DenseNet
|
|
19
|
+
architecture.
|
|
20
|
+
|
|
21
|
+
The model is based on the unofficial implementation [Code2023]_.
|
|
22
|
+
|
|
23
|
+
.. versionadded:: 0.9
|
|
24
|
+
|
|
25
|
+
Notes
|
|
26
|
+
-----
|
|
27
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
28
|
+
by original authors.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
block_layers : int, optional
|
|
33
|
+
Number of layers per dense block. Default is 4.
|
|
34
|
+
growth_rate : int, optional
|
|
35
|
+
Growth rate of the DenseNet. Default is 16.
|
|
36
|
+
bn_size : int, optional
|
|
37
|
+
Bottleneck size. Default is 16.
|
|
38
|
+
drop_prob : float, optional
|
|
39
|
+
Dropout rate. Default is 0.5.
|
|
40
|
+
conv_bias : bool, optional
|
|
41
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
42
|
+
batch_norm : bool, optional
|
|
43
|
+
Whether to use batch normalization. Default is True.
|
|
44
|
+
activation: nn.Module, default=nn.ELU
|
|
45
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
46
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
47
|
+
|
|
48
|
+
References
|
|
49
|
+
----------
|
|
50
|
+
.. [jing2023] Jing, J., Ge, W., Hong, S., Fernandes, M. B., Lin, Z.,
|
|
51
|
+
Yang, C., ... & Westover, M. B. (2023). Development of expert-level
|
|
52
|
+
classification of seizures and rhythmic and periodic
|
|
53
|
+
patterns during eeg interpretation. Neurology, 100(17), e1750-e1762.
|
|
54
|
+
.. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
|
|
55
|
+
Biosignal Transformer for Cross-data Learning in the Wild.
|
|
56
|
+
GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
n_chans=None,
|
|
63
|
+
n_times=None,
|
|
64
|
+
n_outputs=None,
|
|
65
|
+
# Neural network parameters
|
|
66
|
+
block_layers: int = 4,
|
|
67
|
+
growth_rate: int = 16,
|
|
68
|
+
bottleneck_size: int = 16,
|
|
69
|
+
drop_prob: float = 0.5,
|
|
70
|
+
conv_bias: bool = True,
|
|
71
|
+
batch_norm: bool = True,
|
|
72
|
+
activation: type[nn.Module] = nn.ELU,
|
|
73
|
+
kernel_size_conv0: int = 7,
|
|
74
|
+
kernel_size_conv1: int = 1,
|
|
75
|
+
kernel_size_conv2: int = 3,
|
|
76
|
+
kernel_size_pool: int = 3,
|
|
77
|
+
stride_pool: int = 2,
|
|
78
|
+
stride_conv0: int = 2,
|
|
79
|
+
stride_conv1: int = 1,
|
|
80
|
+
stride_conv2: int = 1,
|
|
81
|
+
padding_pool: int = 1,
|
|
82
|
+
padding_conv0: int = 3,
|
|
83
|
+
padding_conv2: int = 1,
|
|
84
|
+
kernel_size_trans: int = 2,
|
|
85
|
+
stride_trans: int = 2,
|
|
86
|
+
# EEGModuleMixin parameters
|
|
87
|
+
# (another way to present the same parameters)
|
|
88
|
+
chs_info=None,
|
|
89
|
+
input_window_seconds=None,
|
|
90
|
+
sfreq=None,
|
|
91
|
+
):
|
|
92
|
+
super().__init__(
|
|
93
|
+
n_outputs=n_outputs,
|
|
94
|
+
n_chans=n_chans,
|
|
95
|
+
chs_info=chs_info,
|
|
96
|
+
n_times=n_times,
|
|
97
|
+
input_window_seconds=input_window_seconds,
|
|
98
|
+
sfreq=sfreq,
|
|
99
|
+
)
|
|
100
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
101
|
+
|
|
102
|
+
# add initial convolutional layer
|
|
103
|
+
# the number of output channels is the smallest power of 2
|
|
104
|
+
# that is greater than the number of input channels
|
|
105
|
+
out_channels = 2 ** (floor(log2(self.n_chans)) + 1)
|
|
106
|
+
first_conv = OrderedDict(
|
|
107
|
+
[
|
|
108
|
+
(
|
|
109
|
+
"conv0",
|
|
110
|
+
nn.Conv1d(
|
|
111
|
+
in_channels=self.n_chans,
|
|
112
|
+
out_channels=out_channels,
|
|
113
|
+
kernel_size=kernel_size_conv0,
|
|
114
|
+
stride=stride_conv0,
|
|
115
|
+
padding=padding_conv0,
|
|
116
|
+
bias=conv_bias,
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
]
|
|
120
|
+
)
|
|
121
|
+
first_conv["norm0"] = nn.BatchNorm1d(out_channels)
|
|
122
|
+
first_conv["act_layer"] = activation()
|
|
123
|
+
first_conv["pool0"] = nn.MaxPool1d(
|
|
124
|
+
kernel_size=kernel_size_pool,
|
|
125
|
+
stride=stride_pool,
|
|
126
|
+
padding=padding_pool,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self.encoder = nn.Sequential(first_conv)
|
|
130
|
+
|
|
131
|
+
n_channels = out_channels
|
|
132
|
+
|
|
133
|
+
# Adding dense blocks
|
|
134
|
+
for n_layer in range(floor(log2(self.n_times // 4))):
|
|
135
|
+
block = _DenseBlock(
|
|
136
|
+
num_layers=block_layers,
|
|
137
|
+
in_channels=n_channels,
|
|
138
|
+
growth_rate=growth_rate,
|
|
139
|
+
bottleneck_size=bottleneck_size,
|
|
140
|
+
drop_prob=drop_prob,
|
|
141
|
+
conv_bias=conv_bias,
|
|
142
|
+
batch_norm=batch_norm,
|
|
143
|
+
activation=activation,
|
|
144
|
+
kernel_size_conv1=kernel_size_conv1,
|
|
145
|
+
kernel_size_conv2=kernel_size_conv2,
|
|
146
|
+
stride_conv1=stride_conv1,
|
|
147
|
+
stride_conv2=stride_conv2,
|
|
148
|
+
padding_conv2=padding_conv2,
|
|
149
|
+
)
|
|
150
|
+
self.encoder.add_module("denseblock%d" % (n_layer + 1), block)
|
|
151
|
+
# update the number of channels after each dense block
|
|
152
|
+
n_channels = n_channels + block_layers * growth_rate
|
|
153
|
+
|
|
154
|
+
trans = _TransitionLayer(
|
|
155
|
+
in_channels=n_channels,
|
|
156
|
+
out_channels=n_channels // 2,
|
|
157
|
+
conv_bias=conv_bias,
|
|
158
|
+
batch_norm=batch_norm,
|
|
159
|
+
activation=activation,
|
|
160
|
+
kernel_size_trans=kernel_size_trans,
|
|
161
|
+
stride_trans=stride_trans,
|
|
162
|
+
)
|
|
163
|
+
self.encoder.add_module("transition%d" % (n_layer + 1), trans)
|
|
164
|
+
# update the number of channels after each transition layer
|
|
165
|
+
n_channels = n_channels // 2
|
|
166
|
+
|
|
167
|
+
self.adaptative_pool = nn.AdaptiveAvgPool1d(1)
|
|
168
|
+
self.activation_layer = activation()
|
|
169
|
+
self.flatten_layer = nn.Flatten()
|
|
170
|
+
|
|
171
|
+
# add final convolutional layer
|
|
172
|
+
self.final_layer = nn.Linear(n_channels, self.n_outputs)
|
|
173
|
+
|
|
174
|
+
self._init_weights()
|
|
175
|
+
|
|
176
|
+
def _init_weights(self):
|
|
177
|
+
"""
|
|
178
|
+
Initialize the weights of the model.
|
|
179
|
+
|
|
180
|
+
Official init from torch repo, using kaiming_normal for conv layers
|
|
181
|
+
and normal for linear layers.
|
|
182
|
+
|
|
183
|
+
"""
|
|
184
|
+
for m in self.modules():
|
|
185
|
+
if isinstance(m, nn.Conv1d):
|
|
186
|
+
nn.init.kaiming_normal_(m.weight.data)
|
|
187
|
+
elif isinstance(m, nn.BatchNorm1d):
|
|
188
|
+
m.weight.data.fill_(1)
|
|
189
|
+
m.bias.data.zero_()
|
|
190
|
+
elif isinstance(m, nn.Linear):
|
|
191
|
+
m.bias.data.zero_()
|
|
192
|
+
|
|
193
|
+
def forward(self, X: torch.Tensor):
|
|
194
|
+
"""
|
|
195
|
+
Forward pass of the model.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
X: torch.Tensor
|
|
200
|
+
The input tensor of the model with shape (batch_size, n_channels, n_times)
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
torch.Tensor
|
|
205
|
+
The output tensor of the model with shape (batch_size, n_outputs)
|
|
206
|
+
"""
|
|
207
|
+
emb = self.encoder(X)
|
|
208
|
+
emb = self.adaptative_pool(emb)
|
|
209
|
+
emb = self.activation_layer(emb)
|
|
210
|
+
emb = self.flatten_layer(emb)
|
|
211
|
+
out = self.final_layer(emb)
|
|
212
|
+
return out
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class _DenseLayer(nn.Sequential):
|
|
216
|
+
r"""
|
|
217
|
+
A densely connected layer with batch normalization and dropout.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
in_channels : int
|
|
222
|
+
Number of input channels.
|
|
223
|
+
growth_rate : int
|
|
224
|
+
Rate of growth of channels in this layer.
|
|
225
|
+
bottleneck_size : int
|
|
226
|
+
Multiplicative factor for the bottleneck layer (does not affect the output size).
|
|
227
|
+
drop_prob : float, optional
|
|
228
|
+
Dropout rate. Default is 0.5.
|
|
229
|
+
conv_bias : bool, optional
|
|
230
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
231
|
+
batch_norm : bool, optional
|
|
232
|
+
Whether to use batch normalization. Default is True.
|
|
233
|
+
activation: nn.Module, default=nn.ELU
|
|
234
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
235
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
236
|
+
|
|
237
|
+
Examples
|
|
238
|
+
--------
|
|
239
|
+
>>> x = torch.randn(128, 5, 1000)
|
|
240
|
+
>>> batch, channels, length = x.shape
|
|
241
|
+
>>> model = _DenseLayer(channels, 5, 2)
|
|
242
|
+
>>> y = model(x)
|
|
243
|
+
>>> y.shape
|
|
244
|
+
torch.Size([128, 10, 1000])
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
def __init__(
|
|
248
|
+
self,
|
|
249
|
+
in_channels: int,
|
|
250
|
+
growth_rate: int,
|
|
251
|
+
bottleneck_size: int,
|
|
252
|
+
drop_prob: float = 0.5,
|
|
253
|
+
conv_bias: bool = True,
|
|
254
|
+
batch_norm: bool = True,
|
|
255
|
+
activation: type[nn.Module] = nn.ELU,
|
|
256
|
+
kernel_size_conv1: int = 1,
|
|
257
|
+
kernel_size_conv2: int = 3,
|
|
258
|
+
stride_conv1: int = 1,
|
|
259
|
+
stride_conv2: int = 1,
|
|
260
|
+
padding_conv2: int = 1,
|
|
261
|
+
):
|
|
262
|
+
super().__init__()
|
|
263
|
+
if batch_norm:
|
|
264
|
+
self.add_module("norm1", nn.BatchNorm1d(in_channels))
|
|
265
|
+
|
|
266
|
+
self.add_module("elu1", activation())
|
|
267
|
+
self.add_module(
|
|
268
|
+
"conv1",
|
|
269
|
+
nn.Conv1d(
|
|
270
|
+
in_channels=in_channels,
|
|
271
|
+
out_channels=bottleneck_size * growth_rate,
|
|
272
|
+
kernel_size=kernel_size_conv1,
|
|
273
|
+
stride=stride_conv1,
|
|
274
|
+
bias=conv_bias,
|
|
275
|
+
),
|
|
276
|
+
)
|
|
277
|
+
if batch_norm:
|
|
278
|
+
self.add_module("norm2", nn.BatchNorm1d(bottleneck_size * growth_rate))
|
|
279
|
+
self.add_module("elu2", activation())
|
|
280
|
+
self.add_module(
|
|
281
|
+
"conv2",
|
|
282
|
+
nn.Conv1d(
|
|
283
|
+
in_channels=bottleneck_size * growth_rate,
|
|
284
|
+
out_channels=growth_rate,
|
|
285
|
+
kernel_size=kernel_size_conv2,
|
|
286
|
+
stride=stride_conv2,
|
|
287
|
+
padding=padding_conv2,
|
|
288
|
+
bias=conv_bias,
|
|
289
|
+
),
|
|
290
|
+
)
|
|
291
|
+
self.drop_prob = drop_prob
|
|
292
|
+
|
|
293
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
294
|
+
# Manually pass through each submodule
|
|
295
|
+
out = x
|
|
296
|
+
for layer in self:
|
|
297
|
+
out = layer(out)
|
|
298
|
+
# apply dropout using the functional API
|
|
299
|
+
out = F.dropout(out, p=self.drop_prob, training=self.training)
|
|
300
|
+
# concatenate input and new features
|
|
301
|
+
return torch.cat([x, out], dim=1)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class _DenseBlock(nn.Sequential):
|
|
305
|
+
r"""
|
|
306
|
+
A densely connected block that uses DenseLayers.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
num_layers : int
|
|
311
|
+
Number of layers in this block.
|
|
312
|
+
in_channels : int
|
|
313
|
+
Number of input channels.
|
|
314
|
+
growth_rate : int
|
|
315
|
+
Rate of growth of channels in this layer.
|
|
316
|
+
bottleneck_size : int
|
|
317
|
+
Multiplicative factor for the bottleneck layer (does not affect the output size).
|
|
318
|
+
drop_prob : float, optional
|
|
319
|
+
Dropout rate. Default is 0.5.
|
|
320
|
+
conv_bias : bool, optional
|
|
321
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
322
|
+
batch_norm : bool, optional
|
|
323
|
+
Whether to use batch normalization. Default is True.
|
|
324
|
+
activation: nn.Module, default=nn.ELU
|
|
325
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
326
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
327
|
+
|
|
328
|
+
Examples
|
|
329
|
+
--------
|
|
330
|
+
>>> x = torch.randn(128, 5, 1000)
|
|
331
|
+
>>> batch, channels, length = x.shape
|
|
332
|
+
>>> model = _DenseBlock(3, channels, 5, 2)
|
|
333
|
+
>>> y = model(x)
|
|
334
|
+
>>> y.shape
|
|
335
|
+
torch.Size([128, 20, 1000])
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
num_layers,
|
|
341
|
+
in_channels,
|
|
342
|
+
growth_rate,
|
|
343
|
+
bottleneck_size,
|
|
344
|
+
drop_prob=0.5,
|
|
345
|
+
conv_bias=True,
|
|
346
|
+
batch_norm=True,
|
|
347
|
+
activation: type[nn.Module] = nn.ELU,
|
|
348
|
+
kernel_size_conv1: int = 1,
|
|
349
|
+
kernel_size_conv2: int = 3,
|
|
350
|
+
stride_conv1: int = 1,
|
|
351
|
+
stride_conv2: int = 1,
|
|
352
|
+
padding_conv2: int = 1,
|
|
353
|
+
):
|
|
354
|
+
super(_DenseBlock, self).__init__()
|
|
355
|
+
for idx_layer in range(num_layers):
|
|
356
|
+
layer = _DenseLayer(
|
|
357
|
+
in_channels=in_channels + idx_layer * growth_rate,
|
|
358
|
+
growth_rate=growth_rate,
|
|
359
|
+
bottleneck_size=bottleneck_size,
|
|
360
|
+
drop_prob=drop_prob,
|
|
361
|
+
conv_bias=conv_bias,
|
|
362
|
+
batch_norm=batch_norm,
|
|
363
|
+
activation=activation,
|
|
364
|
+
kernel_size_conv1=kernel_size_conv1,
|
|
365
|
+
kernel_size_conv2=kernel_size_conv2,
|
|
366
|
+
stride_conv1=stride_conv1,
|
|
367
|
+
stride_conv2=stride_conv2,
|
|
368
|
+
padding_conv2=padding_conv2,
|
|
369
|
+
)
|
|
370
|
+
self.add_module(f"denselayer{idx_layer + 1}", layer)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class _TransitionLayer(nn.Sequential):
|
|
374
|
+
r"""
|
|
375
|
+
A pooling transition layer.
|
|
376
|
+
|
|
377
|
+
Parameters
|
|
378
|
+
----------
|
|
379
|
+
in_channels : int
|
|
380
|
+
Number of input channels.
|
|
381
|
+
out_channels : int
|
|
382
|
+
Number of output channels.
|
|
383
|
+
conv_bias : bool, optional
|
|
384
|
+
Whether to use bias in convolutional layers. Default is True.
|
|
385
|
+
batch_norm : bool, optional
|
|
386
|
+
Whether to use batch normalization. Default is True.
|
|
387
|
+
activation: nn.Module, default=nn.ELU
|
|
388
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
389
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
390
|
+
|
|
391
|
+
Examples
|
|
392
|
+
--------
|
|
393
|
+
>>> x = torch.randn(128, 5, 1000)
|
|
394
|
+
>>> model = _TransitionLayer(5, 18)
|
|
395
|
+
>>> y = model(x)
|
|
396
|
+
>>> y.shape
|
|
397
|
+
torch.Size([128, 18, 500])
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
def __init__(
|
|
401
|
+
self,
|
|
402
|
+
in_channels,
|
|
403
|
+
out_channels,
|
|
404
|
+
conv_bias=True,
|
|
405
|
+
batch_norm=True,
|
|
406
|
+
activation: type[nn.Module] = nn.ELU,
|
|
407
|
+
kernel_size_trans: int = 2,
|
|
408
|
+
stride_trans: int = 2,
|
|
409
|
+
):
|
|
410
|
+
super(_TransitionLayer, self).__init__()
|
|
411
|
+
if batch_norm:
|
|
412
|
+
self.add_module("norm", nn.BatchNorm1d(in_channels))
|
|
413
|
+
self.add_module("elu", activation())
|
|
414
|
+
self.add_module(
|
|
415
|
+
"conv",
|
|
416
|
+
nn.Conv1d(
|
|
417
|
+
in_channels=in_channels,
|
|
418
|
+
out_channels=out_channels,
|
|
419
|
+
kernel_size=1,
|
|
420
|
+
stride=1,
|
|
421
|
+
bias=conv_bias,
|
|
422
|
+
),
|
|
423
|
+
)
|
|
424
|
+
self.add_module(
|
|
425
|
+
"pool", nn.AvgPool1d(kernel_size=kernel_size_trans, stride=stride_trans)
|
|
426
|
+
)
|