braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/eegresnet.py
CHANGED
|
@@ -4,62 +4,67 @@
|
|
|
4
4
|
# License: BSD-3
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
-
|
|
8
7
|
import torch
|
|
8
|
+
from einops.layers.torch import Rearrange
|
|
9
9
|
from torch import nn
|
|
10
10
|
from torch.nn import init
|
|
11
|
-
from torch.nn.functional import elu
|
|
12
|
-
from einops.layers.torch import Rearrange
|
|
13
11
|
|
|
14
|
-
from .
|
|
15
|
-
from .modules import
|
|
16
|
-
|
|
12
|
+
from braindecode.models.base import EEGModuleMixin
|
|
13
|
+
from braindecode.modules import (
|
|
14
|
+
AvgPool2dWithConv,
|
|
15
|
+
Ensure4d,
|
|
16
|
+
SqueezeFinalOutput,
|
|
17
|
+
)
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class EEGResNet(EEGModuleMixin, nn.Sequential):
|
|
20
|
-
"""
|
|
21
|
+
"""EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
|
|
22
|
+
|
|
23
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
|
|
24
|
+
:align: center
|
|
25
|
+
:alt: EEGResNet Architecture
|
|
21
26
|
|
|
22
|
-
|
|
27
|
+
Model described in [Schirrmeister2017]_.
|
|
23
28
|
|
|
24
29
|
Parameters
|
|
25
30
|
----------
|
|
26
|
-
|
|
27
|
-
Alias for
|
|
28
|
-
|
|
29
|
-
Alias for
|
|
30
|
-
|
|
31
|
-
|
|
31
|
+
in_chans :
|
|
32
|
+
Alias for ``n_chans``.
|
|
33
|
+
n_classes :
|
|
34
|
+
Alias for ``n_outputs``.
|
|
35
|
+
input_window_samples :
|
|
36
|
+
Alias for ``n_times``.
|
|
37
|
+
activation: nn.Module, default=nn.ELU
|
|
38
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
39
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
32
40
|
|
|
41
|
+
References
|
|
42
|
+
----------
|
|
43
|
+
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
44
|
+
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
45
|
+
& Ball, T. (2017). Deep learning with convolutional neural networks for ,
|
|
46
|
+
EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
|
|
47
|
+
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
33
48
|
"""
|
|
34
49
|
|
|
35
50
|
def __init__(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
in_chans=None,
|
|
53
|
-
n_classes=None,
|
|
54
|
-
input_window_samples=None,
|
|
55
|
-
add_log_softmax=True,
|
|
51
|
+
self,
|
|
52
|
+
n_chans=None,
|
|
53
|
+
n_outputs=None,
|
|
54
|
+
n_times=None,
|
|
55
|
+
final_pool_length="auto",
|
|
56
|
+
n_first_filters=20,
|
|
57
|
+
n_layers_per_block=2,
|
|
58
|
+
first_filter_length=3,
|
|
59
|
+
activation=nn.ELU,
|
|
60
|
+
split_first_layer=True,
|
|
61
|
+
batch_norm_alpha=0.1,
|
|
62
|
+
batch_norm_epsilon=1e-4,
|
|
63
|
+
conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
|
|
64
|
+
chs_info=None,
|
|
65
|
+
input_window_seconds=None,
|
|
66
|
+
sfreq=250,
|
|
56
67
|
):
|
|
57
|
-
n_chans, n_outputs, n_times = deprecated_args(
|
|
58
|
-
self,
|
|
59
|
-
("in_chans", "n_chans", in_chans, n_chans),
|
|
60
|
-
("n_classes", "n_outputs", n_classes, n_outputs),
|
|
61
|
-
("input_window_samples", "n_times", input_window_samples, n_times),
|
|
62
|
-
)
|
|
63
68
|
super().__init__(
|
|
64
69
|
n_outputs=n_outputs,
|
|
65
70
|
n_chans=n_chans,
|
|
@@ -67,19 +72,17 @@ class EEGResNet(EEGModuleMixin, nn.Sequential):
|
|
|
67
72
|
n_times=n_times,
|
|
68
73
|
input_window_seconds=input_window_seconds,
|
|
69
74
|
sfreq=sfreq,
|
|
70
|
-
add_log_softmax=add_log_softmax,
|
|
71
75
|
)
|
|
72
76
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
73
|
-
del in_chans, n_classes, input_window_samples
|
|
74
77
|
|
|
75
|
-
if final_pool_length ==
|
|
78
|
+
if final_pool_length == "auto":
|
|
76
79
|
assert self.n_times is not None
|
|
77
80
|
assert first_filter_length % 2 == 1
|
|
78
81
|
self.final_pool_length = final_pool_length
|
|
79
82
|
self.n_first_filters = n_first_filters
|
|
80
83
|
self.n_layers_per_block = n_layers_per_block
|
|
81
84
|
self.first_filter_length = first_filter_length
|
|
82
|
-
self.nonlinearity =
|
|
85
|
+
self.nonlinearity = activation
|
|
83
86
|
self.split_first_layer = split_first_layer
|
|
84
87
|
self.batch_norm_alpha = batch_norm_alpha
|
|
85
88
|
self.batch_norm_epsilon = batch_norm_epsilon
|
|
@@ -87,147 +90,207 @@ class EEGResNet(EEGModuleMixin, nn.Sequential):
|
|
|
87
90
|
|
|
88
91
|
self.mapping = {
|
|
89
92
|
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
90
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias"
|
|
93
|
+
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
91
94
|
}
|
|
92
95
|
|
|
93
96
|
self.add_module("ensuredims", Ensure4d())
|
|
94
97
|
if self.split_first_layer:
|
|
95
|
-
self.add_module(
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
98
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
99
|
+
self.add_module(
|
|
100
|
+
"conv_time",
|
|
101
|
+
nn.Conv2d(
|
|
102
|
+
1,
|
|
103
|
+
self.n_first_filters,
|
|
104
|
+
(self.first_filter_length, 1),
|
|
105
|
+
stride=1,
|
|
106
|
+
padding=(self.first_filter_length // 2, 0),
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
self.add_module(
|
|
110
|
+
"conv_spat",
|
|
111
|
+
nn.Conv2d(
|
|
112
|
+
self.n_first_filters,
|
|
113
|
+
self.n_first_filters,
|
|
114
|
+
(1, self.n_chans),
|
|
115
|
+
stride=(1, 1),
|
|
116
|
+
bias=False,
|
|
117
|
+
),
|
|
118
|
+
)
|
|
106
119
|
else:
|
|
107
|
-
self.add_module(
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
120
|
+
self.add_module(
|
|
121
|
+
"conv_time",
|
|
122
|
+
nn.Conv2d(
|
|
123
|
+
self.n_chans,
|
|
124
|
+
self.n_first_filters,
|
|
125
|
+
(self.first_filter_length, 1),
|
|
126
|
+
stride=(1, 1),
|
|
127
|
+
padding=(self.first_filter_length // 2, 0),
|
|
128
|
+
bias=False,
|
|
129
|
+
),
|
|
130
|
+
)
|
|
113
131
|
n_filters_conv = self.n_first_filters
|
|
114
|
-
self.add_module(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
132
|
+
self.add_module(
|
|
133
|
+
"bnorm",
|
|
134
|
+
nn.BatchNorm2d(
|
|
135
|
+
n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
|
|
136
|
+
),
|
|
137
|
+
)
|
|
138
|
+
self.add_module("conv_nonlin", self.nonlinearity())
|
|
120
139
|
cur_dilation = np.array([1, 1])
|
|
121
140
|
n_cur_filters = n_filters_conv
|
|
122
141
|
i_block = 1
|
|
123
142
|
for i_layer in range(self.n_layers_per_block):
|
|
124
|
-
self.add_module(
|
|
125
|
-
|
|
126
|
-
|
|
143
|
+
self.add_module(
|
|
144
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
145
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
146
|
+
)
|
|
127
147
|
i_block += 1
|
|
128
148
|
cur_dilation[0] *= 2
|
|
129
149
|
n_out_filters = int(2 * n_cur_filters)
|
|
130
|
-
self.add_module(
|
|
131
|
-
|
|
132
|
-
|
|
150
|
+
self.add_module(
|
|
151
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
152
|
+
_ResidualBlock(
|
|
153
|
+
n_cur_filters,
|
|
154
|
+
n_out_filters,
|
|
155
|
+
dilation=cur_dilation,
|
|
156
|
+
),
|
|
157
|
+
)
|
|
133
158
|
n_cur_filters = n_out_filters
|
|
134
159
|
for i_layer in range(1, self.n_layers_per_block):
|
|
135
|
-
self.add_module(
|
|
136
|
-
|
|
137
|
-
|
|
160
|
+
self.add_module(
|
|
161
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
162
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
163
|
+
)
|
|
138
164
|
|
|
139
165
|
i_block += 1
|
|
140
166
|
cur_dilation[0] *= 2
|
|
141
167
|
n_out_filters = int(1.5 * n_cur_filters)
|
|
142
|
-
self.add_module(
|
|
143
|
-
|
|
144
|
-
|
|
168
|
+
self.add_module(
|
|
169
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
170
|
+
_ResidualBlock(
|
|
171
|
+
n_cur_filters,
|
|
172
|
+
n_out_filters,
|
|
173
|
+
dilation=cur_dilation,
|
|
174
|
+
),
|
|
175
|
+
)
|
|
145
176
|
n_cur_filters = n_out_filters
|
|
146
177
|
for i_layer in range(1, self.n_layers_per_block):
|
|
147
|
-
self.add_module(
|
|
148
|
-
|
|
149
|
-
|
|
178
|
+
self.add_module(
|
|
179
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
180
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
181
|
+
)
|
|
150
182
|
|
|
151
183
|
i_block += 1
|
|
152
184
|
cur_dilation[0] *= 2
|
|
153
|
-
self.add_module(
|
|
154
|
-
|
|
155
|
-
|
|
185
|
+
self.add_module(
|
|
186
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
187
|
+
_ResidualBlock(
|
|
188
|
+
n_cur_filters,
|
|
189
|
+
n_cur_filters,
|
|
190
|
+
dilation=cur_dilation,
|
|
191
|
+
),
|
|
192
|
+
)
|
|
156
193
|
for i_layer in range(1, self.n_layers_per_block):
|
|
157
|
-
self.add_module(
|
|
158
|
-
|
|
159
|
-
|
|
194
|
+
self.add_module(
|
|
195
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
196
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
197
|
+
)
|
|
160
198
|
|
|
161
199
|
i_block += 1
|
|
162
200
|
cur_dilation[0] *= 2
|
|
163
|
-
self.add_module(
|
|
164
|
-
|
|
165
|
-
|
|
201
|
+
self.add_module(
|
|
202
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
203
|
+
_ResidualBlock(
|
|
204
|
+
n_cur_filters,
|
|
205
|
+
n_cur_filters,
|
|
206
|
+
dilation=cur_dilation,
|
|
207
|
+
),
|
|
208
|
+
)
|
|
166
209
|
for i_layer in range(1, self.n_layers_per_block):
|
|
167
|
-
self.add_module(
|
|
168
|
-
|
|
169
|
-
|
|
210
|
+
self.add_module(
|
|
211
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
212
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
213
|
+
)
|
|
170
214
|
|
|
171
215
|
i_block += 1
|
|
172
216
|
cur_dilation[0] *= 2
|
|
173
|
-
self.add_module(
|
|
174
|
-
|
|
175
|
-
|
|
217
|
+
self.add_module(
|
|
218
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
219
|
+
_ResidualBlock(
|
|
220
|
+
n_cur_filters,
|
|
221
|
+
n_cur_filters,
|
|
222
|
+
dilation=cur_dilation,
|
|
223
|
+
),
|
|
224
|
+
)
|
|
176
225
|
for i_layer in range(1, self.n_layers_per_block):
|
|
177
|
-
self.add_module(
|
|
178
|
-
|
|
179
|
-
|
|
226
|
+
self.add_module(
|
|
227
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
228
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
229
|
+
)
|
|
180
230
|
i_block += 1
|
|
181
231
|
cur_dilation[0] *= 2
|
|
182
|
-
self.add_module(
|
|
183
|
-
|
|
184
|
-
|
|
232
|
+
self.add_module(
|
|
233
|
+
"res_{:d}_{:d}".format(i_block, 0),
|
|
234
|
+
_ResidualBlock(
|
|
235
|
+
n_cur_filters,
|
|
236
|
+
n_cur_filters,
|
|
237
|
+
dilation=cur_dilation,
|
|
238
|
+
),
|
|
239
|
+
)
|
|
185
240
|
for i_layer in range(1, self.n_layers_per_block):
|
|
186
|
-
self.add_module(
|
|
187
|
-
|
|
188
|
-
|
|
241
|
+
self.add_module(
|
|
242
|
+
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
243
|
+
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
244
|
+
)
|
|
189
245
|
|
|
190
246
|
self.eval()
|
|
191
|
-
if self.final_pool_length ==
|
|
192
|
-
self.add_module(
|
|
247
|
+
if self.final_pool_length == "auto":
|
|
248
|
+
self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
|
|
193
249
|
else:
|
|
194
250
|
pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
|
|
195
|
-
self.add_module(
|
|
196
|
-
|
|
197
|
-
|
|
251
|
+
self.add_module(
|
|
252
|
+
"mean_pool",
|
|
253
|
+
AvgPool2dWithConv(
|
|
254
|
+
(self.final_pool_length, 1), (1, 1), dilation=pool_dilation
|
|
255
|
+
),
|
|
256
|
+
)
|
|
198
257
|
|
|
199
258
|
# Incorporating classification module and subsequent ones in one final layer
|
|
200
259
|
module = nn.Sequential()
|
|
201
260
|
|
|
202
|
-
module.add_module(
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
261
|
+
module.add_module(
|
|
262
|
+
"conv_classifier",
|
|
263
|
+
nn.Conv2d(
|
|
264
|
+
n_cur_filters,
|
|
265
|
+
self.n_outputs,
|
|
266
|
+
(1, 1),
|
|
267
|
+
bias=True,
|
|
268
|
+
),
|
|
269
|
+
)
|
|
207
270
|
|
|
208
|
-
module.add_module("squeeze",
|
|
271
|
+
module.add_module("squeeze", SqueezeFinalOutput())
|
|
209
272
|
|
|
210
273
|
self.add_module("final_layer", module)
|
|
211
274
|
|
|
212
275
|
# Initialize all weights
|
|
213
|
-
self.apply(lambda module: _weights_init(module, self.conv_weight_init_fn))
|
|
214
|
-
|
|
215
|
-
# Start in eval mode
|
|
216
|
-
self.eval()
|
|
276
|
+
self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
|
|
217
277
|
|
|
278
|
+
# Start in train mode
|
|
279
|
+
self.train()
|
|
218
280
|
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
281
|
+
@staticmethod
|
|
282
|
+
def _weights_init(module, conv_weight_init_fn):
|
|
283
|
+
"""
|
|
284
|
+
initialize weights
|
|
285
|
+
"""
|
|
286
|
+
classname = module.__class__.__name__
|
|
287
|
+
if "Conv" in classname and classname != "AvgPool2dWithConv":
|
|
288
|
+
conv_weight_init_fn(module.weight)
|
|
289
|
+
if module.bias is not None:
|
|
290
|
+
init.constant_(module.bias, 0)
|
|
291
|
+
elif "BatchNorm" in classname:
|
|
292
|
+
init.constant_(module.weight, 1)
|
|
227
293
|
init.constant_(module.bias, 0)
|
|
228
|
-
elif 'BatchNorm' in classname:
|
|
229
|
-
init.constant_(module.weight, 1)
|
|
230
|
-
init.constant_(module.bias, 0)
|
|
231
294
|
|
|
232
295
|
|
|
233
296
|
class _ResidualBlock(nn.Module):
|
|
@@ -235,46 +298,65 @@ class _ResidualBlock(nn.Module):
|
|
|
235
298
|
create a residual learning building block with two stacked 3x3 convlayers as in paper
|
|
236
299
|
"""
|
|
237
300
|
|
|
238
|
-
def __init__(
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
301
|
+
def __init__(
|
|
302
|
+
self,
|
|
303
|
+
in_filters,
|
|
304
|
+
out_num_filters,
|
|
305
|
+
dilation,
|
|
306
|
+
filter_time_length=3,
|
|
307
|
+
nonlinearity: nn.Module = nn.ELU,
|
|
308
|
+
batch_norm_alpha=0.1,
|
|
309
|
+
batch_norm_epsilon=1e-4,
|
|
310
|
+
):
|
|
244
311
|
super(_ResidualBlock, self).__init__()
|
|
245
312
|
time_padding = int((filter_time_length - 1) * dilation[0])
|
|
246
313
|
assert time_padding % 2 == 0
|
|
247
314
|
time_padding = int(time_padding // 2)
|
|
248
315
|
dilation = (int(dilation[0]), int(dilation[1]))
|
|
249
316
|
assert (out_num_filters - in_filters) % 2 == 0, (
|
|
250
|
-
"Need even number of extra channels in order to be able to "
|
|
251
|
-
|
|
317
|
+
"Need even number of extra channels in order to be able to pad correctly"
|
|
318
|
+
)
|
|
252
319
|
self.n_pad_chans = out_num_filters - in_filters
|
|
253
320
|
|
|
254
321
|
self.conv_1 = nn.Conv2d(
|
|
255
|
-
in_filters,
|
|
322
|
+
in_filters,
|
|
323
|
+
out_num_filters,
|
|
324
|
+
(filter_time_length, 1),
|
|
325
|
+
stride=(1, 1),
|
|
256
326
|
dilation=dilation,
|
|
257
|
-
padding=(time_padding, 0)
|
|
327
|
+
padding=(time_padding, 0),
|
|
328
|
+
)
|
|
258
329
|
self.bn1 = nn.BatchNorm2d(
|
|
259
|
-
out_num_filters,
|
|
260
|
-
|
|
330
|
+
out_num_filters,
|
|
331
|
+
momentum=batch_norm_alpha,
|
|
332
|
+
affine=True,
|
|
333
|
+
eps=batch_norm_epsilon,
|
|
334
|
+
)
|
|
261
335
|
self.conv_2 = nn.Conv2d(
|
|
262
|
-
out_num_filters,
|
|
336
|
+
out_num_filters,
|
|
337
|
+
out_num_filters,
|
|
338
|
+
(filter_time_length, 1),
|
|
339
|
+
stride=(1, 1),
|
|
263
340
|
dilation=dilation,
|
|
264
|
-
padding=(time_padding, 0)
|
|
341
|
+
padding=(time_padding, 0),
|
|
342
|
+
)
|
|
265
343
|
self.bn2 = nn.BatchNorm2d(
|
|
266
|
-
out_num_filters,
|
|
267
|
-
|
|
344
|
+
out_num_filters,
|
|
345
|
+
momentum=batch_norm_alpha,
|
|
346
|
+
affine=True,
|
|
347
|
+
eps=batch_norm_epsilon,
|
|
348
|
+
)
|
|
268
349
|
# also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
|
|
269
350
|
# for resnet options as ilya used them
|
|
270
|
-
self.nonlinearity = nonlinearity
|
|
351
|
+
self.nonlinearity = nonlinearity()
|
|
271
352
|
|
|
272
353
|
def forward(self, x):
|
|
273
354
|
stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
|
|
274
355
|
stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
|
|
275
356
|
if self.n_pad_chans != 0:
|
|
276
|
-
zeros_for_padding =
|
|
277
|
-
|
|
357
|
+
zeros_for_padding = x.new_zeros(
|
|
358
|
+
(x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
|
|
359
|
+
)
|
|
278
360
|
x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
|
|
279
361
|
out = self.nonlinearity(x + stack_2)
|
|
280
362
|
return out
|