braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,354 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops.layers.torch import Rearrange
11
+
12
+ from braindecode.models.base import EEGModuleMixin
13
+ from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
14
+
15
+
16
+ class EEGNeX(EEGModuleMixin, nn.Module):
17
+ r"""EEGNeX model from Chen et al (2024) [eegnex]_.
18
+
19
+ :bdg-success:`Convolution`
20
+
21
+ .. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
22
+ :align: center
23
+ :alt: EEGNeX Architecture
24
+ :width: 620px
25
+
26
+ .. rubric:: Architectural Overview
27
+
28
+ EEGNeX is a **purely convolutional** architecture that refines the EEGNet-style stem
29
+ and deepens the temporal stack with **dilated temporal convolutions**. The end-to-end
30
+ flow is:
31
+
32
+ - (i) **Block-1/2**: two temporal convolutions ``(1 x L)`` with BN refine a
33
+ learned FIR-like *temporal filter bank* (no pooling yet);
34
+ - (ii) **Block-3**: depthwise **spatial** convolution across electrodes
35
+ ``(n_chans x 1)`` with max-norm constraint, followed by ELU → AvgPool (time) → Dropout;
36
+ - (iii) **Block-4/5**: two additional **temporal** convolutions with increasing **dilation**
37
+ to expand the receptive field; the last block applies ELU → AvgPool → Dropout → Flatten;
38
+ - (iv) **Classifier**: a max-norm–constrained linear layer.
39
+
40
+ The published work positions EEGNeX as a compact, conv-only alternative that consistently
41
+ outperforms prior baselines across MOABB-style benchmarks, with the popular
42
+ “EEGNeX-8,32” shorthand denoting *8 temporal filters* and *kernel length 32*.
43
+
44
+
45
+ .. rubric:: Macro Components
46
+
47
+ - **Block-1 / Block-2 — Temporal filter (learned).**
48
+
49
+ - *Operations.*
50
+ - :class:`torch.nn.Conv2d` with kernels ``(1, L)``
51
+ - :class:`torch.nn.BatchNorm2d` (no nonlinearity until Block-3, mirroring a linear FIR analysis stage).
52
+ These layers set up frequency-selective detectors before spatial mixing.
53
+
54
+ - *Interpretability.* Kernels can be inspected as FIR filters; two stacked temporal
55
+ convs allow longer effective kernels without parameter blow-up.
56
+
57
+ - **Block-3 — Spatial projection + condensation.**
58
+
59
+ - *Operations.*
60
+ - :class:`braindecode.modules.Conv2dWithConstraint` with kernel``(n_chans, 1)``
61
+ and ``groups = filter_2`` (depthwise across filters)
62
+ - :class:`torch.nn.BatchNorm2d`
63
+ - :class:`torch.nn.ELU`
64
+ - :class:`torch.nn.AvgPool2d` (time)
65
+ - :class:`torch.nn.Dropout`.
66
+
67
+ **Role**: Learns per-filter spatial patterns over the **full montage** while temporal
68
+ pooling stabilizes and compresses features; max-norm encourages well-behaved spatial
69
+ weights similar to EEGNet practice.
70
+
71
+ - **Block-4 / Block-5 — Dilated temporal integration.**
72
+
73
+ - *Operations.*
74
+ - :class:`torch.nn.Conv2d` with kernels ``(1, k)`` and **dilations**
75
+ (e.g., 2 then 4);
76
+ - :class:`torch.nn.BatchNorm2d`
77
+ - :class:`torch.nn.ELU`
78
+ - :class:`torch.nn.AvgPool2d` (time)
79
+ - :class:`torch.nn.Dropout`
80
+ - :class:`torch.nn.Flatten`.
81
+
82
+ **Role**: Expands the temporal receptive field efficiently to capture rhythms and
83
+ long-range context after condensation.
84
+
85
+ - **Final Classifier — Max-norm linear.**
86
+
87
+ - *Operations.*
88
+ - :class:`braindecode.modules.LinearWithConstraint` maps the flattened
89
+ vector to the target classes; the max-norm constraint regularizes the readout.
90
+
91
+
92
+ .. rubric:: Convolutional Details
93
+
94
+ - **Temporal (where time-domain patterns are learned).**
95
+ Blocks 1-2 learn the primary filter bank (oscillations/transients), while Blocks 4-5
96
+ use **dilation** to integrate over longer horizons without extra pooling. The final
97
+ AvgPool in Block-5 sets the output token rate and helps noise suppression.
98
+
99
+ - **Spatial (how electrodes are processed).**
100
+ A *single* depthwise spatial conv (Block-3) spans the entire electrode set
101
+ (kernel ``(n_chans, 1)``), producing per-temporal-filter topographies; no cross-filter
102
+ mixing occurs at this stage, aiding interpretability.
103
+
104
+ - **Spectral (how frequency content is captured).**
105
+ Frequency selectivity emerges from the learned temporal kernels; dilation broadens effective
106
+ bandwidth coverage by composing multiple scales.
107
+
108
+ .. rubric:: Additional Mechanisms
109
+
110
+ - **EEGNeX-8,32 naming.** “8,32” indicates *8 temporal filters* and *kernel length 32*,
111
+ reflecting the paper's ablation path from EEGNet-8,2 toward thicker temporal kernels
112
+ and a deeper conv stack.
113
+ - **Max-norm constraints.** Spatial (Block-3) and final linear layers use max-norm
114
+ regularization—standard in EEG CNNs—to reduce overfitting and encourage stable spatial
115
+ patterns.
116
+
117
+ .. rubric:: Usage and Configuration
118
+
119
+ - **Kernel schedule.** Start with the canonical **EEGNeX-8,32** (``filter_1=8``,
120
+ ``kernel_block_1_2=32``) and keep **Block-3** depth multiplier modest (e.g., 2) to match
121
+ the paper's “pure conv” profile.
122
+ - **Pooling vs. dilation.** Use pooling in Blocks 3 and 5 to control compute and variance;
123
+ increase dilations (Blocks 4-5) to widen temporal context when windows are short.
124
+ - **Regularization.** Combine dropout (Blocks 3 & 5) with max-norm on spatial and
125
+ classifier layers; prefer ELU activations for stable training on small EEG datasets.
126
+
127
+
128
+ - The braindecode implementation follows the paper's conv-only design with five blocks
129
+ and reproduces the depthwise spatial step and dilated temporal stack. See the class
130
+ reference for exact kernel sizes, dilations, and pooling defaults. You can check the
131
+ original implementation at [EEGNexCode]_.
132
+
133
+ .. versionadded:: 1.1
134
+
135
+
136
+ Parameters
137
+ ----------
138
+ activation : nn.Module, optional
139
+ Activation function to use. Default is `nn.ELU`.
140
+ depth_multiplier : int, optional
141
+ Depth multiplier for the depthwise convolution. Default is 2.
142
+ filter_1 : int, optional
143
+ Number of filters in the first convolutional layer. Default is 8.
144
+ filter_2 : int, optional
145
+ Number of filters in the second convolutional layer. Default is 32.
146
+ drop_prob: float, optional
147
+ Dropout rate. Default is 0.5.
148
+ kernel_block_4 : tuple[int, int], optional
149
+ Kernel size for block 4. Default is (1, 16).
150
+ dilation_block_4 : tuple[int, int], optional
151
+ Dilation rate for block 4. Default is (1, 2).
152
+ avg_pool_block4 : tuple[int, int], optional
153
+ Pooling size for block 4. Default is (1, 4).
154
+ kernel_block_5 : tuple[int, int], optional
155
+ Kernel size for block 5. Default is (1, 16).
156
+ dilation_block_5 : tuple[int, int], optional
157
+ Dilation rate for block 5. Default is (1, 4).
158
+ avg_pool_block5 : tuple[int, int], optional
159
+ Pooling size for block 5. Default is (1, 8).
160
+
161
+ References
162
+ ----------
163
+ .. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
164
+ Toward reliable signals decoding for electroencephalogram: A benchmark
165
+ study to EEGNeX. Biomedical Signal Processing and Control, 87, 105475.
166
+ .. [EEGNexCode] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
167
+ Toward reliable signals decoding for electroencephalogram: A benchmark
168
+ study to EEGNeX. https://github.com/chenxiachan/EEGNeX
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ # Signal related parameters
174
+ n_chans=None,
175
+ n_outputs=None,
176
+ n_times=None,
177
+ chs_info=None,
178
+ input_window_seconds=None,
179
+ sfreq=None,
180
+ # Model parameters
181
+ activation: type[nn.Module] = nn.ELU,
182
+ depth_multiplier: int = 2,
183
+ filter_1: int = 8,
184
+ filter_2: int = 32,
185
+ drop_prob: float = 0.5,
186
+ kernel_block_1_2: int = 64,
187
+ kernel_block_4: int = 16,
188
+ dilation_block_4: int = 2,
189
+ avg_pool_block4: int = 4,
190
+ kernel_block_5: int = 16,
191
+ dilation_block_5: int = 4,
192
+ avg_pool_block5: int = 8,
193
+ max_norm_conv: float = 1.0,
194
+ max_norm_linear: float = 0.25,
195
+ ):
196
+ super().__init__(
197
+ n_outputs=n_outputs,
198
+ n_chans=n_chans,
199
+ chs_info=chs_info,
200
+ n_times=n_times,
201
+ input_window_seconds=input_window_seconds,
202
+ sfreq=sfreq,
203
+ )
204
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
205
+
206
+ self.depth_multiplier = depth_multiplier
207
+ self.filter_1 = filter_1
208
+ self.filter_2 = filter_2
209
+ self.filter_3 = self.filter_2 * self.depth_multiplier
210
+ self.drop_prob = drop_prob
211
+ self.activation = activation
212
+ self.kernel_block_1_2 = (1, kernel_block_1_2)
213
+ self.kernel_block_4 = (1, kernel_block_4)
214
+ self.dilation_block_4 = (1, dilation_block_4)
215
+ self.avg_pool_block4 = (1, avg_pool_block4)
216
+ self.kernel_block_5 = (1, kernel_block_5)
217
+ self.dilation_block_5 = (1, dilation_block_5)
218
+ self.avg_pool_block5 = (1, avg_pool_block5)
219
+
220
+ # final layers output
221
+ self.in_features = self._calculate_output_length()
222
+
223
+ # Following paper nomenclature
224
+ self.block_1 = nn.Sequential(
225
+ Rearrange("batch ch time -> batch 1 ch time"),
226
+ nn.Conv2d(
227
+ in_channels=1,
228
+ out_channels=self.filter_1,
229
+ kernel_size=self.kernel_block_1_2,
230
+ padding="same",
231
+ bias=False,
232
+ ),
233
+ nn.BatchNorm2d(num_features=self.filter_1),
234
+ )
235
+
236
+ self.block_2 = nn.Sequential(
237
+ nn.Conv2d(
238
+ in_channels=self.filter_1,
239
+ out_channels=self.filter_2,
240
+ kernel_size=self.kernel_block_1_2,
241
+ padding="same",
242
+ bias=False,
243
+ ),
244
+ nn.BatchNorm2d(num_features=self.filter_2),
245
+ )
246
+
247
+ self.block_3 = nn.Sequential(
248
+ Conv2dWithConstraint(
249
+ in_channels=self.filter_2,
250
+ out_channels=self.filter_3,
251
+ max_norm=max_norm_conv,
252
+ kernel_size=(self.n_chans, 1),
253
+ groups=self.filter_2,
254
+ bias=False,
255
+ ),
256
+ nn.BatchNorm2d(num_features=self.filter_3),
257
+ self.activation(),
258
+ nn.AvgPool2d(
259
+ kernel_size=self.avg_pool_block4,
260
+ padding=(0, 1),
261
+ ),
262
+ nn.Dropout(p=self.drop_prob),
263
+ )
264
+
265
+ self.block_4 = nn.Sequential(
266
+ nn.Conv2d(
267
+ in_channels=self.filter_3,
268
+ out_channels=self.filter_2,
269
+ kernel_size=self.kernel_block_4,
270
+ dilation=self.dilation_block_4,
271
+ padding="same",
272
+ bias=False,
273
+ ),
274
+ nn.BatchNorm2d(num_features=self.filter_2),
275
+ )
276
+
277
+ self.block_5 = nn.Sequential(
278
+ nn.Conv2d(
279
+ in_channels=self.filter_2,
280
+ out_channels=self.filter_1,
281
+ kernel_size=self.kernel_block_5,
282
+ dilation=self.dilation_block_5,
283
+ padding="same",
284
+ bias=False,
285
+ ),
286
+ nn.BatchNorm2d(num_features=self.filter_1),
287
+ self.activation(),
288
+ nn.AvgPool2d(
289
+ kernel_size=self.avg_pool_block5,
290
+ padding=(0, 1),
291
+ ),
292
+ nn.Dropout(p=self.drop_prob),
293
+ nn.Flatten(),
294
+ )
295
+
296
+ self.final_layer = LinearWithConstraint(
297
+ in_features=self.in_features,
298
+ out_features=self.n_outputs,
299
+ max_norm=max_norm_linear,
300
+ )
301
+
302
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
+ """
304
+ Forward pass of the EEGNeX model.
305
+
306
+ Parameters
307
+ ----------
308
+ x : torch.Tensor
309
+ Input tensor of shape (batch_size, n_chans, n_times).
310
+
311
+ Returns
312
+ -------
313
+ torch.Tensor
314
+ Output tensor of shape (batch_size, n_outputs).
315
+ """
316
+ # x shape: (batch_size, n_chans, n_times)
317
+ x = self.block_1(x)
318
+ # (batch_size, n_filter, n_chans, n_times)
319
+ x = self.block_2(x)
320
+ # (batch_size, n_filter*4, n_chans, n_times)
321
+ x = self.block_3(x)
322
+ # (batch_size, 1, n_filter*8, n_times//4)
323
+ x = self.block_4(x)
324
+ # (batch_size, 1, n_filter*8, n_times//4)
325
+ x = self.block_5(x)
326
+ # (batch_size, n_filter*(n_times//32))
327
+ x = self.final_layer(x)
328
+
329
+ return x
330
+
331
+ def _calculate_output_length(self) -> int:
332
+ # Pooling kernel sizes for the time dimension
333
+ p4 = self.avg_pool_block4[1]
334
+ p5 = self.avg_pool_block5[1]
335
+
336
+ # Padding for the time dimension (assumed from padding=(0, 1))
337
+ pad4 = 1
338
+ pad5 = 1
339
+
340
+ # Stride is assumed to be equal to kernel size (p4 and p5)
341
+
342
+ # Calculate time dimension after block 3 pooling
343
+ # Formula: floor((L_in + 2*padding - kernel_size) / stride) + 1
344
+ T3 = math.floor((self.n_times + 2 * pad4 - p4) / p4) + 1
345
+
346
+ # Calculate time dimension after block 5 pooling
347
+ T5 = math.floor((T3 + 2 * pad5 - p5) / p5) + 1
348
+
349
+ # Calculate final flattened features (channels * 1 * time_dim)
350
+ # The spatial dimension is reduced to 1 after block 3's depthwise conv.
351
+ final_in_features = (
352
+ self.filter_1 * T5
353
+ ) # filter_1 is the number of channels before flatten
354
+ return final_in_features
@@ -0,0 +1,201 @@
1
+ """
2
+ EEG-SimpleConv is a 1D Convolutional Neural Network from Yassine El Ouahidi et al. (2023).
3
+
4
+ Originally designed for Motor Imagery decoding, from EEG signals.
5
+ The model offers competitive performances, with a low latency and is mainly composed of
6
+ 1D convolutional layers.
7
+
8
+ """
9
+
10
+ # Authors: Yassine El Ouahidi <eloua.yas@gmail.com>
11
+ #
12
+ # License: BSD-3
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torchaudio.transforms import Resample
17
+
18
+ from braindecode.models.base import EEGModuleMixin
19
+
20
+
21
+ class EEGSimpleConv(EEGModuleMixin, torch.nn.Module):
22
+ r"""EEGSimpleConv from Ouahidi, YE et al (2023) [Yassine2023]_.
23
+
24
+ :bdg-success:`Convolution`
25
+
26
+ .. figure:: https://raw.githubusercontent.com/elouayas/EEGSimpleConv/refs/heads/main/architecture.png
27
+ :align: center
28
+ :alt: EEGSimpleConv Architecture
29
+
30
+ EEGSimpleConv is a 1D Convolutional Neural Network originally designed
31
+ for decoding motor imagery from EEG signals. The model aims to have a
32
+ very simple and straightforward architecture that allows a low latency,
33
+ while still achieving very competitive performance.
34
+
35
+ EEG-SimpleConv starts with a 1D convolutional layer, where each EEG channel
36
+ enters a separate 1D convolutional channel. This is followed by a series of
37
+ blocks of two 1D convolutional layers. Between the two convolutional layers
38
+ of each block is a max pooling layer, which downsamples the data by a factor
39
+ of 2. Each convolution is followed by a batch normalisation layer and a ReLU
40
+ activation function. Finally, a global average pooling (in the time domain)
41
+ is performed to obtain a single value per feature map, which is then fed
42
+ into a linear layer to obtain the final classification prediction output.
43
+
44
+ The paper and original code with more details about the methodological
45
+ choices are available at the [Yassine2023]_ and [Yassine2023Code]_.
46
+
47
+ The input shape should be three-dimensional matrix representing the EEG
48
+ signals.
49
+
50
+ ``(batch_size, n_channels, n_timesteps)``.
51
+
52
+ Notes
53
+ -----
54
+ The authors recommend using the default parameters for MI decoding.
55
+ Please refer to the original paper and code for more details.
56
+
57
+ Recommended range for the choice of the hyperparameters, regarding the
58
+ evaluation paradigm.
59
+
60
+ | Parameter | Within-Subject | Cross-Subject |
61
+ | feature_maps | [64-144] | [64-144] |
62
+ | n_convs | 1 | [2-4] |
63
+ | resampling_freq | [70-100] | [50-80] |
64
+ | kernel_size | [12-17] | [5-8] |
65
+
66
+
67
+ An intensive ablation study is included in the paper to understand the
68
+ of each parameter on the model performance.
69
+
70
+ .. versionadded:: 0.9
71
+
72
+ Parameters
73
+ ----------
74
+ feature_maps: int
75
+ Number of Feature Maps at the first Convolution, width of the model.
76
+ n_convs: int
77
+ Number of blocks of convolutions (2 convolutions per block), depth of the model.
78
+ resampling: int
79
+ Resampling Frequency.
80
+ kernel_size: int
81
+ Size of the convolutions kernels.
82
+ activation: nn.Module, default=nn.ELU
83
+ Activation function class to apply. Should be a PyTorch activation
84
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
85
+
86
+ References
87
+ ----------
88
+ .. [Yassine2023] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
89
+ N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
90
+ BCI Motor Imagery Decoding. Arxiv preprint. arxiv.org/abs/2309.07159
91
+ .. [Yassine2023Code] Yassine El Ouahidi, V. Gripon, B. Pasdeloup, G. Bouallegue
92
+ N. Farrugia, G. Lioi, 2023. A Strong and Simple Deep Learning Baseline for
93
+ BCI Motor Imagery Decoding. GitHub repository.
94
+ https://github.com/elouayas/EEGSimpleConv.
95
+
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ # Base arguments
101
+ n_outputs=None,
102
+ n_chans=None,
103
+ sfreq=None,
104
+ # Model specific arguments
105
+ feature_maps=128,
106
+ n_convs=2,
107
+ resampling_freq=80,
108
+ kernel_size=8,
109
+ return_feature=False,
110
+ activation: type[nn.Module] = nn.ReLU,
111
+ # Other ways to initialize the model
112
+ chs_info=None,
113
+ n_times=None,
114
+ input_window_seconds=None,
115
+ ):
116
+ super().__init__(
117
+ n_outputs=n_outputs,
118
+ n_chans=n_chans,
119
+ chs_info=chs_info,
120
+ n_times=n_times,
121
+ input_window_seconds=input_window_seconds,
122
+ sfreq=sfreq,
123
+ )
124
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
125
+
126
+ self.return_feature = return_feature
127
+ self.resample = (
128
+ Resample(orig_freq=int(self.sfreq), new_freq=int(resampling_freq))
129
+ if self.sfreq != resampling_freq
130
+ else torch.nn.Identity()
131
+ )
132
+
133
+ self.conv = torch.nn.Conv1d(
134
+ self.n_chans,
135
+ feature_maps,
136
+ kernel_size=kernel_size,
137
+ padding=kernel_size // 2,
138
+ bias=False,
139
+ )
140
+ self.bn = torch.nn.BatchNorm1d(feature_maps)
141
+ self.blocks = []
142
+ new_feature_maps = feature_maps
143
+ old_feature_maps = feature_maps
144
+ for i in range(n_convs):
145
+ if i > 0:
146
+ # 1.414 = sqrt(2) allow constant flops.
147
+ new_feature_maps = int(1.414 * new_feature_maps)
148
+ self.blocks.append(
149
+ torch.nn.Sequential(
150
+ (
151
+ torch.nn.Conv1d(
152
+ old_feature_maps,
153
+ new_feature_maps,
154
+ kernel_size=kernel_size,
155
+ padding=kernel_size // 2,
156
+ bias=False,
157
+ )
158
+ ),
159
+ (torch.nn.BatchNorm1d(new_feature_maps)),
160
+ (torch.nn.MaxPool1d(2) if i > 0 - 1 else torch.nn.MaxPool1d(1)),
161
+ (activation()),
162
+ (
163
+ torch.nn.Conv1d(
164
+ new_feature_maps,
165
+ new_feature_maps,
166
+ kernel_size=kernel_size,
167
+ padding=kernel_size // 2,
168
+ bias=False,
169
+ )
170
+ ),
171
+ (torch.nn.BatchNorm1d(new_feature_maps)),
172
+ (activation()),
173
+ )
174
+ )
175
+ old_feature_maps = new_feature_maps
176
+ self.blocks = torch.nn.ModuleList(self.blocks)
177
+ self.final_layer = torch.nn.Linear(old_feature_maps, self.n_outputs)
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ """
181
+ Forward pass of the model.
182
+
183
+ Parameters
184
+ ----------
185
+ x: PyTorch Tensor
186
+ Input tensor of shape (batch_size, n_channels, n_times)
187
+
188
+ Returns
189
+ -------
190
+ PyTorch Tensor (optional)
191
+ Output tensor of shape (batch_size, n_outputs)
192
+ """
193
+ x_rs = self.resample(x.contiguous())
194
+ feat = torch.relu(self.bn(self.conv(x_rs)))
195
+ for seq in self.blocks:
196
+ feat = seq(feat)
197
+ feat = feat.mean(dim=2)
198
+ if self.return_feature:
199
+ return feat
200
+ else:
201
+ return self.final_layer(feat)