braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1122 @@
|
|
|
1
|
+
# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Optional, Sequence
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from einops.layers.torch import Rearrange
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
from braindecode.models.base import EEGModuleMixin
|
|
16
|
+
|
|
17
|
+
_DEFAULT_CONV_LAYER_SPEC = ( # downsampling: 128Hz -> 1Hz, receptive field 1.1875s, stride 1s
|
|
18
|
+
(8, 32, 8),
|
|
19
|
+
(16, 2, 2),
|
|
20
|
+
(32, 2, 2),
|
|
21
|
+
(64, 2, 2),
|
|
22
|
+
(64, 2, 2),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
|
|
27
|
+
r"""Base class for the SignalJEPA models
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
feature_encoder__conv_layers_spec: list of tuple
|
|
32
|
+
tuples have shape ``(dim, k, stride)`` where:
|
|
33
|
+
|
|
34
|
+
* ``dim`` : number of output channels of the layer (unrelated to EEG channels);
|
|
35
|
+
* ``k`` : temporal length of the layer's kernel;
|
|
36
|
+
* ``stride`` : temporal stride of the layer's kernel.
|
|
37
|
+
|
|
38
|
+
drop_prob: float
|
|
39
|
+
feature_encoder__mode: str
|
|
40
|
+
Normalisation mode. Either ``default`` or ``layer_norm``.
|
|
41
|
+
feature_encoder__conv_bias: bool
|
|
42
|
+
activation: nn.Module
|
|
43
|
+
Activation layer for the feature encoder.
|
|
44
|
+
pos_encoder__spat_dim: int
|
|
45
|
+
Number of dimensions to use to encode the spatial position of the patch,
|
|
46
|
+
i.e. the EEG channel.
|
|
47
|
+
pos_encoder__time_dim: int
|
|
48
|
+
Number of dimensions to use to encode the temporal position of the patch.
|
|
49
|
+
pos_encoder__sfreq_features: float
|
|
50
|
+
The "downsampled" sampling frequency returned by the feature encoder.
|
|
51
|
+
pos_encoder__spat_kwargs: dict
|
|
52
|
+
Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
|
|
53
|
+
embed the channel names.
|
|
54
|
+
transformer__d_model: int
|
|
55
|
+
The number of expected features in the encoder/decoder inputs.
|
|
56
|
+
transformer__num_encoder_layers: int
|
|
57
|
+
The number of encoder layers in the transformer.
|
|
58
|
+
transformer__num_decoder_layers: int
|
|
59
|
+
The number of decoder layers in the transformer.
|
|
60
|
+
transformer__nhead: int
|
|
61
|
+
The number of heads in the multiheadattention models.
|
|
62
|
+
_init_feature_encoder : bool
|
|
63
|
+
Do not change the default value (used for internal purposes).
|
|
64
|
+
_init_transformer : bool
|
|
65
|
+
Do not change the default value (used for internal purposes).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
feature_encoder: _ConvFeatureEncoder | None
|
|
69
|
+
pos_encoder: _PosEncoder | None
|
|
70
|
+
transformer: nn.Transformer | None
|
|
71
|
+
|
|
72
|
+
_feature_encoder_channels: str = "n_chans"
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
n_outputs=None,
|
|
77
|
+
n_chans=None,
|
|
78
|
+
chs_info=None,
|
|
79
|
+
n_times=None,
|
|
80
|
+
input_window_seconds=None,
|
|
81
|
+
sfreq=None,
|
|
82
|
+
*,
|
|
83
|
+
# feature_encoder
|
|
84
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
85
|
+
tuple[int, int, int]
|
|
86
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
87
|
+
drop_prob: float = 0.0,
|
|
88
|
+
feature_encoder__mode: str = "default",
|
|
89
|
+
feature_encoder__conv_bias: bool = False,
|
|
90
|
+
activation: type[nn.Module] = nn.GELU,
|
|
91
|
+
# pos_encoder
|
|
92
|
+
pos_encoder__spat_dim: int = 30,
|
|
93
|
+
pos_encoder__time_dim: int = 34,
|
|
94
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
95
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
96
|
+
# transformer
|
|
97
|
+
transformer__d_model: int = 64,
|
|
98
|
+
transformer__num_encoder_layers: int = 8,
|
|
99
|
+
transformer__num_decoder_layers: int = 4,
|
|
100
|
+
transformer__nhead: int = 8,
|
|
101
|
+
# other
|
|
102
|
+
_init_feature_encoder: bool,
|
|
103
|
+
_init_transformer: bool,
|
|
104
|
+
):
|
|
105
|
+
super().__init__(
|
|
106
|
+
n_outputs=n_outputs,
|
|
107
|
+
n_chans=n_chans,
|
|
108
|
+
chs_info=chs_info,
|
|
109
|
+
n_times=n_times,
|
|
110
|
+
input_window_seconds=input_window_seconds,
|
|
111
|
+
sfreq=sfreq,
|
|
112
|
+
)
|
|
113
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
114
|
+
|
|
115
|
+
self.feature_encoder = None
|
|
116
|
+
self.pos_encoder = None
|
|
117
|
+
self.transformer = None
|
|
118
|
+
if _init_feature_encoder:
|
|
119
|
+
self.feature_encoder = _ConvFeatureEncoder(
|
|
120
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
121
|
+
channels=getattr(self, self._feature_encoder_channels),
|
|
122
|
+
drop_prob=drop_prob,
|
|
123
|
+
mode=feature_encoder__mode,
|
|
124
|
+
conv_bias=feature_encoder__conv_bias,
|
|
125
|
+
activation=activation,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if _init_transformer:
|
|
129
|
+
ch_locs = [ch["loc"] for ch in self.chs_info] # type: ignore
|
|
130
|
+
self.pos_encoder = _PosEncoder(
|
|
131
|
+
spat_dim=pos_encoder__spat_dim,
|
|
132
|
+
time_dim=pos_encoder__time_dim,
|
|
133
|
+
ch_locs=ch_locs,
|
|
134
|
+
sfreq_features=pos_encoder__sfreq_features,
|
|
135
|
+
spat_kwargs=pos_encoder__spat_kwargs,
|
|
136
|
+
)
|
|
137
|
+
self.transformer = nn.Transformer(
|
|
138
|
+
d_model=transformer__d_model,
|
|
139
|
+
nhead=transformer__nhead,
|
|
140
|
+
num_encoder_layers=transformer__num_encoder_layers,
|
|
141
|
+
num_decoder_layers=transformer__num_decoder_layers,
|
|
142
|
+
batch_first=True,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SignalJEPA(_BaseSignalJEPA):
|
|
147
|
+
r"""Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
|
|
148
|
+
|
|
149
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
|
|
150
|
+
|
|
151
|
+
This model is not meant for classification but for SSL pre-training.
|
|
152
|
+
Its output shape depends on the input shape.
|
|
153
|
+
For classification purposes, three variants of this model are available:
|
|
154
|
+
|
|
155
|
+
* :class:`SignalJEPA_Contextual`
|
|
156
|
+
* :class:`SignalJEPA_PostLocal`
|
|
157
|
+
* :class:`SignalJEPA_PreLocal`
|
|
158
|
+
|
|
159
|
+
The classification architectures can either be instantiated from scratch
|
|
160
|
+
(random parameters) or from a pre-trained :class:`SignalJEPA` model.
|
|
161
|
+
|
|
162
|
+
.. versionadded:: 0.9
|
|
163
|
+
|
|
164
|
+
References
|
|
165
|
+
----------
|
|
166
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
167
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
168
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
n_outputs=None,
|
|
174
|
+
n_chans=None,
|
|
175
|
+
chs_info=None,
|
|
176
|
+
n_times=None,
|
|
177
|
+
input_window_seconds=None,
|
|
178
|
+
sfreq=None,
|
|
179
|
+
*,
|
|
180
|
+
# feature_encoder
|
|
181
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
182
|
+
tuple[int, int, int]
|
|
183
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
184
|
+
drop_prob: float = 0.0,
|
|
185
|
+
feature_encoder__mode: str = "default",
|
|
186
|
+
feature_encoder__conv_bias: bool = False,
|
|
187
|
+
activation: type[nn.Module] = nn.GELU,
|
|
188
|
+
# pos_encoder
|
|
189
|
+
pos_encoder__spat_dim: int = 30,
|
|
190
|
+
pos_encoder__time_dim: int = 34,
|
|
191
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
192
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
193
|
+
# transformer
|
|
194
|
+
transformer__d_model: int = 64,
|
|
195
|
+
transformer__num_encoder_layers: int = 8,
|
|
196
|
+
transformer__num_decoder_layers: int = 4,
|
|
197
|
+
transformer__nhead: int = 8,
|
|
198
|
+
):
|
|
199
|
+
super().__init__(
|
|
200
|
+
n_outputs=n_outputs,
|
|
201
|
+
n_chans=n_chans,
|
|
202
|
+
chs_info=chs_info,
|
|
203
|
+
n_times=n_times,
|
|
204
|
+
input_window_seconds=input_window_seconds,
|
|
205
|
+
sfreq=sfreq,
|
|
206
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
207
|
+
drop_prob=drop_prob,
|
|
208
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
209
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
210
|
+
activation=activation,
|
|
211
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
212
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
213
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
214
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
215
|
+
transformer__d_model=transformer__d_model,
|
|
216
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
217
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
218
|
+
transformer__nhead=transformer__nhead,
|
|
219
|
+
_init_feature_encoder=True,
|
|
220
|
+
_init_transformer=True,
|
|
221
|
+
)
|
|
222
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
223
|
+
self.final_layer = nn.Identity()
|
|
224
|
+
|
|
225
|
+
def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
|
|
226
|
+
local_features = self.feature_encoder(X) # type: ignore
|
|
227
|
+
pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
|
|
228
|
+
local_features += pos_encoding # type: ignore
|
|
229
|
+
contextual_features = self.transformer.encoder(local_features) # type: ignore
|
|
230
|
+
y = self.final_layer(contextual_features) # type: ignore
|
|
231
|
+
return y # type: ignore
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class SignalJEPA_Contextual(_BaseSignalJEPA):
|
|
235
|
+
r"""Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
236
|
+
|
|
237
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
|
|
238
|
+
|
|
239
|
+
This architecture is one of the variants of :class:`SignalJEPA`
|
|
240
|
+
that can be used for classification purposes.
|
|
241
|
+
|
|
242
|
+
.. figure:: https://braindecode.org/dev/_static/model/sjepa_contextual.jpg
|
|
243
|
+
:align: center
|
|
244
|
+
:alt: sJEPA Contextual.
|
|
245
|
+
|
|
246
|
+
.. versionadded:: 0.9
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
n_spat_filters : int
|
|
251
|
+
Number of spatial filters.
|
|
252
|
+
|
|
253
|
+
References
|
|
254
|
+
----------
|
|
255
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
256
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
257
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
n_outputs=None,
|
|
263
|
+
n_chans=None,
|
|
264
|
+
chs_info=None,
|
|
265
|
+
n_times=None,
|
|
266
|
+
input_window_seconds=None,
|
|
267
|
+
sfreq=None,
|
|
268
|
+
*,
|
|
269
|
+
n_spat_filters: int = 4,
|
|
270
|
+
# feature_encoder
|
|
271
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
272
|
+
tuple[int, int, int]
|
|
273
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
274
|
+
drop_prob: float = 0.0,
|
|
275
|
+
feature_encoder__mode: str = "default",
|
|
276
|
+
feature_encoder__conv_bias: bool = False,
|
|
277
|
+
activation: type[nn.Module] = nn.GELU,
|
|
278
|
+
# pos_encoder
|
|
279
|
+
pos_encoder__spat_dim: int = 30,
|
|
280
|
+
pos_encoder__time_dim: int = 34,
|
|
281
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
282
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
283
|
+
# transformer
|
|
284
|
+
transformer__d_model: int = 64,
|
|
285
|
+
transformer__num_encoder_layers: int = 8,
|
|
286
|
+
transformer__num_decoder_layers: int = 4,
|
|
287
|
+
transformer__nhead: int = 8,
|
|
288
|
+
# other
|
|
289
|
+
_init_feature_encoder: bool = True,
|
|
290
|
+
_init_transformer: bool = True,
|
|
291
|
+
):
|
|
292
|
+
super().__init__(
|
|
293
|
+
n_outputs=n_outputs,
|
|
294
|
+
n_chans=n_chans,
|
|
295
|
+
chs_info=chs_info,
|
|
296
|
+
n_times=n_times,
|
|
297
|
+
input_window_seconds=input_window_seconds,
|
|
298
|
+
sfreq=sfreq,
|
|
299
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
300
|
+
drop_prob=drop_prob,
|
|
301
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
302
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
303
|
+
activation=activation,
|
|
304
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
305
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
306
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
307
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
308
|
+
transformer__d_model=transformer__d_model,
|
|
309
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
310
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
311
|
+
transformer__nhead=transformer__nhead,
|
|
312
|
+
_init_feature_encoder=_init_feature_encoder,
|
|
313
|
+
_init_transformer=_init_transformer,
|
|
314
|
+
)
|
|
315
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
316
|
+
self.final_layer = _get_separable_clf_layer(
|
|
317
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
318
|
+
n_chans=self.n_chans,
|
|
319
|
+
n_times=self.n_times,
|
|
320
|
+
n_classes=self.n_outputs,
|
|
321
|
+
n_spat_filters=n_spat_filters,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
@classmethod
|
|
325
|
+
def from_pretrained(
|
|
326
|
+
cls,
|
|
327
|
+
model: Optional[SignalJEPA | str | Path] = None, # type: ignore
|
|
328
|
+
n_outputs: Optional[int] = None, # type: ignore
|
|
329
|
+
n_spat_filters: int = 4,
|
|
330
|
+
chs_info: Optional[list[dict[str, Any]]] = None, # type: ignore
|
|
331
|
+
**kwargs,
|
|
332
|
+
):
|
|
333
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
334
|
+
|
|
335
|
+
Parameters
|
|
336
|
+
----------
|
|
337
|
+
model: SignalJEPA, str, Path, or None
|
|
338
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
339
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
340
|
+
n_outputs: int or None
|
|
341
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
342
|
+
optional when loading from Hub (will be read from config).
|
|
343
|
+
n_spat_filters: int
|
|
344
|
+
Number of spatial filters.
|
|
345
|
+
chs_info: list of dict | None
|
|
346
|
+
Information about each individual EEG channel. This should be filled with
|
|
347
|
+
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
348
|
+
**kwargs
|
|
349
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
350
|
+
"""
|
|
351
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
352
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
353
|
+
# This is a Hub load, delegate to parent class
|
|
354
|
+
if isinstance(model, (str, Path)):
|
|
355
|
+
# model is actually the repo_id or directory path
|
|
356
|
+
return super().from_pretrained(model, **kwargs)
|
|
357
|
+
else:
|
|
358
|
+
# model is None, treat as hub-style load
|
|
359
|
+
return super().from_pretrained(**kwargs)
|
|
360
|
+
|
|
361
|
+
# This is the original SignalJEPA transfer learning case
|
|
362
|
+
if not isinstance(model, SignalJEPA):
|
|
363
|
+
raise TypeError(
|
|
364
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
365
|
+
)
|
|
366
|
+
if n_outputs is None:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
feature_encoder = model.feature_encoder
|
|
372
|
+
pos_encoder = model.pos_encoder
|
|
373
|
+
transformer = model.transformer
|
|
374
|
+
assert feature_encoder is not None
|
|
375
|
+
assert pos_encoder is not None
|
|
376
|
+
assert transformer is not None
|
|
377
|
+
|
|
378
|
+
new_model = cls(
|
|
379
|
+
n_outputs=n_outputs,
|
|
380
|
+
n_chans=model.n_chans,
|
|
381
|
+
n_times=model.n_times,
|
|
382
|
+
chs_info=chs_info,
|
|
383
|
+
n_spat_filters=n_spat_filters,
|
|
384
|
+
feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
|
|
385
|
+
_init_feature_encoder=False,
|
|
386
|
+
_init_transformer=False,
|
|
387
|
+
)
|
|
388
|
+
new_model.feature_encoder = deepcopy(feature_encoder)
|
|
389
|
+
new_model.pos_encoder = deepcopy(pos_encoder)
|
|
390
|
+
new_model.transformer = deepcopy(transformer)
|
|
391
|
+
|
|
392
|
+
if chs_info is not None:
|
|
393
|
+
ch_names = [ch["ch_name"] for ch in chs_info]
|
|
394
|
+
new_model.pos_encoder.set_fixed_ch_names(ch_names)
|
|
395
|
+
|
|
396
|
+
return new_model
|
|
397
|
+
|
|
398
|
+
def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
|
|
399
|
+
local_features = self.feature_encoder(X) # type: ignore
|
|
400
|
+
pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
|
|
401
|
+
local_features += pos_encoding # type: ignore
|
|
402
|
+
contextual_features = self.transformer.encoder(local_features) # type: ignore
|
|
403
|
+
y = self.final_layer(contextual_features) # type: ignore
|
|
404
|
+
return y # type: ignore
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class SignalJEPA_PostLocal(_BaseSignalJEPA):
|
|
408
|
+
r"""Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
409
|
+
|
|
410
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
|
|
411
|
+
|
|
412
|
+
This architecture is one of the variants of :class:`SignalJEPA`
|
|
413
|
+
that can be used for classification purposes.
|
|
414
|
+
|
|
415
|
+
.. figure:: https://braindecode.org/dev/_static/model/sjepa_post-local.jpg
|
|
416
|
+
:align: center
|
|
417
|
+
:alt: sJEPA Pre-Local.
|
|
418
|
+
|
|
419
|
+
.. versionadded:: 0.9
|
|
420
|
+
|
|
421
|
+
Parameters
|
|
422
|
+
----------
|
|
423
|
+
n_spat_filters : int
|
|
424
|
+
Number of spatial filters.
|
|
425
|
+
|
|
426
|
+
References
|
|
427
|
+
----------
|
|
428
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
429
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
430
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def __init__(
|
|
434
|
+
self,
|
|
435
|
+
n_outputs=None,
|
|
436
|
+
n_chans=None,
|
|
437
|
+
chs_info=None,
|
|
438
|
+
n_times=None,
|
|
439
|
+
input_window_seconds=None,
|
|
440
|
+
sfreq=None,
|
|
441
|
+
*,
|
|
442
|
+
n_spat_filters: int = 4,
|
|
443
|
+
# feature_encoder
|
|
444
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
445
|
+
tuple[int, int, int]
|
|
446
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
447
|
+
drop_prob: float = 0.0,
|
|
448
|
+
feature_encoder__mode: str = "default",
|
|
449
|
+
feature_encoder__conv_bias: bool = False,
|
|
450
|
+
activation: type[nn.Module] = nn.GELU,
|
|
451
|
+
# pos_encoder
|
|
452
|
+
pos_encoder__spat_dim: int = 30,
|
|
453
|
+
pos_encoder__time_dim: int = 34,
|
|
454
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
455
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
456
|
+
# transformer
|
|
457
|
+
transformer__d_model: int = 64,
|
|
458
|
+
transformer__num_encoder_layers: int = 8,
|
|
459
|
+
transformer__num_decoder_layers: int = 4,
|
|
460
|
+
transformer__nhead: int = 8,
|
|
461
|
+
# other
|
|
462
|
+
_init_feature_encoder: bool = True,
|
|
463
|
+
):
|
|
464
|
+
super().__init__(
|
|
465
|
+
n_outputs=n_outputs,
|
|
466
|
+
n_chans=n_chans,
|
|
467
|
+
chs_info=chs_info,
|
|
468
|
+
n_times=n_times,
|
|
469
|
+
input_window_seconds=input_window_seconds,
|
|
470
|
+
sfreq=sfreq,
|
|
471
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
472
|
+
drop_prob=drop_prob,
|
|
473
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
474
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
475
|
+
activation=activation,
|
|
476
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
477
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
478
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
479
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
480
|
+
transformer__d_model=transformer__d_model,
|
|
481
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
482
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
483
|
+
transformer__nhead=transformer__nhead,
|
|
484
|
+
_init_feature_encoder=_init_feature_encoder,
|
|
485
|
+
_init_transformer=False,
|
|
486
|
+
)
|
|
487
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
488
|
+
self.final_layer = _get_separable_clf_layer(
|
|
489
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
490
|
+
n_chans=self.n_chans,
|
|
491
|
+
n_times=self.n_times,
|
|
492
|
+
n_classes=self.n_outputs,
|
|
493
|
+
n_spat_filters=n_spat_filters,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
@classmethod
|
|
497
|
+
def from_pretrained(
|
|
498
|
+
cls,
|
|
499
|
+
model: SignalJEPA | str | Path = None, # type: ignore
|
|
500
|
+
n_outputs: int = None, # type: ignore
|
|
501
|
+
n_spat_filters: int = 4,
|
|
502
|
+
**kwargs,
|
|
503
|
+
):
|
|
504
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
model: SignalJEPA, str, Path, or None
|
|
509
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
510
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
511
|
+
n_outputs: int or None
|
|
512
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
513
|
+
optional when loading from Hub (will be read from config).
|
|
514
|
+
n_spat_filters: int
|
|
515
|
+
Number of spatial filters.
|
|
516
|
+
**kwargs
|
|
517
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
518
|
+
"""
|
|
519
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
520
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
521
|
+
# This is a Hub load, delegate to parent class
|
|
522
|
+
if isinstance(model, (str, Path)):
|
|
523
|
+
# model is actually the repo_id or directory path
|
|
524
|
+
return super().from_pretrained(model, **kwargs)
|
|
525
|
+
else:
|
|
526
|
+
# model is None, treat as hub-style load
|
|
527
|
+
return super().from_pretrained(**kwargs)
|
|
528
|
+
|
|
529
|
+
# This is the original SignalJEPA transfer learning case
|
|
530
|
+
if not isinstance(model, SignalJEPA):
|
|
531
|
+
raise TypeError(
|
|
532
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
533
|
+
)
|
|
534
|
+
if n_outputs is None:
|
|
535
|
+
raise ValueError(
|
|
536
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
feature_encoder = model.feature_encoder
|
|
540
|
+
assert feature_encoder is not None
|
|
541
|
+
new_model = cls(
|
|
542
|
+
n_outputs=n_outputs,
|
|
543
|
+
n_chans=model.n_chans,
|
|
544
|
+
n_times=model.n_times,
|
|
545
|
+
n_spat_filters=n_spat_filters,
|
|
546
|
+
feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
|
|
547
|
+
_init_feature_encoder=False,
|
|
548
|
+
)
|
|
549
|
+
new_model.feature_encoder = deepcopy(feature_encoder)
|
|
550
|
+
return new_model
|
|
551
|
+
|
|
552
|
+
def forward(self, X):
|
|
553
|
+
local_features = self.feature_encoder(X)
|
|
554
|
+
y = self.final_layer(local_features)
|
|
555
|
+
return y
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class SignalJEPA_PreLocal(_BaseSignalJEPA):
|
|
559
|
+
r"""Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
560
|
+
|
|
561
|
+
:bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
|
|
562
|
+
|
|
563
|
+
This architecture is one of the variants of :class:`SignalJEPA`
|
|
564
|
+
that can be used for classification purposes.
|
|
565
|
+
|
|
566
|
+
.. figure:: https://braindecode.org/dev/_static/model/sjepa_pre-local.jpg
|
|
567
|
+
:align: center
|
|
568
|
+
:alt: sJEPA Pre-Local.
|
|
569
|
+
|
|
570
|
+
.. versionadded:: 0.9
|
|
571
|
+
|
|
572
|
+
.. important::
|
|
573
|
+
**Pre-trained Weights Available**
|
|
574
|
+
|
|
575
|
+
This model has pre-trained weights available on the Hugging Face Hub.
|
|
576
|
+
You can load them using:
|
|
577
|
+
|
|
578
|
+
.. code-block:: python
|
|
579
|
+
|
|
580
|
+
from braindecode.models import SignalJEPA_PreLocal
|
|
581
|
+
|
|
582
|
+
# Load pre-trained model from Hugging Face Hub
|
|
583
|
+
model = SignalJEPA_PreLocal.from_pretrained(
|
|
584
|
+
"braindecode/SignalJEPA-PreLocal-pretrained"
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
To push your own trained model to the Hub:
|
|
588
|
+
|
|
589
|
+
.. code-block:: python
|
|
590
|
+
|
|
591
|
+
# After training your model
|
|
592
|
+
model.push_to_hub(
|
|
593
|
+
repo_id="username/my-sjepa-model",
|
|
594
|
+
commit_message="Upload trained SignalJEPA model",
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
Requires installing ``braindecode[hug]`` for Hub integration.
|
|
598
|
+
|
|
599
|
+
Parameters
|
|
600
|
+
----------
|
|
601
|
+
n_spat_filters : int
|
|
602
|
+
Number of spatial filters.
|
|
603
|
+
|
|
604
|
+
References
|
|
605
|
+
----------
|
|
606
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
607
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
608
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
_feature_encoder_channels: str = "n_spat_filters"
|
|
612
|
+
|
|
613
|
+
def __init__(
|
|
614
|
+
self,
|
|
615
|
+
n_outputs=None,
|
|
616
|
+
n_chans=None,
|
|
617
|
+
chs_info=None,
|
|
618
|
+
n_times=None,
|
|
619
|
+
input_window_seconds=None,
|
|
620
|
+
sfreq=None,
|
|
621
|
+
*,
|
|
622
|
+
n_spat_filters: int = 4,
|
|
623
|
+
# feature_encoder
|
|
624
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
625
|
+
tuple[int, int, int]
|
|
626
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
627
|
+
drop_prob: float = 0.0,
|
|
628
|
+
feature_encoder__mode: str = "default",
|
|
629
|
+
feature_encoder__conv_bias: bool = False,
|
|
630
|
+
activation: type[nn.Module] = nn.GELU,
|
|
631
|
+
# pos_encoder
|
|
632
|
+
pos_encoder__spat_dim: int = 30,
|
|
633
|
+
pos_encoder__time_dim: int = 34,
|
|
634
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
635
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
636
|
+
# transformer
|
|
637
|
+
transformer__d_model: int = 64,
|
|
638
|
+
transformer__num_encoder_layers: int = 8,
|
|
639
|
+
transformer__num_decoder_layers: int = 4,
|
|
640
|
+
transformer__nhead: int = 8,
|
|
641
|
+
# other
|
|
642
|
+
_init_feature_encoder: bool = True,
|
|
643
|
+
):
|
|
644
|
+
self.n_spat_filters = n_spat_filters
|
|
645
|
+
super().__init__(
|
|
646
|
+
n_outputs=n_outputs,
|
|
647
|
+
n_chans=n_chans,
|
|
648
|
+
chs_info=chs_info,
|
|
649
|
+
n_times=n_times,
|
|
650
|
+
input_window_seconds=input_window_seconds,
|
|
651
|
+
sfreq=sfreq,
|
|
652
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
653
|
+
drop_prob=drop_prob,
|
|
654
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
655
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
656
|
+
activation=activation,
|
|
657
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
658
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
659
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
660
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
661
|
+
transformer__d_model=transformer__d_model,
|
|
662
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
663
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
664
|
+
transformer__nhead=transformer__nhead,
|
|
665
|
+
_init_feature_encoder=_init_feature_encoder,
|
|
666
|
+
_init_transformer=False,
|
|
667
|
+
)
|
|
668
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
669
|
+
self.spatial_conv = nn.Sequential(
|
|
670
|
+
Rearrange("b channels time -> b 1 channels time"),
|
|
671
|
+
nn.Conv2d(1, n_spat_filters, (self.n_chans, 1)),
|
|
672
|
+
Rearrange("b spat_filters 1 time -> b spat_filters time"),
|
|
673
|
+
)
|
|
674
|
+
out_emb_dim = _get_out_emb_dim(
|
|
675
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
676
|
+
n_times=self.n_times,
|
|
677
|
+
n_spat_filters=n_spat_filters,
|
|
678
|
+
)
|
|
679
|
+
self.final_layer = nn.Sequential(
|
|
680
|
+
nn.Flatten(start_dim=1),
|
|
681
|
+
nn.Linear(out_emb_dim, self.n_outputs),
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
@classmethod
|
|
685
|
+
def from_pretrained(
|
|
686
|
+
cls,
|
|
687
|
+
model: SignalJEPA | str | Path = None, # type: ignore
|
|
688
|
+
n_outputs: int = None, # type: ignore
|
|
689
|
+
n_spat_filters: int = 4,
|
|
690
|
+
**kwargs,
|
|
691
|
+
):
|
|
692
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
|
|
693
|
+
|
|
694
|
+
Parameters
|
|
695
|
+
----------
|
|
696
|
+
model: SignalJEPA, str, Path, or None
|
|
697
|
+
Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
|
|
698
|
+
(for Hub-style loading), or None (for Hub loading via kwargs).
|
|
699
|
+
n_outputs: int or None
|
|
700
|
+
Number of classes for the new model. Required when loading from a SignalJEPA model,
|
|
701
|
+
optional when loading from Hub (will be read from config).
|
|
702
|
+
n_spat_filters: int
|
|
703
|
+
Number of spatial filters.
|
|
704
|
+
**kwargs
|
|
705
|
+
Additional keyword arguments passed to the parent class for Hub loading.
|
|
706
|
+
"""
|
|
707
|
+
# Check if this is a Hub-style load (from a directory path)
|
|
708
|
+
if isinstance(model, (str, Path)) or (model is None and kwargs):
|
|
709
|
+
# This is a Hub load, delegate to parent class
|
|
710
|
+
if isinstance(model, (str, Path)):
|
|
711
|
+
# model is actually the repo_id or directory path
|
|
712
|
+
return super().from_pretrained(model, **kwargs)
|
|
713
|
+
else:
|
|
714
|
+
# model is None, treat as hub-style load
|
|
715
|
+
return super().from_pretrained(**kwargs)
|
|
716
|
+
|
|
717
|
+
# This is the original SignalJEPA transfer learning case
|
|
718
|
+
if not isinstance(model, SignalJEPA):
|
|
719
|
+
raise TypeError(
|
|
720
|
+
f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
|
|
721
|
+
)
|
|
722
|
+
if n_outputs is None:
|
|
723
|
+
raise ValueError(
|
|
724
|
+
"n_outputs must be provided when loading from a SignalJEPA model"
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
feature_encoder = model.feature_encoder
|
|
728
|
+
assert feature_encoder is not None
|
|
729
|
+
new_model = cls(
|
|
730
|
+
n_outputs=n_outputs,
|
|
731
|
+
n_chans=model.n_chans,
|
|
732
|
+
n_times=model.n_times,
|
|
733
|
+
n_spat_filters=n_spat_filters,
|
|
734
|
+
feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
|
|
735
|
+
_init_feature_encoder=False,
|
|
736
|
+
)
|
|
737
|
+
new_model.feature_encoder = deepcopy(feature_encoder)
|
|
738
|
+
return new_model
|
|
739
|
+
|
|
740
|
+
def forward(self, X):
|
|
741
|
+
X = self.spatial_conv(X)
|
|
742
|
+
local_features = self.feature_encoder(X)
|
|
743
|
+
y = self.final_layer(local_features)
|
|
744
|
+
return y
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
class _ConvFeatureEncoder(nn.Sequential):
|
|
748
|
+
r"""Convolutional feature encoder for EEG data.
|
|
749
|
+
|
|
750
|
+
Computes successive 1D convolutions (with activations) over the time
|
|
751
|
+
dimension of the input EEG signal.
|
|
752
|
+
|
|
753
|
+
Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
|
|
754
|
+
and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
|
|
755
|
+
|
|
756
|
+
Parameters
|
|
757
|
+
----------
|
|
758
|
+
conv_layers_spec: list of tuple
|
|
759
|
+
tuples have shape ``(dim, k, stride)`` where:
|
|
760
|
+
|
|
761
|
+
* ``dim`` : number of output channels of the layer (unrelated to EEG channels);
|
|
762
|
+
* ``k`` : temporal length of the layer's kernel;
|
|
763
|
+
* ``stride`` : temporal stride of the layer's kernel.
|
|
764
|
+
|
|
765
|
+
channels: int
|
|
766
|
+
drop_prob: float
|
|
767
|
+
mode: str
|
|
768
|
+
Normalisation mode. Either ``default`` or ``layer_norm``.
|
|
769
|
+
conv_bias: bool
|
|
770
|
+
activation: nn.Module
|
|
771
|
+
"""
|
|
772
|
+
|
|
773
|
+
def __init__(
|
|
774
|
+
self,
|
|
775
|
+
conv_layers_spec: Sequence[tuple[int, int, int]],
|
|
776
|
+
channels: int,
|
|
777
|
+
drop_prob: float = 0.0,
|
|
778
|
+
mode: str = "default",
|
|
779
|
+
conv_bias: bool = False,
|
|
780
|
+
activation: type[nn.Module] = nn.GELU,
|
|
781
|
+
):
|
|
782
|
+
assert mode in {"default", "layer_norm"}
|
|
783
|
+
|
|
784
|
+
input_channels = 1
|
|
785
|
+
conv_layers = []
|
|
786
|
+
for i, layer_spec in enumerate(conv_layers_spec):
|
|
787
|
+
# Each layer_spec should be a tuple: (output_channels, kernel_size, stride)
|
|
788
|
+
assert len(layer_spec) == 3, "Invalid conv definition: " + str(layer_spec)
|
|
789
|
+
output_channels, kernel_size, stride = layer_spec
|
|
790
|
+
conv_layers.append(
|
|
791
|
+
self._get_block(
|
|
792
|
+
input_channels,
|
|
793
|
+
output_channels,
|
|
794
|
+
kernel_size,
|
|
795
|
+
stride,
|
|
796
|
+
drop_prob,
|
|
797
|
+
activation,
|
|
798
|
+
is_layer_norm=(mode == "layer_norm"),
|
|
799
|
+
is_group_norm=(mode == "default" and i == 0),
|
|
800
|
+
conv_bias=conv_bias,
|
|
801
|
+
)
|
|
802
|
+
)
|
|
803
|
+
input_channels = output_channels
|
|
804
|
+
all_layers = [
|
|
805
|
+
Rearrange("b channels time -> (b channels) 1 time", channels=channels),
|
|
806
|
+
*conv_layers,
|
|
807
|
+
Rearrange(
|
|
808
|
+
"(b channels) emb_dim time_out -> b (channels time_out) emb_dim",
|
|
809
|
+
channels=channels,
|
|
810
|
+
),
|
|
811
|
+
]
|
|
812
|
+
super().__init__(*all_layers)
|
|
813
|
+
self.emb_dim = (
|
|
814
|
+
output_channels # last output dimension becomes the embedding dimension
|
|
815
|
+
)
|
|
816
|
+
self.conv_layers_spec = conv_layers_spec
|
|
817
|
+
|
|
818
|
+
@staticmethod
|
|
819
|
+
def _get_block(
|
|
820
|
+
input_channels,
|
|
821
|
+
output_channels,
|
|
822
|
+
kernel_size,
|
|
823
|
+
stride,
|
|
824
|
+
drop_prob,
|
|
825
|
+
activation,
|
|
826
|
+
is_layer_norm=False,
|
|
827
|
+
is_group_norm=False,
|
|
828
|
+
conv_bias=False,
|
|
829
|
+
):
|
|
830
|
+
assert not (is_layer_norm and is_group_norm), (
|
|
831
|
+
"layer norm and group norm are exclusive"
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
conv = nn.Conv1d(
|
|
835
|
+
input_channels,
|
|
836
|
+
output_channels,
|
|
837
|
+
kernel_size,
|
|
838
|
+
stride=stride,
|
|
839
|
+
bias=conv_bias,
|
|
840
|
+
)
|
|
841
|
+
nn.init.kaiming_normal_(conv.weight)
|
|
842
|
+
if is_layer_norm:
|
|
843
|
+
return nn.Sequential(
|
|
844
|
+
conv,
|
|
845
|
+
nn.Dropout(p=drop_prob),
|
|
846
|
+
nn.Sequential(
|
|
847
|
+
Rearrange("... channels time -> ... time channels"),
|
|
848
|
+
nn.LayerNorm(output_channels, elementwise_affine=True),
|
|
849
|
+
Rearrange("... time channels -> ... channels time"),
|
|
850
|
+
),
|
|
851
|
+
activation(),
|
|
852
|
+
)
|
|
853
|
+
elif is_group_norm:
|
|
854
|
+
return nn.Sequential(
|
|
855
|
+
conv,
|
|
856
|
+
nn.Dropout(p=drop_prob),
|
|
857
|
+
nn.GroupNorm(output_channels, output_channels, affine=True),
|
|
858
|
+
activation(),
|
|
859
|
+
)
|
|
860
|
+
else:
|
|
861
|
+
return nn.Sequential(conv, nn.Dropout(p=drop_prob), activation())
|
|
862
|
+
|
|
863
|
+
def n_times_out(self, n_times):
|
|
864
|
+
return _n_times_out(self.conv_layers_spec, n_times)
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
class _ChannelEmbedding(nn.Embedding):
|
|
868
|
+
r"""Embedding layer for EEG channels.
|
|
869
|
+
|
|
870
|
+
The difference with a regular :class:`nn.Embedding` is that the embedding
|
|
871
|
+
vectors are initialized with a positional encodding of the channel locations.
|
|
872
|
+
|
|
873
|
+
Parameters
|
|
874
|
+
----------
|
|
875
|
+
channel_locations: list of (list of float or None)
|
|
876
|
+
List of the n-dimensions locations of the EEG channels.
|
|
877
|
+
embedding_dim: int
|
|
878
|
+
Dimensionality of the embedding vectors. Must be a multiple of the number
|
|
879
|
+
of dimensions of the channel locations.
|
|
880
|
+
"""
|
|
881
|
+
|
|
882
|
+
def __init__(
|
|
883
|
+
self, channel_locations: list[list[float] | None], embedding_dim: int, **kwargs
|
|
884
|
+
):
|
|
885
|
+
self.coordinate_ranges = [
|
|
886
|
+
(min(coords), max(coords))
|
|
887
|
+
for coords in zip(
|
|
888
|
+
*[
|
|
889
|
+
loc[3:6] if len(loc) == 12 else loc
|
|
890
|
+
for loc in channel_locations
|
|
891
|
+
if loc is not None
|
|
892
|
+
]
|
|
893
|
+
)
|
|
894
|
+
]
|
|
895
|
+
channel_mins, channel_maxs = zip(*self.coordinate_ranges)
|
|
896
|
+
global_min = min(channel_mins)
|
|
897
|
+
global_max = max(channel_maxs)
|
|
898
|
+
self.max_abs_coordinate = max(abs(global_min), abs(global_max))
|
|
899
|
+
self.embedding_dim_per_coordinate = embedding_dim // len(self.coordinate_ranges)
|
|
900
|
+
self.channel_locations = list(channel_locations)
|
|
901
|
+
|
|
902
|
+
assert embedding_dim % len(self.coordinate_ranges) == 0
|
|
903
|
+
|
|
904
|
+
super().__init__(len(channel_locations), embedding_dim, **kwargs)
|
|
905
|
+
|
|
906
|
+
def reset_parameters(self):
|
|
907
|
+
for i, loc in enumerate(self.channel_locations):
|
|
908
|
+
if loc is None:
|
|
909
|
+
nn.init.zeros_(self.weight[i])
|
|
910
|
+
else:
|
|
911
|
+
for j, (x, (x0, x1)) in enumerate(zip(loc, self.coordinate_ranges)):
|
|
912
|
+
with torch.no_grad():
|
|
913
|
+
self.weight[
|
|
914
|
+
i,
|
|
915
|
+
j * self.embedding_dim_per_coordinate : (j + 1)
|
|
916
|
+
* self.embedding_dim_per_coordinate,
|
|
917
|
+
].copy_(
|
|
918
|
+
_pos_encode_contineous(
|
|
919
|
+
x,
|
|
920
|
+
0,
|
|
921
|
+
10 * self.max_abs_coordinate,
|
|
922
|
+
self.embedding_dim_per_coordinate,
|
|
923
|
+
device=self.weight.device,
|
|
924
|
+
),
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
|
|
928
|
+
class _PosEncoder(nn.Module):
|
|
929
|
+
r"""Positional encoder for EEG data.
|
|
930
|
+
|
|
931
|
+
Parameters
|
|
932
|
+
----------
|
|
933
|
+
spat_dim: int
|
|
934
|
+
Number of dimensions to use to encode the spatial position of the patch,
|
|
935
|
+
i.e. the EEG channel.
|
|
936
|
+
time_dim: int
|
|
937
|
+
Number of dimensions to use to encode the temporal position of the patch.
|
|
938
|
+
ch_locs: list of list of float or 2d array
|
|
939
|
+
List of the n-dimensions locations of the EEG channels.
|
|
940
|
+
sfreq_features: float
|
|
941
|
+
The "downsampled" sampling frequency returned by the feature encoder.
|
|
942
|
+
spat_kwargs: dict
|
|
943
|
+
Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
|
|
944
|
+
embed the channel names.
|
|
945
|
+
max_seconds: float
|
|
946
|
+
Maximum number of seconds to consider for the temporal encoding.
|
|
947
|
+
"""
|
|
948
|
+
|
|
949
|
+
def __init__(
|
|
950
|
+
self,
|
|
951
|
+
spat_dim: int,
|
|
952
|
+
time_dim: int,
|
|
953
|
+
ch_locs,
|
|
954
|
+
sfreq_features: float,
|
|
955
|
+
spat_kwargs: dict | None = None,
|
|
956
|
+
max_seconds: float = 600.0, # 10 minutes
|
|
957
|
+
):
|
|
958
|
+
super().__init__()
|
|
959
|
+
spat_kwargs = spat_kwargs or {}
|
|
960
|
+
self.spat_dim = spat_dim
|
|
961
|
+
self.time_dim = time_dim
|
|
962
|
+
self.max_n_times = int(max_seconds * sfreq_features)
|
|
963
|
+
|
|
964
|
+
# Positional encoder for the spatial dimension:
|
|
965
|
+
self.pos_encoder_spat = _ChannelEmbedding(
|
|
966
|
+
ch_locs, spat_dim, **spat_kwargs
|
|
967
|
+
) # (batch_size, n_channels, spat_dim)
|
|
968
|
+
|
|
969
|
+
# Pre-computed tensor for positional encoding on the time dimension:
|
|
970
|
+
self.encoding_time = torch.zeros(0, dtype=torch.float32, requires_grad=False)
|
|
971
|
+
|
|
972
|
+
def _check_encoding_time(self, n_times: int):
|
|
973
|
+
if self.encoding_time.size(0) < n_times:
|
|
974
|
+
self.encoding_time = self.encoding_time.new_empty((n_times, self.time_dim))
|
|
975
|
+
self.encoding_time[:] = _pos_encode_time(
|
|
976
|
+
n_times=n_times,
|
|
977
|
+
n_dim=self.time_dim,
|
|
978
|
+
max_n_times=self.max_n_times,
|
|
979
|
+
device=self.encoding_time.device,
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
def forward(self, local_features, ch_idxs: torch.Tensor | None = None):
|
|
983
|
+
"""
|
|
984
|
+
Parameters
|
|
985
|
+
----------
|
|
986
|
+
* local_features: (batch_size, n_chans * n_times_out, emb_dim)
|
|
987
|
+
* ch_idxs: (batch_size, n_chans) | None
|
|
988
|
+
Indices of the channels to use in the ``ch_names`` list passed
|
|
989
|
+
as argument plus one. Index 0 is reserved for an unknown channel.
|
|
990
|
+
|
|
991
|
+
Returns
|
|
992
|
+
-------
|
|
993
|
+
pos_encoding: (batch_size, n_chans * n_times_out, emb_dim)
|
|
994
|
+
The first ``spat_dim`` dimensions encode the channels positional encoding
|
|
995
|
+
and the following ``time_dim`` dimensions encode the temporal positional encoding.
|
|
996
|
+
"""
|
|
997
|
+
batch_size, n_chans_times, emb_dim = local_features.shape
|
|
998
|
+
if ch_idxs is None:
|
|
999
|
+
ch_idxs = torch.arange(
|
|
1000
|
+
0,
|
|
1001
|
+
self.pos_encoder_spat.num_embeddings,
|
|
1002
|
+
device=local_features.device,
|
|
1003
|
+
).repeat(batch_size, 1)
|
|
1004
|
+
|
|
1005
|
+
batch_size_chs, n_chans = ch_idxs.shape
|
|
1006
|
+
assert emb_dim >= self.spat_dim + self.time_dim
|
|
1007
|
+
assert n_chans_times % n_chans == 0
|
|
1008
|
+
n_times = n_chans_times // n_chans
|
|
1009
|
+
|
|
1010
|
+
pos_encoding = local_features.new_empty(
|
|
1011
|
+
(batch_size_chs, n_chans, n_times, emb_dim)
|
|
1012
|
+
)
|
|
1013
|
+
# Channel pos. encoding
|
|
1014
|
+
pos_encoding[:, :, :, : self.spat_dim] = self.pos_encoder_spat(ch_idxs)[
|
|
1015
|
+
:, :, None, :
|
|
1016
|
+
]
|
|
1017
|
+
# Temporal pos. encoding
|
|
1018
|
+
self._check_encoding_time(n_times)
|
|
1019
|
+
_ = pos_encoding[:, :, :, self.spat_dim : self.spat_dim + self.time_dim].copy_(
|
|
1020
|
+
self.encoding_time[None, None, :n_times, :],
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
return pos_encoding.view(batch_size, n_chans_times, emb_dim)
|
|
1024
|
+
|
|
1025
|
+
|
|
1026
|
+
def _n_times_out(conv_layers_spec, n_times):
|
|
1027
|
+
# it would be equal to n_times//ds_factor without edge effects:
|
|
1028
|
+
n_times_out_ = n_times
|
|
1029
|
+
for _, width, stride in conv_layers_spec:
|
|
1030
|
+
n_times_out_ = int((n_times_out_ - width) / stride) + 1
|
|
1031
|
+
return n_times_out_
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
def _get_out_emb_dim(conv_layers_spec, n_times, n_spat_filters=4):
|
|
1035
|
+
n_time_out = _n_times_out(conv_layers_spec, n_times)
|
|
1036
|
+
emb_dim = conv_layers_spec[-1][0]
|
|
1037
|
+
return n_spat_filters * n_time_out * emb_dim
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
def _get_separable_clf_layer(
|
|
1041
|
+
conv_layers_spec, n_chans, n_times, n_classes, n_spat_filters=4
|
|
1042
|
+
):
|
|
1043
|
+
out_emb_dim = _get_out_emb_dim(
|
|
1044
|
+
conv_layers_spec=conv_layers_spec,
|
|
1045
|
+
n_times=n_times,
|
|
1046
|
+
n_spat_filters=n_spat_filters,
|
|
1047
|
+
)
|
|
1048
|
+
clf_layer = nn.Sequential()
|
|
1049
|
+
clf_layer.add_module(
|
|
1050
|
+
"unflatten_tokens",
|
|
1051
|
+
Rearrange("b (n_chans tokens) d -> b 1 n_chans tokens d", n_chans=n_chans),
|
|
1052
|
+
)
|
|
1053
|
+
clf_layer.add_module("spat_conv", nn.Conv3d(1, n_spat_filters, (n_chans, 1, 1)))
|
|
1054
|
+
clf_layer.add_module("flatten", nn.Flatten(start_dim=1))
|
|
1055
|
+
clf_layer.add_module("linear", nn.Linear(out_emb_dim, n_classes))
|
|
1056
|
+
return clf_layer
|
|
1057
|
+
|
|
1058
|
+
|
|
1059
|
+
def _pos_encode_time(
|
|
1060
|
+
n_times: int,
|
|
1061
|
+
n_dim: int,
|
|
1062
|
+
max_n_times: int,
|
|
1063
|
+
device: torch.device = torch.device("cpu"),
|
|
1064
|
+
):
|
|
1065
|
+
"""1-dimensional positional encoding.
|
|
1066
|
+
|
|
1067
|
+
Parameters
|
|
1068
|
+
----------
|
|
1069
|
+
n_times: int
|
|
1070
|
+
Number of time samples to encode.
|
|
1071
|
+
n_dim: int
|
|
1072
|
+
Number of dimensions of the positional encoding. Must be even.
|
|
1073
|
+
max_n_times: int
|
|
1074
|
+
The largest possible number of time samples to encode.
|
|
1075
|
+
Used to scale the positional encoding.
|
|
1076
|
+
device: torch.device
|
|
1077
|
+
Device to put the output on.
|
|
1078
|
+
Returns
|
|
1079
|
+
-------
|
|
1080
|
+
pos_encoding: (n_times, n_dim)
|
|
1081
|
+
"""
|
|
1082
|
+
assert n_dim % 2 == 0
|
|
1083
|
+
position = torch.arange(n_times, device=device).unsqueeze(1)
|
|
1084
|
+
div_term = torch.exp(
|
|
1085
|
+
torch.arange(0, n_dim, 2, device=device) * (-math.log(max_n_times) / n_dim)
|
|
1086
|
+
)
|
|
1087
|
+
pos_encoding = torch.empty((n_times, n_dim), dtype=torch.float32, device=device)
|
|
1088
|
+
pos_encoding[:, 0::2] = torch.sin(position * div_term)
|
|
1089
|
+
pos_encoding[:, 1::2] = torch.cos(position * div_term)
|
|
1090
|
+
return pos_encoding
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
def _pos_encode_contineous(
|
|
1094
|
+
x, x_min, x_max, n_dim, device: torch.device = torch.device("cpu")
|
|
1095
|
+
):
|
|
1096
|
+
"""1-dimensional positional encoding.
|
|
1097
|
+
|
|
1098
|
+
Parameters
|
|
1099
|
+
----------
|
|
1100
|
+
x: float
|
|
1101
|
+
The position to encode.
|
|
1102
|
+
x_min: float
|
|
1103
|
+
The minimum possible value of x.
|
|
1104
|
+
x_max: float
|
|
1105
|
+
The maximum possible value of x.
|
|
1106
|
+
n_dim: int
|
|
1107
|
+
Number of dimensions of the positional encoding. Must be even.
|
|
1108
|
+
device: torch.device
|
|
1109
|
+
Device to put the output on.
|
|
1110
|
+
Returns
|
|
1111
|
+
-------
|
|
1112
|
+
pos_encoding: (n_dim,)
|
|
1113
|
+
"""
|
|
1114
|
+
assert n_dim % 2 == 0
|
|
1115
|
+
div_term = torch.exp(
|
|
1116
|
+
(1 - torch.arange(0, n_dim, 2, device=device) / n_dim) * 2 * math.pi
|
|
1117
|
+
)
|
|
1118
|
+
pos_encoding = torch.empty((n_dim,), dtype=torch.float32, device=device)
|
|
1119
|
+
xx = (x - x_min) / (x_max - x_min)
|
|
1120
|
+
pos_encoding[0::2] = torch.sin(xx * div_term)
|
|
1121
|
+
pos_encoding[1::2] = torch.cos(xx * div_term)
|
|
1122
|
+
return pos_encoding
|