braindecode 1.3.0.dev181065563__py3-none-any.whl → 1.3.0.dev181594385__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (68) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/augmentation/functional.py +154 -54
  3. braindecode/augmentation/transforms.py +2 -2
  4. braindecode/datasets/__init__.py +10 -2
  5. braindecode/datasets/base.py +116 -152
  6. braindecode/datasets/bcicomp.py +4 -4
  7. braindecode/datasets/bids.py +3 -3
  8. braindecode/datasets/experimental.py +2 -2
  9. braindecode/datasets/mne.py +3 -5
  10. braindecode/datasets/moabb.py +2 -2
  11. braindecode/datasets/nmt.py +2 -2
  12. braindecode/datasets/sleep_physio_challe_18.py +4 -3
  13. braindecode/datasets/sleep_physionet.py +2 -2
  14. braindecode/datasets/tuh.py +2 -2
  15. braindecode/datasets/xy.py +2 -2
  16. braindecode/datautil/serialization.py +18 -13
  17. braindecode/eegneuralnet.py +2 -0
  18. braindecode/functional/functions.py +6 -2
  19. braindecode/functional/initialization.py +2 -3
  20. braindecode/models/__init__.py +6 -0
  21. braindecode/models/atcnet.py +33 -34
  22. braindecode/models/attentionbasenet.py +39 -32
  23. braindecode/models/attn_sleep.py +2 -0
  24. braindecode/models/base.py +280 -2
  25. braindecode/models/bendr.py +469 -0
  26. braindecode/models/biot.py +3 -1
  27. braindecode/models/contrawr.py +2 -0
  28. braindecode/models/ctnet.py +8 -3
  29. braindecode/models/deepsleepnet.py +28 -19
  30. braindecode/models/eegconformer.py +2 -2
  31. braindecode/models/eeginception_erp.py +31 -25
  32. braindecode/models/eegitnet.py +2 -0
  33. braindecode/models/eegminer.py +2 -0
  34. braindecode/models/eegnet.py +1 -1
  35. braindecode/models/eegtcnet.py +2 -0
  36. braindecode/models/fbcnet.py +2 -0
  37. braindecode/models/fblightconvnet.py +2 -0
  38. braindecode/models/fbmsnet.py +2 -0
  39. braindecode/models/ifnet.py +2 -0
  40. braindecode/models/labram.py +193 -87
  41. braindecode/models/msvtnet.py +2 -0
  42. braindecode/models/patchedtransformer.py +640 -0
  43. braindecode/models/signal_jepa.py +111 -27
  44. braindecode/models/sinc_shallow.py +12 -9
  45. braindecode/models/sstdpn.py +869 -0
  46. braindecode/models/summary.csv +9 -6
  47. braindecode/models/syncnet.py +2 -0
  48. braindecode/models/tcn.py +2 -0
  49. braindecode/models/usleep.py +26 -21
  50. braindecode/models/util.py +3 -0
  51. braindecode/modules/attention.py +10 -10
  52. braindecode/modules/blocks.py +3 -3
  53. braindecode/modules/filter.py +2 -3
  54. braindecode/modules/layers.py +18 -17
  55. braindecode/preprocessing/__init__.py +24 -0
  56. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  57. braindecode/preprocessing/preprocess.py +23 -14
  58. braindecode/preprocessing/util.py +166 -0
  59. braindecode/preprocessing/windowers.py +24 -19
  60. braindecode/samplers/base.py +8 -8
  61. braindecode/version.py +1 -1
  62. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/METADATA +6 -2
  63. braindecode-1.3.0.dev181594385.dist-info/RECORD +106 -0
  64. braindecode-1.3.0.dev181065563.dist-info/RECORD +0 -101
  65. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/WHEEL +0 -0
  66. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/LICENSE.txt +0 -0
  67. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/NOTICE.txt +0 -0
  68. {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev181594385.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,640 @@
1
+ # Authors: Jose Mauricio <josemaurici3991@gmail.com>
2
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from braindecode.models.base import EEGModuleMixin
14
+
15
+
16
+ class PBT(EEGModuleMixin, nn.Module):
17
+ r"""Patched Brain Transformer (PBT) model from Klein et al. (2025) [pbt]_.
18
+
19
+ :bdg-danger:`Large Brain Model`
20
+
21
+ This implementation was based in https://github.com/timonkl/PatchedBrainTransformer/
22
+
23
+ .. figure:: https://raw.githubusercontent.com/timonkl/PatchedBrainTransformer/refs/heads/main/PBT_sketch.png
24
+ :align: center
25
+ :alt: Patched Brain Transformer Architecture
26
+ :width: 680px
27
+
28
+ PBT tokenizes EEG trials into per-channel patches, linearly projects each
29
+ patch to a model embedding dimension, prepends a classification token and
30
+ adds channel-aware positional embeddings. The token sequence is processed
31
+ by a Transformer encoder stack and classification is performed from the
32
+ classification token.
33
+
34
+ .. rubric:: Macro Components
35
+
36
+ - ``PBT.tokenization`` **(patch extraction)**
37
+
38
+ *Operations.* The pre-processed EEG signal :math:`X \in \mathbb{R}^{C \times T}`
39
+ (with :math:`C = \text{n_chans}` and :math:`T = \text{n_times}`) is divided into
40
+ non-overlapping patches of size :math:`d_{\text{input}}` along the time axis.
41
+ This process yields :math:`N` total patches, calculated as
42
+ :math:`N = C \left\lfloor \frac{T}{D} \right\rfloor` (where :math:`D = d_{\text{input}}`).
43
+ When time shifts are applied, :math:`N` decreases to
44
+ :math:`N = C \left\lfloor \frac{T - T_{\text{aug}}}{D} \right\rfloor`.
45
+
46
+ *Role.* Tokenizes EEG trials into fixed-size, per-channel patches so the model
47
+ remains adaptive to different numbers of channels and recording lengths.
48
+ Process is inspired by Vision Transformers [visualtransformer]_ and
49
+ adapted for GPT context from [efficient-batchpacking]_.
50
+
51
+ - ``PBT.patch_projection`` **(patch embedding)**
52
+
53
+ *Operations.* The linear layer ``PBT.patch_projection`` maps the tokens from dimension
54
+ :math:`d_{\text{input}}` to the Transformer embedding dimension :math:`d_{\text{model}}`.
55
+ Patches :math:`X_P` are projected as :math:`X_E = X_P W_E^\top`, where
56
+ :math:`W_E \in \mathbb{R}^{d_{\text{model}} \times D}`. In this configuration
57
+ :math:`d_{\text{model}} = 2D` with :math:`D = d_{\text{input}}`.
58
+
59
+ *Interpretability.* Learns periodic structures similar to frequency filters in
60
+ the first convolutional layers of CNNs (for example :class:`~braindecode.models.EEGNet`).
61
+ The learned filters frequently focus on the high-frequency range (20-40 Hz),
62
+ which correlates with beta and gamma waves linked to higher concentration levels.
63
+
64
+ - ``PBT.cls_token`` **(classification token)**
65
+
66
+ *Operations.* A classification token :math:`[c_{\text{ls}}] \in \mathbb{R}^{1 \times d_{\text{model}}}`
67
+ is prepended to the projected patch sequence :math:`X_E`. The CLS token can optionally
68
+ be learnable (see ``learnable_cls``).
69
+
70
+ *Role.* Acts as a dedicated readout token that aggregates information through the
71
+ Transformer encoder stack.
72
+
73
+ - ``PBT.pos_embedding`` **(positional embedding)**
74
+
75
+ *Operations.* Positional indices are generated by ``PBT.linear_projection``, an instance
76
+ of :class:`~braindecode.models.patchedtransformer._ChannelEncoding`, and mapped to vectors
77
+ through :class:`~torch.nn.Embedding`. The embedding table
78
+ :math:`W_{\text{pos}} \in \mathbb{R}^{(N+1) \times d_{\text{model}}}` is added to the token
79
+ sequence, yielding :math:`X_{\text{pos}} = [c_{\text{ls}}, X_E] + W_{\text{pos}}`.
80
+
81
+ *Role/Interpretability.* Introduces spatial and temporal dependence to counter the
82
+ position invariance of the Transformer encoder. The learned positional embedding
83
+ exposes spatial relationships, often revealing a symmetric pattern in central regions
84
+ (C1-C6) associated with the motor cortex.
85
+
86
+ - ``PBT.transformer_encoder`` **(sequence processing and attention)**
87
+
88
+ *Operations.* The token sequence passes through :math:`n_{\text{blocks}}` Transformer
89
+ encoder layers. Each block combines a Multi-Head Self-Attention (MHSA) module with
90
+ ``num_heads`` attention heads and a Feed-Forward Network (FFN). Both MHSA
91
+ and FFN use parallel residual connections with Layer Normalization inside the blocks
92
+ and apply dropout (``drop_prob``) within the Transformer components.
93
+
94
+ *Role/Robustness.* Self-attention enables every token to consider all others, capturing
95
+ global temporal and spatial dependencies immediately and adaptively. This architecture
96
+ accommodates arbitrary numbers of patches and channels, supporting pre-training across
97
+ diverse datasets.
98
+
99
+ - ``PBT.final_layer`` **(readout)**
100
+
101
+ *Operations.* A linear layer operates on the processed CLS token only, and the model
102
+ predicts class probabilities as :math:`y = \operatorname{softmax}([c_{\text{ls}}] W_{\text{class}}^\top + b_{\text{class}})`.
103
+
104
+ *Role.* Performs the final classification from the information aggregated into the CLS
105
+ token after the Transformer encoder stack.
106
+
107
+ .. rubric:: Convolutional Details
108
+
109
+ PBT omits convolutional layers; equivalent feature extraction is carried out by the patch
110
+ pipeline and attention stack.
111
+
112
+ * **Temporal.** Tokenization slices the EEG into fixed windows of size :math:`D = d_{\text{input}}`
113
+ (for the default configuration, :math:`D=64` samples :math:`\approx 0.256\,\text{s}` at
114
+ :math:`250\,\text{Hz}`), while ``PBT.patch_projection`` learns periodic patterns within each
115
+ patch. The Transformer encoder then models long- and short-range temporal dependencies through
116
+ self-attention.
117
+
118
+ * **Spatial.** Patches are channel-specific, keeping the architecture adaptive to any electrode
119
+ montage. Channel-aware positional encodings :math:`W_{\text{pos}}` capture relationships between
120
+ nearby sensors; learned embeddings often form symmetric motifs across motor cortex electrodes
121
+ (C1–C6), and self-attention propagates information across all channels jointly.
122
+
123
+ * **Spectral.** ``PBT.patch_projection`` acts similarly to the first convolutional layer in
124
+ :class:`~braindecode.models.EEGNet`, learning frequency-selective filters without an explicit
125
+ Fourier transform. The highest-energy filters typically reside between :math:`20` and
126
+ :math:`40\,\text{Hz}`, aligning with beta/gamma rhythms tied to focused motor imagery.
127
+
128
+ .. rubric:: Attention / Sequential Modules
129
+
130
+ * **Attention Details.** ``PBT.transformer_encoder`` stacks :math:`n_{\text{blocks}}` Transformer
131
+ encoder layers with Multi-Head Self-Attention. Every token attends to all others, enabling
132
+ immediate global integration across time and channels and supporting heterogeneous datasets.
133
+ Attention rollout visualisations highlight strong activations over motor cortex electrodes
134
+ (C3, C4, Cz) during motor imagery decoding.
135
+
136
+
137
+ .. warning::
138
+
139
+ **Important:** As the other Large Brain Models in Braindecode, :class:`PBT` is
140
+ designed for large-scale pre-training and fine-tuning. Training from
141
+ scratch on small datasets may lead to suboptimal results. Cross-Dataset
142
+ pre-training and subsequent fine-tuning is recommended to leverage the
143
+ full potential of this architecture.
144
+
145
+ Parameters
146
+ ----------
147
+ d_input : int, optional
148
+ Size (in samples) of each patch (token) extracted along the time axis.
149
+ d_model : int, optional
150
+ Transformer embedding dimensionality.
151
+ n_blocks : int, optional
152
+ Number of Transformer encoder layers.
153
+ num_heads : int, optional
154
+ Number of attention heads.
155
+ drop_prob : float, optional
156
+ Dropout probability used in Transformer components.
157
+ learnable_cls : bool, optional
158
+ Whether the classification token is learnable.
159
+ bias_transformer : bool, optional
160
+ Whether to use bias in Transformer linear layers.
161
+ activation : nn.Module, optional
162
+ Activation function class to use in Transformer feed-forward layers.
163
+
164
+ References
165
+ ----------
166
+ .. [pbt] Klein, T., Minakowski, P., & Sager, S. (2025).
167
+ Flexible Patched Brain Transformer model for EEG decoding.
168
+ Scientific Reports, 15(1), 1-12.
169
+ https://www.nature.com/articles/s41598-025-86294-3
170
+ .. [visualtransformer] Dosovitskiy, A., Beyer, L., Kolesnikov, A.,
171
+ Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M.,
172
+ Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J. & Houlsby,
173
+ N. (2021). An Image is Worth 16x16 Words: Transformers for Image
174
+ Recognition at Scale. International Conference on Learning
175
+ Representations (ICLR).
176
+ .. [efficient-batchpacking] Krell, M. M., Kosec, M., Perez, S. P., &
177
+ Fitzgibbon, A. (2021). Efficient sequence packing without
178
+ cross-contamination: Accelerating large language models without
179
+ impacting performance. arXiv preprint arXiv:2107.02027.
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ # Signal related parameters
185
+ n_chans=None,
186
+ n_outputs=None,
187
+ n_times=None,
188
+ chs_info=None,
189
+ input_window_seconds=None,
190
+ sfreq=None,
191
+ # Model parameters
192
+ d_input: int = 64,
193
+ d_model: int = 128,
194
+ n_blocks: int = 4,
195
+ num_heads: int = 4,
196
+ drop_prob: float = 0.1,
197
+ learnable_cls=True,
198
+ bias_transformer=False,
199
+ activation: nn.Module = nn.GELU,
200
+ ) -> None:
201
+ super().__init__(
202
+ n_outputs=n_outputs,
203
+ n_chans=n_chans,
204
+ chs_info=chs_info,
205
+ n_times=n_times,
206
+ input_window_seconds=input_window_seconds,
207
+ sfreq=sfreq,
208
+ )
209
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
210
+ # Store hyperparameters
211
+ self.d_input = d_input
212
+ self.d_model = d_model
213
+ self.n_blocks = n_blocks
214
+ self.num_heads = num_heads
215
+ self.drop_prob = drop_prob
216
+
217
+ # number of distinct positional indices (per-channel tokens + cls)
218
+ self.num_embeddings = (self.n_chans * self.n_times) // self.d_input
219
+
220
+ # Classification token (learnable or fixed zero)
221
+ if learnable_cls:
222
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.002)
223
+ else:
224
+ # non-learnable zeroed tensor
225
+ self.cls_token = torch.full(
226
+ size=(1, 1, self.d_model),
227
+ fill_value=0,
228
+ requires_grad=False,
229
+ dtype=torch.float32,
230
+ )
231
+
232
+ # Patching the signal and producing channel-aware positional index generator
233
+ self.patch_signal = _Patcher(
234
+ n_chans=self.n_chans, n_times=self.n_times, d_input=self.d_input
235
+ )
236
+
237
+ # Linear patch projection from token raw-size -> d_model
238
+ self.patching_projection = nn.Linear(
239
+ in_features=self.d_input, out_features=self.d_model, bias=False
240
+ )
241
+
242
+ # actual embedding table mapping indices -> d_model
243
+ self.pos_embedding = nn.Embedding(
244
+ num_embeddings=self.num_embeddings + 1, embedding_dim=self.d_model
245
+ )
246
+
247
+ # Transformer encoder stack
248
+ self.transformer_encoder = _TransformerEncoder(
249
+ n_blocks=n_blocks,
250
+ d_model=self.d_model,
251
+ n_head=num_heads,
252
+ drop_prob=drop_prob,
253
+ bias=bias_transformer,
254
+ activation=activation,
255
+ )
256
+
257
+ # classification head on classify token - CLS token
258
+ self.final_layer = nn.Linear(
259
+ in_features=d_model, out_features=self.n_outputs, bias=True
260
+ )
261
+
262
+ # initialize weights
263
+ self.apply(self._init_weights)
264
+
265
+ @staticmethod
266
+ def _init_weights(module: nn.Module) -> None:
267
+ """Weight initialization following the original implementation."""
268
+ if isinstance(module, nn.Linear):
269
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
270
+ if module.bias is not None:
271
+ torch.nn.init.zeros_(module.bias)
272
+ elif isinstance(module, nn.Embedding):
273
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.002)
274
+
275
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
276
+ """The implementation follows the original code logic
277
+
278
+ - split input into windows of size `(num_embeddings - 1) * d_input`
279
+ - for each window: reshape into tokens, map positional indices to embeddings,
280
+ add cls token, run Transformer encoder and collect CLS outputs
281
+ - aggregate CLS outputs across windows (if >1) and pass through `final_layer`
282
+
283
+ Parameters
284
+ ----------
285
+ X : torch.Tensor
286
+ Input tensor with shape (batch_size, n_chans, n_times)
287
+
288
+ Returns
289
+ -------
290
+ torch.Tensor
291
+ Output logits with shape (batch_size, n_outputs).
292
+ """
293
+ X, int_pos = self.patch_signal(X)
294
+
295
+ tokens = self.patching_projection(X)
296
+
297
+ cls_token = self.cls_token.expand(X.size(0), 1, -1)
298
+ cls_idx = torch.zeros((X.size(0), 1), dtype=torch.long, device=X.device)
299
+
300
+ tokens = torch.cat([cls_token, tokens], dim=1)
301
+ pos_emb = self.pos_embedding(int_pos)
302
+ transformer_out = self.transformer_encoder(tokens + pos_emb)
303
+
304
+ return self.final_layer(transformer_out[:, 0])
305
+
306
+
307
+ class _LayerNorm(nn.Module):
308
+ """Layer normalization with optional bias.
309
+
310
+ Simple wrapper around :func:`torch.nn.functional.layer_norm` exposing a
311
+ learnable scale and optional bias.
312
+
313
+ Parameters
314
+ ----------
315
+ ndim : int
316
+ Number of features (normalized shape).
317
+ bias : bool
318
+ Whether to include a learnable bias term.
319
+ """
320
+
321
+ def __init__(self, ndim: int, bias: bool) -> None:
322
+ super().__init__()
323
+ self.weight = nn.Parameter(torch.ones(ndim))
324
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
325
+
326
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
327
+ """Apply layer normalization.
328
+
329
+ Parameters
330
+ ----------
331
+ x : torch.Tensor
332
+ Input tensor where the last dimension has size `ndim`.
333
+
334
+ Returns
335
+ -------
336
+ torch.Tensor
337
+ Normalized tensor with the same shape as input.
338
+ """
339
+ return F.layer_norm(
340
+ x,
341
+ normalized_shape=self.weight.shape,
342
+ weight=self.weight,
343
+ bias=self.bias,
344
+ eps=1e-5,
345
+ )
346
+
347
+
348
+ class _MHSA(nn.Module):
349
+ """Multi-head self-attention (MHSA) block.
350
+
351
+ Implements a standard multi-head attention mechanism with optional
352
+ use of PyTorch's scaled_dot_product_attention (FlashAttention) when
353
+ available and requested.
354
+
355
+ Parameters
356
+ ----------
357
+ d_model : int
358
+ Dimensionality of the model / embeddings.
359
+ n_head : int
360
+ Number of attention heads.
361
+ bias : bool
362
+ Whether linear layers use bias.
363
+ drop_prob : float, optional
364
+ drop_prob probability applied to attention weights and residual projection.
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ d_model: int,
370
+ n_head: int,
371
+ bias: bool,
372
+ drop_prob: float = 0.0,
373
+ ) -> None:
374
+ super().__init__()
375
+
376
+ assert d_model % n_head == 0, "d_model must be divisible by n_head"
377
+
378
+ # qkv and output projection
379
+ self.attn = nn.Linear(d_model, 3 * d_model, bias=bias)
380
+ self.proj = nn.Linear(d_model, d_model, bias=bias)
381
+
382
+ # dropout modules
383
+ self.attn_drop_prob = nn.Dropout(drop_prob)
384
+ self.resid_drop_prob = nn.Dropout(drop_prob)
385
+
386
+ self.n_head = n_head
387
+ self.d_model = d_model
388
+ self.drop_prob = drop_prob
389
+
390
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
391
+ """Forward pass for MHSA.
392
+
393
+ Parameters
394
+ ----------
395
+ x : torch.Tensor
396
+ Input tensor of shape (B, T, C) where C == d_model.
397
+
398
+ Returns
399
+ -------
400
+ torch.Tensor
401
+ Output tensor of shape (B, T, C).
402
+ """
403
+
404
+ B, T, C = x.size()
405
+
406
+ # project to q, k, v and reshape for multi-head attention
407
+ q, k, v = self.attn(x).split(self.d_model, dim=2)
408
+
409
+ # (B, nh, T, hs)
410
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
411
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
412
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
413
+
414
+ y = torch.nn.functional.scaled_dot_product_attention(
415
+ q,
416
+ k,
417
+ v,
418
+ attn_mask=None,
419
+ dropout_p=self.drop_prob if self.training else 0.0,
420
+ is_causal=False,
421
+ )
422
+
423
+ # re-assemble heads -> (B, T, C)
424
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
425
+
426
+ # final linear projection + residual drop_prob
427
+ y = self.resid_drop_prob(self.proj(y))
428
+ return y
429
+
430
+
431
+ class _FeedForward(nn.Module):
432
+ """Position-wise feed-forward network from Transformer.
433
+
434
+ Implements the two-layer MLP with GELU activation and dropout used in
435
+ Transformer architectures.
436
+
437
+ Parameters
438
+ ----------
439
+ d_model : int
440
+ Input and output dimensionality.
441
+ dim_feedforward : int, optional
442
+ Hidden dimensionality of the feed-forward layer. If None, must be provided by caller.
443
+ drop_prob : float, optional
444
+ Dropout probability.
445
+ bias : bool, optional
446
+ Whether linear layers use bias.
447
+ """
448
+
449
+ def __init__(
450
+ self,
451
+ d_model: int,
452
+ dim_feedforward: Optional[int] = None,
453
+ drop_prob: float = 0.0,
454
+ bias: bool = False,
455
+ activation: nn.Module = nn.GELU,
456
+ ) -> None:
457
+ super().__init__()
458
+
459
+ if dim_feedforward is None:
460
+ raise ValueError("dim_feedforward must be provided")
461
+
462
+ self.proj_in = nn.Linear(d_model, dim_feedforward, bias=bias)
463
+ self.activation = activation()
464
+ self.proj = nn.Linear(dim_feedforward, d_model, bias=bias)
465
+ self.drop_prob = nn.Dropout(drop_prob)
466
+ self.drop_prob1 = nn.Dropout(drop_prob)
467
+
468
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
469
+ """Forward pass through the feed-forward block."""
470
+ x = self.proj_in(x)
471
+ x = self.activation(x)
472
+ x = self.drop_prob1(x)
473
+ x = self.proj(x)
474
+ x = self.drop_prob(x)
475
+ return x
476
+
477
+
478
+ class _TransformerEncoderLayer(nn.Module):
479
+ """Single Transformer encoder layer (pre-norm) combining MHSA and feed-forward.
480
+
481
+ The block follows the pattern:
482
+ x <- x + MHSA(_LayerNorm(x))
483
+ x <- x + FF(_LayerNorm(x))
484
+ """
485
+
486
+ def __init__(
487
+ self,
488
+ d_model: int,
489
+ n_head: int,
490
+ drop_prob: float = 0.0,
491
+ dim_feedforward: Optional[int] = None,
492
+ bias: bool = False,
493
+ activation: nn.Module = nn.GELU,
494
+ ) -> None:
495
+ super().__init__()
496
+
497
+ if dim_feedforward is None:
498
+ dim_feedforward = 4 * d_model
499
+ # note: preserve the original behaviour (print) from the provided code
500
+ print(
501
+ "dim_feedforward is set to 4*d_model, the default in Vaswani et al. (Attention is all you need)"
502
+ )
503
+
504
+ self.layer_norm_att = _LayerNorm(d_model, bias=bias)
505
+ self.mhsa = _MHSA(d_model, n_head, bias, drop_prob=drop_prob)
506
+ self.layer_norm_ff = _LayerNorm(d_model, bias=bias)
507
+ self.feed_forward = _FeedForward(
508
+ d_model=d_model,
509
+ dim_feedforward=dim_feedforward,
510
+ drop_prob=drop_prob,
511
+ bias=bias,
512
+ activation=activation,
513
+ )
514
+
515
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
516
+ """Execute one encoder layer.
517
+
518
+ Parameters
519
+ ----------
520
+ x : torch.Tensor
521
+ Input of shape (B, T, d_model).
522
+
523
+ Returns
524
+ -------
525
+ torch.Tensor
526
+ Output of the same shape as input.
527
+ """
528
+ x = x + self.mhsa(self.layer_norm_att(x))
529
+ x = x + self.feed_forward(self.layer_norm_ff(x))
530
+ return x
531
+
532
+
533
+ class _TransformerEncoder(nn.Module):
534
+ """Stack of Transformer encoder layers.
535
+
536
+ Parameters
537
+ ----------
538
+ n_blocks : int
539
+ Number of encoder layers to stack.
540
+ d_model : int
541
+ Dimensionality of embeddings.
542
+ n_head : int
543
+ Number of attention heads per layer.
544
+ drop_prob : float
545
+ Dropout probability.
546
+ bias : bool
547
+ Whether linear layers use bias.
548
+ """
549
+
550
+ def __init__(
551
+ self,
552
+ n_blocks: int,
553
+ d_model: int,
554
+ n_head: int,
555
+ drop_prob: float,
556
+ bias: bool,
557
+ activation: nn.Module = nn.GELU,
558
+ ) -> None:
559
+ super().__init__()
560
+
561
+ self.encoder_block = nn.ModuleList(
562
+ [
563
+ _TransformerEncoderLayer(
564
+ d_model=d_model,
565
+ n_head=n_head,
566
+ drop_prob=drop_prob,
567
+ dim_feedforward=None,
568
+ bias=bias,
569
+ activation=activation,
570
+ )
571
+ for _ in range(n_blocks)
572
+ ]
573
+ )
574
+
575
+ # GPT2-like initialization for linear layers
576
+ self.apply(self._init_weights)
577
+ for pn, p in self.named_parameters():
578
+ if pn.endswith("proj.weight"):
579
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_blocks))
580
+
581
+ @staticmethod
582
+ def _init_weights(module: nn.Module) -> None:
583
+ if isinstance(module, nn.Linear):
584
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
585
+ if module.bias is not None:
586
+ torch.nn.init.zeros_(module.bias)
587
+
588
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
589
+ """Forward through all encoder blocks sequentially."""
590
+ for block in self.encoder_block:
591
+ x = block(x)
592
+ return x
593
+
594
+
595
+ class _Patcher(nn.Module):
596
+ """Patching encoding helper.
597
+
598
+ This module "patchifies" the original X entry in a ViT manner.
599
+
600
+ Parameters
601
+ ----------
602
+ n_chans : int
603
+ Number of EEG channels.
604
+ n_times : int
605
+ Number of time samples per trial.
606
+ d_input : int, optional
607
+ Number of positional embedding slots allocated per channel.
608
+ """
609
+
610
+ def __init__(self, n_chans=2, n_times=1000, d_input=64) -> None:
611
+ super().__init__()
612
+ self.d_input = d_input
613
+ self.num_tokens = (n_chans * n_times) // d_input
614
+
615
+ def forward(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
616
+ """Returns patched input together with positional indices
617
+ expanded to batch dimension.
618
+
619
+ Parameters
620
+ ----------
621
+ X : torch.Tensor
622
+ Input tensor.
623
+
624
+ Returns
625
+ -------
626
+ x_patched: torch.LongTensor
627
+ Tensor of shape (B, (C*T)//d_input, d_input) containing positional token indices.
628
+
629
+ position: torch.LongTensor
630
+ Tensor of shape ((C*T)//d_input + 1) containing positional token indices.
631
+ """
632
+
633
+ X_flat = X.view(X.size(0), -1)
634
+ x_patched = X_flat[:, : self.num_tokens * self.d_input].view(
635
+ X.size(0), self.num_tokens, self.d_input
636
+ )
637
+
638
+ position = torch.arange(x_patched.size(1) + 1, device=X.device)
639
+
640
+ return x_patched, position