braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1011 @@
1
+ # Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from copy import deepcopy
8
+ from typing import Any, Sequence
9
+
10
+ import torch
11
+ from einops.layers.torch import Rearrange
12
+ from torch import nn
13
+
14
+ from braindecode.models.base import EEGModuleMixin
15
+
16
+ _DEFAULT_CONV_LAYER_SPEC = ( # downsampling: 128Hz -> 1Hz, receptive field 1.1875s, stride 1s
17
+ (8, 32, 8),
18
+ (16, 2, 2),
19
+ (32, 2, 2),
20
+ (64, 2, 2),
21
+ (64, 2, 2),
22
+ )
23
+
24
+
25
+ class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
26
+ """Base class for the SignalJEPA models
27
+
28
+ Parameters
29
+ ----------
30
+ feature_encoder__conv_layers_spec: list of tuple
31
+ tuples have shape ``(dim, k, stride)`` where:
32
+
33
+ * ``dim`` : number of output channels of the layer (unrelated to EEG channels);
34
+ * ``k`` : temporal length of the layer's kernel;
35
+ * ``stride`` : temporal stride of the layer's kernel.
36
+
37
+ drop_prob: float
38
+ feature_encoder__mode: str
39
+ Normalisation mode. Either ``default`` or ``layer_norm``.
40
+ feature_encoder__conv_bias: bool
41
+ activation: nn.Module
42
+ Activation layer for the feature encoder.
43
+ pos_encoder__spat_dim: int
44
+ Number of dimensions to use to encode the spatial position of the patch,
45
+ i.e. the EEG channel.
46
+ pos_encoder__time_dim: int
47
+ Number of dimensions to use to encode the temporal position of the patch.
48
+ pos_encoder__sfreq_features: float
49
+ The "downsampled" sampling frequency returned by the feature encoder.
50
+ pos_encoder__spat_kwargs: dict
51
+ Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
52
+ embed the channel names.
53
+ transformer__d_model: int
54
+ The number of expected features in the encoder/decoder inputs.
55
+ transformer__num_encoder_layers: int
56
+ The number of encoder layers in the transformer.
57
+ transformer__num_decoder_layers: int
58
+ The number of decoder layers in the transformer.
59
+ transformer__nhead: int
60
+ The number of heads in the multiheadattention models.
61
+ _init_feature_encoder : bool
62
+ Do not change the default value (used for internal purposes).
63
+ _init_transformer : bool
64
+ Do not change the default value (used for internal purposes).
65
+ """
66
+
67
+ feature_encoder: _ConvFeatureEncoder | None
68
+ pos_encoder: _PosEncoder | None
69
+ transformer: nn.Transformer | None
70
+
71
+ _feature_encoder_channels: str = "n_chans"
72
+
73
+ def __init__(
74
+ self,
75
+ n_outputs=None,
76
+ n_chans=None,
77
+ chs_info=None,
78
+ n_times=None,
79
+ input_window_seconds=None,
80
+ sfreq=None,
81
+ *,
82
+ # feature_encoder
83
+ feature_encoder__conv_layers_spec: Sequence[
84
+ tuple[int, int, int]
85
+ ] = _DEFAULT_CONV_LAYER_SPEC,
86
+ drop_prob: float = 0.0,
87
+ feature_encoder__mode: str = "default",
88
+ feature_encoder__conv_bias: bool = False,
89
+ activation: type[nn.Module] = nn.GELU,
90
+ # pos_encoder
91
+ pos_encoder__spat_dim: int = 30,
92
+ pos_encoder__time_dim: int = 34,
93
+ pos_encoder__sfreq_features: float = 1.0,
94
+ pos_encoder__spat_kwargs: dict | None = None,
95
+ # transformer
96
+ transformer__d_model: int = 64,
97
+ transformer__num_encoder_layers: int = 8,
98
+ transformer__num_decoder_layers: int = 4,
99
+ transformer__nhead: int = 8,
100
+ # other
101
+ _init_feature_encoder: bool,
102
+ _init_transformer: bool,
103
+ ):
104
+ super().__init__(
105
+ n_outputs=n_outputs,
106
+ n_chans=n_chans,
107
+ chs_info=chs_info,
108
+ n_times=n_times,
109
+ input_window_seconds=input_window_seconds,
110
+ sfreq=sfreq,
111
+ )
112
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
113
+
114
+ self.feature_encoder = None
115
+ self.pos_encoder = None
116
+ self.transformer = None
117
+ if _init_feature_encoder:
118
+ self.feature_encoder = _ConvFeatureEncoder(
119
+ conv_layers_spec=feature_encoder__conv_layers_spec,
120
+ channels=getattr(self, self._feature_encoder_channels),
121
+ drop_prob=drop_prob,
122
+ mode=feature_encoder__mode,
123
+ conv_bias=feature_encoder__conv_bias,
124
+ activation=activation,
125
+ )
126
+
127
+ if _init_transformer:
128
+ ch_locs = [ch["loc"] for ch in self.chs_info] # type: ignore
129
+ self.pos_encoder = _PosEncoder(
130
+ spat_dim=pos_encoder__spat_dim,
131
+ time_dim=pos_encoder__time_dim,
132
+ ch_locs=ch_locs,
133
+ sfreq_features=pos_encoder__sfreq_features,
134
+ spat_kwargs=pos_encoder__spat_kwargs,
135
+ )
136
+ self.transformer = nn.Transformer(
137
+ d_model=transformer__d_model,
138
+ nhead=transformer__nhead,
139
+ num_encoder_layers=transformer__num_encoder_layers,
140
+ num_decoder_layers=transformer__num_decoder_layers,
141
+ batch_first=True,
142
+ )
143
+
144
+
145
+ class SignalJEPA(_BaseSignalJEPA):
146
+ """Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
147
+
148
+ This model is not meant for classification but for SSL pre-training.
149
+ Its output shape depends on the input shape.
150
+ For classification purposes, three variants of this model are available:
151
+
152
+ * :class:`SignalJEPA_Contextual`
153
+ * :class:`SignalJEPA_PostLocal`
154
+ * :class:`SignalJEPA_PreLocal`
155
+
156
+ The classification architectures can either be instantiated from scratch
157
+ (random parameters) or from a pre-trained :class:`SignalJEPA` model.
158
+
159
+ .. versionadded:: 0.9
160
+
161
+ References
162
+ ----------
163
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
164
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
165
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ n_outputs=None,
171
+ n_chans=None,
172
+ chs_info=None,
173
+ n_times=None,
174
+ input_window_seconds=None,
175
+ sfreq=None,
176
+ *,
177
+ # feature_encoder
178
+ feature_encoder__conv_layers_spec: Sequence[
179
+ tuple[int, int, int]
180
+ ] = _DEFAULT_CONV_LAYER_SPEC,
181
+ drop_prob: float = 0.0,
182
+ feature_encoder__mode: str = "default",
183
+ feature_encoder__conv_bias: bool = False,
184
+ activation: type[nn.Module] = nn.GELU,
185
+ # pos_encoder
186
+ pos_encoder__spat_dim: int = 30,
187
+ pos_encoder__time_dim: int = 34,
188
+ pos_encoder__sfreq_features: float = 1.0,
189
+ pos_encoder__spat_kwargs: dict | None = None,
190
+ # transformer
191
+ transformer__d_model: int = 64,
192
+ transformer__num_encoder_layers: int = 8,
193
+ transformer__num_decoder_layers: int = 4,
194
+ transformer__nhead: int = 8,
195
+ ):
196
+ super().__init__(
197
+ n_outputs=n_outputs,
198
+ n_chans=n_chans,
199
+ chs_info=chs_info,
200
+ n_times=n_times,
201
+ input_window_seconds=input_window_seconds,
202
+ sfreq=sfreq,
203
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
204
+ drop_prob=drop_prob,
205
+ feature_encoder__mode=feature_encoder__mode,
206
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
207
+ activation=activation,
208
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
209
+ pos_encoder__time_dim=pos_encoder__time_dim,
210
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
211
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
212
+ transformer__d_model=transformer__d_model,
213
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
214
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
215
+ transformer__nhead=transformer__nhead,
216
+ _init_feature_encoder=True,
217
+ _init_transformer=True,
218
+ )
219
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
220
+ self.final_layer = nn.Identity()
221
+
222
+ def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
223
+ local_features = self.feature_encoder(X) # type: ignore
224
+ pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
225
+ local_features += pos_encoding # type: ignore
226
+ contextual_features = self.transformer.encoder(local_features) # type: ignore
227
+ y = self.final_layer(contextual_features) # type: ignore
228
+ return y # type: ignore
229
+
230
+
231
+ class SignalJEPA_Contextual(_BaseSignalJEPA):
232
+ """Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
233
+
234
+ This architecture is one of the variants of :class:`SignalJEPA`
235
+ that can be used for classification purposes.
236
+
237
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_contextual.jpg
238
+ :align: center
239
+ :alt: sJEPA Contextual.
240
+
241
+ .. versionadded:: 0.9
242
+
243
+ Parameters
244
+ ----------
245
+ n_spat_filters : int
246
+ Number of spatial filters.
247
+
248
+ References
249
+ ----------
250
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
251
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
252
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ n_outputs=None,
258
+ n_chans=None,
259
+ chs_info=None,
260
+ n_times=None,
261
+ input_window_seconds=None,
262
+ sfreq=None,
263
+ *,
264
+ n_spat_filters: int = 4,
265
+ # feature_encoder
266
+ feature_encoder__conv_layers_spec: Sequence[
267
+ tuple[int, int, int]
268
+ ] = _DEFAULT_CONV_LAYER_SPEC,
269
+ drop_prob: float = 0.0,
270
+ feature_encoder__mode: str = "default",
271
+ feature_encoder__conv_bias: bool = False,
272
+ activation: type[nn.Module] = nn.GELU,
273
+ # pos_encoder
274
+ pos_encoder__spat_dim: int = 30,
275
+ pos_encoder__time_dim: int = 34,
276
+ pos_encoder__sfreq_features: float = 1.0,
277
+ pos_encoder__spat_kwargs: dict | None = None,
278
+ # transformer
279
+ transformer__d_model: int = 64,
280
+ transformer__num_encoder_layers: int = 8,
281
+ transformer__num_decoder_layers: int = 4,
282
+ transformer__nhead: int = 8,
283
+ # other
284
+ _init_feature_encoder: bool = True,
285
+ _init_transformer: bool = True,
286
+ ):
287
+ super().__init__(
288
+ n_outputs=n_outputs,
289
+ n_chans=n_chans,
290
+ chs_info=chs_info,
291
+ n_times=n_times,
292
+ input_window_seconds=input_window_seconds,
293
+ sfreq=sfreq,
294
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
295
+ drop_prob=drop_prob,
296
+ feature_encoder__mode=feature_encoder__mode,
297
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
298
+ activation=activation,
299
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
300
+ pos_encoder__time_dim=pos_encoder__time_dim,
301
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
302
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
303
+ transformer__d_model=transformer__d_model,
304
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
305
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
306
+ transformer__nhead=transformer__nhead,
307
+ _init_feature_encoder=_init_feature_encoder,
308
+ _init_transformer=_init_transformer,
309
+ )
310
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
311
+ self.final_layer = _get_separable_clf_layer(
312
+ conv_layers_spec=feature_encoder__conv_layers_spec,
313
+ n_chans=self.n_chans,
314
+ n_times=self.n_times,
315
+ n_classes=self.n_outputs,
316
+ n_spat_filters=n_spat_filters,
317
+ )
318
+
319
+ @classmethod
320
+ def from_pretrained(
321
+ cls,
322
+ model: SignalJEPA,
323
+ n_outputs: int,
324
+ n_spat_filters: int = 4,
325
+ chs_info: list[dict[str, Any]] | None = None,
326
+ ):
327
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
328
+
329
+ Parameters
330
+ ----------
331
+ model: SignalJEPA
332
+ Pre-trained model.
333
+ n_outputs: int
334
+ Number of classes for the new model.
335
+ n_spat_filters: int
336
+ Number of spatial filters.
337
+ chs_info: list of dict | None
338
+ Information about each individual EEG channel. This should be filled with
339
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
340
+ """
341
+ feature_encoder = model.feature_encoder
342
+ pos_encoder = model.pos_encoder
343
+ transformer = model.transformer
344
+ assert feature_encoder is not None
345
+ assert pos_encoder is not None
346
+ assert transformer is not None
347
+
348
+ new_model = cls(
349
+ n_outputs=n_outputs,
350
+ n_chans=model.n_chans,
351
+ n_times=model.n_times,
352
+ chs_info=chs_info,
353
+ n_spat_filters=n_spat_filters,
354
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
355
+ _init_feature_encoder=False,
356
+ _init_transformer=False,
357
+ )
358
+ new_model.feature_encoder = deepcopy(feature_encoder)
359
+ new_model.pos_encoder = deepcopy(pos_encoder)
360
+ new_model.transformer = deepcopy(transformer)
361
+
362
+ if chs_info is not None:
363
+ ch_names = [ch["ch_name"] for ch in chs_info]
364
+ new_model.pos_encoder.set_fixed_ch_names(ch_names)
365
+
366
+ return new_model
367
+
368
+ def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
369
+ local_features = self.feature_encoder(X) # type: ignore
370
+ pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
371
+ local_features += pos_encoding # type: ignore
372
+ contextual_features = self.transformer.encoder(local_features) # type: ignore
373
+ y = self.final_layer(contextual_features) # type: ignore
374
+ return y # type: ignore
375
+
376
+
377
+ class SignalJEPA_PostLocal(_BaseSignalJEPA):
378
+ """Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
379
+
380
+ This architecture is one of the variants of :class:`SignalJEPA`
381
+ that can be used for classification purposes.
382
+
383
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_post-local.jpg
384
+ :align: center
385
+ :alt: sJEPA Pre-Local.
386
+
387
+ .. versionadded:: 0.9
388
+
389
+ Parameters
390
+ ----------
391
+ n_spat_filters : int
392
+ Number of spatial filters.
393
+
394
+ References
395
+ ----------
396
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
397
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
398
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ n_outputs=None,
404
+ n_chans=None,
405
+ chs_info=None,
406
+ n_times=None,
407
+ input_window_seconds=None,
408
+ sfreq=None,
409
+ *,
410
+ n_spat_filters: int = 4,
411
+ # feature_encoder
412
+ feature_encoder__conv_layers_spec: Sequence[
413
+ tuple[int, int, int]
414
+ ] = _DEFAULT_CONV_LAYER_SPEC,
415
+ drop_prob: float = 0.0,
416
+ feature_encoder__mode: str = "default",
417
+ feature_encoder__conv_bias: bool = False,
418
+ activation: type[nn.Module] = nn.GELU,
419
+ # pos_encoder
420
+ pos_encoder__spat_dim: int = 30,
421
+ pos_encoder__time_dim: int = 34,
422
+ pos_encoder__sfreq_features: float = 1.0,
423
+ pos_encoder__spat_kwargs: dict | None = None,
424
+ # transformer
425
+ transformer__d_model: int = 64,
426
+ transformer__num_encoder_layers: int = 8,
427
+ transformer__num_decoder_layers: int = 4,
428
+ transformer__nhead: int = 8,
429
+ # other
430
+ _init_feature_encoder: bool = True,
431
+ ):
432
+ super().__init__(
433
+ n_outputs=n_outputs,
434
+ n_chans=n_chans,
435
+ chs_info=chs_info,
436
+ n_times=n_times,
437
+ input_window_seconds=input_window_seconds,
438
+ sfreq=sfreq,
439
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
440
+ drop_prob=drop_prob,
441
+ feature_encoder__mode=feature_encoder__mode,
442
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
443
+ activation=activation,
444
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
445
+ pos_encoder__time_dim=pos_encoder__time_dim,
446
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
447
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
448
+ transformer__d_model=transformer__d_model,
449
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
450
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
451
+ transformer__nhead=transformer__nhead,
452
+ _init_feature_encoder=_init_feature_encoder,
453
+ _init_transformer=False,
454
+ )
455
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
456
+ self.final_layer = _get_separable_clf_layer(
457
+ conv_layers_spec=feature_encoder__conv_layers_spec,
458
+ n_chans=self.n_chans,
459
+ n_times=self.n_times,
460
+ n_classes=self.n_outputs,
461
+ n_spat_filters=n_spat_filters,
462
+ )
463
+
464
+ @classmethod
465
+ def from_pretrained(
466
+ cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
467
+ ):
468
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
469
+
470
+ Parameters
471
+ ----------
472
+ model: SignalJEPA
473
+ Pre-trained model.
474
+ n_outputs: int
475
+ Number of classes for the new model.
476
+ n_spat_filters: int
477
+ Number of spatial filters.
478
+ chs_info: list of dict | None
479
+ Information about each individual EEG channel. This should be filled with
480
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
481
+ """
482
+ feature_encoder = model.feature_encoder
483
+ assert feature_encoder is not None
484
+ new_model = cls(
485
+ n_outputs=n_outputs,
486
+ n_chans=model.n_chans,
487
+ n_times=model.n_times,
488
+ n_spat_filters=n_spat_filters,
489
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
490
+ _init_feature_encoder=False,
491
+ )
492
+ new_model.feature_encoder = deepcopy(feature_encoder)
493
+ return new_model
494
+
495
+ def forward(self, X):
496
+ local_features = self.feature_encoder(X)
497
+ y = self.final_layer(local_features)
498
+ return y
499
+
500
+
501
+ class SignalJEPA_PreLocal(_BaseSignalJEPA):
502
+ """Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
503
+
504
+ This architecture is one of the variants of :class:`SignalJEPA`
505
+ that can be used for classification purposes.
506
+
507
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_pre-local.jpg
508
+ :align: center
509
+ :alt: sJEPA Pre-Local.
510
+
511
+ .. versionadded:: 0.9
512
+
513
+ Parameters
514
+ ----------
515
+ n_spat_filters : int
516
+ Number of spatial filters.
517
+
518
+ References
519
+ ----------
520
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
521
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
522
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
523
+ """
524
+
525
+ _feature_encoder_channels: str = "n_spat_filters"
526
+
527
+ def __init__(
528
+ self,
529
+ n_outputs=None,
530
+ n_chans=None,
531
+ chs_info=None,
532
+ n_times=None,
533
+ input_window_seconds=None,
534
+ sfreq=None,
535
+ *,
536
+ n_spat_filters: int = 4,
537
+ # feature_encoder
538
+ feature_encoder__conv_layers_spec: Sequence[
539
+ tuple[int, int, int]
540
+ ] = _DEFAULT_CONV_LAYER_SPEC,
541
+ drop_prob: float = 0.0,
542
+ feature_encoder__mode: str = "default",
543
+ feature_encoder__conv_bias: bool = False,
544
+ activation: type[nn.Module] = nn.GELU,
545
+ # pos_encoder
546
+ pos_encoder__spat_dim: int = 30,
547
+ pos_encoder__time_dim: int = 34,
548
+ pos_encoder__sfreq_features: float = 1.0,
549
+ pos_encoder__spat_kwargs: dict | None = None,
550
+ # transformer
551
+ transformer__d_model: int = 64,
552
+ transformer__num_encoder_layers: int = 8,
553
+ transformer__num_decoder_layers: int = 4,
554
+ transformer__nhead: int = 8,
555
+ # other
556
+ _init_feature_encoder: bool = True,
557
+ ):
558
+ self.n_spat_filters = n_spat_filters
559
+ super().__init__(
560
+ n_outputs=n_outputs,
561
+ n_chans=n_chans,
562
+ chs_info=chs_info,
563
+ n_times=n_times,
564
+ input_window_seconds=input_window_seconds,
565
+ sfreq=sfreq,
566
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
567
+ drop_prob=drop_prob,
568
+ feature_encoder__mode=feature_encoder__mode,
569
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
570
+ activation=activation,
571
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
572
+ pos_encoder__time_dim=pos_encoder__time_dim,
573
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
574
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
575
+ transformer__d_model=transformer__d_model,
576
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
577
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
578
+ transformer__nhead=transformer__nhead,
579
+ _init_feature_encoder=_init_feature_encoder,
580
+ _init_transformer=False,
581
+ )
582
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
583
+ self.spatial_conv = nn.Sequential(
584
+ Rearrange("b channels time -> b 1 channels time"),
585
+ nn.Conv2d(1, n_spat_filters, (self.n_chans, 1)),
586
+ Rearrange("b spat_filters 1 time -> b spat_filters time"),
587
+ )
588
+ out_emb_dim = _get_out_emb_dim(
589
+ conv_layers_spec=feature_encoder__conv_layers_spec,
590
+ n_times=self.n_times,
591
+ n_spat_filters=n_spat_filters,
592
+ )
593
+ self.final_layer = nn.Sequential(
594
+ nn.Flatten(start_dim=1),
595
+ nn.Linear(out_emb_dim, self.n_outputs),
596
+ )
597
+
598
+ @classmethod
599
+ def from_pretrained(
600
+ cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
601
+ ):
602
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
603
+
604
+ Parameters
605
+ ----------
606
+ model: SignalJEPA
607
+ Pre-trained model.
608
+ n_outputs: int
609
+ Number of classes for the new model.
610
+ n_spat_filters: int
611
+ Number of spatial filters.
612
+ chs_info: list of dict | None
613
+ Information about each individual EEG channel. This should be filled with
614
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
615
+ """
616
+ feature_encoder = model.feature_encoder
617
+ assert feature_encoder is not None
618
+ new_model = cls(
619
+ n_outputs=n_outputs,
620
+ n_chans=model.n_chans,
621
+ n_times=model.n_times,
622
+ n_spat_filters=n_spat_filters,
623
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
624
+ _init_feature_encoder=False,
625
+ )
626
+ new_model.feature_encoder = deepcopy(feature_encoder)
627
+ return new_model
628
+
629
+ def forward(self, X):
630
+ X = self.spatial_conv(X)
631
+ local_features = self.feature_encoder(X)
632
+ y = self.final_layer(local_features)
633
+ return y
634
+
635
+
636
+ class _ConvFeatureEncoder(nn.Sequential):
637
+ """Convolutional feature encoder for EEG data.
638
+
639
+ Computes successive 1D convolutions (with activations) over the time
640
+ dimension of the input EEG signal.
641
+
642
+ Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
643
+ and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
644
+
645
+ Parameters
646
+ ----------
647
+ conv_layers_spec: list of tuple
648
+ tuples have shape ``(dim, k, stride)`` where:
649
+
650
+ * ``dim`` : number of output channels of the layer (unrelated to EEG channels);
651
+ * ``k`` : temporal length of the layer's kernel;
652
+ * ``stride`` : temporal stride of the layer's kernel.
653
+
654
+ channels: int
655
+ drop_prob: float
656
+ mode: str
657
+ Normalisation mode. Either ``default`` or ``layer_norm``.
658
+ conv_bias: bool
659
+ activation: nn.Module
660
+ """
661
+
662
+ def __init__(
663
+ self,
664
+ conv_layers_spec: Sequence[tuple[int, int, int]],
665
+ channels: int,
666
+ drop_prob: float = 0.0,
667
+ mode: str = "default",
668
+ conv_bias: bool = False,
669
+ activation: type[nn.Module] = nn.GELU,
670
+ ):
671
+ assert mode in {"default", "layer_norm"}
672
+
673
+ input_channels = 1
674
+ conv_layers = []
675
+ for i, layer_spec in enumerate(conv_layers_spec):
676
+ # Each layer_spec should be a tuple: (output_channels, kernel_size, stride)
677
+ assert len(layer_spec) == 3, "Invalid conv definition: " + str(layer_spec)
678
+ output_channels, kernel_size, stride = layer_spec
679
+ conv_layers.append(
680
+ self._get_block(
681
+ input_channels,
682
+ output_channels,
683
+ kernel_size,
684
+ stride,
685
+ drop_prob,
686
+ activation,
687
+ is_layer_norm=(mode == "layer_norm"),
688
+ is_group_norm=(mode == "default" and i == 0),
689
+ conv_bias=conv_bias,
690
+ )
691
+ )
692
+ input_channels = output_channels
693
+ all_layers = [
694
+ Rearrange("b channels time -> (b channels) 1 time", channels=channels),
695
+ *conv_layers,
696
+ Rearrange(
697
+ "(b channels) emb_dim time_out -> b (channels time_out) emb_dim",
698
+ channels=channels,
699
+ ),
700
+ ]
701
+ super().__init__(*all_layers)
702
+ self.emb_dim = (
703
+ output_channels # last output dimension becomes the embedding dimension
704
+ )
705
+ self.conv_layers_spec = conv_layers_spec
706
+
707
+ @staticmethod
708
+ def _get_block(
709
+ input_channels,
710
+ output_channels,
711
+ kernel_size,
712
+ stride,
713
+ drop_prob,
714
+ activation,
715
+ is_layer_norm=False,
716
+ is_group_norm=False,
717
+ conv_bias=False,
718
+ ):
719
+ assert not (is_layer_norm and is_group_norm), (
720
+ "layer norm and group norm are exclusive"
721
+ )
722
+
723
+ conv = nn.Conv1d(
724
+ input_channels,
725
+ output_channels,
726
+ kernel_size,
727
+ stride=stride,
728
+ bias=conv_bias,
729
+ )
730
+ nn.init.kaiming_normal_(conv.weight)
731
+ if is_layer_norm:
732
+ return nn.Sequential(
733
+ conv,
734
+ nn.Dropout(p=drop_prob),
735
+ nn.Sequential(
736
+ Rearrange("... channels time -> ... time channels"),
737
+ nn.LayerNorm(output_channels, elementwise_affine=True),
738
+ Rearrange("... time channels -> ... channels time"),
739
+ ),
740
+ activation(),
741
+ )
742
+ elif is_group_norm:
743
+ return nn.Sequential(
744
+ conv,
745
+ nn.Dropout(p=drop_prob),
746
+ nn.GroupNorm(output_channels, output_channels, affine=True),
747
+ activation(),
748
+ )
749
+ else:
750
+ return nn.Sequential(conv, nn.Dropout(p=drop_prob), activation())
751
+
752
+ def n_times_out(self, n_times):
753
+ return _n_times_out(self.conv_layers_spec, n_times)
754
+
755
+
756
+ class _ChannelEmbedding(nn.Embedding):
757
+ """Embedding layer for EEG channels.
758
+
759
+ The difference with a regular :class:`nn.Embedding` is that the embedding
760
+ vectors are initialized with a positional encodding of the channel locations.
761
+
762
+ Parameters
763
+ ----------
764
+ channel_locations: list of (list of float or None)
765
+ List of the n-dimensions locations of the EEG channels.
766
+ embedding_dim: int
767
+ Dimensionality of the embedding vectors. Must be a multiple of the number
768
+ of dimensions of the channel locations.
769
+ """
770
+
771
+ def __init__(
772
+ self, channel_locations: list[list[float] | None], embedding_dim: int, **kwargs
773
+ ):
774
+ self.coordinate_ranges = [
775
+ (min(coords), max(coords))
776
+ for coords in zip(
777
+ *[
778
+ loc[3:6] if len(loc) == 12 else loc
779
+ for loc in channel_locations
780
+ if loc is not None
781
+ ]
782
+ )
783
+ ]
784
+ channel_mins, channel_maxs = zip(*self.coordinate_ranges)
785
+ global_min = min(channel_mins)
786
+ global_max = max(channel_maxs)
787
+ self.max_abs_coordinate = max(abs(global_min), abs(global_max))
788
+ self.embedding_dim_per_coordinate = embedding_dim // len(self.coordinate_ranges)
789
+ self.channel_locations = list(channel_locations)
790
+
791
+ assert embedding_dim % len(self.coordinate_ranges) == 0
792
+
793
+ super().__init__(len(channel_locations), embedding_dim, **kwargs)
794
+
795
+ def reset_parameters(self):
796
+ for i, loc in enumerate(self.channel_locations):
797
+ if loc is None:
798
+ nn.init.zeros_(self.weight[i])
799
+ else:
800
+ for j, (x, (x0, x1)) in enumerate(zip(loc, self.coordinate_ranges)):
801
+ with torch.no_grad():
802
+ self.weight[
803
+ i,
804
+ j * self.embedding_dim_per_coordinate : (j + 1)
805
+ * self.embedding_dim_per_coordinate,
806
+ ].copy_(
807
+ _pos_encode_contineous(
808
+ x,
809
+ 0,
810
+ 10 * self.max_abs_coordinate,
811
+ self.embedding_dim_per_coordinate,
812
+ device=self.weight.device,
813
+ ),
814
+ )
815
+
816
+
817
+ class _PosEncoder(nn.Module):
818
+ """Positional encoder for EEG data.
819
+
820
+ Parameters
821
+ ----------
822
+ spat_dim: int
823
+ Number of dimensions to use to encode the spatial position of the patch,
824
+ i.e. the EEG channel.
825
+ time_dim: int
826
+ Number of dimensions to use to encode the temporal position of the patch.
827
+ ch_locs: list of list of float or 2d array
828
+ List of the n-dimensions locations of the EEG channels.
829
+ sfreq_features: float
830
+ The "downsampled" sampling frequency returned by the feature encoder.
831
+ spat_kwargs: dict
832
+ Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
833
+ embed the channel names.
834
+ max_seconds: float
835
+ Maximum number of seconds to consider for the temporal encoding.
836
+ """
837
+
838
+ def __init__(
839
+ self,
840
+ spat_dim: int,
841
+ time_dim: int,
842
+ ch_locs,
843
+ sfreq_features: float,
844
+ spat_kwargs: dict | None = None,
845
+ max_seconds: float = 600.0, # 10 minutes
846
+ ):
847
+ super().__init__()
848
+ spat_kwargs = spat_kwargs or {}
849
+ self.spat_dim = spat_dim
850
+ self.time_dim = time_dim
851
+ self.max_n_times = int(max_seconds * sfreq_features)
852
+
853
+ # Positional encoder for the spatial dimension:
854
+ self.pos_encoder_spat = _ChannelEmbedding(
855
+ ch_locs, spat_dim, **spat_kwargs
856
+ ) # (batch_size, n_channels, spat_dim)
857
+
858
+ # Pre-computed tensor for positional encoding on the time dimension:
859
+ self.encoding_time = torch.zeros(0, dtype=torch.float32, requires_grad=False)
860
+
861
+ def _check_encoding_time(self, n_times: int):
862
+ if self.encoding_time.size(0) < n_times:
863
+ self.encoding_time = self.encoding_time.new_empty((n_times, self.time_dim))
864
+ self.encoding_time[:] = _pos_encode_time(
865
+ n_times=n_times,
866
+ n_dim=self.time_dim,
867
+ max_n_times=self.max_n_times,
868
+ device=self.encoding_time.device,
869
+ )
870
+
871
+ def forward(self, local_features, ch_idxs: torch.Tensor | None = None):
872
+ """
873
+ Parameters
874
+ ----------
875
+ * local_features: (batch_size, n_chans * n_times_out, emb_dim)
876
+ * ch_idxs: (batch_size, n_chans) | None
877
+ Indices of the channels to use in the ``ch_names`` list passed
878
+ as argument plus one. Index 0 is reserved for an unknown channel.
879
+
880
+ Returns
881
+ -------
882
+ pos_encoding: (batch_size, n_chans * n_times_out, emb_dim)
883
+ The first ``spat_dim`` dimensions encode the channels positional encoding
884
+ and the following ``time_dim`` dimensions encode the temporal positional encoding.
885
+ """
886
+ batch_size, n_chans_times, emb_dim = local_features.shape
887
+ if ch_idxs is None:
888
+ ch_idxs = torch.arange(
889
+ 0,
890
+ self.pos_encoder_spat.num_embeddings,
891
+ device=local_features.device,
892
+ ).repeat(batch_size, 1)
893
+
894
+ batch_size_chs, n_chans = ch_idxs.shape
895
+ assert emb_dim >= self.spat_dim + self.time_dim
896
+ assert n_chans_times % n_chans == 0
897
+ n_times = n_chans_times // n_chans
898
+
899
+ pos_encoding = local_features.new_empty(
900
+ (batch_size_chs, n_chans, n_times, emb_dim)
901
+ )
902
+ # Channel pos. encoding
903
+ pos_encoding[:, :, :, : self.spat_dim] = self.pos_encoder_spat(ch_idxs)[
904
+ :, :, None, :
905
+ ]
906
+ # Temporal pos. encoding
907
+ self._check_encoding_time(n_times)
908
+ _ = pos_encoding[:, :, :, self.spat_dim : self.spat_dim + self.time_dim].copy_(
909
+ self.encoding_time[None, None, :n_times, :],
910
+ )
911
+
912
+ return pos_encoding.view(batch_size, n_chans_times, emb_dim)
913
+
914
+
915
+ def _n_times_out(conv_layers_spec, n_times):
916
+ # it would be equal to n_times//ds_factor without edge effects:
917
+ n_times_out_ = n_times
918
+ for _, width, stride in conv_layers_spec:
919
+ n_times_out_ = int((n_times_out_ - width) / stride) + 1
920
+ return n_times_out_
921
+
922
+
923
+ def _get_out_emb_dim(conv_layers_spec, n_times, n_spat_filters=4):
924
+ n_time_out = _n_times_out(conv_layers_spec, n_times)
925
+ emb_dim = conv_layers_spec[-1][0]
926
+ return n_spat_filters * n_time_out * emb_dim
927
+
928
+
929
+ def _get_separable_clf_layer(
930
+ conv_layers_spec, n_chans, n_times, n_classes, n_spat_filters=4
931
+ ):
932
+ out_emb_dim = _get_out_emb_dim(
933
+ conv_layers_spec=conv_layers_spec,
934
+ n_times=n_times,
935
+ n_spat_filters=n_spat_filters,
936
+ )
937
+ clf_layer = nn.Sequential()
938
+ clf_layer.add_module(
939
+ "unflatten_tokens",
940
+ Rearrange("b (n_chans tokens) d -> b 1 n_chans tokens d", n_chans=n_chans),
941
+ )
942
+ clf_layer.add_module("spat_conv", nn.Conv3d(1, n_spat_filters, (n_chans, 1, 1)))
943
+ clf_layer.add_module("flatten", nn.Flatten(start_dim=1))
944
+ clf_layer.add_module("linear", nn.Linear(out_emb_dim, n_classes))
945
+ return clf_layer
946
+
947
+
948
+ def _pos_encode_time(
949
+ n_times: int,
950
+ n_dim: int,
951
+ max_n_times: int,
952
+ device: torch.device = torch.device("cpu"),
953
+ ):
954
+ """1-dimensional positional encoding.
955
+
956
+ Parameters
957
+ ----------
958
+ n_times: int
959
+ Number of time samples to encode.
960
+ n_dim: int
961
+ Number of dimensions of the positional encoding. Must be even.
962
+ max_n_times: int
963
+ The largest possible number of time samples to encode.
964
+ Used to scale the positional encoding.
965
+ device: torch.device
966
+ Device to put the output on.
967
+ Returns
968
+ -------
969
+ pos_encoding: (n_times, n_dim)
970
+ """
971
+ assert n_dim % 2 == 0
972
+ position = torch.arange(n_times, device=device).unsqueeze(1)
973
+ div_term = torch.exp(
974
+ torch.arange(0, n_dim, 2, device=device) * (-math.log(max_n_times) / n_dim)
975
+ )
976
+ pos_encoding = torch.empty((n_times, n_dim), dtype=torch.float32, device=device)
977
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
978
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
979
+ return pos_encoding
980
+
981
+
982
+ def _pos_encode_contineous(
983
+ x, x_min, x_max, n_dim, device: torch.device = torch.device("cpu")
984
+ ):
985
+ """1-dimensional positional encoding.
986
+
987
+ Parameters
988
+ ----------
989
+ x: float
990
+ The position to encode.
991
+ x_min: float
992
+ The minimum possible value of x.
993
+ x_max: float
994
+ The maximum possible value of x.
995
+ n_dim: int
996
+ Number of dimensions of the positional encoding. Must be even.
997
+ device: torch.device
998
+ Device to put the output on.
999
+ Returns
1000
+ -------
1001
+ pos_encoding: (n_dim,)
1002
+ """
1003
+ assert n_dim % 2 == 0
1004
+ div_term = torch.exp(
1005
+ (1 - torch.arange(0, n_dim, 2, device=device) / n_dim) * 2 * math.pi
1006
+ )
1007
+ pos_encoding = torch.empty((n_dim,), dtype=torch.float32, device=device)
1008
+ xx = (x - x_min) / (x_max - x_min)
1009
+ pos_encoding[0::2] = torch.sin(xx * div_term)
1010
+ pos_encoding[1::2] = torch.cos(xx * div_term)
1011
+ return pos_encoding