braindecode 1.3.0.dev173577785__py3-none-any.whl → 1.3.0.dev173767962__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +115 -151
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +2 -2
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +7 -7
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +4 -0
- braindecode/models/atcnet.py +26 -27
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +2 -0
- braindecode/models/ctnet.py +6 -3
- braindecode/models/deepsleepnet.py +27 -18
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegnet.py +1 -1
- braindecode/models/labram.py +188 -84
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sstdpn.py +11 -11
- braindecode/models/summary.csv +2 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +2 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/preprocess.py +12 -12
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev173577785.dist-info → braindecode-1.3.0.dev173767962.dist-info}/METADATA +4 -2
- {braindecode-1.3.0.dev173577785.dist-info → braindecode-1.3.0.dev173767962.dist-info}/RECORD +50 -48
- {braindecode-1.3.0.dev173577785.dist-info → braindecode-1.3.0.dev173767962.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev173577785.dist-info → braindecode-1.3.0.dev173767962.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev173577785.dist-info → braindecode-1.3.0.dev173767962.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev173577785.dist-info → braindecode-1.3.0.dev173767962.dist-info}/top_level.txt +0 -0
|
@@ -35,51 +35,57 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
|
|
|
35
35
|
- :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
|
|
36
36
|
|
|
37
37
|
- *Operations.*
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
38
|
+
|
|
39
|
+
- `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
|
|
40
|
+
- `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
|
|
41
|
+
- `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
|
|
42
|
+
- `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
|
|
43
|
+
- `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
|
|
44
|
+
- `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
|
|
44
45
|
|
|
45
46
|
*Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
|
|
46
47
|
|
|
47
48
|
- :class:`_InceptionModule2` **(refinement at coarser timebase)**
|
|
48
49
|
|
|
49
50
|
- *Operations.*
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
51
|
+
|
|
52
|
+
- `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
|
|
53
|
+
- `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
|
|
54
|
+
- `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
|
|
55
|
+
- `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
|
|
56
|
+
- `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
|
|
57
|
+
- `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
|
|
58
|
+
- `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
|
|
57
59
|
|
|
58
60
|
*Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
|
|
59
61
|
|
|
60
62
|
- :class:`_OutputModule` **(aggregation + readout)**
|
|
61
63
|
|
|
62
64
|
- *Operations.*
|
|
63
|
-
|
|
64
|
-
|
|
65
|
+
|
|
66
|
+
- :class:`torch.nn.Flatten`
|
|
67
|
+
- :class:`torch.nn.Linear` ``(features → 2)``
|
|
65
68
|
|
|
66
69
|
.. rubric:: Convolutional Details
|
|
67
70
|
|
|
68
71
|
- **Temporal (where time-domain patterns are learned).**
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
|
|
73
|
+
First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
|
|
74
|
+
(≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
|
|
75
|
+
``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
|
|
76
|
+
temporal resolution changes only via average pooling.
|
|
73
77
|
|
|
74
78
|
- **Spatial (how electrodes are processed).**
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
|
|
80
|
+
Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
|
|
81
|
+
yielding scale-specific channel projections (no cross-branch mixing until concatenation).
|
|
82
|
+
There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
|
|
78
83
|
|
|
79
84
|
- **Spectral (how frequency information is captured).**
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
85
|
+
|
|
86
|
+
No explicit transform; multiple temporal kernels form a *learned filter bank* over
|
|
87
|
+
ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
|
|
88
|
+
post-stimulus components.
|
|
83
89
|
|
|
84
90
|
.. rubric:: Additional Mechanisms
|
|
85
91
|
|
braindecode/models/eegnet.py
CHANGED
|
@@ -57,7 +57,7 @@ class EEGNet(EEGModuleMixin, nn.Sequential):
|
|
|
57
57
|
|
|
58
58
|
- **Temporal.** The initial temporal convs serve as a *learned filter bank*:
|
|
59
59
|
long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
|
|
60
|
-
Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature
|
|
60
|
+
Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature's spectrum [Lawhern2018]_.
|
|
61
61
|
|
|
62
62
|
- **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
|
|
63
63
|
With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
|
braindecode/models/labram.py
CHANGED
|
@@ -22,6 +22,8 @@ from braindecode.modules import MLP, DropPath
|
|
|
22
22
|
class Labram(EEGModuleMixin, nn.Module):
|
|
23
23
|
"""Labram from Jiang, W B et al (2024) [Jiang2024]_.
|
|
24
24
|
|
|
25
|
+
:bdg-danger:`Large Brain Model`
|
|
26
|
+
|
|
25
27
|
.. figure:: https://arxiv.org/html/2405.18765v1/x1.png
|
|
26
28
|
:align: center
|
|
27
29
|
:alt: Labram Architecture.
|
|
@@ -43,31 +45,45 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
43
45
|
equals True. The original implementation uses (batch, n_chans, n_patches,
|
|
44
46
|
patch_size) as input with static segmentation of the input data.
|
|
45
47
|
|
|
46
|
-
The models have the following sequence of steps
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
The models have the following sequence of steps::
|
|
49
|
+
|
|
50
|
+
if neural tokenizer:
|
|
51
|
+
- SegmentPatch: Segment the input data in patches;
|
|
52
|
+
- TemporalConv: Apply a temporal convolution to the segmented data;
|
|
53
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
54
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
55
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
56
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
- PatchEmbed: Apply a patch embedding to the input data;
|
|
60
|
+
- Residual adding cls, temporal and position embeddings (optional);
|
|
61
|
+
- WindowsAttentionBlock: Apply a windows attention block to the data;
|
|
62
|
+
- LayerNorm: Apply layer normalization to the data;
|
|
63
|
+
- Linear: An head linear layer to transformer the data into classes.
|
|
61
64
|
|
|
62
65
|
.. versionadded:: 0.9
|
|
63
66
|
|
|
67
|
+
|
|
68
|
+
Examples
|
|
69
|
+
--------
|
|
70
|
+
Load pre-trained weights::
|
|
71
|
+
|
|
72
|
+
>>> import torch
|
|
73
|
+
>>> from braindecode.models import Labram
|
|
74
|
+
>>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
|
|
75
|
+
>>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
|
|
76
|
+
>>> state = torch.hub.load_state_dict_from_url(url, progress=True)
|
|
77
|
+
>>> model.load_state_dict(state)
|
|
78
|
+
|
|
79
|
+
|
|
64
80
|
Parameters
|
|
65
81
|
----------
|
|
66
82
|
patch_size : int
|
|
67
83
|
The size of the patch to be used in the patch embedding.
|
|
68
84
|
emb_size : int
|
|
69
85
|
The dimension of the embedding.
|
|
70
|
-
|
|
86
|
+
in_conv_channels : int
|
|
71
87
|
The number of convolutional input channels.
|
|
72
88
|
out_channels : int
|
|
73
89
|
The number of convolutional output channels.
|
|
@@ -79,8 +95,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
79
95
|
The expansion ratio of the mlp layer
|
|
80
96
|
qkv_bias : bool (default=False)
|
|
81
97
|
If True, add a learnable bias to the query, key, and value tensors.
|
|
82
|
-
qk_norm : Pytorch Normalize layer (default=
|
|
83
|
-
If not None, apply LayerNorm to the query and key tensors
|
|
98
|
+
qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
99
|
+
If not None, apply LayerNorm to the query and key tensors.
|
|
100
|
+
Default is nn.LayerNorm for better weight transfer from original LaBraM.
|
|
101
|
+
Set to None to disable Q,K normalization.
|
|
84
102
|
qk_scale : float (default=None)
|
|
85
103
|
If not None, use this value as the scale factor. If None,
|
|
86
104
|
use head_dim**-0.5, where head_dim = dim // num_heads.
|
|
@@ -92,9 +110,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
92
110
|
Dropout rate for the attention weights used on DropPath.
|
|
93
111
|
norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
|
|
94
112
|
The normalization layer to be used.
|
|
95
|
-
init_values : float (default=
|
|
113
|
+
init_values : float (default=0.1)
|
|
96
114
|
If not None, use this value to initialize the gamma_1 and gamma_2
|
|
97
|
-
parameters.
|
|
115
|
+
parameters for residual scaling. Default is 0.1 for better weight
|
|
116
|
+
transfer from original LaBraM. Set to None to disable.
|
|
98
117
|
use_abs_pos_emb : bool (default=True)
|
|
99
118
|
If True, use absolute position embedding.
|
|
100
119
|
use_mean_pooling : bool (default=True)
|
|
@@ -135,19 +154,19 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
135
154
|
input_window_seconds=None,
|
|
136
155
|
patch_size=200,
|
|
137
156
|
emb_size=200,
|
|
138
|
-
|
|
157
|
+
in_conv_channels=1,
|
|
139
158
|
out_channels=8,
|
|
140
159
|
n_layers=12,
|
|
141
160
|
att_num_heads=10,
|
|
142
161
|
mlp_ratio=4.0,
|
|
143
162
|
qkv_bias=False,
|
|
144
|
-
qk_norm=
|
|
163
|
+
qk_norm=nn.LayerNorm,
|
|
145
164
|
qk_scale=None,
|
|
146
165
|
drop_prob=0.0,
|
|
147
166
|
attn_drop_prob=0.0,
|
|
148
167
|
drop_path_prob=0.0,
|
|
149
168
|
norm_layer=nn.LayerNorm,
|
|
150
|
-
init_values=
|
|
169
|
+
init_values=0.1,
|
|
151
170
|
use_abs_pos_emb=True,
|
|
152
171
|
use_mean_pooling=True,
|
|
153
172
|
init_scale=0.001,
|
|
@@ -183,15 +202,15 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
183
202
|
self.patch_size = patch_size
|
|
184
203
|
self.n_path = self.n_times // self.patch_size
|
|
185
204
|
|
|
186
|
-
if neural_tokenizer and
|
|
205
|
+
if neural_tokenizer and in_conv_channels != 1:
|
|
187
206
|
warn(
|
|
188
207
|
"The model is in Neural Tokenizer mode, but the variable "
|
|
189
|
-
+ "`
|
|
190
|
-
+ "`
|
|
191
|
-
+ "
|
|
208
|
+
+ "`in_conv_channels` is different from the default values."
|
|
209
|
+
+ "`in_conv_channels` is only needed for the Neural Decoder mode."
|
|
210
|
+
+ "in_conv_channels is not used in the Neural Tokenizer mode.",
|
|
192
211
|
UserWarning,
|
|
193
212
|
)
|
|
194
|
-
|
|
213
|
+
in_conv_channels = 1
|
|
195
214
|
# If you can use the model in Neural Tokenizer mode,
|
|
196
215
|
# temporal conv layer will be use over the patched dataset
|
|
197
216
|
if neural_tokenizer:
|
|
@@ -228,7 +247,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
228
247
|
_PatchEmbed(
|
|
229
248
|
n_times=self.n_times,
|
|
230
249
|
patch_size=patch_size,
|
|
231
|
-
in_channels=
|
|
250
|
+
in_channels=in_conv_channels,
|
|
232
251
|
emb_dim=self.emb_size,
|
|
233
252
|
),
|
|
234
253
|
)
|
|
@@ -373,8 +392,7 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
373
392
|
Parameters
|
|
374
393
|
----------
|
|
375
394
|
x : torch.Tensor
|
|
376
|
-
The input data with shape (batch, n_chans,
|
|
377
|
-
if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
|
|
395
|
+
The input data with shape (batch, n_chans, n_times).
|
|
378
396
|
input_chans : int
|
|
379
397
|
The number of input channels.
|
|
380
398
|
return_patch_tokens : bool
|
|
@@ -387,37 +405,72 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
387
405
|
x : torch.Tensor
|
|
388
406
|
The output of the model.
|
|
389
407
|
"""
|
|
408
|
+
batch_size = x.shape[0]
|
|
409
|
+
|
|
390
410
|
if self.neural_tokenizer:
|
|
391
|
-
|
|
411
|
+
# For neural tokenizer: input is (batch, n_chans, n_times)
|
|
412
|
+
# patch_embed returns (batch, n_chans, emb_dim)
|
|
413
|
+
x = self.patch_embed(x)
|
|
414
|
+
# x shape: (batch, n_chans, emb_dim)
|
|
415
|
+
n_patch = self.n_chans
|
|
416
|
+
temporal = self.emb_size
|
|
392
417
|
else:
|
|
393
|
-
|
|
394
|
-
|
|
418
|
+
# For neural decoder: input is (batch, n_chans, n_times)
|
|
419
|
+
# patch_embed returns (batch, n_patchs, emb_dim)
|
|
420
|
+
x = self.patch_embed(x)
|
|
421
|
+
# x shape: (batch, n_patchs, emb_dim)
|
|
422
|
+
batch_size, n_patch, temporal = x.shape
|
|
423
|
+
|
|
395
424
|
# add the [CLS] token to the embedded patch tokens
|
|
396
425
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
397
426
|
|
|
427
|
+
# Concatenate cls token with patch/channel embeddings
|
|
398
428
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
399
429
|
|
|
400
430
|
# Positional Embedding
|
|
401
|
-
if input_chans is not None:
|
|
402
|
-
pos_embed_used = self.position_embedding[:, input_chans]
|
|
403
|
-
else:
|
|
404
|
-
pos_embed_used = self.position_embedding
|
|
405
|
-
|
|
406
431
|
if self.position_embedding is not None:
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
432
|
+
if self.neural_tokenizer:
|
|
433
|
+
# In tokenizer mode, use channel-based position embedding
|
|
434
|
+
if input_chans is not None:
|
|
435
|
+
pos_embed_used = self.position_embedding[:, input_chans]
|
|
436
|
+
else:
|
|
437
|
+
pos_embed_used = self.position_embedding
|
|
438
|
+
|
|
439
|
+
pos_embed = self._adj_position_embedding(
|
|
440
|
+
pos_embed_used=pos_embed_used, batch_size=batch_size
|
|
441
|
+
)
|
|
442
|
+
else:
|
|
443
|
+
# In decoder mode, we have different number of patches
|
|
444
|
+
# Adapt position embedding for n_patch patches
|
|
445
|
+
# Use the first n_patch+1 positions from position_embedding
|
|
446
|
+
n_pos = min(self.position_embedding.shape[1], n_patch + 1)
|
|
447
|
+
pos_embed_used = self.position_embedding[:, :n_pos, :]
|
|
448
|
+
pos_embed = pos_embed_used.expand(batch_size, -1, -1)
|
|
449
|
+
|
|
410
450
|
x += pos_embed
|
|
411
451
|
|
|
412
452
|
# The time embedding is added across the channels after the [CLS] token
|
|
413
453
|
if self.neural_tokenizer:
|
|
414
454
|
num_ch = self.n_chans
|
|
455
|
+
time_embed = self._adj_temporal_embedding(
|
|
456
|
+
num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
|
|
457
|
+
)
|
|
458
|
+
x[:, 1:, :] += time_embed
|
|
415
459
|
else:
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
460
|
+
# In decoder mode, we have n_patch patches and don't need to expand
|
|
461
|
+
# Just broadcast the temporal embedding
|
|
462
|
+
if temporal is None:
|
|
463
|
+
temporal = self.emb_size
|
|
464
|
+
|
|
465
|
+
# Get temporal embeddings for n_patch patches
|
|
466
|
+
n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
|
|
467
|
+
time_embed = self.temporal_embedding[
|
|
468
|
+
:, 1 : n_time_tokens + 1, :
|
|
469
|
+
] # (1, n_patch, emb_dim)
|
|
470
|
+
time_embed = time_embed.expand(
|
|
471
|
+
batch_size, -1, -1
|
|
472
|
+
) # (batch, n_patch, emb_dim)
|
|
473
|
+
x[:, 1:, :] += time_embed
|
|
421
474
|
|
|
422
475
|
x = self.pos_drop(x)
|
|
423
476
|
|
|
@@ -428,10 +481,10 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
428
481
|
if self.fc_norm is not None:
|
|
429
482
|
if return_all_tokens:
|
|
430
483
|
return self.fc_norm(x)
|
|
431
|
-
|
|
484
|
+
tokens = x[:, 1:, :]
|
|
432
485
|
if return_patch_tokens:
|
|
433
|
-
return self.fc_norm(
|
|
434
|
-
return self.fc_norm(
|
|
486
|
+
return self.fc_norm(tokens)
|
|
487
|
+
return self.fc_norm(tokens.mean(1))
|
|
435
488
|
else:
|
|
436
489
|
if return_all_tokens:
|
|
437
490
|
return x
|
|
@@ -505,14 +558,16 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
505
558
|
def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
|
|
506
559
|
"""
|
|
507
560
|
Adjust the dimensions of the time embedding to match the
|
|
508
|
-
number of channels.
|
|
561
|
+
number of channels or patches.
|
|
509
562
|
|
|
510
563
|
Parameters
|
|
511
564
|
----------
|
|
512
565
|
num_ch : int
|
|
513
|
-
The number of channels or number of
|
|
566
|
+
The number of channels or number of patches.
|
|
514
567
|
batch_size : int
|
|
515
568
|
Batch size of the input data.
|
|
569
|
+
dim_embed : int
|
|
570
|
+
The embedding dimension (temporal feature dimension).
|
|
516
571
|
|
|
517
572
|
Returns
|
|
518
573
|
-------
|
|
@@ -523,17 +578,24 @@ class Labram(EEGModuleMixin, nn.Module):
|
|
|
523
578
|
if dim_embed is None:
|
|
524
579
|
cut_dimension = self.patch_size
|
|
525
580
|
else:
|
|
526
|
-
cut_dimension = dim_embed
|
|
527
|
-
|
|
528
|
-
|
|
581
|
+
cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
|
|
582
|
+
|
|
583
|
+
# Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
|
|
584
|
+
# Slice to cut_dimension: (1, cut_dimension, emb_size)
|
|
585
|
+
temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
|
|
586
|
+
|
|
529
587
|
# Add a new dimension to the time embedding
|
|
530
|
-
# e.g. (
|
|
588
|
+
# e.g. (1, 5, 200) -> (1, 1, 5, 200)
|
|
531
589
|
temporal_embedding = temporal_embedding.unsqueeze(1)
|
|
532
|
-
|
|
533
|
-
#
|
|
590
|
+
|
|
591
|
+
# Expand the time embedding to match the number of channels or patches
|
|
592
|
+
# (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
|
|
534
593
|
temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
|
|
594
|
+
|
|
535
595
|
# Flatten the intermediate dimensions
|
|
596
|
+
# (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
|
|
536
597
|
temporal_embedding = temporal_embedding.flatten(1, 2)
|
|
598
|
+
|
|
537
599
|
return temporal_embedding
|
|
538
600
|
|
|
539
601
|
def _adj_position_embedding(self, pos_embed_used, batch_size):
|
|
@@ -679,25 +741,27 @@ class _SegmentPatch(nn.Module):
|
|
|
679
741
|
|
|
680
742
|
|
|
681
743
|
class _PatchEmbed(nn.Module):
|
|
682
|
-
"""EEG to Patch Embedding.
|
|
744
|
+
"""EEG to Patch Embedding for Neural Decoder mode.
|
|
683
745
|
|
|
684
746
|
This code is used when we want to apply the patch embedding
|
|
685
|
-
after the codebook layer.
|
|
747
|
+
after the codebook layer (Neural Decoder mode).
|
|
748
|
+
|
|
749
|
+
The input is expected to be in the format (Batch, n_channels, n_times),
|
|
750
|
+
but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
|
|
751
|
+
This class reshapes the input to the pre-patched format, then applies a 2D
|
|
752
|
+
convolution to project this pre-patched data to the embedding dimension,
|
|
753
|
+
and finally flattens across channels to produce a unified embedding.
|
|
686
754
|
|
|
687
755
|
Parameters:
|
|
688
756
|
-----------
|
|
689
757
|
n_times: int (default=2000)
|
|
690
|
-
Number of temporal components of the input tensor.
|
|
758
|
+
Number of temporal components of the input tensor (used for dimension calculation).
|
|
691
759
|
patch_size: int (default=200)
|
|
692
760
|
Size of the patch, default is 1-seconds with 200Hz.
|
|
693
761
|
in_channels: int (default=1)
|
|
694
|
-
Number of input channels
|
|
762
|
+
Number of input channels (from VQVAE codebook).
|
|
695
763
|
emb_dim: int (default=200)
|
|
696
|
-
Number of
|
|
697
|
-
we used the same as patch_size.
|
|
698
|
-
n_codebooks: int (default=62)
|
|
699
|
-
Number of patches to be used in the convolution, here,
|
|
700
|
-
we used the same as n_times // patch_size.
|
|
764
|
+
Number of output embedding dimension.
|
|
701
765
|
"""
|
|
702
766
|
|
|
703
767
|
def __init__(
|
|
@@ -707,10 +771,13 @@ class _PatchEmbed(nn.Module):
|
|
|
707
771
|
self.n_times = n_times
|
|
708
772
|
self.patch_size = patch_size
|
|
709
773
|
self.patch_shape = (1, self.n_times // self.patch_size)
|
|
710
|
-
n_patchs =
|
|
711
|
-
|
|
712
|
-
self.
|
|
774
|
+
self.n_patchs = self.n_times // self.patch_size
|
|
775
|
+
self.emb_dim = emb_dim
|
|
776
|
+
self.in_channels = in_channels
|
|
713
777
|
|
|
778
|
+
# 2D Conv to project the pre-patched data
|
|
779
|
+
# Input: (Batch, in_channels, n_patches, patch_size)
|
|
780
|
+
# After proj: (Batch, emb_dim, n_patches, 1)
|
|
714
781
|
self.proj = nn.Conv2d(
|
|
715
782
|
in_channels=in_channels,
|
|
716
783
|
out_channels=emb_dim,
|
|
@@ -718,27 +785,64 @@ class _PatchEmbed(nn.Module):
|
|
|
718
785
|
stride=(1, self.patch_size),
|
|
719
786
|
)
|
|
720
787
|
|
|
721
|
-
self.merge_transpose = Rearrange(
|
|
722
|
-
"Batch ch patch spatch -> Batch patch spatch ch",
|
|
723
|
-
)
|
|
724
|
-
|
|
725
788
|
def forward(self, x):
|
|
726
789
|
"""
|
|
727
|
-
Apply the
|
|
728
|
-
then merge the output tensor to the desired shape.
|
|
790
|
+
Apply the temporal projection to the input tensor after grouping channels.
|
|
729
791
|
|
|
730
|
-
Parameters
|
|
731
|
-
|
|
732
|
-
x: torch.Tensor
|
|
733
|
-
Input tensor of shape (Batch,
|
|
792
|
+
Parameters
|
|
793
|
+
----------
|
|
794
|
+
x : torch.Tensor
|
|
795
|
+
Input tensor of shape (Batch, n_channels, n_times) or
|
|
796
|
+
(Batch, n_channels, n_patches, patch_size).
|
|
734
797
|
|
|
735
|
-
|
|
798
|
+
Returns
|
|
736
799
|
-------
|
|
737
|
-
|
|
738
|
-
Output tensor of shape (Batch, n_patchs,
|
|
800
|
+
torch.Tensor
|
|
801
|
+
Output tensor of shape (Batch, n_patchs, emb_dim).
|
|
739
802
|
"""
|
|
803
|
+
if x.ndim == 4:
|
|
804
|
+
batch_size, n_channels, n_patchs, patch_len = x.shape
|
|
805
|
+
if patch_len != self.patch_size:
|
|
806
|
+
raise ValueError(
|
|
807
|
+
"When providing a 4D tensor, the last dimension "
|
|
808
|
+
f"({patch_len}) must match patch_size ({self.patch_size})."
|
|
809
|
+
)
|
|
810
|
+
n_times = n_patchs * patch_len
|
|
811
|
+
x = x.reshape(batch_size, n_channels, n_times)
|
|
812
|
+
elif x.ndim == 3:
|
|
813
|
+
batch_size, n_channels, n_times = x.shape
|
|
814
|
+
else:
|
|
815
|
+
raise ValueError(
|
|
816
|
+
"Input must be either 3D (batch, channels, times) or "
|
|
817
|
+
"4D (batch, channels, n_patches, patch_size)."
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
if n_times % self.patch_size != 0:
|
|
821
|
+
raise ValueError(
|
|
822
|
+
f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
|
|
823
|
+
)
|
|
824
|
+
if n_channels % self.in_channels != 0:
|
|
825
|
+
raise ValueError(
|
|
826
|
+
"The input channel dimension "
|
|
827
|
+
f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
group_size = n_channels // self.in_channels
|
|
831
|
+
|
|
832
|
+
# Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
|
|
833
|
+
# EEG channels as the spatial height dimension.
|
|
834
|
+
# Shape after view: (Batch, in_channels, group_size, n_times)
|
|
835
|
+
x = x.view(batch_size, self.in_channels, group_size, n_times)
|
|
836
|
+
|
|
837
|
+
# Apply the temporal projection per group.
|
|
838
|
+
# Output shape: (Batch, emb_dim, group_size, n_patchs)
|
|
740
839
|
x = self.proj(x)
|
|
741
|
-
|
|
840
|
+
|
|
841
|
+
# THIS IS braindecode's MODIFICATION:
|
|
842
|
+
# Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
|
|
843
|
+
x = x.mean(dim=2)
|
|
844
|
+
x = x.transpose(1, 2).contiguous()
|
|
845
|
+
|
|
742
846
|
return x
|
|
743
847
|
|
|
744
848
|
|