braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1012 @@
|
|
|
1
|
+
# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from typing import Any, Sequence
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from einops import parse_shape, rearrange, repeat
|
|
12
|
+
from einops.layers.torch import Rearrange
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
from braindecode.models.base import EEGModuleMixin
|
|
16
|
+
|
|
17
|
+
_DEFAULT_CONV_LAYER_SPEC = ( # downsampling: 128Hz -> 1Hz, receptive field 1.1875s, stride 1s
|
|
18
|
+
(8, 32, 8),
|
|
19
|
+
(16, 2, 2),
|
|
20
|
+
(32, 2, 2),
|
|
21
|
+
(64, 2, 2),
|
|
22
|
+
(64, 2, 2),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
|
|
27
|
+
"""Base class for the SignalJEPA models
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
feature_encoder__conv_layers_spec: list of tuple
|
|
32
|
+
tuples have shape ``(dim, k, stride)`` where:
|
|
33
|
+
|
|
34
|
+
* ``dim`` : number of output channels of the layer (unrelated to EEG channels);
|
|
35
|
+
* ``k`` : temporal length of the layer's kernel;
|
|
36
|
+
* ``stride`` : temporal stride of the layer's kernel.
|
|
37
|
+
|
|
38
|
+
drop_prob: float
|
|
39
|
+
feature_encoder__mode: str
|
|
40
|
+
Normalisation mode. Either ``default`` or ``layer_norm``.
|
|
41
|
+
feature_encoder__conv_bias: bool
|
|
42
|
+
activation: nn.Module
|
|
43
|
+
Activation layer for the feature encoder.
|
|
44
|
+
pos_encoder__spat_dim: int
|
|
45
|
+
Number of dimensions to use to encode the spatial position of the patch,
|
|
46
|
+
i.e. the EEG channel.
|
|
47
|
+
pos_encoder__time_dim: int
|
|
48
|
+
Number of dimensions to use to encode the temporal position of the patch.
|
|
49
|
+
pos_encoder__sfreq_features: float
|
|
50
|
+
The "downsampled" sampling frequency returned by the feature encoder.
|
|
51
|
+
pos_encoder__spat_kwargs: dict
|
|
52
|
+
Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
|
|
53
|
+
embed the channel names.
|
|
54
|
+
transformer__d_model: int
|
|
55
|
+
The number of expected features in the encoder/decoder inputs.
|
|
56
|
+
transformer__num_encoder_layers: int
|
|
57
|
+
The number of encoder layers in the transformer.
|
|
58
|
+
transformer__num_decoder_layers: int
|
|
59
|
+
The number of decoder layers in the transformer.
|
|
60
|
+
transformer__nhead: int
|
|
61
|
+
The number of heads in the multiheadattention models.
|
|
62
|
+
_init_feature_encoder : bool
|
|
63
|
+
Do not change the default value (used for internal purposes).
|
|
64
|
+
_init_transformer : bool
|
|
65
|
+
Do not change the default value (used for internal purposes).
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
feature_encoder: _ConvFeatureEncoder | None
|
|
69
|
+
pos_encoder: _PosEncoder | None
|
|
70
|
+
transformer: nn.Transformer | None
|
|
71
|
+
|
|
72
|
+
_feature_encoder_channels: str = "n_chans"
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
n_outputs=None,
|
|
77
|
+
n_chans=None,
|
|
78
|
+
chs_info=None,
|
|
79
|
+
n_times=None,
|
|
80
|
+
input_window_seconds=None,
|
|
81
|
+
sfreq=None,
|
|
82
|
+
*,
|
|
83
|
+
# feature_encoder
|
|
84
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
85
|
+
tuple[int, int, int]
|
|
86
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
87
|
+
drop_prob: float = 0.0,
|
|
88
|
+
feature_encoder__mode: str = "default",
|
|
89
|
+
feature_encoder__conv_bias: bool = False,
|
|
90
|
+
activation: type[nn.Module] = nn.GELU,
|
|
91
|
+
# pos_encoder
|
|
92
|
+
pos_encoder__spat_dim: int = 30,
|
|
93
|
+
pos_encoder__time_dim: int = 34,
|
|
94
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
95
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
96
|
+
# transformer
|
|
97
|
+
transformer__d_model: int = 64,
|
|
98
|
+
transformer__num_encoder_layers: int = 8,
|
|
99
|
+
transformer__num_decoder_layers: int = 4,
|
|
100
|
+
transformer__nhead: int = 8,
|
|
101
|
+
# other
|
|
102
|
+
_init_feature_encoder: bool,
|
|
103
|
+
_init_transformer: bool,
|
|
104
|
+
):
|
|
105
|
+
super().__init__(
|
|
106
|
+
n_outputs=n_outputs,
|
|
107
|
+
n_chans=n_chans,
|
|
108
|
+
chs_info=chs_info,
|
|
109
|
+
n_times=n_times,
|
|
110
|
+
input_window_seconds=input_window_seconds,
|
|
111
|
+
sfreq=sfreq,
|
|
112
|
+
)
|
|
113
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
114
|
+
|
|
115
|
+
self.feature_encoder = None
|
|
116
|
+
self.pos_encoder = None
|
|
117
|
+
self.transformer = None
|
|
118
|
+
if _init_feature_encoder:
|
|
119
|
+
self.feature_encoder = _ConvFeatureEncoder(
|
|
120
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
121
|
+
channels=getattr(self, self._feature_encoder_channels),
|
|
122
|
+
drop_prob=drop_prob,
|
|
123
|
+
mode=feature_encoder__mode,
|
|
124
|
+
conv_bias=feature_encoder__conv_bias,
|
|
125
|
+
activation=activation,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if _init_transformer:
|
|
129
|
+
ch_locs = [ch["loc"] for ch in self.chs_info] # type: ignore
|
|
130
|
+
self.pos_encoder = _PosEncoder(
|
|
131
|
+
spat_dim=pos_encoder__spat_dim,
|
|
132
|
+
time_dim=pos_encoder__time_dim,
|
|
133
|
+
ch_locs=ch_locs,
|
|
134
|
+
sfreq_features=pos_encoder__sfreq_features,
|
|
135
|
+
spat_kwargs=pos_encoder__spat_kwargs,
|
|
136
|
+
)
|
|
137
|
+
self.transformer = nn.Transformer(
|
|
138
|
+
d_model=transformer__d_model,
|
|
139
|
+
nhead=transformer__nhead,
|
|
140
|
+
num_encoder_layers=transformer__num_encoder_layers,
|
|
141
|
+
num_decoder_layers=transformer__num_decoder_layers,
|
|
142
|
+
batch_first=True,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SignalJEPA(_BaseSignalJEPA):
|
|
147
|
+
"""Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
|
|
148
|
+
|
|
149
|
+
This model is not meant for classification but for SSL pre-training.
|
|
150
|
+
Its output shape depends on the input shape.
|
|
151
|
+
For classification purposes, three variants of this model are available:
|
|
152
|
+
|
|
153
|
+
* :class:`SignalJEPA_Contextual`
|
|
154
|
+
* :class:`SignalJEPA_PostLocal`
|
|
155
|
+
* :class:`SignalJEPA_PreLocal`
|
|
156
|
+
|
|
157
|
+
The classification architectures can either be instantiated from scratch
|
|
158
|
+
(random parameters) or from a pre-trained :class:`SignalJEPA` model.
|
|
159
|
+
|
|
160
|
+
.. versionadded:: 0.9
|
|
161
|
+
|
|
162
|
+
References
|
|
163
|
+
----------
|
|
164
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
165
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
166
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
n_outputs=None,
|
|
172
|
+
n_chans=None,
|
|
173
|
+
chs_info=None,
|
|
174
|
+
n_times=None,
|
|
175
|
+
input_window_seconds=None,
|
|
176
|
+
sfreq=None,
|
|
177
|
+
*,
|
|
178
|
+
# feature_encoder
|
|
179
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
180
|
+
tuple[int, int, int]
|
|
181
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
182
|
+
drop_prob: float = 0.0,
|
|
183
|
+
feature_encoder__mode: str = "default",
|
|
184
|
+
feature_encoder__conv_bias: bool = False,
|
|
185
|
+
activation: type[nn.Module] = nn.GELU,
|
|
186
|
+
# pos_encoder
|
|
187
|
+
pos_encoder__spat_dim: int = 30,
|
|
188
|
+
pos_encoder__time_dim: int = 34,
|
|
189
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
190
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
191
|
+
# transformer
|
|
192
|
+
transformer__d_model: int = 64,
|
|
193
|
+
transformer__num_encoder_layers: int = 8,
|
|
194
|
+
transformer__num_decoder_layers: int = 4,
|
|
195
|
+
transformer__nhead: int = 8,
|
|
196
|
+
):
|
|
197
|
+
super().__init__(
|
|
198
|
+
n_outputs=n_outputs,
|
|
199
|
+
n_chans=n_chans,
|
|
200
|
+
chs_info=chs_info,
|
|
201
|
+
n_times=n_times,
|
|
202
|
+
input_window_seconds=input_window_seconds,
|
|
203
|
+
sfreq=sfreq,
|
|
204
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
205
|
+
drop_prob=drop_prob,
|
|
206
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
207
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
208
|
+
activation=activation,
|
|
209
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
210
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
211
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
212
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
213
|
+
transformer__d_model=transformer__d_model,
|
|
214
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
215
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
216
|
+
transformer__nhead=transformer__nhead,
|
|
217
|
+
_init_feature_encoder=True,
|
|
218
|
+
_init_transformer=True,
|
|
219
|
+
)
|
|
220
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
221
|
+
self.final_layer = nn.Identity()
|
|
222
|
+
|
|
223
|
+
def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
|
|
224
|
+
local_features = self.feature_encoder(X) # type: ignore
|
|
225
|
+
pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
|
|
226
|
+
local_features += pos_encoding # type: ignore
|
|
227
|
+
contextual_features = self.transformer.encoder(local_features) # type: ignore
|
|
228
|
+
y = self.final_layer(contextual_features) # type: ignore
|
|
229
|
+
return y # type: ignore
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class SignalJEPA_Contextual(_BaseSignalJEPA):
|
|
233
|
+
"""Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
234
|
+
|
|
235
|
+
This architecture is one of the variants of :class:`SignalJEPA`
|
|
236
|
+
that can be used for classification purposes.
|
|
237
|
+
|
|
238
|
+
.. figure:: https://braindecode.org/dev/_static/model/sjepa_contextual.jpg
|
|
239
|
+
:align: center
|
|
240
|
+
:alt: sJEPA Contextual.
|
|
241
|
+
|
|
242
|
+
.. versionadded:: 0.9
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
n_spat_filters : int
|
|
247
|
+
Number of spatial filters.
|
|
248
|
+
|
|
249
|
+
References
|
|
250
|
+
----------
|
|
251
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
252
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
253
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
n_outputs=None,
|
|
259
|
+
n_chans=None,
|
|
260
|
+
chs_info=None,
|
|
261
|
+
n_times=None,
|
|
262
|
+
input_window_seconds=None,
|
|
263
|
+
sfreq=None,
|
|
264
|
+
*,
|
|
265
|
+
n_spat_filters: int = 4,
|
|
266
|
+
# feature_encoder
|
|
267
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
268
|
+
tuple[int, int, int]
|
|
269
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
270
|
+
drop_prob: float = 0.0,
|
|
271
|
+
feature_encoder__mode: str = "default",
|
|
272
|
+
feature_encoder__conv_bias: bool = False,
|
|
273
|
+
activation: type[nn.Module] = nn.GELU,
|
|
274
|
+
# pos_encoder
|
|
275
|
+
pos_encoder__spat_dim: int = 30,
|
|
276
|
+
pos_encoder__time_dim: int = 34,
|
|
277
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
278
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
279
|
+
# transformer
|
|
280
|
+
transformer__d_model: int = 64,
|
|
281
|
+
transformer__num_encoder_layers: int = 8,
|
|
282
|
+
transformer__num_decoder_layers: int = 4,
|
|
283
|
+
transformer__nhead: int = 8,
|
|
284
|
+
# other
|
|
285
|
+
_init_feature_encoder: bool = True,
|
|
286
|
+
_init_transformer: bool = True,
|
|
287
|
+
):
|
|
288
|
+
super().__init__(
|
|
289
|
+
n_outputs=n_outputs,
|
|
290
|
+
n_chans=n_chans,
|
|
291
|
+
chs_info=chs_info,
|
|
292
|
+
n_times=n_times,
|
|
293
|
+
input_window_seconds=input_window_seconds,
|
|
294
|
+
sfreq=sfreq,
|
|
295
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
296
|
+
drop_prob=drop_prob,
|
|
297
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
298
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
299
|
+
activation=activation,
|
|
300
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
301
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
302
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
303
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
304
|
+
transformer__d_model=transformer__d_model,
|
|
305
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
306
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
307
|
+
transformer__nhead=transformer__nhead,
|
|
308
|
+
_init_feature_encoder=_init_feature_encoder,
|
|
309
|
+
_init_transformer=_init_transformer,
|
|
310
|
+
)
|
|
311
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
312
|
+
self.final_layer = _get_separable_clf_layer(
|
|
313
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
314
|
+
n_chans=self.n_chans,
|
|
315
|
+
n_times=self.n_times,
|
|
316
|
+
n_classes=self.n_outputs,
|
|
317
|
+
n_spat_filters=n_spat_filters,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
@classmethod
|
|
321
|
+
def from_pretrained(
|
|
322
|
+
cls,
|
|
323
|
+
model: SignalJEPA,
|
|
324
|
+
n_outputs: int,
|
|
325
|
+
n_spat_filters: int = 4,
|
|
326
|
+
chs_info: list[dict[str, Any]] | None = None,
|
|
327
|
+
):
|
|
328
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
329
|
+
|
|
330
|
+
Parameters
|
|
331
|
+
----------
|
|
332
|
+
model: SignalJEPA
|
|
333
|
+
Pre-trained model.
|
|
334
|
+
n_outputs: int
|
|
335
|
+
Number of classes for the new model.
|
|
336
|
+
n_spat_filters: int
|
|
337
|
+
Number of spatial filters.
|
|
338
|
+
chs_info: list of dict | None
|
|
339
|
+
Information about each individual EEG channel. This should be filled with
|
|
340
|
+
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
341
|
+
"""
|
|
342
|
+
feature_encoder = model.feature_encoder
|
|
343
|
+
pos_encoder = model.pos_encoder
|
|
344
|
+
transformer = model.transformer
|
|
345
|
+
assert feature_encoder is not None
|
|
346
|
+
assert pos_encoder is not None
|
|
347
|
+
assert transformer is not None
|
|
348
|
+
|
|
349
|
+
new_model = cls(
|
|
350
|
+
n_outputs=n_outputs,
|
|
351
|
+
n_chans=model.n_chans,
|
|
352
|
+
n_times=model.n_times,
|
|
353
|
+
chs_info=chs_info,
|
|
354
|
+
n_spat_filters=n_spat_filters,
|
|
355
|
+
feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
|
|
356
|
+
_init_feature_encoder=False,
|
|
357
|
+
_init_transformer=False,
|
|
358
|
+
)
|
|
359
|
+
new_model.feature_encoder = deepcopy(feature_encoder)
|
|
360
|
+
new_model.pos_encoder = deepcopy(pos_encoder)
|
|
361
|
+
new_model.transformer = deepcopy(transformer)
|
|
362
|
+
|
|
363
|
+
if chs_info is not None:
|
|
364
|
+
ch_names = [ch["ch_name"] for ch in chs_info]
|
|
365
|
+
new_model.pos_encoder.set_fixed_ch_names(ch_names)
|
|
366
|
+
|
|
367
|
+
return new_model
|
|
368
|
+
|
|
369
|
+
def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
|
|
370
|
+
local_features = self.feature_encoder(X) # type: ignore
|
|
371
|
+
pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
|
|
372
|
+
local_features += pos_encoding # type: ignore
|
|
373
|
+
contextual_features = self.transformer.encoder(local_features) # type: ignore
|
|
374
|
+
y = self.final_layer(contextual_features) # type: ignore
|
|
375
|
+
return y # type: ignore
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class SignalJEPA_PostLocal(_BaseSignalJEPA):
|
|
379
|
+
"""Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
380
|
+
|
|
381
|
+
This architecture is one of the variants of :class:`SignalJEPA`
|
|
382
|
+
that can be used for classification purposes.
|
|
383
|
+
|
|
384
|
+
.. figure:: https://braindecode.org/dev/_static/model/sjepa_post-local.jpg
|
|
385
|
+
:align: center
|
|
386
|
+
:alt: sJEPA Pre-Local.
|
|
387
|
+
|
|
388
|
+
.. versionadded:: 0.9
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
n_spat_filters : int
|
|
393
|
+
Number of spatial filters.
|
|
394
|
+
|
|
395
|
+
References
|
|
396
|
+
----------
|
|
397
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
398
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
399
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
n_outputs=None,
|
|
405
|
+
n_chans=None,
|
|
406
|
+
chs_info=None,
|
|
407
|
+
n_times=None,
|
|
408
|
+
input_window_seconds=None,
|
|
409
|
+
sfreq=None,
|
|
410
|
+
*,
|
|
411
|
+
n_spat_filters: int = 4,
|
|
412
|
+
# feature_encoder
|
|
413
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
414
|
+
tuple[int, int, int]
|
|
415
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
416
|
+
drop_prob: float = 0.0,
|
|
417
|
+
feature_encoder__mode: str = "default",
|
|
418
|
+
feature_encoder__conv_bias: bool = False,
|
|
419
|
+
activation: type[nn.Module] = nn.GELU,
|
|
420
|
+
# pos_encoder
|
|
421
|
+
pos_encoder__spat_dim: int = 30,
|
|
422
|
+
pos_encoder__time_dim: int = 34,
|
|
423
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
424
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
425
|
+
# transformer
|
|
426
|
+
transformer__d_model: int = 64,
|
|
427
|
+
transformer__num_encoder_layers: int = 8,
|
|
428
|
+
transformer__num_decoder_layers: int = 4,
|
|
429
|
+
transformer__nhead: int = 8,
|
|
430
|
+
# other
|
|
431
|
+
_init_feature_encoder: bool = True,
|
|
432
|
+
):
|
|
433
|
+
super().__init__(
|
|
434
|
+
n_outputs=n_outputs,
|
|
435
|
+
n_chans=n_chans,
|
|
436
|
+
chs_info=chs_info,
|
|
437
|
+
n_times=n_times,
|
|
438
|
+
input_window_seconds=input_window_seconds,
|
|
439
|
+
sfreq=sfreq,
|
|
440
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
441
|
+
drop_prob=drop_prob,
|
|
442
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
443
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
444
|
+
activation=activation,
|
|
445
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
446
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
447
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
448
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
449
|
+
transformer__d_model=transformer__d_model,
|
|
450
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
451
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
452
|
+
transformer__nhead=transformer__nhead,
|
|
453
|
+
_init_feature_encoder=_init_feature_encoder,
|
|
454
|
+
_init_transformer=False,
|
|
455
|
+
)
|
|
456
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
457
|
+
self.final_layer = _get_separable_clf_layer(
|
|
458
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
459
|
+
n_chans=self.n_chans,
|
|
460
|
+
n_times=self.n_times,
|
|
461
|
+
n_classes=self.n_outputs,
|
|
462
|
+
n_spat_filters=n_spat_filters,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
@classmethod
|
|
466
|
+
def from_pretrained(
|
|
467
|
+
cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
|
|
468
|
+
):
|
|
469
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
470
|
+
|
|
471
|
+
Parameters
|
|
472
|
+
----------
|
|
473
|
+
model: SignalJEPA
|
|
474
|
+
Pre-trained model.
|
|
475
|
+
n_outputs: int
|
|
476
|
+
Number of classes for the new model.
|
|
477
|
+
n_spat_filters: int
|
|
478
|
+
Number of spatial filters.
|
|
479
|
+
chs_info: list of dict | None
|
|
480
|
+
Information about each individual EEG channel. This should be filled with
|
|
481
|
+
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
482
|
+
"""
|
|
483
|
+
feature_encoder = model.feature_encoder
|
|
484
|
+
assert feature_encoder is not None
|
|
485
|
+
new_model = cls(
|
|
486
|
+
n_outputs=n_outputs,
|
|
487
|
+
n_chans=model.n_chans,
|
|
488
|
+
n_times=model.n_times,
|
|
489
|
+
n_spat_filters=n_spat_filters,
|
|
490
|
+
feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
|
|
491
|
+
_init_feature_encoder=False,
|
|
492
|
+
)
|
|
493
|
+
new_model.feature_encoder = deepcopy(feature_encoder)
|
|
494
|
+
return new_model
|
|
495
|
+
|
|
496
|
+
def forward(self, X):
|
|
497
|
+
local_features = self.feature_encoder(X)
|
|
498
|
+
y = self.final_layer(local_features)
|
|
499
|
+
return y
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
class SignalJEPA_PreLocal(_BaseSignalJEPA):
|
|
503
|
+
"""Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
|
|
504
|
+
|
|
505
|
+
This architecture is one of the variants of :class:`SignalJEPA`
|
|
506
|
+
that can be used for classification purposes.
|
|
507
|
+
|
|
508
|
+
.. figure:: https://braindecode.org/dev/_static/model/sjepa_pre-local.jpg
|
|
509
|
+
:align: center
|
|
510
|
+
:alt: sJEPA Pre-Local.
|
|
511
|
+
|
|
512
|
+
.. versionadded:: 0.9
|
|
513
|
+
|
|
514
|
+
Parameters
|
|
515
|
+
----------
|
|
516
|
+
n_spat_filters : int
|
|
517
|
+
Number of spatial filters.
|
|
518
|
+
|
|
519
|
+
References
|
|
520
|
+
----------
|
|
521
|
+
.. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
|
|
522
|
+
S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
|
|
523
|
+
In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
_feature_encoder_channels: str = "n_spat_filters"
|
|
527
|
+
|
|
528
|
+
def __init__(
|
|
529
|
+
self,
|
|
530
|
+
n_outputs=None,
|
|
531
|
+
n_chans=None,
|
|
532
|
+
chs_info=None,
|
|
533
|
+
n_times=None,
|
|
534
|
+
input_window_seconds=None,
|
|
535
|
+
sfreq=None,
|
|
536
|
+
*,
|
|
537
|
+
n_spat_filters: int = 4,
|
|
538
|
+
# feature_encoder
|
|
539
|
+
feature_encoder__conv_layers_spec: Sequence[
|
|
540
|
+
tuple[int, int, int]
|
|
541
|
+
] = _DEFAULT_CONV_LAYER_SPEC,
|
|
542
|
+
drop_prob: float = 0.0,
|
|
543
|
+
feature_encoder__mode: str = "default",
|
|
544
|
+
feature_encoder__conv_bias: bool = False,
|
|
545
|
+
activation: type[nn.Module] = nn.GELU,
|
|
546
|
+
# pos_encoder
|
|
547
|
+
pos_encoder__spat_dim: int = 30,
|
|
548
|
+
pos_encoder__time_dim: int = 34,
|
|
549
|
+
pos_encoder__sfreq_features: float = 1.0,
|
|
550
|
+
pos_encoder__spat_kwargs: dict | None = None,
|
|
551
|
+
# transformer
|
|
552
|
+
transformer__d_model: int = 64,
|
|
553
|
+
transformer__num_encoder_layers: int = 8,
|
|
554
|
+
transformer__num_decoder_layers: int = 4,
|
|
555
|
+
transformer__nhead: int = 8,
|
|
556
|
+
# other
|
|
557
|
+
_init_feature_encoder: bool = True,
|
|
558
|
+
):
|
|
559
|
+
self.n_spat_filters = n_spat_filters
|
|
560
|
+
super().__init__(
|
|
561
|
+
n_outputs=n_outputs,
|
|
562
|
+
n_chans=n_chans,
|
|
563
|
+
chs_info=chs_info,
|
|
564
|
+
n_times=n_times,
|
|
565
|
+
input_window_seconds=input_window_seconds,
|
|
566
|
+
sfreq=sfreq,
|
|
567
|
+
feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
568
|
+
drop_prob=drop_prob,
|
|
569
|
+
feature_encoder__mode=feature_encoder__mode,
|
|
570
|
+
feature_encoder__conv_bias=feature_encoder__conv_bias,
|
|
571
|
+
activation=activation,
|
|
572
|
+
pos_encoder__spat_dim=pos_encoder__spat_dim,
|
|
573
|
+
pos_encoder__time_dim=pos_encoder__time_dim,
|
|
574
|
+
pos_encoder__sfreq_features=pos_encoder__sfreq_features,
|
|
575
|
+
pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
|
|
576
|
+
transformer__d_model=transformer__d_model,
|
|
577
|
+
transformer__num_encoder_layers=transformer__num_encoder_layers,
|
|
578
|
+
transformer__num_decoder_layers=transformer__num_decoder_layers,
|
|
579
|
+
transformer__nhead=transformer__nhead,
|
|
580
|
+
_init_feature_encoder=_init_feature_encoder,
|
|
581
|
+
_init_transformer=False,
|
|
582
|
+
)
|
|
583
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
584
|
+
self.spatial_conv = nn.Sequential(
|
|
585
|
+
Rearrange("b channels time -> b 1 channels time"),
|
|
586
|
+
nn.Conv2d(1, n_spat_filters, (self.n_chans, 1)),
|
|
587
|
+
Rearrange("b spat_filters 1 time -> b spat_filters time"),
|
|
588
|
+
)
|
|
589
|
+
out_emb_dim = _get_out_emb_dim(
|
|
590
|
+
conv_layers_spec=feature_encoder__conv_layers_spec,
|
|
591
|
+
n_times=self.n_times,
|
|
592
|
+
n_spat_filters=n_spat_filters,
|
|
593
|
+
)
|
|
594
|
+
self.final_layer = nn.Sequential(
|
|
595
|
+
nn.Flatten(start_dim=1),
|
|
596
|
+
nn.Linear(out_emb_dim, self.n_outputs),
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
@classmethod
|
|
600
|
+
def from_pretrained(
|
|
601
|
+
cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
|
|
602
|
+
):
|
|
603
|
+
"""Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
|
|
604
|
+
|
|
605
|
+
Parameters
|
|
606
|
+
----------
|
|
607
|
+
model: SignalJEPA
|
|
608
|
+
Pre-trained model.
|
|
609
|
+
n_outputs: int
|
|
610
|
+
Number of classes for the new model.
|
|
611
|
+
n_spat_filters: int
|
|
612
|
+
Number of spatial filters.
|
|
613
|
+
chs_info: list of dict | None
|
|
614
|
+
Information about each individual EEG channel. This should be filled with
|
|
615
|
+
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
616
|
+
"""
|
|
617
|
+
feature_encoder = model.feature_encoder
|
|
618
|
+
assert feature_encoder is not None
|
|
619
|
+
new_model = cls(
|
|
620
|
+
n_outputs=n_outputs,
|
|
621
|
+
n_chans=model.n_chans,
|
|
622
|
+
n_times=model.n_times,
|
|
623
|
+
n_spat_filters=n_spat_filters,
|
|
624
|
+
feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
|
|
625
|
+
_init_feature_encoder=False,
|
|
626
|
+
)
|
|
627
|
+
new_model.feature_encoder = deepcopy(feature_encoder)
|
|
628
|
+
return new_model
|
|
629
|
+
|
|
630
|
+
def forward(self, X):
|
|
631
|
+
X = self.spatial_conv(X)
|
|
632
|
+
local_features = self.feature_encoder(X)
|
|
633
|
+
y = self.final_layer(local_features)
|
|
634
|
+
return y
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
class _ConvFeatureEncoder(nn.Sequential):
|
|
638
|
+
"""Convolutional feature encoder for EEG data.
|
|
639
|
+
|
|
640
|
+
Computes successive 1D convolutions (with activations) over the time
|
|
641
|
+
dimension of the input EEG signal.
|
|
642
|
+
|
|
643
|
+
Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
|
|
644
|
+
and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
|
|
645
|
+
|
|
646
|
+
Parameters
|
|
647
|
+
----------
|
|
648
|
+
conv_layers_spec: list of tuple
|
|
649
|
+
tuples have shape ``(dim, k, stride)`` where:
|
|
650
|
+
|
|
651
|
+
* ``dim`` : number of output channels of the layer (unrelated to EEG channels);
|
|
652
|
+
* ``k`` : temporal length of the layer's kernel;
|
|
653
|
+
* ``stride`` : temporal stride of the layer's kernel.
|
|
654
|
+
|
|
655
|
+
channels: int
|
|
656
|
+
drop_prob: float
|
|
657
|
+
mode: str
|
|
658
|
+
Normalisation mode. Either ``default`` or ``layer_norm``.
|
|
659
|
+
conv_bias: bool
|
|
660
|
+
activation: nn.Module
|
|
661
|
+
"""
|
|
662
|
+
|
|
663
|
+
def __init__(
|
|
664
|
+
self,
|
|
665
|
+
conv_layers_spec: Sequence[tuple[int, int, int]],
|
|
666
|
+
channels: int,
|
|
667
|
+
drop_prob: float = 0.0,
|
|
668
|
+
mode: str = "default",
|
|
669
|
+
conv_bias: bool = False,
|
|
670
|
+
activation: type[nn.Module] = nn.GELU,
|
|
671
|
+
):
|
|
672
|
+
assert mode in {"default", "layer_norm"}
|
|
673
|
+
|
|
674
|
+
input_channels = 1
|
|
675
|
+
conv_layers = []
|
|
676
|
+
for i, layer_spec in enumerate(conv_layers_spec):
|
|
677
|
+
# Each layer_spec should be a tuple: (output_channels, kernel_size, stride)
|
|
678
|
+
assert len(layer_spec) == 3, "Invalid conv definition: " + str(layer_spec)
|
|
679
|
+
output_channels, kernel_size, stride = layer_spec
|
|
680
|
+
conv_layers.append(
|
|
681
|
+
self._get_block(
|
|
682
|
+
input_channels,
|
|
683
|
+
output_channels,
|
|
684
|
+
kernel_size,
|
|
685
|
+
stride,
|
|
686
|
+
drop_prob,
|
|
687
|
+
activation,
|
|
688
|
+
is_layer_norm=(mode == "layer_norm"),
|
|
689
|
+
is_group_norm=(mode == "default" and i == 0),
|
|
690
|
+
conv_bias=conv_bias,
|
|
691
|
+
)
|
|
692
|
+
)
|
|
693
|
+
input_channels = output_channels
|
|
694
|
+
all_layers = [
|
|
695
|
+
Rearrange("b channels time -> (b channels) 1 time", channels=channels),
|
|
696
|
+
*conv_layers,
|
|
697
|
+
Rearrange(
|
|
698
|
+
"(b channels) emb_dim time_out -> b (channels time_out) emb_dim",
|
|
699
|
+
channels=channels,
|
|
700
|
+
),
|
|
701
|
+
]
|
|
702
|
+
super().__init__(*all_layers)
|
|
703
|
+
self.emb_dim = (
|
|
704
|
+
output_channels # last output dimension becomes the embedding dimension
|
|
705
|
+
)
|
|
706
|
+
self.conv_layers_spec = conv_layers_spec
|
|
707
|
+
|
|
708
|
+
@staticmethod
|
|
709
|
+
def _get_block(
|
|
710
|
+
input_channels,
|
|
711
|
+
output_channels,
|
|
712
|
+
kernel_size,
|
|
713
|
+
stride,
|
|
714
|
+
drop_prob,
|
|
715
|
+
activation,
|
|
716
|
+
is_layer_norm=False,
|
|
717
|
+
is_group_norm=False,
|
|
718
|
+
conv_bias=False,
|
|
719
|
+
):
|
|
720
|
+
assert not (is_layer_norm and is_group_norm), (
|
|
721
|
+
"layer norm and group norm are exclusive"
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
conv = nn.Conv1d(
|
|
725
|
+
input_channels,
|
|
726
|
+
output_channels,
|
|
727
|
+
kernel_size,
|
|
728
|
+
stride=stride,
|
|
729
|
+
bias=conv_bias,
|
|
730
|
+
)
|
|
731
|
+
nn.init.kaiming_normal_(conv.weight)
|
|
732
|
+
if is_layer_norm:
|
|
733
|
+
return nn.Sequential(
|
|
734
|
+
conv,
|
|
735
|
+
nn.Dropout(p=drop_prob),
|
|
736
|
+
nn.Sequential(
|
|
737
|
+
Rearrange("... channels time -> ... time channels"),
|
|
738
|
+
nn.LayerNorm(output_channels, elementwise_affine=True),
|
|
739
|
+
Rearrange("... time channels -> ... channels time"),
|
|
740
|
+
),
|
|
741
|
+
activation(),
|
|
742
|
+
)
|
|
743
|
+
elif is_group_norm:
|
|
744
|
+
return nn.Sequential(
|
|
745
|
+
conv,
|
|
746
|
+
nn.Dropout(p=drop_prob),
|
|
747
|
+
nn.GroupNorm(output_channels, output_channels, affine=True),
|
|
748
|
+
activation(),
|
|
749
|
+
)
|
|
750
|
+
else:
|
|
751
|
+
return nn.Sequential(conv, nn.Dropout(p=drop_prob), activation())
|
|
752
|
+
|
|
753
|
+
def n_times_out(self, n_times):
|
|
754
|
+
return _n_times_out(self.conv_layers_spec, n_times)
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
class _ChannelEmbedding(nn.Embedding):
|
|
758
|
+
"""Embedding layer for EEG channels.
|
|
759
|
+
|
|
760
|
+
The difference with a regular :class:`nn.Embedding` is that the embedding
|
|
761
|
+
vectors are initialized with a positional encodding of the channel locations.
|
|
762
|
+
|
|
763
|
+
Parameters
|
|
764
|
+
----------
|
|
765
|
+
channel_locations: list of (list of float or None)
|
|
766
|
+
List of the n-dimensions locations of the EEG channels.
|
|
767
|
+
embedding_dim: int
|
|
768
|
+
Dimensionality of the embedding vectors. Must be a multiple of the number
|
|
769
|
+
of dimensions of the channel locations.
|
|
770
|
+
"""
|
|
771
|
+
|
|
772
|
+
def __init__(
|
|
773
|
+
self, channel_locations: list[list[float] | None], embedding_dim: int, **kwargs
|
|
774
|
+
):
|
|
775
|
+
self.coordinate_ranges = [
|
|
776
|
+
(min(coords), max(coords))
|
|
777
|
+
for coords in zip(
|
|
778
|
+
*[
|
|
779
|
+
loc[3:6] if len(loc) == 12 else loc
|
|
780
|
+
for loc in channel_locations
|
|
781
|
+
if loc is not None
|
|
782
|
+
]
|
|
783
|
+
)
|
|
784
|
+
]
|
|
785
|
+
channel_mins, channel_maxs = zip(*self.coordinate_ranges)
|
|
786
|
+
global_min = min(channel_mins)
|
|
787
|
+
global_max = max(channel_maxs)
|
|
788
|
+
self.max_abs_coordinate = max(abs(global_min), abs(global_max))
|
|
789
|
+
self.embedding_dim_per_coordinate = embedding_dim // len(self.coordinate_ranges)
|
|
790
|
+
self.channel_locations = list(channel_locations)
|
|
791
|
+
|
|
792
|
+
assert embedding_dim % len(self.coordinate_ranges) == 0
|
|
793
|
+
|
|
794
|
+
super().__init__(len(channel_locations), embedding_dim, **kwargs)
|
|
795
|
+
|
|
796
|
+
def reset_parameters(self):
|
|
797
|
+
for i, loc in enumerate(self.channel_locations):
|
|
798
|
+
if loc is None:
|
|
799
|
+
nn.init.zeros_(self.weight[i])
|
|
800
|
+
else:
|
|
801
|
+
for j, (x, (x0, x1)) in enumerate(zip(loc, self.coordinate_ranges)):
|
|
802
|
+
with torch.no_grad():
|
|
803
|
+
self.weight[
|
|
804
|
+
i,
|
|
805
|
+
j * self.embedding_dim_per_coordinate : (j + 1)
|
|
806
|
+
* self.embedding_dim_per_coordinate,
|
|
807
|
+
].copy_(
|
|
808
|
+
_pos_encode_contineous(
|
|
809
|
+
x,
|
|
810
|
+
0,
|
|
811
|
+
10 * self.max_abs_coordinate,
|
|
812
|
+
self.embedding_dim_per_coordinate,
|
|
813
|
+
device=self.weight.device,
|
|
814
|
+
),
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
class _PosEncoder(nn.Module):
|
|
819
|
+
"""Positional encoder for EEG data.
|
|
820
|
+
|
|
821
|
+
Parameters
|
|
822
|
+
----------
|
|
823
|
+
spat_dim: int
|
|
824
|
+
Number of dimensions to use to encode the spatial position of the patch,
|
|
825
|
+
i.e. the EEG channel.
|
|
826
|
+
time_dim: int
|
|
827
|
+
Number of dimensions to use to encode the temporal position of the patch.
|
|
828
|
+
ch_locs: list of list of float or 2d array
|
|
829
|
+
List of the n-dimensions locations of the EEG channels.
|
|
830
|
+
sfreq_features: float
|
|
831
|
+
The "downsampled" sampling frequency returned by the feature encoder.
|
|
832
|
+
spat_kwargs: dict
|
|
833
|
+
Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
|
|
834
|
+
embed the channel names.
|
|
835
|
+
max_seconds: float
|
|
836
|
+
Maximum number of seconds to consider for the temporal encoding.
|
|
837
|
+
"""
|
|
838
|
+
|
|
839
|
+
def __init__(
|
|
840
|
+
self,
|
|
841
|
+
spat_dim: int,
|
|
842
|
+
time_dim: int,
|
|
843
|
+
ch_locs,
|
|
844
|
+
sfreq_features: float,
|
|
845
|
+
spat_kwargs: dict | None = None,
|
|
846
|
+
max_seconds: float = 600.0, # 10 minutes
|
|
847
|
+
):
|
|
848
|
+
super().__init__()
|
|
849
|
+
spat_kwargs = spat_kwargs or {}
|
|
850
|
+
self.spat_dim = spat_dim
|
|
851
|
+
self.time_dim = time_dim
|
|
852
|
+
self.max_n_times = int(max_seconds * sfreq_features)
|
|
853
|
+
|
|
854
|
+
# Positional encoder for the spatial dimension:
|
|
855
|
+
self.pos_encoder_spat = _ChannelEmbedding(
|
|
856
|
+
ch_locs, spat_dim, **spat_kwargs
|
|
857
|
+
) # (batch_size, n_channels, spat_dim)
|
|
858
|
+
|
|
859
|
+
# Pre-computed tensor for positional encoding on the time dimension:
|
|
860
|
+
self.encoding_time = torch.zeros(0, dtype=torch.float32, requires_grad=False)
|
|
861
|
+
|
|
862
|
+
def _check_encoding_time(self, n_times: int):
|
|
863
|
+
if self.encoding_time.size(0) < n_times:
|
|
864
|
+
self.encoding_time = self.encoding_time.new_empty((n_times, self.time_dim))
|
|
865
|
+
self.encoding_time[:] = _pos_encode_time(
|
|
866
|
+
n_times=n_times,
|
|
867
|
+
n_dim=self.time_dim,
|
|
868
|
+
max_n_times=self.max_n_times,
|
|
869
|
+
device=self.encoding_time.device,
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
def forward(self, local_features, ch_idxs: torch.Tensor | None = None):
|
|
873
|
+
"""
|
|
874
|
+
Parameters
|
|
875
|
+
----------
|
|
876
|
+
* local_features: (batch_size, n_chans * n_times_out, emb_dim)
|
|
877
|
+
* ch_idxs: (batch_size, n_chans) | None
|
|
878
|
+
Indices of the channels to use in the ``ch_names`` list passed
|
|
879
|
+
as argument plus one. Index 0 is reserved for an unknown channel.
|
|
880
|
+
|
|
881
|
+
Returns
|
|
882
|
+
-------
|
|
883
|
+
pos_encoding: (batch_size, n_chans * n_times_out, emb_dim)
|
|
884
|
+
The first ``spat_dim`` dimensions encode the channels positional encoding
|
|
885
|
+
and the following ``time_dim`` dimensions encode the temporal positional encoding.
|
|
886
|
+
"""
|
|
887
|
+
batch_size, n_chans_times, emb_dim = local_features.shape
|
|
888
|
+
if ch_idxs is None:
|
|
889
|
+
ch_idxs = torch.arange(
|
|
890
|
+
0,
|
|
891
|
+
self.pos_encoder_spat.num_embeddings,
|
|
892
|
+
device=local_features.device,
|
|
893
|
+
).repeat(batch_size, 1)
|
|
894
|
+
|
|
895
|
+
batch_size_chs, n_chans = ch_idxs.shape
|
|
896
|
+
assert emb_dim >= self.spat_dim + self.time_dim
|
|
897
|
+
assert n_chans_times % n_chans == 0
|
|
898
|
+
n_times = n_chans_times // n_chans
|
|
899
|
+
|
|
900
|
+
pos_encoding = local_features.new_empty(
|
|
901
|
+
(batch_size_chs, n_chans, n_times, emb_dim)
|
|
902
|
+
)
|
|
903
|
+
# Channel pos. encoding
|
|
904
|
+
pos_encoding[:, :, :, : self.spat_dim] = self.pos_encoder_spat(ch_idxs)[
|
|
905
|
+
:, :, None, :
|
|
906
|
+
]
|
|
907
|
+
# Temporal pos. encoding
|
|
908
|
+
self._check_encoding_time(n_times)
|
|
909
|
+
_ = pos_encoding[:, :, :, self.spat_dim : self.spat_dim + self.time_dim].copy_(
|
|
910
|
+
self.encoding_time[None, None, :n_times, :],
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
return pos_encoding.view(batch_size, n_chans_times, emb_dim)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def _n_times_out(conv_layers_spec, n_times):
|
|
917
|
+
# it would be equal to n_times//ds_factor without edge effects:
|
|
918
|
+
n_times_out_ = n_times
|
|
919
|
+
for _, width, stride in conv_layers_spec:
|
|
920
|
+
n_times_out_ = int((n_times_out_ - width) / stride) + 1
|
|
921
|
+
return n_times_out_
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
def _get_out_emb_dim(conv_layers_spec, n_times, n_spat_filters=4):
|
|
925
|
+
n_time_out = _n_times_out(conv_layers_spec, n_times)
|
|
926
|
+
emb_dim = conv_layers_spec[-1][0]
|
|
927
|
+
return n_spat_filters * n_time_out * emb_dim
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
def _get_separable_clf_layer(
|
|
931
|
+
conv_layers_spec, n_chans, n_times, n_classes, n_spat_filters=4
|
|
932
|
+
):
|
|
933
|
+
out_emb_dim = _get_out_emb_dim(
|
|
934
|
+
conv_layers_spec=conv_layers_spec,
|
|
935
|
+
n_times=n_times,
|
|
936
|
+
n_spat_filters=n_spat_filters,
|
|
937
|
+
)
|
|
938
|
+
clf_layer = nn.Sequential()
|
|
939
|
+
clf_layer.add_module(
|
|
940
|
+
"unflatten_tokens",
|
|
941
|
+
Rearrange("b (n_chans tokens) d -> b 1 n_chans tokens d", n_chans=n_chans),
|
|
942
|
+
)
|
|
943
|
+
clf_layer.add_module("spat_conv", nn.Conv3d(1, n_spat_filters, (n_chans, 1, 1)))
|
|
944
|
+
clf_layer.add_module("flatten", nn.Flatten(start_dim=1))
|
|
945
|
+
clf_layer.add_module("linear", nn.Linear(out_emb_dim, n_classes))
|
|
946
|
+
return clf_layer
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
def _pos_encode_time(
|
|
950
|
+
n_times: int,
|
|
951
|
+
n_dim: int,
|
|
952
|
+
max_n_times: int,
|
|
953
|
+
device: torch.device = torch.device("cpu"),
|
|
954
|
+
):
|
|
955
|
+
"""1-dimensional positional encoding.
|
|
956
|
+
|
|
957
|
+
Parameters
|
|
958
|
+
----------
|
|
959
|
+
n_times: int
|
|
960
|
+
Number of time samples to encode.
|
|
961
|
+
n_dim: int
|
|
962
|
+
Number of dimensions of the positional encoding. Must be even.
|
|
963
|
+
max_n_times: int
|
|
964
|
+
The largest possible number of time samples to encode.
|
|
965
|
+
Used to scale the positional encoding.
|
|
966
|
+
device: torch.device
|
|
967
|
+
Device to put the output on.
|
|
968
|
+
Returns
|
|
969
|
+
-------
|
|
970
|
+
pos_encoding: (n_times, n_dim)
|
|
971
|
+
"""
|
|
972
|
+
assert n_dim % 2 == 0
|
|
973
|
+
position = torch.arange(n_times, device=device).unsqueeze(1)
|
|
974
|
+
div_term = torch.exp(
|
|
975
|
+
torch.arange(0, n_dim, 2, device=device) * (-math.log(max_n_times) / n_dim)
|
|
976
|
+
)
|
|
977
|
+
pos_encoding = torch.empty((n_times, n_dim), dtype=torch.float32, device=device)
|
|
978
|
+
pos_encoding[:, 0::2] = torch.sin(position * div_term)
|
|
979
|
+
pos_encoding[:, 1::2] = torch.cos(position * div_term)
|
|
980
|
+
return pos_encoding
|
|
981
|
+
|
|
982
|
+
|
|
983
|
+
def _pos_encode_contineous(
|
|
984
|
+
x, x_min, x_max, n_dim, device: torch.device = torch.device("cpu")
|
|
985
|
+
):
|
|
986
|
+
"""1-dimensional positional encoding.
|
|
987
|
+
|
|
988
|
+
Parameters
|
|
989
|
+
----------
|
|
990
|
+
x: float
|
|
991
|
+
The position to encode.
|
|
992
|
+
x_min: float
|
|
993
|
+
The minimum possible value of x.
|
|
994
|
+
x_max: float
|
|
995
|
+
The maximum possible value of x.
|
|
996
|
+
n_dim: int
|
|
997
|
+
Number of dimensions of the positional encoding. Must be even.
|
|
998
|
+
device: torch.device
|
|
999
|
+
Device to put the output on.
|
|
1000
|
+
Returns
|
|
1001
|
+
-------
|
|
1002
|
+
pos_encoding: (n_dim,)
|
|
1003
|
+
"""
|
|
1004
|
+
assert n_dim % 2 == 0
|
|
1005
|
+
div_term = torch.exp(
|
|
1006
|
+
(1 - torch.arange(0, n_dim, 2, device=device) / n_dim) * 2 * math.pi
|
|
1007
|
+
)
|
|
1008
|
+
pos_encoding = torch.empty((n_dim,), dtype=torch.float32, device=device)
|
|
1009
|
+
xx = (x - x_min) / (x_max - x_min)
|
|
1010
|
+
pos_encoding[0::2] = torch.sin(xx * div_term)
|
|
1011
|
+
pos_encoding[1::2] = torch.cos(xx * div_term)
|
|
1012
|
+
return pos_encoding
|