braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
# Cedric Rommel <cedric.rommel@inria.fr>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from braindecode.functional import glorot_weight_zero_bias
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
16
|
+
r"""EEG Inception for ERP-based from Santamaria-Vazquez et al (2020) [santamaria2020]_.
|
|
17
|
+
|
|
18
|
+
:bdg-success:`Convolution`
|
|
19
|
+
|
|
20
|
+
.. figure:: https://braindecode.org/dev/_static/model/eeginceptionerp.jpg
|
|
21
|
+
:align: center
|
|
22
|
+
:alt: EEGInceptionERP Architecture
|
|
23
|
+
|
|
24
|
+
Figure: Overview of EEG-Inception architecture. 2D convolution blocks and depthwise 2D convolution blocks include batch normalization, activation and dropout regularization. The kernel size is displayed for convolutional and average pooling layers.
|
|
25
|
+
|
|
26
|
+
.. rubric:: Architectural Overview
|
|
27
|
+
|
|
28
|
+
A two-stage, multi-scale CNN tailored to ERP detection from short (0-1000 ms) single-trial epochs. Signals are mapped through
|
|
29
|
+
* (i) :class:`_InceptionModule1` multi-scale temporal feature extraction plus per-branch spatial mixing;
|
|
30
|
+
* (ii) :class:`_InceptionModule2` deeper multi-scale refinement at a reduced temporal resolution; and
|
|
31
|
+
* (iii) :class:`_OutputModule` compact aggregation and linear readout.
|
|
32
|
+
|
|
33
|
+
.. rubric:: Macro Components
|
|
34
|
+
|
|
35
|
+
- :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
|
|
36
|
+
|
|
37
|
+
- *Operations.*
|
|
38
|
+
|
|
39
|
+
- `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
|
|
40
|
+
- `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
|
|
41
|
+
- `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
|
|
42
|
+
- `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
|
|
43
|
+
- `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
|
|
44
|
+
- `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
|
|
45
|
+
|
|
46
|
+
*Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
|
|
47
|
+
|
|
48
|
+
- :class:`_InceptionModule2` **(refinement at coarser timebase)**
|
|
49
|
+
|
|
50
|
+
- *Operations.*
|
|
51
|
+
|
|
52
|
+
- `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
|
|
53
|
+
- `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
|
|
54
|
+
- `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
|
|
55
|
+
- `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
|
|
56
|
+
- `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
|
|
57
|
+
- `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
|
|
58
|
+
- `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
|
|
59
|
+
|
|
60
|
+
*Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
|
|
61
|
+
|
|
62
|
+
- :class:`_OutputModule` **(aggregation + readout)**
|
|
63
|
+
|
|
64
|
+
- *Operations.*
|
|
65
|
+
|
|
66
|
+
- :class:`torch.nn.Flatten`
|
|
67
|
+
- :class:`torch.nn.Linear` ``(features → 2)``
|
|
68
|
+
|
|
69
|
+
.. rubric:: Convolutional Details
|
|
70
|
+
|
|
71
|
+
- **Temporal (where time-domain patterns are learned).**
|
|
72
|
+
|
|
73
|
+
First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
|
|
74
|
+
(≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
|
|
75
|
+
``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
|
|
76
|
+
temporal resolution changes only via average pooling.
|
|
77
|
+
|
|
78
|
+
- **Spatial (how electrodes are processed).**
|
|
79
|
+
|
|
80
|
+
Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
|
|
81
|
+
yielding scale-specific channel projections (no cross-branch mixing until concatenation).
|
|
82
|
+
There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
|
|
83
|
+
|
|
84
|
+
- **Spectral (how frequency information is captured).**
|
|
85
|
+
|
|
86
|
+
No explicit transform; multiple temporal kernels form a *learned filter bank* over
|
|
87
|
+
ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
|
|
88
|
+
post-stimulus components.
|
|
89
|
+
|
|
90
|
+
.. rubric:: Additional Mechanisms
|
|
91
|
+
|
|
92
|
+
- Every conv/depthwise block includes **BatchNorm**, nonlinearity (paper used grid-searched activation), and **dropout**.
|
|
93
|
+
- Two Inception stages followed by short convs and pooling keep parameters small (≈15k reported) while preserving multi-scale evidence.
|
|
94
|
+
- Expected input: epochs of shape ``(B,1,128,8)`` (time x channels as a 2D map) or reshaped from ``(B,8,128)`` with an added singleton feature dimension.
|
|
95
|
+
|
|
96
|
+
.. rubric:: Usage and Configuration
|
|
97
|
+
|
|
98
|
+
- **Key knobs.** Number of filters per branch; kernel lengths in both Inception modules; depthwise kernel over channels (typically ``n_chans``); pooling lengths/strides; dropout rate; choice of activation.
|
|
99
|
+
- **Training tips.** Use 0-1000 ms windows at 128 Hz with CAR; tune activation and dropout (they strongly affect performance); early-stop on validation loss when overfitting emerges.
|
|
100
|
+
|
|
101
|
+
.. rubric:: Implementation Details
|
|
102
|
+
|
|
103
|
+
The model is strongly based on the original InceptionNet for an image. The main goal is
|
|
104
|
+
to extract features in parallel with different scales. The authors extracted three scales
|
|
105
|
+
proportional to the window sample size. The network had three parts:
|
|
106
|
+
1-larger inception block largest, 2-smaller inception block followed by 3-bottleneck
|
|
107
|
+
for classification.
|
|
108
|
+
|
|
109
|
+
One advantage of the EEG-Inception block is that it allows a network
|
|
110
|
+
to learn simultaneous components of low and high frequency associated with the signal.
|
|
111
|
+
The winners of BEETL Competition/NeurIps 2021 used parts of the
|
|
112
|
+
model [beetl]_.
|
|
113
|
+
|
|
114
|
+
The code for the paper and this model is also available at [santamaria2020]_
|
|
115
|
+
and an adaptation for PyTorch [2]_.
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
n_times : int, optional
|
|
121
|
+
Size of the input, in number of samples. Set to 128 (1s) as in
|
|
122
|
+
[santamaria2020]_.
|
|
123
|
+
sfreq : float, optional
|
|
124
|
+
EEG sampling frequency. Defaults to 128 as in [santamaria2020]_.
|
|
125
|
+
drop_prob : float, optional
|
|
126
|
+
Dropout rate inside all the network. Defaults to 0.5 as in
|
|
127
|
+
[santamaria2020]_.
|
|
128
|
+
scales_samples_s: list(float), optional
|
|
129
|
+
Windows for inception block. Temporal scale (s) of the convolutions on
|
|
130
|
+
each Inception module. This parameter determines the kernel sizes of
|
|
131
|
+
the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in
|
|
132
|
+
[santamaria2020]_.
|
|
133
|
+
n_filters : int, optional
|
|
134
|
+
Initial number of convolutional filters. Defaults to 8 as in
|
|
135
|
+
[santamaria2020]_.
|
|
136
|
+
activation: nn.Module, optional
|
|
137
|
+
Activation function. Defaults to ELU activation as in
|
|
138
|
+
[santamaria2020]_.
|
|
139
|
+
batch_norm_alpha: float, optional
|
|
140
|
+
Momentum for BatchNorm2d. Defaults to 0.01.
|
|
141
|
+
depth_multiplier: int, optional
|
|
142
|
+
Depth multiplier for the depthwise convolution. Defaults to 2 as in
|
|
143
|
+
[santamaria2020]_.
|
|
144
|
+
pooling_sizes: list(int), optional
|
|
145
|
+
Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as
|
|
146
|
+
in [santamaria2020]_.
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
References
|
|
150
|
+
----------
|
|
151
|
+
.. [santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
|
|
152
|
+
Vaquerizo-Villar, F., & Hornero, R. (2020).
|
|
153
|
+
EEG-inception: A novel deep convolutional neural network for assistive
|
|
154
|
+
ERP-based brain-computer interfaces.
|
|
155
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
|
|
156
|
+
Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
|
|
157
|
+
.. [2] Grifcc. Implementation of the EEGInception in torch (2022).
|
|
158
|
+
Online: https://github.com/Grifcc/EEG/
|
|
159
|
+
.. [beetl] Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
|
|
160
|
+
Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
|
|
161
|
+
Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
|
|
162
|
+
Lawhern, V.J., Śliwowski, M., Rouanne, V. & Tempczyk, P. (2022).
|
|
163
|
+
2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &
|
|
164
|
+
Heterogeneous EEG Data Sets. Proceedings of the NeurIPS 2021 Competitions and
|
|
165
|
+
Demonstrations Track, in Proceedings of Machine Learning Research
|
|
166
|
+
176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
|
|
167
|
+
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
n_chans=None,
|
|
173
|
+
n_outputs=None,
|
|
174
|
+
n_times=1000,
|
|
175
|
+
sfreq=128,
|
|
176
|
+
drop_prob=0.5,
|
|
177
|
+
scales_samples_s=(0.5, 0.25, 0.125),
|
|
178
|
+
n_filters=8,
|
|
179
|
+
activation: type[nn.Module] = nn.ELU,
|
|
180
|
+
batch_norm_alpha=0.01,
|
|
181
|
+
depth_multiplier=2,
|
|
182
|
+
pooling_sizes=(4, 2, 2, 2),
|
|
183
|
+
chs_info=None,
|
|
184
|
+
input_window_seconds=None,
|
|
185
|
+
):
|
|
186
|
+
super().__init__(
|
|
187
|
+
n_outputs=n_outputs,
|
|
188
|
+
n_chans=n_chans,
|
|
189
|
+
chs_info=chs_info,
|
|
190
|
+
n_times=n_times,
|
|
191
|
+
input_window_seconds=input_window_seconds,
|
|
192
|
+
sfreq=sfreq,
|
|
193
|
+
)
|
|
194
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
195
|
+
self.drop_prob = drop_prob
|
|
196
|
+
self.n_filters = n_filters
|
|
197
|
+
self.scales_samples_s = scales_samples_s
|
|
198
|
+
self.scales_samples = tuple(
|
|
199
|
+
int(size_s * self.sfreq) for size_s in self.scales_samples_s
|
|
200
|
+
)
|
|
201
|
+
self.activation = activation
|
|
202
|
+
self.alpha_momentum = batch_norm_alpha
|
|
203
|
+
self.depth_multiplier = depth_multiplier
|
|
204
|
+
self.pooling_sizes = pooling_sizes
|
|
205
|
+
|
|
206
|
+
self.mapping = {
|
|
207
|
+
"classification.1.weight": "final_layer.fc.weight",
|
|
208
|
+
"classification.1.bias": "final_layer.fc.bias",
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
self.add_module("ensuredims", Ensure4d())
|
|
212
|
+
|
|
213
|
+
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
|
|
214
|
+
|
|
215
|
+
# ======== Inception branches ========================
|
|
216
|
+
block11 = self._get_inception_branch_1(
|
|
217
|
+
in_channels=self.n_chans,
|
|
218
|
+
out_channels=self.n_filters,
|
|
219
|
+
kernel_length=self.scales_samples[0],
|
|
220
|
+
alpha_momentum=self.alpha_momentum,
|
|
221
|
+
activation=self.activation,
|
|
222
|
+
drop_prob=self.drop_prob,
|
|
223
|
+
depth_multiplier=self.depth_multiplier,
|
|
224
|
+
)
|
|
225
|
+
block12 = self._get_inception_branch_1(
|
|
226
|
+
in_channels=self.n_chans,
|
|
227
|
+
out_channels=self.n_filters,
|
|
228
|
+
kernel_length=self.scales_samples[1],
|
|
229
|
+
alpha_momentum=self.alpha_momentum,
|
|
230
|
+
activation=self.activation,
|
|
231
|
+
drop_prob=self.drop_prob,
|
|
232
|
+
depth_multiplier=self.depth_multiplier,
|
|
233
|
+
)
|
|
234
|
+
block13 = self._get_inception_branch_1(
|
|
235
|
+
in_channels=self.n_chans,
|
|
236
|
+
out_channels=self.n_filters,
|
|
237
|
+
kernel_length=self.scales_samples[2],
|
|
238
|
+
alpha_momentum=self.alpha_momentum,
|
|
239
|
+
activation=self.activation,
|
|
240
|
+
drop_prob=self.drop_prob,
|
|
241
|
+
depth_multiplier=self.depth_multiplier,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
self.add_module(
|
|
245
|
+
"inception_block_1", InceptionBlock((block11, block12, block13))
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
|
|
249
|
+
|
|
250
|
+
# ======== Inception branches ========================
|
|
251
|
+
n_concat_filters = len(self.scales_samples) * self.n_filters
|
|
252
|
+
n_concat_dw_filters = n_concat_filters * self.depth_multiplier
|
|
253
|
+
block21 = self._get_inception_branch_2(
|
|
254
|
+
in_channels=n_concat_dw_filters,
|
|
255
|
+
out_channels=self.n_filters,
|
|
256
|
+
kernel_length=self.scales_samples[0] // 4,
|
|
257
|
+
alpha_momentum=self.alpha_momentum,
|
|
258
|
+
activation=self.activation,
|
|
259
|
+
drop_prob=self.drop_prob,
|
|
260
|
+
)
|
|
261
|
+
block22 = self._get_inception_branch_2(
|
|
262
|
+
in_channels=n_concat_dw_filters,
|
|
263
|
+
out_channels=self.n_filters,
|
|
264
|
+
kernel_length=self.scales_samples[1] // 4,
|
|
265
|
+
alpha_momentum=self.alpha_momentum,
|
|
266
|
+
activation=self.activation,
|
|
267
|
+
drop_prob=self.drop_prob,
|
|
268
|
+
)
|
|
269
|
+
block23 = self._get_inception_branch_2(
|
|
270
|
+
in_channels=n_concat_dw_filters,
|
|
271
|
+
out_channels=self.n_filters,
|
|
272
|
+
kernel_length=self.scales_samples[2] // 4,
|
|
273
|
+
alpha_momentum=self.alpha_momentum,
|
|
274
|
+
activation=self.activation,
|
|
275
|
+
drop_prob=self.drop_prob,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
self.add_module(
|
|
279
|
+
"inception_block_2", InceptionBlock((block21, block22, block23))
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
|
|
283
|
+
|
|
284
|
+
self.add_module(
|
|
285
|
+
"final_block",
|
|
286
|
+
nn.Sequential(
|
|
287
|
+
nn.Conv2d(
|
|
288
|
+
n_concat_filters,
|
|
289
|
+
n_concat_filters // 2,
|
|
290
|
+
(1, 8),
|
|
291
|
+
padding="same",
|
|
292
|
+
bias=False,
|
|
293
|
+
),
|
|
294
|
+
nn.BatchNorm2d(n_concat_filters // 2, momentum=self.alpha_momentum),
|
|
295
|
+
activation(),
|
|
296
|
+
nn.Dropout(self.drop_prob),
|
|
297
|
+
nn.AvgPool2d((1, self.pooling_sizes[2])),
|
|
298
|
+
nn.Conv2d(
|
|
299
|
+
n_concat_filters // 2,
|
|
300
|
+
n_concat_filters // 4,
|
|
301
|
+
(1, 4),
|
|
302
|
+
padding="same",
|
|
303
|
+
bias=False,
|
|
304
|
+
),
|
|
305
|
+
nn.BatchNorm2d(n_concat_filters // 4, momentum=self.alpha_momentum),
|
|
306
|
+
activation(),
|
|
307
|
+
nn.Dropout(self.drop_prob),
|
|
308
|
+
nn.AvgPool2d((1, self.pooling_sizes[3])),
|
|
309
|
+
),
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
spatial_dim_last_layer = self.n_times // math.prod(self.pooling_sizes)
|
|
313
|
+
n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
|
|
314
|
+
|
|
315
|
+
self.add_module("flat", nn.Flatten())
|
|
316
|
+
|
|
317
|
+
# Incorporating classification module and subsequent ones in one final layer
|
|
318
|
+
module = nn.Sequential()
|
|
319
|
+
|
|
320
|
+
module.add_module(
|
|
321
|
+
"fc",
|
|
322
|
+
nn.Linear(spatial_dim_last_layer * n_channels_last_layer, self.n_outputs),
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
module.add_module("identity", nn.Identity())
|
|
326
|
+
|
|
327
|
+
self.add_module("final_layer", module)
|
|
328
|
+
|
|
329
|
+
glorot_weight_zero_bias(self)
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def _get_inception_branch_1(
|
|
333
|
+
in_channels,
|
|
334
|
+
out_channels,
|
|
335
|
+
kernel_length,
|
|
336
|
+
alpha_momentum,
|
|
337
|
+
drop_prob,
|
|
338
|
+
activation,
|
|
339
|
+
depth_multiplier,
|
|
340
|
+
):
|
|
341
|
+
return nn.Sequential(
|
|
342
|
+
nn.Conv2d(
|
|
343
|
+
1,
|
|
344
|
+
out_channels,
|
|
345
|
+
kernel_size=(1, kernel_length),
|
|
346
|
+
padding="same",
|
|
347
|
+
bias=True,
|
|
348
|
+
),
|
|
349
|
+
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
350
|
+
activation(),
|
|
351
|
+
nn.Dropout(drop_prob),
|
|
352
|
+
DepthwiseConv2d(
|
|
353
|
+
out_channels,
|
|
354
|
+
kernel_size=(in_channels, 1),
|
|
355
|
+
depth_multiplier=depth_multiplier,
|
|
356
|
+
bias=False,
|
|
357
|
+
padding="valid",
|
|
358
|
+
),
|
|
359
|
+
nn.BatchNorm2d(depth_multiplier * out_channels, momentum=alpha_momentum),
|
|
360
|
+
activation(),
|
|
361
|
+
nn.Dropout(drop_prob),
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
@staticmethod
|
|
365
|
+
def _get_inception_branch_2(
|
|
366
|
+
in_channels, out_channels, kernel_length, alpha_momentum, drop_prob, activation
|
|
367
|
+
):
|
|
368
|
+
return nn.Sequential(
|
|
369
|
+
nn.Conv2d(
|
|
370
|
+
in_channels,
|
|
371
|
+
out_channels,
|
|
372
|
+
kernel_size=(1, kernel_length),
|
|
373
|
+
padding="same",
|
|
374
|
+
bias=False,
|
|
375
|
+
),
|
|
376
|
+
nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
|
|
377
|
+
activation(),
|
|
378
|
+
nn.Dropout(drop_prob),
|
|
379
|
+
)
|