braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev171178473__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +116 -152
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +218 -0
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +4 -3
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +18 -13
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +12 -8
- braindecode/models/atcnet.py +156 -17
- braindecode/models/attentionbasenet.py +148 -16
- braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +3 -1
- braindecode/models/ctnet.py +7 -4
- braindecode/models/deep4.py +6 -2
- braindecode/models/deepsleepnet.py +127 -5
- braindecode/models/eegconformer.py +114 -15
- braindecode/models/eeginception_erp.py +82 -7
- braindecode/models/eeginception_mi.py +2 -0
- braindecode/models/eegnet.py +64 -177
- braindecode/models/eegnex.py +113 -6
- braindecode/models/eegsimpleconv.py +2 -0
- braindecode/models/eegtcnet.py +1 -1
- braindecode/models/labram.py +188 -84
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/sccnet.py +81 -8
- braindecode/models/shallow_fbcsp.py +2 -0
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sleep_stager_blanco_2020.py +2 -0
- braindecode/models/sleep_stager_chambon_2018.py +2 -0
- braindecode/models/sparcnet.py +2 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +42 -41
- braindecode/models/tidnet.py +2 -0
- braindecode/models/tsinception.py +15 -3
- braindecode/models/usleep.py +108 -9
- braindecode/models/util.py +8 -5
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +42 -39
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/METADATA +12 -3
- braindecode-1.3.0.dev171178473.dist-info/RECORD +106 -0
- braindecode/models/eegresnet.py +0 -362
- braindecode-1.2.0.dev184328194.dist-info/RECORD +0 -101
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/WHEEL +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/top_level.txt +0 -0
braindecode/models/eegnex.py
CHANGED
|
@@ -16,9 +16,122 @@ from braindecode.modules import Conv2dWithConstraint, LinearWithConstraint
|
|
|
16
16
|
class EEGNeX(EEGModuleMixin, nn.Module):
|
|
17
17
|
"""EEGNeX model from Chen et al. (2024) [eegnex]_.
|
|
18
18
|
|
|
19
|
+
:bdg-success:`Convolution`
|
|
20
|
+
|
|
19
21
|
.. figure:: https://braindecode.org/dev/_static/model/eegnex.jpg
|
|
20
22
|
:align: center
|
|
21
23
|
:alt: EEGNeX Architecture
|
|
24
|
+
:width: 620px
|
|
25
|
+
|
|
26
|
+
.. rubric:: Architectural Overview
|
|
27
|
+
|
|
28
|
+
EEGNeX is a **purely convolutional** architecture that refines the EEGNet-style stem
|
|
29
|
+
and deepens the temporal stack with **dilated temporal convolutions**. The end-to-end
|
|
30
|
+
flow is:
|
|
31
|
+
|
|
32
|
+
- (i) **Block-1/2**: two temporal convolutions ``(1 x L)`` with BN refine a
|
|
33
|
+
learned FIR-like *temporal filter bank* (no pooling yet);
|
|
34
|
+
- (ii) **Block-3**: depthwise **spatial** convolution across electrodes
|
|
35
|
+
``(n_chans x 1)`` with max-norm constraint, followed by ELU → AvgPool (time) → Dropout;
|
|
36
|
+
- (iii) **Block-4/5**: two additional **temporal** convolutions with increasing **dilation**
|
|
37
|
+
to expand the receptive field; the last block applies ELU → AvgPool → Dropout → Flatten;
|
|
38
|
+
- (iv) **Classifier**: a max-norm–constrained linear layer.
|
|
39
|
+
|
|
40
|
+
The published work positions EEGNeX as a compact, conv-only alternative that consistently
|
|
41
|
+
outperforms prior baselines across MOABB-style benchmarks, with the popular
|
|
42
|
+
“EEGNeX-8,32” shorthand denoting *8 temporal filters* and *kernel length 32*.
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
.. rubric:: Macro Components
|
|
46
|
+
|
|
47
|
+
- **Block-1 / Block-2 — Temporal filter (learned).**
|
|
48
|
+
|
|
49
|
+
- *Operations.*
|
|
50
|
+
- :class:`torch.nn.Conv2d` with kernels ``(1, L)``
|
|
51
|
+
- :class:`torch.nn.BatchNorm2d` (no nonlinearity until Block-3, mirroring a linear FIR analysis stage).
|
|
52
|
+
These layers set up frequency-selective detectors before spatial mixing.
|
|
53
|
+
|
|
54
|
+
- *Interpretability.* Kernels can be inspected as FIR filters; two stacked temporal
|
|
55
|
+
convs allow longer effective kernels without parameter blow-up.
|
|
56
|
+
|
|
57
|
+
- **Block-3 — Spatial projection + condensation.**
|
|
58
|
+
|
|
59
|
+
- *Operations.*
|
|
60
|
+
- :class:`braindecode.modules.Conv2dWithConstraint` with kernel``(n_chans, 1)``
|
|
61
|
+
and ``groups = filter_2`` (depthwise across filters)
|
|
62
|
+
- :class:`torch.nn.BatchNorm2d`
|
|
63
|
+
- :class:`torch.nn.ELU`
|
|
64
|
+
- :class:`torch.nn.AvgPool2d` (time)
|
|
65
|
+
- :class:`torch.nn.Dropout`.
|
|
66
|
+
|
|
67
|
+
**Role**: Learns per-filter spatial patterns over the **full montage** while temporal
|
|
68
|
+
pooling stabilizes and compresses features; max-norm encourages well-behaved spatial
|
|
69
|
+
weights similar to EEGNet practice.
|
|
70
|
+
|
|
71
|
+
- **Block-4 / Block-5 — Dilated temporal integration.**
|
|
72
|
+
|
|
73
|
+
- *Operations.*
|
|
74
|
+
- :class:`torch.nn.Conv2d` with kernels ``(1, k)`` and **dilations**
|
|
75
|
+
(e.g., 2 then 4);
|
|
76
|
+
- :class:`torch.nn.BatchNorm2d`
|
|
77
|
+
- :class:`torch.nn.ELU`
|
|
78
|
+
- :class:`torch.nn.AvgPool2d` (time)
|
|
79
|
+
- :class:`torch.nn.Dropout`
|
|
80
|
+
- :class:`torch.nn.Flatten`.
|
|
81
|
+
|
|
82
|
+
**Role**: Expands the temporal receptive field efficiently to capture rhythms and
|
|
83
|
+
long-range context after condensation.
|
|
84
|
+
|
|
85
|
+
- **Final Classifier — Max-norm linear.**
|
|
86
|
+
|
|
87
|
+
- *Operations.*
|
|
88
|
+
- :class:`braindecode.modules.LinearWithConstraint` maps the flattened
|
|
89
|
+
vector to the target classes; the max-norm constraint regularizes the readout.
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
.. rubric:: Convolutional Details
|
|
93
|
+
|
|
94
|
+
- **Temporal (where time-domain patterns are learned).**
|
|
95
|
+
Blocks 1-2 learn the primary filter bank (oscillations/transients), while Blocks 4-5
|
|
96
|
+
use **dilation** to integrate over longer horizons without extra pooling. The final
|
|
97
|
+
AvgPool in Block-5 sets the output token rate and helps noise suppression.
|
|
98
|
+
|
|
99
|
+
- **Spatial (how electrodes are processed).**
|
|
100
|
+
A *single* depthwise spatial conv (Block-3) spans the entire electrode set
|
|
101
|
+
(kernel ``(n_chans, 1)``), producing per-temporal-filter topographies; no cross-filter
|
|
102
|
+
mixing occurs at this stage, aiding interpretability.
|
|
103
|
+
|
|
104
|
+
- **Spectral (how frequency content is captured).**
|
|
105
|
+
Frequency selectivity emerges from the learned temporal kernels; dilation broadens effective
|
|
106
|
+
bandwidth coverage by composing multiple scales.
|
|
107
|
+
|
|
108
|
+
.. rubric:: Additional Mechanisms
|
|
109
|
+
|
|
110
|
+
- **EEGNeX-8,32 naming.** “8,32” indicates *8 temporal filters* and *kernel length 32*,
|
|
111
|
+
reflecting the paper's ablation path from EEGNet-8,2 toward thicker temporal kernels
|
|
112
|
+
and a deeper conv stack.
|
|
113
|
+
- **Max-norm constraints.** Spatial (Block-3) and final linear layers use max-norm
|
|
114
|
+
regularization—standard in EEG CNNs—to reduce overfitting and encourage stable spatial
|
|
115
|
+
patterns.
|
|
116
|
+
|
|
117
|
+
.. rubric:: Usage and Configuration
|
|
118
|
+
|
|
119
|
+
- **Kernel schedule.** Start with the canonical **EEGNeX-8,32** (``filter_1=8``,
|
|
120
|
+
``kernel_block_1_2=32``) and keep **Block-3** depth multiplier modest (e.g., 2) to match
|
|
121
|
+
the paper's “pure conv” profile.
|
|
122
|
+
- **Pooling vs. dilation.** Use pooling in Blocks 3 and 5 to control compute and variance;
|
|
123
|
+
increase dilations (Blocks 4-5) to widen temporal context when windows are short.
|
|
124
|
+
- **Regularization.** Combine dropout (Blocks 3 & 5) with max-norm on spatial and
|
|
125
|
+
classifier layers; prefer ELU activations for stable training on small EEG datasets.
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
- The braindecode implementation follows the paper's conv-only design with five blocks
|
|
129
|
+
and reproduces the depthwise spatial step and dilated temporal stack. See the class
|
|
130
|
+
reference for exact kernel sizes, dilations, and pooling defaults. You can check the
|
|
131
|
+
original implementation at [EEGNexCode]_.
|
|
132
|
+
|
|
133
|
+
.. versionadded:: 1.1
|
|
134
|
+
|
|
22
135
|
|
|
23
136
|
Parameters
|
|
24
137
|
----------
|
|
@@ -45,12 +158,6 @@ class EEGNeX(EEGModuleMixin, nn.Module):
|
|
|
45
158
|
avg_pool_block5 : tuple[int, int], optional
|
|
46
159
|
Pooling size for block 5. Default is (1, 8).
|
|
47
160
|
|
|
48
|
-
Notes
|
|
49
|
-
-----
|
|
50
|
-
This implementation is not guaranteed to be correct, has not been checked
|
|
51
|
-
by original authors, only reimplemented from the paper description and
|
|
52
|
-
source code in tensorflow [EEGNexCode]_.
|
|
53
|
-
|
|
54
161
|
References
|
|
55
162
|
----------
|
|
56
163
|
.. [eegnex] Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2024).
|
|
@@ -21,6 +21,8 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
21
21
|
class EEGSimpleConv(EEGModuleMixin, torch.nn.Module):
|
|
22
22
|
"""EEGSimpleConv from Ouahidi, YE et al. (2023) [Yassine2023]_.
|
|
23
23
|
|
|
24
|
+
:bdg-success:`Convolution`
|
|
25
|
+
|
|
24
26
|
.. figure:: https://raw.githubusercontent.com/elouayas/EEGSimpleConv/refs/heads/main/architecture.png
|
|
25
27
|
:align: center
|
|
26
28
|
:alt: EEGSimpleConv Architecture
|
braindecode/models/eegtcnet.py
CHANGED
|
@@ -157,7 +157,7 @@ class EEGTCNet(EEGModuleMixin, nn.Module):
|
|
|
157
157
|
class _EEGNetTC(nn.Module):
|
|
158
158
|
"""EEGNet Temporal Convolutional Network (TCN) block.
|
|
159
159
|
|
|
160
|
-
The main difference from our
|
|
160
|
+
The main difference from our :class:`EEGNet` (braindecode) implementation is the
|
|
161
161
|
kernel and dimensional order. Because of this, we decided to keep this
|
|
162
162
|
implementation in a future issue; we will re-evaluate if it is necessary
|
|
163
163
|
to maintain this separate implementation.
|
braindecode/models/labram.py
CHANGED
|
@@ -22,6 +22,8 @@ from braindecode.modules import MLP, DropPath
|
|
|
22
22
|
class Labram(EEGModuleMixin, nn.Module):
|
|
23
23
|
"""Labram from Jiang, W B et al (2024) [Jiang2024]_.
|
|
24
24
|
|
|
25
|
+
:bdg-danger:`Large Brain Model`
|
|
26
|
+
|
|
25
27
|
.. figure:: https://arxiv.org/html/2405.18765v1/x1.png
|
|
26
28
|
:align: center
|
|
27
29
|
:alt: Labram Architecture.
|
|
@@ -43,31 +45,45 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
43
45
|
equals True. The original implementation uses (batch, n_chans, n_patches,
|
|
44
46
|
patch_size) as input with static segmentation of the input data.
|
|
45
47
|
|
|
46
|
-
The models have the following sequence of steps
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
The models have the following sequence of steps::
|
|
49
|
+
|
|
50
|
+
if neural tokenizer:
|
|
51
|
+
- SegmentPatch: Segment the input data in patches;
|
|
52
|
+
- TemporalConv: Apply a temporal convolution to the segmented data;
|
|
53
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
54
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
55
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
56
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
- PatchEmbed: Apply a patch embedding to the input data;
|
|
60
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
61
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
62
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
63
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
61
64
|
|
|
62
65
|
.. versionadded:: 0.9
|
|
63
66
|
|
|
67
|
+
|
|
68
|
+
Examples
|
|
69
|
+
--------
|
|
70
|
+
Load pre-trained weights::
|
|
71
|
+
|
|
72
|
+
>>> import torch
|
|
73
|
+
>>> from braindecode.models import Labram
|
|
74
|
+
>>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
|
|
75
|
+
>>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
|
|
76
|
+
>>> state = torch.hub.load_state_dict_from_url(url, progress=True)
|
|
77
|
+
>>> model.load_state_dict(state)
|
|
78
|
+
|
|
79
|
+
|
|
64
80
|
Parameters
|
|
65
81
|
----------
|
|
66
82
|
patch_size : int
|
|
67
83
|
The size of the patch to be used in the patch embedding.
|
|
68
84
|
emb_size : int
|
|
69
85
|
The dimension of the embedding.
|
|
70
|
-
|
|
86
|
+
in_conv_channels : int
|
|
71
87
|
The number of convolutional input channels.
|
|
72
88
|
out_channels : int
|
|
73
89
|
The number of convolutional output channels.
|
|
@@ -79,8 +95,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
79
95
|
The expansion ratio of the mlp layer
|
|
80
96
|
qkv_bias : bool (default=False)
|
|
81
97
|
If True, add a learnable bias to the query, key, and value tensors.
|
|
82
|
-
qk_norm : Pytorch Normalize layer (default=
|
|
83
|
-
If not None, apply LayerNorm to the query and key tensors
|
|
98
|
+
qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
99
|
+
If not None, apply LayerNorm to the query and key tensors.
|
|
100
|
+
Default is nn.LayerNorm for better weight transfer from original LaBraM.
|
|
101
|
+
Set to None to disable Q,K normalization.
|
|
84
102
|
qk_scale : float (default=None)
|
|
85
103
|
If not None, use this value as the scale factor. If None,
|
|
86
104
|
use head_dim**-0.5, where head_dim = dim // num_heads.
|
|
@@ -92,9 +110,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
92
110
|
Dropout rate for the attention weights used on DropPath.
|
|
93
111
|
norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
94
112
|
The normalization layer to be used.
|
|
95
|
-
init_values : float (default=
|
|
113
|
+
init_values : float (default=0.1)
|
|
96
114
|
If not None, use this value to initialize the gamma_1 and gamma_2
|
|
97
|
-
parameters.
|
|
115
|
+
parameters for residual scaling. Default is 0.1 for better weight
|
|
116
|
+
transfer from original LaBraM. Set to None to disable.
|
|
98
117
|
use_abs_pos_emb : bool (default=True)
|
|
99
118
|
If True, use absolute position embedding.
|
|
100
119
|
use_mean_pooling : bool (default=True)
|
|
@@ -135,19 +154,19 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
135
154
|
input_window_seconds=None,
|
|
136
155
|
patch_size=200,
|
|
137
156
|
emb_size=200,
|
|
138
|
-
|
|
157
|
+
in_conv_channels=1,
|
|
139
158
|
out_channels=8,
|
|
140
159
|
n_layers=12,
|
|
141
160
|
att_num_heads=10,
|
|
142
161
|
mlp_ratio=4.0,
|
|
143
162
|
qkv_bias=False,
|
|
144
|
-
qk_norm=
|
|
163
|
+
qk_norm=nn.LayerNorm,
|
|
145
164
|
qk_scale=None,
|
|
146
165
|
drop_prob=0.0,
|
|
147
166
|
attn_drop_prob=0.0,
|
|
148
167
|
drop_path_prob=0.0,
|
|
149
168
|
norm_layer=nn.LayerNorm,
|
|
150
|
-
init_values=
|
|
169
|
+
init_values=0.1,
|
|
151
170
|
use_abs_pos_emb=True,
|
|
152
171
|
use_mean_pooling=True,
|
|
153
172
|
init_scale=0.001,
|
|
@@ -183,15 +202,15 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
183
202
|
self.patch_size = patch_size
|
|
184
203
|
self.n_path = self.n_times // self.patch_size
|
|
185
204
|
|
|
186
|
-
if neural_tokenizer and
|
|
205
|
+
if neural_tokenizer and in_conv_channels != 1:
|
|
187
206
|
warn(
|
|
188
207
|
"The model is in Neural Tokenizer mode, but the variable "
|
|
189
|
-
+ "`
|
|
190
|
-
+ "`
|
|
191
|
-
+ "
|
|
208
|
+
+ "`in_conv_channels` is different from the default values."
|
|
209
|
+
+ "`in_conv_channels` is only needed for the Neural Decoder mode."
|
|
210
|
+
+ "in_conv_channels is not used in the Neural Tokenizer mode.",
|
|
192
211
|
UserWarning,
|
|
193
212
|
)
|
|
194
|
-
|
|
213
|
+
in_conv_channels = 1
|
|
195
214
|
# If you can use the model in Neural Tokenizer mode,
|
|
196
215
|
# temporal conv layer will be use over the patched dataset
|
|
197
216
|
if neural_tokenizer:
|
|
@@ -228,7 +247,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
228
247
|
_PatchEmbed(
|
|
229
248
|
n_times=self.n_times,
|
|
230
249
|
patch_size=patch_size,
|
|
231
|
-
in_channels=
|
|
250
|
+
in_channels=in_conv_channels,
|
|
232
251
|
emb_dim=self.emb_size,
|
|
233
252
|
),
|
|
234
253
|
)
|
|
@@ -373,8 +392,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
373
392
|
Parameters
|
|
374
393
|
----------
|
|
375
394
|
x : torch.Tensor
|
|
376
|
-
The input data with shape (batch, n_chans,
|
|
377
|
-
if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
|
|
395
|
+
The input data with shape (batch, n_chans, n_times).
|
|
378
396
|
input_chans : int
|
|
379
397
|
The number of input channels.
|
|
380
398
|
return_patch_tokens : bool
|
|
@@ -387,37 +405,72 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
387
405
|
x : torch.Tensor
|
|
388
406
|
The output of the model.
|
|
389
407
|
"""
|
|
408
|
+
batch_size = x.shape[0]
|
|
409
|
+
|
|
390
410
|
if self.neural_tokenizer:
|
|
391
|
-
|
|
411
|
+
# For neural tokenizer: input is (batch, n_chans, n_times)
|
|
412
|
+
# patch_embed returns (batch, n_chans, emb_dim)
|
|
413
|
+
x = self.patch_embed(x)
|
|
414
|
+
# x shape: (batch, n_chans, emb_dim)
|
|
415
|
+
n_patch = self.n_chans
|
|
416
|
+
temporal = self.emb_size
|
|
392
417
|
else:
|
|
393
|
-
|
|
394
|
-
|
|
418
|
+
# For neural decoder: input is (batch, n_chans, n_times)
|
|
419
|
+
# patch_embed returns (batch, n_patchs, emb_dim)
|
|
420
|
+
x = self.patch_embed(x)
|
|
421
|
+
# x shape: (batch, n_patchs, emb_dim)
|
|
422
|
+
batch_size, n_patch, temporal = x.shape
|
|
423
|
+
|
|
395
424
|
# add the [CLS] token to the embedded patch tokens
|
|
396
425
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
397
426
|
|
|
427
|
+
# Concatenate cls token with patch/channel embeddings
|
|
398
428
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
399
429
|
|
|
400
430
|
# Positional Embedding
|
|
401
|
-
if input_chans is not None:
|
|
402
|
-
pos_embed_used = self.position_embedding[:, input_chans]
|
|
403
|
-
else:
|
|
404
|
-
pos_embed_used = self.position_embedding
|
|
405
|
-
|
|
406
431
|
if self.position_embedding is not None:
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
432
|
+
if self.neural_tokenizer:
|
|
433
|
+
# In tokenizer mode, use channel-based position embedding
|
|
434
|
+
if input_chans is not None:
|
|
435
|
+
pos_embed_used = self.position_embedding[:, input_chans]
|
|
436
|
+
else:
|
|
437
|
+
pos_embed_used = self.position_embedding
|
|
438
|
+
|
|
439
|
+
pos_embed = self._adj_position_embedding(
|
|
440
|
+
pos_embed_used=pos_embed_used, batch_size=batch_size
|
|
441
|
+
)
|
|
442
|
+
else:
|
|
443
|
+
# In decoder mode, we have different number of patches
|
|
444
|
+
# Adapt position embedding for n_patch patches
|
|
445
|
+
# Use the first n_patch+1 positions from position_embedding
|
|
446
|
+
n_pos = min(self.position_embedding.shape[1], n_patch + 1)
|
|
447
|
+
pos_embed_used = self.position_embedding[:, :n_pos, :]
|
|
448
|
+
pos_embed = pos_embed_used.expand(batch_size, -1, -1)
|
|
449
|
+
|
|
410
450
|
x += pos_embed
|
|
411
451
|
|
|
412
452
|
# The time embedding is added across the channels after the [CLS] token
|
|
413
453
|
if self.neural_tokenizer:
|
|
414
454
|
num_ch = self.n_chans
|
|
455
|
+
time_embed = self._adj_temporal_embedding(
|
|
456
|
+
num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
|
|
457
|
+
)
|
|
458
|
+
x[:, 1:, :] += time_embed
|
|
415
459
|
else:
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
460
|
+
# In decoder mode, we have n_patch patches and don't need to expand
|
|
461
|
+
# Just broadcast the temporal embedding
|
|
462
|
+
if temporal is None:
|
|
463
|
+
temporal = self.emb_size
|
|
464
|
+
|
|
465
|
+
# Get temporal embeddings for n_patch patches
|
|
466
|
+
n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
|
|
467
|
+
time_embed = self.temporal_embedding[
|
|
468
|
+
:, 1 : n_time_tokens + 1, :
|
|
469
|
+
] # (1, n_patch, emb_dim)
|
|
470
|
+
time_embed = time_embed.expand(
|
|
471
|
+
batch_size, -1, -1
|
|
472
|
+
) # (batch, n_patch, emb_dim)
|
|
473
|
+
x[:, 1:, :] += time_embed
|
|
421
474
|
|
|
422
475
|
x = self.pos_drop(x)
|
|
423
476
|
|
|
@@ -428,10 +481,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
428
481
|
if self.fc_norm is not None:
|
|
429
482
|
if return_all_tokens:
|
|
430
483
|
return self.fc_norm(x)
|
|
431
|
-
|
|
484
|
+
tokens = x[:, 1:, :]
|
|
432
485
|
if return_patch_tokens:
|
|
433
|
-
return self.fc_norm(
|
|
434
|
-
return self.fc_norm(
|
|
486
|
+
return self.fc_norm(tokens)
|
|
487
|
+
return self.fc_norm(tokens.mean(1))
|
|
435
488
|
else:
|
|
436
489
|
if return_all_tokens:
|
|
437
490
|
return x
|
|
@@ -505,14 +558,16 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
505
558
|
def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
|
|
506
559
|
"""
|
|
507
560
|
Adjust the dimensions of the time embedding to match the
|
|
508
|
-
number of channels.
|
|
561
|
+
number of channels or patches.
|
|
509
562
|
|
|
510
563
|
Parameters
|
|
511
564
|
----------
|
|
512
565
|
num_ch : int
|
|
513
|
-
The number of channels or number of
|
|
566
|
+
The number of channels or number of patches.
|
|
514
567
|
batch_size : int
|
|
515
568
|
Batch size of the input data.
|
|
569
|
+
dim_embed : int
|
|
570
|
+
The embedding dimension (temporal feature dimension).
|
|
516
571
|
|
|
517
572
|
Returns
|
|
518
573
|
-------
|
|
@@ -523,17 +578,24 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
523
578
|
if dim_embed is None:
|
|
524
579
|
cut_dimension = self.patch_size
|
|
525
580
|
else:
|
|
526
|
-
cut_dimension = dim_embed
|
|
527
|
-
|
|
528
|
-
|
|
581
|
+
cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
|
|
582
|
+
|
|
583
|
+
# Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
|
|
584
|
+
# Slice to cut_dimension: (1, cut_dimension, emb_size)
|
|
585
|
+
temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
|
|
586
|
+
|
|
529
587
|
# Add a new dimension to the time embedding
|
|
530
|
-
# e.g. (
|
|
588
|
+
# e.g. (1, 5, 200) -> (1, 1, 5, 200)
|
|
531
589
|
temporal_embedding = temporal_embedding.unsqueeze(1)
|
|
532
|
-
|
|
533
|
-
#
|
|
590
|
+
|
|
591
|
+
# Expand the time embedding to match the number of channels or patches
|
|
592
|
+
# (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
|
|
534
593
|
temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
|
|
594
|
+
|
|
535
595
|
# Flatten the intermediate dimensions
|
|
596
|
+
# (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
|
|
536
597
|
temporal_embedding = temporal_embedding.flatten(1, 2)
|
|
598
|
+
|
|
537
599
|
return temporal_embedding
|
|
538
600
|
|
|
539
601
|
def _adj_position_embedding(self, pos_embed_used, batch_size):
|
|
@@ -679,25 +741,27 @@ class _SegmentPatch(nn.Module):
|
|
|
679
741
|
|
|
680
742
|
|
|
681
743
|
class _PatchEmbed(nn.Module):
|
|
682
|
-
"""EEG to Patch Embedding.
|
|
744
|
+
"""EEG to Patch Embedding for Neural Decoder mode.
|
|
683
745
|
|
|
684
746
|
This code is used when we want to apply the patch embedding
|
|
685
|
-
after the codebook layer.
|
|
747
|
+
after the codebook layer (Neural Decoder mode).
|
|
748
|
+
|
|
749
|
+
The input is expected to be in the format (Batch, n_channels, n_times),
|
|
750
|
+
but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
|
|
751
|
+
This class reshapes the input to the pre-patched format, then applies a 2D
|
|
752
|
+
convolution to project this pre-patched data to the embedding dimension,
|
|
753
|
+
and finally flattens across channels to produce a unified embedding.
|
|
686
754
|
|
|
687
755
|
Parameters:
|
|
688
756
|
-----------
|
|
689
757
|
n_times: int (default=2000)
|
|
690
|
-
Number of temporal components of the input tensor.
|
|
758
|
+
Number of temporal components of the input tensor (used for dimension calculation).
|
|
691
759
|
patch_size: int (default=200)
|
|
692
760
|
Size of the patch, default is 1-seconds with 200Hz.
|
|
693
761
|
in_channels: int (default=1)
|
|
694
|
-
Number of input channels
|
|
762
|
+
Number of input channels (from VQVAE codebook).
|
|
695
763
|
emb_dim: int (default=200)
|
|
696
|
-
Number of
|
|
697
|
-
we used the same as patch_size.
|
|
698
|
-
n_codebooks: int (default=62)
|
|
699
|
-
Number of patches to be used in the convolution, here,
|
|
700
|
-
we used the same as n_times // patch_size.
|
|
764
|
+
Number of output embedding dimension.
|
|
701
765
|
"""
|
|
702
766
|
|
|
703
767
|
def __init__(
|
|
@@ -707,10 +771,13 @@ class _PatchEmbed(nn.Module):
|
|
|
707
771
|
self.n_times = n_times
|
|
708
772
|
self.patch_size = patch_size
|
|
709
773
|
self.patch_shape = (1, self.n_times // self.patch_size)
|
|
710
|
-
n_patchs =
|
|
711
|
-
|
|
712
|
-
self.
|
|
774
|
+
self.n_patchs = self.n_times // self.patch_size
|
|
775
|
+
self.emb_dim = emb_dim
|
|
776
|
+
self.in_channels = in_channels
|
|
713
777
|
|
|
778
|
+
# 2D Conv to project the pre-patched data
|
|
779
|
+
# Input: (Batch, in_channels, n_patches, patch_size)
|
|
780
|
+
# After proj: (Batch, emb_dim, n_patches, 1)
|
|
714
781
|
self.proj = nn.Conv2d(
|
|
715
782
|
in_channels=in_channels,
|
|
716
783
|
out_channels=emb_dim,
|
|
@@ -718,27 +785,64 @@ class _PatchEmbed(nn.Module):
|
|
|
718
785
|
stride=(1, self.patch_size),
|
|
719
786
|
)
|
|
720
787
|
|
|
721
|
-
self.merge_transpose = Rearrange(
|
|
722
|
-
"Batch ch patch spatch -> Batch patch spatch ch",
|
|
723
|
-
)
|
|
724
|
-
|
|
725
788
|
def forward(self, x):
|
|
726
789
|
"""
|
|
727
|
-
Apply the
|
|
728
|
-
then merge the output tensor to the desired shape.
|
|
790
|
+
Apply the temporal projection to the input tensor after grouping channels.
|
|
729
791
|
|
|
730
|
-
Parameters
|
|
731
|
-
|
|
732
|
-
x: torch.Tensor
|
|
733
|
-
Input tensor of shape (Batch,
|
|
792
|
+
Parameters
|
|
793
|
+
----------
|
|
794
|
+
x : torch.Tensor
|
|
795
|
+
Input tensor of shape (Batch, n_channels, n_times) or
|
|
796
|
+
(Batch, n_channels, n_patches, patch_size).
|
|
734
797
|
|
|
735
|
-
|
|
798
|
+
Returns
|
|
736
799
|
-------
|
|
737
|
-
|
|
738
|
-
Output tensor of shape (Batch, n_patchs,
|
|
800
|
+
torch.Tensor
|
|
801
|
+
Output tensor of shape (Batch, n_patchs, emb_dim).
|
|
739
802
|
"""
|
|
803
|
+
if x.ndim == 4:
|
|
804
|
+
batch_size, n_channels, n_patchs, patch_len = x.shape
|
|
805
|
+
if patch_len != self.patch_size:
|
|
806
|
+
raise ValueError(
|
|
807
|
+
"When providing a 4D tensor, the last dimension "
|
|
808
|
+
f"({patch_len}) must match patch_size ({self.patch_size})."
|
|
809
|
+
)
|
|
810
|
+
n_times = n_patchs * patch_len
|
|
811
|
+
x = x.reshape(batch_size, n_channels, n_times)
|
|
812
|
+
elif x.ndim == 3:
|
|
813
|
+
batch_size, n_channels, n_times = x.shape
|
|
814
|
+
else:
|
|
815
|
+
raise ValueError(
|
|
816
|
+
"Input must be either 3D (batch, channels, times) or "
|
|
817
|
+
"4D (batch, channels, n_patches, patch_size)."
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
if n_times % self.patch_size != 0:
|
|
821
|
+
raise ValueError(
|
|
822
|
+
f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
|
|
823
|
+
)
|
|
824
|
+
if n_channels % self.in_channels != 0:
|
|
825
|
+
raise ValueError(
|
|
826
|
+
"The input channel dimension "
|
|
827
|
+
f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
group_size = n_channels // self.in_channels
|
|
831
|
+
|
|
832
|
+
# Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
|
|
833
|
+
# EEG channels as the spatial height dimension.
|
|
834
|
+
# Shape after view: (Batch, in_channels, group_size, n_times)
|
|
835
|
+
x = x.view(batch_size, self.in_channels, group_size, n_times)
|
|
836
|
+
|
|
837
|
+
# Apply the temporal projection per group.
|
|
838
|
+
# Output shape: (Batch, emb_dim, group_size, n_patchs)
|
|
740
839
|
x = self.proj(x)
|
|
741
|
-
|
|
840
|
+
|
|
841
|
+
# THIS IS braindecode's MODIFICATION:
|
|
842
|
+
# Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
|
|
843
|
+
x = x.mean(dim=2)
|
|
844
|
+
x = x.transpose(1, 2).contiguous()
|
|
845
|
+
|
|
742
846
|
return x
|
|
743
847
|
|
|
744
848
|
|