braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev171178473__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +116 -152
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +218 -0
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +4 -3
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +18 -13
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +12 -8
- braindecode/models/atcnet.py +156 -17
- braindecode/models/attentionbasenet.py +148 -16
- braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +3 -1
- braindecode/models/ctnet.py +7 -4
- braindecode/models/deep4.py +6 -2
- braindecode/models/deepsleepnet.py +127 -5
- braindecode/models/eegconformer.py +114 -15
- braindecode/models/eeginception_erp.py +82 -7
- braindecode/models/eeginception_mi.py +2 -0
- braindecode/models/eegnet.py +64 -177
- braindecode/models/eegnex.py +113 -6
- braindecode/models/eegsimpleconv.py +2 -0
- braindecode/models/eegtcnet.py +1 -1
- braindecode/models/labram.py +188 -84
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/sccnet.py +81 -8
- braindecode/models/shallow_fbcsp.py +2 -0
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sleep_stager_blanco_2020.py +2 -0
- braindecode/models/sleep_stager_chambon_2018.py +2 -0
- braindecode/models/sparcnet.py +2 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +42 -41
- braindecode/models/tidnet.py +2 -0
- braindecode/models/tsinception.py +15 -3
- braindecode/models/usleep.py +108 -9
- braindecode/models/util.py +8 -5
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +42 -39
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/METADATA +12 -3
- braindecode-1.3.0.dev171178473.dist-info/RECORD +106 -0
- braindecode/models/eegresnet.py +0 -362
- braindecode-1.2.0.dev184328194.dist-info/RECORD +0 -101
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/WHEEL +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BENDR(EEGModuleMixin, nn.Module):
|
|
11
|
+
"""BENDR (BErt-inspired Neural Data Representations) from Kostas et al. (2021) [bendr]_.
|
|
12
|
+
|
|
13
|
+
:bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
|
|
14
|
+
|
|
15
|
+
.. figure:: https://www.frontiersin.org/files/Articles/653659/fnhum-15-653659-HTML/image_m/fnhum-15-653659-g001.jpg
|
|
16
|
+
:align: center
|
|
17
|
+
:alt: BENDR Architecture
|
|
18
|
+
:width: 1000px
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
The **BENDR** architecture adapts techniques used for language modeling (LM) toward the
|
|
22
|
+
development of encephalography modeling (EM) [bendr]_. It utilizes a self-supervised
|
|
23
|
+
training objective to learn compressed representations of raw EEG signals [bendr]_. The
|
|
24
|
+
model is capable of modeling completely novel raw EEG sequences recorded with differing
|
|
25
|
+
hardware and subjects, aiming for transferable performance across a variety of downstream
|
|
26
|
+
BCI and EEG classification tasks [bendr]_.
|
|
27
|
+
|
|
28
|
+
.. rubric:: Architectural Overview
|
|
29
|
+
|
|
30
|
+
BENDR is adapted from wav2vec 2.0 [wav2vec2]_ and is composed of two main stages: a
|
|
31
|
+
feature extractor (Convolutional stage) that produces BErt-inspired Neural Data
|
|
32
|
+
Representations (BENDR), followed by a transformer encoder (Contextualizer) [bendr]_.
|
|
33
|
+
|
|
34
|
+
.. rubric:: Macro Components
|
|
35
|
+
|
|
36
|
+
- `BENDR.encoder` **(Convolutional Stage/Feature Extractor)**
|
|
37
|
+
- *Operations.* A stack of six short-receptive field 1D convolutions [bendr]_. Each
|
|
38
|
+
block consists of 1D convolution, GroupNorm, and GELU activation.
|
|
39
|
+
- *Role.* Takes raw data :math:`X_{raw}` and dramatically downsamples it to a new
|
|
40
|
+
sequence of vectors (BENDR) [bendr]_. Each resulting vector has a length of 512.
|
|
41
|
+
- `BENDR.contextualizer` **(Transformer Encoder)**
|
|
42
|
+
- *Operations.* A transformer encoder that uses layered, multi-head self-attention
|
|
43
|
+
[bendr]_. It employs T-Fixup weight initialization [tfixup]_ and uses 8 layers
|
|
44
|
+
and 8 heads.
|
|
45
|
+
- *Role.* Maps the sequence of BENDR vectors to a contextualized sequence. The output
|
|
46
|
+
of a fixed start token is typically used as the aggregate representation for
|
|
47
|
+
downstream classification [bendr]_.
|
|
48
|
+
- `Contextualizer.position_encoder` **(Positional Encoding)**
|
|
49
|
+
- *Operations.* An additive (grouped) convolution layer with a receptive field of 25
|
|
50
|
+
and 16 groups [bendr]_.
|
|
51
|
+
- *Role.* Encodes position information before the input enters the transformer.
|
|
52
|
+
|
|
53
|
+
.. rubric:: How the information is encoded temporally, spatially, and spectrally
|
|
54
|
+
|
|
55
|
+
* **Temporal.**
|
|
56
|
+
The convolutional encoder uses a stack of blocks where the stride matches the receptive
|
|
57
|
+
field (e.g., 3 for the first block, 2 for subsequent blocks) [bendr]_. This process
|
|
58
|
+
downsamples the raw data by a factor of 96, resulting in an effective sampling frequency
|
|
59
|
+
of approximately 2.67 Hz.
|
|
60
|
+
* **Spatial.**
|
|
61
|
+
To maintain simplicity and reduce complexity, the convolutional stage uses **1D
|
|
62
|
+
convolutions** and elects not to mix EEG channels across the first stage [bendr]_. The
|
|
63
|
+
input includes 20 channels (19 EEG channels and one relative amplitude channel).
|
|
64
|
+
* **Spectral.**
|
|
65
|
+
The convolution operations implicitly extract features from the raw EEG signal [bendr]_.
|
|
66
|
+
The representations (BENDR) are derived from the raw waveform using convolutional
|
|
67
|
+
operations followed by sequence modeling [wav2vec2]_.
|
|
68
|
+
|
|
69
|
+
.. rubric:: Additional Mechanisms
|
|
70
|
+
|
|
71
|
+
- **Self-Supervision (Pre-training).** Uses a masked sequence learning approach (adapted
|
|
72
|
+
from wav2vec 2.0 [wav2vec2]_) where contiguous spans of BENDR sequences are masked, and
|
|
73
|
+
the model attempts to reconstruct the original underlying encoded vector based on the
|
|
74
|
+
transformer output and a set of negative distractors [bendr]_.
|
|
75
|
+
- **Regularization.** LayerDrop [layerdrop]_ and Dropout (at probabilities 0.01 and 0.15,
|
|
76
|
+
respectively) are used during pre-training [bendr]_. The implementation also uses T-Fixup
|
|
77
|
+
scaling for parameter initialization [tfixup]_.
|
|
78
|
+
- **Input Conditioning.** A fixed token (a vector filled with the value **-5**) is
|
|
79
|
+
prepended to the BENDR sequence before input to the transformer, serving as the aggregate
|
|
80
|
+
representation token [bendr]_.
|
|
81
|
+
|
|
82
|
+
Notes
|
|
83
|
+
-----
|
|
84
|
+
* The full BENDR architecture contains a large number of parameters; configuration (1)
|
|
85
|
+
involved training over **one billion parameters** [bendr]_.
|
|
86
|
+
* Randomly initialized full BENDR architecture was generally ineffective at solving
|
|
87
|
+
downstream tasks without prior self-supervised training [bendr]_.
|
|
88
|
+
* The pre-training task (contrastive predictive coding via masking) is generalizable,
|
|
89
|
+
exhibiting strong uniformity of performance across novel subjects, hardware, and
|
|
90
|
+
tasks [bendr]_.
|
|
91
|
+
|
|
92
|
+
.. warning::
|
|
93
|
+
|
|
94
|
+
**Important:** To utilize the full potential of BENDR, the model requires
|
|
95
|
+
**self-supervised pre-training** on large, unlabeled EEG datasets (like TUEG) followed
|
|
96
|
+
by subsequent fine-tuning on the specific downstream classification task [bendr]_.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
encoder_h : int, default=512
|
|
101
|
+
Hidden size (number of output channels) of the convolutional encoder. This determines
|
|
102
|
+
the dimensionality of the BENDR feature vectors produced by the encoder.
|
|
103
|
+
contextualizer_hidden : int, default=3076
|
|
104
|
+
Hidden size of the feedforward layer within each transformer block. The paper uses
|
|
105
|
+
approximately 2x the transformer dimension (3076 ~ 2 x 1536).
|
|
106
|
+
projection_head : bool, default=False
|
|
107
|
+
If True, adds a projection layer at the end of the encoder to project back to the
|
|
108
|
+
input feature size. This is used during self-supervised pre-training but typically
|
|
109
|
+
disabled during fine-tuning.
|
|
110
|
+
drop_prob : float, default=0.1
|
|
111
|
+
Dropout probability applied throughout the model. The paper recommends 0.15 for
|
|
112
|
+
pre-training and 0.0 for fine-tuning. Default is 0.1 as a compromise.
|
|
113
|
+
layer_drop : float, default=0.0
|
|
114
|
+
Probability of dropping entire transformer layers during training (LayerDrop
|
|
115
|
+
regularization [layerdrop]_). The paper uses 0.01 for pre-training and 0.0 for
|
|
116
|
+
fine-tuning.
|
|
117
|
+
activation : :class:`torch.nn.Module`, default=:class:`torch.nn.GELU`
|
|
118
|
+
Activation function used in the encoder convolutional blocks. The paper uses GELU
|
|
119
|
+
activation throughout.
|
|
120
|
+
transformer_layers : int, default=8
|
|
121
|
+
Number of transformer encoder layers in the contextualizer. The paper uses 8 layers.
|
|
122
|
+
transformer_heads : int, default=8
|
|
123
|
+
Number of attention heads in each transformer layer. The paper uses 8 heads with
|
|
124
|
+
head dimension of 192 (1536 / 8).
|
|
125
|
+
position_encoder_length : int, default=25
|
|
126
|
+
Kernel size for the convolutional positional encoding layer. The paper uses a
|
|
127
|
+
receptive field of 25 with 16 groups.
|
|
128
|
+
enc_width : tuple of int, default=(3, 2, 2, 2, 2, 2)
|
|
129
|
+
Kernel sizes for each of the 6 convolutional blocks in the encoder. Each value
|
|
130
|
+
corresponds to one block.
|
|
131
|
+
enc_downsample : tuple of int, default=(3, 2, 2, 2, 2, 2)
|
|
132
|
+
Stride values for each of the 6 convolutional blocks in the encoder. The total
|
|
133
|
+
downsampling factor is the product of all strides (3 x 2 x 2 x 2 x 2 x 2 = 96).
|
|
134
|
+
start_token : int or float, default=-5
|
|
135
|
+
Value used to fill the start token embedding that is prepended to the BENDR sequence
|
|
136
|
+
before input to the transformer. This token's output is used as the aggregate
|
|
137
|
+
representation for classification.
|
|
138
|
+
final_layer : bool, default=True
|
|
139
|
+
If True, includes a final linear classification layer that maps from encoder_h to
|
|
140
|
+
n_outputs. If False, the model outputs the contextualized features directly.
|
|
141
|
+
|
|
142
|
+
References
|
|
143
|
+
----------
|
|
144
|
+
.. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
|
|
145
|
+
BENDR: Using transformers and a contrastive self-supervised learning task to learn from
|
|
146
|
+
massive amounts of EEG data.
|
|
147
|
+
Frontiers in Human Neuroscience, 15, 653659.
|
|
148
|
+
https://doi.org/10.3389/fnhum.2021.653659
|
|
149
|
+
.. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
|
|
150
|
+
wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
|
|
151
|
+
In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
|
|
152
|
+
Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
|
|
153
|
+
https://dl.acm.org/doi/10.5555/3495724.3496768
|
|
154
|
+
.. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
|
|
155
|
+
Improving Transformer Optimization Through Better Initialization.
|
|
156
|
+
In International Conference on Machine Learning (pp. 4475-4483). PMLR.
|
|
157
|
+
https://dl.acm.org/doi/10.5555/3524938.3525354
|
|
158
|
+
.. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
|
|
159
|
+
Reducing Transformer Depth on Demand with Structured Dropout.
|
|
160
|
+
International Conference on Learning Representations.
|
|
161
|
+
Retrieved from https://openreview.net/forum?id=SylO2yStDr
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
# Signal related parameters
|
|
167
|
+
n_chans=None,
|
|
168
|
+
n_outputs=None,
|
|
169
|
+
n_times=None,
|
|
170
|
+
chs_info=None,
|
|
171
|
+
input_window_seconds=None,
|
|
172
|
+
sfreq=None,
|
|
173
|
+
# Model parameters
|
|
174
|
+
encoder_h=512, # Hidden size of the encoder convolutional layers
|
|
175
|
+
contextualizer_hidden=3076, # Feedforward hidden size in transformer
|
|
176
|
+
projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
|
|
177
|
+
drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
|
|
178
|
+
layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
|
|
179
|
+
activation=nn.GELU, # Activation function
|
|
180
|
+
# Transformer specific parameters
|
|
181
|
+
transformer_layers=8,
|
|
182
|
+
transformer_heads=8,
|
|
183
|
+
position_encoder_length=25, # Kernel size for positional encoding conv
|
|
184
|
+
enc_width=(3, 2, 2, 2, 2, 2),
|
|
185
|
+
enc_downsample=(3, 2, 2, 2, 2, 2),
|
|
186
|
+
start_token=-5, # Value for start token embedding
|
|
187
|
+
final_layer=True, # Whether to include the final linear layer
|
|
188
|
+
):
|
|
189
|
+
super().__init__(
|
|
190
|
+
n_outputs=n_outputs,
|
|
191
|
+
n_chans=n_chans,
|
|
192
|
+
chs_info=chs_info,
|
|
193
|
+
n_times=n_times,
|
|
194
|
+
input_window_seconds=input_window_seconds,
|
|
195
|
+
sfreq=sfreq,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Keep these parameters if needed later, otherwise they are captured by the mixin
|
|
199
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
200
|
+
|
|
201
|
+
self.encoder_h = encoder_h
|
|
202
|
+
self.contextualizer_hidden = contextualizer_hidden
|
|
203
|
+
self.include_final_layer = final_layer
|
|
204
|
+
|
|
205
|
+
# Encoder: Use parameters from __init__
|
|
206
|
+
self.encoder = _ConvEncoderBENDR(
|
|
207
|
+
in_features=self.n_chans,
|
|
208
|
+
encoder_h=encoder_h,
|
|
209
|
+
dropout=drop_prob,
|
|
210
|
+
projection_head=projection_head,
|
|
211
|
+
enc_width=enc_width,
|
|
212
|
+
enc_downsample=enc_downsample,
|
|
213
|
+
activation=activation,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
self.contextualizer = _BENDRContextualizer(
|
|
217
|
+
in_features=self.encoder.encoder_h, # Use the output feature size of the encoder
|
|
218
|
+
hidden_feedforward=contextualizer_hidden,
|
|
219
|
+
heads=transformer_heads,
|
|
220
|
+
layers=transformer_layers,
|
|
221
|
+
dropout=drop_prob, # Use general dropout probability
|
|
222
|
+
activation=activation,
|
|
223
|
+
position_encoder=position_encoder_length, # Pass position encoder kernel size
|
|
224
|
+
layer_drop=layer_drop,
|
|
225
|
+
start_token=start_token, # Keep fixed start token value
|
|
226
|
+
)
|
|
227
|
+
in_features = self.encoder.encoder_h # Set in_features for final layer
|
|
228
|
+
|
|
229
|
+
self.final_layer = None
|
|
230
|
+
if self.include_final_layer:
|
|
231
|
+
# Input to Linear will be [batch_size, encoder_h] after taking last timestep
|
|
232
|
+
linear = nn.Linear(in_features=in_features, out_features=self.n_outputs)
|
|
233
|
+
self.final_layer = nn.utils.parametrizations.weight_norm(
|
|
234
|
+
linear, name="weight", dim=1
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def forward(self, x):
|
|
238
|
+
# Input x: [batch_size, n_chans, n_times]
|
|
239
|
+
encoded = self.encoder(x)
|
|
240
|
+
# encoded: [batch_size, encoder_h, n_encoded_times]
|
|
241
|
+
|
|
242
|
+
context = self.contextualizer(encoded)
|
|
243
|
+
# context: [batch_size, encoder_h, n_encoded_times + 1] (due to start token)
|
|
244
|
+
|
|
245
|
+
# Extract features - use the output corresponding to the start token (index 0)
|
|
246
|
+
# The output has shape [batch_size, features, seq_len+1]
|
|
247
|
+
# The start token was prepended at position 0 in the contextualizer before the
|
|
248
|
+
# transformer layers. After permutation back to [batch, features, seq_len+1],
|
|
249
|
+
# the start token output is at index 0.
|
|
250
|
+
# According to the paper (Kostas et al. 2021): "The transformer output of this
|
|
251
|
+
# initial position was not modified during pre-training, and only used for
|
|
252
|
+
# downstream tasks." This follows BERT's [CLS] token convention where the
|
|
253
|
+
# first token aggregates sequence information via self-attention.
|
|
254
|
+
feature = context[:, :, 0]
|
|
255
|
+
# feature: [batch_size, encoder_h]
|
|
256
|
+
|
|
257
|
+
if self.final_layer is not None:
|
|
258
|
+
feature = self.final_layer(feature)
|
|
259
|
+
# feature: [batch_size, n_outputs]
|
|
260
|
+
|
|
261
|
+
return feature
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class _ConvEncoderBENDR(nn.Module):
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
in_features,
|
|
268
|
+
encoder_h=512,
|
|
269
|
+
enc_width=(3, 2, 2, 2, 2, 2),
|
|
270
|
+
dropout=0.0,
|
|
271
|
+
projection_head=False,
|
|
272
|
+
enc_downsample=(3, 2, 2, 2, 2, 2),
|
|
273
|
+
activation=nn.GELU,
|
|
274
|
+
):
|
|
275
|
+
super().__init__()
|
|
276
|
+
self.encoder_h = encoder_h
|
|
277
|
+
self.in_features = in_features
|
|
278
|
+
|
|
279
|
+
if not isinstance(enc_width, (list, tuple)):
|
|
280
|
+
enc_width = [enc_width]
|
|
281
|
+
if not isinstance(enc_downsample, (list, tuple)):
|
|
282
|
+
enc_downsample = [enc_downsample]
|
|
283
|
+
if len(enc_downsample) != len(enc_width):
|
|
284
|
+
raise ValueError(
|
|
285
|
+
"Encoder width and downsampling factors must have the same length."
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Centerable convolutions make life simpler
|
|
289
|
+
enc_width = [
|
|
290
|
+
e if e % 2 != 0 else e + 1 for e in enc_width
|
|
291
|
+
] # Ensure odd kernel size
|
|
292
|
+
self._downsampling = enc_downsample
|
|
293
|
+
self._width = enc_width
|
|
294
|
+
|
|
295
|
+
current_in_features = in_features
|
|
296
|
+
self.encoder = nn.Sequential()
|
|
297
|
+
for i, (width, downsample) in enumerate(zip(enc_width, enc_downsample)):
|
|
298
|
+
self.encoder.add_module(
|
|
299
|
+
"Encoder_{}".format(i),
|
|
300
|
+
nn.Sequential(
|
|
301
|
+
nn.Conv1d(
|
|
302
|
+
current_in_features,
|
|
303
|
+
encoder_h,
|
|
304
|
+
width,
|
|
305
|
+
stride=downsample,
|
|
306
|
+
padding=width
|
|
307
|
+
// 2, # Correct padding for 'same' output length before stride
|
|
308
|
+
),
|
|
309
|
+
nn.Dropout2d(dropout), # 2D dropout (matches paper specification)
|
|
310
|
+
nn.GroupNorm(
|
|
311
|
+
encoder_h // 2, encoder_h
|
|
312
|
+
), # Consider making num_groups configurable or ensure encoder_h is divisible by 2
|
|
313
|
+
activation(),
|
|
314
|
+
),
|
|
315
|
+
)
|
|
316
|
+
current_in_features = encoder_h
|
|
317
|
+
if projection_head:
|
|
318
|
+
self.encoder.add_module(
|
|
319
|
+
"projection_head",
|
|
320
|
+
nn.Conv1d(encoder_h, encoder_h, 1),
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def forward(self, x):
|
|
324
|
+
return self.encoder(x)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class _BENDRContextualizer(nn.Module):
|
|
328
|
+
"""Transformer-based contextualizer for BENDR."""
|
|
329
|
+
|
|
330
|
+
def __init__(
|
|
331
|
+
self,
|
|
332
|
+
in_features,
|
|
333
|
+
hidden_feedforward=3076,
|
|
334
|
+
heads=8,
|
|
335
|
+
layers=8,
|
|
336
|
+
dropout=0.1, # Default dropout
|
|
337
|
+
activation=nn.GELU, # Activation for transformer FF layer
|
|
338
|
+
position_encoder=25, # Kernel size for conv positional encoding
|
|
339
|
+
layer_drop=0.0, # Probability of dropping a whole layer
|
|
340
|
+
start_token=-5, # Value for start token embedding
|
|
341
|
+
):
|
|
342
|
+
super().__init__()
|
|
343
|
+
self.dropout = dropout
|
|
344
|
+
self.layer_drop = layer_drop
|
|
345
|
+
self.start_token = start_token # Store start token value
|
|
346
|
+
|
|
347
|
+
# Paper specification: transformer_dim = 3 * encoder_h (1536 for encoder_h=512)
|
|
348
|
+
self.transformer_dim = 3 * in_features
|
|
349
|
+
self.in_features = in_features
|
|
350
|
+
# --- Positional Encoding --- (Applied before projection)
|
|
351
|
+
self.position_encoder = None
|
|
352
|
+
if position_encoder > 0:
|
|
353
|
+
conv = nn.Conv1d(
|
|
354
|
+
in_features,
|
|
355
|
+
in_features,
|
|
356
|
+
kernel_size=position_encoder,
|
|
357
|
+
padding=position_encoder // 2,
|
|
358
|
+
groups=16, # Number of groups for depthwise separation
|
|
359
|
+
)
|
|
360
|
+
# T-Fixup positional encoding initialization (paper specification)
|
|
361
|
+
nn.init.normal_(conv.weight, mean=0, std=2 / self.transformer_dim)
|
|
362
|
+
nn.init.constant_(conv.bias, 0)
|
|
363
|
+
|
|
364
|
+
conv = nn.utils.parametrizations.weight_norm(conv, name="weight", dim=2)
|
|
365
|
+
self.relative_position = nn.Sequential(conv, activation())
|
|
366
|
+
|
|
367
|
+
# --- Input Conditioning --- (Includes projection up to transformer_dim)
|
|
368
|
+
# Rearrange, Norm, Dropout, Project, Rearrange
|
|
369
|
+
self.input_conditioning = nn.Sequential(
|
|
370
|
+
Rearrange(
|
|
371
|
+
"batch channel time -> batch time channel"
|
|
372
|
+
), # Batch, Time, Channels
|
|
373
|
+
nn.LayerNorm(in_features),
|
|
374
|
+
nn.Dropout(dropout),
|
|
375
|
+
nn.Linear(in_features, self.transformer_dim), # Project up using Linear
|
|
376
|
+
# nn.Conv1d(in_features, self.transformer_dim, 1), # Alternative: Project up using Conv1d
|
|
377
|
+
Rearrange(
|
|
378
|
+
"batch time channel -> time batch channel"
|
|
379
|
+
), # Time, Batch, Channels (Transformer expected format)
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# --- Transformer Encoder Layers ---
|
|
383
|
+
# Paper uses T-Fixup: remove internal LayerNorm layers
|
|
384
|
+
|
|
385
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
386
|
+
d_model=self.transformer_dim, # Use projected dimension
|
|
387
|
+
nhead=heads,
|
|
388
|
+
dim_feedforward=hidden_feedforward,
|
|
389
|
+
dropout=dropout, # Dropout within transformer layer
|
|
390
|
+
activation=activation(),
|
|
391
|
+
batch_first=False, # Expects (T, B, C)
|
|
392
|
+
norm_first=False, # Standard post-norm architecture
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# T-Fixup: Replace internal LayerNorms with Identity
|
|
396
|
+
encoder_layer.norm1 = nn.Identity()
|
|
397
|
+
encoder_layer.norm2 = nn.Identity()
|
|
398
|
+
|
|
399
|
+
self.transformer_layers = nn.ModuleList(
|
|
400
|
+
[copy.deepcopy(encoder_layer) for _ in range(layers)]
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Add final LayerNorm after all transformer layers (paper specification)
|
|
404
|
+
self.norm = nn.LayerNorm(self.transformer_dim)
|
|
405
|
+
|
|
406
|
+
# --- Output Layer --- (Project back down to in_features)
|
|
407
|
+
# Paper uses Conv1d with kernel=1 (equivalent to Linear but expects (B,C,T) input)
|
|
408
|
+
self.output_layer = nn.Conv1d(self.transformer_dim, in_features, 1)
|
|
409
|
+
|
|
410
|
+
# Initialize parameters with T-Fixup (paper specification)
|
|
411
|
+
self.apply(self._init_bert_params)
|
|
412
|
+
|
|
413
|
+
def _init_bert_params(self, module):
|
|
414
|
+
"""Initialize linear layers and apply T-Fixup scaling."""
|
|
415
|
+
if isinstance(module, nn.Linear):
|
|
416
|
+
# Standard init
|
|
417
|
+
nn.init.xavier_uniform_(module.weight.data)
|
|
418
|
+
if module.bias is not None:
|
|
419
|
+
module.bias.data.zero_()
|
|
420
|
+
# Tfixup Scaling
|
|
421
|
+
module.weight.data = (
|
|
422
|
+
0.67 * len(self.transformer_layers) ** (-0.25) * module.weight.data
|
|
423
|
+
)
|
|
424
|
+
elif isinstance(module, nn.LayerNorm):
|
|
425
|
+
module.bias.data.zero_()
|
|
426
|
+
module.weight.data.fill_(1.0)
|
|
427
|
+
|
|
428
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
429
|
+
# Input x: [batch_size, in_features, seq_len]
|
|
430
|
+
|
|
431
|
+
# Apply relative positional encoding
|
|
432
|
+
if hasattr(self, "relative_position"):
|
|
433
|
+
pos_enc = self.relative_position(x)
|
|
434
|
+
x = x + pos_enc
|
|
435
|
+
|
|
436
|
+
# Apply input conditioning (includes projection up and rearrange)
|
|
437
|
+
x = self.input_conditioning(x)
|
|
438
|
+
# x: [seq_len, batch_size, transformer_dim]
|
|
439
|
+
|
|
440
|
+
# Prepend start token
|
|
441
|
+
if self.start_token is not None:
|
|
442
|
+
token_emb = torch.full(
|
|
443
|
+
(1, x.shape[1], x.shape[2]),
|
|
444
|
+
float(self.start_token),
|
|
445
|
+
device=x.device,
|
|
446
|
+
requires_grad=False,
|
|
447
|
+
)
|
|
448
|
+
x = torch.cat([token_emb, x], dim=0)
|
|
449
|
+
# x: [seq_len + 1, batch_size, transformer_dim]
|
|
450
|
+
|
|
451
|
+
# Apply transformer layers with layer drop
|
|
452
|
+
for layer in self.transformer_layers:
|
|
453
|
+
if not self.training or torch.rand(1) > self.layer_drop:
|
|
454
|
+
x = layer(x)
|
|
455
|
+
# x: [seq_len + 1, batch_size, transformer_dim]
|
|
456
|
+
|
|
457
|
+
# Apply final LayerNorm (T-Fixup: norm only at the end)
|
|
458
|
+
x = self.norm(x)
|
|
459
|
+
# x: [seq_len + 1, batch_size, transformer_dim]
|
|
460
|
+
|
|
461
|
+
# Permute to (B, C, T) format for Conv1d output layer
|
|
462
|
+
x = Rearrange("time batch channel -> batch channel time")(x)
|
|
463
|
+
# x: [batch_size, transformer_dim, seq_len + 1]
|
|
464
|
+
|
|
465
|
+
# Apply output projection (Conv1d expects B, C, T)
|
|
466
|
+
x = self.output_layer(x)
|
|
467
|
+
# x: [batch_size, in_features, seq_len + 1]
|
|
468
|
+
|
|
469
|
+
return x
|
braindecode/models/biot.py
CHANGED
|
@@ -11,13 +11,15 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
11
11
|
class BIOT(EEGModuleMixin, nn.Module):
|
|
12
12
|
"""BIOT from Yang et al. (2023) [Yang2023]_
|
|
13
13
|
|
|
14
|
+
:bdg-danger:`Large Brain Model`
|
|
15
|
+
|
|
14
16
|
.. figure:: https://braindecode.org/dev/_static/model/biot.jpg
|
|
15
17
|
:align: center
|
|
16
18
|
:alt: BioT
|
|
17
19
|
|
|
18
20
|
BIOT: Cross-data Biosignal Learning in the Wild.
|
|
19
21
|
|
|
20
|
-
BIOT is a large
|
|
22
|
+
BIOT is a large brain model for biosignal classification. It is
|
|
21
23
|
a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
|
|
22
24
|
|
|
23
25
|
It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
|
braindecode/models/ctnet.py
CHANGED
|
@@ -39,16 +39,19 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
39
39
|
The architecture consists of three main components:
|
|
40
40
|
|
|
41
41
|
1. **Convolutional Module**:
|
|
42
|
-
|
|
43
|
-
|
|
42
|
+
|
|
43
|
+
- Apply :class:`EEGNet` to perform some feature extraction, denoted here as
|
|
44
|
+
_PatchEmbeddingEEGNet module.
|
|
44
45
|
|
|
45
46
|
2. **Transformer Encoder Module**:
|
|
47
|
+
|
|
46
48
|
- Utilizes multi-head self-attention mechanisms as EEGConformer but
|
|
47
|
-
|
|
49
|
+
with residual blocks.
|
|
48
50
|
|
|
49
51
|
3. **Classifier Module**:
|
|
52
|
+
|
|
50
53
|
- Combines features from both the convolutional module
|
|
51
|
-
|
|
54
|
+
and the Transformer encoder.
|
|
52
55
|
- Flattens the combined features and applies dropout for regularization.
|
|
53
56
|
- Uses a fully connected layer to produce the final classification output.
|
|
54
57
|
|
braindecode/models/deep4.py
CHANGED
|
@@ -19,9 +19,13 @@ from braindecode.modules import (
|
|
|
19
19
|
class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
20
20
|
"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
:bdg-success:`Convolution`
|
|
23
|
+
|
|
24
|
+
.. figure:: https://onlinelibrary.wiley.com/cms/asset/fc200ccc-d8c4-45b4-8577-56ce4d15999a/hbm23730-fig-0001-m.jpg
|
|
23
25
|
:align: center
|
|
24
|
-
:alt:
|
|
26
|
+
:alt: Deep4Net Architecture
|
|
27
|
+
:width: 600px
|
|
28
|
+
|
|
25
29
|
|
|
26
30
|
Model described in [Schirrmeister2017]_.
|
|
27
31
|
|
|
@@ -8,14 +8,137 @@ from braindecode.models.base import EEGModuleMixin
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
11
|
-
"""
|
|
11
|
+
"""DeepSleepNet from Supratak et al. (2017) [Supratak2017]_.
|
|
12
12
|
|
|
13
|
-
|
|
13
|
+
:bdg-success:`Convolution` :bdg-info:`Recurrent`
|
|
14
|
+
|
|
15
|
+
.. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/master/img/deepsleepnet.png
|
|
14
16
|
:align: center
|
|
15
17
|
:alt: DeepSleepNet Architecture
|
|
18
|
+
:width: 700px
|
|
19
|
+
|
|
20
|
+
.. rubric:: Architectural Overview
|
|
21
|
+
|
|
22
|
+
DeepSleepNet couples **dual-path convolution neural network representation learning** with
|
|
23
|
+
**sequence residual learning** via bidirectional LSTMs.
|
|
24
|
+
|
|
25
|
+
The network have:
|
|
26
|
+
|
|
27
|
+
- (i) learns complementary, time-frequency features from each
|
|
28
|
+
30-s epoch using **two parallel CNNs** (small vs. large first-layer filters), then
|
|
29
|
+
- (ii) models **temporal dependencies across epochs** using **two-layer BiLSTMs**
|
|
30
|
+
with a **residual shortcut** from the CNN features, and finally
|
|
31
|
+
- (iii) outputs per-epoch sleep stages. This design encodes both
|
|
32
|
+
epoch-local patterns and longer-range transition rules used by human scorers.
|
|
33
|
+
|
|
34
|
+
In term of implementation:
|
|
35
|
+
|
|
36
|
+
- (i) :class:`_RepresentationLearning` two CNNs extract epoch-wise features
|
|
37
|
+
(small-filter path for temporal precision; large-filter path for frequency precision);
|
|
38
|
+
- (ii) :class:`_SequenceResidualLearning` stacked BiLSTMs with peepholes + residual shortcut
|
|
39
|
+
inject temporal context while preserving CNN evidence;
|
|
40
|
+
- (iii) :class:`_Classifier` linear readout (softmax) for the five sleep stages.
|
|
41
|
+
|
|
42
|
+
.. rubric:: Macro Components
|
|
43
|
+
|
|
44
|
+
- :class:`_RepresentationLearning` **(dual-path CNN → epoch feature)**
|
|
45
|
+
|
|
46
|
+
- *Operations.*
|
|
47
|
+
- **Small-filter CNN** 4 times:
|
|
48
|
+
|
|
49
|
+
- :class:`~torch.nn.Conv1d`
|
|
50
|
+
- :class:`~torch.nn.BatchNorm1d`
|
|
51
|
+
- :class:`~torch.nn.ReLU`
|
|
52
|
+
- :class:`~torch.nn.MaxPool1d` after.
|
|
53
|
+
|
|
54
|
+
- First conv uses **filter length ≈ Fs/2** and **stride ≈ Fs/16** to emphasize *timing* of graphoelements.
|
|
55
|
+
|
|
56
|
+
- **Large-filter CNN**:
|
|
57
|
+
|
|
58
|
+
- Same stack but first conv uses **filter length ≈ 4·Fs** and
|
|
59
|
+
- **stride ≈ Fs/2** to emphasize *frequency* content.
|
|
60
|
+
|
|
61
|
+
- Outputs from both paths are **concatenated** into the epoch embedding ``a_t``.
|
|
62
|
+
|
|
63
|
+
- *Rationale.*
|
|
64
|
+
Two first-layer scales provide a **learned, dual-scale filter bank** that trades
|
|
65
|
+
temporal vs. frequency precision without hand-crafted features.
|
|
66
|
+
|
|
67
|
+
- :class:`_SequenceResidualLearning` (:class:`~torch.nn.BiLSTM` **context + residual fusion)**
|
|
68
|
+
|
|
69
|
+
- *Operations.*
|
|
70
|
+
- **Two-layer BiLSTM** with **peephole connections** processes the sequence of epoch embeddings
|
|
71
|
+
``{a_t}`` forward and backward; hidden states from both directions are **concatenated**.
|
|
72
|
+
- A **shortcut MLP** (fully connected + :class:`~torch.nn.BatchNorm1d` + :class:`~torch.nn.ReLU`) projects ``a_t`` to the BiLSTM output
|
|
73
|
+
dimension and is **added** (residual) to the :class:`~torch.nn.BiLSTM` output at each time step.
|
|
74
|
+
- *Role.* Encodes **stage-transition rules** and smooths predictions over time while preserving
|
|
75
|
+
salient CNN features via the residual path.
|
|
76
|
+
|
|
77
|
+
- :class:`_Classifier` **(epoch-wise prediction)**
|
|
16
78
|
|
|
17
|
-
|
|
18
|
-
|
|
79
|
+
- *Operations.*
|
|
80
|
+
- :class:`~torch.nn.Linear` to produce per-epoch class probabilities.
|
|
81
|
+
|
|
82
|
+
Original training uses two-step optimization: CNN pretraining on class-balanced data,
|
|
83
|
+
then end-to-end fine-tuning with sequential batches.
|
|
84
|
+
|
|
85
|
+
.. rubric:: Convolutional Details
|
|
86
|
+
|
|
87
|
+
- **Temporal (where time-domain patterns are learned).**
|
|
88
|
+
|
|
89
|
+
Both CNN paths use **1-D temporal convolutions**. The *small-filter* path (first kernel ≈ Fs/2,
|
|
90
|
+
stride ≈ Fs/16) captures *when* characteristic transients occur; the *large-filter* path
|
|
91
|
+
(first kernel ≈ 4·Fs, stride ≈ Fs/2) captures *which* frequency components dominate over the
|
|
92
|
+
epoch. Deeper layers use **small kernels** to refine features with fewer parameters, interleaved
|
|
93
|
+
with **max pooling** for downsampling.
|
|
94
|
+
|
|
95
|
+
- **Spatial (how channels are processed).**
|
|
96
|
+
|
|
97
|
+
The original model operates on **single-channel** raw EEG; convolutions therefore mix only
|
|
98
|
+
along time (no spatial convolution across electrodes).
|
|
99
|
+
|
|
100
|
+
- **Spectral (how frequency information emerges).**
|
|
101
|
+
|
|
102
|
+
No explicit Fourier/wavelet transform is used. The **large-filter path** serves as a
|
|
103
|
+
*frequency-sensitive* analyzer, while the **small-filter path** remains *time-sensitive*,
|
|
104
|
+
together functioning as a **two-band learned filter bank** at the first layer.
|
|
105
|
+
|
|
106
|
+
.. rubric:: Attention / Sequential Modules
|
|
107
|
+
|
|
108
|
+
- **Type.** **Bidirectional LSTM** (two layers) with **peephole connections**; forward and
|
|
109
|
+
backward streams are independent and concatenated.
|
|
110
|
+
- **Shapes.** For a sequence of ``N`` epochs, the CNN produces ``{a_t} ∈ R^{D}``;
|
|
111
|
+
BiLSTM outputs ``h_t ∈ R^{2H}``; the shortcut MLP maps ``a_t → R^{2H}`` to enable
|
|
112
|
+
**element-wise residual addition**.
|
|
113
|
+
- **Role.** Models **long-range temporal dependencies** (e.g., persisting N2 without visible
|
|
114
|
+
K-complex/spindles), stabilizing per-epoch predictions.
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
.. rubric:: Additional Mechanisms
|
|
118
|
+
|
|
119
|
+
- **Residual shortcut over sequence encoder.** Adds projected CNN features to BiLSTM outputs,
|
|
120
|
+
improving gradient flow and retaining discriminative content from representation learning.
|
|
121
|
+
- **Two-step training.**
|
|
122
|
+
|
|
123
|
+
- (i) **Pretrain** the CNN paths with class-balanced sampling;
|
|
124
|
+
- (ii) **fine-tune** the full network with sequential batches, using **lower LR** for CNNs and **higher LR** for the
|
|
125
|
+
sequence encoder.
|
|
126
|
+
|
|
127
|
+
- **State handling.** BiLSTM states are **reinitialized per subject** so that temporal context
|
|
128
|
+
does not leak across recordings.
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
.. rubric:: Usage and Configuration
|
|
132
|
+
|
|
133
|
+
- **Epoch pipeline.** Use **two parallel CNNs** with the first conv sized to **Fs/2** (small path)
|
|
134
|
+
and **4·Fs** (large path), with strides **Fs/16** and **Fs/2**, respectively; stack three more
|
|
135
|
+
conv blocks with small kernels, plus **max pooling** in each path. Concatenate path outputs
|
|
136
|
+
to form epoch embeddings.
|
|
137
|
+
- **Sequence encoder.** Apply **two-layer BiLSTM (peepholes)** over the sequence of embeddings;
|
|
138
|
+
add a **projection MLP** on the CNN features and **sum** with BiLSTM outputs (residual).
|
|
139
|
+
Finish with :class:`~torch.nn.Linear` per epoch.
|
|
140
|
+
- **Reference implementation.** See the official repository for a faithful implementation and
|
|
141
|
+
training scripts.
|
|
19
142
|
|
|
20
143
|
Parameters
|
|
21
144
|
----------
|
|
@@ -32,7 +155,6 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
|
32
155
|
drop_prob : float, default=0.5
|
|
33
156
|
The dropout rate for regularization. Values should be between 0 and 1.
|
|
34
157
|
|
|
35
|
-
|
|
36
158
|
References
|
|
37
159
|
----------
|
|
38
160
|
.. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
|