braindecode 1.3.0.dev180851780__py3-none-any.whl → 1.3.0.dev183667303__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +115 -151
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +2 -2
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +7 -7
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +2 -0
- braindecode/models/atcnet.py +25 -26
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +2 -0
- braindecode/models/ctnet.py +6 -3
- braindecode/models/deepsleepnet.py +27 -18
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegnet.py +1 -1
- braindecode/models/labram.py +193 -87
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sstdpn.py +11 -11
- braindecode/models/summary.csv +1 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +1 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +12 -12
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/METADATA +6 -2
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/RECORD +52 -49
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/top_level.txt +0 -0
|
@@ -35,51 +35,57 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
35
35
|
- :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
|
|
36
36
|
|
|
37
37
|
- *Operations.*
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
|
|
39
|
+
- `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
|
|
40
|
+
- `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
|
|
41
|
+
- `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
|
|
42
|
+
- `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
|
|
43
|
+
- `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
|
|
44
|
+
- `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
|
|
44
45
|
|
|
45
46
|
*Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
|
|
46
47
|
|
|
47
48
|
- :class:`_InceptionModule2` **(refinement at coarser timebase)**
|
|
48
49
|
|
|
49
50
|
- *Operations.*
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
51
|
+
|
|
52
|
+
- `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
|
|
53
|
+
- `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
|
|
54
|
+
- `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
|
|
55
|
+
- `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
|
|
56
|
+
- `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
|
|
57
|
+
- `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
|
|
58
|
+
- `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
|
|
57
59
|
|
|
58
60
|
*Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
|
|
59
61
|
|
|
60
62
|
- :class:`_OutputModule` **(aggregation + readout)**
|
|
61
63
|
|
|
62
64
|
- *Operations.*
|
|
63
|
-
|
|
64
|
-
|
|
65
|
+
|
|
66
|
+
- :class:`torch.nn.Flatten`
|
|
67
|
+
- :class:`torch.nn.Linear` ``(features → 2)``
|
|
65
68
|
|
|
66
69
|
.. rubric:: Convolutional Details
|
|
67
70
|
|
|
68
71
|
- **Temporal (where time-domain patterns are learned).**
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
|
|
73
|
+
First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
|
|
74
|
+
(≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
|
|
75
|
+
``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
|
|
76
|
+
temporal resolution changes only via average pooling.
|
|
73
77
|
|
|
74
78
|
- **Spatial (how electrodes are processed).**
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
|
|
80
|
+
Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
|
|
81
|
+
yielding scale-specific channel projections (no cross-branch mixing until concatenation).
|
|
82
|
+
There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
|
|
78
83
|
|
|
79
84
|
- **Spectral (how frequency information is captured).**
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
85
|
+
|
|
86
|
+
No explicit transform; multiple temporal kernels form a *learned filter bank* over
|
|
87
|
+
ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
|
|
88
|
+
post-stimulus components.
|
|
83
89
|
|
|
84
90
|
.. rubric:: Additional Mechanisms
|
|
85
91
|
|
braindecode/models/eegnet.py
CHANGED
|
@@ -57,7 +57,7 @@ class EEGNet(EEGModuleMixin, nn.Sequential):
|
|
|
57
57
|
|
|
58
58
|
- **Temporal.** The initial temporal convs serve as a *learned filter bank*:
|
|
59
59
|
long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
|
|
60
|
-
Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature
|
|
60
|
+
Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature's spectrum [Lawhern2018]_.
|
|
61
61
|
|
|
62
62
|
- **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
|
|
63
63
|
With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
|
braindecode/models/labram.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Labram module.
|
|
3
3
|
Authors: Wei-Bang Jiang
|
|
4
4
|
Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
5
|
+
Matthew Chen <matt.chen4260@gmail.com>
|
|
5
6
|
License: BSD 3 clause
|
|
6
7
|
"""
|
|
7
8
|
|
|
@@ -22,12 +23,14 @@ from braindecode.modules import MLP, DropPath
|
|
|
22
23
|
class Labram(EEGModuleMixin, nn.Module):
|
|
23
24
|
"""Labram from Jiang, W B et al (2024) [Jiang2024]_.
|
|
24
25
|
|
|
26
|
+
:bdg-danger:`Large Brain Model`
|
|
27
|
+
|
|
25
28
|
.. figure:: https://arxiv.org/html/2405.18765v1/x1.png
|
|
26
29
|
:align: center
|
|
27
30
|
:alt: Labram Architecture.
|
|
28
31
|
|
|
29
32
|
Large Brain Model for Learning Generic Representations with Tremendous
|
|
30
|
-
EEG Data in BCI from [Jiang2024]_
|
|
33
|
+
EEG Data in BCI from [Jiang2024]_.
|
|
31
34
|
|
|
32
35
|
This is an **adaptation** of the code [Code2024]_ from the Labram model.
|
|
33
36
|
|
|
@@ -35,7 +38,8 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
35
38
|
BEiTv2 [BeiTv2]_.
|
|
36
39
|
|
|
37
40
|
The models can be used in two modes:
|
|
38
|
-
|
|
41
|
+
|
|
42
|
+
- Neural Tokenizer: Design to get an embedding layers (e.g. classification).
|
|
39
43
|
- Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
|
|
40
44
|
|
|
41
45
|
The braindecode's modification is to allow the model to be used in
|
|
@@ -43,31 +47,45 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
43
47
|
equals True. The original implementation uses (batch, n_chans, n_patches,
|
|
44
48
|
patch_size) as input with static segmentation of the input data.
|
|
45
49
|
|
|
46
|
-
The models have the following sequence of steps
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
50
|
+
The models have the following sequence of steps::
|
|
51
|
+
|
|
52
|
+
if neural tokenizer:
|
|
53
|
+
- SegmentPatch: Segment the input data in patches;
|
|
54
|
+
- TemporalConv: Apply a temporal convolution to the segmented data;
|
|
55
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
56
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
57
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
58
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
- PatchEmbed: Apply a patch embedding to the input data;
|
|
62
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
63
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
64
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
65
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
61
66
|
|
|
62
67
|
.. versionadded:: 0.9
|
|
63
68
|
|
|
69
|
+
|
|
70
|
+
Examples
|
|
71
|
+
--------
|
|
72
|
+
Load pre-trained weights::
|
|
73
|
+
|
|
74
|
+
>>> import torch
|
|
75
|
+
>>> from braindecode.models import Labram
|
|
76
|
+
>>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
|
|
77
|
+
>>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
|
|
78
|
+
>>> state = torch.hub.load_state_dict_from_url(url, progress=True)
|
|
79
|
+
>>> model.load_state_dict(state)
|
|
80
|
+
|
|
81
|
+
|
|
64
82
|
Parameters
|
|
65
83
|
----------
|
|
66
84
|
patch_size : int
|
|
67
85
|
The size of the patch to be used in the patch embedding.
|
|
68
86
|
emb_size : int
|
|
69
87
|
The dimension of the embedding.
|
|
70
|
-
|
|
88
|
+
in_conv_channels : int
|
|
71
89
|
The number of convolutional input channels.
|
|
72
90
|
out_channels : int
|
|
73
91
|
The number of convolutional output channels.
|
|
@@ -79,8 +97,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
79
97
|
The expansion ratio of the mlp layer
|
|
80
98
|
qkv_bias : bool (default=False)
|
|
81
99
|
If True, add a learnable bias to the query, key, and value tensors.
|
|
82
|
-
qk_norm : Pytorch Normalize layer (default=
|
|
83
|
-
If not None, apply LayerNorm to the query and key tensors
|
|
100
|
+
qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
101
|
+
If not None, apply LayerNorm to the query and key tensors.
|
|
102
|
+
Default is nn.LayerNorm for better weight transfer from original LaBraM.
|
|
103
|
+
Set to None to disable Q,K normalization.
|
|
84
104
|
qk_scale : float (default=None)
|
|
85
105
|
If not None, use this value as the scale factor. If None,
|
|
86
106
|
use head_dim**-0.5, where head_dim = dim // num_heads.
|
|
@@ -92,9 +112,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
92
112
|
Dropout rate for the attention weights used on DropPath.
|
|
93
113
|
norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
94
114
|
The normalization layer to be used.
|
|
95
|
-
init_values : float (default=
|
|
115
|
+
init_values : float (default=0.1)
|
|
96
116
|
If not None, use this value to initialize the gamma_1 and gamma_2
|
|
97
|
-
parameters.
|
|
117
|
+
parameters for residual scaling. Default is 0.1 for better weight
|
|
118
|
+
transfer from original LaBraM. Set to None to disable.
|
|
98
119
|
use_abs_pos_emb : bool (default=True)
|
|
99
120
|
If True, use absolute position embedding.
|
|
100
121
|
use_mean_pooling : bool (default=True)
|
|
@@ -102,7 +123,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
102
123
|
init_scale : float (default=0.001)
|
|
103
124
|
The initial scale to be used in the parameters of the model.
|
|
104
125
|
neural_tokenizer : bool (default=True)
|
|
105
|
-
The model can be used in two modes: Neural
|
|
126
|
+
The model can be used in two modes: Neural Tokenizer or Neural Decoder.
|
|
106
127
|
attn_head_dim : bool (default=None)
|
|
107
128
|
The head dimension to be used in the attention layer, to be used only
|
|
108
129
|
during pre-training.
|
|
@@ -135,19 +156,19 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
135
156
|
input_window_seconds=None,
|
|
136
157
|
patch_size=200,
|
|
137
158
|
emb_size=200,
|
|
138
|
-
|
|
159
|
+
in_conv_channels=1,
|
|
139
160
|
out_channels=8,
|
|
140
161
|
n_layers=12,
|
|
141
162
|
att_num_heads=10,
|
|
142
163
|
mlp_ratio=4.0,
|
|
143
164
|
qkv_bias=False,
|
|
144
|
-
qk_norm=
|
|
165
|
+
qk_norm=nn.LayerNorm,
|
|
145
166
|
qk_scale=None,
|
|
146
167
|
drop_prob=0.0,
|
|
147
168
|
attn_drop_prob=0.0,
|
|
148
169
|
drop_path_prob=0.0,
|
|
149
170
|
norm_layer=nn.LayerNorm,
|
|
150
|
-
init_values=
|
|
171
|
+
init_values=0.1,
|
|
151
172
|
use_abs_pos_emb=True,
|
|
152
173
|
use_mean_pooling=True,
|
|
153
174
|
init_scale=0.001,
|
|
@@ -183,15 +204,15 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
183
204
|
self.patch_size = patch_size
|
|
184
205
|
self.n_path = self.n_times // self.patch_size
|
|
185
206
|
|
|
186
|
-
if neural_tokenizer and
|
|
207
|
+
if neural_tokenizer and in_conv_channels != 1:
|
|
187
208
|
warn(
|
|
188
209
|
"The model is in Neural Tokenizer mode, but the variable "
|
|
189
|
-
+ "`
|
|
190
|
-
+ "`
|
|
191
|
-
+ "
|
|
210
|
+
+ "`in_conv_channels` is different from the default values."
|
|
211
|
+
+ "`in_conv_channels` is only needed for the Neural Decoder mode."
|
|
212
|
+
+ "in_conv_channels is not used in the Neural Tokenizer mode.",
|
|
192
213
|
UserWarning,
|
|
193
214
|
)
|
|
194
|
-
|
|
215
|
+
in_conv_channels = 1
|
|
195
216
|
# If you can use the model in Neural Tokenizer mode,
|
|
196
217
|
# temporal conv layer will be use over the patched dataset
|
|
197
218
|
if neural_tokenizer:
|
|
@@ -228,7 +249,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
228
249
|
_PatchEmbed(
|
|
229
250
|
n_times=self.n_times,
|
|
230
251
|
patch_size=patch_size,
|
|
231
|
-
in_channels=
|
|
252
|
+
in_channels=in_conv_channels,
|
|
232
253
|
emb_dim=self.emb_size,
|
|
233
254
|
),
|
|
234
255
|
)
|
|
@@ -373,8 +394,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
373
394
|
Parameters
|
|
374
395
|
----------
|
|
375
396
|
x : torch.Tensor
|
|
376
|
-
The input data with shape (batch, n_chans,
|
|
377
|
-
if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
|
|
397
|
+
The input data with shape (batch, n_chans, n_times).
|
|
378
398
|
input_chans : int
|
|
379
399
|
The number of input channels.
|
|
380
400
|
return_patch_tokens : bool
|
|
@@ -387,37 +407,72 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
387
407
|
x : torch.Tensor
|
|
388
408
|
The output of the model.
|
|
389
409
|
"""
|
|
410
|
+
batch_size = x.shape[0]
|
|
411
|
+
|
|
390
412
|
if self.neural_tokenizer:
|
|
391
|
-
|
|
413
|
+
# For neural tokenizer: input is (batch, n_chans, n_times)
|
|
414
|
+
# patch_embed returns (batch, n_chans, emb_dim)
|
|
415
|
+
x = self.patch_embed(x)
|
|
416
|
+
# x shape: (batch, n_chans, emb_dim)
|
|
417
|
+
n_patch = self.n_chans
|
|
418
|
+
temporal = self.emb_size
|
|
392
419
|
else:
|
|
393
|
-
|
|
394
|
-
|
|
420
|
+
# For neural decoder: input is (batch, n_chans, n_times)
|
|
421
|
+
# patch_embed returns (batch, n_patchs, emb_dim)
|
|
422
|
+
x = self.patch_embed(x)
|
|
423
|
+
# x shape: (batch, n_patchs, emb_dim)
|
|
424
|
+
batch_size, n_patch, temporal = x.shape
|
|
425
|
+
|
|
395
426
|
# add the [CLS] token to the embedded patch tokens
|
|
396
427
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
397
428
|
|
|
429
|
+
# Concatenate cls token with patch/channel embeddings
|
|
398
430
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
399
431
|
|
|
400
432
|
# Positional Embedding
|
|
401
|
-
if input_chans is not None:
|
|
402
|
-
pos_embed_used = self.position_embedding[:, input_chans]
|
|
403
|
-
else:
|
|
404
|
-
pos_embed_used = self.position_embedding
|
|
405
|
-
|
|
406
433
|
if self.position_embedding is not None:
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
434
|
+
if self.neural_tokenizer:
|
|
435
|
+
# In tokenizer mode, use channel-based position embedding
|
|
436
|
+
if input_chans is not None:
|
|
437
|
+
pos_embed_used = self.position_embedding[:, input_chans]
|
|
438
|
+
else:
|
|
439
|
+
pos_embed_used = self.position_embedding
|
|
440
|
+
|
|
441
|
+
pos_embed = self._adj_position_embedding(
|
|
442
|
+
pos_embed_used=pos_embed_used, batch_size=batch_size
|
|
443
|
+
)
|
|
444
|
+
else:
|
|
445
|
+
# In decoder mode, we have different number of patches
|
|
446
|
+
# Adapt position embedding for n_patch patches
|
|
447
|
+
# Use the first n_patch+1 positions from position_embedding
|
|
448
|
+
n_pos = min(self.position_embedding.shape[1], n_patch + 1)
|
|
449
|
+
pos_embed_used = self.position_embedding[:, :n_pos, :]
|
|
450
|
+
pos_embed = pos_embed_used.expand(batch_size, -1, -1)
|
|
451
|
+
|
|
410
452
|
x += pos_embed
|
|
411
453
|
|
|
412
454
|
# The time embedding is added across the channels after the [CLS] token
|
|
413
455
|
if self.neural_tokenizer:
|
|
414
456
|
num_ch = self.n_chans
|
|
457
|
+
time_embed = self._adj_temporal_embedding(
|
|
458
|
+
num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
|
|
459
|
+
)
|
|
460
|
+
x[:, 1:, :] += time_embed
|
|
415
461
|
else:
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
462
|
+
# In decoder mode, we have n_patch patches and don't need to expand
|
|
463
|
+
# Just broadcast the temporal embedding
|
|
464
|
+
if temporal is None:
|
|
465
|
+
temporal = self.emb_size
|
|
466
|
+
|
|
467
|
+
# Get temporal embeddings for n_patch patches
|
|
468
|
+
n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
|
|
469
|
+
time_embed = self.temporal_embedding[
|
|
470
|
+
:, 1 : n_time_tokens + 1, :
|
|
471
|
+
] # (1, n_patch, emb_dim)
|
|
472
|
+
time_embed = time_embed.expand(
|
|
473
|
+
batch_size, -1, -1
|
|
474
|
+
) # (batch, n_patch, emb_dim)
|
|
475
|
+
x[:, 1:, :] += time_embed
|
|
421
476
|
|
|
422
477
|
x = self.pos_drop(x)
|
|
423
478
|
|
|
@@ -428,10 +483,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
428
483
|
if self.fc_norm is not None:
|
|
429
484
|
if return_all_tokens:
|
|
430
485
|
return self.fc_norm(x)
|
|
431
|
-
|
|
486
|
+
tokens = x[:, 1:, :]
|
|
432
487
|
if return_patch_tokens:
|
|
433
|
-
return self.fc_norm(
|
|
434
|
-
return self.fc_norm(
|
|
488
|
+
return self.fc_norm(tokens)
|
|
489
|
+
return self.fc_norm(tokens.mean(1))
|
|
435
490
|
else:
|
|
436
491
|
if return_all_tokens:
|
|
437
492
|
return x
|
|
@@ -505,14 +560,16 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
505
560
|
def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
|
|
506
561
|
"""
|
|
507
562
|
Adjust the dimensions of the time embedding to match the
|
|
508
|
-
number of channels.
|
|
563
|
+
number of channels or patches.
|
|
509
564
|
|
|
510
565
|
Parameters
|
|
511
566
|
----------
|
|
512
567
|
num_ch : int
|
|
513
|
-
The number of channels or number of
|
|
568
|
+
The number of channels or number of patches.
|
|
514
569
|
batch_size : int
|
|
515
570
|
Batch size of the input data.
|
|
571
|
+
dim_embed : int
|
|
572
|
+
The embedding dimension (temporal feature dimension).
|
|
516
573
|
|
|
517
574
|
Returns
|
|
518
575
|
-------
|
|
@@ -523,17 +580,24 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
523
580
|
if dim_embed is None:
|
|
524
581
|
cut_dimension = self.patch_size
|
|
525
582
|
else:
|
|
526
|
-
cut_dimension = dim_embed
|
|
527
|
-
|
|
528
|
-
|
|
583
|
+
cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
|
|
584
|
+
|
|
585
|
+
# Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
|
|
586
|
+
# Slice to cut_dimension: (1, cut_dimension, emb_size)
|
|
587
|
+
temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
|
|
588
|
+
|
|
529
589
|
# Add a new dimension to the time embedding
|
|
530
|
-
# e.g. (
|
|
590
|
+
# e.g. (1, 5, 200) -> (1, 1, 5, 200)
|
|
531
591
|
temporal_embedding = temporal_embedding.unsqueeze(1)
|
|
532
|
-
|
|
533
|
-
#
|
|
592
|
+
|
|
593
|
+
# Expand the time embedding to match the number of channels or patches
|
|
594
|
+
# (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
|
|
534
595
|
temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
|
|
596
|
+
|
|
535
597
|
# Flatten the intermediate dimensions
|
|
598
|
+
# (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
|
|
536
599
|
temporal_embedding = temporal_embedding.flatten(1, 2)
|
|
600
|
+
|
|
537
601
|
return temporal_embedding
|
|
538
602
|
|
|
539
603
|
def _adj_position_embedding(self, pos_embed_used, batch_size):
|
|
@@ -679,25 +743,27 @@ class _SegmentPatch(nn.Module):
|
|
|
679
743
|
|
|
680
744
|
|
|
681
745
|
class _PatchEmbed(nn.Module):
|
|
682
|
-
"""EEG to Patch Embedding.
|
|
746
|
+
"""EEG to Patch Embedding for Neural Decoder mode.
|
|
683
747
|
|
|
684
748
|
This code is used when we want to apply the patch embedding
|
|
685
|
-
after the codebook layer.
|
|
749
|
+
after the codebook layer (Neural Decoder mode).
|
|
750
|
+
|
|
751
|
+
The input is expected to be in the format (Batch, n_channels, n_times),
|
|
752
|
+
but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
|
|
753
|
+
This class reshapes the input to the pre-patched format, then applies a 2D
|
|
754
|
+
convolution to project this pre-patched data to the embedding dimension,
|
|
755
|
+
and finally flattens across channels to produce a unified embedding.
|
|
686
756
|
|
|
687
757
|
Parameters:
|
|
688
758
|
-----------
|
|
689
759
|
n_times: int (default=2000)
|
|
690
|
-
Number of temporal components of the input tensor.
|
|
760
|
+
Number of temporal components of the input tensor (used for dimension calculation).
|
|
691
761
|
patch_size: int (default=200)
|
|
692
762
|
Size of the patch, default is 1-seconds with 200Hz.
|
|
693
763
|
in_channels: int (default=1)
|
|
694
|
-
Number of input channels
|
|
764
|
+
Number of input channels (from VQVAE codebook).
|
|
695
765
|
emb_dim: int (default=200)
|
|
696
|
-
Number of
|
|
697
|
-
we used the same as patch_size.
|
|
698
|
-
n_codebooks: int (default=62)
|
|
699
|
-
Number of patches to be used in the convolution, here,
|
|
700
|
-
we used the same as n_times // patch_size.
|
|
766
|
+
Number of output embedding dimension.
|
|
701
767
|
"""
|
|
702
768
|
|
|
703
769
|
def __init__(
|
|
@@ -707,10 +773,13 @@ class _PatchEmbed(nn.Module):
|
|
|
707
773
|
self.n_times = n_times
|
|
708
774
|
self.patch_size = patch_size
|
|
709
775
|
self.patch_shape = (1, self.n_times // self.patch_size)
|
|
710
|
-
n_patchs =
|
|
711
|
-
|
|
712
|
-
self.
|
|
776
|
+
self.n_patchs = self.n_times // self.patch_size
|
|
777
|
+
self.emb_dim = emb_dim
|
|
778
|
+
self.in_channels = in_channels
|
|
713
779
|
|
|
780
|
+
# 2D Conv to project the pre-patched data
|
|
781
|
+
# Input: (Batch, in_channels, n_patches, patch_size)
|
|
782
|
+
# After proj: (Batch, emb_dim, n_patches, 1)
|
|
714
783
|
self.proj = nn.Conv2d(
|
|
715
784
|
in_channels=in_channels,
|
|
716
785
|
out_channels=emb_dim,
|
|
@@ -718,27 +787,64 @@ class _PatchEmbed(nn.Module):
|
|
|
718
787
|
stride=(1, self.patch_size),
|
|
719
788
|
)
|
|
720
789
|
|
|
721
|
-
self.merge_transpose = Rearrange(
|
|
722
|
-
"Batch ch patch spatch -> Batch patch spatch ch",
|
|
723
|
-
)
|
|
724
|
-
|
|
725
790
|
def forward(self, x):
|
|
726
791
|
"""
|
|
727
|
-
Apply the
|
|
728
|
-
then merge the output tensor to the desired shape.
|
|
792
|
+
Apply the temporal projection to the input tensor after grouping channels.
|
|
729
793
|
|
|
730
|
-
Parameters
|
|
731
|
-
|
|
732
|
-
x: torch.Tensor
|
|
733
|
-
Input tensor of shape (Batch,
|
|
794
|
+
Parameters
|
|
795
|
+
----------
|
|
796
|
+
x : torch.Tensor
|
|
797
|
+
Input tensor of shape (Batch, n_channels, n_times) or
|
|
798
|
+
(Batch, n_channels, n_patches, patch_size).
|
|
734
799
|
|
|
735
|
-
|
|
800
|
+
Returns
|
|
736
801
|
-------
|
|
737
|
-
|
|
738
|
-
Output tensor of shape (Batch, n_patchs,
|
|
802
|
+
torch.Tensor
|
|
803
|
+
Output tensor of shape (Batch, n_patchs, emb_dim).
|
|
739
804
|
"""
|
|
805
|
+
if x.ndim == 4:
|
|
806
|
+
batch_size, n_channels, n_patchs, patch_len = x.shape
|
|
807
|
+
if patch_len != self.patch_size:
|
|
808
|
+
raise ValueError(
|
|
809
|
+
"When providing a 4D tensor, the last dimension "
|
|
810
|
+
f"({patch_len}) must match patch_size ({self.patch_size})."
|
|
811
|
+
)
|
|
812
|
+
n_times = n_patchs * patch_len
|
|
813
|
+
x = x.reshape(batch_size, n_channels, n_times)
|
|
814
|
+
elif x.ndim == 3:
|
|
815
|
+
batch_size, n_channels, n_times = x.shape
|
|
816
|
+
else:
|
|
817
|
+
raise ValueError(
|
|
818
|
+
"Input must be either 3D (batch, channels, times) or "
|
|
819
|
+
"4D (batch, channels, n_patches, patch_size)."
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
if n_times % self.patch_size != 0:
|
|
823
|
+
raise ValueError(
|
|
824
|
+
f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
|
|
825
|
+
)
|
|
826
|
+
if n_channels % self.in_channels != 0:
|
|
827
|
+
raise ValueError(
|
|
828
|
+
"The input channel dimension "
|
|
829
|
+
f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
group_size = n_channels // self.in_channels
|
|
833
|
+
|
|
834
|
+
# Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
|
|
835
|
+
# EEG channels as the spatial height dimension.
|
|
836
|
+
# Shape after view: (Batch, in_channels, group_size, n_times)
|
|
837
|
+
x = x.view(batch_size, self.in_channels, group_size, n_times)
|
|
838
|
+
|
|
839
|
+
# Apply the temporal projection per group.
|
|
840
|
+
# Output shape: (Batch, emb_dim, group_size, n_patchs)
|
|
740
841
|
x = self.proj(x)
|
|
741
|
-
|
|
842
|
+
|
|
843
|
+
# THIS IS braindecode's MODIFICATION:
|
|
844
|
+
# Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
|
|
845
|
+
x = x.mean(dim=2)
|
|
846
|
+
x = x.transpose(1, 2).contiguous()
|
|
847
|
+
|
|
742
848
|
return x
|
|
743
849
|
|
|
744
850
|
|