braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -1,843 +0,0 @@
1
- """
2
- REVE (Representation for EEG with versatile embeddings) model.
3
- Authors: Jonathan Lys (jonathan.lys@imt-atlantique.org)
4
- License: BSD 3 clause
5
- """
6
-
7
- import json
8
- import logging
9
- import math
10
- import os
11
- from pathlib import Path
12
- from typing import Optional, Union
13
-
14
- import requests
15
- import torch
16
- import torch.nn.functional as F
17
- from einops import rearrange
18
- from torch import nn
19
- from torch.nn.attention import SDPBackend, sdpa_kernel
20
-
21
- from braindecode.models.base import EEGModuleMixin
22
-
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- class REVE(EEGModuleMixin, nn.Module):
27
- r"""
28
- **R**\ epresentation for **E**\ EG with **V**\ ersatile **E**\ mbeddings (REVE) from El Ouahidi et al. (2025) [reve]_.
29
-
30
- :bdg-danger:`Foundation Model` :bdg-info:`Attention/Transformer`
31
-
32
- .. figure:: https://brain-bzh.github.io/reve/static/images/architecture.png
33
- :align: center
34
- :alt: REVE Training pipeline overview
35
- :width: 1000px
36
-
37
- Foundation models have transformed machine learning by reducing reliance on
38
- task-specific data and induced biases through large-scale pretraining. While
39
- successful in language and vision, their adoption in EEG has lagged due to the
40
- heterogeneity of public datasets, which are collected under varying protocols,
41
- devices, and electrode configurations. Existing EEG foundation models struggle
42
- to generalize across these variations, often restricting pretraining to a single
43
- setup and resulting in suboptimal performance, particularly under linear probing.
44
-
45
- REVE is a pretrained model explicitly designed to generalize across diverse EEG signals. It introduces
46
- a **4D positional encoding** scheme that enables processing signals of arbitrary length and electrode
47
- arrangement. Using a masked autoencoding objective, REVE was pretrained on over **60,000 hours** of EEG
48
- data from **92 datasets** spanning **25,000 subjects**, the largest EEG pretraining effort to date.
49
-
50
- .. rubric:: Channels Invariant Positional Encoding
51
-
52
- Prior EEG foundation models (:class:`~braindecode.models.Labram`, :class:`~braindecode.models.BIOT`) rely on
53
- fixed positional embeddings, making direct transfer to unseen electrode layouts infeasible. CBraMod uses
54
- convolution-based positional encoding that requires fine-tuning when adapting to new configurations.
55
- As noted in the CBraMod paper: *"fixing the pre-trained parameters during training on downstream
56
- datasets will lead to a very large performance decline."*
57
-
58
- REVE's 4D positional encoding jointly encodes spatial :math:`(x, y, z)` and temporal :math:`(t)` positions
59
- using Fourier embeddings, enabling true cross-configuration transfer without retraining. The fourier embedding
60
- have inspiration on brainmodule [brainmodule]_, generalized to 4D for EEG with the channel spatial coordinates
61
- and temporal patch index.
62
-
63
- .. rubric:: Linear Probing Performance
64
-
65
- A key advantage of REVE is producing useful latent representation without heavy fine-tuning. Under linear
66
- probing (frozen encoder), REVE achieves state-of-the-art results on downstream EEG tasks.
67
- This enables practical deployment in low-data scenarios where extensive fine-tuning is not feasible.
68
-
69
- .. rubric:: Architecture
70
-
71
- The model adopts modern Transformer components validated through ablation studies:
72
-
73
- - **Normalization**: RMSNorm outperforms LayerNorm;
74
- - **Activation**: GEGLU outperforms GELU;
75
- - **Attention**: Flash Attention via PyTorch's SDPA;
76
- - **Masking ratio**: 55% optimal for spatio-temporal block masking
77
-
78
- These choices align with best practices from large language models and were empirically validated
79
- on EEG data.
80
-
81
- .. rubric:: Secondary Loss
82
-
83
- A secondary reconstruction objective using attention pooling across layers prevents over-specialization
84
- in the final layer. This pooling acts as an information bottleneck, forcing the model to distill key
85
- information from the entire sequence. Ablations show this loss is crucial for linear probing quality:
86
- removing it drops average performance in 10% under the frozen evaluation.
87
-
88
- .. rubric:: Macro Components
89
-
90
- - ``REVE.to_patch_embedding`` **Patch Tokenization**
91
-
92
- The EEG signal is split into overlapping patches along the time dimension, generating
93
- :math:`p = \left\lceil \frac{T - w}{w - o} \right\rceil + \mathbf{1}[(T - w) \bmod (w - o) \neq 0]`
94
- patches of size :math:`w` with overlap :math:`o`, where :math:`T` is the signal length.
95
- Each patch is linearly projected to the embedding dimension.
96
-
97
- - ``REVE.fourier4d`` + ``REVE.mlp4d`` **4D Positional Embedding (4DPE)**
98
-
99
- The 4DPE encodes each token's 4D coordinates :math:`(x, y, z, t)` where :math:`(x, y, z)` are the
100
- 3D spatial coordinates from a standardized electrode position bank, and :math:`t` is the temporal
101
- patch index. The encoding combines:
102
-
103
- 1. **Fourier embedding**: Sinusoidal encoding across multiple frequencies for smooth interpolation
104
- to unseen positions
105
- 2. **MLP embedding**: :class:`~torch.nn.Linear` (4 → embed_dim) → :class:`~torch.nn.GELU` → :class:`~torch.nn.LayerNorm` for learnable refinement
106
-
107
- Both components are summed and normalized. The 4DPE adds negligible computational overhead,
108
- scaling linearly with the number of tokens.
109
-
110
- - ``REVE.transformer`` **Transformer Encoder**
111
-
112
- Pre-LayerNorm Transformer with multi-head self-attention (:class:`~torch.nn.RMSNorm`), feed-forward networks (GEGLU
113
- activation), and residual connections. Default configuration: 22 layers, 8 heads, 512 embedding
114
- dimension (~72M parameters).
115
-
116
- - ``REVE.final_layer`` **Classification Head**
117
-
118
- Two modes (controlled by the ``attention_pooling`` parameter):
119
-
120
- - When ``attention_pooling`` is disabled (e.g., ``None`` or ``False``): flatten all tokens
121
- → :class:`~torch.nn.LayerNorm` → :class:`~torch.nn.Linear`
122
- - When ``attention_pooling`` is enabled: attention pooling with a learnable query token
123
- attending to all encoder outputs
124
-
125
- .. rubric:: Known Limitations
126
-
127
- - **Sparse electrode setups**: Performance degrades with very few channels. On motor imagery,
128
- accuracy drops from 0.824 (64 channels) to 0.660 (1 channel). For tasks requiring broad
129
- spatial coverage (e.g., imagined speech), performance with <4 channels approaches chance level.
130
- - **Demographic bias**: The pretraining corpus aggregates publicly available datasets, most
131
- originating from North America and Europe, resulting in limited demographic diversity,
132
- more details about the datasets used for pretraining can be found in the REVE paper [reve]_.
133
-
134
- .. rubric:: Pretrained Weights
135
-
136
- Weights are available on `HuggingFace <https://huggingface.co/collections/brain-bzh/reve>`_,
137
- but you must agree to the data usage terms before downloading:
138
-
139
- - ``brain-bzh/reve-base``: 72M parameters, 512 embedding dim, 22 layers (~260 A100 GPU hours)
140
- - ``brain-bzh/reve-large``: ~400M parameters, 1250 embedding dim
141
-
142
- .. important::
143
- **Pre-trained Weights Available (Registration Required)**
144
-
145
- This model has pre-trained weights available on the Hugging Face Hub.
146
- **You must first register and agree to the data usage terms on the authors'
147
- HuggingFace repository before you can access the weights.**
148
- `Link here <https://huggingface.co/collections/brain-bzh/reve>`_.
149
-
150
- You can load them using:
151
-
152
- .. code-block:: python
153
-
154
- from braindecode.models import REVE
155
-
156
- # Load pre-trained model from Hugging Face Hub
157
- model = REVE.from_pretrained("brain-bzh/reve-base")
158
-
159
- To push your own trained model to the Hub:
160
-
161
- .. code-block:: python
162
-
163
- # After training your model
164
- model.push_to_hub(
165
- repo_id="username/my-reve-model", commit_message="Upload trained REVE model"
166
- )
167
-
168
- Requires installing ``braindecode[hug]`` for Hub integration.
169
-
170
- .. rubric:: Usage
171
-
172
- .. code-block:: python
173
-
174
- from braindecode.models import REVE
175
-
176
- model = REVE(
177
- n_outputs=4, # e.g., 4-class motor imagery
178
- n_chans=22,
179
- n_times=1000, # 5 seconds at 200 Hz
180
- sfreq=200,
181
- chs_info=[{"ch_name": "C3"}, {"ch_name": "C4"}, ...],
182
- )
183
-
184
- # Forward pass: (batch, n_chans, n_times) -> (batch, n_outputs)
185
- output = model(eeg_data, pos=channel_positions)
186
-
187
- .. warning::
188
-
189
- Input data must be sampled at **200 Hz** to match pretraining. The model applies
190
- z-score normalization followed by clipping at 15 standard deviations internally
191
- during pretraining-users should apply similar preprocessing.
192
-
193
- Parameters
194
- ----------
195
- embed_dim : int, default=512
196
- Embedding dimension. Use 512 for REVE-Base, 1250 for REVE-Large.
197
- depth : int, default=22
198
- Number of Transformer layers.
199
- heads : int, default=8
200
- Number of attention heads.
201
- head_dim : int, default=64
202
- Dimension per attention head.
203
- mlp_dim_ratio : float, default=2.66
204
- FFN hidden dimension ratio: ``mlp_dim = embed_dim × mlp_dim_ratio``.
205
- use_geglu : bool, default=True
206
- Use GEGLU activation (recommended) or standard GELU.
207
- freqs : int, default=4
208
- Number of frequencies for Fourier positional embedding.
209
- patch_size : int, default=200
210
- Temporal patch size in samples (200 samples = 1 second at 200 Hz).
211
- patch_overlap : int, default=20
212
- Overlap between patches in samples.
213
- attention_pooling : bool, default=False
214
- Pooling strategy for aggregating transformer outputs before classification.
215
- If ``False`` (default), all tokens are flattened into a single vector of size
216
- ``(n_chans x n_patches x embed_dim)``, which is then passed through LayerNorm
217
- and a linear classifier. If ``True``, uses attention-based pooling with a
218
- learnable query token that attends to all encoder outputs, producing a single
219
- embedding of size ``embed_dim``. Attention pooling is more parameter-efficient
220
- for long sequences and variable-length inputs.
221
-
222
- References
223
- ----------
224
- .. [reve] El Ouahidi, Y., Lys, J., Thölke, P., Farrugia, N., Pasdeloup, B.,
225
- Gripon, V., Jerbi, K. & Lioi, G. (2025). REVE: A Foundation Model for EEG -
226
- Adapting to Any Setup with Large-Scale Pretraining on 25,000 Subjects.
227
- The Thirty-Ninth Annual Conference on Neural Information Processing Systems.
228
- https://openreview.net/forum?id=ZeFMtRBy4Z
229
- .. [brainmodule] Défossez, A., Caucheteux, C., Rapin, J., Kabeli, O., & King, J. R.
230
- (2023). Decoding speech perception from non-invasive brain recordings. Nature
231
- Machine Intelligence, 5(10), 1097-1107.
232
-
233
- Notes
234
- -----
235
- The position bank is downloaded from HuggingFace on first initialization, mapping
236
- standard 10-20/10-10/10-05 electrode names to 3D coordinates. This enables the
237
- 4D positional encoding to generalize across electrode configurations without
238
- requiring matched layouts between pretraining and downstream tasks.
239
- """
240
-
241
- def __init__(
242
- self,
243
- n_outputs=None,
244
- n_chans=None,
245
- chs_info=None,
246
- n_times=None,
247
- input_window_seconds=None,
248
- sfreq=None,
249
- # REVE specific parameters
250
- embed_dim: int = 512,
251
- depth: int = 22,
252
- heads: int = 8,
253
- head_dim: int = 64,
254
- mlp_dim_ratio: float = 2.66,
255
- use_geglu: bool = True,
256
- freqs: int = 4,
257
- patch_size: int = 200,
258
- patch_overlap: int = 20,
259
- attention_pooling: bool = False,
260
- ):
261
- super().__init__(
262
- n_outputs=n_outputs,
263
- n_chans=n_chans,
264
- chs_info=chs_info,
265
- n_times=n_times,
266
- input_window_seconds=input_window_seconds,
267
- sfreq=sfreq,
268
- )
269
-
270
- self.embed_dim = embed_dim
271
- self.freqs = freqs
272
- self.patch_size = patch_size
273
- self.patch_overlap = patch_overlap
274
- self.depth = depth
275
- self.heads = heads
276
- self.head_dim = head_dim
277
- self.mlp_dim_ratio = mlp_dim_ratio
278
- self.use_geglu = use_geglu
279
-
280
- self.use_attention_pooling = attention_pooling
281
-
282
- self.to_patch_embedding = nn.Sequential(
283
- nn.Linear(in_features=self.patch_size, out_features=self.embed_dim)
284
- )
285
-
286
- self.fourier4d = FourierEmb4D(self.embed_dim, freqs=self.freqs)
287
-
288
- self.mlp4d = nn.Sequential(
289
- nn.Linear(4, self.embed_dim, bias=False),
290
- nn.GELU(),
291
- nn.LayerNorm(self.embed_dim),
292
- )
293
-
294
- self.ln = nn.LayerNorm(self.embed_dim) # 4DPE module layernorm
295
-
296
- self.transformer = TransformerBackbone(
297
- dim=self.embed_dim,
298
- depth=self.depth,
299
- heads=self.heads,
300
- head_dim=self.head_dim,
301
- mlp_dim=int(self.embed_dim * self.mlp_dim_ratio),
302
- geglu=self.use_geglu,
303
- )
304
-
305
- final_dim = self._get_flattened_output_dim()
306
-
307
- if self.use_attention_pooling:
308
- self.final_layer = nn.Sequential(
309
- nn.LayerNorm(self.embed_dim),
310
- nn.Linear(self.embed_dim, self.n_outputs),
311
- )
312
- else:
313
- self.final_layer = nn.Sequential(
314
- nn.Flatten(),
315
- nn.LayerNorm(final_dim),
316
- nn.Linear(final_dim, self.n_outputs),
317
- )
318
-
319
- if self.use_attention_pooling:
320
- self.cls_query_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
321
-
322
- self._position_bank = RevePositionBank()
323
-
324
- self.default_pos = None
325
- if chs_info is not None:
326
- self.default_pos = self.get_positions([ch["ch_name"] for ch in chs_info])
327
-
328
- def _get_flattened_output_dim(self) -> int:
329
- """Helper function to compute the flattened output dimension after the transformer."""
330
-
331
- if self.use_attention_pooling:
332
- return self.embed_dim
333
-
334
- n_patches = math.ceil(
335
- (self.n_times - self.patch_size) / (self.patch_size - self.patch_overlap)
336
- )
337
-
338
- if (self.n_times - self.patch_size) % (
339
- self.patch_size - self.patch_overlap
340
- ) == 0:
341
- n_patches += 1
342
-
343
- flat_dim = self.n_chans * n_patches * self.embed_dim
344
- return flat_dim
345
-
346
- def get_positions(self, channel_names: list[str]) -> torch.Tensor:
347
- """Fetch channel positions from the position bank. The position bank is downloaded when the model is instantiated.
348
-
349
- Parameters
350
- ----------
351
- channel_names : list[str]
352
- List of channel names for which to fetch positions.
353
-
354
- Returns
355
- -------
356
- torch.Tensor
357
- Tensor of shape (num_channels, 3) containing the (x, y, z) positions of the channels.
358
- """
359
-
360
- return self._position_bank.forward(channel_names)
361
-
362
- def forward(
363
- self,
364
- eeg: torch.Tensor,
365
- pos: Optional[torch.Tensor] = None,
366
- return_output: bool = False,
367
- ) -> Union[torch.Tensor, list[torch.Tensor]]:
368
- """
369
- Forward pass of the model.
370
-
371
- Goes through the following steps:
372
- 1. Patch extraction from the EEG signal.
373
- 2. 4D positional embedding computation.
374
- 3. Transformer encoding.
375
- 4. Final layer processing (if `return_output` is False).
376
-
377
- Parameters
378
- ----------
379
- eeg : torch.Tensor
380
- Input EEG tensor of shape (batch_size, channels, sequence_length).
381
- pos : torch.Tensor
382
- Position tensor of shape (batch_size, channels, 3) representing (x, y, z) coordinates.
383
- return_output : bool, optional
384
- If True, returns the output from the transformer directly.
385
- If False, applies the final layer and returns the processed output. Default is False.
386
-
387
- Returns
388
- -------
389
- Union[torch.Tensor, list[torch.Tensor]]
390
- - If `return_output` is False: Returns a single `torch.Tensor` (output after final layer).
391
- - If `return_output` is True: Returns a `list[torch.Tensor]` (outputs from transformer layers).
392
-
393
- The output tensor(s) from the model. If `return_output` is True,
394
- returns the transformer output; otherwise, returns the output after the final layer.
395
- """
396
-
397
- patches = eeg.unfold(
398
- dimension=2,
399
- size=self.patch_size,
400
- step=self.patch_size - self.patch_overlap,
401
- )
402
- batch_size, channel, n_patches, _ = patches.shape
403
-
404
- if pos is None:
405
- if self.default_pos is None:
406
- raise ValueError(
407
- "No positions provided and no default positions available. Please provide channel positions."
408
- )
409
- pos = self.default_pos.expand(batch_size, -1, -1).to(eeg.device)
410
-
411
- pos = FourierEmb4D.add_time_patch(pos, n_patches)
412
- pos_embed = self.ln(self.fourier4d(pos) + self.mlp4d(pos))
413
-
414
- # Patch embedding: (batch, channels, n_patches, patch_size) -> (batch, channels, n_patches, embed_dim)
415
- patch_embeddings = self.to_patch_embedding(patches)
416
-
417
- # Flatten spatial and temporal dimensions into a single sequence dimension
418
- # (batch, channels, n_patches, embed_dim) -> (batch, channels * n_patches, embed_dim)
419
- # This creates a sequence of tokens where each token represents one patch from one channel
420
- x = (
421
- rearrange(
422
- patch_embeddings,
423
- "batch chan patch emb -> batch (chan patch) emb",
424
- chan=channel,
425
- patch=n_patches,
426
- emb=self.embed_dim,
427
- )
428
- + pos_embed
429
- )
430
- x = self.transformer(x, return_output)
431
-
432
- if return_output:
433
- return x
434
-
435
- # Reshape back from flattened sequence to separate channel and temporal dimensions
436
- # (batch, channels * n_patches, embed_dim) -> (batch, channels, n_patches, embed_dim)
437
- # This recovers the spatio-temporal structure for downstream processing
438
- x = rearrange(
439
- x,
440
- "batch (chan patch) emb -> batch chan patch emb",
441
- batch=batch_size,
442
- chan=channel,
443
- patch=n_patches,
444
- emb=self.embed_dim,
445
- )
446
-
447
- if self.use_attention_pooling:
448
- x = self._attention_pooling(x)
449
-
450
- x = self.final_layer(x)
451
- return x
452
-
453
- def _attention_pooling(self, x: torch.Tensor) -> torch.Tensor:
454
- """Apply attention pooling on the sequence dimension of x.
455
-
456
- Parameters
457
- ----------
458
- x : torch.Tensor
459
- Input tensor of shape (B, C, S, E), where B is the batch size,
460
- C is the number of channels, S is the sequence length,
461
- and E is the embedding dimension. Typically the output from the transformer.
462
-
463
- Returns
464
- -------
465
- torch.Tensor
466
- Output tensor of shape (B, E) after attention pooling.
467
- """
468
-
469
- batch_size, n_channels, seq_len, embed_dim = x.shape
470
- # Flatten channel and sequence dimensions for attention pooling
471
- # (batch, channels, seq_len, embed_dim) -> (batch, channels * seq_len, embed_dim)
472
- x = rearrange(
473
- x,
474
- "batch chan seq emb -> batch (chan seq) emb",
475
- chan=n_channels,
476
- seq=seq_len,
477
- emb=embed_dim,
478
- )
479
- query_output = self.cls_query_token.expand(batch_size, -1, -1) # (B, 1, E)
480
- attention_scores = torch.matmul(query_output, x.transpose(-1, -2)) / (
481
- self.embed_dim**0.5
482
- ) # (B, 1, C*S)
483
- attention_weights = torch.softmax(attention_scores, dim=-1) # (B, 1, C*S)
484
- out = torch.matmul(attention_weights, x).squeeze(1) # (B, E)
485
- return out
486
-
487
-
488
- #################################################################################
489
- # Layers #
490
- #################################################################################
491
-
492
-
493
- class GEGLU(nn.Module):
494
- def forward(self, x: torch.Tensor) -> torch.Tensor:
495
- x, gates = x.chunk(2, dim=-1)
496
- return F.gelu(gates) * x
497
-
498
-
499
- class RMSNorm(torch.nn.Module):
500
- def __init__(self, dim: int, eps: float = 1e-6):
501
- super().__init__()
502
- self.eps = eps
503
- self.weight = nn.Parameter(torch.ones(dim))
504
-
505
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
506
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
507
-
508
- def forward(self, x: torch.Tensor) -> torch.Tensor:
509
- output = self._norm(x.float()).type_as(x)
510
- return output * self.weight
511
-
512
-
513
- class FeedForward(nn.Module):
514
- def __init__(self, dim: int, hidden_dim: int, geglu: bool):
515
- super().__init__()
516
- self.net = nn.Sequential(
517
- RMSNorm(dim),
518
- nn.Linear(dim, hidden_dim * 2 if geglu else hidden_dim, bias=False),
519
- GEGLU() if geglu else nn.GELU(),
520
- nn.Linear(hidden_dim, dim, bias=False),
521
- )
522
-
523
- def forward(self, x: torch.Tensor) -> torch.Tensor:
524
- return self.net(x)
525
-
526
-
527
- #################################################################################
528
- # Attention #
529
- #################################################################################
530
-
531
-
532
- class ClassicalAttention(nn.Module):
533
- def __init__(self, heads: int, use_sdpa: bool = True):
534
- super().__init__()
535
- self.use_sdpa = use_sdpa
536
- self.heads = heads
537
-
538
- def forward(self, qkv: torch.Tensor) -> torch.Tensor:
539
- # Split concatenated QKV into separate tensors
540
- # qkv shape: (batch, seq_len, 3 * heads * head_dim)
541
- q, k, v = qkv.chunk(3, dim=-1)
542
-
543
- # Reshape for multi-head attention: split last dim into (heads, head_dim)
544
- # (batch, seq_len, heads * head_dim) -> (batch, heads, seq_len, head_dim)
545
- q, k, v = (
546
- rearrange(
547
- t,
548
- "batch seq (heads dim) -> batch heads seq dim",
549
- heads=self.heads,
550
- )
551
- for t in (q, k, v)
552
- )
553
-
554
- if self.use_sdpa: # SDPA Implementation
555
- with sdpa_kernel(
556
- [
557
- SDPBackend.FLASH_ATTENTION,
558
- SDPBackend.EFFICIENT_ATTENTION,
559
- SDPBackend.MATH,
560
- ]
561
- ):
562
- out = F.scaled_dot_product_attention(q, k, v)
563
- else: # Naive Implementation
564
- _, _, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
565
- dots = torch.matmul(q, k.transpose(-1, -2)) * scale
566
- attn = nn.Softmax(dim=-1)(dots)
567
- out = torch.matmul(attn, v)
568
-
569
- # Merge heads back: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim)
570
- out = rearrange(out, "batch heads seq dim -> batch seq (heads dim)")
571
- return out
572
-
573
-
574
- class Attention(nn.Module):
575
- """
576
- Multi-head self-attention layer with RMSNorm.
577
- """
578
-
579
- def __init__(self, dim: int, heads: int = 8, head_dim: int = 64):
580
- super().__init__()
581
- inner_dim = head_dim * heads
582
- self.heads = heads
583
- self.norm = RMSNorm(dim)
584
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
585
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
586
-
587
- self.attend = ClassicalAttention(self.heads, use_sdpa=True)
588
-
589
- def forward(self, x):
590
- x = self.norm(x)
591
- qkv = self.to_qkv(x)
592
- out = self.attend(qkv)
593
- return self.to_out(out)
594
-
595
-
596
- #################################################################################
597
- # Transformer #
598
- #################################################################################
599
-
600
-
601
- class TransformerBackbone(nn.Module):
602
- """
603
- Transformer backbone consisting of multiple layers of attention and feed-forward networks.
604
- """
605
-
606
- def __init__(self, dim, depth, heads, head_dim, mlp_dim, geglu):
607
- super().__init__()
608
- self.dim = dim
609
- self.layers = nn.ModuleList([])
610
- for _ in range(depth):
611
- self.layers.append(
612
- nn.ModuleList(
613
- [
614
- Attention(
615
- self.dim,
616
- heads=heads,
617
- head_dim=head_dim,
618
- ),
619
- FeedForward(self.dim, mlp_dim, geglu),
620
- ]
621
- )
622
- )
623
-
624
- def forward(
625
- self, x, return_out_layers=False
626
- ) -> Union[torch.Tensor, list[torch.Tensor]]:
627
- out_layers = [x] if return_out_layers else []
628
- for attn, ff in self.layers:
629
- x = attn(x) + x
630
- x = ff(x) + x
631
- if return_out_layers:
632
- out_layers.append(x)
633
- return out_layers if return_out_layers else x
634
-
635
-
636
- ##################################################################################
637
- # 4D PE #
638
- ##################################################################################
639
-
640
-
641
- class FourierEmb4D(nn.Module):
642
- """
643
- Fourier positional embedding for 4D positions (x, y, z, t).
644
- This version allows for a reduced number of frequencies (n_freqs),
645
- and ensures the output embedding has the specified dimension.
646
-
647
- Parameters
648
- ----------
649
- dimension : int
650
- The dimension of the output embedding. Must be an even number.
651
- freqs : int
652
- The number of frequencies to use for the Fourier embedding.
653
- increment_time : float, optional
654
- The time increment to scale the time dimension. Default is 0.1.
655
- margin : float, optional
656
- The margin to add to the position coordinates to avoid boundary issues. Default is 0.4.
657
-
658
- """
659
-
660
- def __init__(
661
- self, dimension: int, freqs: int, increment_time=0.1, margin: float = 0.4
662
- ):
663
- super().__init__()
664
- self.dimension = dimension
665
- self.freqs = freqs
666
- self.increment_time = increment_time
667
- self.margin = margin
668
-
669
- def forward(self, positions_: torch.Tensor) -> torch.Tensor:
670
- positions = positions_.clone()
671
- positions[:, :, -1] *= self.increment_time
672
- input_shape = positions.shape
673
- batch_dims = list(input_shape[:-1])
674
-
675
- freqs_w = torch.arange(self.freqs).to(positions)
676
- freqs_z = freqs_w[:, None]
677
- freqs_y = freqs_z[:, None]
678
- freqs_x = freqs_y[:, None]
679
- width = 1 + 2 * self.margin
680
- positions = positions + self.margin
681
- p_x = 2 * math.pi * freqs_x / width
682
- p_y = 2 * math.pi * freqs_y / width
683
- p_z = 2 * math.pi * freqs_z / width
684
- p_w = 2 * math.pi * freqs_w / width
685
- positions = positions[..., None, None, None, None, :]
686
- loc = (
687
- positions[..., 0] * p_x
688
- + positions[..., 1] * p_y
689
- + positions[..., 2] * p_z
690
- + positions[..., 3] * p_w
691
- )
692
- batch_dims.append(-1)
693
- loc = loc.view(batch_dims)
694
-
695
- half_dim = self.dimension // 2
696
- current_dim = loc.shape[-1]
697
- if current_dim != half_dim:
698
- if current_dim > half_dim:
699
- loc = loc[..., :half_dim]
700
- else:
701
- raise ValueError(
702
- f"Input dimension ({current_dim}) is too small for target "
703
- f"embedding dimension ({self.dimension}). Expected at least {half_dim}."
704
- )
705
-
706
- emb = torch.cat([torch.cos(loc), torch.sin(loc)], dim=-1)
707
- return emb
708
-
709
- @classmethod
710
- def add_time_patch(cls, pos: torch.Tensor, num_patches: int) -> torch.Tensor:
711
- """
712
- Expand the position tensor by adding a time dimension, handling batched data.
713
-
714
- Parameters
715
- ----------
716
- pos : torch.Tensor
717
- Input tensor of shape (B, C, 3), where B is the batch size,
718
- C is the number of channels, and 3 represents x, y, z.
719
- num_patches : int
720
- The number of time patches.
721
-
722
- Returns
723
- -------
724
- torch.Tensor
725
- Output tensor of shape (B, C * num_patches, 4), where each position is repeated with each time value.
726
- """
727
- batch, nchans, _ = pos.shape
728
- # Repeat each position for each time step
729
- pos_repeated = pos.unsqueeze(2).repeat(
730
- 1, 1, num_patches, 1
731
- ) # Shape: (batch, nchans, num_patches, 3)
732
- # Generate time values with the specified increment
733
- time_values = torch.arange(
734
- 0, num_patches, 1, device=pos.device
735
- ).float() # Shape: (num_patches,)
736
- time_values = time_values.view(1, 1, num_patches, 1).expand(
737
- batch, nchans, num_patches, 1
738
- ) # (batch, nchans, num_patches, 1)
739
- # Concatenate the repeated positions with the time values along the last dimension
740
- pos_with_time = torch.cat(
741
- (pos_repeated, time_values), dim=-1
742
- ) # Shape: (batch, nchans, num_patches, 4)
743
- # Reshape to (batch, nchans * num_patches, 4)
744
- pos_with_time = pos_with_time.view(batch, nchans * num_patches, 4)
745
-
746
- return pos_with_time
747
-
748
-
749
- class RevePositionBank(torch.nn.Module):
750
- """Position bank for REVE model that maps standard EEG channel names to 3D coordinates.
751
-
752
- The position bank is cached locally in the library root to avoid repeated downloads.
753
-
754
- The coordinates come from the 92 datasets used during REVE pretraining.
755
-
756
- Parameters
757
- ----------
758
- url : str, optional
759
- URL to download the position bank JSON file. Default is the HuggingFace URL.
760
- timeout : int, optional
761
- Timeout in seconds for the HTTP request. Default is 5 seconds.
762
- cache_dir : str, optional
763
- Directory to cache the position bank. Default is the models folder within the library.
764
- """
765
-
766
- def __init__(
767
- self,
768
- url: str = "https://huggingface.co/brain-bzh/reve-positions/resolve/main/positions.json",
769
- timeout: int = 5,
770
- cache_dir: Optional[str] = None,
771
- ):
772
- super().__init__()
773
-
774
- if cache_dir is None:
775
- # Use the model root directory
776
- cache_dir = str(Path(__file__).parent)
777
-
778
- cache_file = os.path.join(cache_dir, ".cache", "reve_positions.json")
779
- os.makedirs(os.path.dirname(cache_file), exist_ok=True)
780
-
781
- config = None
782
-
783
- # Try to load from cache first
784
- if os.path.exists(cache_file):
785
- try:
786
- with open(cache_file, "r") as f:
787
- config = json.load(f)
788
- logger.info(f"Loaded position bank from cache: {cache_file}")
789
- except (json.JSONDecodeError, IOError) as e:
790
- logger.warning(f"Failed to load cache, downloading: {e}")
791
-
792
- # Download if cache miss or failed to load
793
- if config is None:
794
- try:
795
- response = requests.get(url, timeout=timeout)
796
- response.raise_for_status()
797
- config = json.loads(response.text)
798
-
799
- # Save to cache
800
- with open(cache_file, "w") as f:
801
- json.dump(config, f)
802
- logger.info(f"Downloaded and cached position bank to: {cache_file}")
803
- except (requests.RequestException, json.JSONDecodeError) as e:
804
- raise RuntimeError(
805
- f"Failed to download or parse the position bank from {url}: {e}"
806
- ) from e
807
-
808
- try:
809
- self.position_names = list(config.keys())
810
- self.mapping = {name: i for i, name in enumerate(self.position_names)}
811
- positions_data = list(config.values())
812
- positions = torch.tensor(positions_data, dtype=torch.float32)
813
- self.register_buffer("embedding", positions)
814
-
815
- # Validate shape (N, 3)
816
- if self.embedding.dim() != 2 or self.embedding.shape[1] != 3:
817
- raise ValueError(
818
- f"Position data must have shape (N, 3), but got {self.embedding.shape}"
819
- )
820
-
821
- assert self.embedding.shape == (len(self.mapping), 3), (
822
- f"Expected embedding shape ({len(self.mapping)}, 3), but got {self.embedding.shape}"
823
- )
824
-
825
- except (ValueError, TypeError, AssertionError) as e:
826
- raise RuntimeError(
827
- f"Invalid position data format in the downloaded config: {e}"
828
- ) from e
829
-
830
- def forward(self, channel_names: list[str]):
831
- indices = [self.mapping[q] for q in channel_names if q in self.mapping]
832
-
833
- if len(indices) < len(channel_names):
834
- logger.warning(
835
- f"Found {len(indices)} positions out of {len(channel_names)} channels"
836
- )
837
-
838
- indices = torch.tensor(indices, device=self.embedding.device)
839
-
840
- return self.embedding[indices]
841
-
842
- def get_all_positions(self):
843
- return self.position_names