braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,891 @@
1
+ """This implementation is adapted from ETH Zurich's BioFoundation repository.
2
+
3
+ Döner, B., Ingolfsson, T. M., Benini, L., & Li, Y. (2025).
4
+ LUNA: Efficient and Topology-Agnostic Foundation Model for EEG Signal Analysis.
5
+ The Thirty-Ninth Annual Conference on Neural Information Processing Systems, NeurIPS.
6
+ Retrieved from https://openreview.net/forum?id=uazfjnFL0G
7
+
8
+ Original Authors: Berkay Döner, Thorir Mar Ingolfsson
9
+ Braindecode Adaptation: Bruno Aristimunha
10
+
11
+ the LICENSE Of this file is APACHE-2.0.
12
+ """
13
+
14
+ import math
15
+ from typing import Any, Dict, Optional, Tuple, Type
16
+
17
+ import torch
18
+ import torch.fft as fft
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import rearrange
22
+ from rotary_embedding_torch import RotaryEmbedding
23
+
24
+ from braindecode.models.base import EEGModuleMixin
25
+ from braindecode.models.util import extract_channel_locations_from_chs_info
26
+ from braindecode.modules.layers import DropPath
27
+
28
+
29
+ class LUNA(EEGModuleMixin, nn.Module):
30
+ r"""LUNA from Döner et al [LUNA]_.
31
+
32
+ :bdg-success:`Convolution` :bdg-danger:`Foundation Model` :bdg-dark-line:`Channel`
33
+
34
+ .. figure:: https://arxiv.org/html/2510.22257v1/x1.png
35
+ :align: center
36
+ :alt: LUNA Architecture.
37
+
38
+ LUNA is a topology-invariant EEG model that processes signals from varying
39
+ numbers of channels using a channel-unification mechanism with learned queries.
40
+
41
+ The architecture consists of:
42
+ 1. Patch Feature Extraction (temporal CNN + FFT-based features)
43
+ 2. Channel-Unification Module (cross-attention with learned queries)
44
+ 3. Patch-wise Temporal Encoder (RoPE-based transformer)
45
+ 4. Decoder Heads (classification or reconstruction)
46
+
47
+ .. important::
48
+ **Pre-trained Weights Available**
49
+
50
+ This model has pre-trained weights available on the Hugging Face Hub
51
+ at `thorir/LUNA <https://huggingface.co/thorir/LUNA>`_.
52
+
53
+ Available model variants:
54
+
55
+ - **LUNA_base.safetensors** - Base model (embed_dim=64, num_queries=4, depth=8)
56
+ - **LUNA_large.safetensors** - Large model (embed_dim=96, num_queries=6, depth=10)
57
+ - **LUNA_huge.safetensors** - Huge model (embed_dim=128, num_queries=8, depth=24)
58
+
59
+ Example loading for fine-tuning:
60
+
61
+ .. code-block:: python
62
+
63
+ from huggingface_hub import hf_hub_download
64
+ from safetensors.torch import load_file
65
+ from braindecode.models import LUNA
66
+
67
+ # Download pre-trained weights
68
+ model_path = hf_hub_download(
69
+ repo_id="thorir/LUNA",
70
+ filename="LUNA_base.safetensors",
71
+ )
72
+
73
+ # Create model for classification (fine-tuning)
74
+ model = LUNA(
75
+ n_outputs=2, # Number of classes for your task
76
+ n_chans=22,
77
+ n_times=1000,
78
+ embed_dim=64,
79
+ num_queries=4,
80
+ depth=8,
81
+ )
82
+
83
+ # Load pre-trained encoder weights
84
+ state_dict = load_file(model_path)
85
+ # Apply key mapping for pretrained weights
86
+ mapping = model.mapping.copy()
87
+ mapping["cross_attn.temparature"] = "cross_attn.temperature"
88
+ mapped_state_dict = {mapping.get(k, k): v for k, v in state_dict.items()}
89
+ model.load_state_dict(mapped_state_dict, strict=False)
90
+
91
+ To push your own trained model to the Hub:
92
+
93
+ .. code-block:: python
94
+
95
+ # After training your model
96
+ model.push_to_hub(
97
+ repo_id="username/my-luna-model", commit_message="Upload trained LUNA model"
98
+ )
99
+
100
+ Requires installing ``braindecode[hug]`` for Hub integration.
101
+
102
+ Parameters
103
+ ----------
104
+ patch_size : int
105
+ Number of time samples per patch. Default: 40.
106
+ num_queries : int
107
+ Number of learned queries for channel unification.
108
+ Paper uses: 4 (Base), 6 (Large), 8 (Huge). Default: 4.
109
+ embed_dim : int
110
+ Embedding dimension for patch features.
111
+ Paper uses: 64 (Base), 96 (Large), 128 (Huge). Default: 64.
112
+ depth : int
113
+ Number of transformer encoder blocks.
114
+ Paper uses: 8 (Base), 10 (Large), 24 (Huge). Default: 8.
115
+ num_heads : int
116
+ Number of attention heads in channel unification.
117
+ Default: 2.
118
+ mlp_ratio : float
119
+ Ratio of MLP hidden dimension to embedding dimension. Default: 4.0.
120
+ norm_layer : nn.Module
121
+ Normalization layer class. Default: nn.LayerNorm.
122
+ drop_path : float
123
+ Stochastic depth rate. Default: 0.0.
124
+
125
+ References
126
+ ----------
127
+ .. [LUNA] Döner, B., Ingolfsson, T. M., Benini, L., & Li, Y. (2025).
128
+ LUNA: Efficient and Topology-Agnostic Foundation Model for EEG Signal Analysis.
129
+ The Thirty-Ninth Annual Conference on Neural Information Processing Systems - NeurIPS.
130
+ Retrieved from https://openreview.net/forum?id=uazfjnFL0G
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ # Braindecode EEGModuleMixin parameters
136
+ n_outputs: Optional[int] = None,
137
+ n_chans: Optional[int] = None,
138
+ n_times: Optional[int] = None,
139
+ sfreq: Optional[float] = None,
140
+ chs_info: Optional[Any] = None,
141
+ input_window_seconds: Optional[float] = None,
142
+ # Model-specific parameters
143
+ patch_size: int = 40,
144
+ num_queries: int = 4,
145
+ embed_dim: int = 64,
146
+ depth: int = 8,
147
+ num_heads: int = 2,
148
+ mlp_ratio: float = 4.0,
149
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
150
+ drop_path: float = 0.0,
151
+ drop_prob_chan: float = 0.0,
152
+ attn_drop: float = 0.0,
153
+ activation: Type[nn.Module] = nn.GELU,
154
+ ):
155
+ super().__init__(
156
+ n_outputs=n_outputs,
157
+ n_chans=n_chans,
158
+ n_times=n_times,
159
+ sfreq=sfreq,
160
+ chs_info=chs_info,
161
+ input_window_seconds=input_window_seconds,
162
+ )
163
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
164
+
165
+ # Mapping for loading pretrained weights
166
+ self.mapping = {
167
+ "channel_location_embedder.0.fc1.weight": "channel_location_embedder.fc1.weight",
168
+ "channel_location_embedder.0.fc1.bias": "channel_location_embedder.fc1.bias",
169
+ "channel_location_embedder.0.fc2.weight": "channel_location_embedder.fc2.weight",
170
+ "channel_location_embedder.0.fc2.bias": "channel_location_embedder.fc2.bias",
171
+ "channel_location_embedder.0.norm.weight": "channel_location_embedder.norm.weight",
172
+ "channel_location_embedder.0.norm.bias": "channel_location_embedder.norm.bias",
173
+ }
174
+
175
+ # Model parameters
176
+ self.num_classes = self.n_outputs if self.n_outputs else 0
177
+ self.embed_dim = embed_dim
178
+ self.num_queries = num_queries
179
+ self.patch_size = patch_size
180
+ self.patch_embed_size = embed_dim
181
+ self.num_heads = num_heads
182
+ self.depth = depth
183
+ self.drop_path = drop_path
184
+ self.attn_drop = attn_drop
185
+ self.drop_prob_chan = drop_prob_chan
186
+ self.mlp_ratio = mlp_ratio
187
+ self.activation = activation
188
+
189
+ # Layers
190
+ self.patch_embed = _PatchEmbedNetwork(
191
+ embed_dim=self.embed_dim, patch_size=self.patch_size
192
+ )
193
+ self.freq_embed = _FrequencyFeatureEmbedder(
194
+ embed_dim=self.embed_dim, patch_size=self.patch_size
195
+ )
196
+ # For weight loading, we omit the normalization here to match parameter count
197
+ self.channel_location_embedder = _Mlp(
198
+ in_features=int(self.patch_embed_size),
199
+ out_features=int(self.patch_embed_size),
200
+ hidden_features=int(self.patch_embed_size * 2),
201
+ act_layer=self.activation,
202
+ drop=self.drop_prob_chan,
203
+ )
204
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
205
+ self.cross_attn = _CrossAttentionBlock(
206
+ num_queries=self.num_queries,
207
+ input_embed_dim=self.embed_dim,
208
+ output_embed_dim=self.embed_dim,
209
+ num_heads=self.num_heads,
210
+ ff_dim=int(self.mlp_ratio * self.embed_dim),
211
+ )
212
+ self.blocks = nn.ModuleList(
213
+ [
214
+ _RotaryTransformerBlock(
215
+ dim=int(self.embed_dim * self.num_queries),
216
+ num_heads=int(self.num_heads * self.num_queries),
217
+ mlp_ratio=self.mlp_ratio,
218
+ qkv_bias=True,
219
+ drop=0.0,
220
+ attn_drop=0.0,
221
+ drop_path=self.drop_path,
222
+ norm_layer=norm_layer,
223
+ )
224
+ for i in range(depth)
225
+ ]
226
+ )
227
+ self.norm = norm_layer(int(self.embed_dim * self.num_queries))
228
+
229
+ self._channel_location_cache: Dict[int, torch.Tensor] = {}
230
+
231
+ if self.num_classes == 0:
232
+ self.decoder_head = _PatchReconstructionHeadWithQueries(
233
+ input_dim=self.patch_size,
234
+ embed_dim=self.embed_dim,
235
+ num_heads=self.num_heads,
236
+ num_queries=self.num_queries,
237
+ )
238
+ self.channel_emb = _ChannelEmbeddings(self.embed_dim)
239
+ else:
240
+ self.final_layer = _ClassificationHeadWithQueries(
241
+ input_dim=self.patch_size,
242
+ num_queries=self.num_queries,
243
+ embed_dim=self.embed_dim,
244
+ num_classes=self.num_classes,
245
+ num_heads=self.num_heads,
246
+ )
247
+ self.mask_token.requires_grad = (
248
+ False # no use of mask token for classification
249
+ )
250
+
251
+ self.initialize_weights()
252
+
253
+ def initialize_weights(self) -> None:
254
+ self.cross_attn.initialize_weights()
255
+ trunc_normal_(self.mask_token, std=0.02)
256
+ self.apply(self._init_weights)
257
+ self.fix_init_weight()
258
+
259
+ def _init_weights(self, m: nn.Module) -> None:
260
+ if isinstance(m, nn.Linear):
261
+ torch.nn.init.xavier_normal_(m.weight)
262
+ if isinstance(m, nn.Linear) and m.bias is not None:
263
+ nn.init.constant_(m.bias, 0)
264
+ elif isinstance(m, nn.LayerNorm):
265
+ nn.init.constant_(m.bias, 0)
266
+ nn.init.constant_(m.weight, 1.0)
267
+
268
+ def fix_init_weight(self) -> None:
269
+ def rescale(param: torch.Tensor, layer_id: int) -> None:
270
+ param.div_(math.sqrt(2.0 * layer_id))
271
+
272
+ for layer_id, layer in enumerate(self.blocks):
273
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
274
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
275
+
276
+ def prepare_tokens(
277
+ self,
278
+ x_signal: torch.Tensor,
279
+ channel_locations: torch.Tensor,
280
+ mask: Optional[torch.Tensor] = None,
281
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
282
+ num_channels = channel_locations.shape[1]
283
+ num_patches_per_channel = x_signal.shape[-1] // self.patch_size
284
+ x_patched = self.patch_embed(x_signal)
285
+ freq_embed = self.freq_embed(x_signal)
286
+ x_patched = x_patched + freq_embed
287
+ x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel
288
+
289
+ if mask is not None:
290
+ mask_tokens = self.mask_token.repeat(
291
+ x_masked.shape[0], x_masked.shape[1], 1
292
+ ) # (B, N, D) N = C * num_patches_per_channel
293
+ mask = rearrange(
294
+ mask, "B C (S P) -> B (C S) P", P=self.patch_size
295
+ ) # (B, C, T) -> (B, N, P)
296
+ mask = (
297
+ (mask.sum(dim=-1) > 0).unsqueeze(-1).float()
298
+ ) # (B, N, 1), since a patch is either fully masked or not
299
+ x_masked = torch.where(mask.bool(), mask_tokens, x_masked)
300
+
301
+ channel_min = torch.min(channel_locations, dim=1, keepdim=True)[0]
302
+ channel_max = torch.max(channel_locations, dim=1, keepdim=True)[0]
303
+ channel_locations = (channel_locations - channel_min) / (
304
+ channel_max - channel_min + 1e-8
305
+ )
306
+
307
+ if mask is not None:
308
+ channel_locations = (
309
+ channel_locations + torch.randn_like(channel_locations) * 0.02
310
+ )
311
+
312
+ channel_locations = nerf_positional_encoding(
313
+ channel_locations, self.patch_embed_size
314
+ )
315
+ channel_locations_emb = self.channel_location_embedder(channel_locations)
316
+
317
+ x_tokenized = rearrange(x_masked, "B (C t) D -> (B t) C D", C=num_channels)
318
+ channel_locations_emb = channel_locations_emb.repeat(
319
+ num_patches_per_channel, 1, 1
320
+ )
321
+ x_tokenized = x_tokenized + channel_locations_emb
322
+
323
+ return x_tokenized, channel_locations_emb
324
+
325
+ def forward(
326
+ self,
327
+ X: torch.Tensor,
328
+ mask: Optional[torch.Tensor] = None,
329
+ channel_locations: Optional[torch.Tensor] = None,
330
+ channel_names: Optional[torch.Tensor] = None,
331
+ ) -> torch.Tensor:
332
+ """Forward pass."""
333
+ x_signal = X
334
+ B, C, _ = x_signal.shape
335
+
336
+ if channel_locations is None:
337
+ channel_locations = self.get_default_channel_locations(
338
+ batch_size=B,
339
+ num_channels=C,
340
+ device=x_signal.device,
341
+ dtype=x_signal.dtype,
342
+ )
343
+
344
+ x, channel_locations_emb = self.prepare_tokens(
345
+ x_signal, channel_locations, mask=mask
346
+ )
347
+ x, _ = self.cross_attn(x)
348
+ x = rearrange(x, "(B t) Q D -> B t (Q D)", B=B)
349
+ num_patches = x.shape[1]
350
+
351
+ for blk in self.blocks:
352
+ x = blk(x)
353
+ x_latent = self.norm(x)
354
+
355
+ if self.num_classes > 0:
356
+ return self.final_layer(x_latent)
357
+
358
+ if channel_names is None:
359
+ raise ValueError("channel_names must be provided for reconstruction tasks.")
360
+ channel_emb = self.channel_emb(channel_names)
361
+ channel_emb = channel_emb.repeat(num_patches, 1, 1)
362
+ decoder_queries = channel_locations_emb + channel_emb
363
+ return self.decoder_head(x_latent, decoder_queries)
364
+
365
+ def get_default_channel_locations(
366
+ self,
367
+ batch_size: int,
368
+ num_channels: int,
369
+ device: torch.device,
370
+ dtype: torch.dtype,
371
+ ) -> torch.Tensor:
372
+ if num_channels not in self._channel_location_cache:
373
+ template = self.build_channel_location_template(num_channels)
374
+ self._channel_location_cache[num_channels] = template
375
+ template = self._channel_location_cache[num_channels].to(
376
+ device=device, dtype=dtype
377
+ )
378
+ return template.unsqueeze(0).repeat(batch_size, 1, 1)
379
+
380
+ def build_channel_location_template(self, num_channels: int) -> torch.Tensor:
381
+ """Build channel location template for the model.
382
+
383
+ Attempts to extract channel locations from chs_info. Falls back to a default
384
+ linear spacing along the x-axis if real locations are unavailable.
385
+
386
+ Parameters
387
+ ----------
388
+ num_channels : int
389
+ Number of channels to generate locations for.
390
+
391
+ Returns
392
+ -------
393
+ torch.Tensor
394
+ Tensor of shape (num_channels, 3) with channel locations in 3D space.
395
+ """
396
+ # Try to extract channel locations from chs_info using the unified utility
397
+ channel_info = getattr(self, "_chs_info", None)
398
+ if channel_info is not None:
399
+ locs = extract_channel_locations_from_chs_info(
400
+ channel_info, num_channels=num_channels
401
+ )
402
+ if locs is not None and len(locs) == num_channels:
403
+ return torch.from_numpy(locs).float()
404
+
405
+ # Fallback: generate default linear spacing along x-axis
406
+ positions = torch.linspace(-1.0, 1.0, steps=num_channels, dtype=torch.float32)
407
+ zeros = torch.zeros_like(positions)
408
+ locs_tensor = torch.stack([positions, zeros, zeros], dim=-1)
409
+ return locs_tensor
410
+
411
+ def _get_default_channel_locations(
412
+ self,
413
+ batch_size: int,
414
+ num_channels: int,
415
+ device: torch.device,
416
+ dtype: torch.dtype,
417
+ ) -> torch.Tensor:
418
+ return self.get_default_channel_locations(
419
+ batch_size=batch_size,
420
+ num_channels=num_channels,
421
+ device=device,
422
+ dtype=dtype,
423
+ )
424
+
425
+ def _build_channel_location_template(self, num_channels: int) -> torch.Tensor:
426
+ return self.build_channel_location_template(num_channels)
427
+
428
+
429
+ def trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0) -> None:
430
+ nn.init.trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
431
+
432
+
433
+ def nerf_positional_encoding(coords: torch.Tensor, embed_size: int) -> torch.Tensor:
434
+ """
435
+ coords: (N, C, 3)
436
+ Returns: (N, C, embed_size)
437
+ """
438
+ N, C, dim = coords.shape
439
+ device = coords.device
440
+ freqs = embed_size // (2 * dim)
441
+ leftover = embed_size - freqs * 2 * dim
442
+ freq_bands = 2.0 ** torch.arange(freqs, device=device).float()
443
+ scaled_coords = coords.unsqueeze(-1) * freq_bands.view(
444
+ 1, 1, 1, -1
445
+ ) # (N, C, dim, freqs)
446
+ sin_enc = torch.sin(scaled_coords) # (N, C, dim, freqs)
447
+ cos_enc = torch.cos(scaled_coords) # (N, C, dim, freqs)
448
+ encoded = (
449
+ torch.stack([sin_enc, cos_enc], dim=-1)
450
+ .permute(0, 1, 3, 2, 4)
451
+ .reshape(N, C, freqs * dim * 2)
452
+ )
453
+ if leftover > 0:
454
+ pad = torch.zeros(N, C, leftover, device=device, dtype=coords.dtype)
455
+ encoded = torch.cat([encoded, pad], dim=-1)
456
+ return encoded
457
+
458
+
459
+ class _ChannelEmbeddings(nn.Module):
460
+ r"""
461
+ This class creates embeddings for each EEG channel based on a predefined
462
+ mapping of channel names to indices.
463
+
464
+ The number of unique channels is determined by the union of channels
465
+ from SEED Pretraining, TUEG, and Siena datasets.
466
+
467
+ Parameters
468
+ ----------
469
+ embed_dim : int
470
+ Dimension of the channel embeddings.
471
+ number_channels : int
472
+ Number of unique EEG channels. Default is 90.
473
+
474
+ """
475
+
476
+ def __init__(self, embed_dim: int, number_channels=90) -> None:
477
+ super().__init__()
478
+ self.embeddings = nn.Embedding(number_channels, embed_dim)
479
+
480
+ def forward(self, indices: torch.Tensor) -> torch.Tensor:
481
+ return self.embeddings(indices)
482
+
483
+ def initialize_weights(self) -> None:
484
+ torch.init.normal_(self.embeddings.weight, std=2.0)
485
+
486
+
487
+ class _FrequencyFeatureEmbedder(nn.Module):
488
+ r"""
489
+ This class takes data that is of the form (B, C, T) and patches it
490
+ along the time dimension (T) into patches of size P (patch_size).
491
+ The output is of the form (B, C, S, P) where S = T // P.
492
+ """
493
+
494
+ def __init__(self, patch_size: int, embed_dim: int) -> None:
495
+ super().__init__()
496
+ self.patch_size = patch_size
497
+ self.embed_dim = embed_dim
498
+ in_features = 2 * (patch_size // 2 + 1)
499
+ self.frequency_to_embed = _Mlp(
500
+ in_features=in_features,
501
+ hidden_features=int(4 * in_features),
502
+ out_features=embed_dim,
503
+ )
504
+
505
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
506
+ B, C, T = x.size()
507
+ S = T // self.patch_size
508
+ # There is a chance that the input tensor is not divisible by the patch size
509
+ # In this case we need to pad the tensor with zeros
510
+ if T % self.patch_size != 0:
511
+ # Pad last dimension with zeros to make it divisible by patch size
512
+ pad_size = self.patch_size - (T % self.patch_size)
513
+ x = F.pad(x, (0, pad_size))
514
+ T = x.size(-1)
515
+ S = T // self.patch_size
516
+ x = x.view(B, C, S, self.patch_size)
517
+
518
+ freq_representation = fft.rfft(
519
+ x, dim=-1
520
+ ) # (B, C, num_patches, patch_size // 2 + 1)
521
+ magnitude = torch.abs(freq_representation)
522
+ phase = torch.angle(freq_representation)
523
+
524
+ # Concatenate magnitude and phase along the frequency axis (last dimension)
525
+ freq_features = torch.cat((magnitude, phase), dim=-1)
526
+ # Map frequency features to embedding dimension
527
+ embedded = self.frequency_to_embed(
528
+ freq_features
529
+ ) # (B, C, num_patches, embed_dim)
530
+ embedded = rearrange(embedded, "B C t D -> B (C t) D")
531
+ return embedded
532
+
533
+
534
+ class _RotarySelfAttentionBlock(nn.Module):
535
+ def __init__(
536
+ self,
537
+ dim: int,
538
+ num_heads: int = 8,
539
+ qkv_bias: bool = True,
540
+ qk_scale: Optional[float] = None,
541
+ attn_drop: float = 0.0,
542
+ proj_drop: float = 0.0,
543
+ ) -> None:
544
+ super().__init__()
545
+ self.dim = dim
546
+ self.num_heads = num_heads
547
+ head_dim = dim // num_heads
548
+ self.rotary_emb = RotaryEmbedding(dim=head_dim, learned_freq=False)
549
+
550
+ self.scale = qk_scale or head_dim**-0.5
551
+
552
+ self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias)
553
+ self.attn_drop = attn_drop
554
+ self.attn_drop_fn = nn.Dropout(attn_drop)
555
+ self.proj = nn.Linear(dim, dim)
556
+ self.proj_drop = nn.Dropout(proj_drop)
557
+
558
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
559
+ B, N, C = x.shape
560
+ qkv = (
561
+ self.qkv_proj(x)
562
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
563
+ .permute(2, 0, 3, 1, 4)
564
+ ) # (K, B, H, N, D)
565
+ q, k, v = qkv[0], qkv[1], qkv[2]
566
+ q = self.rotary_emb.rotate_queries_or_keys(q)
567
+ k = self.rotary_emb.rotate_queries_or_keys(k)
568
+ # Calculate attention scores
569
+ attn_weights = (q @ k.transpose(-2, -1)) * self.scale # (B, H, N, N)
570
+
571
+ # Apply softmax to get attention probabilities
572
+ attn_weights = torch.softmax(attn_weights, dim=-1)
573
+
574
+ # Apply dropout
575
+ attn_weights = self.attn_drop_fn(attn_weights)
576
+
577
+ # Apply attention weights to values
578
+ attn = attn_weights @ v # (B, H, N, D)
579
+ attn = rearrange(attn, "B H N D -> B N (H D)")
580
+ return self.proj_drop(self.proj(attn))
581
+
582
+
583
+ class _FeedForwardBlock(nn.Module):
584
+ def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
585
+ super().__init__()
586
+ self.fc1 = nn.Linear(dim, hidden_dim)
587
+ self.activation = nn.GELU()
588
+ self.dropout1 = nn.Dropout(dropout)
589
+ self.dropout2 = nn.Dropout(dropout)
590
+ self.fc2 = nn.Linear(hidden_dim, dim)
591
+ self.norm = nn.LayerNorm(hidden_dim)
592
+
593
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
594
+ x = self.fc1(x)
595
+ x = self.activation(x)
596
+ x = self.dropout1(x)
597
+ x = self.norm(x)
598
+ x = self.fc2(x)
599
+ x = self.dropout2(x)
600
+ return x
601
+
602
+
603
+ class _RotaryTransformerBlock(nn.Module):
604
+ def __init__(
605
+ self,
606
+ dim: int,
607
+ num_heads: int,
608
+ mlp_ratio: float = 4.0,
609
+ qkv_bias: bool = False,
610
+ qk_scale: Optional[float] = None,
611
+ drop: float = 0.0,
612
+ attn_drop: float = 0.0,
613
+ drop_path: float = 0.0,
614
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
615
+ ) -> None:
616
+ super().__init__()
617
+ self.norm1 = norm_layer(dim)
618
+ self.attn = _RotarySelfAttentionBlock(
619
+ dim=dim,
620
+ num_heads=num_heads,
621
+ qkv_bias=qkv_bias,
622
+ qk_scale=qk_scale,
623
+ attn_drop=attn_drop,
624
+ proj_drop=drop,
625
+ )
626
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
627
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
628
+ self.norm2 = norm_layer(dim)
629
+ self.mlp = _FeedForwardBlock(
630
+ dim=dim, hidden_dim=int(dim * mlp_ratio), dropout=drop
631
+ )
632
+
633
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
634
+ x = x + self.drop_path1(self.attn(self.norm1(x)))
635
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
636
+ return x
637
+
638
+
639
+ class _PatchReconstructionHeadWithQueries(nn.Module):
640
+ def __init__(
641
+ self,
642
+ input_dim: int = 8,
643
+ embed_dim: int = 768,
644
+ num_heads: int = 8,
645
+ num_queries: int = 4,
646
+ ) -> None:
647
+ super().__init__()
648
+ self.input_dim = input_dim
649
+ self.embed_dim = embed_dim
650
+ self.reconstruction_shape = self.input_dim
651
+ self.num_queries = num_queries
652
+ # Projection from embed space to pixel space, according to type of input
653
+ self.decoder_pred = nn.TransformerDecoder(
654
+ nn.TransformerDecoderLayer(
655
+ embed_dim,
656
+ num_heads,
657
+ dropout=0.0,
658
+ batch_first=True,
659
+ activation="gelu",
660
+ dim_feedforward=int(embed_dim * 4),
661
+ norm_first=True,
662
+ ),
663
+ num_layers=1,
664
+ )
665
+ self.norm = nn.LayerNorm(embed_dim)
666
+ self.decoder_linear = _Mlp(
667
+ embed_dim, int(embed_dim * 4), input_dim, act_layer=nn.GELU, drop=0.0
668
+ ) # nn.Linear(embed_dim, input_dim, bias=True)
669
+
670
+ def forward(self, enc: torch.Tensor, decoder_queries: torch.Tensor) -> torch.Tensor:
671
+ """
672
+ enc: [B, num_patches, embed_dim], embed_dim = Q*D
673
+ decoder_queries: [B*num_patches, num_channels, embed_dim]
674
+ """
675
+
676
+ B, num_patches, embed_dim = enc.shape
677
+ enc = rearrange(enc, "B t (Q D) -> (B t) Q D", Q=self.num_queries)
678
+ out = self.decoder_pred(decoder_queries, enc) # (B*t, C, D)
679
+ out = self.norm(out)
680
+ out = self.decoder_linear(out) # (B*t, C, patch_size)
681
+ out = rearrange(out, "(B t) C P -> B C (t P)", B=B)
682
+ return out
683
+
684
+
685
+ class _ClassificationHeadWithQueries(nn.Module):
686
+ def __init__(
687
+ self,
688
+ input_dim: int = 8,
689
+ embed_dim: int = 768,
690
+ num_queries: int = 8,
691
+ num_heads: int = 8,
692
+ num_classes: int = 2,
693
+ drop_decoder: float = 0.15,
694
+ drop_ffn: float = 0.15,
695
+ ) -> None:
696
+ super().__init__()
697
+ self.input_dim = input_dim
698
+ self.embed_dim = int(embed_dim * num_queries)
699
+ self.reconstruction_shape = self.input_dim
700
+ self.decoder_attn = nn.MultiheadAttention(
701
+ self.embed_dim, num_heads, batch_first=True, dropout=drop_decoder
702
+ )
703
+ self.decoder_ffn = _Mlp(
704
+ in_features=self.embed_dim,
705
+ hidden_features=int(self.embed_dim * 4),
706
+ out_features=num_classes,
707
+ act_layer=nn.GELU,
708
+ drop=drop_ffn,
709
+ )
710
+
711
+ self.learned_agg = nn.Parameter(
712
+ torch.randn(1, 1, self.embed_dim), requires_grad=True
713
+ )
714
+
715
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
716
+ """
717
+ Output shape:
718
+ [B, num_tokens, in_chans, input_dim]
719
+ Args:
720
+ x: [B, num_tokens+1, embed_dim]
721
+ channel_embeddings: [B, in_chans, embed_dim]
722
+ """
723
+ B, num_patches, embed_dim = x.shape
724
+ decoder_queries = self.learned_agg.repeat(x.shape[0], 1, 1)
725
+
726
+ x = self.decoder_attn(query=decoder_queries, key=x, value=x)[0]
727
+ x = x[:, 0, :]
728
+ x = self.decoder_ffn(x)
729
+ return x
730
+
731
+
732
+ class _CrossAttentionBlock(nn.Module):
733
+ def __init__(
734
+ self,
735
+ num_queries: int,
736
+ input_embed_dim: int,
737
+ output_embed_dim: int,
738
+ num_heads: int,
739
+ dropout_p: float = 0.1,
740
+ ff_dim: int = 2048,
741
+ ) -> None:
742
+ super().__init__()
743
+ self.num_queries = num_queries
744
+ self.dropout_p = dropout_p
745
+ self.query_embed = nn.Parameter(
746
+ torch.randn(1, num_queries, input_embed_dim), requires_grad=True
747
+ ) # Learnable queries
748
+ self.cross_attention = nn.MultiheadAttention(
749
+ embed_dim=input_embed_dim,
750
+ num_heads=num_heads,
751
+ dropout=dropout_p,
752
+ batch_first=True,
753
+ )
754
+ self.temperature = nn.Parameter(torch.tensor(1.0), requires_grad=False)
755
+
756
+ self.ffn = _Mlp(
757
+ input_embed_dim,
758
+ ff_dim,
759
+ output_embed_dim,
760
+ act_layer=nn.GELU,
761
+ drop=dropout_p,
762
+ )
763
+ self.keys_norm = nn.LayerNorm(input_embed_dim)
764
+ self.values_norm = nn.LayerNorm(input_embed_dim)
765
+ self.queries_norm = nn.LayerNorm(input_embed_dim)
766
+ self.query_self_attn = nn.TransformerEncoder(
767
+ nn.TransformerEncoderLayer(
768
+ input_embed_dim,
769
+ nhead=num_heads,
770
+ activation="gelu",
771
+ dim_feedforward=ff_dim,
772
+ batch_first=True,
773
+ norm_first=True,
774
+ ),
775
+ num_layers=3,
776
+ )
777
+
778
+ def initialize_weights(self) -> None:
779
+ torch.nn.init.orthogonal_(self.query_embed, gain=1.0)
780
+ self.apply(self._init_weights)
781
+
782
+ def _init_weights(self, m: nn.Module) -> None:
783
+ if isinstance(m, nn.Linear):
784
+ # we use xavier_uniform following official JAX ViT:
785
+ torch.nn.init.xavier_normal_(m.weight)
786
+ if isinstance(m, nn.Linear) and m.bias is not None:
787
+ nn.init.constant_(m.bias, 0)
788
+ elif isinstance(m, nn.LayerNorm):
789
+ nn.init.constant_(m.bias, 0)
790
+ nn.init.constant_(m.weight, 1.0)
791
+
792
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
793
+ # x is the input with shape (batch_size*num_patches, num_channels, embed_dim)
794
+ batch_size, _, _ = x.size()
795
+ queries = self.query_embed.repeat(batch_size, 1, 1)
796
+ queries = self.queries_norm(queries)
797
+ keys = self.keys_norm(x)
798
+ values = self.values_norm(x)
799
+
800
+ attention_out, attention_scores = self.cross_attention(
801
+ query=queries, key=keys, value=values
802
+ ) # Shape: (batch_size*num_patches, num_queries, embed_dim)
803
+ attention_out = self.ffn(attention_out) + attention_out
804
+
805
+ attention_out = self.query_self_attn(attention_out)
806
+ return (
807
+ attention_out,
808
+ attention_scores,
809
+ ) # Shape: (batch_size*num_patches, num_queries, embed_dim)
810
+
811
+
812
+ class _PatchEmbedNetwork(nn.Module):
813
+ def __init__(self, embed_dim: int = 64, patch_size: int = 40) -> None:
814
+ super().__init__()
815
+ self.patch_size = patch_size
816
+ self.embed_dim = embed_dim
817
+ self.in_channels = 1
818
+ self.out_channels = int(embed_dim // 4)
819
+ self.groups = 4
820
+ self.kernel_size = int(patch_size // 2)
821
+ self.proj_in = nn.Sequential(
822
+ nn.Conv2d(
823
+ in_channels=self.in_channels,
824
+ out_channels=self.out_channels,
825
+ kernel_size=(1, self.kernel_size - 1),
826
+ stride=(1, self.kernel_size // 2),
827
+ padding=(0, self.kernel_size // 2 - 1),
828
+ ),
829
+ nn.GroupNorm(self.groups, self.out_channels),
830
+ nn.GELU(),
831
+ nn.Conv2d(
832
+ in_channels=self.out_channels,
833
+ out_channels=self.out_channels,
834
+ kernel_size=(1, 3),
835
+ stride=(1, 1),
836
+ padding=(0, 1),
837
+ ),
838
+ nn.GroupNorm(self.groups, self.out_channels),
839
+ nn.GELU(),
840
+ nn.Conv2d(
841
+ in_channels=self.out_channels,
842
+ out_channels=self.out_channels,
843
+ kernel_size=(1, 3),
844
+ stride=(1, 1),
845
+ padding=(0, 1),
846
+ ),
847
+ nn.GroupNorm(self.groups, self.out_channels),
848
+ nn.GELU(),
849
+ )
850
+
851
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
852
+ """
853
+ x: (B, C, T)
854
+ output: (B, C*S, D) where S = T//patch_size, D = embed_dim
855
+ """
856
+ x = rearrange(x, "B C (S P) -> B (C S) P", P=self.patch_size)
857
+ x = x.unsqueeze(1)
858
+ x = self.proj_in(x)
859
+ x = rearrange(x, "B E CS D -> B CS (D E)")
860
+ return x
861
+
862
+
863
+ class _Mlp(nn.Module):
864
+ r"""MLP as used in Vision Transformer, MLP-Mixer and related networks.
865
+
866
+ Code copied from timm.models.mlp.Mlp
867
+ """
868
+
869
+ def __init__(
870
+ self,
871
+ in_features,
872
+ hidden_features=None,
873
+ out_features=None,
874
+ act_layer=nn.GELU,
875
+ drop=0.0,
876
+ ):
877
+ super().__init__()
878
+ out_features = out_features or in_features
879
+ hidden_features = hidden_features or in_features
880
+ self.fc1 = nn.Linear(in_features, hidden_features)
881
+ self.act = act_layer()
882
+ self.fc2 = nn.Linear(hidden_features, out_features)
883
+ self.drop = nn.Dropout(drop)
884
+
885
+ def forward(self, x):
886
+ x = self.fc1(x)
887
+ x = self.act(x)
888
+ x = self.drop(x)
889
+ x = self.fc2(x)
890
+ x = self.drop(x)
891
+ return x