braindecode 1.3.0.dev168011974__py3-none-any.whl → 1.3.0.dev171178473__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (53) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +10 -2
  3. braindecode/datasets/base.py +116 -152
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +2 -2
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/serialization.py +7 -7
  15. braindecode/eegneuralnet.py +2 -0
  16. braindecode/functional/functions.py +6 -2
  17. braindecode/functional/initialization.py +2 -3
  18. braindecode/models/__init__.py +6 -0
  19. braindecode/models/atcnet.py +26 -27
  20. braindecode/models/attentionbasenet.py +39 -32
  21. braindecode/models/base.py +280 -2
  22. braindecode/models/bendr.py +469 -0
  23. braindecode/models/biot.py +2 -0
  24. braindecode/models/ctnet.py +6 -3
  25. braindecode/models/deepsleepnet.py +27 -18
  26. braindecode/models/eegconformer.py +2 -2
  27. braindecode/models/eeginception_erp.py +31 -25
  28. braindecode/models/eegnet.py +1 -1
  29. braindecode/models/labram.py +188 -84
  30. braindecode/models/patchedtransformer.py +640 -0
  31. braindecode/models/signal_jepa.py +109 -27
  32. braindecode/models/sinc_shallow.py +10 -9
  33. braindecode/models/sstdpn.py +869 -0
  34. braindecode/models/summary.csv +3 -0
  35. braindecode/models/usleep.py +26 -21
  36. braindecode/models/util.py +3 -0
  37. braindecode/modules/attention.py +10 -10
  38. braindecode/modules/blocks.py +3 -3
  39. braindecode/modules/filter.py +2 -3
  40. braindecode/modules/layers.py +18 -17
  41. braindecode/preprocessing/__init__.py +24 -0
  42. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  43. braindecode/preprocessing/preprocess.py +12 -12
  44. braindecode/preprocessing/util.py +166 -0
  45. braindecode/preprocessing/windowers.py +24 -19
  46. braindecode/samplers/base.py +8 -8
  47. braindecode/version.py +1 -1
  48. {braindecode-1.3.0.dev168011974.dist-info → braindecode-1.3.0.dev171178473.dist-info}/METADATA +6 -2
  49. {braindecode-1.3.0.dev168011974.dist-info → braindecode-1.3.0.dev171178473.dist-info}/RECORD +53 -48
  50. {braindecode-1.3.0.dev168011974.dist-info → braindecode-1.3.0.dev171178473.dist-info}/WHEEL +0 -0
  51. {braindecode-1.3.0.dev168011974.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/LICENSE.txt +0 -0
  52. {braindecode-1.3.0.dev168011974.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/NOTICE.txt +0 -0
  53. {braindecode-1.3.0.dev168011974.dist-info → braindecode-1.3.0.dev171178473.dist-info}/top_level.txt +0 -0
@@ -35,51 +35,57 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
35
35
  - :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
36
36
 
37
37
  - *Operations.*
38
- - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
39
- - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
40
- - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
41
- - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
- - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
43
- - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
38
+
39
+ - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
40
+ - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
41
+ - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
+ - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
43
+ - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
44
+ - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
44
45
 
45
46
  *Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
46
47
 
47
48
  - :class:`_InceptionModule2` **(refinement at coarser timebase)**
48
49
 
49
50
  - *Operations.*
50
- - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
51
- - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
52
- - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
53
- - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
54
- - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
55
- - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
56
- - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
51
+
52
+ - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
53
+ - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
54
+ - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` BN → activation → dropout.
55
+ - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
56
+ - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
57
+ - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
58
+ - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
57
59
 
58
60
  *Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
59
61
 
60
62
  - :class:`_OutputModule` **(aggregation + readout)**
61
63
 
62
64
  - *Operations.*
63
- - :class:`torch.nn.Flatten`
64
- - :class:`torch.nn.Linear` ``(features → 2)``
65
+
66
+ - :class:`torch.nn.Flatten`
67
+ - :class:`torch.nn.Linear` ``(features → 2)``
65
68
 
66
69
  .. rubric:: Convolutional Details
67
70
 
68
71
  - **Temporal (where time-domain patterns are learned).**
69
- First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
70
- (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
71
- ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
72
- temporal resolution changes only via average pooling.
72
+
73
+ First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
74
+ (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
75
+ ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
76
+ temporal resolution changes only via average pooling.
73
77
 
74
78
  - **Spatial (how electrodes are processed).**
75
- Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
76
- yielding scale-specific channel projections (no cross-branch mixing until concatenation).
77
- There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
79
+
80
+ Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
81
+ yielding scale-specific channel projections (no cross-branch mixing until concatenation).
82
+ There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
78
83
 
79
84
  - **Spectral (how frequency information is captured).**
80
- No explicit transform; multiple temporal kernels form a *learned filter bank* over
81
- ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
82
- post-stimulus components.
85
+
86
+ No explicit transform; multiple temporal kernels form a *learned filter bank* over
87
+ ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
88
+ post-stimulus components.
83
89
 
84
90
  .. rubric:: Additional Mechanisms
85
91
 
@@ -57,7 +57,7 @@ class EEGNet(EEGModuleMixin, nn.Sequential):
57
57
 
58
58
  - **Temporal.** The initial temporal convs serve as a *learned filter bank*:
59
59
  long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
60
- Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each features spectrum [Lawhern2018]_.
60
+ Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature's spectrum [Lawhern2018]_.
61
61
 
62
62
  - **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
63
63
  With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
@@ -22,6 +22,8 @@ from braindecode.modules import MLP, DropPath
22
22
  class Labram(EEGModuleMixin, nn.Module):
23
23
  """Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
24
 
25
+ :bdg-danger:`Large Brain Model`
26
+
25
27
  .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
26
28
  :align: center
27
29
  :alt: Labram Architecture.
@@ -43,31 +45,45 @@ class Labram(EEGModuleMixin, nn.Module):
43
45
  equals True. The original implementation uses (batch, n_chans, n_patches,
44
46
  patch_size) as input with static segmentation of the input data.
45
47
 
46
- The models have the following sequence of steps:
47
- if neural tokenizer:
48
- - SegmentPatch: Segment the input data in patches;
49
- - TemporalConv: Apply a temporal convolution to the segmented data;
50
- - Residual adding cls, temporal and position embeddings (optional);
51
- - WindowsAttentionBlock: Apply a windows attention block to the data;
52
- - LayerNorm: Apply layer normalization to the data;
53
- - Linear: An head linear layer to transformer the data into classes.
54
-
55
- else:
56
- - PatchEmbed: Apply a patch embedding to the input data;
57
- - Residual adding cls, temporal and position embeddings (optional);
58
- - WindowsAttentionBlock: Apply a windows attention block to the data;
59
- - LayerNorm: Apply layer normalization to the data;
60
- - Linear: An head linear layer to transformer the data into classes.
48
+ The models have the following sequence of steps::
49
+
50
+ if neural tokenizer:
51
+ - SegmentPatch: Segment the input data in patches;
52
+ - TemporalConv: Apply a temporal convolution to the segmented data;
53
+ - Residual adding cls, temporal and position embeddings (optional);
54
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
55
+ - LayerNorm: Apply layer normalization to the data;
56
+ - Linear: An head linear layer to transformer the data into classes.
57
+
58
+ else:
59
+ - PatchEmbed: Apply a patch embedding to the input data;
60
+ - Residual adding cls, temporal and position embeddings (optional);
61
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
62
+ - LayerNorm: Apply layer normalization to the data;
63
+ - Linear: An head linear layer to transformer the data into classes.
61
64
 
62
65
  .. versionadded:: 0.9
63
66
 
67
+
68
+ Examples
69
+ --------
70
+ Load pre-trained weights::
71
+
72
+ >>> import torch
73
+ >>> from braindecode.models import Labram
74
+ >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
75
+ >>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
76
+ >>> state = torch.hub.load_state_dict_from_url(url, progress=True)
77
+ >>> model.load_state_dict(state)
78
+
79
+
64
80
  Parameters
65
81
  ----------
66
82
  patch_size : int
67
83
  The size of the patch to be used in the patch embedding.
68
84
  emb_size : int
69
85
  The dimension of the embedding.
70
- in_channels : int
86
+ in_conv_channels : int
71
87
  The number of convolutional input channels.
72
88
  out_channels : int
73
89
  The number of convolutional output channels.
@@ -79,8 +95,10 @@ class Labram(EEGModuleMixin, nn.Module):
79
95
  The expansion ratio of the mlp layer
80
96
  qkv_bias : bool (default=False)
81
97
  If True, add a learnable bias to the query, key, and value tensors.
82
- qk_norm : Pytorch Normalize layer (default=None)
83
- If not None, apply LayerNorm to the query and key tensors
98
+ qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
99
+ If not None, apply LayerNorm to the query and key tensors.
100
+ Default is nn.LayerNorm for better weight transfer from original LaBraM.
101
+ Set to None to disable Q,K normalization.
84
102
  qk_scale : float (default=None)
85
103
  If not None, use this value as the scale factor. If None,
86
104
  use head_dim**-0.5, where head_dim = dim // num_heads.
@@ -92,9 +110,10 @@ class Labram(EEGModuleMixin, nn.Module):
92
110
  Dropout rate for the attention weights used on DropPath.
93
111
  norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
94
112
  The normalization layer to be used.
95
- init_values : float (default=None)
113
+ init_values : float (default=0.1)
96
114
  If not None, use this value to initialize the gamma_1 and gamma_2
97
- parameters.
115
+ parameters for residual scaling. Default is 0.1 for better weight
116
+ transfer from original LaBraM. Set to None to disable.
98
117
  use_abs_pos_emb : bool (default=True)
99
118
  If True, use absolute position embedding.
100
119
  use_mean_pooling : bool (default=True)
@@ -135,19 +154,19 @@ class Labram(EEGModuleMixin, nn.Module):
135
154
  input_window_seconds=None,
136
155
  patch_size=200,
137
156
  emb_size=200,
138
- in_channels=1,
157
+ in_conv_channels=1,
139
158
  out_channels=8,
140
159
  n_layers=12,
141
160
  att_num_heads=10,
142
161
  mlp_ratio=4.0,
143
162
  qkv_bias=False,
144
- qk_norm=None,
163
+ qk_norm=nn.LayerNorm,
145
164
  qk_scale=None,
146
165
  drop_prob=0.0,
147
166
  attn_drop_prob=0.0,
148
167
  drop_path_prob=0.0,
149
168
  norm_layer=nn.LayerNorm,
150
- init_values=None,
169
+ init_values=0.1,
151
170
  use_abs_pos_emb=True,
152
171
  use_mean_pooling=True,
153
172
  init_scale=0.001,
@@ -183,15 +202,15 @@ class Labram(EEGModuleMixin, nn.Module):
183
202
  self.patch_size = patch_size
184
203
  self.n_path = self.n_times // self.patch_size
185
204
 
186
- if neural_tokenizer and in_channels != 1:
205
+ if neural_tokenizer and in_conv_channels != 1:
187
206
  warn(
188
207
  "The model is in Neural Tokenizer mode, but the variable "
189
- + "`in_channels` is different from the default values."
190
- + "`in_channels` is only needed for the Neural Decoder mode."
191
- + "in_channels is not used in the Neural Tokenizer mode.",
208
+ + "`in_conv_channels` is different from the default values."
209
+ + "`in_conv_channels` is only needed for the Neural Decoder mode."
210
+ + "in_conv_channels is not used in the Neural Tokenizer mode.",
192
211
  UserWarning,
193
212
  )
194
- in_channels = 1
213
+ in_conv_channels = 1
195
214
  # If you can use the model in Neural Tokenizer mode,
196
215
  # temporal conv layer will be use over the patched dataset
197
216
  if neural_tokenizer:
@@ -228,7 +247,7 @@ class Labram(EEGModuleMixin, nn.Module):
228
247
  _PatchEmbed(
229
248
  n_times=self.n_times,
230
249
  patch_size=patch_size,
231
- in_channels=in_channels,
250
+ in_channels=in_conv_channels,
232
251
  emb_dim=self.emb_size,
233
252
  ),
234
253
  )
@@ -373,8 +392,7 @@ class Labram(EEGModuleMixin, nn.Module):
373
392
  Parameters
374
393
  ----------
375
394
  x : torch.Tensor
376
- The input data with shape (batch, n_chans, n_patches, patch size),
377
- if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
395
+ The input data with shape (batch, n_chans, n_times).
378
396
  input_chans : int
379
397
  The number of input channels.
380
398
  return_patch_tokens : bool
@@ -387,37 +405,72 @@ class Labram(EEGModuleMixin, nn.Module):
387
405
  x : torch.Tensor
388
406
  The output of the model.
389
407
  """
408
+ batch_size = x.shape[0]
409
+
390
410
  if self.neural_tokenizer:
391
- batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
411
+ # For neural tokenizer: input is (batch, n_chans, n_times)
412
+ # patch_embed returns (batch, n_chans, emb_dim)
413
+ x = self.patch_embed(x)
414
+ # x shape: (batch, n_chans, emb_dim)
415
+ n_patch = self.n_chans
416
+ temporal = self.emb_size
392
417
  else:
393
- batch_size, nch, n_patch = self.patch_embed(x).shape
394
- x = self.patch_embed(x)
418
+ # For neural decoder: input is (batch, n_chans, n_times)
419
+ # patch_embed returns (batch, n_patchs, emb_dim)
420
+ x = self.patch_embed(x)
421
+ # x shape: (batch, n_patchs, emb_dim)
422
+ batch_size, n_patch, temporal = x.shape
423
+
395
424
  # add the [CLS] token to the embedded patch tokens
396
425
  cls_tokens = self.cls_token.expand(batch_size, -1, -1)
397
426
 
427
+ # Concatenate cls token with patch/channel embeddings
398
428
  x = torch.cat((cls_tokens, x), dim=1)
399
429
 
400
430
  # Positional Embedding
401
- if input_chans is not None:
402
- pos_embed_used = self.position_embedding[:, input_chans]
403
- else:
404
- pos_embed_used = self.position_embedding
405
-
406
431
  if self.position_embedding is not None:
407
- pos_embed = self._adj_position_embedding(
408
- pos_embed_used=pos_embed_used, batch_size=batch_size
409
- )
432
+ if self.neural_tokenizer:
433
+ # In tokenizer mode, use channel-based position embedding
434
+ if input_chans is not None:
435
+ pos_embed_used = self.position_embedding[:, input_chans]
436
+ else:
437
+ pos_embed_used = self.position_embedding
438
+
439
+ pos_embed = self._adj_position_embedding(
440
+ pos_embed_used=pos_embed_used, batch_size=batch_size
441
+ )
442
+ else:
443
+ # In decoder mode, we have different number of patches
444
+ # Adapt position embedding for n_patch patches
445
+ # Use the first n_patch+1 positions from position_embedding
446
+ n_pos = min(self.position_embedding.shape[1], n_patch + 1)
447
+ pos_embed_used = self.position_embedding[:, :n_pos, :]
448
+ pos_embed = pos_embed_used.expand(batch_size, -1, -1)
449
+
410
450
  x += pos_embed
411
451
 
412
452
  # The time embedding is added across the channels after the [CLS] token
413
453
  if self.neural_tokenizer:
414
454
  num_ch = self.n_chans
455
+ time_embed = self._adj_temporal_embedding(
456
+ num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
457
+ )
458
+ x[:, 1:, :] += time_embed
415
459
  else:
416
- num_ch = n_patch
417
- time_embed = self._adj_temporal_embedding(
418
- num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
419
- )
420
- x[:, 1:, :] += time_embed
460
+ # In decoder mode, we have n_patch patches and don't need to expand
461
+ # Just broadcast the temporal embedding
462
+ if temporal is None:
463
+ temporal = self.emb_size
464
+
465
+ # Get temporal embeddings for n_patch patches
466
+ n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
467
+ time_embed = self.temporal_embedding[
468
+ :, 1 : n_time_tokens + 1, :
469
+ ] # (1, n_patch, emb_dim)
470
+ time_embed = time_embed.expand(
471
+ batch_size, -1, -1
472
+ ) # (batch, n_patch, emb_dim)
473
+ x[:, 1:, :] += time_embed
421
474
 
422
475
  x = self.pos_drop(x)
423
476
 
@@ -428,10 +481,10 @@ class Labram(EEGModuleMixin, nn.Module):
428
481
  if self.fc_norm is not None:
429
482
  if return_all_tokens:
430
483
  return self.fc_norm(x)
431
- temporal = x[:, 1:, :]
484
+ tokens = x[:, 1:, :]
432
485
  if return_patch_tokens:
433
- return self.fc_norm(temporal)
434
- return self.fc_norm(temporal.mean(1))
486
+ return self.fc_norm(tokens)
487
+ return self.fc_norm(tokens.mean(1))
435
488
  else:
436
489
  if return_all_tokens:
437
490
  return x
@@ -505,14 +558,16 @@ class Labram(EEGModuleMixin, nn.Module):
505
558
  def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
506
559
  """
507
560
  Adjust the dimensions of the time embedding to match the
508
- number of channels.
561
+ number of channels or patches.
509
562
 
510
563
  Parameters
511
564
  ----------
512
565
  num_ch : int
513
- The number of channels or number of code books vectors.
566
+ The number of channels or number of patches.
514
567
  batch_size : int
515
568
  Batch size of the input data.
569
+ dim_embed : int
570
+ The embedding dimension (temporal feature dimension).
516
571
 
517
572
  Returns
518
573
  -------
@@ -523,17 +578,24 @@ class Labram(EEGModuleMixin, nn.Module):
523
578
  if dim_embed is None:
524
579
  cut_dimension = self.patch_size
525
580
  else:
526
- cut_dimension = dim_embed
527
- # first step will be match the time_embed to the number of channels
528
- temporal_embedding = self.temporal_embedding[:, 1:cut_dimension, :]
581
+ cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
582
+
583
+ # Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
584
+ # Slice to cut_dimension: (1, cut_dimension, emb_size)
585
+ temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
586
+
529
587
  # Add a new dimension to the time embedding
530
- # e.g. (batch, 62, 200) -> (batch, 1, 62, 200)
588
+ # e.g. (1, 5, 200) -> (1, 1, 5, 200)
531
589
  temporal_embedding = temporal_embedding.unsqueeze(1)
532
- # Expand the time embedding to match the number of channels
533
- # or number of patches from
590
+
591
+ # Expand the time embedding to match the number of channels or patches
592
+ # (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
534
593
  temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
594
+
535
595
  # Flatten the intermediate dimensions
596
+ # (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
536
597
  temporal_embedding = temporal_embedding.flatten(1, 2)
598
+
537
599
  return temporal_embedding
538
600
 
539
601
  def _adj_position_embedding(self, pos_embed_used, batch_size):
@@ -679,25 +741,27 @@ class _SegmentPatch(nn.Module):
679
741
 
680
742
 
681
743
  class _PatchEmbed(nn.Module):
682
- """EEG to Patch Embedding.
744
+ """EEG to Patch Embedding for Neural Decoder mode.
683
745
 
684
746
  This code is used when we want to apply the patch embedding
685
- after the codebook layer.
747
+ after the codebook layer (Neural Decoder mode).
748
+
749
+ The input is expected to be in the format (Batch, n_channels, n_times),
750
+ but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
751
+ This class reshapes the input to the pre-patched format, then applies a 2D
752
+ convolution to project this pre-patched data to the embedding dimension,
753
+ and finally flattens across channels to produce a unified embedding.
686
754
 
687
755
  Parameters:
688
756
  -----------
689
757
  n_times: int (default=2000)
690
- Number of temporal components of the input tensor.
758
+ Number of temporal components of the input tensor (used for dimension calculation).
691
759
  patch_size: int (default=200)
692
760
  Size of the patch, default is 1-seconds with 200Hz.
693
761
  in_channels: int (default=1)
694
- Number of input channels for to be used in the convolution.
762
+ Number of input channels (from VQVAE codebook).
695
763
  emb_dim: int (default=200)
696
- Number of out_channes to be used in the convolution, here,
697
- we used the same as patch_size.
698
- n_codebooks: int (default=62)
699
- Number of patches to be used in the convolution, here,
700
- we used the same as n_times // patch_size.
764
+ Number of output embedding dimension.
701
765
  """
702
766
 
703
767
  def __init__(
@@ -707,10 +771,13 @@ class _PatchEmbed(nn.Module):
707
771
  self.n_times = n_times
708
772
  self.patch_size = patch_size
709
773
  self.patch_shape = (1, self.n_times // self.patch_size)
710
- n_patchs = n_codebooks * (self.n_times // self.patch_size)
711
-
712
- self.n_patchs = n_patchs
774
+ self.n_patchs = self.n_times // self.patch_size
775
+ self.emb_dim = emb_dim
776
+ self.in_channels = in_channels
713
777
 
778
+ # 2D Conv to project the pre-patched data
779
+ # Input: (Batch, in_channels, n_patches, patch_size)
780
+ # After proj: (Batch, emb_dim, n_patches, 1)
714
781
  self.proj = nn.Conv2d(
715
782
  in_channels=in_channels,
716
783
  out_channels=emb_dim,
@@ -718,27 +785,64 @@ class _PatchEmbed(nn.Module):
718
785
  stride=(1, self.patch_size),
719
786
  )
720
787
 
721
- self.merge_transpose = Rearrange(
722
- "Batch ch patch spatch -> Batch patch spatch ch",
723
- )
724
-
725
788
  def forward(self, x):
726
789
  """
727
- Apply the convolution to the input tensor.
728
- then merge the output tensor to the desired shape.
790
+ Apply the temporal projection to the input tensor after grouping channels.
729
791
 
730
- Parameters:
731
- -----------
732
- x: torch.Tensor
733
- Input tensor of shape (Batch, Channels, n_patchs, patch_size).
792
+ Parameters
793
+ ----------
794
+ x : torch.Tensor
795
+ Input tensor of shape (Batch, n_channels, n_times) or
796
+ (Batch, n_channels, n_patches, patch_size).
734
797
 
735
- Return:
798
+ Returns
736
799
  -------
737
- x: torch.Tensor
738
- Output tensor of shape (Batch, n_patchs, patch_size, channels).
800
+ torch.Tensor
801
+ Output tensor of shape (Batch, n_patchs, emb_dim).
739
802
  """
803
+ if x.ndim == 4:
804
+ batch_size, n_channels, n_patchs, patch_len = x.shape
805
+ if patch_len != self.patch_size:
806
+ raise ValueError(
807
+ "When providing a 4D tensor, the last dimension "
808
+ f"({patch_len}) must match patch_size ({self.patch_size})."
809
+ )
810
+ n_times = n_patchs * patch_len
811
+ x = x.reshape(batch_size, n_channels, n_times)
812
+ elif x.ndim == 3:
813
+ batch_size, n_channels, n_times = x.shape
814
+ else:
815
+ raise ValueError(
816
+ "Input must be either 3D (batch, channels, times) or "
817
+ "4D (batch, channels, n_patches, patch_size)."
818
+ )
819
+
820
+ if n_times % self.patch_size != 0:
821
+ raise ValueError(
822
+ f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
823
+ )
824
+ if n_channels % self.in_channels != 0:
825
+ raise ValueError(
826
+ "The input channel dimension "
827
+ f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
828
+ )
829
+
830
+ group_size = n_channels // self.in_channels
831
+
832
+ # Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
833
+ # EEG channels as the spatial height dimension.
834
+ # Shape after view: (Batch, in_channels, group_size, n_times)
835
+ x = x.view(batch_size, self.in_channels, group_size, n_times)
836
+
837
+ # Apply the temporal projection per group.
838
+ # Output shape: (Batch, emb_dim, group_size, n_patchs)
740
839
  x = self.proj(x)
741
- x = self.merge_transpose(x)
840
+
841
+ # THIS IS braindecode's MODIFICATION:
842
+ # Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
843
+ x = x.mean(dim=2)
844
+ x = x.transpose(1, 2).contiguous()
845
+
742
846
  return x
743
847
 
744
848