braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from einops.layers.torch import Rearrange
|
|
11
|
+
|
|
12
|
+
from braindecode.models.base import EEGModuleMixin
|
|
13
|
+
from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EEGNeX(EEGModuleMixin, nn.Module):
|
|
17
|
+
"""EEGNeX model from Chen et al. (2024) [eegnex]_.
|
|
18
|
+
|
|
19
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
|
|
20
|
+
:align: center
|
|
21
|
+
:alt: EEGNeX Architecture
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
activation : nn.Module, optional
|
|
26
|
+
Activation function to use. Default is `nn.ELU`.
|
|
27
|
+
depth_multiplier : int, optional
|
|
28
|
+
Depth multiplier for the depthwise convolution. Default is 2.
|
|
29
|
+
filter_1 : int, optional
|
|
30
|
+
Number of filters in the first convolutional layer. Default is 8.
|
|
31
|
+
filter_2 : int, optional
|
|
32
|
+
Number of filters in the second convolutional layer. Default is 32.
|
|
33
|
+
drop_prob: float, optional
|
|
34
|
+
Dropout rate. Default is 0.5.
|
|
35
|
+
kernel_block_4 : tuple[int, int], optional
|
|
36
|
+
Kernel size for block 4. Default is (1, 16).
|
|
37
|
+
dilation_block_4 : tuple[int, int], optional
|
|
38
|
+
Dilation rate for block 4. Default is (1, 2).
|
|
39
|
+
avg_pool_block4 : tuple[int, int], optional
|
|
40
|
+
Pooling size for block 4. Default is (1, 4).
|
|
41
|
+
kernel_block_5 : tuple[int, int], optional
|
|
42
|
+
Kernel size for block 5. Default is (1, 16).
|
|
43
|
+
dilation_block_5 : tuple[int, int], optional
|
|
44
|
+
Dilation rate for block 5. Default is (1, 4).
|
|
45
|
+
avg_pool_block5 : tuple[int, int], optional
|
|
46
|
+
Pooling size for block 5. Default is (1, 8).
|
|
47
|
+
|
|
48
|
+
Notes
|
|
49
|
+
-----
|
|
50
|
+
This implementation is not guaranteed to be correct, has not been checked
|
|
51
|
+
by original authors, only reimplemented from the paper description and
|
|
52
|
+
source code in tensorflow [EEGNexCode]_.
|
|
53
|
+
|
|
54
|
+
References
|
|
55
|
+
----------
|
|
56
|
+
.. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
|
|
57
|
+
Toward reliable signals decoding for electroencephalogram: A benchmark
|
|
58
|
+
study to EEGNeX. Biomedical Signal Processing and Control, 87, 105475.
|
|
59
|
+
.. [EEGNexCode] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
|
|
60
|
+
Toward reliable signals decoding for electroencephalogram: A benchmark
|
|
61
|
+
study to EEGNeX. https://github.com/chenxiachan/EEGNeX
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
# Signal related parameters
|
|
67
|
+
n_chans=None,
|
|
68
|
+
n_outputs=None,
|
|
69
|
+
n_times=None,
|
|
70
|
+
chs_info=None,
|
|
71
|
+
input_window_seconds=None,
|
|
72
|
+
sfreq=None,
|
|
73
|
+
# Model parameters
|
|
74
|
+
activation: nn.Module = nn.ELU,
|
|
75
|
+
depth_multiplier: int = 2,
|
|
76
|
+
filter_1: int = 8,
|
|
77
|
+
filter_2: int = 32,
|
|
78
|
+
drop_prob: float = 0.5,
|
|
79
|
+
kernel_block_1_2: int = 64,
|
|
80
|
+
kernel_block_4: int = 16,
|
|
81
|
+
dilation_block_4: int = 2,
|
|
82
|
+
avg_pool_block4: int = 4,
|
|
83
|
+
kernel_block_5: int = 16,
|
|
84
|
+
dilation_block_5: int = 4,
|
|
85
|
+
avg_pool_block5: int = 8,
|
|
86
|
+
max_norm_conv: float = 1.0,
|
|
87
|
+
max_norm_linear: float = 0.25,
|
|
88
|
+
):
|
|
89
|
+
super().__init__(
|
|
90
|
+
n_outputs=n_outputs,
|
|
91
|
+
n_chans=n_chans,
|
|
92
|
+
chs_info=chs_info,
|
|
93
|
+
n_times=n_times,
|
|
94
|
+
input_window_seconds=input_window_seconds,
|
|
95
|
+
sfreq=sfreq,
|
|
96
|
+
)
|
|
97
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
98
|
+
|
|
99
|
+
self.depth_multiplier = depth_multiplier
|
|
100
|
+
self.filter_1 = filter_1
|
|
101
|
+
self.filter_2 = filter_2
|
|
102
|
+
self.filter_3 = self.filter_2 * self.depth_multiplier
|
|
103
|
+
self.drop_prob = drop_prob
|
|
104
|
+
self.activation = activation
|
|
105
|
+
self.kernel_block_1_2 = (1, kernel_block_1_2)
|
|
106
|
+
self.kernel_block_4 = (1, kernel_block_4)
|
|
107
|
+
self.dilation_block_4 = (1, dilation_block_4)
|
|
108
|
+
self.avg_pool_block4 = (1, avg_pool_block4)
|
|
109
|
+
self.kernel_block_5 = (1, kernel_block_5)
|
|
110
|
+
self.dilation_block_5 = (1, dilation_block_5)
|
|
111
|
+
self.avg_pool_block5 = (1, avg_pool_block5)
|
|
112
|
+
|
|
113
|
+
# final layers output
|
|
114
|
+
self.in_features = self._calculate_output_length()
|
|
115
|
+
|
|
116
|
+
# Following paper nomenclature
|
|
117
|
+
self.block_1 = nn.Sequential(
|
|
118
|
+
Rearrange("batch ch time -> batch 1 ch time"),
|
|
119
|
+
nn.Conv2d(
|
|
120
|
+
in_channels=1,
|
|
121
|
+
out_channels=self.filter_1,
|
|
122
|
+
kernel_size=self.kernel_block_1_2,
|
|
123
|
+
padding="same",
|
|
124
|
+
bias=False,
|
|
125
|
+
),
|
|
126
|
+
nn.BatchNorm2d(num_features=self.filter_1),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self.block_2 = nn.Sequential(
|
|
130
|
+
nn.Conv2d(
|
|
131
|
+
in_channels=self.filter_1,
|
|
132
|
+
out_channels=self.filter_2,
|
|
133
|
+
kernel_size=self.kernel_block_1_2,
|
|
134
|
+
padding="same",
|
|
135
|
+
bias=False,
|
|
136
|
+
),
|
|
137
|
+
nn.BatchNorm2d(num_features=self.filter_2),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.block_3 = nn.Sequential(
|
|
141
|
+
Conv2dWithConstraint(
|
|
142
|
+
in_channels=self.filter_2,
|
|
143
|
+
out_channels=self.filter_3,
|
|
144
|
+
max_norm=max_norm_conv,
|
|
145
|
+
kernel_size=(self.n_chans, 1),
|
|
146
|
+
groups=self.filter_2,
|
|
147
|
+
bias=False,
|
|
148
|
+
),
|
|
149
|
+
nn.BatchNorm2d(num_features=self.filter_3),
|
|
150
|
+
self.activation(),
|
|
151
|
+
nn.AvgPool2d(
|
|
152
|
+
kernel_size=self.avg_pool_block4,
|
|
153
|
+
padding=(0, 1),
|
|
154
|
+
),
|
|
155
|
+
nn.Dropout(p=self.drop_prob),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.block_4 = nn.Sequential(
|
|
159
|
+
nn.Conv2d(
|
|
160
|
+
in_channels=self.filter_3,
|
|
161
|
+
out_channels=self.filter_2,
|
|
162
|
+
kernel_size=self.kernel_block_4,
|
|
163
|
+
dilation=self.dilation_block_4,
|
|
164
|
+
padding="same",
|
|
165
|
+
bias=False,
|
|
166
|
+
),
|
|
167
|
+
nn.BatchNorm2d(num_features=self.filter_2),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self.block_5 = nn.Sequential(
|
|
171
|
+
nn.Conv2d(
|
|
172
|
+
in_channels=self.filter_2,
|
|
173
|
+
out_channels=self.filter_1,
|
|
174
|
+
kernel_size=self.kernel_block_5,
|
|
175
|
+
dilation=self.dilation_block_5,
|
|
176
|
+
padding="same",
|
|
177
|
+
bias=False,
|
|
178
|
+
),
|
|
179
|
+
nn.BatchNorm2d(num_features=self.filter_1),
|
|
180
|
+
self.activation(),
|
|
181
|
+
nn.AvgPool2d(
|
|
182
|
+
kernel_size=self.avg_pool_block5,
|
|
183
|
+
padding=(0, 1),
|
|
184
|
+
),
|
|
185
|
+
nn.Dropout(p=self.drop_prob),
|
|
186
|
+
nn.Flatten(),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
self.final_layer = LinearWithConstraint(
|
|
190
|
+
in_features=self.in_features,
|
|
191
|
+
out_features=self.n_outputs,
|
|
192
|
+
max_norm=max_norm_linear,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
196
|
+
"""
|
|
197
|
+
Forward pass of the EEGNeX model.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
x : torch.Tensor
|
|
202
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
torch.Tensor
|
|
207
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
208
|
+
"""
|
|
209
|
+
# x shape: (batch_size, n_chans, n_times)
|
|
210
|
+
x = self.block_1(x)
|
|
211
|
+
# (batch_size, n_filter, n_chans, n_times)
|
|
212
|
+
x = self.block_2(x)
|
|
213
|
+
# (batch_size, n_filter*4, n_chans, n_times)
|
|
214
|
+
x = self.block_3(x)
|
|
215
|
+
# (batch_size, 1, n_filter*8, n_times//4)
|
|
216
|
+
x = self.block_4(x)
|
|
217
|
+
# (batch_size, 1, n_filter*8, n_times//4)
|
|
218
|
+
x = self.block_5(x)
|
|
219
|
+
# (batch_size, n_filter*(n_times//32))
|
|
220
|
+
x = self.final_layer(x)
|
|
221
|
+
|
|
222
|
+
return x
|
|
223
|
+
|
|
224
|
+
def _calculate_output_length(self) -> int:
|
|
225
|
+
# Pooling kernel sizes for the time dimension
|
|
226
|
+
p4 = self.avg_pool_block4[1]
|
|
227
|
+
p5 = self.avg_pool_block5[1]
|
|
228
|
+
|
|
229
|
+
# Padding for the time dimension (assumed from padding=(0, 1))
|
|
230
|
+
pad4 = 1
|
|
231
|
+
pad5 = 1
|
|
232
|
+
|
|
233
|
+
# Stride is assumed to be equal to kernel size (p4 and p5)
|
|
234
|
+
|
|
235
|
+
# Calculate time dimension after block 3 pooling
|
|
236
|
+
# Formula: floor((L_in + 2*padding - kernel_size) / stride) + 1
|
|
237
|
+
T3 = math.floor((self.n_times + 2 * pad4 - p4) / p4) + 1
|
|
238
|
+
|
|
239
|
+
# Calculate time dimension after block 5 pooling
|
|
240
|
+
T5 = math.floor((T3 + 2 * pad5 - p5) / p5) + 1
|
|
241
|
+
|
|
242
|
+
# Calculate final flattened features (channels * 1 * time_dim)
|
|
243
|
+
# The spatial dimension is reduced to 1 after block 3's depthwise conv.
|
|
244
|
+
final_in_features = (
|
|
245
|
+
self.filter_1 * T5
|
|
246
|
+
) # filter_1 is the number of channels before flatten
|
|
247
|
+
return final_in_features
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
# Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
# Tonio Ball
|
|
3
|
+
#
|
|
4
|
+
# License: BSD-3
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from einops.layers.torch import Rearrange
|
|
9
|
+
from torch import nn
|
|
10
|
+
from torch.nn import init
|
|
11
|
+
|
|
12
|
+
from braindecode.models.base import EEGModuleMixin
|
|
13
|
+
from braindecode.modules import (
|
|
14
|
+
AvgPool2dWithConv,
|
|
15
|
+
Ensure4d,
|
|
16
|
+
SqueezeFinalOutput,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EEGResNet(EEGModuleMixin, nn.Sequential):
|
|
21
|
+
"""EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
|
|
22
|
+
|
|
23
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
|
|
24
|
+
:align: center
|
|
25
|
+
:alt: EEGResNet Architecture
|
|
26
|
+
|
|
27
|
+
Model described in [Schirrmeister2017]_.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
in_chans :
|
|
32
|
+
Alias for ``n_chans``.
|
|
33
|
+
n_classes :
|
|
34
|
+
Alias for ``n_outputs``.
|
|
35
|
+
input_window_samples :
|
|
36
|
+
Alias for ``n_times``.
|
|
37
|
+
activation: nn.Module, default=nn.ELU
|
|
38
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
39
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
40
|
+
|
|
41
|
+
References
|
|
42
|
+
----------
|
|
43
|
+
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
44
|
+
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
45
|
+
& Ball, T. (2017). Deep learning with convolutional neural networks for ,
|
|
46
|
+
EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
|
|
47
|
+
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
n_chans=None,
|
|
53
|
+
n_outputs=None,
|
|
54
|
+
n_times=None,
|
|
55
|
+
final_pool_length="auto",
|
|
56
|
+
n_first_filters=20,
|
|
57
|
+
n_layers_per_block=2,
|
|
58
|
+
first_filter_length=3,
|
|
59
|
+
activation=nn.ELU,
|
|
60
|
+
split_first_layer=True,
|
|
61
|
+
batch_norm_alpha=0.1,
|
|
62
|
+
batch_norm_epsilon=1e-4,
|
|
63
|
+
conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
|
|
64
|
+
chs_info=None,
|
|
65
|
+
input_window_seconds=None,
|
|
66
|
+
sfreq=250,
|
|
67
|
+
):
|
|
68
|
+
super().__init__(
|
|
69
|
+
n_outputs=n_outputs,
|
|
70
|
+
n_chans=n_chans,
|
|
71
|
+
chs_info=chs_info,
|
|
72
|
+
n_times=n_times,
|
|
73
|
+
input_window_seconds=input_window_seconds,
|
|
74
|
+
sfreq=sfreq,
|
|
75
|
+
)
|
|
76
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
77
|
+
|
|
78
|
+
if final_pool_length == "auto":
|
|
79
|
+
assert self.n_times is not None
|
|
80
|
+
assert first_filter_length % 2 == 1
|
|
81
|
+
self.final_pool_length = final_pool_length
|
|
82
|
+
self.n_first_filters = n_first_filters
|
|
83
|
+
self.n_layers_per_block = n_layers_per_block
|
|
84
|
+
self.first_filter_length = first_filter_length
|
|
85
|
+
self.nonlinearity = activation
|
|
86
|
+
self.split_first_layer = split_first_layer
|
|
87
|
+
self.batch_norm_alpha = batch_norm_alpha
|
|
88
|
+
self.batch_norm_epsilon = batch_norm_epsilon
|
|
89
|
+
self.conv_weight_init_fn = conv_weight_init_fn
|
|
90
|
+
|
|
91
|
+
self.mapping = {
|
|
92
|
+
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
93
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
self.add_module("ensuredims", Ensure4d())
|
|
97
|
+
if self.split_first_layer:
|
|
98
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
99
|
+
self.add_module(
|
|
100
|
+
"conv_time",
|
|
101
|
+
nn.Conv2d(
|
|
102
|
+
1,
|
|
103
|
+
self.n_first_filters,
|
|
104
|
+
(self.first_filter_length, 1),
|
|
105
|
+
stride=1,
|
|
106
|
+
padding=(self.first_filter_length // 2, 0),
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
self.add_module(
|
|
110
|
+
"conv_spat",
|
|
111
|
+
nn.Conv2d(
|
|
112
|
+
self.n_first_filters,
|
|
113
|
+
self.n_first_filters,
|
|
114
|
+
(1, self.n_chans),
|
|
115
|
+
stride=(1, 1),
|
|
116
|
+
bias=False,
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
self.add_module(
|
|
121
|
+
"conv_time",
|
|
122
|
+
nn.Conv2d(
|
|
123
|
+
self.n_chans,
|
|
124
|
+
self.n_first_filters,
|
|
125
|
+
(self.first_filter_length, 1),
|
|
126
|
+
stride=(1, 1),
|
|
127
|
+
padding=(self.first_filter_length // 2, 0),
|
|
128
|
+
bias=False,
|
|
129
|
+
),
|
|
130
|
+
)
|
|
131
|
+
n_filters_conv = self.n_first_filters
|
|
132
|
+
self.add_module(
|
|
133
|
+
"bnorm",
|
|
134
|
+
nn.BatchNorm2d(
|
|
135
|
+
n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
|
|
136
|
+
),
|
|
137
|
+
)
|
|
138
|
+
self.add_module("conv_nonlin", self.nonlinearity())
|
|
139
|
+
cur_dilation = np.array([1, 1])
|
|
140
|
+
n_cur_filters = n_filters_conv
|
|
141
|
+
i_block = 1
|
|
142
|
+
for i_layer in range(self.n_layers_per_block):
|
|
143
|
+
self.add_module(
|
|
144
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
145
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
146
|
+
)
|
|
147
|
+
i_block += 1
|
|
148
|
+
cur_dilation[0] *= 2
|
|
149
|
+
n_out_filters = int(2 * n_cur_filters)
|
|
150
|
+
self.add_module(
|
|
151
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
152
|
+
_ResidualBlock(
|
|
153
|
+
n_cur_filters,
|
|
154
|
+
n_out_filters,
|
|
155
|
+
dilation=cur_dilation,
|
|
156
|
+
),
|
|
157
|
+
)
|
|
158
|
+
n_cur_filters = n_out_filters
|
|
159
|
+
for i_layer in range(1, self.n_layers_per_block):
|
|
160
|
+
self.add_module(
|
|
161
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
162
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
i_block += 1
|
|
166
|
+
cur_dilation[0] *= 2
|
|
167
|
+
n_out_filters = int(1.5 * n_cur_filters)
|
|
168
|
+
self.add_module(
|
|
169
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
170
|
+
_ResidualBlock(
|
|
171
|
+
n_cur_filters,
|
|
172
|
+
n_out_filters,
|
|
173
|
+
dilation=cur_dilation,
|
|
174
|
+
),
|
|
175
|
+
)
|
|
176
|
+
n_cur_filters = n_out_filters
|
|
177
|
+
for i_layer in range(1, self.n_layers_per_block):
|
|
178
|
+
self.add_module(
|
|
179
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
180
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
i_block += 1
|
|
184
|
+
cur_dilation[0] *= 2
|
|
185
|
+
self.add_module(
|
|
186
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
187
|
+
_ResidualBlock(
|
|
188
|
+
n_cur_filters,
|
|
189
|
+
n_cur_filters,
|
|
190
|
+
dilation=cur_dilation,
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
for i_layer in range(1, self.n_layers_per_block):
|
|
194
|
+
self.add_module(
|
|
195
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
196
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
i_block += 1
|
|
200
|
+
cur_dilation[0] *= 2
|
|
201
|
+
self.add_module(
|
|
202
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
203
|
+
_ResidualBlock(
|
|
204
|
+
n_cur_filters,
|
|
205
|
+
n_cur_filters,
|
|
206
|
+
dilation=cur_dilation,
|
|
207
|
+
),
|
|
208
|
+
)
|
|
209
|
+
for i_layer in range(1, self.n_layers_per_block):
|
|
210
|
+
self.add_module(
|
|
211
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
212
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
i_block += 1
|
|
216
|
+
cur_dilation[0] *= 2
|
|
217
|
+
self.add_module(
|
|
218
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
219
|
+
_ResidualBlock(
|
|
220
|
+
n_cur_filters,
|
|
221
|
+
n_cur_filters,
|
|
222
|
+
dilation=cur_dilation,
|
|
223
|
+
),
|
|
224
|
+
)
|
|
225
|
+
for i_layer in range(1, self.n_layers_per_block):
|
|
226
|
+
self.add_module(
|
|
227
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
228
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
229
|
+
)
|
|
230
|
+
i_block += 1
|
|
231
|
+
cur_dilation[0] *= 2
|
|
232
|
+
self.add_module(
|
|
233
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
234
|
+
_ResidualBlock(
|
|
235
|
+
n_cur_filters,
|
|
236
|
+
n_cur_filters,
|
|
237
|
+
dilation=cur_dilation,
|
|
238
|
+
),
|
|
239
|
+
)
|
|
240
|
+
for i_layer in range(1, self.n_layers_per_block):
|
|
241
|
+
self.add_module(
|
|
242
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
243
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
self.eval()
|
|
247
|
+
if self.final_pool_length == "auto":
|
|
248
|
+
self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
|
|
249
|
+
else:
|
|
250
|
+
pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
|
|
251
|
+
self.add_module(
|
|
252
|
+
"mean_pool",
|
|
253
|
+
AvgPool2dWithConv(
|
|
254
|
+
(self.final_pool_length, 1), (1, 1), dilation=pool_dilation
|
|
255
|
+
),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
259
|
+
module = nn.Sequential()
|
|
260
|
+
|
|
261
|
+
module.add_module(
|
|
262
|
+
"conv_classifier",
|
|
263
|
+
nn.Conv2d(
|
|
264
|
+
n_cur_filters,
|
|
265
|
+
self.n_outputs,
|
|
266
|
+
(1, 1),
|
|
267
|
+
bias=True,
|
|
268
|
+
),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
272
|
+
|
|
273
|
+
self.add_module("final_layer", module)
|
|
274
|
+
|
|
275
|
+
# Initialize all weights
|
|
276
|
+
self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
|
|
277
|
+
|
|
278
|
+
# Start in train mode
|
|
279
|
+
self.train()
|
|
280
|
+
|
|
281
|
+
@staticmethod
|
|
282
|
+
def _weights_init(module, conv_weight_init_fn):
|
|
283
|
+
"""
|
|
284
|
+
initialize weights
|
|
285
|
+
"""
|
|
286
|
+
classname = module.__class__.__name__
|
|
287
|
+
if "Conv" in classname and classname != "AvgPool2dWithConv":
|
|
288
|
+
conv_weight_init_fn(module.weight)
|
|
289
|
+
if module.bias is not None:
|
|
290
|
+
init.constant_(module.bias, 0)
|
|
291
|
+
elif "BatchNorm" in classname:
|
|
292
|
+
init.constant_(module.weight, 1)
|
|
293
|
+
init.constant_(module.bias, 0)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class _ResidualBlock(nn.Module):
|
|
297
|
+
"""
|
|
298
|
+
create a residual learning building block with two stacked 3x3 convlayers as in paper
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __init__(
|
|
302
|
+
self,
|
|
303
|
+
in_filters,
|
|
304
|
+
out_num_filters,
|
|
305
|
+
dilation,
|
|
306
|
+
filter_time_length=3,
|
|
307
|
+
nonlinearity: nn.Module = nn.ELU,
|
|
308
|
+
batch_norm_alpha=0.1,
|
|
309
|
+
batch_norm_epsilon=1e-4,
|
|
310
|
+
):
|
|
311
|
+
super(_ResidualBlock, self).__init__()
|
|
312
|
+
time_padding = int((filter_time_length - 1) * dilation[0])
|
|
313
|
+
assert time_padding % 2 == 0
|
|
314
|
+
time_padding = int(time_padding // 2)
|
|
315
|
+
dilation = (int(dilation[0]), int(dilation[1]))
|
|
316
|
+
assert (out_num_filters - in_filters) % 2 == 0, (
|
|
317
|
+
"Need even number of extra channels in order to be able to pad correctly"
|
|
318
|
+
)
|
|
319
|
+
self.n_pad_chans = out_num_filters - in_filters
|
|
320
|
+
|
|
321
|
+
self.conv_1 = nn.Conv2d(
|
|
322
|
+
in_filters,
|
|
323
|
+
out_num_filters,
|
|
324
|
+
(filter_time_length, 1),
|
|
325
|
+
stride=(1, 1),
|
|
326
|
+
dilation=dilation,
|
|
327
|
+
padding=(time_padding, 0),
|
|
328
|
+
)
|
|
329
|
+
self.bn1 = nn.BatchNorm2d(
|
|
330
|
+
out_num_filters,
|
|
331
|
+
momentum=batch_norm_alpha,
|
|
332
|
+
affine=True,
|
|
333
|
+
eps=batch_norm_epsilon,
|
|
334
|
+
)
|
|
335
|
+
self.conv_2 = nn.Conv2d(
|
|
336
|
+
out_num_filters,
|
|
337
|
+
out_num_filters,
|
|
338
|
+
(filter_time_length, 1),
|
|
339
|
+
stride=(1, 1),
|
|
340
|
+
dilation=dilation,
|
|
341
|
+
padding=(time_padding, 0),
|
|
342
|
+
)
|
|
343
|
+
self.bn2 = nn.BatchNorm2d(
|
|
344
|
+
out_num_filters,
|
|
345
|
+
momentum=batch_norm_alpha,
|
|
346
|
+
affine=True,
|
|
347
|
+
eps=batch_norm_epsilon,
|
|
348
|
+
)
|
|
349
|
+
# also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
|
|
350
|
+
# for resnet options as ilya used them
|
|
351
|
+
self.nonlinearity = nonlinearity()
|
|
352
|
+
|
|
353
|
+
def forward(self, x):
|
|
354
|
+
stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
|
|
355
|
+
stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
|
|
356
|
+
if self.n_pad_chans != 0:
|
|
357
|
+
zeros_for_padding = x.new_zeros(
|
|
358
|
+
(x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
|
|
359
|
+
)
|
|
360
|
+
x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
|
|
361
|
+
out = self.nonlinearity(x + stack_2)
|
|
362
|
+
return out
|