braindecode 1.3.0.dev180851780__py3-none-any.whl → 1.3.0.dev183667303__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +115 -151
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +2 -2
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +7 -7
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +2 -0
- braindecode/models/atcnet.py +25 -26
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +2 -0
- braindecode/models/ctnet.py +6 -3
- braindecode/models/deepsleepnet.py +27 -18
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegnet.py +1 -1
- braindecode/models/labram.py +193 -87
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sstdpn.py +11 -11
- braindecode/models/summary.csv +1 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +1 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +12 -12
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/METADATA +6 -2
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/RECORD +52 -49
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev183667303.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BENDR(EEGModuleMixin, nn.Module):
|
|
11
|
+
"""BENDR (BErt-inspired Neural Data Representations) from Kostas et al. (2021) [bendr]_.
|
|
12
|
+
|
|
13
|
+
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
|
|
14
|
+
|
|
15
|
+
.. figure:: https://www.frontiersin.org/files/Articles/653659/fnhum-15-653659-HTML/image_m/fnhum-15-653659-g001.jpg
|
|
16
|
+
:align: center
|
|
17
|
+
:alt: BENDR Architecture
|
|
18
|
+
:width: 1000px
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
The **BENDR** architecture adapts techniques used for language modeling (LM) toward the
|
|
22
|
+
development of encephalography modeling (EM) [bendr]_. It utilizes a self-supervised
|
|
23
|
+
training objective to learn compressed representations of raw EEG signals [bendr]_. The
|
|
24
|
+
model is capable of modeling completely novel raw EEG sequences recorded with differing
|
|
25
|
+
hardware and subjects, aiming for transferable performance across a variety of downstream
|
|
26
|
+
BCI and EEG classification tasks [bendr]_.
|
|
27
|
+
|
|
28
|
+
.. rubric:: Architectural Overview
|
|
29
|
+
|
|
30
|
+
BENDR is adapted from wav2vec 2.0 [wav2vec2]_ and is composed of two main stages: a
|
|
31
|
+
feature extractor (Convolutional stage) that produces BErt-inspired Neural Data
|
|
32
|
+
Representations (BENDR), followed by a transformer encoder (Contextualizer) [bendr]_.
|
|
33
|
+
|
|
34
|
+
.. rubric:: Macro Components
|
|
35
|
+
|
|
36
|
+
- `BENDR.encoder` **(Convolutional Stage/Feature Extractor)**
|
|
37
|
+
- *Operations.* A stack of six short-receptive field 1D convolutions [bendr]_. Each
|
|
38
|
+
block consists of 1D convolution, GroupNorm, and GELU activation.
|
|
39
|
+
- *Role.* Takes raw data :math:`X_{raw}` and dramatically downsamples it to a new
|
|
40
|
+
sequence of vectors (BENDR) [bendr]_. Each resulting vector has a length of 512.
|
|
41
|
+
- `BENDR.contextualizer` **(Transformer Encoder)**
|
|
42
|
+
- *Operations.* A transformer encoder that uses layered, multi-head self-attention
|
|
43
|
+
[bendr]_. It employs T-Fixup weight initialization [tfixup]_ and uses 8 layers
|
|
44
|
+
and 8 heads.
|
|
45
|
+
- *Role.* Maps the sequence of BENDR vectors to a contextualized sequence. The output
|
|
46
|
+
of a fixed start token is typically used as the aggregate representation for
|
|
47
|
+
downstream classification [bendr]_.
|
|
48
|
+
- `Contextualizer.position_encoder` **(Positional Encoding)**
|
|
49
|
+
- *Operations.* An additive (grouped) convolution layer with a receptive field of 25
|
|
50
|
+
and 16 groups [bendr]_.
|
|
51
|
+
- *Role.* Encodes position information before the input enters the transformer.
|
|
52
|
+
|
|
53
|
+
.. rubric:: How the information is encoded temporally, spatially, and spectrally
|
|
54
|
+
|
|
55
|
+
* **Temporal.**
|
|
56
|
+
The convolutional encoder uses a stack of blocks where the stride matches the receptive
|
|
57
|
+
field (e.g., 3 for the first block, 2 for subsequent blocks) [bendr]_. This process
|
|
58
|
+
downsamples the raw data by a factor of 96, resulting in an effective sampling frequency
|
|
59
|
+
of approximately 2.67 Hz.
|
|
60
|
+
* **Spatial.**
|
|
61
|
+
To maintain simplicity and reduce complexity, the convolutional stage uses **1D
|
|
62
|
+
convolutions** and elects not to mix EEG channels across the first stage [bendr]_. The
|
|
63
|
+
input includes 20 channels (19 EEG channels and one relative amplitude channel).
|
|
64
|
+
* **Spectral.**
|
|
65
|
+
The convolution operations implicitly extract features from the raw EEG signal [bendr]_.
|
|
66
|
+
The representations (BENDR) are derived from the raw waveform using convolutional
|
|
67
|
+
operations followed by sequence modeling [wav2vec2]_.
|
|
68
|
+
|
|
69
|
+
.. rubric:: Additional Mechanisms
|
|
70
|
+
|
|
71
|
+
- **Self-Supervision (Pre-training).** Uses a masked sequence learning approach (adapted
|
|
72
|
+
from wav2vec 2.0 [wav2vec2]_) where contiguous spans of BENDR sequences are masked, and
|
|
73
|
+
the model attempts to reconstruct the original underlying encoded vector based on the
|
|
74
|
+
transformer output and a set of negative distractors [bendr]_.
|
|
75
|
+
- **Regularization.** LayerDrop [layerdrop]_ and Dropout (at probabilities 0.01 and 0.15,
|
|
76
|
+
respectively) are used during pre-training [bendr]_. The implementation also uses T-Fixup
|
|
77
|
+
scaling for parameter initialization [tfixup]_.
|
|
78
|
+
- **Input Conditioning.** A fixed token (a vector filled with the value **-5**) is
|
|
79
|
+
prepended to the BENDR sequence before input to the transformer, serving as the aggregate
|
|
80
|
+
representation token [bendr]_.
|
|
81
|
+
|
|
82
|
+
Notes
|
|
83
|
+
-----
|
|
84
|
+
* The full BENDR architecture contains a large number of parameters; configuration (1)
|
|
85
|
+
involved training over **one billion parameters** [bendr]_.
|
|
86
|
+
* Randomly initialized full BENDR architecture was generally ineffective at solving
|
|
87
|
+
downstream tasks without prior self-supervised training [bendr]_.
|
|
88
|
+
* The pre-training task (contrastive predictive coding via masking) is generalizable,
|
|
89
|
+
exhibiting strong uniformity of performance across novel subjects, hardware, and
|
|
90
|
+
tasks [bendr]_.
|
|
91
|
+
|
|
92
|
+
.. warning::
|
|
93
|
+
|
|
94
|
+
**Important:** To utilize the full potential of BENDR, the model requires
|
|
95
|
+
**self-supervised pre-training** on large, unlabeled EEG datasets (like TUEG) followed
|
|
96
|
+
by subsequent fine-tuning on the specific downstream classification task [bendr]_.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
encoder_h : int, default=512
|
|
101
|
+
Hidden size (number of output channels) of the convolutional encoder. This determines
|
|
102
|
+
the dimensionality of the BENDR feature vectors produced by the encoder.
|
|
103
|
+
contextualizer_hidden : int, default=3076
|
|
104
|
+
Hidden size of the feedforward layer within each transformer block. The paper uses
|
|
105
|
+
approximately 2x the transformer dimension (3076 ~ 2 x 1536).
|
|
106
|
+
projection_head : bool, default=False
|
|
107
|
+
If True, adds a projection layer at the end of the encoder to project back to the
|
|
108
|
+
input feature size. This is used during self-supervised pre-training but typically
|
|
109
|
+
disabled during fine-tuning.
|
|
110
|
+
drop_prob : float, default=0.1
|
|
111
|
+
Dropout probability applied throughout the model. The paper recommends 0.15 for
|
|
112
|
+
pre-training and 0.0 for fine-tuning. Default is 0.1 as a compromise.
|
|
113
|
+
layer_drop : float, default=0.0
|
|
114
|
+
Probability of dropping entire transformer layers during training (LayerDrop
|
|
115
|
+
regularization [layerdrop]_). The paper uses 0.01 for pre-training and 0.0 for
|
|
116
|
+
fine-tuning.
|
|
117
|
+
activation : :class:`torch.nn.Module`, default=:class:`torch.nn.GELU`
|
|
118
|
+
Activation function used in the encoder convolutional blocks. The paper uses GELU
|
|
119
|
+
activation throughout.
|
|
120
|
+
transformer_layers : int, default=8
|
|
121
|
+
Number of transformer encoder layers in the contextualizer. The paper uses 8 layers.
|
|
122
|
+
transformer_heads : int, default=8
|
|
123
|
+
Number of attention heads in each transformer layer. The paper uses 8 heads with
|
|
124
|
+
head dimension of 192 (1536 / 8).
|
|
125
|
+
position_encoder_length : int, default=25
|
|
126
|
+
Kernel size for the convolutional positional encoding layer. The paper uses a
|
|
127
|
+
receptive field of 25 with 16 groups.
|
|
128
|
+
enc_width : tuple of int, default=(3, 2, 2, 2, 2, 2)
|
|
129
|
+
Kernel sizes for each of the 6 convolutional blocks in the encoder. Each value
|
|
130
|
+
corresponds to one block.
|
|
131
|
+
enc_downsample : tuple of int, default=(3, 2, 2, 2, 2, 2)
|
|
132
|
+
Stride values for each of the 6 convolutional blocks in the encoder. The total
|
|
133
|
+
downsampling factor is the product of all strides (3 x 2 x 2 x 2 x 2 x 2 = 96).
|
|
134
|
+
start_token : int or float, default=-5
|
|
135
|
+
Value used to fill the start token embedding that is prepended to the BENDR sequence
|
|
136
|
+
before input to the transformer. This token's output is used as the aggregate
|
|
137
|
+
representation for classification.
|
|
138
|
+
final_layer : bool, default=True
|
|
139
|
+
If True, includes a final linear classification layer that maps from encoder_h to
|
|
140
|
+
n_outputs. If False, the model outputs the contextualized features directly.
|
|
141
|
+
|
|
142
|
+
References
|
|
143
|
+
----------
|
|
144
|
+
.. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
|
|
145
|
+
BENDR: Using transformers and a contrastive self-supervised learning task to learn from
|
|
146
|
+
massive amounts of EEG data.
|
|
147
|
+
Frontiers in Human Neuroscience, 15, 653659.
|
|
148
|
+
https://doi.org/10.3389/fnhum.2021.653659
|
|
149
|
+
.. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
|
|
150
|
+
wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
|
|
151
|
+
In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
|
|
152
|
+
Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
|
|
153
|
+
https://dl.acm.org/doi/10.5555/3495724.3496768
|
|
154
|
+
.. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
|
|
155
|
+
Improving Transformer Optimization Through Better Initialization.
|
|
156
|
+
In International Conference on Machine Learning (pp. 4475-4483). PMLR.
|
|
157
|
+
https://dl.acm.org/doi/10.5555/3524938.3525354
|
|
158
|
+
.. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
|
|
159
|
+
Reducing Transformer Depth on Demand with Structured Dropout.
|
|
160
|
+
International Conference on Learning Representations.
|
|
161
|
+
Retrieved from https://openreview.net/forum?id=SylO2yStDr
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
# Signal related parameters
|
|
167
|
+
n_chans=None,
|
|
168
|
+
n_outputs=None,
|
|
169
|
+
n_times=None,
|
|
170
|
+
chs_info=None,
|
|
171
|
+
input_window_seconds=None,
|
|
172
|
+
sfreq=None,
|
|
173
|
+
# Model parameters
|
|
174
|
+
encoder_h=512, # Hidden size of the encoder convolutional layers
|
|
175
|
+
contextualizer_hidden=3076, # Feedforward hidden size in transformer
|
|
176
|
+
projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
|
|
177
|
+
drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
|
|
178
|
+
layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
|
|
179
|
+
activation=nn.GELU, # Activation function
|
|
180
|
+
# Transformer specific parameters
|
|
181
|
+
transformer_layers=8,
|
|
182
|
+
transformer_heads=8,
|
|
183
|
+
position_encoder_length=25, # Kernel size for positional encoding conv
|
|
184
|
+
enc_width=(3, 2, 2, 2, 2, 2),
|
|
185
|
+
enc_downsample=(3, 2, 2, 2, 2, 2),
|
|
186
|
+
start_token=-5, # Value for start token embedding
|
|
187
|
+
final_layer=True, # Whether to include the final linear layer
|
|
188
|
+
):
|
|
189
|
+
super().__init__(
|
|
190
|
+
n_outputs=n_outputs,
|
|
191
|
+
n_chans=n_chans,
|
|
192
|
+
chs_info=chs_info,
|
|
193
|
+
n_times=n_times,
|
|
194
|
+
input_window_seconds=input_window_seconds,
|
|
195
|
+
sfreq=sfreq,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Keep these parameters if needed later, otherwise they are captured by the mixin
|
|
199
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
200
|
+
|
|
201
|
+
self.encoder_h = encoder_h
|
|
202
|
+
self.contextualizer_hidden = contextualizer_hidden
|
|
203
|
+
self.include_final_layer = final_layer
|
|
204
|
+
|
|
205
|
+
# Encoder: Use parameters from __init__
|
|
206
|
+
self.encoder = _ConvEncoderBENDR(
|
|
207
|
+
in_features=self.n_chans,
|
|
208
|
+
encoder_h=encoder_h,
|
|
209
|
+
dropout=drop_prob,
|
|
210
|
+
projection_head=projection_head,
|
|
211
|
+
enc_width=enc_width,
|
|
212
|
+
enc_downsample=enc_downsample,
|
|
213
|
+
activation=activation,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
self.contextualizer = _BENDRContextualizer(
|
|
217
|
+
in_features=self.encoder.encoder_h, # Use the output feature size of the encoder
|
|
218
|
+
hidden_feedforward=contextualizer_hidden,
|
|
219
|
+
heads=transformer_heads,
|
|
220
|
+
layers=transformer_layers,
|
|
221
|
+
dropout=drop_prob, # Use general dropout probability
|
|
222
|
+
activation=activation,
|
|
223
|
+
position_encoder=position_encoder_length, # Pass position encoder kernel size
|
|
224
|
+
layer_drop=layer_drop,
|
|
225
|
+
start_token=start_token, # Keep fixed start token value
|
|
226
|
+
)
|
|
227
|
+
in_features = self.encoder.encoder_h # Set in_features for final layer
|
|
228
|
+
|
|
229
|
+
self.final_layer = None
|
|
230
|
+
if self.include_final_layer:
|
|
231
|
+
# Input to Linear will be [batch_size, encoder_h] after taking last timestep
|
|
232
|
+
linear = nn.Linear(in_features=in_features, out_features=self.n_outputs)
|
|
233
|
+
self.final_layer = nn.utils.parametrizations.weight_norm(
|
|
234
|
+
linear, name="weight", dim=1
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def forward(self, x):
|
|
238
|
+
# Input x: [batch_size, n_chans, n_times]
|
|
239
|
+
encoded = self.encoder(x)
|
|
240
|
+
# encoded: [batch_size, encoder_h, n_encoded_times]
|
|
241
|
+
|
|
242
|
+
context = self.contextualizer(encoded)
|
|
243
|
+
# context: [batch_size, encoder_h, n_encoded_times + 1] (due to start token)
|
|
244
|
+
|
|
245
|
+
# Extract features - use the output corresponding to the start token (index 0)
|
|
246
|
+
# The output has shape [batch_size, features, seq_len+1]
|
|
247
|
+
# The start token was prepended at position 0 in the contextualizer before the
|
|
248
|
+
# transformer layers. After permutation back to [batch, features, seq_len+1],
|
|
249
|
+
# the start token output is at index 0.
|
|
250
|
+
# According to the paper (Kostas et al. 2021): "The transformer output of this
|
|
251
|
+
# initial position was not modified during pre-training, and only used for
|
|
252
|
+
# downstream tasks." This follows BERT's [CLS] token convention where the
|
|
253
|
+
# first token aggregates sequence information via self-attention.
|
|
254
|
+
feature = context[:, :, 0]
|
|
255
|
+
# feature: [batch_size, encoder_h]
|
|
256
|
+
|
|
257
|
+
if self.final_layer is not None:
|
|
258
|
+
feature = self.final_layer(feature)
|
|
259
|
+
# feature: [batch_size, n_outputs]
|
|
260
|
+
|
|
261
|
+
return feature
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class _ConvEncoderBENDR(nn.Module):
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
in_features,
|
|
268
|
+
encoder_h=512,
|
|
269
|
+
enc_width=(3, 2, 2, 2, 2, 2),
|
|
270
|
+
dropout=0.0,
|
|
271
|
+
projection_head=False,
|
|
272
|
+
enc_downsample=(3, 2, 2, 2, 2, 2),
|
|
273
|
+
activation=nn.GELU,
|
|
274
|
+
):
|
|
275
|
+
super().__init__()
|
|
276
|
+
self.encoder_h = encoder_h
|
|
277
|
+
self.in_features = in_features
|
|
278
|
+
|
|
279
|
+
if not isinstance(enc_width, (list, tuple)):
|
|
280
|
+
enc_width = [enc_width]
|
|
281
|
+
if not isinstance(enc_downsample, (list, tuple)):
|
|
282
|
+
enc_downsample = [enc_downsample]
|
|
283
|
+
if len(enc_downsample) != len(enc_width):
|
|
284
|
+
raise ValueError(
|
|
285
|
+
"Encoder width and downsampling factors must have the same length."
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Centerable convolutions make life simpler
|
|
289
|
+
enc_width = [
|
|
290
|
+
e if e % 2 != 0 else e + 1 for e in enc_width
|
|
291
|
+
] # Ensure odd kernel size
|
|
292
|
+
self._downsampling = enc_downsample
|
|
293
|
+
self._width = enc_width
|
|
294
|
+
|
|
295
|
+
current_in_features = in_features
|
|
296
|
+
self.encoder = nn.Sequential()
|
|
297
|
+
for i, (width, downsample) in enumerate(zip(enc_width, enc_downsample)):
|
|
298
|
+
self.encoder.add_module(
|
|
299
|
+
"Encoder_{}".format(i),
|
|
300
|
+
nn.Sequential(
|
|
301
|
+
nn.Conv1d(
|
|
302
|
+
current_in_features,
|
|
303
|
+
encoder_h,
|
|
304
|
+
width,
|
|
305
|
+
stride=downsample,
|
|
306
|
+
padding=width
|
|
307
|
+
// 2, # Correct padding for 'same' output length before stride
|
|
308
|
+
),
|
|
309
|
+
nn.Dropout2d(dropout), # 2D dropout (matches paper specification)
|
|
310
|
+
nn.GroupNorm(
|
|
311
|
+
encoder_h // 2, encoder_h
|
|
312
|
+
), # Consider making num_groups configurable or ensure encoder_h is divisible by 2
|
|
313
|
+
activation(),
|
|
314
|
+
),
|
|
315
|
+
)
|
|
316
|
+
current_in_features = encoder_h
|
|
317
|
+
if projection_head:
|
|
318
|
+
self.encoder.add_module(
|
|
319
|
+
"projection_head",
|
|
320
|
+
nn.Conv1d(encoder_h, encoder_h, 1),
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def forward(self, x):
|
|
324
|
+
return self.encoder(x)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class _BENDRContextualizer(nn.Module):
|
|
328
|
+
"""Transformer-based contextualizer for BENDR."""
|
|
329
|
+
|
|
330
|
+
def __init__(
|
|
331
|
+
self,
|
|
332
|
+
in_features,
|
|
333
|
+
hidden_feedforward=3076,
|
|
334
|
+
heads=8,
|
|
335
|
+
layers=8,
|
|
336
|
+
dropout=0.1, # Default dropout
|
|
337
|
+
activation=nn.GELU, # Activation for transformer FF layer
|
|
338
|
+
position_encoder=25, # Kernel size for conv positional encoding
|
|
339
|
+
layer_drop=0.0, # Probability of dropping a whole layer
|
|
340
|
+
start_token=-5, # Value for start token embedding
|
|
341
|
+
):
|
|
342
|
+
super().__init__()
|
|
343
|
+
self.dropout = dropout
|
|
344
|
+
self.layer_drop = layer_drop
|
|
345
|
+
self.start_token = start_token # Store start token value
|
|
346
|
+
|
|
347
|
+
# Paper specification: transformer_dim = 3 * encoder_h (1536 for encoder_h=512)
|
|
348
|
+
self.transformer_dim = 3 * in_features
|
|
349
|
+
self.in_features = in_features
|
|
350
|
+
# --- Positional Encoding --- (Applied before projection)
|
|
351
|
+
self.position_encoder = None
|
|
352
|
+
if position_encoder > 0:
|
|
353
|
+
conv = nn.Conv1d(
|
|
354
|
+
in_features,
|
|
355
|
+
in_features,
|
|
356
|
+
kernel_size=position_encoder,
|
|
357
|
+
padding=position_encoder // 2,
|
|
358
|
+
groups=16, # Number of groups for depthwise separation
|
|
359
|
+
)
|
|
360
|
+
# T-Fixup positional encoding initialization (paper specification)
|
|
361
|
+
nn.init.normal_(conv.weight, mean=0, std=2 / self.transformer_dim)
|
|
362
|
+
nn.init.constant_(conv.bias, 0)
|
|
363
|
+
|
|
364
|
+
conv = nn.utils.parametrizations.weight_norm(conv, name="weight", dim=2)
|
|
365
|
+
self.relative_position = nn.Sequential(conv, activation())
|
|
366
|
+
|
|
367
|
+
# --- Input Conditioning --- (Includes projection up to transformer_dim)
|
|
368
|
+
# Rearrange, Norm, Dropout, Project, Rearrange
|
|
369
|
+
self.input_conditioning = nn.Sequential(
|
|
370
|
+
Rearrange(
|
|
371
|
+
"batch channel time -> batch time channel"
|
|
372
|
+
), # Batch, Time, Channels
|
|
373
|
+
nn.LayerNorm(in_features),
|
|
374
|
+
nn.Dropout(dropout),
|
|
375
|
+
nn.Linear(in_features, self.transformer_dim), # Project up using Linear
|
|
376
|
+
# nn.Conv1d(in_features, self.transformer_dim, 1), # Alternative: Project up using Conv1d
|
|
377
|
+
Rearrange(
|
|
378
|
+
"batch time channel -> time batch channel"
|
|
379
|
+
), # Time, Batch, Channels (Transformer expected format)
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# --- Transformer Encoder Layers ---
|
|
383
|
+
# Paper uses T-Fixup: remove internal LayerNorm layers
|
|
384
|
+
|
|
385
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
386
|
+
d_model=self.transformer_dim, # Use projected dimension
|
|
387
|
+
nhead=heads,
|
|
388
|
+
dim_feedforward=hidden_feedforward,
|
|
389
|
+
dropout=dropout, # Dropout within transformer layer
|
|
390
|
+
activation=activation(),
|
|
391
|
+
batch_first=False, # Expects (T, B, C)
|
|
392
|
+
norm_first=False, # Standard post-norm architecture
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# T-Fixup: Replace internal LayerNorms with Identity
|
|
396
|
+
encoder_layer.norm1 = nn.Identity()
|
|
397
|
+
encoder_layer.norm2 = nn.Identity()
|
|
398
|
+
|
|
399
|
+
self.transformer_layers = nn.ModuleList(
|
|
400
|
+
[copy.deepcopy(encoder_layer) for _ in range(layers)]
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Add final LayerNorm after all transformer layers (paper specification)
|
|
404
|
+
self.norm = nn.LayerNorm(self.transformer_dim)
|
|
405
|
+
|
|
406
|
+
# --- Output Layer --- (Project back down to in_features)
|
|
407
|
+
# Paper uses Conv1d with kernel=1 (equivalent to Linear but expects (B,C,T) input)
|
|
408
|
+
self.output_layer = nn.Conv1d(self.transformer_dim, in_features, 1)
|
|
409
|
+
|
|
410
|
+
# Initialize parameters with T-Fixup (paper specification)
|
|
411
|
+
self.apply(self._init_bert_params)
|
|
412
|
+
|
|
413
|
+
def _init_bert_params(self, module):
|
|
414
|
+
"""Initialize linear layers and apply T-Fixup scaling."""
|
|
415
|
+
if isinstance(module, nn.Linear):
|
|
416
|
+
# Standard init
|
|
417
|
+
nn.init.xavier_uniform_(module.weight.data)
|
|
418
|
+
if module.bias is not None:
|
|
419
|
+
module.bias.data.zero_()
|
|
420
|
+
# Tfixup Scaling
|
|
421
|
+
module.weight.data = (
|
|
422
|
+
0.67 * len(self.transformer_layers) ** (-0.25) * module.weight.data
|
|
423
|
+
)
|
|
424
|
+
elif isinstance(module, nn.LayerNorm):
|
|
425
|
+
module.bias.data.zero_()
|
|
426
|
+
module.weight.data.fill_(1.0)
|
|
427
|
+
|
|
428
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
429
|
+
# Input x: [batch_size, in_features, seq_len]
|
|
430
|
+
|
|
431
|
+
# Apply relative positional encoding
|
|
432
|
+
if hasattr(self, "relative_position"):
|
|
433
|
+
pos_enc = self.relative_position(x)
|
|
434
|
+
x = x + pos_enc
|
|
435
|
+
|
|
436
|
+
# Apply input conditioning (includes projection up and rearrange)
|
|
437
|
+
x = self.input_conditioning(x)
|
|
438
|
+
# x: [seq_len, batch_size, transformer_dim]
|
|
439
|
+
|
|
440
|
+
# Prepend start token
|
|
441
|
+
if self.start_token is not None:
|
|
442
|
+
token_emb = torch.full(
|
|
443
|
+
(1, x.shape[1], x.shape[2]),
|
|
444
|
+
float(self.start_token),
|
|
445
|
+
device=x.device,
|
|
446
|
+
requires_grad=False,
|
|
447
|
+
)
|
|
448
|
+
x = torch.cat([token_emb, x], dim=0)
|
|
449
|
+
# x: [seq_len + 1, batch_size, transformer_dim]
|
|
450
|
+
|
|
451
|
+
# Apply transformer layers with layer drop
|
|
452
|
+
for layer in self.transformer_layers:
|
|
453
|
+
if not self.training or torch.rand(1) > self.layer_drop:
|
|
454
|
+
x = layer(x)
|
|
455
|
+
# x: [seq_len + 1, batch_size, transformer_dim]
|
|
456
|
+
|
|
457
|
+
# Apply final LayerNorm (T-Fixup: norm only at the end)
|
|
458
|
+
x = self.norm(x)
|
|
459
|
+
# x: [seq_len + 1, batch_size, transformer_dim]
|
|
460
|
+
|
|
461
|
+
# Permute to (B, C, T) format for Conv1d output layer
|
|
462
|
+
x = Rearrange("time batch channel -> batch channel time")(x)
|
|
463
|
+
# x: [batch_size, transformer_dim, seq_len + 1]
|
|
464
|
+
|
|
465
|
+
# Apply output projection (Conv1d expects B, C, T)
|
|
466
|
+
x = self.output_layer(x)
|
|
467
|
+
# x: [batch_size, in_features, seq_len + 1]
|
|
468
|
+
|
|
469
|
+
return x
|
braindecode/models/biot.py
CHANGED
|
@@ -11,6 +11,8 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
11
11
|
class BIOT(EEGModuleMixin, nn.Module):
|
|
12
12
|
"""BIOT from Yang et al. (2023) [Yang2023]_
|
|
13
13
|
|
|
14
|
+
:bdg-danger:`Large Brain Model`
|
|
15
|
+
|
|
14
16
|
.. figure:: https://braindecode.org/dev/_static/model/biot.jpg
|
|
15
17
|
:align: center
|
|
16
18
|
:alt: BioT
|
braindecode/models/ctnet.py
CHANGED
|
@@ -39,16 +39,19 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
39
39
|
The architecture consists of three main components:
|
|
40
40
|
|
|
41
41
|
1. **Convolutional Module**:
|
|
42
|
+
|
|
42
43
|
- Apply :class:`EEGNet` to perform some feature extraction, denoted here as
|
|
43
|
-
|
|
44
|
+
_PatchEmbeddingEEGNet module.
|
|
44
45
|
|
|
45
46
|
2. **Transformer Encoder Module**:
|
|
47
|
+
|
|
46
48
|
- Utilizes multi-head self-attention mechanisms as EEGConformer but
|
|
47
|
-
|
|
49
|
+
with residual blocks.
|
|
48
50
|
|
|
49
51
|
3. **Classifier Module**:
|
|
52
|
+
|
|
50
53
|
- Combines features from both the convolutional module
|
|
51
|
-
|
|
54
|
+
and the Transformer encoder.
|
|
52
55
|
- Flattens the combined features and applies dropout for regularization.
|
|
53
56
|
- Uses a fully connected layer to produce the final classification output.
|
|
54
57
|
|
|
@@ -43,16 +43,21 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
|
43
43
|
|
|
44
44
|
- :class:`_RepresentationLearning` **(dual-path CNN → epoch feature)**
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
- *Operations.*
|
|
47
|
+
- **Small-filter CNN** 4 times:
|
|
48
|
+
|
|
49
|
+
- :class:`~torch.nn.Conv1d`
|
|
50
|
+
- :class:`~torch.nn.BatchNorm1d`
|
|
51
|
+
- :class:`~torch.nn.ReLU`
|
|
52
|
+
- :class:`~torch.nn.MaxPool1d` after.
|
|
53
|
+
|
|
54
|
+
- First conv uses **filter length ≈ Fs/2** and **stride ≈ Fs/16** to emphasize *timing* of graphoelements.
|
|
55
|
+
|
|
53
56
|
- **Large-filter CNN**:
|
|
57
|
+
|
|
54
58
|
- Same stack but first conv uses **filter length ≈ 4·Fs** and
|
|
55
59
|
- **stride ≈ Fs/2** to emphasize *frequency* content.
|
|
60
|
+
|
|
56
61
|
- Outputs from both paths are **concatenated** into the epoch embedding ``a_t``.
|
|
57
62
|
|
|
58
63
|
- *Rationale.*
|
|
@@ -81,20 +86,22 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
|
81
86
|
|
|
82
87
|
- **Temporal (where time-domain patterns are learned).**
|
|
83
88
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
+
Both CNN paths use **1-D temporal convolutions**. The *small-filter* path (first kernel ≈ Fs/2,
|
|
90
|
+
stride ≈ Fs/16) captures *when* characteristic transients occur; the *large-filter* path
|
|
91
|
+
(first kernel ≈ 4·Fs, stride ≈ Fs/2) captures *which* frequency components dominate over the
|
|
92
|
+
epoch. Deeper layers use **small kernels** to refine features with fewer parameters, interleaved
|
|
93
|
+
with **max pooling** for downsampling.
|
|
89
94
|
|
|
90
95
|
- **Spatial (how channels are processed).**
|
|
91
|
-
|
|
92
|
-
|
|
96
|
+
|
|
97
|
+
The original model operates on **single-channel** raw EEG; convolutions therefore mix only
|
|
98
|
+
along time (no spatial convolution across electrodes).
|
|
93
99
|
|
|
94
100
|
- **Spectral (how frequency information emerges).**
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
101
|
+
|
|
102
|
+
No explicit Fourier/wavelet transform is used. The **large-filter path** serves as a
|
|
103
|
+
*frequency-sensitive* analyzer, while the **small-filter path** remains *time-sensitive*,
|
|
104
|
+
together functioning as a **two-band learned filter bank** at the first layer.
|
|
98
105
|
|
|
99
106
|
.. rubric:: Attention / Sequential Modules
|
|
100
107
|
|
|
@@ -112,9 +119,11 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
|
112
119
|
- **Residual shortcut over sequence encoder.** Adds projected CNN features to BiLSTM outputs,
|
|
113
120
|
improving gradient flow and retaining discriminative content from representation learning.
|
|
114
121
|
- **Two-step training.**
|
|
122
|
+
|
|
115
123
|
- (i) **Pretrain** the CNN paths with class-balanced sampling;
|
|
116
124
|
- (ii) **fine-tune** the full network with sequential batches, using **lower LR** for CNNs and **higher LR** for the
|
|
117
|
-
|
|
125
|
+
sequence encoder.
|
|
126
|
+
|
|
118
127
|
- **State handling.** BiLSTM states are **reinitialized per subject** so that temporal context
|
|
119
128
|
does not leak across recordings.
|
|
120
129
|
|
|
@@ -51,7 +51,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
51
51
|
The result is rearranged to a token sequence ``(B, S_tokens, D)``, where ``D = n_filters_time``.
|
|
52
52
|
|
|
53
53
|
*Interpretability/robustness.* Temporal kernels can be inspected as FIR filters;
|
|
54
|
-
the spatial conv yields channel projections analogous to :class:`ShallowFBCSPNet
|
|
54
|
+
the spatial conv yields channel projections analogous to :class:`ShallowFBCSPNet`'s learned
|
|
55
55
|
spatial filters. Temporal pooling stabilizes statistics and reduces sequence length.
|
|
56
56
|
|
|
57
57
|
- :class:`_TransformerEncoder` **(context over temporal tokens)**
|
|
@@ -119,7 +119,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
|
|
|
119
119
|
capacity (but higher compute); larger strides → fewer tokens and stronger inductive bias.
|
|
120
120
|
|
|
121
121
|
- **Embedding dimension = filters.** ``n_filters_time`` serves double duty as both the
|
|
122
|
-
number of temporal filters in the stem and the transformer
|
|
122
|
+
number of temporal filters in the stem and the transformer's embedding size ``D``,
|
|
123
123
|
simplifying dimensional alignment.
|
|
124
124
|
|
|
125
125
|
.. rubric:: Usage and Configuration
|