braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,845 +0,0 @@
1
- # Authors: Meta Platforms, Inc. and affiliates (original)
2
- # Bruno Aristimunha <b.aristimunha@gmail.com> (Braindecode adaptation)
3
- # Hubert Banville <hubertjb@meta.com> (Braindecode adaptation and Review)
4
- #
5
- # License: Attribution-NonCommercial 4.0 International
6
-
7
- """BrainModule: Dilated Convolutional Encoder for EEG decoding."""
8
-
9
- from __future__ import annotations
10
-
11
- import math
12
- import typing as tp
13
-
14
- import torch
15
- import torchaudio as ta
16
- from torch import nn
17
- from torch.nn import functional as F
18
-
19
- from braindecode.models.base import EEGModuleMixin
20
- from braindecode.modules.layers import SubjectLayers
21
-
22
- __all__ = ["BrainModule"]
23
-
24
-
25
- class BrainModule(EEGModuleMixin, nn.Module):
26
- r"""BrainModule from [brainmagick]_, also known as SimpleConv.
27
-
28
- A dilated convolutional encoder for EEG decoding, using residual
29
- connections and optional GLU gating for improved expressivity.
30
-
31
- :bdg-success:`Convolution`
32
-
33
- .. figure:: ../_static/model/simpleconv.png
34
- :align: center
35
- :alt: BrainModule Architecture
36
- :width: 500px
37
-
38
- Figure adapted Extended Data Fig. 4 from [brainmagick]_ to highlight only the model part.
39
- Architecture of the brain module. Architecture used to process the brain recordings.
40
- For each layer, the authors note first the number of output channels, while the number of time steps
41
- is constant throughout the layers. The model is composed of a spatial attention layer,
42
- then a 1x1 convolution without activation. A 'Subject Layer' is selected based on the subject index s,
43
- which consists in a 1x1 convolution learnt only for that subject with no activation. Then,
44
- the authors apply five convolutional blocks made of three convolutions. The first
45
- two use residual skip connection and increasing dilation, followed by a BatchNorm layer and a
46
- GELU activation. The third convolution is not residual, and uses a GLU activation
47
- (which halves the number of channels) and no normalization.
48
- Finally, the authors apply two 1x1 convolutions with a GELU in between.
49
-
50
- The BrainModule (also referred to as SimpleConv) is a deep dilated
51
- convolutional encoder specifically designed to decode perceived speech from
52
- non-invasive brain recordings like EEG and MEG. It is engineered to address
53
- the high noise levels and inter-individual variability inherent in
54
- non-invasive neuroimaging by using a single architecture trained across
55
- large cohorts while accommodating participant-specific differences.
56
-
57
- .. rubric:: Architecture Overview
58
-
59
- The BrainModule integrates three primary mechanisms to align brain activity
60
- with deep speech representations:
61
-
62
- 1. **Spatial-temporal feature extraction.** The model uses a dedicated
63
- spatial attention layer to remap sensor data based on physical
64
- locations, followed by temporal processing through dilated convolutions.
65
- 2. **Subject-specific adaptation.** To leverage inter-subject variability,
66
- the architecture includes a "Subject Layer" or participant-specific
67
- 1x1 convolution that allows the model to share core weights across a
68
- cohort while learning individual-specific neural patterns.
69
- 3. **Dilated residual blocks with gating.** The core encoder employs a
70
- stack of convolutional blocks featuring skip connections and increasing
71
- dilation to expand the receptive field without losing temporal
72
- resolution, supplemented by optional Gated Linear Units (GLU) for
73
- increased expressivity.
74
-
75
- .. rubric:: Macro Components
76
-
77
- ``BrainModule.input_projection`` (Initial Processing)
78
- **Operations.** Raw M/EEG input
79
- :math:`\mathbf{X} \in \mathbb{R}^{C \times T}` is first processed
80
- through a spatial attention layer that projects sensor locations onto a
81
- 2D plane using Fourier-parameterized functions. This is followed by a
82
- subject-specific 1x1 convolution
83
- :math:`\mathbf{M}_s \in \mathbb{R}^{D_1 \times D_1}` if subject
84
- features are enabled. The resulting features are projected to the
85
- ``hidden_dim`` (default 320) to ensure compatibility with subsequent
86
- residual connections.
87
-
88
- **Role.** Converts high-dimensional, subject-dependent sensor data into
89
- a standardized latent space while preserving spatial and temporal
90
- relationships.
91
-
92
- ``BrainModule.encoder`` (Convolutional Sequence)
93
- **Operations.** Implemented via
94
- :class:`~braindecode.models.brainmodule._ConvSequence`, this component
95
- consists of a stack of ``k`` convolutional blocks. Each block typically
96
- contains: (a) **Residual dilated convolutions.** Two layers with kernel
97
- size 3, residual skip connections, and dilation factors that grow
98
- exponentially (e.g., powers of two with periodic resets) to capture
99
- multi-scale temporal context. (b) **GLU gating.** Every ``N`` layers
100
- (defined by ``glu``), a Gated Linear Unit is applied, which halves the
101
- channel dimension and introduces non-linear gating to filter
102
- intermediate representations.
103
-
104
- **Role.** Extracts deep hierarchical temporal features from the brain
105
- signal, significantly expanding the model's receptive field to align
106
- with the contextual windows of speech modules like wav2vec 2.0.
107
-
108
- .. rubric:: Temporal, Spatial, and Spectral Encoding
109
-
110
- - **Temporal:** Increasing dilation factors across layers allow the model to
111
- integrate information over large time windows without the computational
112
- cost of standard large kernels, while a 150 ms input shift facilitates
113
- alignment between stimulus and brain response.
114
- - **Spatial:** The spatial attention layer learns a softmax weighting over
115
- input sensors based on their 3D coordinates, allowing the model to focus
116
- on regions typically activated during auditory stimulation (e.g., the
117
- temporal cortex).
118
- - **Spectral:** Through the optional ``n_fft`` parameter, the model can
119
- apply an STFT transformation, converting time-domain signals into a
120
- spectrogram representation before encoding.
121
-
122
- .. rubric:: Additional Mechanisms
123
-
124
- - **Clamping and scaling:** The model relies on clamping input values
125
- (e.g., at 20 standard deviations) to prevent outliers and large
126
- electromagnetic artifacts from destabilizing the BatchNorm estimates and
127
- optimization process.
128
- - **Scaled subject embeddings:** When ``subject_dim`` is used, the
129
- :class:`~braindecode.models.brainmodule._ScaledEmbedding` layer scales up
130
- the learning rate for subject-specific features to prevent slow
131
- convergence in multi-participant training.
132
-
133
-
134
- - **_ConvSequence and residual logic:** This class handles the actual
135
- stacking of layers. It is designed to be flexible with the ``growth``
136
- parameter; if the channel size changes between layers (``growth != 1.0``),
137
- it automatically applies a 1x1 ``skip_projection`` convolution to the
138
- residual path so dimensions match for addition.
139
- - **_ChannelDropout:** Unlike standard dropout which zeroes individual
140
- neurons, this zeroes entire channels. It includes a rescale feature that
141
- multiplies the remaining channels by a factor
142
- ``total_channels / active_channels`` to maintain the expected value of the
143
- signal during training.
144
- - **_ScaledEmbedding:** This is a clever optimization for multi-subject
145
- learning. By dividing the initial weights by a scale and then multiplying
146
- the output by the same scale, it effectively increases the gradient
147
- magnitude for the embedding weights, allowing subject-specific features to
148
- learn faster than the shared backbone.
149
-
150
-
151
- Parameters
152
- ----------
153
- hidden_dim : int, default=320
154
- Hidden dimension for convolutional layers. Input is projected to this
155
- dimension before the convolutional blocks.
156
- depth : int, default=10
157
- Number of convolutional blocks. Each block contains a dilated convolution
158
- with batch normalization and activation, followed by a residual connection.
159
- kernel_size : int, default=3
160
- Convolutional kernel size. Must be odd for proper padding with dilation.
161
- growth : float, default=1.0
162
- Channel size multiplier: hidden_dim * (growth ** layer_index).
163
- Values > 1.0 grow channels deeper; < 1.0 shrink them.
164
- Note: growth != 1.0 disables residual connections between layers
165
- with different channel sizes.
166
- dilation_growth : int, default=2
167
- Dilation multiplier per layer (e.g., 2 means dilation doubles each layer).
168
- Improves receptive field exponentially. Requires odd kernel_size.
169
- dilation_period : int, default=5
170
- Reset dilation to 1 every N layers. Prevents dilation from growing
171
- too large and maintains local connectivity.
172
- conv_drop_prob : float, default=0.0
173
- Dropout probability for convolutional layers.
174
- dropout_input : float, default=0.0
175
- Dropout probability applied to model input only.
176
- batch_norm : bool, default=True
177
- If True, apply batch normalization after each convolution.
178
- activation : type[nn.Module], default=nn.GELU
179
- Activation function class to use (e.g., nn.GELU, nn.ReLU, nn.ELU).
180
- n_subjects : int, default=200
181
- Number of unique subjects (for subject-specific pathways).
182
- Only used if subject_dim > 0.
183
- subject_dim : int, default=0
184
- Dimension of subject embeddings. If 0, no subject-specific features.
185
- If > 0, adds subject embeddings to the input before encoding.
186
- subject_layers : bool, default=False
187
- If True, apply subject-specific linear transformations to input channels.
188
- Each subject has its own weight matrix. Requires subject_dim > 0.
189
- subject_layers_dim : str, default="input"
190
- Where to apply subject layers: "input" or "hidden".
191
- subject_layers_id : bool, default=False
192
- If True, initialize subject layers as identity matrices.
193
- embedding_scale : float, default=1.0
194
- Scaling factor for subject embeddings learning rate.
195
- n_fft : int, optional
196
- FFT size for STFT processing. If None, no STFT is applied.
197
- If specified, applies spectrogram transform before encoding.
198
- fft_complex : bool, default=True
199
- If True, keep complex spectrogram. If False, use power spectrogram.
200
- Only used when n_fft is not None.
201
- channel_dropout_prob : float, default=0.0
202
- Probability of dropping each channel during training (0.0 to 1.0).
203
- If 0.0, no channel dropout is applied.
204
- channel_dropout_type : str, optional
205
- If specified with chs_info, only drop channels of this type
206
- (e.g., 'eeg', 'ref', 'eog'). If None with dropout_prob > 0, drops any channel.
207
- glu : int, default=2
208
- If > 0, applies Gated Linear Units (GLU) every N convolutional layers.
209
- GLUs gate intermediate representations for more expressivity.
210
- If 0, no GLU is applied.
211
- glu_context : int, default=1
212
- Context window size for GLU gates. If > 0, uses contextual information
213
- from neighboring time steps for gating. Requires glu > 0.
214
-
215
- References
216
- ----------
217
- .. [brainmagick] Défossez, A., Caucheteux, C., Rapin, J., Kabeli, O., & King, J. R.
218
- (2023). Decoding speech perception from non-invasive brain recordings. Nature
219
- Machine Intelligence, 5(10), 1097-1107.
220
-
221
- Notes
222
- -----
223
- - Input shape: (batch, n_chans, n_times)
224
- - Output shape: (batch, n_outputs)
225
- - The model uses dilated convolutions with stride=1 to maintain temporal
226
- resolution while achieving large receptive fields.
227
- - Residual connections are applied at every layer where input and output
228
- channels match.
229
- - Subject-specific features (subject_dim > 0, subject_layers) require passing
230
- subject indices in the forward pass as an optional parameter or via batch.
231
- - STFT processing (n_fft > 0) automatically transforms input to spectrogram domain.
232
-
233
- .. versionadded:: 1.2
234
-
235
- """
236
-
237
- def __init__(
238
- self,
239
- # braindecode EEGModuleMixin parameters
240
- n_chans: int | None = None,
241
- n_outputs: int | None = None,
242
- n_times: int | None = None,
243
- sfreq: float | None = None,
244
- chs_info: list[dict] | None = None,
245
- input_window_seconds: float | None = None,
246
- ########
247
- # Model related parameters
248
- # Architecture
249
- hidden_dim: int = 320,
250
- depth: int = 10,
251
- kernel_size: int = 3,
252
- growth: float = 1.0,
253
- dilation_growth: int = 2,
254
- dilation_period: int = 5,
255
- # Regularization
256
- conv_drop_prob: float = 0.0,
257
- dropout_input: float = 0.0,
258
- batch_norm: bool = True,
259
- activation: type[nn.Module] = nn.GELU,
260
- # Subject-specific features (optional)
261
- n_subjects: int = 200,
262
- subject_dim: int = 0,
263
- subject_layers: bool = False,
264
- subject_layers_dim: str = "input",
265
- subject_layers_id: bool = False,
266
- embedding_scale: float = 1.0,
267
- # STFT/Spectrogram (optional)
268
- n_fft: int | None = None,
269
- fft_complex: bool = True,
270
- # Channel dropout (optional)
271
- channel_dropout_prob: float = 0.0,
272
- channel_dropout_type: str | None = None,
273
- # GLU gates (optional)
274
- glu: int = 2,
275
- glu_context: int = 1,
276
- ):
277
- # Initialize EEGModuleMixin
278
- super().__init__(
279
- n_outputs=n_outputs,
280
- n_chans=n_chans,
281
- chs_info=chs_info,
282
- n_times=n_times,
283
- input_window_seconds=input_window_seconds,
284
- sfreq=sfreq,
285
- )
286
-
287
- # Store parameters for later use
288
- self.subject_dim = subject_dim
289
- self.n_subjects = n_subjects
290
- self.n_fft = n_fft
291
- self.fft_complex = fft_complex
292
- self.hidden_dim = hidden_dim
293
-
294
- # Validate inputs
295
- _validate_brainmodule_params(
296
- subject_layers=subject_layers,
297
- subject_dim=subject_dim,
298
- depth=depth,
299
- kernel_size=kernel_size,
300
- growth=growth,
301
- dilation_growth=dilation_growth,
302
- channel_dropout_prob=channel_dropout_prob,
303
- channel_dropout_type=channel_dropout_type,
304
- glu=glu,
305
- glu_context=glu_context,
306
- )
307
-
308
- # Initialize channel dropout (optional)
309
- self.channel_dropout = None
310
- if channel_dropout_prob > 0:
311
- self.channel_dropout = _ChannelDropout(
312
- dropout_prob=channel_dropout_prob,
313
- ch_info=chs_info,
314
- channel_type=channel_dropout_type,
315
- )
316
-
317
- # Initialize subject-specific modules (optional)
318
- self.subject_embedding = None
319
- self.subject_layers_module = None
320
- input_channels = self.n_chans
321
-
322
- if subject_dim > 0:
323
- self.subject_embedding = _ScaledEmbedding(
324
- n_subjects, subject_dim, embedding_scale
325
- )
326
- input_channels += subject_dim
327
-
328
- if subject_layers:
329
- assert subject_dim > 0, "subject_layers requires subject_dim > 0"
330
- # Use n_chans for input dim since subject_layers is applied before
331
- # subject embeddings are concatenated in forward()
332
- meg_dim = self.n_chans
333
- dim = hidden_dim if subject_layers_dim == "hidden" else meg_dim
334
- self.subject_layers_module = SubjectLayers(
335
- meg_dim, dim, n_subjects, subject_layers_id
336
- )
337
- # After subject_layers, we have 'dim' channels, then add subject_dim
338
- input_channels = dim + subject_dim
339
-
340
- # Initialize STFT module (optional)
341
- self.stft = None
342
- if n_fft is not None:
343
- self.stft = ta.transforms.Spectrogram(
344
- n_fft=n_fft,
345
- hop_length=n_fft // 2,
346
- normalized=True,
347
- power=None if fft_complex else 1,
348
- return_complex=True,
349
- )
350
- # Update input channels for spectrogram
351
- freq_bins = n_fft // 2 + 1
352
- if fft_complex:
353
- input_channels *= 2 * freq_bins
354
- else:
355
- input_channels *= freq_bins
356
-
357
- # Initial projection layer: project input channels to hidden_dim
358
- # This is crucial for residual connections to work properly
359
- self.input_projection = nn.Conv1d(input_channels, hidden_dim, 1)
360
-
361
- # Build channel dimensions for encoder (all same size for residuals)
362
- # With growth=1.0, all layers have hidden_dim channels (residuals work)
363
- # With growth!=1.0, channels vary (residuals only where dims match)
364
- encoder_dims = [hidden_dim] + [
365
- int(round(hidden_dim * growth**k)) for k in range(depth)
366
- ]
367
-
368
- # Build encoder (stride=1, no downsampling)
369
- self.encoder = _ConvSequence(
370
- channels=encoder_dims,
371
- kernel_size=kernel_size,
372
- dilation_growth=dilation_growth,
373
- dilation_period=dilation_period,
374
- dropout=conv_drop_prob,
375
- dropout_input=dropout_input,
376
- batch_norm=batch_norm,
377
- glu=glu,
378
- glu_context=glu_context,
379
- activation=activation,
380
- )
381
-
382
- # Final layer: temporal aggregation + output projection
383
- # Use the last encoder dimension (may differ from hidden_dim if growth != 1)
384
- final_hidden_dim = encoder_dims[-1]
385
- self.final_layer = nn.Sequential(
386
- nn.AdaptiveAvgPool1d(1),
387
- nn.Flatten(start_dim=1),
388
- nn.Linear(final_hidden_dim, self.n_outputs),
389
- )
390
-
391
- def forward(
392
- self, x: torch.Tensor, subject_index: torch.Tensor | None = None
393
- ) -> torch.Tensor:
394
- """
395
- Forward pass.
396
-
397
- Parameters
398
- ----------
399
- x : torch.Tensor
400
- Input EEG data of shape (batch, n_chans, n_times).
401
- subject_index : torch.Tensor, optional
402
- Subject indices of shape (batch,). Required if subject_dim > 0.
403
-
404
- Returns
405
- -------
406
- torch.Tensor
407
- Output logits/predictions of shape (batch, n_outputs).
408
- """
409
- # Validate input shape
410
- if x.dim() != 3:
411
- raise ValueError(
412
- f"Expected 3D input (batch, channels, time), got shape {x.shape}"
413
- )
414
- if x.shape[1] != self.n_chans:
415
- raise ValueError(f"Expected {self.n_chans} channels, got {x.shape[1]}")
416
-
417
- # Apply STFT if enabled
418
- if self.stft is not None:
419
- # Pad for STFT window
420
- assert self.n_fft is not None, "n_fft must be set if stft is not None"
421
- pad_size = self.n_fft // 4
422
- x = F.pad(
423
- _pad_multiple(x, self.n_fft // 2), (pad_size, pad_size), mode="reflect"
424
- )
425
-
426
- # Apply STFT
427
- spec = self.stft(x) # (batch, channels, freq, time)
428
- B, C, Fr, T = spec.shape
429
-
430
- if self.fft_complex:
431
- # Convert complex to real/imag channels
432
- spec = torch.view_as_real(spec).permute(0, 1, 2, 4, 3)
433
- x = spec.reshape(B, C * 2 * Fr, T)
434
- else:
435
- x = spec.reshape(B, C * Fr, T)
436
-
437
- # Apply channel dropout if enabled
438
- if self.channel_dropout is not None:
439
- x = self.channel_dropout(x)
440
-
441
- # Apply subject layers if enabled
442
- if self.subject_layers_module is not None:
443
- if subject_index is None:
444
- raise ValueError(
445
- "subject_index is required when subject_layers is enabled"
446
- )
447
- x = self.subject_layers_module(x, subject_index)
448
-
449
- # Apply subject embedding if enabled
450
- if self.subject_embedding is not None:
451
- if subject_index is None:
452
- raise ValueError("subject_index is required when subject_dim > 0")
453
- emb = self.subject_embedding(subject_index) # (batch, subject_dim)
454
- emb = emb[:, :, None].expand(
455
- -1, -1, x.shape[-1]
456
- ) # (batch, subject_dim, time)
457
- x = torch.cat([x, emb], dim=1) # Concatenate along channel dimension
458
-
459
- # Project input to hidden dimension
460
- x = self.input_projection(x)
461
-
462
- # Encode with residual dilated convolutions
463
- x = self.encoder(x)
464
-
465
- # Apply final layer (pool + linear)
466
- x = self.final_layer(x)
467
-
468
- return x
469
-
470
-
471
- class _ConvSequence(nn.Module):
472
- """Sequence of residual dilated convolutional layers with GLU activation.
473
-
474
- This is a simplified encoder-only architecture that maintains temporal
475
- resolution (stride=1) and applies residual connections at every layer
476
- where input and output channels match.
477
-
478
- Parameters
479
- ----------
480
- channels : Sequence[int]
481
- Channel dimensions for each layer. E.g., [320, 320, 320, 320] for
482
- a 3-layer network with 320 hidden dims.
483
- kernel_size : int, default=3
484
- Convolutional kernel size. Must be odd for proper padding with dilation.
485
- dilation_growth : int, default=2
486
- Dilation multiplier per layer. Improves receptive field exponentially.
487
- dilation_period : int, default=5
488
- Reset dilation to 1 every N layers.
489
- dropout : float, default=0.0
490
- Dropout probability after activation.
491
- dropout_input : float, default=0.0
492
- Dropout probability applied to input only.
493
- batch_norm : bool, default=True
494
- Whether to apply batch normalization.
495
- glu : int, default=2
496
- Apply GLU gating every N layers. If 0, no GLU.
497
- glu_context : int, default=1
498
- Context window for GLU convolution.
499
- activation : type, default=nn.GELU
500
- Activation function class.
501
- """
502
-
503
- def __init__(
504
- self,
505
- channels: tp.Sequence[int],
506
- kernel_size: int = 3,
507
- dilation_growth: int = 2,
508
- dilation_period: int = 5,
509
- dropout: float = 0.0,
510
- dropout_input: float = 0.0,
511
- batch_norm: bool = True,
512
- glu: int = 2,
513
- glu_context: int = 1,
514
- activation: tp.Any = None,
515
- ) -> None:
516
- super().__init__()
517
-
518
- if dilation_growth > 1:
519
- assert kernel_size % 2 != 0, (
520
- "Supports only odd kernel with dilation for now"
521
- )
522
-
523
- if activation is None:
524
- activation = nn.GELU
525
-
526
- self.sequence = nn.ModuleList()
527
- self.glus = nn.ModuleList()
528
- self.skip_projections = nn.ModuleList() # For when chin != chout
529
-
530
- dilation = 1
531
- channels = tuple(channels)
532
-
533
- for k, (chin, chout) in enumerate(zip(channels[:-1], channels[1:])):
534
- layers: tp.List[nn.Module] = []
535
-
536
- # Input dropout (only on first layer)
537
- if k == 0 and dropout_input > 0:
538
- layers.append(nn.Dropout(dropout_input))
539
-
540
- # Reset dilation periodically
541
- if dilation_period and (k % dilation_period) == 0:
542
- dilation = 1
543
-
544
- # Dilated convolution with proper padding to maintain temporal size
545
- pad = kernel_size // 2 * dilation
546
- layers.extend(
547
- [
548
- nn.Conv1d(
549
- chin,
550
- chout,
551
- kernel_size=kernel_size,
552
- stride=1, # Always stride=1 for residual connections
553
- padding=pad,
554
- dilation=dilation,
555
- ),
556
- ]
557
- )
558
-
559
- # Batch norm + activation + dropout
560
- if batch_norm:
561
- layers.append(nn.BatchNorm1d(num_features=chout))
562
- layers.append(activation())
563
- if dropout > 0:
564
- layers.append(nn.Dropout(dropout))
565
-
566
- dilation *= dilation_growth
567
-
568
- self.sequence.append(nn.Sequential(*layers))
569
-
570
- # Add skip projection if channels don't match (for growth != 1.0)
571
- if chin != chout:
572
- self.skip_projections.append(nn.Conv1d(chin, chout, 1))
573
- else:
574
- self.skip_projections.append(None)
575
-
576
- # GLU gating every N layers
577
- if glu > 0 and (k + 1) % glu == 0:
578
- self.glus.append(
579
- nn.Sequential(
580
- nn.Conv1d(
581
- chout, 2 * chout, 1 + 2 * glu_context, padding=glu_context
582
- ),
583
- nn.GLU(dim=1),
584
- )
585
- )
586
- else:
587
- self.glus.append(None)
588
-
589
- def forward(self, x: torch.Tensor) -> torch.Tensor:
590
- for module, glu, skip_proj in zip(
591
- self.sequence, self.glus, self.skip_projections
592
- ):
593
- # Apply residual connection
594
- # If channels match, add directly; otherwise use projection
595
- if skip_proj is not None:
596
- x = skip_proj(x) + module(x)
597
- else:
598
- x = x + module(x)
599
- if glu is not None:
600
- x = glu(x)
601
- return x
602
-
603
-
604
- def _pad_multiple(x: torch.Tensor, base: int) -> torch.Tensor:
605
- """Pad tensor to be a multiple of base."""
606
- length = x.shape[-1]
607
- target = math.ceil(length / base) * base
608
- return F.pad(x, (0, target - length))
609
-
610
-
611
- class _ScaledEmbedding(nn.Module):
612
- """Scaled embedding layer for subjects.
613
-
614
- Scales up the learning rate for the embedding to prevent slow convergence.
615
- Used for subject-specific representations.
616
- """
617
-
618
- def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0):
619
- super().__init__()
620
- self.embedding = nn.Embedding(num_embeddings, embedding_dim)
621
- self.embedding.weight.data /= scale
622
- self.scale = scale
623
-
624
- @property
625
- def weight(self) -> torch.Tensor:
626
- """Get scaled embedding weights."""
627
- return self.embedding.weight * self.scale
628
-
629
- def forward(self, x: torch.Tensor) -> torch.Tensor:
630
- """Forward pass.
631
-
632
- Parameters
633
- ----------
634
- x : torch.Tensor
635
- Subject indices of shape (batch,).
636
-
637
- Returns
638
- -------
639
- torch.Tensor
640
- Scaled embeddings of shape (batch, embedding_dim).
641
- """
642
- return self.embedding(x) * self.scale
643
-
644
-
645
- class _ChannelDropout(nn.Module):
646
- """Channel dropout with rescaling and optional ch_info support.
647
-
648
- Randomly drops channels during training and rescales output to maintain
649
- expected value. Optionally supports selective channel dropout based on
650
- channel type (EEG, reference, EOG, etc.) using ch_info metadata.
651
-
652
- Parameters
653
- ----------
654
- dropout_prob : float, default=0.0
655
- Probability of dropping each channel (0.0 to 1.0).
656
- If 0.0, no dropout is applied.
657
- ch_info : list of dict, optional
658
- Channel information from MNE (e.g., from raw.info['chs']).
659
- Each dict should have 'ch_name' and 'ch_type' keys.
660
- If provided, enables selective channel dropout by type.
661
- channel_type : str, optional
662
- If specified with ch_info, only drop channels of this type
663
- (e.g., 'eeg', 'ref', 'eog'). If None, drop from all available channels.
664
- rescale : bool, default=True
665
- If True, rescale output to maintain expected value.
666
- scale_factor = n_channels / (n_channels - n_dropped)
667
-
668
- Examples
669
- --------
670
- >>> # Random channel dropout
671
- >>> dropout = _ChannelDropout(dropout_prob=0.1)
672
- >>> x = torch.randn(4, 32, 1000)
673
- >>> x_dropped = dropout(x)
674
-
675
- >>> # Selective EEG dropout using ch_info
676
- >>> ch_info = [{'ch_name': 'Fp1', 'ch_type': 'eeg'}, ...]
677
- >>> dropout_eeg = _ChannelDropout(
678
- ... dropout_prob=0.1,
679
- ... ch_info=ch_info,
680
- ... channel_type='eeg' # Only drop EEG channels
681
- ... )
682
- >>> x_dropped = dropout_eeg(x) # Reference channels never dropped
683
- """
684
-
685
- def __init__(
686
- self,
687
- dropout_prob: float = 0.0,
688
- ch_info: list[dict] | None = None,
689
- channel_type: str | None = None,
690
- rescale: bool = True,
691
- ):
692
- super().__init__()
693
-
694
- if not 0.0 <= dropout_prob <= 1.0:
695
- raise ValueError(f"dropout_prob must be in [0.0, 1.0], got {dropout_prob}")
696
- if channel_type is not None and ch_info is None:
697
- raise ValueError("channel_type requires ch_info to be provided")
698
-
699
- self.dropout_prob = dropout_prob
700
- self.rescale = rescale
701
- self.ch_info = ch_info
702
- self.channel_type = channel_type
703
-
704
- # Compute droppable channel indices
705
- self.droppable_indices: list[int] | None = None
706
- if ch_info is not None:
707
- if channel_type is not None:
708
- # Drop only specific type
709
- self.droppable_indices = [
710
- i
711
- for i, ch in enumerate(ch_info)
712
- if ch.get("ch_type", ch.get("kind")) == channel_type
713
- ]
714
- else:
715
- # Drop any channel
716
- self.droppable_indices = list(range(len(ch_info)))
717
-
718
- def forward(self, x: torch.Tensor) -> torch.Tensor:
719
- """Forward pass with channel dropout.
720
-
721
- Parameters
722
- ----------
723
- x : torch.Tensor
724
- Input of shape (batch, channels, time).
725
-
726
- Returns
727
- -------
728
- torch.Tensor
729
- Output of same shape as input, with selected channels randomly zeroed.
730
- """
731
- if not self.training or self.dropout_prob == 0:
732
- return x
733
-
734
- _, channels, _ = x.shape
735
-
736
- # Determine which channels to drop
737
- if self.droppable_indices is not None:
738
- # Only drop from specified indices
739
- n_droppable = len(self.droppable_indices)
740
- n_to_drop = max(1, int(n_droppable * self.dropout_prob))
741
- if n_to_drop > 0:
742
- drop_indices = torch.tensor(
743
- self.droppable_indices, device=x.device, dtype=torch.long
744
- )
745
- # Randomly select which droppable indices to actually drop
746
- selected = torch.randperm(n_droppable, device=x.device)[:n_to_drop]
747
- drop_indices = drop_indices[selected]
748
- else:
749
- drop_indices = torch.tensor([], device=x.device, dtype=torch.long)
750
- else:
751
- # Drop from any channel
752
- n_to_drop = max(1, int(channels * self.dropout_prob))
753
- drop_indices = torch.randperm(channels, device=x.device)[:n_to_drop]
754
-
755
- # Clone and apply dropout
756
- if len(drop_indices) > 0:
757
- x_out = x.clone()
758
- x_out[:, drop_indices, :] = 0
759
-
760
- # Rescale to maintain expected value
761
- if self.rescale:
762
- scale_factor = channels / (channels - len(drop_indices))
763
- x_out = x_out * scale_factor
764
- else:
765
- x_out = x
766
-
767
- return x_out
768
-
769
- def __repr__(self) -> str:
770
- return (
771
- f"ChannelDropout(dropout_prob={self.dropout_prob}, "
772
- f"rescale={self.rescale}, "
773
- f"channel_type={self.channel_type})"
774
- )
775
-
776
-
777
- def _validate_brainmodule_params(
778
- subject_layers: bool,
779
- subject_dim: int,
780
- depth: int,
781
- kernel_size: int,
782
- growth: float,
783
- dilation_growth: int,
784
- channel_dropout_prob: float,
785
- channel_dropout_type: str | None,
786
- glu: int,
787
- glu_context: int,
788
- ) -> None:
789
- """Validate BrainModule parameters.
790
-
791
- Parameters
792
- ----------
793
- subject_layers : bool
794
- Whether to use subject-specific layer transformations.
795
- subject_dim : int
796
- Dimension of subject embeddings.
797
- depth : int
798
- Number of convolutional blocks.
799
- kernel_size : int
800
- Convolutional kernel size.
801
- growth : float
802
- Channel size multiplier per layer.
803
- dilation_growth : int
804
- Dilation multiplier per layer.
805
- channel_dropout_prob : float
806
- Channel dropout probability.
807
- channel_dropout_type : str or None
808
- Channel type to selectively drop.
809
- glu : int
810
- GLU gating interval.
811
- glu_context : int
812
- GLU context window size.
813
-
814
- Raises
815
- ------
816
- ValueError
817
- If any parameter combination is invalid.
818
- """
819
- validations = [
820
- (
821
- subject_layers and subject_dim == 0,
822
- "subject_layers=True requires subject_dim > 0",
823
- ),
824
- (depth < 1, "depth must be >= 1"),
825
- (kernel_size <= 0, "kernel_size must be > 0"),
826
- (kernel_size % 2 == 0, "kernel_size must be odd for proper padding"),
827
- (growth <= 0, "growth must be > 0"),
828
- (dilation_growth < 1, "dilation_growth must be >= 1"),
829
- (
830
- not 0.0 <= channel_dropout_prob <= 1.0,
831
- "channel_dropout_prob must be in [0.0, 1.0]",
832
- ),
833
- (
834
- channel_dropout_type is not None and channel_dropout_prob == 0.0,
835
- "channel_dropout_type requires channel_dropout_prob > 0",
836
- ),
837
- (glu < 0, "glu must be >= 0"),
838
- (glu_context < 0, "glu_context must be >= 0"),
839
- (glu_context > 0 and glu == 0, "glu_context > 0 requires glu > 0"),
840
- (glu_context >= kernel_size, "glu_context must be < kernel_size"),
841
- ]
842
-
843
- for condition, message in validations:
844
- if condition:
845
- raise ValueError(message)