braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev171178473__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (70) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/augmentation/functional.py +154 -54
  3. braindecode/augmentation/transforms.py +2 -2
  4. braindecode/datasets/__init__.py +10 -2
  5. braindecode/datasets/base.py +116 -152
  6. braindecode/datasets/bcicomp.py +4 -4
  7. braindecode/datasets/bids.py +3 -3
  8. braindecode/datasets/experimental.py +218 -0
  9. braindecode/datasets/mne.py +3 -5
  10. braindecode/datasets/moabb.py +2 -2
  11. braindecode/datasets/nmt.py +2 -2
  12. braindecode/datasets/sleep_physio_challe_18.py +4 -3
  13. braindecode/datasets/sleep_physionet.py +2 -2
  14. braindecode/datasets/tuh.py +2 -2
  15. braindecode/datasets/xy.py +2 -2
  16. braindecode/datautil/serialization.py +18 -13
  17. braindecode/eegneuralnet.py +2 -0
  18. braindecode/functional/functions.py +6 -2
  19. braindecode/functional/initialization.py +2 -3
  20. braindecode/models/__init__.py +12 -8
  21. braindecode/models/atcnet.py +156 -17
  22. braindecode/models/attentionbasenet.py +148 -16
  23. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  24. braindecode/models/base.py +280 -2
  25. braindecode/models/bendr.py +469 -0
  26. braindecode/models/biot.py +3 -1
  27. braindecode/models/ctnet.py +7 -4
  28. braindecode/models/deep4.py +6 -2
  29. braindecode/models/deepsleepnet.py +127 -5
  30. braindecode/models/eegconformer.py +114 -15
  31. braindecode/models/eeginception_erp.py +82 -7
  32. braindecode/models/eeginception_mi.py +2 -0
  33. braindecode/models/eegnet.py +64 -177
  34. braindecode/models/eegnex.py +113 -6
  35. braindecode/models/eegsimpleconv.py +2 -0
  36. braindecode/models/eegtcnet.py +1 -1
  37. braindecode/models/labram.py +188 -84
  38. braindecode/models/patchedtransformer.py +640 -0
  39. braindecode/models/sccnet.py +81 -8
  40. braindecode/models/shallow_fbcsp.py +2 -0
  41. braindecode/models/signal_jepa.py +109 -27
  42. braindecode/models/sinc_shallow.py +10 -9
  43. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  44. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  45. braindecode/models/sparcnet.py +2 -0
  46. braindecode/models/sstdpn.py +869 -0
  47. braindecode/models/summary.csv +42 -41
  48. braindecode/models/tidnet.py +2 -0
  49. braindecode/models/tsinception.py +15 -3
  50. braindecode/models/usleep.py +108 -9
  51. braindecode/models/util.py +8 -5
  52. braindecode/modules/attention.py +10 -10
  53. braindecode/modules/blocks.py +3 -3
  54. braindecode/modules/filter.py +2 -3
  55. braindecode/modules/layers.py +18 -17
  56. braindecode/preprocessing/__init__.py +24 -0
  57. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  58. braindecode/preprocessing/preprocess.py +42 -39
  59. braindecode/preprocessing/util.py +166 -0
  60. braindecode/preprocessing/windowers.py +24 -19
  61. braindecode/samplers/base.py +8 -8
  62. braindecode/version.py +1 -1
  63. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/METADATA +12 -3
  64. braindecode-1.3.0.dev171178473.dist-info/RECORD +106 -0
  65. braindecode/models/eegresnet.py +0 -362
  66. braindecode-1.2.0.dev184328194.dist-info/RECORD +0 -101
  67. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/WHEEL +0 -0
  68. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/LICENSE.txt +0 -0
  69. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/NOTICE.txt +0 -0
  70. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,122 @@ from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
16
16
  class EEGNeX(EEGModuleMixin, nn.Module):
17
17
  """EEGNeX model from Chen et al. (2024) [eegnex]_.
18
18
 
19
+ :bdg-success:`Convolution`
20
+
19
21
  .. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
20
22
  :align: center
21
23
  :alt: EEGNeX Architecture
24
+ :width: 620px
25
+
26
+ .. rubric:: Architectural Overview
27
+
28
+ EEGNeX is a **purely convolutional** architecture that refines the EEGNet-style stem
29
+ and deepens the temporal stack with **dilated temporal convolutions**. The end-to-end
30
+ flow is:
31
+
32
+ - (i) **Block-1/2**: two temporal convolutions ``(1 x L)`` with BN refine a
33
+ learned FIR-like *temporal filter bank* (no pooling yet);
34
+ - (ii) **Block-3**: depthwise **spatial** convolution across electrodes
35
+ ``(n_chans x 1)`` with max-norm constraint, followed by ELU → AvgPool (time) → Dropout;
36
+ - (iii) **Block-4/5**: two additional **temporal** convolutions with increasing **dilation**
37
+ to expand the receptive field; the last block applies ELU → AvgPool → Dropout → Flatten;
38
+ - (iv) **Classifier**: a max-norm–constrained linear layer.
39
+
40
+ The published work positions EEGNeX as a compact, conv-only alternative that consistently
41
+ outperforms prior baselines across MOABB-style benchmarks, with the popular
42
+ “EEGNeX-8,32” shorthand denoting *8 temporal filters* and *kernel length 32*.
43
+
44
+
45
+ .. rubric:: Macro Components
46
+
47
+ - **Block-1 / Block-2 — Temporal filter (learned).**
48
+
49
+ - *Operations.*
50
+ - :class:`torch.nn.Conv2d` with kernels ``(1, L)``
51
+ - :class:`torch.nn.BatchNorm2d` (no nonlinearity until Block-3, mirroring a linear FIR analysis stage).
52
+ These layers set up frequency-selective detectors before spatial mixing.
53
+
54
+ - *Interpretability.* Kernels can be inspected as FIR filters; two stacked temporal
55
+ convs allow longer effective kernels without parameter blow-up.
56
+
57
+ - **Block-3 — Spatial projection + condensation.**
58
+
59
+ - *Operations.*
60
+ - :class:`braindecode.modules.Conv2dWithConstraint` with kernel``(n_chans, 1)``
61
+ and ``groups = filter_2`` (depthwise across filters)
62
+ - :class:`torch.nn.BatchNorm2d`
63
+ - :class:`torch.nn.ELU`
64
+ - :class:`torch.nn.AvgPool2d` (time)
65
+ - :class:`torch.nn.Dropout`.
66
+
67
+ **Role**: Learns per-filter spatial patterns over the **full montage** while temporal
68
+ pooling stabilizes and compresses features; max-norm encourages well-behaved spatial
69
+ weights similar to EEGNet practice.
70
+
71
+ - **Block-4 / Block-5 — Dilated temporal integration.**
72
+
73
+ - *Operations.*
74
+ - :class:`torch.nn.Conv2d` with kernels ``(1, k)`` and **dilations**
75
+ (e.g., 2 then 4);
76
+ - :class:`torch.nn.BatchNorm2d`
77
+ - :class:`torch.nn.ELU`
78
+ - :class:`torch.nn.AvgPool2d` (time)
79
+ - :class:`torch.nn.Dropout`
80
+ - :class:`torch.nn.Flatten`.
81
+
82
+ **Role**: Expands the temporal receptive field efficiently to capture rhythms and
83
+ long-range context after condensation.
84
+
85
+ - **Final Classifier — Max-norm linear.**
86
+
87
+ - *Operations.*
88
+ - :class:`braindecode.modules.LinearWithConstraint` maps the flattened
89
+ vector to the target classes; the max-norm constraint regularizes the readout.
90
+
91
+
92
+ .. rubric:: Convolutional Details
93
+
94
+ - **Temporal (where time-domain patterns are learned).**
95
+ Blocks 1-2 learn the primary filter bank (oscillations/transients), while Blocks 4-5
96
+ use **dilation** to integrate over longer horizons without extra pooling. The final
97
+ AvgPool in Block-5 sets the output token rate and helps noise suppression.
98
+
99
+ - **Spatial (how electrodes are processed).**
100
+ A *single* depthwise spatial conv (Block-3) spans the entire electrode set
101
+ (kernel ``(n_chans, 1)``), producing per-temporal-filter topographies; no cross-filter
102
+ mixing occurs at this stage, aiding interpretability.
103
+
104
+ - **Spectral (how frequency content is captured).**
105
+ Frequency selectivity emerges from the learned temporal kernels; dilation broadens effective
106
+ bandwidth coverage by composing multiple scales.
107
+
108
+ .. rubric:: Additional Mechanisms
109
+
110
+ - **EEGNeX-8,32 naming.** “8,32” indicates *8 temporal filters* and *kernel length 32*,
111
+ reflecting the paper's ablation path from EEGNet-8,2 toward thicker temporal kernels
112
+ and a deeper conv stack.
113
+ - **Max-norm constraints.** Spatial (Block-3) and final linear layers use max-norm
114
+ regularization—standard in EEG CNNs—to reduce overfitting and encourage stable spatial
115
+ patterns.
116
+
117
+ .. rubric:: Usage and Configuration
118
+
119
+ - **Kernel schedule.** Start with the canonical **EEGNeX-8,32** (``filter_1=8``,
120
+ ``kernel_block_1_2=32``) and keep **Block-3** depth multiplier modest (e.g., 2) to match
121
+ the paper's “pure conv” profile.
122
+ - **Pooling vs. dilation.** Use pooling in Blocks 3 and 5 to control compute and variance;
123
+ increase dilations (Blocks 4-5) to widen temporal context when windows are short.
124
+ - **Regularization.** Combine dropout (Blocks 3 & 5) with max-norm on spatial and
125
+ classifier layers; prefer ELU activations for stable training on small EEG datasets.
126
+
127
+
128
+ - The braindecode implementation follows the paper's conv-only design with five blocks
129
+ and reproduces the depthwise spatial step and dilated temporal stack. See the class
130
+ reference for exact kernel sizes, dilations, and pooling defaults. You can check the
131
+ original implementation at [EEGNexCode]_.
132
+
133
+ .. versionadded:: 1.1
134
+
22
135
 
23
136
  Parameters
24
137
  ----------
@@ -45,12 +158,6 @@ class EEGNeX(EEGModuleMixin, nn.Module):
45
158
  avg_pool_block5 : tuple[int, int], optional
46
159
  Pooling size for block 5. Default is (1, 8).
47
160
 
48
- Notes
49
- -----
50
- This implementation is not guaranteed to be correct, has not been checked
51
- by original authors, only reimplemented from the paper description and
52
- source code in tensorflow [EEGNexCode]_.
53
-
54
161
  References
55
162
  ----------
56
163
  .. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
@@ -21,6 +21,8 @@ from braindecode.models.base import EEGModuleMixin
21
21
  class EEGSimpleConv(EEGModuleMixin, torch.nn.Module):
22
22
  """EEGSimpleConv from Ouahidi, YE et al. (2023) [Yassine2023]_.
23
23
 
24
+ :bdg-success:`Convolution`
25
+
24
26
  .. figure:: https://raw.githubusercontent.com/elouayas/EEGSimpleConv/refs/heads/main/architecture.png
25
27
  :align: center
26
28
  :alt: EEGSimpleConv Architecture
@@ -157,7 +157,7 @@ class EEGTCNet(EEGModuleMixin, nn.Module):
157
157
  class _EEGNetTC(nn.Module):
158
158
  """EEGNet Temporal Convolutional Network (TCN) block.
159
159
 
160
- The main difference from our EEGNetV4 (braindecode) implementation is the
160
+ The main difference from our :class:`EEGNet` (braindecode) implementation is the
161
161
  kernel and dimensional order. Because of this, we decided to keep this
162
162
  implementation in a future issue; we will re-evaluate if it is necessary
163
163
  to maintain this separate implementation.
@@ -22,6 +22,8 @@ from braindecode.modules import MLP, DropPath
22
22
  class Labram(EEGModuleMixin, nn.Module):
23
23
  """Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
24
 
25
+ :bdg-danger:`Large Brain Model`
26
+
25
27
  .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
26
28
  :align: center
27
29
  :alt: Labram Architecture.
@@ -43,31 +45,45 @@ class Labram(EEGModuleMixin, nn.Module):
43
45
  equals True. The original implementation uses (batch, n_chans, n_patches,
44
46
  patch_size) as input with static segmentation of the input data.
45
47
 
46
- The models have the following sequence of steps:
47
- if neural tokenizer:
48
- - SegmentPatch: Segment the input data in patches;
49
- - TemporalConv: Apply a temporal convolution to the segmented data;
50
- - Residual adding cls, temporal and position embeddings (optional);
51
- - WindowsAttentionBlock: Apply a windows attention block to the data;
52
- - LayerNorm: Apply layer normalization to the data;
53
- - Linear: An head linear layer to transformer the data into classes.
54
-
55
- else:
56
- - PatchEmbed: Apply a patch embedding to the input data;
57
- - Residual adding cls, temporal and position embeddings (optional);
58
- - WindowsAttentionBlock: Apply a windows attention block to the data;
59
- - LayerNorm: Apply layer normalization to the data;
60
- - Linear: An head linear layer to transformer the data into classes.
48
+ The models have the following sequence of steps::
49
+
50
+ if neural tokenizer:
51
+ - SegmentPatch: Segment the input data in patches;
52
+ - TemporalConv: Apply a temporal convolution to the segmented data;
53
+ - Residual adding cls, temporal and position embeddings (optional);
54
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
55
+ - LayerNorm: Apply layer normalization to the data;
56
+ - Linear: An head linear layer to transformer the data into classes.
57
+
58
+ else:
59
+ - PatchEmbed: Apply a patch embedding to the input data;
60
+ - Residual adding cls, temporal and position embeddings (optional);
61
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
62
+ - LayerNorm: Apply layer normalization to the data;
63
+ - Linear: An head linear layer to transformer the data into classes.
61
64
 
62
65
  .. versionadded:: 0.9
63
66
 
67
+
68
+ Examples
69
+ --------
70
+ Load pre-trained weights::
71
+
72
+ >>> import torch
73
+ >>> from braindecode.models import Labram
74
+ >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
75
+ >>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
76
+ >>> state = torch.hub.load_state_dict_from_url(url, progress=True)
77
+ >>> model.load_state_dict(state)
78
+
79
+
64
80
  Parameters
65
81
  ----------
66
82
  patch_size : int
67
83
  The size of the patch to be used in the patch embedding.
68
84
  emb_size : int
69
85
  The dimension of the embedding.
70
- in_channels : int
86
+ in_conv_channels : int
71
87
  The number of convolutional input channels.
72
88
  out_channels : int
73
89
  The number of convolutional output channels.
@@ -79,8 +95,10 @@ class Labram(EEGModuleMixin, nn.Module):
79
95
  The expansion ratio of the mlp layer
80
96
  qkv_bias : bool (default=False)
81
97
  If True, add a learnable bias to the query, key, and value tensors.
82
- qk_norm : Pytorch Normalize layer (default=None)
83
- If not None, apply LayerNorm to the query and key tensors
98
+ qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
99
+ If not None, apply LayerNorm to the query and key tensors.
100
+ Default is nn.LayerNorm for better weight transfer from original LaBraM.
101
+ Set to None to disable Q,K normalization.
84
102
  qk_scale : float (default=None)
85
103
  If not None, use this value as the scale factor. If None,
86
104
  use head_dim**-0.5, where head_dim = dim // num_heads.
@@ -92,9 +110,10 @@ class Labram(EEGModuleMixin, nn.Module):
92
110
  Dropout rate for the attention weights used on DropPath.
93
111
  norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
94
112
  The normalization layer to be used.
95
- init_values : float (default=None)
113
+ init_values : float (default=0.1)
96
114
  If not None, use this value to initialize the gamma_1 and gamma_2
97
- parameters.
115
+ parameters for residual scaling. Default is 0.1 for better weight
116
+ transfer from original LaBraM. Set to None to disable.
98
117
  use_abs_pos_emb : bool (default=True)
99
118
  If True, use absolute position embedding.
100
119
  use_mean_pooling : bool (default=True)
@@ -135,19 +154,19 @@ class Labram(EEGModuleMixin, nn.Module):
135
154
  input_window_seconds=None,
136
155
  patch_size=200,
137
156
  emb_size=200,
138
- in_channels=1,
157
+ in_conv_channels=1,
139
158
  out_channels=8,
140
159
  n_layers=12,
141
160
  att_num_heads=10,
142
161
  mlp_ratio=4.0,
143
162
  qkv_bias=False,
144
- qk_norm=None,
163
+ qk_norm=nn.LayerNorm,
145
164
  qk_scale=None,
146
165
  drop_prob=0.0,
147
166
  attn_drop_prob=0.0,
148
167
  drop_path_prob=0.0,
149
168
  norm_layer=nn.LayerNorm,
150
- init_values=None,
169
+ init_values=0.1,
151
170
  use_abs_pos_emb=True,
152
171
  use_mean_pooling=True,
153
172
  init_scale=0.001,
@@ -183,15 +202,15 @@ class Labram(EEGModuleMixin, nn.Module):
183
202
  self.patch_size = patch_size
184
203
  self.n_path = self.n_times // self.patch_size
185
204
 
186
- if neural_tokenizer and in_channels != 1:
205
+ if neural_tokenizer and in_conv_channels != 1:
187
206
  warn(
188
207
  "The model is in Neural Tokenizer mode, but the variable "
189
- + "`in_channels` is different from the default values."
190
- + "`in_channels` is only needed for the Neural Decoder mode."
191
- + "in_channels is not used in the Neural Tokenizer mode.",
208
+ + "`in_conv_channels` is different from the default values."
209
+ + "`in_conv_channels` is only needed for the Neural Decoder mode."
210
+ + "in_conv_channels is not used in the Neural Tokenizer mode.",
192
211
  UserWarning,
193
212
  )
194
- in_channels = 1
213
+ in_conv_channels = 1
195
214
  # If you can use the model in Neural Tokenizer mode,
196
215
  # temporal conv layer will be use over the patched dataset
197
216
  if neural_tokenizer:
@@ -228,7 +247,7 @@ class Labram(EEGModuleMixin, nn.Module):
228
247
  _PatchEmbed(
229
248
  n_times=self.n_times,
230
249
  patch_size=patch_size,
231
- in_channels=in_channels,
250
+ in_channels=in_conv_channels,
232
251
  emb_dim=self.emb_size,
233
252
  ),
234
253
  )
@@ -373,8 +392,7 @@ class Labram(EEGModuleMixin, nn.Module):
373
392
  Parameters
374
393
  ----------
375
394
  x : torch.Tensor
376
- The input data with shape (batch, n_chans, n_patches, patch size),
377
- if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
395
+ The input data with shape (batch, n_chans, n_times).
378
396
  input_chans : int
379
397
  The number of input channels.
380
398
  return_patch_tokens : bool
@@ -387,37 +405,72 @@ class Labram(EEGModuleMixin, nn.Module):
387
405
  x : torch.Tensor
388
406
  The output of the model.
389
407
  """
408
+ batch_size = x.shape[0]
409
+
390
410
  if self.neural_tokenizer:
391
- batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
411
+ # For neural tokenizer: input is (batch, n_chans, n_times)
412
+ # patch_embed returns (batch, n_chans, emb_dim)
413
+ x = self.patch_embed(x)
414
+ # x shape: (batch, n_chans, emb_dim)
415
+ n_patch = self.n_chans
416
+ temporal = self.emb_size
392
417
  else:
393
- batch_size, nch, n_patch = self.patch_embed(x).shape
394
- x = self.patch_embed(x)
418
+ # For neural decoder: input is (batch, n_chans, n_times)
419
+ # patch_embed returns (batch, n_patchs, emb_dim)
420
+ x = self.patch_embed(x)
421
+ # x shape: (batch, n_patchs, emb_dim)
422
+ batch_size, n_patch, temporal = x.shape
423
+
395
424
  # add the [CLS] token to the embedded patch tokens
396
425
  cls_tokens = self.cls_token.expand(batch_size, -1, -1)
397
426
 
427
+ # Concatenate cls token with patch/channel embeddings
398
428
  x = torch.cat((cls_tokens, x), dim=1)
399
429
 
400
430
  # Positional Embedding
401
- if input_chans is not None:
402
- pos_embed_used = self.position_embedding[:, input_chans]
403
- else:
404
- pos_embed_used = self.position_embedding
405
-
406
431
  if self.position_embedding is not None:
407
- pos_embed = self._adj_position_embedding(
408
- pos_embed_used=pos_embed_used, batch_size=batch_size
409
- )
432
+ if self.neural_tokenizer:
433
+ # In tokenizer mode, use channel-based position embedding
434
+ if input_chans is not None:
435
+ pos_embed_used = self.position_embedding[:, input_chans]
436
+ else:
437
+ pos_embed_used = self.position_embedding
438
+
439
+ pos_embed = self._adj_position_embedding(
440
+ pos_embed_used=pos_embed_used, batch_size=batch_size
441
+ )
442
+ else:
443
+ # In decoder mode, we have different number of patches
444
+ # Adapt position embedding for n_patch patches
445
+ # Use the first n_patch+1 positions from position_embedding
446
+ n_pos = min(self.position_embedding.shape[1], n_patch + 1)
447
+ pos_embed_used = self.position_embedding[:, :n_pos, :]
448
+ pos_embed = pos_embed_used.expand(batch_size, -1, -1)
449
+
410
450
  x += pos_embed
411
451
 
412
452
  # The time embedding is added across the channels after the [CLS] token
413
453
  if self.neural_tokenizer:
414
454
  num_ch = self.n_chans
455
+ time_embed = self._adj_temporal_embedding(
456
+ num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
457
+ )
458
+ x[:, 1:, :] += time_embed
415
459
  else:
416
- num_ch = n_patch
417
- time_embed = self._adj_temporal_embedding(
418
- num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
419
- )
420
- x[:, 1:, :] += time_embed
460
+ # In decoder mode, we have n_patch patches and don't need to expand
461
+ # Just broadcast the temporal embedding
462
+ if temporal is None:
463
+ temporal = self.emb_size
464
+
465
+ # Get temporal embeddings for n_patch patches
466
+ n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
467
+ time_embed = self.temporal_embedding[
468
+ :, 1 : n_time_tokens + 1, :
469
+ ] # (1, n_patch, emb_dim)
470
+ time_embed = time_embed.expand(
471
+ batch_size, -1, -1
472
+ ) # (batch, n_patch, emb_dim)
473
+ x[:, 1:, :] += time_embed
421
474
 
422
475
  x = self.pos_drop(x)
423
476
 
@@ -428,10 +481,10 @@ class Labram(EEGModuleMixin, nn.Module):
428
481
  if self.fc_norm is not None:
429
482
  if return_all_tokens:
430
483
  return self.fc_norm(x)
431
- temporal = x[:, 1:, :]
484
+ tokens = x[:, 1:, :]
432
485
  if return_patch_tokens:
433
- return self.fc_norm(temporal)
434
- return self.fc_norm(temporal.mean(1))
486
+ return self.fc_norm(tokens)
487
+ return self.fc_norm(tokens.mean(1))
435
488
  else:
436
489
  if return_all_tokens:
437
490
  return x
@@ -505,14 +558,16 @@ class Labram(EEGModuleMixin, nn.Module):
505
558
  def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
506
559
  """
507
560
  Adjust the dimensions of the time embedding to match the
508
- number of channels.
561
+ number of channels or patches.
509
562
 
510
563
  Parameters
511
564
  ----------
512
565
  num_ch : int
513
- The number of channels or number of code books vectors.
566
+ The number of channels or number of patches.
514
567
  batch_size : int
515
568
  Batch size of the input data.
569
+ dim_embed : int
570
+ The embedding dimension (temporal feature dimension).
516
571
 
517
572
  Returns
518
573
  -------
@@ -523,17 +578,24 @@ class Labram(EEGModuleMixin, nn.Module):
523
578
  if dim_embed is None:
524
579
  cut_dimension = self.patch_size
525
580
  else:
526
- cut_dimension = dim_embed
527
- # first step will be match the time_embed to the number of channels
528
- temporal_embedding = self.temporal_embedding[:, 1:cut_dimension, :]
581
+ cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
582
+
583
+ # Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
584
+ # Slice to cut_dimension: (1, cut_dimension, emb_size)
585
+ temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
586
+
529
587
  # Add a new dimension to the time embedding
530
- # e.g. (batch, 62, 200) -> (batch, 1, 62, 200)
588
+ # e.g. (1, 5, 200) -> (1, 1, 5, 200)
531
589
  temporal_embedding = temporal_embedding.unsqueeze(1)
532
- # Expand the time embedding to match the number of channels
533
- # or number of patches from
590
+
591
+ # Expand the time embedding to match the number of channels or patches
592
+ # (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
534
593
  temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
594
+
535
595
  # Flatten the intermediate dimensions
596
+ # (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
536
597
  temporal_embedding = temporal_embedding.flatten(1, 2)
598
+
537
599
  return temporal_embedding
538
600
 
539
601
  def _adj_position_embedding(self, pos_embed_used, batch_size):
@@ -679,25 +741,27 @@ class _SegmentPatch(nn.Module):
679
741
 
680
742
 
681
743
  class _PatchEmbed(nn.Module):
682
- """EEG to Patch Embedding.
744
+ """EEG to Patch Embedding for Neural Decoder mode.
683
745
 
684
746
  This code is used when we want to apply the patch embedding
685
- after the codebook layer.
747
+ after the codebook layer (Neural Decoder mode).
748
+
749
+ The input is expected to be in the format (Batch, n_channels, n_times),
750
+ but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
751
+ This class reshapes the input to the pre-patched format, then applies a 2D
752
+ convolution to project this pre-patched data to the embedding dimension,
753
+ and finally flattens across channels to produce a unified embedding.
686
754
 
687
755
  Parameters:
688
756
  -----------
689
757
  n_times: int (default=2000)
690
- Number of temporal components of the input tensor.
758
+ Number of temporal components of the input tensor (used for dimension calculation).
691
759
  patch_size: int (default=200)
692
760
  Size of the patch, default is 1-seconds with 200Hz.
693
761
  in_channels: int (default=1)
694
- Number of input channels for to be used in the convolution.
762
+ Number of input channels (from VQVAE codebook).
695
763
  emb_dim: int (default=200)
696
- Number of out_channes to be used in the convolution, here,
697
- we used the same as patch_size.
698
- n_codebooks: int (default=62)
699
- Number of patches to be used in the convolution, here,
700
- we used the same as n_times // patch_size.
764
+ Number of output embedding dimension.
701
765
  """
702
766
 
703
767
  def __init__(
@@ -707,10 +771,13 @@ class _PatchEmbed(nn.Module):
707
771
  self.n_times = n_times
708
772
  self.patch_size = patch_size
709
773
  self.patch_shape = (1, self.n_times // self.patch_size)
710
- n_patchs = n_codebooks * (self.n_times // self.patch_size)
711
-
712
- self.n_patchs = n_patchs
774
+ self.n_patchs = self.n_times // self.patch_size
775
+ self.emb_dim = emb_dim
776
+ self.in_channels = in_channels
713
777
 
778
+ # 2D Conv to project the pre-patched data
779
+ # Input: (Batch, in_channels, n_patches, patch_size)
780
+ # After proj: (Batch, emb_dim, n_patches, 1)
714
781
  self.proj = nn.Conv2d(
715
782
  in_channels=in_channels,
716
783
  out_channels=emb_dim,
@@ -718,27 +785,64 @@ class _PatchEmbed(nn.Module):
718
785
  stride=(1, self.patch_size),
719
786
  )
720
787
 
721
- self.merge_transpose = Rearrange(
722
- "Batch ch patch spatch -> Batch patch spatch ch",
723
- )
724
-
725
788
  def forward(self, x):
726
789
  """
727
- Apply the convolution to the input tensor.
728
- then merge the output tensor to the desired shape.
790
+ Apply the temporal projection to the input tensor after grouping channels.
729
791
 
730
- Parameters:
731
- -----------
732
- x: torch.Tensor
733
- Input tensor of shape (Batch, Channels, n_patchs, patch_size).
792
+ Parameters
793
+ ----------
794
+ x : torch.Tensor
795
+ Input tensor of shape (Batch, n_channels, n_times) or
796
+ (Batch, n_channels, n_patches, patch_size).
734
797
 
735
- Return:
798
+ Returns
736
799
  -------
737
- x: torch.Tensor
738
- Output tensor of shape (Batch, n_patchs, patch_size, channels).
800
+ torch.Tensor
801
+ Output tensor of shape (Batch, n_patchs, emb_dim).
739
802
  """
803
+ if x.ndim == 4:
804
+ batch_size, n_channels, n_patchs, patch_len = x.shape
805
+ if patch_len != self.patch_size:
806
+ raise ValueError(
807
+ "When providing a 4D tensor, the last dimension "
808
+ f"({patch_len}) must match patch_size ({self.patch_size})."
809
+ )
810
+ n_times = n_patchs * patch_len
811
+ x = x.reshape(batch_size, n_channels, n_times)
812
+ elif x.ndim == 3:
813
+ batch_size, n_channels, n_times = x.shape
814
+ else:
815
+ raise ValueError(
816
+ "Input must be either 3D (batch, channels, times) or "
817
+ "4D (batch, channels, n_patches, patch_size)."
818
+ )
819
+
820
+ if n_times % self.patch_size != 0:
821
+ raise ValueError(
822
+ f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
823
+ )
824
+ if n_channels % self.in_channels != 0:
825
+ raise ValueError(
826
+ "The input channel dimension "
827
+ f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
828
+ )
829
+
830
+ group_size = n_channels // self.in_channels
831
+
832
+ # Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
833
+ # EEG channels as the spatial height dimension.
834
+ # Shape after view: (Batch, in_channels, group_size, n_times)
835
+ x = x.view(batch_size, self.in_channels, group_size, n_times)
836
+
837
+ # Apply the temporal projection per group.
838
+ # Output shape: (Batch, emb_dim, group_size, n_patchs)
740
839
  x = self.proj(x)
741
- x = self.merge_transpose(x)
840
+
841
+ # THIS IS braindecode's MODIFICATION:
842
+ # Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
843
+ x = x.mean(dim=2)
844
+ x = x.transpose(1, 2).contiguous()
845
+
742
846
  return x
743
847
 
744
848