braindecode 1.3.0.dev180851780__py3-none-any.whl → 1.3.0.dev181594385__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (66) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +10 -2
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +2 -2
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/serialization.py +7 -7
  15. braindecode/eegneuralnet.py +2 -0
  16. braindecode/functional/functions.py +6 -2
  17. braindecode/functional/initialization.py +2 -3
  18. braindecode/models/__init__.py +2 -0
  19. braindecode/models/atcnet.py +26 -27
  20. braindecode/models/attentionbasenet.py +39 -32
  21. braindecode/models/attn_sleep.py +2 -0
  22. braindecode/models/base.py +280 -2
  23. braindecode/models/bendr.py +469 -0
  24. braindecode/models/biot.py +2 -0
  25. braindecode/models/contrawr.py +2 -0
  26. braindecode/models/ctnet.py +8 -3
  27. braindecode/models/deepsleepnet.py +28 -19
  28. braindecode/models/eegconformer.py +2 -2
  29. braindecode/models/eeginception_erp.py +31 -25
  30. braindecode/models/eegitnet.py +2 -0
  31. braindecode/models/eegminer.py +2 -0
  32. braindecode/models/eegnet.py +1 -1
  33. braindecode/models/eegtcnet.py +2 -0
  34. braindecode/models/fbcnet.py +2 -0
  35. braindecode/models/fblightconvnet.py +2 -0
  36. braindecode/models/fbmsnet.py +2 -0
  37. braindecode/models/ifnet.py +2 -0
  38. braindecode/models/labram.py +193 -87
  39. braindecode/models/msvtnet.py +2 -0
  40. braindecode/models/patchedtransformer.py +1 -1
  41. braindecode/models/signal_jepa.py +111 -27
  42. braindecode/models/sinc_shallow.py +12 -9
  43. braindecode/models/sstdpn.py +11 -11
  44. braindecode/models/summary.csv +1 -0
  45. braindecode/models/syncnet.py +2 -0
  46. braindecode/models/tcn.py +2 -0
  47. braindecode/models/usleep.py +26 -21
  48. braindecode/models/util.py +1 -0
  49. braindecode/modules/attention.py +10 -10
  50. braindecode/modules/blocks.py +3 -3
  51. braindecode/modules/filter.py +2 -3
  52. braindecode/modules/layers.py +18 -17
  53. braindecode/preprocessing/__init__.py +24 -0
  54. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  55. braindecode/preprocessing/preprocess.py +12 -12
  56. braindecode/preprocessing/util.py +166 -0
  57. braindecode/preprocessing/windowers.py +24 -19
  58. braindecode/samplers/base.py +8 -8
  59. braindecode/version.py +1 -1
  60. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/METADATA +6 -2
  61. braindecode-1.3.0.dev181594385.dist-info/RECORD +106 -0
  62. braindecode-1.3.0.dev180851780.dist-info/RECORD +0 -103
  63. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/WHEEL +0 -0
  64. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/LICENSE.txt +0 -0
  65. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/NOTICE.txt +0 -0
  66. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/top_level.txt +0 -0
@@ -35,51 +35,57 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
35
35
  - :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
36
36
 
37
37
  - *Operations.*
38
- - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
39
- - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
40
- - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
41
- - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
- - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
43
- - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
38
+
39
+ - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
40
+ - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
41
+ - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
+ - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
43
+ - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
44
+ - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
44
45
 
45
46
  *Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
46
47
 
47
48
  - :class:`_InceptionModule2` **(refinement at coarser timebase)**
48
49
 
49
50
  - *Operations.*
50
- - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
51
- - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
52
- - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
53
- - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
54
- - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
55
- - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
56
- - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
51
+
52
+ - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
53
+ - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
54
+ - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` BN → activation → dropout.
55
+ - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
56
+ - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
57
+ - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
58
+ - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
57
59
 
58
60
  *Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
59
61
 
60
62
  - :class:`_OutputModule` **(aggregation + readout)**
61
63
 
62
64
  - *Operations.*
63
- - :class:`torch.nn.Flatten`
64
- - :class:`torch.nn.Linear` ``(features → 2)``
65
+
66
+ - :class:`torch.nn.Flatten`
67
+ - :class:`torch.nn.Linear` ``(features → 2)``
65
68
 
66
69
  .. rubric:: Convolutional Details
67
70
 
68
71
  - **Temporal (where time-domain patterns are learned).**
69
- First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
70
- (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
71
- ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
72
- temporal resolution changes only via average pooling.
72
+
73
+ First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
74
+ (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
75
+ ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
76
+ temporal resolution changes only via average pooling.
73
77
 
74
78
  - **Spatial (how electrodes are processed).**
75
- Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
76
- yielding scale-specific channel projections (no cross-branch mixing until concatenation).
77
- There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
79
+
80
+ Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
81
+ yielding scale-specific channel projections (no cross-branch mixing until concatenation).
82
+ There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
78
83
 
79
84
  - **Spectral (how frequency information is captured).**
80
- No explicit transform; multiple temporal kernels form a *learned filter bank* over
81
- ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
82
- post-stimulus components.
85
+
86
+ No explicit transform; multiple temporal kernels form a *learned filter bank* over
87
+ ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
88
+ post-stimulus components.
83
89
 
84
90
  .. rubric:: Additional Mechanisms
85
91
 
@@ -11,6 +11,8 @@ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
11
11
  class EEGITNet(EEGModuleMixin, nn.Sequential):
12
12
  """EEG-ITNet from Salami, et al (2022) [Salami2022]_
13
13
 
14
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
15
+
14
16
  .. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
15
17
  :align: center
16
18
  :alt: EEG-ITNet Architecture
@@ -21,6 +21,8 @@ _eeg_miner_methods = ["mag", "corr", "plv"]
21
21
  class EEGMiner(EEGModuleMixin, nn.Module):
22
22
  """EEGMiner from Ludwig et al (2024) [eegminer]_.
23
23
 
24
+ :bdg-success:`Convolution` :bdg-warning:`Interpretability`
25
+
24
26
  .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036010/revision2/jnead44d7f1_hr.jpg
25
27
  :align: center
26
28
  :alt: EEGMiner Architecture
@@ -57,7 +57,7 @@ class EEGNet(EEGModuleMixin, nn.Sequential):
57
57
 
58
58
  - **Temporal.** The initial temporal convs serve as a *learned filter bank*:
59
59
  long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
60
- Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each features spectrum [Lawhern2018]_.
60
+ Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature's spectrum [Lawhern2018]_.
61
61
 
62
62
  - **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
63
63
  With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
@@ -15,6 +15,8 @@ from braindecode.modules import Chomp1d, MaxNormLinear
15
15
  class EEGTCNet(EEGModuleMixin, nn.Module):
16
16
  """EEGTCNet model from Ingolfsson et al. (2020) [ingolfsson2020]_.
17
17
 
18
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
19
+
18
20
  .. figure:: https://braindecode.org/dev/_static/model/eegtcnet.jpg
19
21
  :align: center
20
22
  :alt: EEGTCNet Architecture
@@ -31,6 +31,8 @@ _valid_layers = {
31
31
  class FBCNet(EEGModuleMixin, nn.Module):
32
32
  """FBCNet from Mane, R et al (2021) [fbcnet2021]_.
33
33
 
34
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
35
+
34
36
  .. figure:: https://raw.githubusercontent.com/ravikiran-mane/FBCNet/refs/heads/master/FBCNet-V2.png
35
37
  :align: center
36
38
  :alt: FBCNet Architecture
@@ -18,6 +18,8 @@ from braindecode.modules import (
18
18
  class FBLightConvNet(EEGModuleMixin, nn.Module):
19
19
  """LightConvNet from Ma, X et al (2023) [lightconvnet]_.
20
20
 
21
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
22
+
21
23
  .. figure:: https://raw.githubusercontent.com/Ma-Xinzhi/LightConvNet/refs/heads/main/network_architecture.png
22
24
  :align: center
23
25
  :alt: LightConvNet Neural Network
@@ -19,6 +19,8 @@ from braindecode.modules import (
19
19
  class FBMSNet(EEGModuleMixin, nn.Module):
20
20
  """FBMSNet from Liu et al (2022) [fbmsnet]_.
21
21
 
22
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
23
+
22
24
  .. figure:: https://raw.githubusercontent.com/Want2Vanish/FBMSNet/refs/heads/main/FBMSNet.png
23
25
  :align: center
24
26
  :alt: FBMSNet Architecture
@@ -31,6 +31,8 @@ from braindecode.modules import (
31
31
  class IFNet(EEGModuleMixin, nn.Module):
32
32
  """IFNetV2 from Wang J et al (2023) [ifnet]_.
33
33
 
34
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
35
+
34
36
  .. figure:: https://raw.githubusercontent.com/Jiaheng-Wang/IFNet/main/IFNet.png
35
37
  :align: center
36
38
  :alt: IFNetV2 Architecture
@@ -2,6 +2,7 @@
2
2
  Labram module.
3
3
  Authors: Wei-Bang Jiang
4
4
  Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ Matthew Chen <matt.chen4260@gmail.com>
5
6
  License: BSD 3 clause
6
7
  """
7
8
 
@@ -22,12 +23,14 @@ from braindecode.modules import MLP, DropPath
22
23
  class Labram(EEGModuleMixin, nn.Module):
23
24
  """Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
25
 
26
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
27
+
25
28
  .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
26
29
  :align: center
27
30
  :alt: Labram Architecture.
28
31
 
29
32
  Large Brain Model for Learning Generic Representations with Tremendous
30
- EEG Data in BCI from [Jiang2024]_
33
+ EEG Data in BCI from [Jiang2024]_.
31
34
 
32
35
  This is an **adaptation** of the code [Code2024]_ from the Labram model.
33
36
 
@@ -35,7 +38,8 @@ class Labram(EEGModuleMixin, nn.Module):
35
38
  BEiTv2 [BeiTv2]_.
36
39
 
37
40
  The models can be used in two modes:
38
- - Neural Tokenizor: Design to get an embedding layers (e.g. classification).
41
+
42
+ - Neural Tokenizer: Design to get an embedding layers (e.g. classification).
39
43
  - Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
40
44
 
41
45
  The braindecode's modification is to allow the model to be used in
@@ -43,31 +47,45 @@ class Labram(EEGModuleMixin, nn.Module):
43
47
  equals True. The original implementation uses (batch, n_chans, n_patches,
44
48
  patch_size) as input with static segmentation of the input data.
45
49
 
46
- The models have the following sequence of steps:
47
- if neural tokenizer:
48
- - SegmentPatch: Segment the input data in patches;
49
- - TemporalConv: Apply a temporal convolution to the segmented data;
50
- - Residual adding cls, temporal and position embeddings (optional);
51
- - WindowsAttentionBlock: Apply a windows attention block to the data;
52
- - LayerNorm: Apply layer normalization to the data;
53
- - Linear: An head linear layer to transformer the data into classes.
54
-
55
- else:
56
- - PatchEmbed: Apply a patch embedding to the input data;
57
- - Residual adding cls, temporal and position embeddings (optional);
58
- - WindowsAttentionBlock: Apply a windows attention block to the data;
59
- - LayerNorm: Apply layer normalization to the data;
60
- - Linear: An head linear layer to transformer the data into classes.
50
+ The models have the following sequence of steps::
51
+
52
+ if neural tokenizer:
53
+ - SegmentPatch: Segment the input data in patches;
54
+ - TemporalConv: Apply a temporal convolution to the segmented data;
55
+ - Residual adding cls, temporal and position embeddings (optional);
56
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
57
+ - LayerNorm: Apply layer normalization to the data;
58
+ - Linear: An head linear layer to transformer the data into classes.
59
+
60
+ else:
61
+ - PatchEmbed: Apply a patch embedding to the input data;
62
+ - Residual adding cls, temporal and position embeddings (optional);
63
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
64
+ - LayerNorm: Apply layer normalization to the data;
65
+ - Linear: An head linear layer to transformer the data into classes.
61
66
 
62
67
  .. versionadded:: 0.9
63
68
 
69
+
70
+ Examples
71
+ --------
72
+ Load pre-trained weights::
73
+
74
+ >>> import torch
75
+ >>> from braindecode.models import Labram
76
+ >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
77
+ >>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
78
+ >>> state = torch.hub.load_state_dict_from_url(url, progress=True)
79
+ >>> model.load_state_dict(state)
80
+
81
+
64
82
  Parameters
65
83
  ----------
66
84
  patch_size : int
67
85
  The size of the patch to be used in the patch embedding.
68
86
  emb_size : int
69
87
  The dimension of the embedding.
70
- in_channels : int
88
+ in_conv_channels : int
71
89
  The number of convolutional input channels.
72
90
  out_channels : int
73
91
  The number of convolutional output channels.
@@ -79,8 +97,10 @@ class Labram(EEGModuleMixin, nn.Module):
79
97
  The expansion ratio of the mlp layer
80
98
  qkv_bias : bool (default=False)
81
99
  If True, add a learnable bias to the query, key, and value tensors.
82
- qk_norm : Pytorch Normalize layer (default=None)
83
- If not None, apply LayerNorm to the query and key tensors
100
+ qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
101
+ If not None, apply LayerNorm to the query and key tensors.
102
+ Default is nn.LayerNorm for better weight transfer from original LaBraM.
103
+ Set to None to disable Q,K normalization.
84
104
  qk_scale : float (default=None)
85
105
  If not None, use this value as the scale factor. If None,
86
106
  use head_dim**-0.5, where head_dim = dim // num_heads.
@@ -92,9 +112,10 @@ class Labram(EEGModuleMixin, nn.Module):
92
112
  Dropout rate for the attention weights used on DropPath.
93
113
  norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
94
114
  The normalization layer to be used.
95
- init_values : float (default=None)
115
+ init_values : float (default=0.1)
96
116
  If not None, use this value to initialize the gamma_1 and gamma_2
97
- parameters.
117
+ parameters for residual scaling. Default is 0.1 for better weight
118
+ transfer from original LaBraM. Set to None to disable.
98
119
  use_abs_pos_emb : bool (default=True)
99
120
  If True, use absolute position embedding.
100
121
  use_mean_pooling : bool (default=True)
@@ -102,7 +123,7 @@ class Labram(EEGModuleMixin, nn.Module):
102
123
  init_scale : float (default=0.001)
103
124
  The initial scale to be used in the parameters of the model.
104
125
  neural_tokenizer : bool (default=True)
105
- The model can be used in two modes: Neural Tokenizor or Neural Decoder.
126
+ The model can be used in two modes: Neural Tokenizer or Neural Decoder.
106
127
  attn_head_dim : bool (default=None)
107
128
  The head dimension to be used in the attention layer, to be used only
108
129
  during pre-training.
@@ -135,19 +156,19 @@ class Labram(EEGModuleMixin, nn.Module):
135
156
  input_window_seconds=None,
136
157
  patch_size=200,
137
158
  emb_size=200,
138
- in_channels=1,
159
+ in_conv_channels=1,
139
160
  out_channels=8,
140
161
  n_layers=12,
141
162
  att_num_heads=10,
142
163
  mlp_ratio=4.0,
143
164
  qkv_bias=False,
144
- qk_norm=None,
165
+ qk_norm=nn.LayerNorm,
145
166
  qk_scale=None,
146
167
  drop_prob=0.0,
147
168
  attn_drop_prob=0.0,
148
169
  drop_path_prob=0.0,
149
170
  norm_layer=nn.LayerNorm,
150
- init_values=None,
171
+ init_values=0.1,
151
172
  use_abs_pos_emb=True,
152
173
  use_mean_pooling=True,
153
174
  init_scale=0.001,
@@ -183,15 +204,15 @@ class Labram(EEGModuleMixin, nn.Module):
183
204
  self.patch_size = patch_size
184
205
  self.n_path = self.n_times // self.patch_size
185
206
 
186
- if neural_tokenizer and in_channels != 1:
207
+ if neural_tokenizer and in_conv_channels != 1:
187
208
  warn(
188
209
  "The model is in Neural Tokenizer mode, but the variable "
189
- + "`in_channels` is different from the default values."
190
- + "`in_channels` is only needed for the Neural Decoder mode."
191
- + "in_channels is not used in the Neural Tokenizer mode.",
210
+ + "`in_conv_channels` is different from the default values."
211
+ + "`in_conv_channels` is only needed for the Neural Decoder mode."
212
+ + "in_conv_channels is not used in the Neural Tokenizer mode.",
192
213
  UserWarning,
193
214
  )
194
- in_channels = 1
215
+ in_conv_channels = 1
195
216
  # If you can use the model in Neural Tokenizer mode,
196
217
  # temporal conv layer will be use over the patched dataset
197
218
  if neural_tokenizer:
@@ -228,7 +249,7 @@ class Labram(EEGModuleMixin, nn.Module):
228
249
  _PatchEmbed(
229
250
  n_times=self.n_times,
230
251
  patch_size=patch_size,
231
- in_channels=in_channels,
252
+ in_channels=in_conv_channels,
232
253
  emb_dim=self.emb_size,
233
254
  ),
234
255
  )
@@ -373,8 +394,7 @@ class Labram(EEGModuleMixin, nn.Module):
373
394
  Parameters
374
395
  ----------
375
396
  x : torch.Tensor
376
- The input data with shape (batch, n_chans, n_patches, patch size),
377
- if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
397
+ The input data with shape (batch, n_chans, n_times).
378
398
  input_chans : int
379
399
  The number of input channels.
380
400
  return_patch_tokens : bool
@@ -387,37 +407,72 @@ class Labram(EEGModuleMixin, nn.Module):
387
407
  x : torch.Tensor
388
408
  The output of the model.
389
409
  """
410
+ batch_size = x.shape[0]
411
+
390
412
  if self.neural_tokenizer:
391
- batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
413
+ # For neural tokenizer: input is (batch, n_chans, n_times)
414
+ # patch_embed returns (batch, n_chans, emb_dim)
415
+ x = self.patch_embed(x)
416
+ # x shape: (batch, n_chans, emb_dim)
417
+ n_patch = self.n_chans
418
+ temporal = self.emb_size
392
419
  else:
393
- batch_size, nch, n_patch = self.patch_embed(x).shape
394
- x = self.patch_embed(x)
420
+ # For neural decoder: input is (batch, n_chans, n_times)
421
+ # patch_embed returns (batch, n_patchs, emb_dim)
422
+ x = self.patch_embed(x)
423
+ # x shape: (batch, n_patchs, emb_dim)
424
+ batch_size, n_patch, temporal = x.shape
425
+
395
426
  # add the [CLS] token to the embedded patch tokens
396
427
  cls_tokens = self.cls_token.expand(batch_size, -1, -1)
397
428
 
429
+ # Concatenate cls token with patch/channel embeddings
398
430
  x = torch.cat((cls_tokens, x), dim=1)
399
431
 
400
432
  # Positional Embedding
401
- if input_chans is not None:
402
- pos_embed_used = self.position_embedding[:, input_chans]
403
- else:
404
- pos_embed_used = self.position_embedding
405
-
406
433
  if self.position_embedding is not None:
407
- pos_embed = self._adj_position_embedding(
408
- pos_embed_used=pos_embed_used, batch_size=batch_size
409
- )
434
+ if self.neural_tokenizer:
435
+ # In tokenizer mode, use channel-based position embedding
436
+ if input_chans is not None:
437
+ pos_embed_used = self.position_embedding[:, input_chans]
438
+ else:
439
+ pos_embed_used = self.position_embedding
440
+
441
+ pos_embed = self._adj_position_embedding(
442
+ pos_embed_used=pos_embed_used, batch_size=batch_size
443
+ )
444
+ else:
445
+ # In decoder mode, we have different number of patches
446
+ # Adapt position embedding for n_patch patches
447
+ # Use the first n_patch+1 positions from position_embedding
448
+ n_pos = min(self.position_embedding.shape[1], n_patch + 1)
449
+ pos_embed_used = self.position_embedding[:, :n_pos, :]
450
+ pos_embed = pos_embed_used.expand(batch_size, -1, -1)
451
+
410
452
  x += pos_embed
411
453
 
412
454
  # The time embedding is added across the channels after the [CLS] token
413
455
  if self.neural_tokenizer:
414
456
  num_ch = self.n_chans
457
+ time_embed = self._adj_temporal_embedding(
458
+ num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
459
+ )
460
+ x[:, 1:, :] += time_embed
415
461
  else:
416
- num_ch = n_patch
417
- time_embed = self._adj_temporal_embedding(
418
- num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
419
- )
420
- x[:, 1:, :] += time_embed
462
+ # In decoder mode, we have n_patch patches and don't need to expand
463
+ # Just broadcast the temporal embedding
464
+ if temporal is None:
465
+ temporal = self.emb_size
466
+
467
+ # Get temporal embeddings for n_patch patches
468
+ n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
469
+ time_embed = self.temporal_embedding[
470
+ :, 1 : n_time_tokens + 1, :
471
+ ] # (1, n_patch, emb_dim)
472
+ time_embed = time_embed.expand(
473
+ batch_size, -1, -1
474
+ ) # (batch, n_patch, emb_dim)
475
+ x[:, 1:, :] += time_embed
421
476
 
422
477
  x = self.pos_drop(x)
423
478
 
@@ -428,10 +483,10 @@ class Labram(EEGModuleMixin, nn.Module):
428
483
  if self.fc_norm is not None:
429
484
  if return_all_tokens:
430
485
  return self.fc_norm(x)
431
- temporal = x[:, 1:, :]
486
+ tokens = x[:, 1:, :]
432
487
  if return_patch_tokens:
433
- return self.fc_norm(temporal)
434
- return self.fc_norm(temporal.mean(1))
488
+ return self.fc_norm(tokens)
489
+ return self.fc_norm(tokens.mean(1))
435
490
  else:
436
491
  if return_all_tokens:
437
492
  return x
@@ -505,14 +560,16 @@ class Labram(EEGModuleMixin, nn.Module):
505
560
  def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
506
561
  """
507
562
  Adjust the dimensions of the time embedding to match the
508
- number of channels.
563
+ number of channels or patches.
509
564
 
510
565
  Parameters
511
566
  ----------
512
567
  num_ch : int
513
- The number of channels or number of code books vectors.
568
+ The number of channels or number of patches.
514
569
  batch_size : int
515
570
  Batch size of the input data.
571
+ dim_embed : int
572
+ The embedding dimension (temporal feature dimension).
516
573
 
517
574
  Returns
518
575
  -------
@@ -523,17 +580,24 @@ class Labram(EEGModuleMixin, nn.Module):
523
580
  if dim_embed is None:
524
581
  cut_dimension = self.patch_size
525
582
  else:
526
- cut_dimension = dim_embed
527
- # first step will be match the time_embed to the number of channels
528
- temporal_embedding = self.temporal_embedding[:, 1:cut_dimension, :]
583
+ cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
584
+
585
+ # Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
586
+ # Slice to cut_dimension: (1, cut_dimension, emb_size)
587
+ temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
588
+
529
589
  # Add a new dimension to the time embedding
530
- # e.g. (batch, 62, 200) -> (batch, 1, 62, 200)
590
+ # e.g. (1, 5, 200) -> (1, 1, 5, 200)
531
591
  temporal_embedding = temporal_embedding.unsqueeze(1)
532
- # Expand the time embedding to match the number of channels
533
- # or number of patches from
592
+
593
+ # Expand the time embedding to match the number of channels or patches
594
+ # (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
534
595
  temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
596
+
535
597
  # Flatten the intermediate dimensions
598
+ # (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
536
599
  temporal_embedding = temporal_embedding.flatten(1, 2)
600
+
537
601
  return temporal_embedding
538
602
 
539
603
  def _adj_position_embedding(self, pos_embed_used, batch_size):
@@ -679,25 +743,27 @@ class _SegmentPatch(nn.Module):
679
743
 
680
744
 
681
745
  class _PatchEmbed(nn.Module):
682
- """EEG to Patch Embedding.
746
+ """EEG to Patch Embedding for Neural Decoder mode.
683
747
 
684
748
  This code is used when we want to apply the patch embedding
685
- after the codebook layer.
749
+ after the codebook layer (Neural Decoder mode).
750
+
751
+ The input is expected to be in the format (Batch, n_channels, n_times),
752
+ but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
753
+ This class reshapes the input to the pre-patched format, then applies a 2D
754
+ convolution to project this pre-patched data to the embedding dimension,
755
+ and finally flattens across channels to produce a unified embedding.
686
756
 
687
757
  Parameters:
688
758
  -----------
689
759
  n_times: int (default=2000)
690
- Number of temporal components of the input tensor.
760
+ Number of temporal components of the input tensor (used for dimension calculation).
691
761
  patch_size: int (default=200)
692
762
  Size of the patch, default is 1-seconds with 200Hz.
693
763
  in_channels: int (default=1)
694
- Number of input channels for to be used in the convolution.
764
+ Number of input channels (from VQVAE codebook).
695
765
  emb_dim: int (default=200)
696
- Number of out_channes to be used in the convolution, here,
697
- we used the same as patch_size.
698
- n_codebooks: int (default=62)
699
- Number of patches to be used in the convolution, here,
700
- we used the same as n_times // patch_size.
766
+ Number of output embedding dimension.
701
767
  """
702
768
 
703
769
  def __init__(
@@ -707,10 +773,13 @@ class _PatchEmbed(nn.Module):
707
773
  self.n_times = n_times
708
774
  self.patch_size = patch_size
709
775
  self.patch_shape = (1, self.n_times // self.patch_size)
710
- n_patchs = n_codebooks * (self.n_times // self.patch_size)
711
-
712
- self.n_patchs = n_patchs
776
+ self.n_patchs = self.n_times // self.patch_size
777
+ self.emb_dim = emb_dim
778
+ self.in_channels = in_channels
713
779
 
780
+ # 2D Conv to project the pre-patched data
781
+ # Input: (Batch, in_channels, n_patches, patch_size)
782
+ # After proj: (Batch, emb_dim, n_patches, 1)
714
783
  self.proj = nn.Conv2d(
715
784
  in_channels=in_channels,
716
785
  out_channels=emb_dim,
@@ -718,27 +787,64 @@ class _PatchEmbed(nn.Module):
718
787
  stride=(1, self.patch_size),
719
788
  )
720
789
 
721
- self.merge_transpose = Rearrange(
722
- "Batch ch patch spatch -> Batch patch spatch ch",
723
- )
724
-
725
790
  def forward(self, x):
726
791
  """
727
- Apply the convolution to the input tensor.
728
- then merge the output tensor to the desired shape.
792
+ Apply the temporal projection to the input tensor after grouping channels.
729
793
 
730
- Parameters:
731
- -----------
732
- x: torch.Tensor
733
- Input tensor of shape (Batch, Channels, n_patchs, patch_size).
794
+ Parameters
795
+ ----------
796
+ x : torch.Tensor
797
+ Input tensor of shape (Batch, n_channels, n_times) or
798
+ (Batch, n_channels, n_patches, patch_size).
734
799
 
735
- Return:
800
+ Returns
736
801
  -------
737
- x: torch.Tensor
738
- Output tensor of shape (Batch, n_patchs, patch_size, channels).
802
+ torch.Tensor
803
+ Output tensor of shape (Batch, n_patchs, emb_dim).
739
804
  """
805
+ if x.ndim == 4:
806
+ batch_size, n_channels, n_patchs, patch_len = x.shape
807
+ if patch_len != self.patch_size:
808
+ raise ValueError(
809
+ "When providing a 4D tensor, the last dimension "
810
+ f"({patch_len}) must match patch_size ({self.patch_size})."
811
+ )
812
+ n_times = n_patchs * patch_len
813
+ x = x.reshape(batch_size, n_channels, n_times)
814
+ elif x.ndim == 3:
815
+ batch_size, n_channels, n_times = x.shape
816
+ else:
817
+ raise ValueError(
818
+ "Input must be either 3D (batch, channels, times) or "
819
+ "4D (batch, channels, n_patches, patch_size)."
820
+ )
821
+
822
+ if n_times % self.patch_size != 0:
823
+ raise ValueError(
824
+ f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
825
+ )
826
+ if n_channels % self.in_channels != 0:
827
+ raise ValueError(
828
+ "The input channel dimension "
829
+ f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
830
+ )
831
+
832
+ group_size = n_channels // self.in_channels
833
+
834
+ # Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
835
+ # EEG channels as the spatial height dimension.
836
+ # Shape after view: (Batch, in_channels, group_size, n_times)
837
+ x = x.view(batch_size, self.in_channels, group_size, n_times)
838
+
839
+ # Apply the temporal projection per group.
840
+ # Output shape: (Batch, emb_dim, group_size, n_patchs)
740
841
  x = self.proj(x)
741
- x = self.merge_transpose(x)
842
+
843
+ # THIS IS braindecode's MODIFICATION:
844
+ # Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
845
+ x = x.mean(dim=2)
846
+ x = x.transpose(1, 2).contiguous()
847
+
742
848
  return x
743
849
 
744
850
 
@@ -13,6 +13,8 @@ from braindecode.models.base import EEGModuleMixin
13
13
  class MSVTNet(EEGModuleMixin, nn.Module):
14
14
  """MSVTNet model from Liu K et al (2024) from [msvt2024]_.
15
15
 
16
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Small Attention`
17
+
16
18
  This model implements a multi-scale convolutional transformer network
17
19
  for EEG signal classification, as described in [msvt2024]_.
18
20