braindecode 1.3.0.dev180329405__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +12 -4
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +17 -7
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/__init__.py +11 -1
  15. braindecode/datautil/channel_utils.py +114 -0
  16. braindecode/datautil/serialization.py +7 -7
  17. braindecode/functional/functions.py +6 -2
  18. braindecode/functional/initialization.py +2 -3
  19. braindecode/models/__init__.py +6 -0
  20. braindecode/models/atcnet.py +26 -27
  21. braindecode/models/attentionbasenet.py +37 -32
  22. braindecode/models/attn_sleep.py +2 -0
  23. braindecode/models/base.py +280 -2
  24. braindecode/models/bendr.py +469 -0
  25. braindecode/models/biot.py +2 -0
  26. braindecode/models/contrawr.py +2 -0
  27. braindecode/models/ctnet.py +8 -3
  28. braindecode/models/deepsleepnet.py +28 -19
  29. braindecode/models/eegconformer.py +2 -2
  30. braindecode/models/eeginception_erp.py +31 -25
  31. braindecode/models/eegitnet.py +2 -0
  32. braindecode/models/eegminer.py +2 -0
  33. braindecode/models/eegnet.py +1 -1
  34. braindecode/models/eegsym.py +917 -0
  35. braindecode/models/eegtcnet.py +2 -0
  36. braindecode/models/fbcnet.py +5 -1
  37. braindecode/models/fblightconvnet.py +2 -0
  38. braindecode/models/fbmsnet.py +20 -6
  39. braindecode/models/ifnet.py +2 -0
  40. braindecode/models/labram.py +33 -26
  41. braindecode/models/medformer.py +758 -0
  42. braindecode/models/msvtnet.py +2 -0
  43. braindecode/models/patchedtransformer.py +1 -1
  44. braindecode/models/signal_jepa.py +111 -27
  45. braindecode/models/sinc_shallow.py +12 -9
  46. braindecode/models/sstdpn.py +11 -11
  47. braindecode/models/summary.csv +3 -0
  48. braindecode/models/syncnet.py +2 -0
  49. braindecode/models/tcn.py +2 -0
  50. braindecode/models/usleep.py +26 -21
  51. braindecode/models/util.py +3 -0
  52. braindecode/modules/attention.py +10 -10
  53. braindecode/modules/blocks.py +3 -3
  54. braindecode/modules/filter.py +2 -9
  55. braindecode/modules/layers.py +18 -17
  56. braindecode/preprocessing/__init__.py +232 -3
  57. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  58. braindecode/preprocessing/mne_preprocess.py +142 -10
  59. braindecode/preprocessing/preprocess.py +28 -18
  60. braindecode/preprocessing/util.py +166 -0
  61. braindecode/preprocessing/windowers.py +26 -20
  62. braindecode/samplers/base.py +8 -8
  63. braindecode/version.py +1 -1
  64. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +6 -2
  65. braindecode-1.3.0.dev182330353.dist-info/RECORD +109 -0
  66. braindecode-1.3.0.dev180329405.dist-info/RECORD +0 -103
  67. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
  68. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
  69. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
  70. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,917 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, List, Tuple, cast
4
+
5
+ import torch
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from braindecode.datautil.channel_utils import (
10
+ division_channels_idx,
11
+ match_hemisphere_chans,
12
+ )
13
+ from braindecode.models.base import EEGModuleMixin
14
+
15
+
16
+ class EEGSym(EEGModuleMixin, nn.Module):
17
+ """EEGSym from Pérez-Velasco et al (2022) [eegsym2022]_.
18
+
19
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel`
20
+
21
+ .. figure:: ../../docs/_static/model/eegsym.png
22
+ :align: center
23
+ :alt: EEGSym Architecture
24
+
25
+
26
+ The **EEGSym** is a novel Convolutional Neural Network (CNN) architecture designed for
27
+ Motor Imagery (MI) based Brain-Computer Interfaces (BCIs), primarily aimed at
28
+ **overcoming inter-subject variability** and significantly **reducing BCI inefficiency**
29
+ [eegsym2022]_.
30
+
31
+ The architecture integrates advances from Deep Learning (DL), complemented by
32
+ Transfer Learning (TL) techniques and Data Augmentation (DA), to achieve strong
33
+ performance in inter-subject MI classification [eegsym2022]_.
34
+
35
+ .. rubric:: Architectural Overview
36
+
37
+ EEGSym systematically incorporates three core features:
38
+
39
+ #. **Inception Modules** for multi-scale temporal analysis [eegsym2022]_.
40
+ #. **Residual Connections** maintain spatio-temporal signal structure and
41
+ enable deeper feature extraction [eegsym2022]_.
42
+ #. A **Siamese-network design** exploits the inherent symmetry of the brain
43
+ across the mid-sagittal plane [eegsym2022]_.
44
+
45
+ .. rubric:: Macro Components
46
+
47
+ - `EEGSym.symmetric_division` **(Input Processing)**
48
+ - *Operations.* The input is virtually split into left, right, and middle channels.
49
+ Middle (central) channels are duplicated and concatenated to both left
50
+ and right lateralized electrodes to form the two hemisphere inputs [eegsym2022]_.
51
+ - *Role.* Prepares the data for the siamese-network approach,
52
+ reducing the number of parameters in the spatial filters
53
+ for the tempospatial analysis stage [eegsym2022]_.
54
+
55
+ - `EEGSym.inception_block` **(Tempospatial Analysis - Temporal Feature Extraction)**
56
+ - *Operations.* Uses :class:`_InceptionBlock` modules, which apply parallel
57
+ temporal convolutions with different kernel sizes (scales) [eegsym2022]_.
58
+ This is followed by concatenation, residual connections, and average
59
+ pooling for temporal dimensionality reduction [eegsym2022]_.
60
+ - *Role.* Captures detailed temporal relationships in the architecture,
61
+ similarly to :class:`~braindecode.models.eeginception_mi.EEGInceptionMI`
62
+ [eeginception2020]_. The first block uses large temporal kernels
63
+ (e.g., 500 ms, 250 ms, 125 ms) [eegsym2022]_.
64
+
65
+ - `EEGSym.residual_blocks` **(Tempospatial Analysis - Spatial Feature Extraction)**
66
+ - *Operations.* Composed of multiple :class:`_ResidualBlock` modules (typically three instances)
67
+ [eegsym2022]_. Each block applies temporal convolution, pooling, and a spatial analysis layer
68
+ (convolution or grouped convolution) [eegsym2022]_.
69
+ - *Role.* Enhances spatial feature extraction by incorporating residual
70
+ connections across all CNN stages, which helps maintain the spatio-temporal
71
+ structure of the signal through deeper layers [eegsym2022]_.
72
+
73
+ - `EEGSym.channel_merging` **(Hemisphere Merging)**
74
+ - *Operations.* The :class:`_ChannelMergingBlock` reduces the spatial dimensionality
75
+ (Z and C) to 1, performing two residual convolutions followed by a final grouped
76
+ convolution that merges the feature information from the two hemispheres [eegsym2022]_.
77
+ - *Role.* Extracts complex relationships between channels of both hemispheres as part of the
78
+ symmetry exploitation [eegsym2022]_.
79
+
80
+ - `EEGSym.temporal_merging` **(Temporal Collapse)**
81
+ - *Operations.* The :class:`_TemporalMergingBlock` uses residual convolution
82
+ followed by grouped convolution to reduce the temporal dimension (S) to 1 [eegsym2022]_.
83
+ - *Role.* Final step of temporal aggregation before the output module [eegsym2022]_.
84
+
85
+ - `EEGSym.output_blocks` **(Output Processing)**
86
+ - *Operations.* The :class:`_OutputBlock` applies four residual convolution iterations
87
+ (1x1x1 convolutions) followed by flattening [eegsym2022]_.
88
+ - *Role.* Final feature refinement through residual connections before the
89
+ fully connected classification layer [eegsym2022]_.
90
+
91
+ .. rubric:: How the information is encoded temporally, spatially, and spectrally
92
+
93
+ * **Temporal.**
94
+ Temporal features are extracted across multiple scales in the inception modules
95
+ using different temporal convolution kernel sizes (e.g., corresponding to
96
+ 500 ms, 250 ms, and 125 ms windows for a 128 Hz sampling rate), very similar to [eeginception2020]_.
97
+ Subsequent pooling operations and residual blocks continue to reduce the temporal dimension
98
+ [eegsym2022]_.
99
+
100
+ * **Spatial.**
101
+
102
+ Spatial features are extracted via two main mechanisms:
103
+
104
+ - (1) The **siamese-network design** implicitly introduces brain symmetry by treating the two hemispheres
105
+ equally during feature extraction [eegsym2022]_.
106
+ - (2) **Residual connections** are utilized in the Tempospatial Analysis stage to enhance the extraction of
107
+ spatial correlations between electrodes [eegsym2022]_.
108
+
109
+ * **Spectral.**
110
+ Spectral information is implicitly captured by the varying kernel sizes of the temporal convolutions
111
+ in the inception modules [eegsym2022]_. These kernels filter the signal across different temporal windows,
112
+ corresponding to different frequency characteristics.
113
+
114
+ Notes
115
+ ----------
116
+ * EEGSym achieved competitive accuracies across five large MI datasets [eegsym2022]_.
117
+ * The model maintained high accuracy using a reduced set of electrodes (8 or 16 channels)
118
+ [eegsym2022]_.
119
+ * This is PyTorch implementation of the EEGSym model of the TensorFlow original [eegsym2022code]_.
120
+
121
+ Parameters
122
+ ----------
123
+ filters_per_branch : int, optional
124
+ Number of filters in each inception branch. Should be a multiple of 8.
125
+ Default is 12 [eegsym2022]_.
126
+ scales_time : tuple of int, optional
127
+ Temporal scales (in milliseconds) for the temporal convolutions in the first
128
+ inception module. Default is (500, 250, 125) [eegsym2022]_.
129
+ drop_prob : float, optional
130
+ Dropout probability. Default is 0.25 [eegsym2022]_.
131
+ activation : type[nn.Module], optional
132
+ Activation function class to use. Default is :class:`nn.ELU` [eegsym2022]_.
133
+ spatial_resnet_repetitions : int, optional
134
+ Number of repetitions of the spatial analysis operations at each step.
135
+ Default is 5 [eegsym2022]_.
136
+ left_right_chs : list of tuple of str, optional
137
+ List of tuples pairing left and right hemisphere channel names,
138
+ e.g., ``[('C3', 'C4'), ('FC5', 'FC6')]``. If not provided, channels
139
+ are automatically split into left/right hemispheres using
140
+ :func:`~braindecode.datautil.channel_utils.division_channels_idx` and
141
+ :func:`~braindecode.datautil.channel_utils.match_hemisphere_chans`.
142
+ Must be provided together with ``middle_chs`` [eegsym2022]_.
143
+ middle_chs : list of str, optional
144
+ List of midline (central) channel names that lie on the mid-sagittal plane,
145
+ e.g., ``['FZ', 'CZ', 'PZ']``. These channels are duplicated and concatenated
146
+ to both hemispheres. If not provided, channels are automatically identified
147
+ using :func:`~braindecode.datautil.channel_utils.division_channels_idx`.
148
+ Must be provided together with ``left_right_chs`` [eegsym2022]_.
149
+
150
+ References
151
+ ----------
152
+ .. [eegsym2022] Pérez-Velasco, S., Santamaría-Vázquez, E., Martínez-Cagigal, V.,
153
+ Marcos-Martínez, D., & Hornero, R. (2022). EEGSym: Overcoming inter-subject
154
+ variability in motor imagery based BCIs with deep learning. IEEE Transactions
155
+ on Neural Systems and Rehabilitation Engineering, 30, 1766-1775.
156
+ .. [eegsym2022code] Pérez-Velasco, S., EEGSym source code.
157
+ https://github.com/Serpeve/EEGSym
158
+ .. [eeginception2020] Santamaría-Vázquez, E., Martínez-Cagigal, V.,
159
+ Vaquerizo-Villar, F., & Hornero, R. (2020). EEG-Inception: A novel deep
160
+ convolutional neural network for assistive ERP-based brain-computer interfaces.
161
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering, 28, 2773-2782.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ # braidecode parameters
167
+ n_chans=None,
168
+ n_outputs=None,
169
+ n_times=None,
170
+ chs_info=None,
171
+ input_window_seconds=None,
172
+ sfreq=None,
173
+ # Model parameters
174
+ filters_per_branch: int = 12,
175
+ scales_time: Tuple[int, int, int] = (500, 250, 125),
176
+ drop_prob: float = 0.25,
177
+ activation: type[nn.Module] = nn.ELU,
178
+ spatial_resnet_repetitions: int = 5,
179
+ left_right_chs: list[tuple[str, str]] | None = None,
180
+ middle_chs: list[str] | None = None,
181
+ ):
182
+ if (left_right_chs is None) != (middle_chs is None):
183
+ raise ValueError(
184
+ "Either both or none of left_right_chs and middle_chs must be provided."
185
+ )
186
+ super().__init__(
187
+ n_outputs=n_outputs,
188
+ n_chans=n_chans,
189
+ chs_info=chs_info,
190
+ n_times=n_times,
191
+ input_window_seconds=input_window_seconds,
192
+ sfreq=sfreq,
193
+ )
194
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
195
+
196
+ self.filters_per_branch = filters_per_branch
197
+ self.scales_time = scales_time
198
+ self.drop_prob = drop_prob
199
+ self.activation = activation()
200
+ self.spatial_resnet_repetitions = spatial_resnet_repetitions
201
+
202
+ # Calculate scales in samples
203
+ self.scales_samples = [int(s * self.sfreq / 2000) * 2 + 1 for s in scales_time]
204
+
205
+ # Note: chs_info is actually list[dict] despite base class type hint
206
+ # saying list[str]
207
+ ch_names = [cast(dict[str, Any], ch)["ch_name"] for ch in self.chs_info]
208
+ if left_right_chs is None:
209
+ left_chs, right_chs, middle_chs = division_channels_idx(ch_names)
210
+ try:
211
+ # Try to match hemispheres based on channel naming
212
+ left_chs, right_chs = zip(*match_hemisphere_chans(left_chs, right_chs))
213
+ except (ValueError, IndexError):
214
+ # Fallback: if matching fails, treat all channels as one hemisphere
215
+ # This allows the model to work with arbitrary channel configurations
216
+ left_chs = ch_names
217
+ right_chs = ch_names
218
+ middle_chs = []
219
+ else:
220
+ left_chs, right_chs = zip(*left_right_chs)
221
+ # middle_chs is guaranteed to be not None when left_right_chs is not None
222
+ # (checked in __init__ validation)
223
+ assert middle_chs is not None, (
224
+ "middle_chs must be provided with left_right_chs"
225
+ )
226
+
227
+ # Convert to indices and store as tensors for TorchScript compatibility
228
+ left_idx = [ch_names.index(ch) for ch in left_chs]
229
+ right_idx = [ch_names.index(ch) for ch in right_chs]
230
+ middle_idx = [ch_names.index(ch) for ch in middle_chs]
231
+
232
+ # Register as buffers (non-trainable tensors) for TorchScript compatibility
233
+ self.register_buffer("left_idx", torch.tensor(left_idx, dtype=torch.long))
234
+ self.register_buffer("right_idx", torch.tensor(right_idx, dtype=torch.long))
235
+ self.register_buffer("middle_idx", torch.tensor(middle_idx, dtype=torch.long))
236
+
237
+ self.n_channels_per_hemi = len(left_idx) + len(middle_idx)
238
+ ##################
239
+ # Build the model
240
+ ##################
241
+ self.include_extra_dim = Rearrange("batch channel time -> batch 1 channel time")
242
+
243
+ self.permute_layer = Rearrange(
244
+ "batch features z time space -> batch features z space time"
245
+ )
246
+
247
+ # Build the model
248
+ self.inception_block1 = _InceptionBlock(
249
+ in_channels=1,
250
+ scales_samples=self.scales_samples,
251
+ filters_per_branch=self.filters_per_branch,
252
+ ncha=self.n_channels_per_hemi,
253
+ activation=self.activation,
254
+ drop_prob=self.drop_prob,
255
+ average_pool=2,
256
+ spatial_resnet_repetitions=self.spatial_resnet_repetitions,
257
+ init=True,
258
+ )
259
+ self.inception_block2 = _InceptionBlock(
260
+ in_channels=self.filters_per_branch * len(self.scales_samples),
261
+ scales_samples=[max(1, s // 4) for s in self.scales_samples],
262
+ filters_per_branch=self.filters_per_branch,
263
+ ncha=self.n_channels_per_hemi,
264
+ activation=self.activation,
265
+ drop_prob=self.drop_prob,
266
+ average_pool=2,
267
+ spatial_resnet_repetitions=self.spatial_resnet_repetitions,
268
+ init=False,
269
+ )
270
+
271
+ # Residual blocks (spatial dim is still n_channels_per_hemi through the network)
272
+ self.residual_blocks = nn.Sequential(
273
+ _ResidualBlock(
274
+ in_channels=self.filters_per_branch * len(self.scales_samples),
275
+ filters=self.filters_per_branch
276
+ * len(self.scales_samples), # No reduction
277
+ kernel_size=16,
278
+ ncha=self.n_channels_per_hemi,
279
+ activation=self.activation,
280
+ drop_prob=self.drop_prob,
281
+ average_pool=2,
282
+ spatial_resnet_repetitions=self.spatial_resnet_repetitions,
283
+ ),
284
+ _ResidualBlock(
285
+ in_channels=self.filters_per_branch * len(self.scales_samples),
286
+ filters=int(
287
+ self.filters_per_branch * len(self.scales_samples) / 2
288
+ ), # Reduce by /2
289
+ kernel_size=8,
290
+ ncha=self.n_channels_per_hemi,
291
+ activation=self.activation,
292
+ drop_prob=self.drop_prob,
293
+ average_pool=2,
294
+ spatial_resnet_repetitions=self.spatial_resnet_repetitions,
295
+ ),
296
+ _ResidualBlock(
297
+ in_channels=int(self.filters_per_branch * len(self.scales_samples) / 2),
298
+ filters=int(
299
+ self.filters_per_branch * len(self.scales_samples) / 4
300
+ ), # Reduce by /2
301
+ kernel_size=4,
302
+ ncha=self.n_channels_per_hemi,
303
+ activation=self.activation,
304
+ drop_prob=self.drop_prob,
305
+ average_pool=2,
306
+ spatial_resnet_repetitions=self.spatial_resnet_repetitions,
307
+ ),
308
+ )
309
+
310
+ # Temporal reduction
311
+ self.temporal_reduction = nn.Sequential(
312
+ _TemporalBlock(
313
+ in_channels=int(self.filters_per_branch * len(self.scales_samples) / 4),
314
+ filters=int(self.filters_per_branch * len(self.scales_samples) / 4),
315
+ kernel_size=4,
316
+ activation=self.activation,
317
+ drop_prob=self.drop_prob,
318
+ ),
319
+ nn.AvgPool3d(kernel_size=(1, 2, 1)),
320
+ )
321
+
322
+ # Channel merging
323
+ self.channel_merging = _ChannelMergingBlock(
324
+ in_channels=int(self.filters_per_branch * len(self.scales_samples) / 4),
325
+ filters=int(self.filters_per_branch * len(self.scales_samples) / 4),
326
+ groups=int(
327
+ self.filters_per_branch * len(self.scales_samples) / 12
328
+ ), # 36/12=3 groups
329
+ ncha=self.n_channels_per_hemi,
330
+ division=2,
331
+ activation=self.activation,
332
+ drop_prob=self.drop_prob,
333
+ )
334
+
335
+ # Temporal merging
336
+ # Calculate temporal dimension at this point
337
+ # After: Inc1 (pool/2), Inc2 (pool/2), Res1-3 (pool/2 each), TempRed (pool/2)
338
+ # Total reduction: 2^6 = 64
339
+ temporal_dim_at_merging = self.n_times // 64
340
+
341
+ self.temporal_merging = _TemporalMergingBlock(
342
+ in_channels=int(self.filters_per_branch * len(self.scales_samples) / 4),
343
+ filters=int(self.filters_per_branch * len(self.scales_samples) / 2),
344
+ groups=int(self.filters_per_branch * len(self.scales_samples) / 4),
345
+ n_times=temporal_dim_at_merging,
346
+ activation=self.activation,
347
+ drop_prob=self.drop_prob,
348
+ )
349
+
350
+ # Output layers
351
+ self.output_blocks = nn.Sequential(
352
+ _OutputBlock(
353
+ in_channels=int(self.filters_per_branch * len(self.scales_samples) / 2),
354
+ activation=self.activation,
355
+ drop_prob=self.drop_prob,
356
+ ),
357
+ nn.Flatten(),
358
+ )
359
+
360
+ # Final fully connected layer
361
+ self.final_layer = nn.Linear(
362
+ in_features=int(self.filters_per_branch * len(self.scales_samples) / 2),
363
+ out_features=self.n_outputs,
364
+ )
365
+
366
+ def forward(self, x):
367
+ """Forward pass.
368
+
369
+ Parameters
370
+ ----------
371
+ x : torch.Tensor
372
+ Input tensor of shape (batch_size, n_channels, n_times).
373
+
374
+ Returns
375
+ -------
376
+ torch.Tensor
377
+ Output tensor of shape (batch_size, n_classes).
378
+ """
379
+ # Input: (B, C, T) = (batch, channels, time)
380
+ # Step 1: Add feature dimension
381
+ x = self.include_extra_dim(x) # (B, 1, C, T)
382
+
383
+ # Step 2: Split into left, right, and middle channels
384
+ # Use index_select for TorchScript compatibility
385
+ left_data = torch.index_select(x, 2, self.left_idx) # (B, 1, n_left, T)
386
+ right_data = torch.index_select(x, 2, self.right_idx) # (B, 1, n_right, T)
387
+ middle_data = torch.index_select(x, 2, self.middle_idx) # (B, 1, n_middle, T)
388
+
389
+ # Step 3: Concatenate middle channels to both hemispheres
390
+ left_hemi = torch.cat(
391
+ [left_data, middle_data], dim=2
392
+ ) # (B, 1, n_left+n_middle, T)
393
+ right_hemi = torch.cat(
394
+ [right_data, middle_data], dim=2
395
+ ) # (B, 1, n_right+n_middle, T)
396
+
397
+ # Step 4: Stack along Z dimension
398
+ x = torch.stack([left_hemi, right_hemi], dim=2) # (B, 1, 2, n_ch_per_hemi, T)
399
+
400
+ # Step 5:
401
+ # From: (B, F, Z, Space, Time)
402
+ # To: (B, F, Z, Time, Space)
403
+ x = self.permute_layer(x)
404
+
405
+ # Now x is in correct format: (Batch, Features, Z, Time, Space)
406
+
407
+ # Initial inception modules
408
+ x = self.inception_block1([x])[0] # Returns list, take first element
409
+ x = self.inception_block2([x])[0] # Returns list, take first element
410
+
411
+ # Residual blocks
412
+ x = self.residual_blocks(x)
413
+
414
+ # Temporal reduction
415
+ x = self.temporal_reduction(x)
416
+
417
+ # Channel merging
418
+ x = self.channel_merging(x)
419
+
420
+ # Temporal merging
421
+ x = self.temporal_merging(x)
422
+
423
+ # Output blocks
424
+ x = self.output_blocks(x)
425
+
426
+ # Final fully connected layer
427
+ x = self.final_layer(x)
428
+
429
+ return x
430
+
431
+
432
+ class _InceptionBlock(nn.Module):
433
+ """Inception module used in EEGSym architecture.
434
+
435
+ Parameters
436
+ ----------
437
+ in_channels : int
438
+ Number of input channels.
439
+ scales_samples : list of int
440
+ List of sample sizes for the temporal convolution kernels.
441
+ filters_per_branch : int
442
+ Number of filters in each inception branch.
443
+ ncha : int
444
+ Number of input channels.
445
+ activation : nn.Module
446
+ Activation function to use.
447
+ drop_prob : float
448
+ Dropout probability.
449
+ average_pool : int
450
+ Kernel size for average pooling.
451
+ spatial_resnet_repetitions : int
452
+ Number of repetitions of the spatial analysis operations.
453
+ residual : bool
454
+ If True, includes residual connections.
455
+ init : bool
456
+ If True, applies channel merging operation if residual is False.
457
+ """
458
+
459
+ def __init__(
460
+ self,
461
+ in_channels: int,
462
+ scales_samples: List[int],
463
+ filters_per_branch: int,
464
+ ncha: int,
465
+ activation: nn.Module,
466
+ drop_prob: float,
467
+ average_pool: int,
468
+ spatial_resnet_repetitions: int,
469
+ init: bool = False,
470
+ ):
471
+ super().__init__()
472
+ self.activation = activation
473
+ self.drop_prob = drop_prob
474
+ self.average_pool = average_pool
475
+ self.init = init
476
+
477
+ # Temporal convolutions
478
+ self.temporal_convs = nn.ModuleList()
479
+ for scale in scales_samples:
480
+ self.temporal_convs.append(
481
+ nn.Sequential(
482
+ nn.Conv3d(
483
+ in_channels=in_channels,
484
+ out_channels=filters_per_branch,
485
+ kernel_size=(1, scale, 1),
486
+ padding=(0, scale // 2, 0),
487
+ ),
488
+ nn.BatchNorm3d(filters_per_branch),
489
+ activation,
490
+ nn.Dropout(drop_prob),
491
+ )
492
+ )
493
+
494
+ # Spatial convolutions
495
+ if ncha != 1:
496
+ self.spatial_convs = nn.ModuleList()
497
+ for _ in range(spatial_resnet_repetitions):
498
+ self.spatial_convs.append(
499
+ nn.Sequential(
500
+ nn.Conv3d(
501
+ in_channels=filters_per_branch * len(scales_samples),
502
+ out_channels=filters_per_branch * len(scales_samples),
503
+ kernel_size=(1, 1, ncha),
504
+ padding=(0, 0, 0),
505
+ ),
506
+ nn.BatchNorm3d(filters_per_branch * len(scales_samples)),
507
+ activation,
508
+ nn.Dropout(drop_prob),
509
+ )
510
+ )
511
+
512
+ self.pool = (
513
+ nn.AvgPool3d(kernel_size=(1, average_pool, 1))
514
+ if average_pool != 1
515
+ else nn.Identity()
516
+ )
517
+
518
+ def forward(self, x_list: list[torch.Tensor]) -> list[torch.Tensor]:
519
+ outputs: list[torch.Tensor] = []
520
+ for x in x_list:
521
+ # Apply temporal convolutions
522
+ temp_outputs = [conv(x) for conv in self.temporal_convs]
523
+ x_out = torch.cat(temp_outputs, dim=1)
524
+
525
+ # Trim temporal dimension if needed (due to even kernel sizes with padding)
526
+ if x_out.shape[3] > x.shape[3]:
527
+ x_out = x_out[:, :, :, : x.shape[3], :]
528
+
529
+ # Residual connection
530
+ x_out = x_out + x
531
+
532
+ # Average pooling
533
+ x_out = self.pool(x_out)
534
+
535
+ # Apply spatial convolutions
536
+ if hasattr(self, "spatial_convs"):
537
+ for spatial_conv in self.spatial_convs:
538
+ x_spatial = spatial_conv(x_out)
539
+ x_out = x_out + x_spatial # Always use residual connection
540
+
541
+ outputs.append(x_out)
542
+ return outputs
543
+
544
+
545
+ class _ResidualBlock(nn.Module):
546
+ """Residual block used in EEGSym architecture.
547
+
548
+ Parameters
549
+ ----------
550
+ in_channels : int
551
+ Number of input channels.
552
+ filters : int
553
+ Number of filters for the convolutional layers.
554
+ kernel_size : int
555
+ Kernel size for the temporal convolution.
556
+ activation : nn.Module
557
+ Activation function to use.
558
+ drop_prob : float
559
+ Dropout probability.
560
+ average_pool : int
561
+ Kernel size for average pooling.
562
+ spatial_resnet_repetitions : int
563
+ Number of repetitions of the spatial analysis operations.
564
+ residual : bool
565
+ If True, includes residual connections.
566
+ """
567
+
568
+ def __init__(
569
+ self,
570
+ in_channels: int,
571
+ filters: int,
572
+ kernel_size: int,
573
+ ncha: int,
574
+ activation: nn.Module,
575
+ drop_prob: float,
576
+ average_pool: int,
577
+ spatial_resnet_repetitions: int = 5,
578
+ ):
579
+ super().__init__()
580
+ self.activation = activation
581
+ self.drop_prob = drop_prob
582
+
583
+ # Temporal convolution
584
+ self.temporal_conv = nn.Sequential(
585
+ nn.Conv3d(
586
+ in_channels=in_channels,
587
+ out_channels=filters,
588
+ kernel_size=(1, kernel_size, 1),
589
+ padding=(0, kernel_size // 2, 0),
590
+ ),
591
+ nn.BatchNorm3d(filters),
592
+ activation,
593
+ nn.Dropout(drop_prob),
594
+ )
595
+
596
+ # Projection layer for dimension matching if needed
597
+ if in_channels != filters:
598
+ self.projection = nn.Conv3d(
599
+ in_channels=in_channels,
600
+ out_channels=filters,
601
+ kernel_size=(1, 1, 1),
602
+ )
603
+ else:
604
+ self.projection = None
605
+
606
+ # Average pooling
607
+ self.avg_pool = nn.AvgPool3d(
608
+ kernel_size=(1, average_pool, 1)
609
+ ) # FIXED: pool Time
610
+
611
+ # Spatial convolutions (multiple repetitions like in InceptionBlock)
612
+ if ncha != 1:
613
+ self.spatial_convs = nn.ModuleList()
614
+ for _ in range(spatial_resnet_repetitions):
615
+ self.spatial_convs.append(
616
+ nn.Sequential(
617
+ nn.Conv3d(
618
+ in_channels=filters,
619
+ out_channels=filters,
620
+ kernel_size=(1, 1, ncha), # Spatial convolution
621
+ padding=(0, 0, 0),
622
+ ),
623
+ nn.BatchNorm3d(filters),
624
+ activation,
625
+ nn.Dropout(drop_prob),
626
+ )
627
+ )
628
+ else:
629
+ self.spatial_convs = None
630
+
631
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
632
+ x_res = self.temporal_conv(x)
633
+
634
+ # Trim temporal dimension if needed (due to even kernel sizes with padding)
635
+ if x_res.shape[3] > x.shape[3]:
636
+ x_res = x_res[:, :, :, : x.shape[3], :]
637
+
638
+ # Handle channel dimension mismatch if needed
639
+ if self.projection is not None:
640
+ x = self.projection(x)
641
+
642
+ x_out = x_res + x # Residual connection
643
+ x_out = self.avg_pool(x_out)
644
+
645
+ # Apply spatial convolutions if present (multiple repetitions)
646
+ if self.spatial_convs is not None:
647
+ for spatial_conv in self.spatial_convs:
648
+ x_spatial = spatial_conv(x_out)
649
+ x_out = x_out + x_spatial # Residual connection with broadcasting
650
+
651
+ return x_out
652
+
653
+
654
+ class _TemporalBlock(nn.Module):
655
+ """Temporal reduction block used in EEGSym architecture.
656
+
657
+ Parameters
658
+ ----------
659
+ in_channels : int
660
+ Number of input channels.
661
+ filters : int
662
+ Number of filters for the convolutional layers.
663
+ kernel_size : int
664
+ Kernel size for the temporal convolution.
665
+ activation : nn.Module
666
+ Activation function to use.
667
+ drop_prob : float
668
+ Dropout probability.
669
+ residual : bool
670
+ If True, includes residual connections.
671
+ """
672
+
673
+ def __init__(
674
+ self,
675
+ in_channels: int,
676
+ filters: int,
677
+ kernel_size: int,
678
+ activation: nn.Module,
679
+ drop_prob: float,
680
+ ):
681
+ super().__init__()
682
+ self.activation = activation
683
+ self.drop_prob = drop_prob
684
+
685
+ self.conv = nn.Sequential(
686
+ nn.Conv3d(
687
+ in_channels=in_channels,
688
+ out_channels=filters,
689
+ kernel_size=(1, kernel_size, 1),
690
+ padding=(0, kernel_size // 2, 0),
691
+ ),
692
+ nn.BatchNorm3d(filters),
693
+ activation,
694
+ nn.Dropout(drop_prob),
695
+ )
696
+
697
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
698
+ x_res = self.conv(x)
699
+
700
+ # Trim temporal dimension if needed (due to even kernel sizes with padding)
701
+ if x_res.shape[3] > x.shape[3]:
702
+ x_res = x_res[:, :, :, : x.shape[3], :]
703
+
704
+ x_res = x_res + x
705
+ return x_res
706
+
707
+
708
+ class _ChannelMergingBlock(nn.Module):
709
+ """Channel merging block used in EEGSym architecture.
710
+
711
+ This block performs hemisphere merging through:
712
+ 1. Two residual convolution iterations (with full spatial kernel)
713
+ 2. One grouped convolution (merges Z dimension from 2 to 1)
714
+
715
+ Parameters
716
+ ----------
717
+ in_channels : int
718
+ Number of input channels.
719
+ filters : int
720
+ Number of filters for the convolutional layers.
721
+ groups : int
722
+ Number of groups for the final grouped convolution.
723
+ ncha : int
724
+ Number of spatial channels to merge.
725
+ division : int
726
+ Z dimension size to merge (typically 2 for two hemispheres).
727
+ activation : nn.Module
728
+ Activation function to use.
729
+ drop_prob : float
730
+ Dropout probability.
731
+ """
732
+
733
+ def __init__(
734
+ self,
735
+ in_channels: int,
736
+ filters: int,
737
+ groups: int,
738
+ ncha: int,
739
+ division: int,
740
+ activation: nn.Module,
741
+ drop_prob: float,
742
+ ):
743
+ super().__init__()
744
+ self.activation = activation
745
+ self.drop_prob = drop_prob
746
+
747
+ # TWO residual convolution iterations
748
+ # Each reduces spatial dimension: ncha → 1
749
+ self.residual_convs = nn.ModuleList()
750
+ for _ in range(2):
751
+ self.residual_convs.append(
752
+ nn.Sequential(
753
+ nn.Conv3d(
754
+ in_channels=in_channels,
755
+ out_channels=filters,
756
+ kernel_size=(division, 1, ncha), # (Z, Time, Space)
757
+ padding=(0, 0, 0), # Valid padding
758
+ ),
759
+ nn.BatchNorm3d(filters),
760
+ activation,
761
+ nn.Dropout(drop_prob),
762
+ )
763
+ )
764
+
765
+ # Final grouped convolution
766
+ # Merges Z dimension: 2 → 1
767
+ self.grouped_conv = nn.Sequential(
768
+ nn.Conv3d(
769
+ in_channels=in_channels,
770
+ out_channels=filters,
771
+ kernel_size=(division, 1, ncha), # (Z, Time, Space)
772
+ groups=groups,
773
+ padding=(0, 0, 0),
774
+ ),
775
+ nn.BatchNorm3d(filters),
776
+ activation,
777
+ nn.Dropout(drop_prob),
778
+ )
779
+
780
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
781
+ # Apply 2 residual iterations
782
+ # Each iteration: conv reduces dims, then Add broadcasts back
783
+ for residual_conv in self.residual_convs:
784
+ x_res = residual_conv(x)
785
+ x = x + x_res # Broadcasts x_res (1,T,1) to match x (2,T,5)
786
+
787
+ # Apply final grouped conv (permanently reduces dimensions)
788
+ x = self.grouped_conv(x)
789
+
790
+ return x
791
+
792
+
793
+ class _TemporalMergingBlock(nn.Module):
794
+ """Temporal merging block used in EEGSym architecture.
795
+
796
+ This block performs temporal dimension collapse through:
797
+ 1. One residual convolution (temporal collapse with residual connection)
798
+ 2. One grouped convolution (temporal collapse + double filters)
799
+
800
+ Parameters
801
+ ----------
802
+ in_channels : int
803
+ Number of input channels.
804
+ filters : int
805
+ Number of output filters (should be 2x input channels).
806
+ groups : int
807
+ Number of groups for the grouped convolution.
808
+ n_times : int
809
+ Current temporal dimension size.
810
+ activation : nn.Module
811
+ Activation function to use.
812
+ drop_prob : float
813
+ Dropout probability.
814
+ """
815
+
816
+ def __init__(
817
+ self,
818
+ in_channels: int,
819
+ filters: int,
820
+ groups: int,
821
+ n_times: int,
822
+ activation: nn.Module,
823
+ drop_prob: float,
824
+ ):
825
+ super().__init__()
826
+ self.activation = activation
827
+ self.drop_prob = drop_prob
828
+
829
+ # Calculate temporal kernel size
830
+ # At this point in network, temporal dim has been reduced by pooling
831
+ self.temporal_kernel = n_times # Should be 6 for 384 input samples
832
+
833
+ # Residual convolution (collapses time dimension)
834
+ self.residual_conv = nn.Sequential(
835
+ nn.Conv3d(
836
+ in_channels=in_channels,
837
+ out_channels=in_channels, # Same channels for residual
838
+ kernel_size=(1, self.temporal_kernel, 1), # (Z, Time, Space)
839
+ padding=(0, 0, 0), # Valid padding - reduces time to 1
840
+ ),
841
+ nn.BatchNorm3d(in_channels),
842
+ activation,
843
+ nn.Dropout(drop_prob),
844
+ )
845
+
846
+ # Grouped convolution (collapses time dimension, doubles filters)
847
+ self.grouped_conv = nn.Sequential(
848
+ nn.Conv3d(
849
+ in_channels=in_channels,
850
+ out_channels=filters, # Double the channels
851
+ kernel_size=(1, self.temporal_kernel, 1), # (Z, Time, Space)
852
+ groups=groups,
853
+ padding=(0, 0, 0),
854
+ ),
855
+ nn.BatchNorm3d(filters),
856
+ activation,
857
+ nn.Dropout(drop_prob),
858
+ )
859
+
860
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
861
+ # Residual convolution with broadcasting
862
+ x_res = self.residual_conv(x)
863
+ x = x + x_res # Broadcasts x_res (1,1,1) back to x shape (1,6,1)
864
+
865
+ # Grouped convolution (reduces time to 1, doubles channels)
866
+ x = self.grouped_conv(x)
867
+
868
+ return x
869
+
870
+
871
+ class _OutputBlock(nn.Module):
872
+ """Output block used in EEGSym architecture.
873
+
874
+ Parameters
875
+ ----------
876
+ in_channels : int
877
+ Number of input channels.
878
+ activation : nn.Module
879
+ Activation function to use.
880
+ drop_prob : float
881
+ Dropout probability.
882
+ residual : bool
883
+ If True, includes residual connections.
884
+ """
885
+
886
+ def __init__(
887
+ self,
888
+ in_channels: int,
889
+ activation: nn.Module,
890
+ drop_prob: float,
891
+ n_residual: int = 4,
892
+ ):
893
+ super().__init__()
894
+ self.activation = activation
895
+ self.drop_prob = drop_prob
896
+
897
+ self.conv_blocks = nn.ModuleList()
898
+ for _ in range(n_residual):
899
+ self.conv_blocks.append(
900
+ nn.Sequential(
901
+ nn.Conv3d(
902
+ in_channels=in_channels,
903
+ out_channels=in_channels,
904
+ kernel_size=(1, 1, 1),
905
+ padding=(0, 0, 0),
906
+ ),
907
+ nn.BatchNorm3d(in_channels),
908
+ activation,
909
+ nn.Dropout(drop_prob),
910
+ )
911
+ )
912
+
913
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
914
+ for conv_block in self.conv_blocks:
915
+ x_res = conv_block(x)
916
+ x = x + x_res
917
+ return x