braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -2,21 +2,24 @@
|
|
|
2
2
|
# Cedric Rommel <cedric.rommel@inria.fr>
|
|
3
3
|
#
|
|
4
4
|
# License: BSD (3-clause)
|
|
5
|
-
|
|
5
|
+
import math
|
|
6
6
|
|
|
7
|
-
from torch import nn
|
|
8
7
|
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import nn
|
|
9
9
|
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .base import EEGModuleMixin, deprecated_args
|
|
10
|
+
from braindecode.functional import glorot_weight_zero_bias
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
17
|
-
"""EEG Inception for ERP-based
|
|
16
|
+
"""EEG Inception for ERP-based from Santamaria-Vazquez et al (2020) [santamaria2020]_.
|
|
17
|
+
|
|
18
|
+
.. figure:: https://braindecode.org/dev/_static/model/eeginceptionerp.jpg
|
|
19
|
+
:align: center
|
|
20
|
+
:alt: EEGInceptionERP Architecture
|
|
18
21
|
|
|
19
|
-
The code for the paper and this model is also available at [
|
|
22
|
+
The code for the paper and this model is also available at [santamaria2020]_
|
|
20
23
|
and an adaptation for PyTorch [2]_.
|
|
21
24
|
|
|
22
25
|
The model is strongly based on the original InceptionNet for an image. The main goal is
|
|
@@ -27,9 +30,10 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
27
30
|
|
|
28
31
|
One advantage of the EEG-Inception block is that it allows a network
|
|
29
32
|
to learn simultaneous components of low and high frequency associated with the signal.
|
|
30
|
-
The winners of BEETL Competition/NeurIps 2021 used parts of the
|
|
33
|
+
The winners of BEETL Competition/NeurIps 2021 used parts of the
|
|
34
|
+
model [beetl]_.
|
|
31
35
|
|
|
32
|
-
The model is fully described in [
|
|
36
|
+
The model is fully described in [santamaria2020]_.
|
|
33
37
|
|
|
34
38
|
Notes
|
|
35
39
|
-----
|
|
@@ -40,85 +44,70 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
40
44
|
----------
|
|
41
45
|
n_times : int, optional
|
|
42
46
|
Size of the input, in number of samples. Set to 128 (1s) as in
|
|
43
|
-
[
|
|
47
|
+
[santamaria2020]_.
|
|
44
48
|
sfreq : float, optional
|
|
45
|
-
EEG sampling frequency. Defaults to 128 as in [
|
|
49
|
+
EEG sampling frequency. Defaults to 128 as in [santamaria2020]_.
|
|
46
50
|
drop_prob : float, optional
|
|
47
51
|
Dropout rate inside all the network. Defaults to 0.5 as in
|
|
48
|
-
[
|
|
52
|
+
[santamaria2020]_.
|
|
49
53
|
scales_samples_s: list(float), optional
|
|
50
54
|
Windows for inception block. Temporal scale (s) of the convolutions on
|
|
51
55
|
each Inception module. This parameter determines the kernel sizes of
|
|
52
56
|
the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in
|
|
53
|
-
[
|
|
57
|
+
[santamaria2020]_.
|
|
54
58
|
n_filters : int, optional
|
|
55
59
|
Initial number of convolutional filters. Defaults to 8 as in
|
|
56
|
-
[
|
|
60
|
+
[santamaria2020]_.
|
|
57
61
|
activation: nn.Module, optional
|
|
58
62
|
Activation function. Defaults to ELU activation as in
|
|
59
|
-
[
|
|
63
|
+
[santamaria2020]_.
|
|
60
64
|
batch_norm_alpha: float, optional
|
|
61
65
|
Momentum for BatchNorm2d. Defaults to 0.01.
|
|
62
66
|
depth_multiplier: int, optional
|
|
63
67
|
Depth multiplier for the depthwise convolution. Defaults to 2 as in
|
|
64
|
-
[
|
|
68
|
+
[santamaria2020]_.
|
|
65
69
|
pooling_sizes: list(int), optional
|
|
66
70
|
Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as
|
|
67
|
-
in [
|
|
68
|
-
|
|
69
|
-
Alias for n_chans.
|
|
70
|
-
n_classes : int
|
|
71
|
-
Alias for n_outputs.
|
|
72
|
-
input_window_samples : int
|
|
73
|
-
Alias for n_times.
|
|
71
|
+
in [santamaria2020]_.
|
|
72
|
+
|
|
74
73
|
|
|
75
74
|
References
|
|
76
75
|
----------
|
|
77
|
-
.. [
|
|
76
|
+
.. [santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
|
|
78
77
|
Vaquerizo-Villar, F., & Hornero, R. (2020).
|
|
79
78
|
EEG-inception: A novel deep convolutional neural network for assistive
|
|
80
79
|
ERP-based brain-computer interfaces.
|
|
81
80
|
IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
|
|
82
81
|
Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
|
|
83
|
-
.. [2]
|
|
84
|
-
Online: https://github.com/Grifcc/EEG/
|
|
85
|
-
.. [beetl]
|
|
82
|
+
.. [2] Grifcc. Implementation of the EEGInception in torch (2022).
|
|
83
|
+
Online: https://github.com/Grifcc/EEG/
|
|
84
|
+
.. [beetl] Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
|
|
86
85
|
Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
|
|
87
86
|
Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
|
|
88
|
-
Lawhern, V.J., Śliwowski, M., Rouanne, V. & Tempczyk, P
|
|
87
|
+
Lawhern, V.J., Śliwowski, M., Rouanne, V. & Tempczyk, P. (2022).
|
|
89
88
|
2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &
|
|
90
|
-
Heterogeneous EEG Data Sets.
|
|
91
|
-
Demonstrations Track
|
|
89
|
+
Heterogeneous EEG Data Sets. Proceedings of the NeurIPS 2021 Competitions and
|
|
90
|
+
Demonstrations Track, in Proceedings of Machine Learning Research
|
|
92
91
|
176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
|
|
93
92
|
|
|
94
93
|
"""
|
|
95
94
|
|
|
96
95
|
def __init__(
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
in_channels=None,
|
|
112
|
-
n_classes=None,
|
|
113
|
-
input_window_samples=None,
|
|
114
|
-
add_log_softmax=True,
|
|
96
|
+
self,
|
|
97
|
+
n_chans=None,
|
|
98
|
+
n_outputs=None,
|
|
99
|
+
n_times=1000,
|
|
100
|
+
sfreq=128,
|
|
101
|
+
drop_prob=0.5,
|
|
102
|
+
scales_samples_s=(0.5, 0.25, 0.125),
|
|
103
|
+
n_filters=8,
|
|
104
|
+
activation: nn.Module = nn.ELU,
|
|
105
|
+
batch_norm_alpha=0.01,
|
|
106
|
+
depth_multiplier=2,
|
|
107
|
+
pooling_sizes=(4, 2, 2, 2),
|
|
108
|
+
chs_info=None,
|
|
109
|
+
input_window_seconds=None,
|
|
115
110
|
):
|
|
116
|
-
n_chans, n_outputs, n_times, = deprecated_args(
|
|
117
|
-
self,
|
|
118
|
-
('in_channels', 'n_chans', in_channels, n_chans),
|
|
119
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
120
|
-
('input_window_samples', 'n_times', input_window_samples, n_times),
|
|
121
|
-
)
|
|
122
111
|
super().__init__(
|
|
123
112
|
n_outputs=n_outputs,
|
|
124
113
|
n_chans=n_chans,
|
|
@@ -126,23 +115,23 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
126
115
|
n_times=n_times,
|
|
127
116
|
input_window_seconds=input_window_seconds,
|
|
128
117
|
sfreq=sfreq,
|
|
129
|
-
add_log_softmax=add_log_softmax
|
|
130
118
|
)
|
|
131
119
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
132
|
-
del in_channels, n_classes, input_window_samples
|
|
133
120
|
self.drop_prob = drop_prob
|
|
134
121
|
self.n_filters = n_filters
|
|
135
122
|
self.scales_samples_s = scales_samples_s
|
|
136
123
|
self.scales_samples = tuple(
|
|
137
|
-
int(size_s * self.sfreq) for size_s in self.scales_samples_s
|
|
124
|
+
int(size_s * self.sfreq) for size_s in self.scales_samples_s
|
|
125
|
+
)
|
|
138
126
|
self.activation = activation
|
|
139
127
|
self.alpha_momentum = batch_norm_alpha
|
|
140
128
|
self.depth_multiplier = depth_multiplier
|
|
141
129
|
self.pooling_sizes = pooling_sizes
|
|
142
130
|
|
|
143
131
|
self.mapping = {
|
|
144
|
-
|
|
145
|
-
|
|
132
|
+
"classification.1.weight": "final_layer.fc.weight",
|
|
133
|
+
"classification.1.bias": "final_layer.fc.bias",
|
|
134
|
+
}
|
|
146
135
|
|
|
147
136
|
self.add_module("ensuredims", Ensure4d())
|
|
148
137
|
|
|
@@ -177,7 +166,9 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
177
166
|
depth_multiplier=self.depth_multiplier,
|
|
178
167
|
)
|
|
179
168
|
|
|
180
|
-
self.add_module(
|
|
169
|
+
self.add_module(
|
|
170
|
+
"inception_block_1", InceptionBlock((block11, block12, block13))
|
|
171
|
+
)
|
|
181
172
|
|
|
182
173
|
self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
|
|
183
174
|
|
|
@@ -190,7 +181,7 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
190
181
|
kernel_length=self.scales_samples[0] // 4,
|
|
191
182
|
alpha_momentum=self.alpha_momentum,
|
|
192
183
|
activation=self.activation,
|
|
193
|
-
drop_prob=self.drop_prob
|
|
184
|
+
drop_prob=self.drop_prob,
|
|
194
185
|
)
|
|
195
186
|
block22 = self._get_inception_branch_2(
|
|
196
187
|
in_channels=n_concat_dw_filters,
|
|
@@ -198,7 +189,7 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
198
189
|
kernel_length=self.scales_samples[1] // 4,
|
|
199
190
|
alpha_momentum=self.alpha_momentum,
|
|
200
191
|
activation=self.activation,
|
|
201
|
-
drop_prob=self.drop_prob
|
|
192
|
+
drop_prob=self.drop_prob,
|
|
202
193
|
)
|
|
203
194
|
block23 = self._get_inception_branch_2(
|
|
204
195
|
in_channels=n_concat_dw_filters,
|
|
@@ -206,44 +197,44 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
206
197
|
kernel_length=self.scales_samples[2] // 4,
|
|
207
198
|
alpha_momentum=self.alpha_momentum,
|
|
208
199
|
activation=self.activation,
|
|
209
|
-
drop_prob=self.drop_prob
|
|
200
|
+
drop_prob=self.drop_prob,
|
|
210
201
|
)
|
|
211
202
|
|
|
212
203
|
self.add_module(
|
|
213
|
-
"inception_block_2",
|
|
204
|
+
"inception_block_2", InceptionBlock((block21, block22, block23))
|
|
205
|
+
)
|
|
214
206
|
|
|
215
207
|
self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
|
|
216
208
|
|
|
217
|
-
self.add_module(
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
209
|
+
self.add_module(
|
|
210
|
+
"final_block",
|
|
211
|
+
nn.Sequential(
|
|
212
|
+
nn.Conv2d(
|
|
213
|
+
n_concat_filters,
|
|
214
|
+
n_concat_filters // 2,
|
|
215
|
+
(1, 8),
|
|
216
|
+
padding="same",
|
|
217
|
+
bias=False,
|
|
218
|
+
),
|
|
219
|
+
nn.BatchNorm2d(n_concat_filters // 2, momentum=self.alpha_momentum),
|
|
220
|
+
activation(),
|
|
221
|
+
nn.Dropout(self.drop_prob),
|
|
222
|
+
nn.AvgPool2d((1, self.pooling_sizes[2])),
|
|
223
|
+
nn.Conv2d(
|
|
224
|
+
n_concat_filters // 2,
|
|
225
|
+
n_concat_filters // 4,
|
|
226
|
+
(1, 4),
|
|
227
|
+
padding="same",
|
|
228
|
+
bias=False,
|
|
229
|
+
),
|
|
230
|
+
nn.BatchNorm2d(n_concat_filters // 4, momentum=self.alpha_momentum),
|
|
231
|
+
activation(),
|
|
232
|
+
nn.Dropout(self.drop_prob),
|
|
233
|
+
nn.AvgPool2d((1, self.pooling_sizes[3])),
|
|
224
234
|
),
|
|
225
|
-
|
|
226
|
-
momentum=self.alpha_momentum),
|
|
227
|
-
activation,
|
|
228
|
-
nn.Dropout(self.drop_prob),
|
|
229
|
-
nn.AvgPool2d((1, self.pooling_sizes[2])),
|
|
235
|
+
)
|
|
230
236
|
|
|
231
|
-
|
|
232
|
-
n_concat_filters // 2,
|
|
233
|
-
n_concat_filters // 4,
|
|
234
|
-
(1, 4),
|
|
235
|
-
padding="same",
|
|
236
|
-
bias=False
|
|
237
|
-
),
|
|
238
|
-
nn.BatchNorm2d(n_concat_filters // 4,
|
|
239
|
-
momentum=self.alpha_momentum),
|
|
240
|
-
activation,
|
|
241
|
-
nn.Dropout(self.drop_prob),
|
|
242
|
-
nn.AvgPool2d((1, self.pooling_sizes[3])),
|
|
243
|
-
))
|
|
244
|
-
|
|
245
|
-
spatial_dim_last_layer = (
|
|
246
|
-
self.n_times // prod(self.pooling_sizes))
|
|
237
|
+
spatial_dim_last_layer = self.n_times // math.prod(self.pooling_sizes)
|
|
247
238
|
n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
|
|
248
239
|
|
|
249
240
|
self.add_module("flat", nn.Flatten())
|
|
@@ -251,63 +242,63 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
251
242
|
# Incorporating classification module and subsequent ones in one final layer
|
|
252
243
|
module = nn.Sequential()
|
|
253
244
|
|
|
254
|
-
module.add_module(
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
), )
|
|
245
|
+
module.add_module(
|
|
246
|
+
"fc",
|
|
247
|
+
nn.Linear(spatial_dim_last_layer * n_channels_last_layer, self.n_outputs),
|
|
248
|
+
)
|
|
259
249
|
|
|
260
|
-
|
|
261
|
-
module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
|
|
262
|
-
else:
|
|
263
|
-
module.add_module("identity", nn.Identity())
|
|
250
|
+
module.add_module("identity", nn.Identity())
|
|
264
251
|
|
|
265
252
|
self.add_module("final_layer", module)
|
|
266
253
|
|
|
267
|
-
|
|
254
|
+
glorot_weight_zero_bias(self)
|
|
268
255
|
|
|
269
256
|
@staticmethod
|
|
270
|
-
def _get_inception_branch_1(
|
|
271
|
-
|
|
272
|
-
|
|
257
|
+
def _get_inception_branch_1(
|
|
258
|
+
in_channels,
|
|
259
|
+
out_channels,
|
|
260
|
+
kernel_length,
|
|
261
|
+
alpha_momentum,
|
|
262
|
+
drop_prob,
|
|
263
|
+
activation,
|
|
264
|
+
depth_multiplier,
|
|
265
|
+
):
|
|
273
266
|
return nn.Sequential(
|
|
274
267
|
nn.Conv2d(
|
|
275
268
|
1,
|
|
276
269
|
out_channels,
|
|
277
270
|
kernel_size=(1, kernel_length),
|
|
278
271
|
padding="same",
|
|
279
|
-
bias=True
|
|
272
|
+
bias=True,
|
|
280
273
|
),
|
|
281
274
|
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
282
|
-
activation,
|
|
275
|
+
activation(),
|
|
283
276
|
nn.Dropout(drop_prob),
|
|
284
|
-
|
|
277
|
+
DepthwiseConv2d(
|
|
285
278
|
out_channels,
|
|
286
279
|
kernel_size=(in_channels, 1),
|
|
287
280
|
depth_multiplier=depth_multiplier,
|
|
288
281
|
bias=False,
|
|
289
282
|
padding="valid",
|
|
290
283
|
),
|
|
291
|
-
nn.BatchNorm2d(
|
|
292
|
-
|
|
293
|
-
momentum=alpha_momentum
|
|
294
|
-
),
|
|
295
|
-
activation,
|
|
284
|
+
nn.BatchNorm2d(depth_multiplier * out_channels, momentum=alpha_momentum),
|
|
285
|
+
activation(),
|
|
296
286
|
nn.Dropout(drop_prob),
|
|
297
287
|
)
|
|
298
288
|
|
|
299
289
|
@staticmethod
|
|
300
|
-
def _get_inception_branch_2(
|
|
301
|
-
|
|
290
|
+
def _get_inception_branch_2(
|
|
291
|
+
in_channels, out_channels, kernel_length, alpha_momentum, drop_prob, activation
|
|
292
|
+
):
|
|
302
293
|
return nn.Sequential(
|
|
303
294
|
nn.Conv2d(
|
|
304
295
|
in_channels,
|
|
305
296
|
out_channels,
|
|
306
297
|
kernel_size=(1, kernel_length),
|
|
307
298
|
padding="same",
|
|
308
|
-
bias=False
|
|
299
|
+
bias=False,
|
|
309
300
|
),
|
|
310
301
|
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
311
|
-
activation,
|
|
302
|
+
activation(),
|
|
312
303
|
nn.Dropout(drop_prob),
|
|
313
304
|
)
|