braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1012 @@
1
+ # Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from copy import deepcopy
8
+ from typing import Any, Sequence
9
+
10
+ import torch
11
+ from einops import parse_shape, rearrange, repeat
12
+ from einops.layers.torch import Rearrange
13
+ from torch import nn
14
+
15
+ from braindecode.models.base import EEGModuleMixin
16
+
17
+ _DEFAULT_CONV_LAYER_SPEC = ( # downsampling: 128Hz -> 1Hz, receptive field 1.1875s, stride 1s
18
+ (8, 32, 8),
19
+ (16, 2, 2),
20
+ (32, 2, 2),
21
+ (64, 2, 2),
22
+ (64, 2, 2),
23
+ )
24
+
25
+
26
+ class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
27
+ """Base class for the SignalJEPA models
28
+
29
+ Parameters
30
+ ----------
31
+ feature_encoder__conv_layers_spec: list of tuple
32
+ tuples have shape ``(dim, k, stride)`` where:
33
+
34
+ * ``dim`` : number of output channels of the layer (unrelated to EEG channels);
35
+ * ``k`` : temporal length of the layer's kernel;
36
+ * ``stride`` : temporal stride of the layer's kernel.
37
+
38
+ drop_prob: float
39
+ feature_encoder__mode: str
40
+ Normalisation mode. Either ``default`` or ``layer_norm``.
41
+ feature_encoder__conv_bias: bool
42
+ activation: nn.Module
43
+ Activation layer for the feature encoder.
44
+ pos_encoder__spat_dim: int
45
+ Number of dimensions to use to encode the spatial position of the patch,
46
+ i.e. the EEG channel.
47
+ pos_encoder__time_dim: int
48
+ Number of dimensions to use to encode the temporal position of the patch.
49
+ pos_encoder__sfreq_features: float
50
+ The "downsampled" sampling frequency returned by the feature encoder.
51
+ pos_encoder__spat_kwargs: dict
52
+ Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
53
+ embed the channel names.
54
+ transformer__d_model: int
55
+ The number of expected features in the encoder/decoder inputs.
56
+ transformer__num_encoder_layers: int
57
+ The number of encoder layers in the transformer.
58
+ transformer__num_decoder_layers: int
59
+ The number of decoder layers in the transformer.
60
+ transformer__nhead: int
61
+ The number of heads in the multiheadattention models.
62
+ _init_feature_encoder : bool
63
+ Do not change the default value (used for internal purposes).
64
+ _init_transformer : bool
65
+ Do not change the default value (used for internal purposes).
66
+ """
67
+
68
+ feature_encoder: _ConvFeatureEncoder | None
69
+ pos_encoder: _PosEncoder | None
70
+ transformer: nn.Transformer | None
71
+
72
+ _feature_encoder_channels: str = "n_chans"
73
+
74
+ def __init__(
75
+ self,
76
+ n_outputs=None,
77
+ n_chans=None,
78
+ chs_info=None,
79
+ n_times=None,
80
+ input_window_seconds=None,
81
+ sfreq=None,
82
+ *,
83
+ # feature_encoder
84
+ feature_encoder__conv_layers_spec: Sequence[
85
+ tuple[int, int, int]
86
+ ] = _DEFAULT_CONV_LAYER_SPEC,
87
+ drop_prob: float = 0.0,
88
+ feature_encoder__mode: str = "default",
89
+ feature_encoder__conv_bias: bool = False,
90
+ activation: type[nn.Module] = nn.GELU,
91
+ # pos_encoder
92
+ pos_encoder__spat_dim: int = 30,
93
+ pos_encoder__time_dim: int = 34,
94
+ pos_encoder__sfreq_features: float = 1.0,
95
+ pos_encoder__spat_kwargs: dict | None = None,
96
+ # transformer
97
+ transformer__d_model: int = 64,
98
+ transformer__num_encoder_layers: int = 8,
99
+ transformer__num_decoder_layers: int = 4,
100
+ transformer__nhead: int = 8,
101
+ # other
102
+ _init_feature_encoder: bool,
103
+ _init_transformer: bool,
104
+ ):
105
+ super().__init__(
106
+ n_outputs=n_outputs,
107
+ n_chans=n_chans,
108
+ chs_info=chs_info,
109
+ n_times=n_times,
110
+ input_window_seconds=input_window_seconds,
111
+ sfreq=sfreq,
112
+ )
113
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
114
+
115
+ self.feature_encoder = None
116
+ self.pos_encoder = None
117
+ self.transformer = None
118
+ if _init_feature_encoder:
119
+ self.feature_encoder = _ConvFeatureEncoder(
120
+ conv_layers_spec=feature_encoder__conv_layers_spec,
121
+ channels=getattr(self, self._feature_encoder_channels),
122
+ drop_prob=drop_prob,
123
+ mode=feature_encoder__mode,
124
+ conv_bias=feature_encoder__conv_bias,
125
+ activation=activation,
126
+ )
127
+
128
+ if _init_transformer:
129
+ ch_locs = [ch["loc"] for ch in self.chs_info] # type: ignore
130
+ self.pos_encoder = _PosEncoder(
131
+ spat_dim=pos_encoder__spat_dim,
132
+ time_dim=pos_encoder__time_dim,
133
+ ch_locs=ch_locs,
134
+ sfreq_features=pos_encoder__sfreq_features,
135
+ spat_kwargs=pos_encoder__spat_kwargs,
136
+ )
137
+ self.transformer = nn.Transformer(
138
+ d_model=transformer__d_model,
139
+ nhead=transformer__nhead,
140
+ num_encoder_layers=transformer__num_encoder_layers,
141
+ num_decoder_layers=transformer__num_decoder_layers,
142
+ batch_first=True,
143
+ )
144
+
145
+
146
+ class SignalJEPA(_BaseSignalJEPA):
147
+ """Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
148
+
149
+ This model is not meant for classification but for SSL pre-training.
150
+ Its output shape depends on the input shape.
151
+ For classification purposes, three variants of this model are available:
152
+
153
+ * :class:`SignalJEPA_Contextual`
154
+ * :class:`SignalJEPA_PostLocal`
155
+ * :class:`SignalJEPA_PreLocal`
156
+
157
+ The classification architectures can either be instantiated from scratch
158
+ (random parameters) or from a pre-trained :class:`SignalJEPA` model.
159
+
160
+ .. versionadded:: 0.9
161
+
162
+ References
163
+ ----------
164
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
165
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
166
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ n_outputs=None,
172
+ n_chans=None,
173
+ chs_info=None,
174
+ n_times=None,
175
+ input_window_seconds=None,
176
+ sfreq=None,
177
+ *,
178
+ # feature_encoder
179
+ feature_encoder__conv_layers_spec: Sequence[
180
+ tuple[int, int, int]
181
+ ] = _DEFAULT_CONV_LAYER_SPEC,
182
+ drop_prob: float = 0.0,
183
+ feature_encoder__mode: str = "default",
184
+ feature_encoder__conv_bias: bool = False,
185
+ activation: type[nn.Module] = nn.GELU,
186
+ # pos_encoder
187
+ pos_encoder__spat_dim: int = 30,
188
+ pos_encoder__time_dim: int = 34,
189
+ pos_encoder__sfreq_features: float = 1.0,
190
+ pos_encoder__spat_kwargs: dict | None = None,
191
+ # transformer
192
+ transformer__d_model: int = 64,
193
+ transformer__num_encoder_layers: int = 8,
194
+ transformer__num_decoder_layers: int = 4,
195
+ transformer__nhead: int = 8,
196
+ ):
197
+ super().__init__(
198
+ n_outputs=n_outputs,
199
+ n_chans=n_chans,
200
+ chs_info=chs_info,
201
+ n_times=n_times,
202
+ input_window_seconds=input_window_seconds,
203
+ sfreq=sfreq,
204
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
205
+ drop_prob=drop_prob,
206
+ feature_encoder__mode=feature_encoder__mode,
207
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
208
+ activation=activation,
209
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
210
+ pos_encoder__time_dim=pos_encoder__time_dim,
211
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
212
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
213
+ transformer__d_model=transformer__d_model,
214
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
215
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
216
+ transformer__nhead=transformer__nhead,
217
+ _init_feature_encoder=True,
218
+ _init_transformer=True,
219
+ )
220
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
221
+ self.final_layer = nn.Identity()
222
+
223
+ def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
224
+ local_features = self.feature_encoder(X) # type: ignore
225
+ pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
226
+ local_features += pos_encoding # type: ignore
227
+ contextual_features = self.transformer.encoder(local_features) # type: ignore
228
+ y = self.final_layer(contextual_features) # type: ignore
229
+ return y # type: ignore
230
+
231
+
232
+ class SignalJEPA_Contextual(_BaseSignalJEPA):
233
+ """Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
234
+
235
+ This architecture is one of the variants of :class:`SignalJEPA`
236
+ that can be used for classification purposes.
237
+
238
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_contextual.jpg
239
+ :align: center
240
+ :alt: sJEPA Contextual.
241
+
242
+ .. versionadded:: 0.9
243
+
244
+ Parameters
245
+ ----------
246
+ n_spat_filters : int
247
+ Number of spatial filters.
248
+
249
+ References
250
+ ----------
251
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
252
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
253
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ n_outputs=None,
259
+ n_chans=None,
260
+ chs_info=None,
261
+ n_times=None,
262
+ input_window_seconds=None,
263
+ sfreq=None,
264
+ *,
265
+ n_spat_filters: int = 4,
266
+ # feature_encoder
267
+ feature_encoder__conv_layers_spec: Sequence[
268
+ tuple[int, int, int]
269
+ ] = _DEFAULT_CONV_LAYER_SPEC,
270
+ drop_prob: float = 0.0,
271
+ feature_encoder__mode: str = "default",
272
+ feature_encoder__conv_bias: bool = False,
273
+ activation: type[nn.Module] = nn.GELU,
274
+ # pos_encoder
275
+ pos_encoder__spat_dim: int = 30,
276
+ pos_encoder__time_dim: int = 34,
277
+ pos_encoder__sfreq_features: float = 1.0,
278
+ pos_encoder__spat_kwargs: dict | None = None,
279
+ # transformer
280
+ transformer__d_model: int = 64,
281
+ transformer__num_encoder_layers: int = 8,
282
+ transformer__num_decoder_layers: int = 4,
283
+ transformer__nhead: int = 8,
284
+ # other
285
+ _init_feature_encoder: bool = True,
286
+ _init_transformer: bool = True,
287
+ ):
288
+ super().__init__(
289
+ n_outputs=n_outputs,
290
+ n_chans=n_chans,
291
+ chs_info=chs_info,
292
+ n_times=n_times,
293
+ input_window_seconds=input_window_seconds,
294
+ sfreq=sfreq,
295
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
296
+ drop_prob=drop_prob,
297
+ feature_encoder__mode=feature_encoder__mode,
298
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
299
+ activation=activation,
300
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
301
+ pos_encoder__time_dim=pos_encoder__time_dim,
302
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
303
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
304
+ transformer__d_model=transformer__d_model,
305
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
306
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
307
+ transformer__nhead=transformer__nhead,
308
+ _init_feature_encoder=_init_feature_encoder,
309
+ _init_transformer=_init_transformer,
310
+ )
311
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
312
+ self.final_layer = _get_separable_clf_layer(
313
+ conv_layers_spec=feature_encoder__conv_layers_spec,
314
+ n_chans=self.n_chans,
315
+ n_times=self.n_times,
316
+ n_classes=self.n_outputs,
317
+ n_spat_filters=n_spat_filters,
318
+ )
319
+
320
+ @classmethod
321
+ def from_pretrained(
322
+ cls,
323
+ model: SignalJEPA,
324
+ n_outputs: int,
325
+ n_spat_filters: int = 4,
326
+ chs_info: list[dict[str, Any]] | None = None,
327
+ ):
328
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
329
+
330
+ Parameters
331
+ ----------
332
+ model: SignalJEPA
333
+ Pre-trained model.
334
+ n_outputs: int
335
+ Number of classes for the new model.
336
+ n_spat_filters: int
337
+ Number of spatial filters.
338
+ chs_info: list of dict | None
339
+ Information about each individual EEG channel. This should be filled with
340
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
341
+ """
342
+ feature_encoder = model.feature_encoder
343
+ pos_encoder = model.pos_encoder
344
+ transformer = model.transformer
345
+ assert feature_encoder is not None
346
+ assert pos_encoder is not None
347
+ assert transformer is not None
348
+
349
+ new_model = cls(
350
+ n_outputs=n_outputs,
351
+ n_chans=model.n_chans,
352
+ n_times=model.n_times,
353
+ chs_info=chs_info,
354
+ n_spat_filters=n_spat_filters,
355
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
356
+ _init_feature_encoder=False,
357
+ _init_transformer=False,
358
+ )
359
+ new_model.feature_encoder = deepcopy(feature_encoder)
360
+ new_model.pos_encoder = deepcopy(pos_encoder)
361
+ new_model.transformer = deepcopy(transformer)
362
+
363
+ if chs_info is not None:
364
+ ch_names = [ch["ch_name"] for ch in chs_info]
365
+ new_model.pos_encoder.set_fixed_ch_names(ch_names)
366
+
367
+ return new_model
368
+
369
+ def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
370
+ local_features = self.feature_encoder(X) # type: ignore
371
+ pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
372
+ local_features += pos_encoding # type: ignore
373
+ contextual_features = self.transformer.encoder(local_features) # type: ignore
374
+ y = self.final_layer(contextual_features) # type: ignore
375
+ return y # type: ignore
376
+
377
+
378
+ class SignalJEPA_PostLocal(_BaseSignalJEPA):
379
+ """Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
380
+
381
+ This architecture is one of the variants of :class:`SignalJEPA`
382
+ that can be used for classification purposes.
383
+
384
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_post-local.jpg
385
+ :align: center
386
+ :alt: sJEPA Pre-Local.
387
+
388
+ .. versionadded:: 0.9
389
+
390
+ Parameters
391
+ ----------
392
+ n_spat_filters : int
393
+ Number of spatial filters.
394
+
395
+ References
396
+ ----------
397
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
398
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
399
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ n_outputs=None,
405
+ n_chans=None,
406
+ chs_info=None,
407
+ n_times=None,
408
+ input_window_seconds=None,
409
+ sfreq=None,
410
+ *,
411
+ n_spat_filters: int = 4,
412
+ # feature_encoder
413
+ feature_encoder__conv_layers_spec: Sequence[
414
+ tuple[int, int, int]
415
+ ] = _DEFAULT_CONV_LAYER_SPEC,
416
+ drop_prob: float = 0.0,
417
+ feature_encoder__mode: str = "default",
418
+ feature_encoder__conv_bias: bool = False,
419
+ activation: type[nn.Module] = nn.GELU,
420
+ # pos_encoder
421
+ pos_encoder__spat_dim: int = 30,
422
+ pos_encoder__time_dim: int = 34,
423
+ pos_encoder__sfreq_features: float = 1.0,
424
+ pos_encoder__spat_kwargs: dict | None = None,
425
+ # transformer
426
+ transformer__d_model: int = 64,
427
+ transformer__num_encoder_layers: int = 8,
428
+ transformer__num_decoder_layers: int = 4,
429
+ transformer__nhead: int = 8,
430
+ # other
431
+ _init_feature_encoder: bool = True,
432
+ ):
433
+ super().__init__(
434
+ n_outputs=n_outputs,
435
+ n_chans=n_chans,
436
+ chs_info=chs_info,
437
+ n_times=n_times,
438
+ input_window_seconds=input_window_seconds,
439
+ sfreq=sfreq,
440
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
441
+ drop_prob=drop_prob,
442
+ feature_encoder__mode=feature_encoder__mode,
443
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
444
+ activation=activation,
445
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
446
+ pos_encoder__time_dim=pos_encoder__time_dim,
447
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
448
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
449
+ transformer__d_model=transformer__d_model,
450
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
451
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
452
+ transformer__nhead=transformer__nhead,
453
+ _init_feature_encoder=_init_feature_encoder,
454
+ _init_transformer=False,
455
+ )
456
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
457
+ self.final_layer = _get_separable_clf_layer(
458
+ conv_layers_spec=feature_encoder__conv_layers_spec,
459
+ n_chans=self.n_chans,
460
+ n_times=self.n_times,
461
+ n_classes=self.n_outputs,
462
+ n_spat_filters=n_spat_filters,
463
+ )
464
+
465
+ @classmethod
466
+ def from_pretrained(
467
+ cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
468
+ ):
469
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
470
+
471
+ Parameters
472
+ ----------
473
+ model: SignalJEPA
474
+ Pre-trained model.
475
+ n_outputs: int
476
+ Number of classes for the new model.
477
+ n_spat_filters: int
478
+ Number of spatial filters.
479
+ chs_info: list of dict | None
480
+ Information about each individual EEG channel. This should be filled with
481
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
482
+ """
483
+ feature_encoder = model.feature_encoder
484
+ assert feature_encoder is not None
485
+ new_model = cls(
486
+ n_outputs=n_outputs,
487
+ n_chans=model.n_chans,
488
+ n_times=model.n_times,
489
+ n_spat_filters=n_spat_filters,
490
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
491
+ _init_feature_encoder=False,
492
+ )
493
+ new_model.feature_encoder = deepcopy(feature_encoder)
494
+ return new_model
495
+
496
+ def forward(self, X):
497
+ local_features = self.feature_encoder(X)
498
+ y = self.final_layer(local_features)
499
+ return y
500
+
501
+
502
+ class SignalJEPA_PreLocal(_BaseSignalJEPA):
503
+ """Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
504
+
505
+ This architecture is one of the variants of :class:`SignalJEPA`
506
+ that can be used for classification purposes.
507
+
508
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_pre-local.jpg
509
+ :align: center
510
+ :alt: sJEPA Pre-Local.
511
+
512
+ .. versionadded:: 0.9
513
+
514
+ Parameters
515
+ ----------
516
+ n_spat_filters : int
517
+ Number of spatial filters.
518
+
519
+ References
520
+ ----------
521
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
522
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
523
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
524
+ """
525
+
526
+ _feature_encoder_channels: str = "n_spat_filters"
527
+
528
+ def __init__(
529
+ self,
530
+ n_outputs=None,
531
+ n_chans=None,
532
+ chs_info=None,
533
+ n_times=None,
534
+ input_window_seconds=None,
535
+ sfreq=None,
536
+ *,
537
+ n_spat_filters: int = 4,
538
+ # feature_encoder
539
+ feature_encoder__conv_layers_spec: Sequence[
540
+ tuple[int, int, int]
541
+ ] = _DEFAULT_CONV_LAYER_SPEC,
542
+ drop_prob: float = 0.0,
543
+ feature_encoder__mode: str = "default",
544
+ feature_encoder__conv_bias: bool = False,
545
+ activation: type[nn.Module] = nn.GELU,
546
+ # pos_encoder
547
+ pos_encoder__spat_dim: int = 30,
548
+ pos_encoder__time_dim: int = 34,
549
+ pos_encoder__sfreq_features: float = 1.0,
550
+ pos_encoder__spat_kwargs: dict | None = None,
551
+ # transformer
552
+ transformer__d_model: int = 64,
553
+ transformer__num_encoder_layers: int = 8,
554
+ transformer__num_decoder_layers: int = 4,
555
+ transformer__nhead: int = 8,
556
+ # other
557
+ _init_feature_encoder: bool = True,
558
+ ):
559
+ self.n_spat_filters = n_spat_filters
560
+ super().__init__(
561
+ n_outputs=n_outputs,
562
+ n_chans=n_chans,
563
+ chs_info=chs_info,
564
+ n_times=n_times,
565
+ input_window_seconds=input_window_seconds,
566
+ sfreq=sfreq,
567
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
568
+ drop_prob=drop_prob,
569
+ feature_encoder__mode=feature_encoder__mode,
570
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
571
+ activation=activation,
572
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
573
+ pos_encoder__time_dim=pos_encoder__time_dim,
574
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
575
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
576
+ transformer__d_model=transformer__d_model,
577
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
578
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
579
+ transformer__nhead=transformer__nhead,
580
+ _init_feature_encoder=_init_feature_encoder,
581
+ _init_transformer=False,
582
+ )
583
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
584
+ self.spatial_conv = nn.Sequential(
585
+ Rearrange("b channels time -> b 1 channels time"),
586
+ nn.Conv2d(1, n_spat_filters, (self.n_chans, 1)),
587
+ Rearrange("b spat_filters 1 time -> b spat_filters time"),
588
+ )
589
+ out_emb_dim = _get_out_emb_dim(
590
+ conv_layers_spec=feature_encoder__conv_layers_spec,
591
+ n_times=self.n_times,
592
+ n_spat_filters=n_spat_filters,
593
+ )
594
+ self.final_layer = nn.Sequential(
595
+ nn.Flatten(start_dim=1),
596
+ nn.Linear(out_emb_dim, self.n_outputs),
597
+ )
598
+
599
+ @classmethod
600
+ def from_pretrained(
601
+ cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
602
+ ):
603
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
604
+
605
+ Parameters
606
+ ----------
607
+ model: SignalJEPA
608
+ Pre-trained model.
609
+ n_outputs: int
610
+ Number of classes for the new model.
611
+ n_spat_filters: int
612
+ Number of spatial filters.
613
+ chs_info: list of dict | None
614
+ Information about each individual EEG channel. This should be filled with
615
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
616
+ """
617
+ feature_encoder = model.feature_encoder
618
+ assert feature_encoder is not None
619
+ new_model = cls(
620
+ n_outputs=n_outputs,
621
+ n_chans=model.n_chans,
622
+ n_times=model.n_times,
623
+ n_spat_filters=n_spat_filters,
624
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
625
+ _init_feature_encoder=False,
626
+ )
627
+ new_model.feature_encoder = deepcopy(feature_encoder)
628
+ return new_model
629
+
630
+ def forward(self, X):
631
+ X = self.spatial_conv(X)
632
+ local_features = self.feature_encoder(X)
633
+ y = self.final_layer(local_features)
634
+ return y
635
+
636
+
637
+ class _ConvFeatureEncoder(nn.Sequential):
638
+ """Convolutional feature encoder for EEG data.
639
+
640
+ Computes successive 1D convolutions (with activations) over the time
641
+ dimension of the input EEG signal.
642
+
643
+ Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
644
+ and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
645
+
646
+ Parameters
647
+ ----------
648
+ conv_layers_spec: list of tuple
649
+ tuples have shape ``(dim, k, stride)`` where:
650
+
651
+ * ``dim`` : number of output channels of the layer (unrelated to EEG channels);
652
+ * ``k`` : temporal length of the layer's kernel;
653
+ * ``stride`` : temporal stride of the layer's kernel.
654
+
655
+ channels: int
656
+ drop_prob: float
657
+ mode: str
658
+ Normalisation mode. Either ``default`` or ``layer_norm``.
659
+ conv_bias: bool
660
+ activation: nn.Module
661
+ """
662
+
663
+ def __init__(
664
+ self,
665
+ conv_layers_spec: Sequence[tuple[int, int, int]],
666
+ channels: int,
667
+ drop_prob: float = 0.0,
668
+ mode: str = "default",
669
+ conv_bias: bool = False,
670
+ activation: type[nn.Module] = nn.GELU,
671
+ ):
672
+ assert mode in {"default", "layer_norm"}
673
+
674
+ input_channels = 1
675
+ conv_layers = []
676
+ for i, layer_spec in enumerate(conv_layers_spec):
677
+ # Each layer_spec should be a tuple: (output_channels, kernel_size, stride)
678
+ assert len(layer_spec) == 3, "Invalid conv definition: " + str(layer_spec)
679
+ output_channels, kernel_size, stride = layer_spec
680
+ conv_layers.append(
681
+ self._get_block(
682
+ input_channels,
683
+ output_channels,
684
+ kernel_size,
685
+ stride,
686
+ drop_prob,
687
+ activation,
688
+ is_layer_norm=(mode == "layer_norm"),
689
+ is_group_norm=(mode == "default" and i == 0),
690
+ conv_bias=conv_bias,
691
+ )
692
+ )
693
+ input_channels = output_channels
694
+ all_layers = [
695
+ Rearrange("b channels time -> (b channels) 1 time", channels=channels),
696
+ *conv_layers,
697
+ Rearrange(
698
+ "(b channels) emb_dim time_out -> b (channels time_out) emb_dim",
699
+ channels=channels,
700
+ ),
701
+ ]
702
+ super().__init__(*all_layers)
703
+ self.emb_dim = (
704
+ output_channels # last output dimension becomes the embedding dimension
705
+ )
706
+ self.conv_layers_spec = conv_layers_spec
707
+
708
+ @staticmethod
709
+ def _get_block(
710
+ input_channels,
711
+ output_channels,
712
+ kernel_size,
713
+ stride,
714
+ drop_prob,
715
+ activation,
716
+ is_layer_norm=False,
717
+ is_group_norm=False,
718
+ conv_bias=False,
719
+ ):
720
+ assert not (is_layer_norm and is_group_norm), (
721
+ "layer norm and group norm are exclusive"
722
+ )
723
+
724
+ conv = nn.Conv1d(
725
+ input_channels,
726
+ output_channels,
727
+ kernel_size,
728
+ stride=stride,
729
+ bias=conv_bias,
730
+ )
731
+ nn.init.kaiming_normal_(conv.weight)
732
+ if is_layer_norm:
733
+ return nn.Sequential(
734
+ conv,
735
+ nn.Dropout(p=drop_prob),
736
+ nn.Sequential(
737
+ Rearrange("... channels time -> ... time channels"),
738
+ nn.LayerNorm(output_channels, elementwise_affine=True),
739
+ Rearrange("... time channels -> ... channels time"),
740
+ ),
741
+ activation(),
742
+ )
743
+ elif is_group_norm:
744
+ return nn.Sequential(
745
+ conv,
746
+ nn.Dropout(p=drop_prob),
747
+ nn.GroupNorm(output_channels, output_channels, affine=True),
748
+ activation(),
749
+ )
750
+ else:
751
+ return nn.Sequential(conv, nn.Dropout(p=drop_prob), activation())
752
+
753
+ def n_times_out(self, n_times):
754
+ return _n_times_out(self.conv_layers_spec, n_times)
755
+
756
+
757
+ class _ChannelEmbedding(nn.Embedding):
758
+ """Embedding layer for EEG channels.
759
+
760
+ The difference with a regular :class:`nn.Embedding` is that the embedding
761
+ vectors are initialized with a positional encodding of the channel locations.
762
+
763
+ Parameters
764
+ ----------
765
+ channel_locations: list of (list of float or None)
766
+ List of the n-dimensions locations of the EEG channels.
767
+ embedding_dim: int
768
+ Dimensionality of the embedding vectors. Must be a multiple of the number
769
+ of dimensions of the channel locations.
770
+ """
771
+
772
+ def __init__(
773
+ self, channel_locations: list[list[float] | None], embedding_dim: int, **kwargs
774
+ ):
775
+ self.coordinate_ranges = [
776
+ (min(coords), max(coords))
777
+ for coords in zip(
778
+ *[
779
+ loc[3:6] if len(loc) == 12 else loc
780
+ for loc in channel_locations
781
+ if loc is not None
782
+ ]
783
+ )
784
+ ]
785
+ channel_mins, channel_maxs = zip(*self.coordinate_ranges)
786
+ global_min = min(channel_mins)
787
+ global_max = max(channel_maxs)
788
+ self.max_abs_coordinate = max(abs(global_min), abs(global_max))
789
+ self.embedding_dim_per_coordinate = embedding_dim // len(self.coordinate_ranges)
790
+ self.channel_locations = list(channel_locations)
791
+
792
+ assert embedding_dim % len(self.coordinate_ranges) == 0
793
+
794
+ super().__init__(len(channel_locations), embedding_dim, **kwargs)
795
+
796
+ def reset_parameters(self):
797
+ for i, loc in enumerate(self.channel_locations):
798
+ if loc is None:
799
+ nn.init.zeros_(self.weight[i])
800
+ else:
801
+ for j, (x, (x0, x1)) in enumerate(zip(loc, self.coordinate_ranges)):
802
+ with torch.no_grad():
803
+ self.weight[
804
+ i,
805
+ j * self.embedding_dim_per_coordinate : (j + 1)
806
+ * self.embedding_dim_per_coordinate,
807
+ ].copy_(
808
+ _pos_encode_contineous(
809
+ x,
810
+ 0,
811
+ 10 * self.max_abs_coordinate,
812
+ self.embedding_dim_per_coordinate,
813
+ device=self.weight.device,
814
+ ),
815
+ )
816
+
817
+
818
+ class _PosEncoder(nn.Module):
819
+ """Positional encoder for EEG data.
820
+
821
+ Parameters
822
+ ----------
823
+ spat_dim: int
824
+ Number of dimensions to use to encode the spatial position of the patch,
825
+ i.e. the EEG channel.
826
+ time_dim: int
827
+ Number of dimensions to use to encode the temporal position of the patch.
828
+ ch_locs: list of list of float or 2d array
829
+ List of the n-dimensions locations of the EEG channels.
830
+ sfreq_features: float
831
+ The "downsampled" sampling frequency returned by the feature encoder.
832
+ spat_kwargs: dict
833
+ Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
834
+ embed the channel names.
835
+ max_seconds: float
836
+ Maximum number of seconds to consider for the temporal encoding.
837
+ """
838
+
839
+ def __init__(
840
+ self,
841
+ spat_dim: int,
842
+ time_dim: int,
843
+ ch_locs,
844
+ sfreq_features: float,
845
+ spat_kwargs: dict | None = None,
846
+ max_seconds: float = 600.0, # 10 minutes
847
+ ):
848
+ super().__init__()
849
+ spat_kwargs = spat_kwargs or {}
850
+ self.spat_dim = spat_dim
851
+ self.time_dim = time_dim
852
+ self.max_n_times = int(max_seconds * sfreq_features)
853
+
854
+ # Positional encoder for the spatial dimension:
855
+ self.pos_encoder_spat = _ChannelEmbedding(
856
+ ch_locs, spat_dim, **spat_kwargs
857
+ ) # (batch_size, n_channels, spat_dim)
858
+
859
+ # Pre-computed tensor for positional encoding on the time dimension:
860
+ self.encoding_time = torch.zeros(0, dtype=torch.float32, requires_grad=False)
861
+
862
+ def _check_encoding_time(self, n_times: int):
863
+ if self.encoding_time.size(0) < n_times:
864
+ self.encoding_time = self.encoding_time.new_empty((n_times, self.time_dim))
865
+ self.encoding_time[:] = _pos_encode_time(
866
+ n_times=n_times,
867
+ n_dim=self.time_dim,
868
+ max_n_times=self.max_n_times,
869
+ device=self.encoding_time.device,
870
+ )
871
+
872
+ def forward(self, local_features, ch_idxs: torch.Tensor | None = None):
873
+ """
874
+ Parameters
875
+ ----------
876
+ * local_features: (batch_size, n_chans * n_times_out, emb_dim)
877
+ * ch_idxs: (batch_size, n_chans) | None
878
+ Indices of the channels to use in the ``ch_names`` list passed
879
+ as argument plus one. Index 0 is reserved for an unknown channel.
880
+
881
+ Returns
882
+ -------
883
+ pos_encoding: (batch_size, n_chans * n_times_out, emb_dim)
884
+ The first ``spat_dim`` dimensions encode the channels positional encoding
885
+ and the following ``time_dim`` dimensions encode the temporal positional encoding.
886
+ """
887
+ batch_size, n_chans_times, emb_dim = local_features.shape
888
+ if ch_idxs is None:
889
+ ch_idxs = torch.arange(
890
+ 0,
891
+ self.pos_encoder_spat.num_embeddings,
892
+ device=local_features.device,
893
+ ).repeat(batch_size, 1)
894
+
895
+ batch_size_chs, n_chans = ch_idxs.shape
896
+ assert emb_dim >= self.spat_dim + self.time_dim
897
+ assert n_chans_times % n_chans == 0
898
+ n_times = n_chans_times // n_chans
899
+
900
+ pos_encoding = local_features.new_empty(
901
+ (batch_size_chs, n_chans, n_times, emb_dim)
902
+ )
903
+ # Channel pos. encoding
904
+ pos_encoding[:, :, :, : self.spat_dim] = self.pos_encoder_spat(ch_idxs)[
905
+ :, :, None, :
906
+ ]
907
+ # Temporal pos. encoding
908
+ self._check_encoding_time(n_times)
909
+ _ = pos_encoding[:, :, :, self.spat_dim : self.spat_dim + self.time_dim].copy_(
910
+ self.encoding_time[None, None, :n_times, :],
911
+ )
912
+
913
+ return pos_encoding.view(batch_size, n_chans_times, emb_dim)
914
+
915
+
916
+ def _n_times_out(conv_layers_spec, n_times):
917
+ # it would be equal to n_times//ds_factor without edge effects:
918
+ n_times_out_ = n_times
919
+ for _, width, stride in conv_layers_spec:
920
+ n_times_out_ = int((n_times_out_ - width) / stride) + 1
921
+ return n_times_out_
922
+
923
+
924
+ def _get_out_emb_dim(conv_layers_spec, n_times, n_spat_filters=4):
925
+ n_time_out = _n_times_out(conv_layers_spec, n_times)
926
+ emb_dim = conv_layers_spec[-1][0]
927
+ return n_spat_filters * n_time_out * emb_dim
928
+
929
+
930
+ def _get_separable_clf_layer(
931
+ conv_layers_spec, n_chans, n_times, n_classes, n_spat_filters=4
932
+ ):
933
+ out_emb_dim = _get_out_emb_dim(
934
+ conv_layers_spec=conv_layers_spec,
935
+ n_times=n_times,
936
+ n_spat_filters=n_spat_filters,
937
+ )
938
+ clf_layer = nn.Sequential()
939
+ clf_layer.add_module(
940
+ "unflatten_tokens",
941
+ Rearrange("b (n_chans tokens) d -> b 1 n_chans tokens d", n_chans=n_chans),
942
+ )
943
+ clf_layer.add_module("spat_conv", nn.Conv3d(1, n_spat_filters, (n_chans, 1, 1)))
944
+ clf_layer.add_module("flatten", nn.Flatten(start_dim=1))
945
+ clf_layer.add_module("linear", nn.Linear(out_emb_dim, n_classes))
946
+ return clf_layer
947
+
948
+
949
+ def _pos_encode_time(
950
+ n_times: int,
951
+ n_dim: int,
952
+ max_n_times: int,
953
+ device: torch.device = torch.device("cpu"),
954
+ ):
955
+ """1-dimensional positional encoding.
956
+
957
+ Parameters
958
+ ----------
959
+ n_times: int
960
+ Number of time samples to encode.
961
+ n_dim: int
962
+ Number of dimensions of the positional encoding. Must be even.
963
+ max_n_times: int
964
+ The largest possible number of time samples to encode.
965
+ Used to scale the positional encoding.
966
+ device: torch.device
967
+ Device to put the output on.
968
+ Returns
969
+ -------
970
+ pos_encoding: (n_times, n_dim)
971
+ """
972
+ assert n_dim % 2 == 0
973
+ position = torch.arange(n_times, device=device).unsqueeze(1)
974
+ div_term = torch.exp(
975
+ torch.arange(0, n_dim, 2, device=device) * (-math.log(max_n_times) / n_dim)
976
+ )
977
+ pos_encoding = torch.empty((n_times, n_dim), dtype=torch.float32, device=device)
978
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
979
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
980
+ return pos_encoding
981
+
982
+
983
+ def _pos_encode_contineous(
984
+ x, x_min, x_max, n_dim, device: torch.device = torch.device("cpu")
985
+ ):
986
+ """1-dimensional positional encoding.
987
+
988
+ Parameters
989
+ ----------
990
+ x: float
991
+ The position to encode.
992
+ x_min: float
993
+ The minimum possible value of x.
994
+ x_max: float
995
+ The maximum possible value of x.
996
+ n_dim: int
997
+ Number of dimensions of the positional encoding. Must be even.
998
+ device: torch.device
999
+ Device to put the output on.
1000
+ Returns
1001
+ -------
1002
+ pos_encoding: (n_dim,)
1003
+ """
1004
+ assert n_dim % 2 == 0
1005
+ div_term = torch.exp(
1006
+ (1 - torch.arange(0, n_dim, 2, device=device) / n_dim) * 2 * math.pi
1007
+ )
1008
+ pos_encoding = torch.empty((n_dim,), dtype=torch.float32, device=device)
1009
+ xx = (x - x_min) / (x_max - x_min)
1010
+ pos_encoding[0::2] = torch.sin(xx * div_term)
1011
+ pos_encoding[1::2] = torch.cos(xx * div_term)
1012
+ return pos_encoding