braindecode 1.3.0.dev181065563__py3-none-any.whl → 1.3.0.dev183667303__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (56) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/augmentation/functional.py +154 -54
  3. braindecode/augmentation/transforms.py +2 -2
  4. braindecode/datasets/__init__.py +10 -2
  5. braindecode/datasets/base.py +116 -152
  6. braindecode/datasets/bcicomp.py +4 -4
  7. braindecode/datasets/bids.py +3 -3
  8. braindecode/datasets/experimental.py +2 -2
  9. braindecode/datasets/mne.py +3 -5
  10. braindecode/datasets/moabb.py +2 -2
  11. braindecode/datasets/nmt.py +2 -2
  12. braindecode/datasets/sleep_physio_challe_18.py +4 -3
  13. braindecode/datasets/sleep_physionet.py +2 -2
  14. braindecode/datasets/tuh.py +2 -2
  15. braindecode/datasets/xy.py +2 -2
  16. braindecode/datautil/serialization.py +18 -13
  17. braindecode/eegneuralnet.py +2 -0
  18. braindecode/functional/functions.py +6 -2
  19. braindecode/functional/initialization.py +2 -3
  20. braindecode/models/__init__.py +6 -0
  21. braindecode/models/atcnet.py +32 -33
  22. braindecode/models/attentionbasenet.py +39 -32
  23. braindecode/models/base.py +280 -2
  24. braindecode/models/bendr.py +469 -0
  25. braindecode/models/biot.py +3 -1
  26. braindecode/models/ctnet.py +6 -3
  27. braindecode/models/deepsleepnet.py +27 -18
  28. braindecode/models/eegconformer.py +2 -2
  29. braindecode/models/eeginception_erp.py +31 -25
  30. braindecode/models/eegnet.py +1 -1
  31. braindecode/models/labram.py +193 -87
  32. braindecode/models/patchedtransformer.py +640 -0
  33. braindecode/models/signal_jepa.py +109 -27
  34. braindecode/models/sinc_shallow.py +10 -9
  35. braindecode/models/sstdpn.py +869 -0
  36. braindecode/models/summary.csv +9 -6
  37. braindecode/models/usleep.py +26 -21
  38. braindecode/models/util.py +3 -0
  39. braindecode/modules/attention.py +10 -10
  40. braindecode/modules/blocks.py +3 -3
  41. braindecode/modules/filter.py +2 -3
  42. braindecode/modules/layers.py +18 -17
  43. braindecode/preprocessing/__init__.py +24 -0
  44. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  45. braindecode/preprocessing/preprocess.py +23 -14
  46. braindecode/preprocessing/util.py +166 -0
  47. braindecode/preprocessing/windowers.py +24 -19
  48. braindecode/samplers/base.py +8 -8
  49. braindecode/version.py +1 -1
  50. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/METADATA +6 -2
  51. braindecode-1.3.0.dev183667303.dist-info/RECORD +106 -0
  52. braindecode-1.3.0.dev181065563.dist-info/RECORD +0 -101
  53. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/WHEEL +0 -0
  54. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/LICENSE.txt +0 -0
  55. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/NOTICE.txt +0 -0
  56. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,469 @@
1
+ import copy
2
+
3
+ import torch
4
+ from einops.layers.torch import Rearrange
5
+ from torch import nn
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class BENDR(EEGModuleMixin, nn.Module):
11
+ """BENDR (BErt-inspired Neural Data Representations) from Kostas et al. (2021) [bendr]_.
12
+
13
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
14
+
15
+ .. figure:: https://www.frontiersin.org/files/Articles/653659/fnhum-15-653659-HTML/image_m/fnhum-15-653659-g001.jpg
16
+ :align: center
17
+ :alt: BENDR Architecture
18
+ :width: 1000px
19
+
20
+
21
+ The **BENDR** architecture adapts techniques used for language modeling (LM) toward the
22
+ development of encephalography modeling (EM) [bendr]_. It utilizes a self-supervised
23
+ training objective to learn compressed representations of raw EEG signals [bendr]_. The
24
+ model is capable of modeling completely novel raw EEG sequences recorded with differing
25
+ hardware and subjects, aiming for transferable performance across a variety of downstream
26
+ BCI and EEG classification tasks [bendr]_.
27
+
28
+ .. rubric:: Architectural Overview
29
+
30
+ BENDR is adapted from wav2vec 2.0 [wav2vec2]_ and is composed of two main stages: a
31
+ feature extractor (Convolutional stage) that produces BErt-inspired Neural Data
32
+ Representations (BENDR), followed by a transformer encoder (Contextualizer) [bendr]_.
33
+
34
+ .. rubric:: Macro Components
35
+
36
+ - `BENDR.encoder` **(Convolutional Stage/Feature Extractor)**
37
+ - *Operations.* A stack of six short-receptive field 1D convolutions [bendr]_. Each
38
+ block consists of 1D convolution, GroupNorm, and GELU activation.
39
+ - *Role.* Takes raw data :math:`X_{raw}` and dramatically downsamples it to a new
40
+ sequence of vectors (BENDR) [bendr]_. Each resulting vector has a length of 512.
41
+ - `BENDR.contextualizer` **(Transformer Encoder)**
42
+ - *Operations.* A transformer encoder that uses layered, multi-head self-attention
43
+ [bendr]_. It employs T-Fixup weight initialization [tfixup]_ and uses 8 layers
44
+ and 8 heads.
45
+ - *Role.* Maps the sequence of BENDR vectors to a contextualized sequence. The output
46
+ of a fixed start token is typically used as the aggregate representation for
47
+ downstream classification [bendr]_.
48
+ - `Contextualizer.position_encoder` **(Positional Encoding)**
49
+ - *Operations.* An additive (grouped) convolution layer with a receptive field of 25
50
+ and 16 groups [bendr]_.
51
+ - *Role.* Encodes position information before the input enters the transformer.
52
+
53
+ .. rubric:: How the information is encoded temporally, spatially, and spectrally
54
+
55
+ * **Temporal.**
56
+ The convolutional encoder uses a stack of blocks where the stride matches the receptive
57
+ field (e.g., 3 for the first block, 2 for subsequent blocks) [bendr]_. This process
58
+ downsamples the raw data by a factor of 96, resulting in an effective sampling frequency
59
+ of approximately 2.67 Hz.
60
+ * **Spatial.**
61
+ To maintain simplicity and reduce complexity, the convolutional stage uses **1D
62
+ convolutions** and elects not to mix EEG channels across the first stage [bendr]_. The
63
+ input includes 20 channels (19 EEG channels and one relative amplitude channel).
64
+ * **Spectral.**
65
+ The convolution operations implicitly extract features from the raw EEG signal [bendr]_.
66
+ The representations (BENDR) are derived from the raw waveform using convolutional
67
+ operations followed by sequence modeling [wav2vec2]_.
68
+
69
+ .. rubric:: Additional Mechanisms
70
+
71
+ - **Self-Supervision (Pre-training).** Uses a masked sequence learning approach (adapted
72
+ from wav2vec 2.0 [wav2vec2]_) where contiguous spans of BENDR sequences are masked, and
73
+ the model attempts to reconstruct the original underlying encoded vector based on the
74
+ transformer output and a set of negative distractors [bendr]_.
75
+ - **Regularization.** LayerDrop [layerdrop]_ and Dropout (at probabilities 0.01 and 0.15,
76
+ respectively) are used during pre-training [bendr]_. The implementation also uses T-Fixup
77
+ scaling for parameter initialization [tfixup]_.
78
+ - **Input Conditioning.** A fixed token (a vector filled with the value **-5**) is
79
+ prepended to the BENDR sequence before input to the transformer, serving as the aggregate
80
+ representation token [bendr]_.
81
+
82
+ Notes
83
+ -----
84
+ * The full BENDR architecture contains a large number of parameters; configuration (1)
85
+ involved training over **one billion parameters** [bendr]_.
86
+ * Randomly initialized full BENDR architecture was generally ineffective at solving
87
+ downstream tasks without prior self-supervised training [bendr]_.
88
+ * The pre-training task (contrastive predictive coding via masking) is generalizable,
89
+ exhibiting strong uniformity of performance across novel subjects, hardware, and
90
+ tasks [bendr]_.
91
+
92
+ .. warning::
93
+
94
+ **Important:** To utilize the full potential of BENDR, the model requires
95
+ **self-supervised pre-training** on large, unlabeled EEG datasets (like TUEG) followed
96
+ by subsequent fine-tuning on the specific downstream classification task [bendr]_.
97
+
98
+ Parameters
99
+ ----------
100
+ encoder_h : int, default=512
101
+ Hidden size (number of output channels) of the convolutional encoder. This determines
102
+ the dimensionality of the BENDR feature vectors produced by the encoder.
103
+ contextualizer_hidden : int, default=3076
104
+ Hidden size of the feedforward layer within each transformer block. The paper uses
105
+ approximately 2x the transformer dimension (3076 ~ 2 x 1536).
106
+ projection_head : bool, default=False
107
+ If True, adds a projection layer at the end of the encoder to project back to the
108
+ input feature size. This is used during self-supervised pre-training but typically
109
+ disabled during fine-tuning.
110
+ drop_prob : float, default=0.1
111
+ Dropout probability applied throughout the model. The paper recommends 0.15 for
112
+ pre-training and 0.0 for fine-tuning. Default is 0.1 as a compromise.
113
+ layer_drop : float, default=0.0
114
+ Probability of dropping entire transformer layers during training (LayerDrop
115
+ regularization [layerdrop]_). The paper uses 0.01 for pre-training and 0.0 for
116
+ fine-tuning.
117
+ activation : :class:`torch.nn.Module`, default=:class:`torch.nn.GELU`
118
+ Activation function used in the encoder convolutional blocks. The paper uses GELU
119
+ activation throughout.
120
+ transformer_layers : int, default=8
121
+ Number of transformer encoder layers in the contextualizer. The paper uses 8 layers.
122
+ transformer_heads : int, default=8
123
+ Number of attention heads in each transformer layer. The paper uses 8 heads with
124
+ head dimension of 192 (1536 / 8).
125
+ position_encoder_length : int, default=25
126
+ Kernel size for the convolutional positional encoding layer. The paper uses a
127
+ receptive field of 25 with 16 groups.
128
+ enc_width : tuple of int, default=(3, 2, 2, 2, 2, 2)
129
+ Kernel sizes for each of the 6 convolutional blocks in the encoder. Each value
130
+ corresponds to one block.
131
+ enc_downsample : tuple of int, default=(3, 2, 2, 2, 2, 2)
132
+ Stride values for each of the 6 convolutional blocks in the encoder. The total
133
+ downsampling factor is the product of all strides (3 x 2 x 2 x 2 x 2 x 2 = 96).
134
+ start_token : int or float, default=-5
135
+ Value used to fill the start token embedding that is prepended to the BENDR sequence
136
+ before input to the transformer. This token's output is used as the aggregate
137
+ representation for classification.
138
+ final_layer : bool, default=True
139
+ If True, includes a final linear classification layer that maps from encoder_h to
140
+ n_outputs. If False, the model outputs the contextualized features directly.
141
+
142
+ References
143
+ ----------
144
+ .. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
145
+ BENDR: Using transformers and a contrastive self-supervised learning task to learn from
146
+ massive amounts of EEG data.
147
+ Frontiers in Human Neuroscience, 15, 653659.
148
+ https://doi.org/10.3389/fnhum.2021.653659
149
+ .. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
150
+ wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
151
+ In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
152
+ Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
153
+ https://dl.acm.org/doi/10.5555/3495724.3496768
154
+ .. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
155
+ Improving Transformer Optimization Through Better Initialization.
156
+ In International Conference on Machine Learning (pp. 4475-4483). PMLR.
157
+ https://dl.acm.org/doi/10.5555/3524938.3525354
158
+ .. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
159
+ Reducing Transformer Depth on Demand with Structured Dropout.
160
+ International Conference on Learning Representations.
161
+ Retrieved from https://openreview.net/forum?id=SylO2yStDr
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ # Signal related parameters
167
+ n_chans=None,
168
+ n_outputs=None,
169
+ n_times=None,
170
+ chs_info=None,
171
+ input_window_seconds=None,
172
+ sfreq=None,
173
+ # Model parameters
174
+ encoder_h=512, # Hidden size of the encoder convolutional layers
175
+ contextualizer_hidden=3076, # Feedforward hidden size in transformer
176
+ projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
177
+ drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
178
+ layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
179
+ activation=nn.GELU, # Activation function
180
+ # Transformer specific parameters
181
+ transformer_layers=8,
182
+ transformer_heads=8,
183
+ position_encoder_length=25, # Kernel size for positional encoding conv
184
+ enc_width=(3, 2, 2, 2, 2, 2),
185
+ enc_downsample=(3, 2, 2, 2, 2, 2),
186
+ start_token=-5, # Value for start token embedding
187
+ final_layer=True, # Whether to include the final linear layer
188
+ ):
189
+ super().__init__(
190
+ n_outputs=n_outputs,
191
+ n_chans=n_chans,
192
+ chs_info=chs_info,
193
+ n_times=n_times,
194
+ input_window_seconds=input_window_seconds,
195
+ sfreq=sfreq,
196
+ )
197
+
198
+ # Keep these parameters if needed later, otherwise they are captured by the mixin
199
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
200
+
201
+ self.encoder_h = encoder_h
202
+ self.contextualizer_hidden = contextualizer_hidden
203
+ self.include_final_layer = final_layer
204
+
205
+ # Encoder: Use parameters from __init__
206
+ self.encoder = _ConvEncoderBENDR(
207
+ in_features=self.n_chans,
208
+ encoder_h=encoder_h,
209
+ dropout=drop_prob,
210
+ projection_head=projection_head,
211
+ enc_width=enc_width,
212
+ enc_downsample=enc_downsample,
213
+ activation=activation,
214
+ )
215
+
216
+ self.contextualizer = _BENDRContextualizer(
217
+ in_features=self.encoder.encoder_h, # Use the output feature size of the encoder
218
+ hidden_feedforward=contextualizer_hidden,
219
+ heads=transformer_heads,
220
+ layers=transformer_layers,
221
+ dropout=drop_prob, # Use general dropout probability
222
+ activation=activation,
223
+ position_encoder=position_encoder_length, # Pass position encoder kernel size
224
+ layer_drop=layer_drop,
225
+ start_token=start_token, # Keep fixed start token value
226
+ )
227
+ in_features = self.encoder.encoder_h # Set in_features for final layer
228
+
229
+ self.final_layer = None
230
+ if self.include_final_layer:
231
+ # Input to Linear will be [batch_size, encoder_h] after taking last timestep
232
+ linear = nn.Linear(in_features=in_features, out_features=self.n_outputs)
233
+ self.final_layer = nn.utils.parametrizations.weight_norm(
234
+ linear, name="weight", dim=1
235
+ )
236
+
237
+ def forward(self, x):
238
+ # Input x: [batch_size, n_chans, n_times]
239
+ encoded = self.encoder(x)
240
+ # encoded: [batch_size, encoder_h, n_encoded_times]
241
+
242
+ context = self.contextualizer(encoded)
243
+ # context: [batch_size, encoder_h, n_encoded_times + 1] (due to start token)
244
+
245
+ # Extract features - use the output corresponding to the start token (index 0)
246
+ # The output has shape [batch_size, features, seq_len+1]
247
+ # The start token was prepended at position 0 in the contextualizer before the
248
+ # transformer layers. After permutation back to [batch, features, seq_len+1],
249
+ # the start token output is at index 0.
250
+ # According to the paper (Kostas et al. 2021): "The transformer output of this
251
+ # initial position was not modified during pre-training, and only used for
252
+ # downstream tasks." This follows BERT's [CLS] token convention where the
253
+ # first token aggregates sequence information via self-attention.
254
+ feature = context[:, :, 0]
255
+ # feature: [batch_size, encoder_h]
256
+
257
+ if self.final_layer is not None:
258
+ feature = self.final_layer(feature)
259
+ # feature: [batch_size, n_outputs]
260
+
261
+ return feature
262
+
263
+
264
+ class _ConvEncoderBENDR(nn.Module):
265
+ def __init__(
266
+ self,
267
+ in_features,
268
+ encoder_h=512,
269
+ enc_width=(3, 2, 2, 2, 2, 2),
270
+ dropout=0.0,
271
+ projection_head=False,
272
+ enc_downsample=(3, 2, 2, 2, 2, 2),
273
+ activation=nn.GELU,
274
+ ):
275
+ super().__init__()
276
+ self.encoder_h = encoder_h
277
+ self.in_features = in_features
278
+
279
+ if not isinstance(enc_width, (list, tuple)):
280
+ enc_width = [enc_width]
281
+ if not isinstance(enc_downsample, (list, tuple)):
282
+ enc_downsample = [enc_downsample]
283
+ if len(enc_downsample) != len(enc_width):
284
+ raise ValueError(
285
+ "Encoder width and downsampling factors must have the same length."
286
+ )
287
+
288
+ # Centerable convolutions make life simpler
289
+ enc_width = [
290
+ e if e % 2 != 0 else e + 1 for e in enc_width
291
+ ] # Ensure odd kernel size
292
+ self._downsampling = enc_downsample
293
+ self._width = enc_width
294
+
295
+ current_in_features = in_features
296
+ self.encoder = nn.Sequential()
297
+ for i, (width, downsample) in enumerate(zip(enc_width, enc_downsample)):
298
+ self.encoder.add_module(
299
+ "Encoder_{}".format(i),
300
+ nn.Sequential(
301
+ nn.Conv1d(
302
+ current_in_features,
303
+ encoder_h,
304
+ width,
305
+ stride=downsample,
306
+ padding=width
307
+ // 2, # Correct padding for 'same' output length before stride
308
+ ),
309
+ nn.Dropout2d(dropout), # 2D dropout (matches paper specification)
310
+ nn.GroupNorm(
311
+ encoder_h // 2, encoder_h
312
+ ), # Consider making num_groups configurable or ensure encoder_h is divisible by 2
313
+ activation(),
314
+ ),
315
+ )
316
+ current_in_features = encoder_h
317
+ if projection_head:
318
+ self.encoder.add_module(
319
+ "projection_head",
320
+ nn.Conv1d(encoder_h, encoder_h, 1),
321
+ )
322
+
323
+ def forward(self, x):
324
+ return self.encoder(x)
325
+
326
+
327
+ class _BENDRContextualizer(nn.Module):
328
+ """Transformer-based contextualizer for BENDR."""
329
+
330
+ def __init__(
331
+ self,
332
+ in_features,
333
+ hidden_feedforward=3076,
334
+ heads=8,
335
+ layers=8,
336
+ dropout=0.1, # Default dropout
337
+ activation=nn.GELU, # Activation for transformer FF layer
338
+ position_encoder=25, # Kernel size for conv positional encoding
339
+ layer_drop=0.0, # Probability of dropping a whole layer
340
+ start_token=-5, # Value for start token embedding
341
+ ):
342
+ super().__init__()
343
+ self.dropout = dropout
344
+ self.layer_drop = layer_drop
345
+ self.start_token = start_token # Store start token value
346
+
347
+ # Paper specification: transformer_dim = 3 * encoder_h (1536 for encoder_h=512)
348
+ self.transformer_dim = 3 * in_features
349
+ self.in_features = in_features
350
+ # --- Positional Encoding --- (Applied before projection)
351
+ self.position_encoder = None
352
+ if position_encoder > 0:
353
+ conv = nn.Conv1d(
354
+ in_features,
355
+ in_features,
356
+ kernel_size=position_encoder,
357
+ padding=position_encoder // 2,
358
+ groups=16, # Number of groups for depthwise separation
359
+ )
360
+ # T-Fixup positional encoding initialization (paper specification)
361
+ nn.init.normal_(conv.weight, mean=0, std=2 / self.transformer_dim)
362
+ nn.init.constant_(conv.bias, 0)
363
+
364
+ conv = nn.utils.parametrizations.weight_norm(conv, name="weight", dim=2)
365
+ self.relative_position = nn.Sequential(conv, activation())
366
+
367
+ # --- Input Conditioning --- (Includes projection up to transformer_dim)
368
+ # Rearrange, Norm, Dropout, Project, Rearrange
369
+ self.input_conditioning = nn.Sequential(
370
+ Rearrange(
371
+ "batch channel time -> batch time channel"
372
+ ), # Batch, Time, Channels
373
+ nn.LayerNorm(in_features),
374
+ nn.Dropout(dropout),
375
+ nn.Linear(in_features, self.transformer_dim), # Project up using Linear
376
+ # nn.Conv1d(in_features, self.transformer_dim, 1), # Alternative: Project up using Conv1d
377
+ Rearrange(
378
+ "batch time channel -> time batch channel"
379
+ ), # Time, Batch, Channels (Transformer expected format)
380
+ )
381
+
382
+ # --- Transformer Encoder Layers ---
383
+ # Paper uses T-Fixup: remove internal LayerNorm layers
384
+
385
+ encoder_layer = nn.TransformerEncoderLayer(
386
+ d_model=self.transformer_dim, # Use projected dimension
387
+ nhead=heads,
388
+ dim_feedforward=hidden_feedforward,
389
+ dropout=dropout, # Dropout within transformer layer
390
+ activation=activation(),
391
+ batch_first=False, # Expects (T, B, C)
392
+ norm_first=False, # Standard post-norm architecture
393
+ )
394
+
395
+ # T-Fixup: Replace internal LayerNorms with Identity
396
+ encoder_layer.norm1 = nn.Identity()
397
+ encoder_layer.norm2 = nn.Identity()
398
+
399
+ self.transformer_layers = nn.ModuleList(
400
+ [copy.deepcopy(encoder_layer) for _ in range(layers)]
401
+ )
402
+
403
+ # Add final LayerNorm after all transformer layers (paper specification)
404
+ self.norm = nn.LayerNorm(self.transformer_dim)
405
+
406
+ # --- Output Layer --- (Project back down to in_features)
407
+ # Paper uses Conv1d with kernel=1 (equivalent to Linear but expects (B,C,T) input)
408
+ self.output_layer = nn.Conv1d(self.transformer_dim, in_features, 1)
409
+
410
+ # Initialize parameters with T-Fixup (paper specification)
411
+ self.apply(self._init_bert_params)
412
+
413
+ def _init_bert_params(self, module):
414
+ """Initialize linear layers and apply T-Fixup scaling."""
415
+ if isinstance(module, nn.Linear):
416
+ # Standard init
417
+ nn.init.xavier_uniform_(module.weight.data)
418
+ if module.bias is not None:
419
+ module.bias.data.zero_()
420
+ # Tfixup Scaling
421
+ module.weight.data = (
422
+ 0.67 * len(self.transformer_layers) ** (-0.25) * module.weight.data
423
+ )
424
+ elif isinstance(module, nn.LayerNorm):
425
+ module.bias.data.zero_()
426
+ module.weight.data.fill_(1.0)
427
+
428
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
429
+ # Input x: [batch_size, in_features, seq_len]
430
+
431
+ # Apply relative positional encoding
432
+ if hasattr(self, "relative_position"):
433
+ pos_enc = self.relative_position(x)
434
+ x = x + pos_enc
435
+
436
+ # Apply input conditioning (includes projection up and rearrange)
437
+ x = self.input_conditioning(x)
438
+ # x: [seq_len, batch_size, transformer_dim]
439
+
440
+ # Prepend start token
441
+ if self.start_token is not None:
442
+ token_emb = torch.full(
443
+ (1, x.shape[1], x.shape[2]),
444
+ float(self.start_token),
445
+ device=x.device,
446
+ requires_grad=False,
447
+ )
448
+ x = torch.cat([token_emb, x], dim=0)
449
+ # x: [seq_len + 1, batch_size, transformer_dim]
450
+
451
+ # Apply transformer layers with layer drop
452
+ for layer in self.transformer_layers:
453
+ if not self.training or torch.rand(1) > self.layer_drop:
454
+ x = layer(x)
455
+ # x: [seq_len + 1, batch_size, transformer_dim]
456
+
457
+ # Apply final LayerNorm (T-Fixup: norm only at the end)
458
+ x = self.norm(x)
459
+ # x: [seq_len + 1, batch_size, transformer_dim]
460
+
461
+ # Permute to (B, C, T) format for Conv1d output layer
462
+ x = Rearrange("time batch channel -> batch channel time")(x)
463
+ # x: [batch_size, transformer_dim, seq_len + 1]
464
+
465
+ # Apply output projection (Conv1d expects B, C, T)
466
+ x = self.output_layer(x)
467
+ # x: [batch_size, in_features, seq_len + 1]
468
+
469
+ return x
@@ -11,13 +11,15 @@ from braindecode.models.base import EEGModuleMixin
11
11
  class BIOT(EEGModuleMixin, nn.Module):
12
12
  """BIOT from Yang et al. (2023) [Yang2023]_
13
13
 
14
+ :bdg-danger:`Large Brain Model`
15
+
14
16
  .. figure:: https://braindecode.org/dev/_static/model/biot.jpg
15
17
  :align: center
16
18
  :alt: BioT
17
19
 
18
20
  BIOT: Cross-data Biosignal Learning in the Wild.
19
21
 
20
- BIOT is a large language model for biosignal classification. It is
22
+ BIOT is a large brain model for biosignal classification. It is
21
23
  a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
22
24
 
23
25
  It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
@@ -39,16 +39,19 @@ class CTNet(EEGModuleMixin, nn.Module):
39
39
  The architecture consists of three main components:
40
40
 
41
41
  1. **Convolutional Module**:
42
+
42
43
  - Apply :class:`EEGNet` to perform some feature extraction, denoted here as
43
- _PatchEmbeddingEEGNet module.
44
+ _PatchEmbeddingEEGNet module.
44
45
 
45
46
  2. **Transformer Encoder Module**:
47
+
46
48
  - Utilizes multi-head self-attention mechanisms as EEGConformer but
47
- with residual blocks.
49
+ with residual blocks.
48
50
 
49
51
  3. **Classifier Module**:
52
+
50
53
  - Combines features from both the convolutional module
51
- and the Transformer encoder.
54
+ and the Transformer encoder.
52
55
  - Flattens the combined features and applies dropout for regularization.
53
56
  - Uses a fully connected layer to produce the final classification output.
54
57
 
@@ -43,16 +43,21 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
43
43
 
44
44
  - :class:`_RepresentationLearning` **(dual-path CNN → epoch feature)**
45
45
 
46
- - *Operations.*
47
- - **Small-filter CNN** 4 times:
48
- - :class:`~torch.nn.Conv1d`
49
- - :class:`~torch.nn.BatchNorm1d`
50
- - :class:`~torch.nn.ReLU`
51
- - :class:`~torch.nn.MaxPool1d` after.
52
- First conv uses **filter length ≈ Fs/2** and **stride ≈ Fs/16** to emphasize *timing* of graphoelements.
46
+ - *Operations.*
47
+ - **Small-filter CNN** 4 times:
48
+
49
+ - :class:`~torch.nn.Conv1d`
50
+ - :class:`~torch.nn.BatchNorm1d`
51
+ - :class:`~torch.nn.ReLU`
52
+ - :class:`~torch.nn.MaxPool1d` after.
53
+
54
+ - First conv uses **filter length ≈ Fs/2** and **stride ≈ Fs/16** to emphasize *timing* of graphoelements.
55
+
53
56
  - **Large-filter CNN**:
57
+
54
58
  - Same stack but first conv uses **filter length ≈ 4·Fs** and
55
59
  - **stride ≈ Fs/2** to emphasize *frequency* content.
60
+
56
61
  - Outputs from both paths are **concatenated** into the epoch embedding ``a_t``.
57
62
 
58
63
  - *Rationale.*
@@ -81,20 +86,22 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
81
86
 
82
87
  - **Temporal (where time-domain patterns are learned).**
83
88
 
84
- Both CNN paths use **1-D temporal convolutions**. The *small-filter* path (first kernel ≈ Fs/2,
85
- stride ≈ Fs/16) captures *when* characteristic transients occur; the *large-filter* path
86
- (first kernel ≈ 4·Fs, stride ≈ Fs/2) captures *which* frequency components dominate over the
87
- epoch. Deeper layers use **small kernels** to refine features with fewer parameters, interleaved
88
- with **max pooling** for downsampling.
89
+ Both CNN paths use **1-D temporal convolutions**. The *small-filter* path (first kernel ≈ Fs/2,
90
+ stride ≈ Fs/16) captures *when* characteristic transients occur; the *large-filter* path
91
+ (first kernel ≈ 4·Fs, stride ≈ Fs/2) captures *which* frequency components dominate over the
92
+ epoch. Deeper layers use **small kernels** to refine features with fewer parameters, interleaved
93
+ with **max pooling** for downsampling.
89
94
 
90
95
  - **Spatial (how channels are processed).**
91
- The original model operates on **single-channel** raw EEG; convolutions therefore mix only
92
- along time (no spatial convolution across electrodes).
96
+
97
+ The original model operates on **single-channel** raw EEG; convolutions therefore mix only
98
+ along time (no spatial convolution across electrodes).
93
99
 
94
100
  - **Spectral (how frequency information emerges).**
95
- No explicit Fourier/wavelet transform is used. The **large-filter path** serves as a
96
- *frequency-sensitive* analyzer, while the **small-filter path** remains *time-sensitive*,
97
- together functioning as a **two-band learned filter bank** at the first layer.
101
+
102
+ No explicit Fourier/wavelet transform is used. The **large-filter path** serves as a
103
+ *frequency-sensitive* analyzer, while the **small-filter path** remains *time-sensitive*,
104
+ together functioning as a **two-band learned filter bank** at the first layer.
98
105
 
99
106
  .. rubric:: Attention / Sequential Modules
100
107
 
@@ -112,9 +119,11 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
112
119
  - **Residual shortcut over sequence encoder.** Adds projected CNN features to BiLSTM outputs,
113
120
  improving gradient flow and retaining discriminative content from representation learning.
114
121
  - **Two-step training.**
122
+
115
123
  - (i) **Pretrain** the CNN paths with class-balanced sampling;
116
124
  - (ii) **fine-tune** the full network with sequential batches, using **lower LR** for CNNs and **higher LR** for the
117
- sequence encoder.
125
+ sequence encoder.
126
+
118
127
  - **State handling.** BiLSTM states are **reinitialized per subject** so that temporal context
119
128
  does not leak across recordings.
120
129
 
@@ -51,7 +51,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
51
51
  The result is rearranged to a token sequence ``(B, S_tokens, D)``, where ``D = n_filters_time``.
52
52
 
53
53
  *Interpretability/robustness.* Temporal kernels can be inspected as FIR filters;
54
- the spatial conv yields channel projections analogous to :class:`ShallowFBCSPNet`’s learned
54
+ the spatial conv yields channel projections analogous to :class:`ShallowFBCSPNet`'s learned
55
55
  spatial filters. Temporal pooling stabilizes statistics and reduces sequence length.
56
56
 
57
57
  - :class:`_TransformerEncoder` **(context over temporal tokens)**
@@ -119,7 +119,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
119
119
  capacity (but higher compute); larger strides → fewer tokens and stronger inductive bias.
120
120
 
121
121
  - **Embedding dimension = filters.** ``n_filters_time`` serves double duty as both the
122
- number of temporal filters in the stem and the transformers embedding size ``D``,
122
+ number of temporal filters in the stem and the transformer's embedding size ``D``,
123
123
  simplifying dimensional alignment.
124
124
 
125
125
  .. rubric:: Usage and Configuration