braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,372 @@
1
+ # Authors: Yonghao Song <eeyhsong@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import warnings
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from einops.layers.torch import Rearrange
11
+ from torch import Tensor, nn
12
+
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import FeedForwardBlock, MultiHeadAttention
15
+
16
+
17
+ class EEGConformer(EEGModuleMixin, nn.Module):
18
+ """EEG Conformer from Song et al. (2022) from [song2022]_.
19
+
20
+ .. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
21
+ :align: center
22
+ :alt: EEGConformer Architecture
23
+
24
+ Convolutional Transformer for EEG decoding.
25
+
26
+ The paper and original code with more details about the methodological
27
+ choices are available at the [song2022]_ and [ConformerCode]_.
28
+
29
+ This neural network architecture receives a traditional braindecode input.
30
+ The input shape should be three-dimensional matrix representing the EEG
31
+ signals.
32
+
33
+ `(batch_size, n_channels, n_timesteps)`.
34
+
35
+ The EEG Conformer architecture is composed of three modules:
36
+ - PatchEmbedding
37
+ - TransformerEncoder
38
+ - ClassificationHead
39
+
40
+ Notes
41
+ -----
42
+ The authors recommend using data augmentation before using Conformer,
43
+ e.g. segmentation and recombination,
44
+ Please refer to the original paper and code for more details.
45
+
46
+ The model was initially tuned on 4 seconds of 250 Hz data.
47
+ Please adjust the scale of the temporal convolutional layer,
48
+ and the pooling layer for better performance.
49
+
50
+ .. versionadded:: 0.8
51
+
52
+ We aggregate the parameters based on the parts of the models, or
53
+ when the parameters were used first, e.g. n_filters_time.
54
+
55
+ Parameters
56
+ ----------
57
+ n_filters_time: int
58
+ Number of temporal filters, defines also embedding size.
59
+ filter_time_length: int
60
+ Length of the temporal filter.
61
+ pool_time_length: int
62
+ Length of temporal pooling filter.
63
+ pool_time_stride: int
64
+ Length of stride between temporal pooling filters.
65
+ drop_prob: float
66
+ Dropout rate of the convolutional layer.
67
+ att_depth: int
68
+ Number of self-attention layers.
69
+ att_heads: int
70
+ Number of attention heads.
71
+ att_drop_prob: float
72
+ Dropout rate of the self-attention layer.
73
+ final_fc_length: int | str
74
+ The dimension of the fully connected layer.
75
+ return_features: bool
76
+ If True, the forward method returns the features before the
77
+ last classification layer. Defaults to False.
78
+ activation: nn.Module
79
+ Activation function as parameter. Default is nn.ELU
80
+ activation_transfor: nn.Module
81
+ Activation function as parameter, applied at the FeedForwardBlock module
82
+ inside the transformer. Default is nn.GeLU
83
+
84
+ References
85
+ ----------
86
+ .. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
87
+ conformer: Convolutional transformer for EEG decoding and visualization.
88
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering,
89
+ 31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
90
+ .. [ConformerCode] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
91
+ conformer: Convolutional transformer for EEG decoding and visualization.
92
+ https://github.com/eeyhsong/EEG-Conformer.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ n_outputs=None,
98
+ n_chans=None,
99
+ n_filters_time=40,
100
+ filter_time_length=25,
101
+ pool_time_length=75,
102
+ pool_time_stride=15,
103
+ drop_prob=0.5,
104
+ att_depth=6,
105
+ att_heads=10,
106
+ att_drop_prob=0.5,
107
+ final_fc_length="auto",
108
+ return_features=False,
109
+ activation: nn.Module = nn.ELU,
110
+ activation_transfor: nn.Module = nn.GELU,
111
+ n_times=None,
112
+ chs_info=None,
113
+ input_window_seconds=None,
114
+ sfreq=None,
115
+ ):
116
+ super().__init__(
117
+ n_outputs=n_outputs,
118
+ n_chans=n_chans,
119
+ chs_info=chs_info,
120
+ n_times=n_times,
121
+ input_window_seconds=input_window_seconds,
122
+ sfreq=sfreq,
123
+ )
124
+ self.mapping = {
125
+ "classification_head.fc.6.weight": "final_layer.final_layer.0.weight",
126
+ "classification_head.fc.6.bias": "final_layer.final_layer.0.bias",
127
+ }
128
+
129
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
130
+ if not (self.n_chans <= 64):
131
+ warnings.warn(
132
+ "This model has only been tested on no more "
133
+ + "than 64 channels. no guarantee to work with "
134
+ + "more channels.",
135
+ UserWarning,
136
+ )
137
+
138
+ self.return_features = return_features
139
+
140
+ self.patch_embedding = _PatchEmbedding(
141
+ n_filters_time=n_filters_time,
142
+ filter_time_length=filter_time_length,
143
+ n_channels=self.n_chans,
144
+ pool_time_length=pool_time_length,
145
+ stride_avg_pool=pool_time_stride,
146
+ drop_prob=drop_prob,
147
+ activation=activation,
148
+ )
149
+
150
+ if final_fc_length == "auto":
151
+ assert self.n_times is not None
152
+ self.final_fc_length = self.get_fc_size()
153
+
154
+ self.transformer = _TransformerEncoder(
155
+ att_depth=att_depth,
156
+ emb_size=n_filters_time,
157
+ att_heads=att_heads,
158
+ att_drop=att_drop_prob,
159
+ activation=activation_transfor,
160
+ )
161
+
162
+ self.fc = _FullyConnected(
163
+ final_fc_length=self.final_fc_length, activation=activation
164
+ )
165
+
166
+ self.final_layer = nn.Linear(self.fc.hidden_channels, self.n_outputs)
167
+
168
+ def forward(self, x: Tensor) -> Tensor:
169
+ x = torch.unsqueeze(x, dim=1) # add one extra dimension
170
+ x = self.patch_embedding(x)
171
+ feature = self.transformer(x)
172
+
173
+ if self.return_features:
174
+ return feature
175
+
176
+ x = self.fc(feature)
177
+ x = self.final_layer(x)
178
+ return x
179
+
180
+ def get_fc_size(self):
181
+ out = self.patch_embedding(torch.ones((1, 1, self.n_chans, self.n_times)))
182
+ size_embedding_1 = out.cpu().data.numpy().shape[1]
183
+ size_embedding_2 = out.cpu().data.numpy().shape[2]
184
+
185
+ return size_embedding_1 * size_embedding_2
186
+
187
+
188
+ class _PatchEmbedding(nn.Module):
189
+ """Patch Embedding.
190
+
191
+ The authors used a convolution module to capture local features,
192
+ instead of position embedding.
193
+
194
+ Parameters
195
+ ----------
196
+ n_filters_time: int
197
+ Number of temporal filters, defines also embedding size.
198
+ filter_time_length: int
199
+ Length of the temporal filter.
200
+ n_channels: int
201
+ Number of channels to be used as number of spatial filters.
202
+ pool_time_length: int
203
+ Length of temporal poling filter.
204
+ stride_avg_pool: int
205
+ Length of stride between temporal pooling filters.
206
+ drop_prob: float
207
+ Dropout rate of the convolutional layer.
208
+
209
+ Returns
210
+ -------
211
+ x: torch.Tensor
212
+ The output tensor of the patch embedding layer.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ n_filters_time,
218
+ filter_time_length,
219
+ n_channels,
220
+ pool_time_length,
221
+ stride_avg_pool,
222
+ drop_prob,
223
+ activation: nn.Module = nn.ELU,
224
+ ):
225
+ super().__init__()
226
+
227
+ self.shallownet = nn.Sequential(
228
+ nn.Conv2d(1, n_filters_time, (1, filter_time_length), (1, 1)),
229
+ nn.Conv2d(n_filters_time, n_filters_time, (n_channels, 1), (1, 1)),
230
+ nn.BatchNorm2d(num_features=n_filters_time),
231
+ activation(),
232
+ nn.AvgPool2d(
233
+ kernel_size=(1, pool_time_length), stride=(1, stride_avg_pool)
234
+ ),
235
+ # pooling acts as slicing to obtain 'patch' along the
236
+ # time dimension as in ViT
237
+ nn.Dropout(p=drop_prob),
238
+ )
239
+
240
+ self.projection = nn.Sequential(
241
+ nn.Conv2d(
242
+ n_filters_time, n_filters_time, (1, 1), stride=(1, 1)
243
+ ), # transpose, conv could enhance fiting ability slightly
244
+ Rearrange("b d_model 1 seq -> b seq d_model"),
245
+ )
246
+
247
+ def forward(self, x: Tensor) -> Tensor:
248
+ x = self.shallownet(x)
249
+ x = self.projection(x)
250
+ return x
251
+
252
+
253
+ class _ResidualAdd(nn.Module):
254
+ def __init__(self, fn):
255
+ super().__init__()
256
+ self.fn = fn
257
+
258
+ def forward(self, x):
259
+ res = x
260
+ x = self.fn(x)
261
+ x += res
262
+ return x
263
+
264
+
265
+ class _TransformerEncoderBlock(nn.Sequential):
266
+ def __init__(
267
+ self,
268
+ emb_size,
269
+ att_heads,
270
+ att_drop,
271
+ forward_expansion=4,
272
+ activation: nn.Module = nn.GELU,
273
+ ):
274
+ super().__init__(
275
+ _ResidualAdd(
276
+ nn.Sequential(
277
+ nn.LayerNorm(emb_size),
278
+ MultiHeadAttention(emb_size, att_heads, att_drop),
279
+ nn.Dropout(att_drop),
280
+ )
281
+ ),
282
+ _ResidualAdd(
283
+ nn.Sequential(
284
+ nn.LayerNorm(emb_size),
285
+ FeedForwardBlock(
286
+ emb_size,
287
+ expansion=forward_expansion,
288
+ drop_p=att_drop,
289
+ activation=activation,
290
+ ),
291
+ nn.Dropout(att_drop),
292
+ )
293
+ ),
294
+ )
295
+
296
+
297
+ class _TransformerEncoder(nn.Sequential):
298
+ """Transformer encoder module for the transformer encoder.
299
+
300
+ Similar to the layers used in ViT.
301
+
302
+ Parameters
303
+ ----------
304
+ att_depth : int
305
+ Number of transformer encoder blocks.
306
+ emb_size : int
307
+ Embedding size of the transformer encoder.
308
+ att_heads : int
309
+ Number of attention heads.
310
+ att_drop : float
311
+ Dropout probability for the attention layers.
312
+
313
+ """
314
+
315
+ def __init__(
316
+ self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
317
+ ):
318
+ super().__init__(
319
+ *[
320
+ _TransformerEncoderBlock(
321
+ emb_size, att_heads, att_drop, activation=activation
322
+ )
323
+ for _ in range(att_depth)
324
+ ]
325
+ )
326
+
327
+
328
+ class _FullyConnected(nn.Module):
329
+ def __init__(
330
+ self,
331
+ final_fc_length,
332
+ drop_prob_1=0.5,
333
+ drop_prob_2=0.3,
334
+ out_channels=256,
335
+ hidden_channels=32,
336
+ activation: nn.Module = nn.ELU,
337
+ ):
338
+ """Fully-connected layer for the transformer encoder.
339
+
340
+ Parameters
341
+ ----------
342
+ final_fc_length : int
343
+ Length of the final fully connected layer.
344
+ n_classes : int
345
+ Number of classes for classification.
346
+ drop_prob_1 : float
347
+ Dropout probability for the first dropout layer.
348
+ drop_prob_2 : float
349
+ Dropout probability for the second dropout layer.
350
+ out_channels : int
351
+ Number of output channels for the first linear layer.
352
+ hidden_channels : int
353
+ Number of output channels for the second linear layer.
354
+ return_features : bool
355
+ Whether to return input features.
356
+ """
357
+
358
+ super().__init__()
359
+ self.hidden_channels = hidden_channels
360
+ self.fc = nn.Sequential(
361
+ nn.Linear(final_fc_length, out_channels),
362
+ activation(),
363
+ nn.Dropout(drop_prob_1),
364
+ nn.Linear(out_channels, hidden_channels),
365
+ activation(),
366
+ nn.Dropout(drop_prob_2),
367
+ )
368
+
369
+ def forward(self, x):
370
+ x = x.contiguous().view(x.size(0), -1)
371
+ out = self.fc(x)
372
+ return out
@@ -0,0 +1,304 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ # Cedric Rommel <cedric.rommel@inria.fr>
3
+ #
4
+ # License: BSD (3-clause)
5
+ import math
6
+
7
+ from einops.layers.torch import Rearrange
8
+ from torch import nn
9
+
10
+ from braindecode.functional import glorot_weight_zero_bias
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
13
+
14
+
15
+ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
16
+ """EEG Inception for ERP-based from Santamaria-Vazquez et al (2020) [santamaria2020]_.
17
+
18
+ .. figure:: https://braindecode.org/dev/_static/model/eeginceptionerp.jpg
19
+ :align: center
20
+ :alt: EEGInceptionERP Architecture
21
+
22
+ The code for the paper and this model is also available at [santamaria2020]_
23
+ and an adaptation for PyTorch [2]_.
24
+
25
+ The model is strongly based on the original InceptionNet for an image. The main goal is
26
+ to extract features in parallel with different scales. The authors extracted three scales
27
+ proportional to the window sample size. The network had three parts:
28
+ 1-larger inception block largest, 2-smaller inception block followed by 3-bottleneck
29
+ for classification.
30
+
31
+ One advantage of the EEG-Inception block is that it allows a network
32
+ to learn simultaneous components of low and high frequency associated with the signal.
33
+ The winners of BEETL Competition/NeurIps 2021 used parts of the
34
+ model [beetl]_.
35
+
36
+ The model is fully described in [santamaria2020]_.
37
+
38
+ Notes
39
+ -----
40
+ This implementation is not guaranteed to be correct, has not been checked
41
+ by original authors, only reimplemented from the paper based on [2]_.
42
+
43
+ Parameters
44
+ ----------
45
+ n_times : int, optional
46
+ Size of the input, in number of samples. Set to 128 (1s) as in
47
+ [santamaria2020]_.
48
+ sfreq : float, optional
49
+ EEG sampling frequency. Defaults to 128 as in [santamaria2020]_.
50
+ drop_prob : float, optional
51
+ Dropout rate inside all the network. Defaults to 0.5 as in
52
+ [santamaria2020]_.
53
+ scales_samples_s: list(float), optional
54
+ Windows for inception block. Temporal scale (s) of the convolutions on
55
+ each Inception module. This parameter determines the kernel sizes of
56
+ the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in
57
+ [santamaria2020]_.
58
+ n_filters : int, optional
59
+ Initial number of convolutional filters. Defaults to 8 as in
60
+ [santamaria2020]_.
61
+ activation: nn.Module, optional
62
+ Activation function. Defaults to ELU activation as in
63
+ [santamaria2020]_.
64
+ batch_norm_alpha: float, optional
65
+ Momentum for BatchNorm2d. Defaults to 0.01.
66
+ depth_multiplier: int, optional
67
+ Depth multiplier for the depthwise convolution. Defaults to 2 as in
68
+ [santamaria2020]_.
69
+ pooling_sizes: list(int), optional
70
+ Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as
71
+ in [santamaria2020]_.
72
+
73
+
74
+ References
75
+ ----------
76
+ .. [santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V.,
77
+ Vaquerizo-Villar, F., & Hornero, R. (2020).
78
+ EEG-inception: A novel deep convolutional neural network for assistive
79
+ ERP-based brain-computer interfaces.
80
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28.
81
+ Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106
82
+ .. [2] Grifcc. Implementation of the EEGInception in torch (2022).
83
+ Online: https://github.com/Grifcc/EEG/
84
+ .. [beetl] Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S.,
85
+ Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis,
86
+ Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M.,
87
+ Lawhern, V.J., Śliwowski, M., Rouanne, V. &amp; Tempczyk, P. (2022).
88
+ 2021 BEETL Competition: Advancing Transfer Learning for Subject Independence &amp;
89
+ Heterogeneous EEG Data Sets. Proceedings of the NeurIPS 2021 Competitions and
90
+ Demonstrations Track, in Proceedings of Machine Learning Research
91
+ 176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html.
92
+
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ n_chans=None,
98
+ n_outputs=None,
99
+ n_times=1000,
100
+ sfreq=128,
101
+ drop_prob=0.5,
102
+ scales_samples_s=(0.5, 0.25, 0.125),
103
+ n_filters=8,
104
+ activation: nn.Module = nn.ELU,
105
+ batch_norm_alpha=0.01,
106
+ depth_multiplier=2,
107
+ pooling_sizes=(4, 2, 2, 2),
108
+ chs_info=None,
109
+ input_window_seconds=None,
110
+ ):
111
+ super().__init__(
112
+ n_outputs=n_outputs,
113
+ n_chans=n_chans,
114
+ chs_info=chs_info,
115
+ n_times=n_times,
116
+ input_window_seconds=input_window_seconds,
117
+ sfreq=sfreq,
118
+ )
119
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
120
+ self.drop_prob = drop_prob
121
+ self.n_filters = n_filters
122
+ self.scales_samples_s = scales_samples_s
123
+ self.scales_samples = tuple(
124
+ int(size_s * self.sfreq) for size_s in self.scales_samples_s
125
+ )
126
+ self.activation = activation
127
+ self.alpha_momentum = batch_norm_alpha
128
+ self.depth_multiplier = depth_multiplier
129
+ self.pooling_sizes = pooling_sizes
130
+
131
+ self.mapping = {
132
+ "classification.1.weight": "final_layer.fc.weight",
133
+ "classification.1.bias": "final_layer.fc.bias",
134
+ }
135
+
136
+ self.add_module("ensuredims", Ensure4d())
137
+
138
+ self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
139
+
140
+ # ======== Inception branches ========================
141
+ block11 = self._get_inception_branch_1(
142
+ in_channels=self.n_chans,
143
+ out_channels=self.n_filters,
144
+ kernel_length=self.scales_samples[0],
145
+ alpha_momentum=self.alpha_momentum,
146
+ activation=self.activation,
147
+ drop_prob=self.drop_prob,
148
+ depth_multiplier=self.depth_multiplier,
149
+ )
150
+ block12 = self._get_inception_branch_1(
151
+ in_channels=self.n_chans,
152
+ out_channels=self.n_filters,
153
+ kernel_length=self.scales_samples[1],
154
+ alpha_momentum=self.alpha_momentum,
155
+ activation=self.activation,
156
+ drop_prob=self.drop_prob,
157
+ depth_multiplier=self.depth_multiplier,
158
+ )
159
+ block13 = self._get_inception_branch_1(
160
+ in_channels=self.n_chans,
161
+ out_channels=self.n_filters,
162
+ kernel_length=self.scales_samples[2],
163
+ alpha_momentum=self.alpha_momentum,
164
+ activation=self.activation,
165
+ drop_prob=self.drop_prob,
166
+ depth_multiplier=self.depth_multiplier,
167
+ )
168
+
169
+ self.add_module(
170
+ "inception_block_1", InceptionBlock((block11, block12, block13))
171
+ )
172
+
173
+ self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0])))
174
+
175
+ # ======== Inception branches ========================
176
+ n_concat_filters = len(self.scales_samples) * self.n_filters
177
+ n_concat_dw_filters = n_concat_filters * self.depth_multiplier
178
+ block21 = self._get_inception_branch_2(
179
+ in_channels=n_concat_dw_filters,
180
+ out_channels=self.n_filters,
181
+ kernel_length=self.scales_samples[0] // 4,
182
+ alpha_momentum=self.alpha_momentum,
183
+ activation=self.activation,
184
+ drop_prob=self.drop_prob,
185
+ )
186
+ block22 = self._get_inception_branch_2(
187
+ in_channels=n_concat_dw_filters,
188
+ out_channels=self.n_filters,
189
+ kernel_length=self.scales_samples[1] // 4,
190
+ alpha_momentum=self.alpha_momentum,
191
+ activation=self.activation,
192
+ drop_prob=self.drop_prob,
193
+ )
194
+ block23 = self._get_inception_branch_2(
195
+ in_channels=n_concat_dw_filters,
196
+ out_channels=self.n_filters,
197
+ kernel_length=self.scales_samples[2] // 4,
198
+ alpha_momentum=self.alpha_momentum,
199
+ activation=self.activation,
200
+ drop_prob=self.drop_prob,
201
+ )
202
+
203
+ self.add_module(
204
+ "inception_block_2", InceptionBlock((block21, block22, block23))
205
+ )
206
+
207
+ self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1])))
208
+
209
+ self.add_module(
210
+ "final_block",
211
+ nn.Sequential(
212
+ nn.Conv2d(
213
+ n_concat_filters,
214
+ n_concat_filters // 2,
215
+ (1, 8),
216
+ padding="same",
217
+ bias=False,
218
+ ),
219
+ nn.BatchNorm2d(n_concat_filters // 2, momentum=self.alpha_momentum),
220
+ activation(),
221
+ nn.Dropout(self.drop_prob),
222
+ nn.AvgPool2d((1, self.pooling_sizes[2])),
223
+ nn.Conv2d(
224
+ n_concat_filters // 2,
225
+ n_concat_filters // 4,
226
+ (1, 4),
227
+ padding="same",
228
+ bias=False,
229
+ ),
230
+ nn.BatchNorm2d(n_concat_filters // 4, momentum=self.alpha_momentum),
231
+ activation(),
232
+ nn.Dropout(self.drop_prob),
233
+ nn.AvgPool2d((1, self.pooling_sizes[3])),
234
+ ),
235
+ )
236
+
237
+ spatial_dim_last_layer = self.n_times // math.prod(self.pooling_sizes)
238
+ n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4
239
+
240
+ self.add_module("flat", nn.Flatten())
241
+
242
+ # Incorporating classification module and subsequent ones in one final layer
243
+ module = nn.Sequential()
244
+
245
+ module.add_module(
246
+ "fc",
247
+ nn.Linear(spatial_dim_last_layer * n_channels_last_layer, self.n_outputs),
248
+ )
249
+
250
+ module.add_module("identity", nn.Identity())
251
+
252
+ self.add_module("final_layer", module)
253
+
254
+ glorot_weight_zero_bias(self)
255
+
256
+ @staticmethod
257
+ def _get_inception_branch_1(
258
+ in_channels,
259
+ out_channels,
260
+ kernel_length,
261
+ alpha_momentum,
262
+ drop_prob,
263
+ activation,
264
+ depth_multiplier,
265
+ ):
266
+ return nn.Sequential(
267
+ nn.Conv2d(
268
+ 1,
269
+ out_channels,
270
+ kernel_size=(1, kernel_length),
271
+ padding="same",
272
+ bias=True,
273
+ ),
274
+ nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
275
+ activation(),
276
+ nn.Dropout(drop_prob),
277
+ DepthwiseConv2d(
278
+ out_channels,
279
+ kernel_size=(in_channels, 1),
280
+ depth_multiplier=depth_multiplier,
281
+ bias=False,
282
+ padding="valid",
283
+ ),
284
+ nn.BatchNorm2d(depth_multiplier * out_channels, momentum=alpha_momentum),
285
+ activation(),
286
+ nn.Dropout(drop_prob),
287
+ )
288
+
289
+ @staticmethod
290
+ def _get_inception_branch_2(
291
+ in_channels, out_channels, kernel_length, alpha_momentum, drop_prob, activation
292
+ ):
293
+ return nn.Sequential(
294
+ nn.Conv2d(
295
+ in_channels,
296
+ out_channels,
297
+ kernel_size=(1, kernel_length),
298
+ padding="same",
299
+ bias=False,
300
+ ),
301
+ nn.BatchNorm2d(out_channels, momentum=alpha_momentum),
302
+ activation(),
303
+ nn.Dropout(drop_prob),
304
+ )