braindecode 1.3.0.dev181065563__py3-none-any.whl → 1.3.0.dev183667303__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +116 -152
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +4 -3
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +18 -13
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +6 -0
- braindecode/models/atcnet.py +32 -33
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +3 -1
- braindecode/models/ctnet.py +6 -3
- braindecode/models/deepsleepnet.py +27 -18
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegnet.py +1 -1
- braindecode/models/labram.py +193 -87
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +9 -6
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +3 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +23 -14
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/METADATA +6 -2
- braindecode-1.3.0.dev183667303.dist-info/RECORD +106 -0
- braindecode-1.3.0.dev181065563.dist-info/RECORD +0 -101
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev181065563.dist-info → braindecode-1.3.0.dev183667303.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,640 @@
|
|
|
1
|
+
# Authors: Jose Mauricio <josemaurici3991@gmail.com>
|
|
2
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
from braindecode.models.base import EEGModuleMixin
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PBT(EEGModuleMixin, nn.Module):
|
|
17
|
+
r"""Patched Brain Transformer (PBT) model from Klein et al. (2025) [pbt]_.
|
|
18
|
+
|
|
19
|
+
:bdg-danger:`Large Brain Models`
|
|
20
|
+
|
|
21
|
+
This implementation was based in https://github.com/timonkl/PatchedBrainTransformer/
|
|
22
|
+
|
|
23
|
+
.. figure:: https://raw.githubusercontent.com/timonkl/PatchedBrainTransformer/refs/heads/main/PBT_sketch.png
|
|
24
|
+
:align: center
|
|
25
|
+
:alt: Patched Brain Transformer Architecture
|
|
26
|
+
:width: 680px
|
|
27
|
+
|
|
28
|
+
PBT tokenizes EEG trials into per-channel patches, linearly projects each
|
|
29
|
+
patch to a model embedding dimension, prepends a classification token and
|
|
30
|
+
adds channel-aware positional embeddings. The token sequence is processed
|
|
31
|
+
by a Transformer encoder stack and classification is performed from the
|
|
32
|
+
classification token.
|
|
33
|
+
|
|
34
|
+
.. rubric:: Macro Components
|
|
35
|
+
|
|
36
|
+
- ``PBT.tokenization`` **(patch extraction)**
|
|
37
|
+
|
|
38
|
+
*Operations.* The pre-processed EEG signal :math:`X \in \mathbb{R}^{C \times T}`
|
|
39
|
+
(with :math:`C = \text{n_chans}` and :math:`T = \text{n_times}`) is divided into
|
|
40
|
+
non-overlapping patches of size :math:`d_{\text{input}}` along the time axis.
|
|
41
|
+
This process yields :math:`N` total patches, calculated as
|
|
42
|
+
:math:`N = C \left\lfloor \frac{T}{D} \right\rfloor` (where :math:`D = d_{\text{input}}`).
|
|
43
|
+
When time shifts are applied, :math:`N` decreases to
|
|
44
|
+
:math:`N = C \left\lfloor \frac{T - T_{\text{aug}}}{D} \right\rfloor`.
|
|
45
|
+
|
|
46
|
+
*Role.* Tokenizes EEG trials into fixed-size, per-channel patches so the model
|
|
47
|
+
remains adaptive to different numbers of channels and recording lengths.
|
|
48
|
+
Process is inspired by Vision Transformers [visualtransformer]_ and
|
|
49
|
+
adapted for GPT context from [efficient-batchpacking]_.
|
|
50
|
+
|
|
51
|
+
- ``PBT.patch_projection`` **(patch embedding)**
|
|
52
|
+
|
|
53
|
+
*Operations.* The linear layer ``PBT.patch_projection`` maps the tokens from dimension
|
|
54
|
+
:math:`d_{\text{input}}` to the Transformer embedding dimension :math:`d_{\text{model}}`.
|
|
55
|
+
Patches :math:`X_P` are projected as :math:`X_E = X_P W_E^\top`, where
|
|
56
|
+
:math:`W_E \in \mathbb{R}^{d_{\text{model}} \times D}`. In this configuration
|
|
57
|
+
:math:`d_{\text{model}} = 2D` with :math:`D = d_{\text{input}}`.
|
|
58
|
+
|
|
59
|
+
*Interpretability.* Learns periodic structures similar to frequency filters in
|
|
60
|
+
the first convolutional layers of CNNs (for example :class:`~braindecode.models.EEGNet`).
|
|
61
|
+
The learned filters frequently focus on the high-frequency range (20-40 Hz),
|
|
62
|
+
which correlates with beta and gamma waves linked to higher concentration levels.
|
|
63
|
+
|
|
64
|
+
- ``PBT.cls_token`` **(classification token)**
|
|
65
|
+
|
|
66
|
+
*Operations.* A classification token :math:`[c_{\text{ls}}] \in \mathbb{R}^{1 \times d_{\text{model}}}`
|
|
67
|
+
is prepended to the projected patch sequence :math:`X_E`. The CLS token can optionally
|
|
68
|
+
be learnable (see ``learnable_cls``).
|
|
69
|
+
|
|
70
|
+
*Role.* Acts as a dedicated readout token that aggregates information through the
|
|
71
|
+
Transformer encoder stack.
|
|
72
|
+
|
|
73
|
+
- ``PBT.pos_embedding`` **(positional embedding)**
|
|
74
|
+
|
|
75
|
+
*Operations.* Positional indices are generated by ``PBT.linear_projection``, an instance
|
|
76
|
+
of :class:`~braindecode.models.patchedtransformer._ChannelEncoding`, and mapped to vectors
|
|
77
|
+
through :class:`~torch.nn.Embedding`. The embedding table
|
|
78
|
+
:math:`W_{\text{pos}} \in \mathbb{R}^{(N+1) \times d_{\text{model}}}` is added to the token
|
|
79
|
+
sequence, yielding :math:`X_{\text{pos}} = [c_{\text{ls}}, X_E] + W_{\text{pos}}`.
|
|
80
|
+
|
|
81
|
+
*Role/Interpretability.* Introduces spatial and temporal dependence to counter the
|
|
82
|
+
position invariance of the Transformer encoder. The learned positional embedding
|
|
83
|
+
exposes spatial relationships, often revealing a symmetric pattern in central regions
|
|
84
|
+
(C1-C6) associated with the motor cortex.
|
|
85
|
+
|
|
86
|
+
- ``PBT.transformer_encoder`` **(sequence processing and attention)**
|
|
87
|
+
|
|
88
|
+
*Operations.* The token sequence passes through :math:`n_{\text{blocks}}` Transformer
|
|
89
|
+
encoder layers. Each block combines a Multi-Head Self-Attention (MHSA) module with
|
|
90
|
+
``num_heads`` attention heads and a Feed-Forward Network (FFN). Both MHSA
|
|
91
|
+
and FFN use parallel residual connections with Layer Normalization inside the blocks
|
|
92
|
+
and apply dropout (``drop_prob``) within the Transformer components.
|
|
93
|
+
|
|
94
|
+
*Role/Robustness.* Self-attention enables every token to consider all others, capturing
|
|
95
|
+
global temporal and spatial dependencies immediately and adaptively. This architecture
|
|
96
|
+
accommodates arbitrary numbers of patches and channels, supporting pre-training across
|
|
97
|
+
diverse datasets.
|
|
98
|
+
|
|
99
|
+
- ``PBT.final_layer`` **(readout)**
|
|
100
|
+
|
|
101
|
+
*Operations.* A linear layer operates on the processed CLS token only, and the model
|
|
102
|
+
predicts class probabilities as :math:`y = \operatorname{softmax}([c_{\text{ls}}] W_{\text{class}}^\top + b_{\text{class}})`.
|
|
103
|
+
|
|
104
|
+
*Role.* Performs the final classification from the information aggregated into the CLS
|
|
105
|
+
token after the Transformer encoder stack.
|
|
106
|
+
|
|
107
|
+
.. rubric:: Convolutional Details
|
|
108
|
+
|
|
109
|
+
PBT omits convolutional layers; equivalent feature extraction is carried out by the patch
|
|
110
|
+
pipeline and attention stack.
|
|
111
|
+
|
|
112
|
+
* **Temporal.** Tokenization slices the EEG into fixed windows of size :math:`D = d_{\text{input}}`
|
|
113
|
+
(for the default configuration, :math:`D=64` samples :math:`\approx 0.256\,\text{s}` at
|
|
114
|
+
:math:`250\,\text{Hz}`), while ``PBT.patch_projection`` learns periodic patterns within each
|
|
115
|
+
patch. The Transformer encoder then models long- and short-range temporal dependencies through
|
|
116
|
+
self-attention.
|
|
117
|
+
|
|
118
|
+
* **Spatial.** Patches are channel-specific, keeping the architecture adaptive to any electrode
|
|
119
|
+
montage. Channel-aware positional encodings :math:`W_{\text{pos}}` capture relationships between
|
|
120
|
+
nearby sensors; learned embeddings often form symmetric motifs across motor cortex electrodes
|
|
121
|
+
(C1–C6), and self-attention propagates information across all channels jointly.
|
|
122
|
+
|
|
123
|
+
* **Spectral.** ``PBT.patch_projection`` acts similarly to the first convolutional layer in
|
|
124
|
+
:class:`~braindecode.models.EEGNet`, learning frequency-selective filters without an explicit
|
|
125
|
+
Fourier transform. The highest-energy filters typically reside between :math:`20` and
|
|
126
|
+
:math:`40\,\text{Hz}`, aligning with beta/gamma rhythms tied to focused motor imagery.
|
|
127
|
+
|
|
128
|
+
.. rubric:: Attention / Sequential Modules
|
|
129
|
+
|
|
130
|
+
* **Attention Details.** ``PBT.transformer_encoder`` stacks :math:`n_{\text{blocks}}` Transformer
|
|
131
|
+
encoder layers with Multi-Head Self-Attention. Every token attends to all others, enabling
|
|
132
|
+
immediate global integration across time and channels and supporting heterogeneous datasets.
|
|
133
|
+
Attention rollout visualisations highlight strong activations over motor cortex electrodes
|
|
134
|
+
(C3, C4, Cz) during motor imagery decoding.
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
.. warning::
|
|
138
|
+
|
|
139
|
+
**Important:** As the other Large Brain Models in Braindecode, :class:`PBT` is
|
|
140
|
+
designed for large-scale pre-training and fine-tuning. Training from
|
|
141
|
+
scratch on small datasets may lead to suboptimal results. Cross-Dataset
|
|
142
|
+
pre-training and subsequent fine-tuning is recommended to leverage the
|
|
143
|
+
full potential of this architecture.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
d_input : int, optional
|
|
148
|
+
Size (in samples) of each patch (token) extracted along the time axis.
|
|
149
|
+
d_model : int, optional
|
|
150
|
+
Transformer embedding dimensionality.
|
|
151
|
+
n_blocks : int, optional
|
|
152
|
+
Number of Transformer encoder layers.
|
|
153
|
+
num_heads : int, optional
|
|
154
|
+
Number of attention heads.
|
|
155
|
+
drop_prob : float, optional
|
|
156
|
+
Dropout probability used in Transformer components.
|
|
157
|
+
learnable_cls : bool, optional
|
|
158
|
+
Whether the classification token is learnable.
|
|
159
|
+
bias_transformer : bool, optional
|
|
160
|
+
Whether to use bias in Transformer linear layers.
|
|
161
|
+
activation : nn.Module, optional
|
|
162
|
+
Activation function class to use in Transformer feed-forward layers.
|
|
163
|
+
|
|
164
|
+
References
|
|
165
|
+
----------
|
|
166
|
+
.. [pbt] Klein, T., Minakowski, P., & Sager, S. (2025).
|
|
167
|
+
Flexible Patched Brain Transformer model for EEG decoding.
|
|
168
|
+
Scientific Reports, 15(1), 1-12.
|
|
169
|
+
https://www.nature.com/articles/s41598-025-86294-3
|
|
170
|
+
.. [visualtransformer] Dosovitskiy, A., Beyer, L., Kolesnikov, A.,
|
|
171
|
+
Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M.,
|
|
172
|
+
Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J. & Houlsby,
|
|
173
|
+
N. (2021). An Image is Worth 16x16 Words: Transformers for Image
|
|
174
|
+
Recognition at Scale. International Conference on Learning
|
|
175
|
+
Representations (ICLR).
|
|
176
|
+
.. [efficient-batchpacking] Krell, M. M., Kosec, M., Perez, S. P., &
|
|
177
|
+
Fitzgibbon, A. (2021). Efficient sequence packing without
|
|
178
|
+
cross-contamination: Accelerating large language models without
|
|
179
|
+
impacting performance. arXiv preprint arXiv:2107.02027.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __init__(
|
|
183
|
+
self,
|
|
184
|
+
# Signal related parameters
|
|
185
|
+
n_chans=None,
|
|
186
|
+
n_outputs=None,
|
|
187
|
+
n_times=None,
|
|
188
|
+
chs_info=None,
|
|
189
|
+
input_window_seconds=None,
|
|
190
|
+
sfreq=None,
|
|
191
|
+
# Model parameters
|
|
192
|
+
d_input: int = 64,
|
|
193
|
+
d_model: int = 128,
|
|
194
|
+
n_blocks: int = 4,
|
|
195
|
+
num_heads: int = 4,
|
|
196
|
+
drop_prob: float = 0.1,
|
|
197
|
+
learnable_cls=True,
|
|
198
|
+
bias_transformer=False,
|
|
199
|
+
activation: nn.Module = nn.GELU,
|
|
200
|
+
) -> None:
|
|
201
|
+
super().__init__(
|
|
202
|
+
n_outputs=n_outputs,
|
|
203
|
+
n_chans=n_chans,
|
|
204
|
+
chs_info=chs_info,
|
|
205
|
+
n_times=n_times,
|
|
206
|
+
input_window_seconds=input_window_seconds,
|
|
207
|
+
sfreq=sfreq,
|
|
208
|
+
)
|
|
209
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
210
|
+
# Store hyperparameters
|
|
211
|
+
self.d_input = d_input
|
|
212
|
+
self.d_model = d_model
|
|
213
|
+
self.n_blocks = n_blocks
|
|
214
|
+
self.num_heads = num_heads
|
|
215
|
+
self.drop_prob = drop_prob
|
|
216
|
+
|
|
217
|
+
# number of distinct positional indices (per-channel tokens + cls)
|
|
218
|
+
self.num_embeddings = (self.n_chans * self.n_times) // self.d_input
|
|
219
|
+
|
|
220
|
+
# Classification token (learnable or fixed zero)
|
|
221
|
+
if learnable_cls:
|
|
222
|
+
self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.002)
|
|
223
|
+
else:
|
|
224
|
+
# non-learnable zeroed tensor
|
|
225
|
+
self.cls_token = torch.full(
|
|
226
|
+
size=(1, 1, self.d_model),
|
|
227
|
+
fill_value=0,
|
|
228
|
+
requires_grad=False,
|
|
229
|
+
dtype=torch.float32,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Patching the signal and producing channel-aware positional index generator
|
|
233
|
+
self.patch_signal = _Patcher(
|
|
234
|
+
n_chans=self.n_chans, n_times=self.n_times, d_input=self.d_input
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Linear patch projection from token raw-size -> d_model
|
|
238
|
+
self.patching_projection = nn.Linear(
|
|
239
|
+
in_features=self.d_input, out_features=self.d_model, bias=False
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# actual embedding table mapping indices -> d_model
|
|
243
|
+
self.pos_embedding = nn.Embedding(
|
|
244
|
+
num_embeddings=self.num_embeddings + 1, embedding_dim=self.d_model
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Transformer encoder stack
|
|
248
|
+
self.transformer_encoder = _TransformerEncoder(
|
|
249
|
+
n_blocks=n_blocks,
|
|
250
|
+
d_model=self.d_model,
|
|
251
|
+
n_head=num_heads,
|
|
252
|
+
drop_prob=drop_prob,
|
|
253
|
+
bias=bias_transformer,
|
|
254
|
+
activation=activation,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# classification head on classify token - CLS token
|
|
258
|
+
self.final_layer = nn.Linear(
|
|
259
|
+
in_features=d_model, out_features=self.n_outputs, bias=True
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# initialize weights
|
|
263
|
+
self.apply(self._init_weights)
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _init_weights(module: nn.Module) -> None:
|
|
267
|
+
"""Weight initialization following the original implementation."""
|
|
268
|
+
if isinstance(module, nn.Linear):
|
|
269
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
270
|
+
if module.bias is not None:
|
|
271
|
+
torch.nn.init.zeros_(module.bias)
|
|
272
|
+
elif isinstance(module, nn.Embedding):
|
|
273
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.002)
|
|
274
|
+
|
|
275
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
276
|
+
"""The implementation follows the original code logic
|
|
277
|
+
|
|
278
|
+
- split input into windows of size `(num_embeddings - 1) * d_input`
|
|
279
|
+
- for each window: reshape into tokens, map positional indices to embeddings,
|
|
280
|
+
add cls token, run Transformer encoder and collect CLS outputs
|
|
281
|
+
- aggregate CLS outputs across windows (if >1) and pass through `final_layer`
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
X : torch.Tensor
|
|
286
|
+
Input tensor with shape (batch_size, n_chans, n_times)
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
torch.Tensor
|
|
291
|
+
Output logits with shape (batch_size, n_outputs).
|
|
292
|
+
"""
|
|
293
|
+
X, int_pos = self.patch_signal(X)
|
|
294
|
+
|
|
295
|
+
tokens = self.patching_projection(X)
|
|
296
|
+
|
|
297
|
+
cls_token = self.cls_token.expand(X.size(0), 1, -1)
|
|
298
|
+
cls_idx = torch.zeros((X.size(0), 1), dtype=torch.long, device=X.device)
|
|
299
|
+
|
|
300
|
+
tokens = torch.cat([cls_token, tokens], dim=1)
|
|
301
|
+
pos_emb = self.pos_embedding(int_pos)
|
|
302
|
+
transformer_out = self.transformer_encoder(tokens + pos_emb)
|
|
303
|
+
|
|
304
|
+
return self.final_layer(transformer_out[:, 0])
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class _LayerNorm(nn.Module):
|
|
308
|
+
"""Layer normalization with optional bias.
|
|
309
|
+
|
|
310
|
+
Simple wrapper around :func:`torch.nn.functional.layer_norm` exposing a
|
|
311
|
+
learnable scale and optional bias.
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
ndim : int
|
|
316
|
+
Number of features (normalized shape).
|
|
317
|
+
bias : bool
|
|
318
|
+
Whether to include a learnable bias term.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
def __init__(self, ndim: int, bias: bool) -> None:
|
|
322
|
+
super().__init__()
|
|
323
|
+
self.weight = nn.Parameter(torch.ones(ndim))
|
|
324
|
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
|
325
|
+
|
|
326
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
327
|
+
"""Apply layer normalization.
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
x : torch.Tensor
|
|
332
|
+
Input tensor where the last dimension has size `ndim`.
|
|
333
|
+
|
|
334
|
+
Returns
|
|
335
|
+
-------
|
|
336
|
+
torch.Tensor
|
|
337
|
+
Normalized tensor with the same shape as input.
|
|
338
|
+
"""
|
|
339
|
+
return F.layer_norm(
|
|
340
|
+
x,
|
|
341
|
+
normalized_shape=self.weight.shape,
|
|
342
|
+
weight=self.weight,
|
|
343
|
+
bias=self.bias,
|
|
344
|
+
eps=1e-5,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class _MHSA(nn.Module):
|
|
349
|
+
"""Multi-head self-attention (MHSA) block.
|
|
350
|
+
|
|
351
|
+
Implements a standard multi-head attention mechanism with optional
|
|
352
|
+
use of PyTorch's scaled_dot_product_attention (FlashAttention) when
|
|
353
|
+
available and requested.
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
d_model : int
|
|
358
|
+
Dimensionality of the model / embeddings.
|
|
359
|
+
n_head : int
|
|
360
|
+
Number of attention heads.
|
|
361
|
+
bias : bool
|
|
362
|
+
Whether linear layers use bias.
|
|
363
|
+
drop_prob : float, optional
|
|
364
|
+
drop_prob probability applied to attention weights and residual projection.
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(
|
|
368
|
+
self,
|
|
369
|
+
d_model: int,
|
|
370
|
+
n_head: int,
|
|
371
|
+
bias: bool,
|
|
372
|
+
drop_prob: float = 0.0,
|
|
373
|
+
) -> None:
|
|
374
|
+
super().__init__()
|
|
375
|
+
|
|
376
|
+
assert d_model % n_head == 0, "d_model must be divisible by n_head"
|
|
377
|
+
|
|
378
|
+
# qkv and output projection
|
|
379
|
+
self.attn = nn.Linear(d_model, 3 * d_model, bias=bias)
|
|
380
|
+
self.proj = nn.Linear(d_model, d_model, bias=bias)
|
|
381
|
+
|
|
382
|
+
# dropout modules
|
|
383
|
+
self.attn_drop_prob = nn.Dropout(drop_prob)
|
|
384
|
+
self.resid_drop_prob = nn.Dropout(drop_prob)
|
|
385
|
+
|
|
386
|
+
self.n_head = n_head
|
|
387
|
+
self.d_model = d_model
|
|
388
|
+
self.drop_prob = drop_prob
|
|
389
|
+
|
|
390
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
391
|
+
"""Forward pass for MHSA.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
x : torch.Tensor
|
|
396
|
+
Input tensor of shape (B, T, C) where C == d_model.
|
|
397
|
+
|
|
398
|
+
Returns
|
|
399
|
+
-------
|
|
400
|
+
torch.Tensor
|
|
401
|
+
Output tensor of shape (B, T, C).
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
B, T, C = x.size()
|
|
405
|
+
|
|
406
|
+
# project to q, k, v and reshape for multi-head attention
|
|
407
|
+
q, k, v = self.attn(x).split(self.d_model, dim=2)
|
|
408
|
+
|
|
409
|
+
# (B, nh, T, hs)
|
|
410
|
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
411
|
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
412
|
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
|
|
413
|
+
|
|
414
|
+
y = torch.nn.functional.scaled_dot_product_attention(
|
|
415
|
+
q,
|
|
416
|
+
k,
|
|
417
|
+
v,
|
|
418
|
+
attn_mask=None,
|
|
419
|
+
dropout_p=self.drop_prob if self.training else 0.0,
|
|
420
|
+
is_causal=False,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# re-assemble heads -> (B, T, C)
|
|
424
|
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
|
425
|
+
|
|
426
|
+
# final linear projection + residual drop_prob
|
|
427
|
+
y = self.resid_drop_prob(self.proj(y))
|
|
428
|
+
return y
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
class _FeedForward(nn.Module):
|
|
432
|
+
"""Position-wise feed-forward network from Transformer.
|
|
433
|
+
|
|
434
|
+
Implements the two-layer MLP with GELU activation and dropout used in
|
|
435
|
+
Transformer architectures.
|
|
436
|
+
|
|
437
|
+
Parameters
|
|
438
|
+
----------
|
|
439
|
+
d_model : int
|
|
440
|
+
Input and output dimensionality.
|
|
441
|
+
dim_feedforward : int, optional
|
|
442
|
+
Hidden dimensionality of the feed-forward layer. If None, must be provided by caller.
|
|
443
|
+
drop_prob : float, optional
|
|
444
|
+
Dropout probability.
|
|
445
|
+
bias : bool, optional
|
|
446
|
+
Whether linear layers use bias.
|
|
447
|
+
"""
|
|
448
|
+
|
|
449
|
+
def __init__(
|
|
450
|
+
self,
|
|
451
|
+
d_model: int,
|
|
452
|
+
dim_feedforward: Optional[int] = None,
|
|
453
|
+
drop_prob: float = 0.0,
|
|
454
|
+
bias: bool = False,
|
|
455
|
+
activation: nn.Module = nn.GELU,
|
|
456
|
+
) -> None:
|
|
457
|
+
super().__init__()
|
|
458
|
+
|
|
459
|
+
if dim_feedforward is None:
|
|
460
|
+
raise ValueError("dim_feedforward must be provided")
|
|
461
|
+
|
|
462
|
+
self.proj_in = nn.Linear(d_model, dim_feedforward, bias=bias)
|
|
463
|
+
self.activation = activation()
|
|
464
|
+
self.proj = nn.Linear(dim_feedforward, d_model, bias=bias)
|
|
465
|
+
self.drop_prob = nn.Dropout(drop_prob)
|
|
466
|
+
self.drop_prob1 = nn.Dropout(drop_prob)
|
|
467
|
+
|
|
468
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
469
|
+
"""Forward pass through the feed-forward block."""
|
|
470
|
+
x = self.proj_in(x)
|
|
471
|
+
x = self.activation(x)
|
|
472
|
+
x = self.drop_prob1(x)
|
|
473
|
+
x = self.proj(x)
|
|
474
|
+
x = self.drop_prob(x)
|
|
475
|
+
return x
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
class _TransformerEncoderLayer(nn.Module):
|
|
479
|
+
"""Single Transformer encoder layer (pre-norm) combining MHSA and feed-forward.
|
|
480
|
+
|
|
481
|
+
The block follows the pattern:
|
|
482
|
+
x <- x + MHSA(_LayerNorm(x))
|
|
483
|
+
x <- x + FF(_LayerNorm(x))
|
|
484
|
+
"""
|
|
485
|
+
|
|
486
|
+
def __init__(
|
|
487
|
+
self,
|
|
488
|
+
d_model: int,
|
|
489
|
+
n_head: int,
|
|
490
|
+
drop_prob: float = 0.0,
|
|
491
|
+
dim_feedforward: Optional[int] = None,
|
|
492
|
+
bias: bool = False,
|
|
493
|
+
activation: nn.Module = nn.GELU,
|
|
494
|
+
) -> None:
|
|
495
|
+
super().__init__()
|
|
496
|
+
|
|
497
|
+
if dim_feedforward is None:
|
|
498
|
+
dim_feedforward = 4 * d_model
|
|
499
|
+
# note: preserve the original behaviour (print) from the provided code
|
|
500
|
+
print(
|
|
501
|
+
"dim_feedforward is set to 4*d_model, the default in Vaswani et al. (Attention is all you need)"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
self.layer_norm_att = _LayerNorm(d_model, bias=bias)
|
|
505
|
+
self.mhsa = _MHSA(d_model, n_head, bias, drop_prob=drop_prob)
|
|
506
|
+
self.layer_norm_ff = _LayerNorm(d_model, bias=bias)
|
|
507
|
+
self.feed_forward = _FeedForward(
|
|
508
|
+
d_model=d_model,
|
|
509
|
+
dim_feedforward=dim_feedforward,
|
|
510
|
+
drop_prob=drop_prob,
|
|
511
|
+
bias=bias,
|
|
512
|
+
activation=activation,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
516
|
+
"""Execute one encoder layer.
|
|
517
|
+
|
|
518
|
+
Parameters
|
|
519
|
+
----------
|
|
520
|
+
x : torch.Tensor
|
|
521
|
+
Input of shape (B, T, d_model).
|
|
522
|
+
|
|
523
|
+
Returns
|
|
524
|
+
-------
|
|
525
|
+
torch.Tensor
|
|
526
|
+
Output of the same shape as input.
|
|
527
|
+
"""
|
|
528
|
+
x = x + self.mhsa(self.layer_norm_att(x))
|
|
529
|
+
x = x + self.feed_forward(self.layer_norm_ff(x))
|
|
530
|
+
return x
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class _TransformerEncoder(nn.Module):
|
|
534
|
+
"""Stack of Transformer encoder layers.
|
|
535
|
+
|
|
536
|
+
Parameters
|
|
537
|
+
----------
|
|
538
|
+
n_blocks : int
|
|
539
|
+
Number of encoder layers to stack.
|
|
540
|
+
d_model : int
|
|
541
|
+
Dimensionality of embeddings.
|
|
542
|
+
n_head : int
|
|
543
|
+
Number of attention heads per layer.
|
|
544
|
+
drop_prob : float
|
|
545
|
+
Dropout probability.
|
|
546
|
+
bias : bool
|
|
547
|
+
Whether linear layers use bias.
|
|
548
|
+
"""
|
|
549
|
+
|
|
550
|
+
def __init__(
|
|
551
|
+
self,
|
|
552
|
+
n_blocks: int,
|
|
553
|
+
d_model: int,
|
|
554
|
+
n_head: int,
|
|
555
|
+
drop_prob: float,
|
|
556
|
+
bias: bool,
|
|
557
|
+
activation: nn.Module = nn.GELU,
|
|
558
|
+
) -> None:
|
|
559
|
+
super().__init__()
|
|
560
|
+
|
|
561
|
+
self.encoder_block = nn.ModuleList(
|
|
562
|
+
[
|
|
563
|
+
_TransformerEncoderLayer(
|
|
564
|
+
d_model=d_model,
|
|
565
|
+
n_head=n_head,
|
|
566
|
+
drop_prob=drop_prob,
|
|
567
|
+
dim_feedforward=None,
|
|
568
|
+
bias=bias,
|
|
569
|
+
activation=activation,
|
|
570
|
+
)
|
|
571
|
+
for _ in range(n_blocks)
|
|
572
|
+
]
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# GPT2-like initialization for linear layers
|
|
576
|
+
self.apply(self._init_weights)
|
|
577
|
+
for pn, p in self.named_parameters():
|
|
578
|
+
if pn.endswith("proj.weight"):
|
|
579
|
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_blocks))
|
|
580
|
+
|
|
581
|
+
@staticmethod
|
|
582
|
+
def _init_weights(module: nn.Module) -> None:
|
|
583
|
+
if isinstance(module, nn.Linear):
|
|
584
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
585
|
+
if module.bias is not None:
|
|
586
|
+
torch.nn.init.zeros_(module.bias)
|
|
587
|
+
|
|
588
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
589
|
+
"""Forward through all encoder blocks sequentially."""
|
|
590
|
+
for block in self.encoder_block:
|
|
591
|
+
x = block(x)
|
|
592
|
+
return x
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
class _Patcher(nn.Module):
|
|
596
|
+
"""Patching encoding helper.
|
|
597
|
+
|
|
598
|
+
This module "patchifies" the original X entry in a ViT manner.
|
|
599
|
+
|
|
600
|
+
Parameters
|
|
601
|
+
----------
|
|
602
|
+
n_chans : int
|
|
603
|
+
Number of EEG channels.
|
|
604
|
+
n_times : int
|
|
605
|
+
Number of time samples per trial.
|
|
606
|
+
d_input : int, optional
|
|
607
|
+
Number of positional embedding slots allocated per channel.
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
def __init__(self, n_chans=2, n_times=1000, d_input=64) -> None:
|
|
611
|
+
super().__init__()
|
|
612
|
+
self.d_input = d_input
|
|
613
|
+
self.num_tokens = (n_chans * n_times) // d_input
|
|
614
|
+
|
|
615
|
+
def forward(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
616
|
+
"""Returns patched input together with positional indices
|
|
617
|
+
expanded to batch dimension.
|
|
618
|
+
|
|
619
|
+
Parameters
|
|
620
|
+
----------
|
|
621
|
+
X : torch.Tensor
|
|
622
|
+
Input tensor.
|
|
623
|
+
|
|
624
|
+
Returns
|
|
625
|
+
-------
|
|
626
|
+
x_patched: torch.LongTensor
|
|
627
|
+
Tensor of shape (B, (C*T)//d_input, d_input) containing positional token indices.
|
|
628
|
+
|
|
629
|
+
position: torch.LongTensor
|
|
630
|
+
Tensor of shape ((C*T)//d_input + 1) containing positional token indices.
|
|
631
|
+
"""
|
|
632
|
+
|
|
633
|
+
X_flat = X.view(X.size(0), -1)
|
|
634
|
+
x_patched = X_flat[:, : self.num_tokens * self.d_input].view(
|
|
635
|
+
X.size(0), self.num_tokens, self.d_input
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
position = torch.arange(x_patched.size(1) + 1, device=X.device)
|
|
639
|
+
|
|
640
|
+
return x_patched, position
|