braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1122 @@
1
+ # Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from copy import deepcopy
8
+ from pathlib import Path
9
+ from typing import Any, Optional, Sequence
10
+
11
+ import torch
12
+ from einops.layers.torch import Rearrange
13
+ from torch import nn
14
+
15
+ from braindecode.models.base import EEGModuleMixin
16
+
17
+ _DEFAULT_CONV_LAYER_SPEC = ( # downsampling: 128Hz -> 1Hz, receptive field 1.1875s, stride 1s
18
+ (8, 32, 8),
19
+ (16, 2, 2),
20
+ (32, 2, 2),
21
+ (64, 2, 2),
22
+ (64, 2, 2),
23
+ )
24
+
25
+
26
+ class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
27
+ r"""Base class for the SignalJEPA models
28
+
29
+ Parameters
30
+ ----------
31
+ feature_encoder__conv_layers_spec: list of tuple
32
+ tuples have shape ``(dim, k, stride)`` where:
33
+
34
+ * ``dim`` : number of output channels of the layer (unrelated to EEG channels);
35
+ * ``k`` : temporal length of the layer's kernel;
36
+ * ``stride`` : temporal stride of the layer's kernel.
37
+
38
+ drop_prob: float
39
+ feature_encoder__mode: str
40
+ Normalisation mode. Either ``default`` or ``layer_norm``.
41
+ feature_encoder__conv_bias: bool
42
+ activation: nn.Module
43
+ Activation layer for the feature encoder.
44
+ pos_encoder__spat_dim: int
45
+ Number of dimensions to use to encode the spatial position of the patch,
46
+ i.e. the EEG channel.
47
+ pos_encoder__time_dim: int
48
+ Number of dimensions to use to encode the temporal position of the patch.
49
+ pos_encoder__sfreq_features: float
50
+ The "downsampled" sampling frequency returned by the feature encoder.
51
+ pos_encoder__spat_kwargs: dict
52
+ Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
53
+ embed the channel names.
54
+ transformer__d_model: int
55
+ The number of expected features in the encoder/decoder inputs.
56
+ transformer__num_encoder_layers: int
57
+ The number of encoder layers in the transformer.
58
+ transformer__num_decoder_layers: int
59
+ The number of decoder layers in the transformer.
60
+ transformer__nhead: int
61
+ The number of heads in the multiheadattention models.
62
+ _init_feature_encoder : bool
63
+ Do not change the default value (used for internal purposes).
64
+ _init_transformer : bool
65
+ Do not change the default value (used for internal purposes).
66
+ """
67
+
68
+ feature_encoder: _ConvFeatureEncoder | None
69
+ pos_encoder: _PosEncoder | None
70
+ transformer: nn.Transformer | None
71
+
72
+ _feature_encoder_channels: str = "n_chans"
73
+
74
+ def __init__(
75
+ self,
76
+ n_outputs=None,
77
+ n_chans=None,
78
+ chs_info=None,
79
+ n_times=None,
80
+ input_window_seconds=None,
81
+ sfreq=None,
82
+ *,
83
+ # feature_encoder
84
+ feature_encoder__conv_layers_spec: Sequence[
85
+ tuple[int, int, int]
86
+ ] = _DEFAULT_CONV_LAYER_SPEC,
87
+ drop_prob: float = 0.0,
88
+ feature_encoder__mode: str = "default",
89
+ feature_encoder__conv_bias: bool = False,
90
+ activation: type[nn.Module] = nn.GELU,
91
+ # pos_encoder
92
+ pos_encoder__spat_dim: int = 30,
93
+ pos_encoder__time_dim: int = 34,
94
+ pos_encoder__sfreq_features: float = 1.0,
95
+ pos_encoder__spat_kwargs: dict | None = None,
96
+ # transformer
97
+ transformer__d_model: int = 64,
98
+ transformer__num_encoder_layers: int = 8,
99
+ transformer__num_decoder_layers: int = 4,
100
+ transformer__nhead: int = 8,
101
+ # other
102
+ _init_feature_encoder: bool,
103
+ _init_transformer: bool,
104
+ ):
105
+ super().__init__(
106
+ n_outputs=n_outputs,
107
+ n_chans=n_chans,
108
+ chs_info=chs_info,
109
+ n_times=n_times,
110
+ input_window_seconds=input_window_seconds,
111
+ sfreq=sfreq,
112
+ )
113
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
114
+
115
+ self.feature_encoder = None
116
+ self.pos_encoder = None
117
+ self.transformer = None
118
+ if _init_feature_encoder:
119
+ self.feature_encoder = _ConvFeatureEncoder(
120
+ conv_layers_spec=feature_encoder__conv_layers_spec,
121
+ channels=getattr(self, self._feature_encoder_channels),
122
+ drop_prob=drop_prob,
123
+ mode=feature_encoder__mode,
124
+ conv_bias=feature_encoder__conv_bias,
125
+ activation=activation,
126
+ )
127
+
128
+ if _init_transformer:
129
+ ch_locs = [ch["loc"] for ch in self.chs_info] # type: ignore
130
+ self.pos_encoder = _PosEncoder(
131
+ spat_dim=pos_encoder__spat_dim,
132
+ time_dim=pos_encoder__time_dim,
133
+ ch_locs=ch_locs,
134
+ sfreq_features=pos_encoder__sfreq_features,
135
+ spat_kwargs=pos_encoder__spat_kwargs,
136
+ )
137
+ self.transformer = nn.Transformer(
138
+ d_model=transformer__d_model,
139
+ nhead=transformer__nhead,
140
+ num_encoder_layers=transformer__num_encoder_layers,
141
+ num_decoder_layers=transformer__num_decoder_layers,
142
+ batch_first=True,
143
+ )
144
+
145
+
146
+ class SignalJEPA(_BaseSignalJEPA):
147
+ r"""Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
148
+
149
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
150
+
151
+ This model is not meant for classification but for SSL pre-training.
152
+ Its output shape depends on the input shape.
153
+ For classification purposes, three variants of this model are available:
154
+
155
+ * :class:`SignalJEPA_Contextual`
156
+ * :class:`SignalJEPA_PostLocal`
157
+ * :class:`SignalJEPA_PreLocal`
158
+
159
+ The classification architectures can either be instantiated from scratch
160
+ (random parameters) or from a pre-trained :class:`SignalJEPA` model.
161
+
162
+ .. versionadded:: 0.9
163
+
164
+ References
165
+ ----------
166
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
167
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
168
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ n_outputs=None,
174
+ n_chans=None,
175
+ chs_info=None,
176
+ n_times=None,
177
+ input_window_seconds=None,
178
+ sfreq=None,
179
+ *,
180
+ # feature_encoder
181
+ feature_encoder__conv_layers_spec: Sequence[
182
+ tuple[int, int, int]
183
+ ] = _DEFAULT_CONV_LAYER_SPEC,
184
+ drop_prob: float = 0.0,
185
+ feature_encoder__mode: str = "default",
186
+ feature_encoder__conv_bias: bool = False,
187
+ activation: type[nn.Module] = nn.GELU,
188
+ # pos_encoder
189
+ pos_encoder__spat_dim: int = 30,
190
+ pos_encoder__time_dim: int = 34,
191
+ pos_encoder__sfreq_features: float = 1.0,
192
+ pos_encoder__spat_kwargs: dict | None = None,
193
+ # transformer
194
+ transformer__d_model: int = 64,
195
+ transformer__num_encoder_layers: int = 8,
196
+ transformer__num_decoder_layers: int = 4,
197
+ transformer__nhead: int = 8,
198
+ ):
199
+ super().__init__(
200
+ n_outputs=n_outputs,
201
+ n_chans=n_chans,
202
+ chs_info=chs_info,
203
+ n_times=n_times,
204
+ input_window_seconds=input_window_seconds,
205
+ sfreq=sfreq,
206
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
207
+ drop_prob=drop_prob,
208
+ feature_encoder__mode=feature_encoder__mode,
209
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
210
+ activation=activation,
211
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
212
+ pos_encoder__time_dim=pos_encoder__time_dim,
213
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
214
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
215
+ transformer__d_model=transformer__d_model,
216
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
217
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
218
+ transformer__nhead=transformer__nhead,
219
+ _init_feature_encoder=True,
220
+ _init_transformer=True,
221
+ )
222
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
223
+ self.final_layer = nn.Identity()
224
+
225
+ def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
226
+ local_features = self.feature_encoder(X) # type: ignore
227
+ pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
228
+ local_features += pos_encoding # type: ignore
229
+ contextual_features = self.transformer.encoder(local_features) # type: ignore
230
+ y = self.final_layer(contextual_features) # type: ignore
231
+ return y # type: ignore
232
+
233
+
234
+ class SignalJEPA_Contextual(_BaseSignalJEPA):
235
+ r"""Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
236
+
237
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
238
+
239
+ This architecture is one of the variants of :class:`SignalJEPA`
240
+ that can be used for classification purposes.
241
+
242
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_contextual.jpg
243
+ :align: center
244
+ :alt: sJEPA Contextual.
245
+
246
+ .. versionadded:: 0.9
247
+
248
+ Parameters
249
+ ----------
250
+ n_spat_filters : int
251
+ Number of spatial filters.
252
+
253
+ References
254
+ ----------
255
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
256
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
257
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ n_outputs=None,
263
+ n_chans=None,
264
+ chs_info=None,
265
+ n_times=None,
266
+ input_window_seconds=None,
267
+ sfreq=None,
268
+ *,
269
+ n_spat_filters: int = 4,
270
+ # feature_encoder
271
+ feature_encoder__conv_layers_spec: Sequence[
272
+ tuple[int, int, int]
273
+ ] = _DEFAULT_CONV_LAYER_SPEC,
274
+ drop_prob: float = 0.0,
275
+ feature_encoder__mode: str = "default",
276
+ feature_encoder__conv_bias: bool = False,
277
+ activation: type[nn.Module] = nn.GELU,
278
+ # pos_encoder
279
+ pos_encoder__spat_dim: int = 30,
280
+ pos_encoder__time_dim: int = 34,
281
+ pos_encoder__sfreq_features: float = 1.0,
282
+ pos_encoder__spat_kwargs: dict | None = None,
283
+ # transformer
284
+ transformer__d_model: int = 64,
285
+ transformer__num_encoder_layers: int = 8,
286
+ transformer__num_decoder_layers: int = 4,
287
+ transformer__nhead: int = 8,
288
+ # other
289
+ _init_feature_encoder: bool = True,
290
+ _init_transformer: bool = True,
291
+ ):
292
+ super().__init__(
293
+ n_outputs=n_outputs,
294
+ n_chans=n_chans,
295
+ chs_info=chs_info,
296
+ n_times=n_times,
297
+ input_window_seconds=input_window_seconds,
298
+ sfreq=sfreq,
299
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
300
+ drop_prob=drop_prob,
301
+ feature_encoder__mode=feature_encoder__mode,
302
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
303
+ activation=activation,
304
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
305
+ pos_encoder__time_dim=pos_encoder__time_dim,
306
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
307
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
308
+ transformer__d_model=transformer__d_model,
309
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
310
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
311
+ transformer__nhead=transformer__nhead,
312
+ _init_feature_encoder=_init_feature_encoder,
313
+ _init_transformer=_init_transformer,
314
+ )
315
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
316
+ self.final_layer = _get_separable_clf_layer(
317
+ conv_layers_spec=feature_encoder__conv_layers_spec,
318
+ n_chans=self.n_chans,
319
+ n_times=self.n_times,
320
+ n_classes=self.n_outputs,
321
+ n_spat_filters=n_spat_filters,
322
+ )
323
+
324
+ @classmethod
325
+ def from_pretrained(
326
+ cls,
327
+ model: Optional[SignalJEPA | str | Path] = None, # type: ignore
328
+ n_outputs: Optional[int] = None, # type: ignore
329
+ n_spat_filters: int = 4,
330
+ chs_info: Optional[list[dict[str, Any]]] = None, # type: ignore
331
+ **kwargs,
332
+ ):
333
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
334
+
335
+ Parameters
336
+ ----------
337
+ model: SignalJEPA, str, Path, or None
338
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
339
+ (for Hub-style loading), or None (for Hub loading via kwargs).
340
+ n_outputs: int or None
341
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
342
+ optional when loading from Hub (will be read from config).
343
+ n_spat_filters: int
344
+ Number of spatial filters.
345
+ chs_info: list of dict | None
346
+ Information about each individual EEG channel. This should be filled with
347
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
348
+ **kwargs
349
+ Additional keyword arguments passed to the parent class for Hub loading.
350
+ """
351
+ # Check if this is a Hub-style load (from a directory path)
352
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
353
+ # This is a Hub load, delegate to parent class
354
+ if isinstance(model, (str, Path)):
355
+ # model is actually the repo_id or directory path
356
+ return super().from_pretrained(model, **kwargs)
357
+ else:
358
+ # model is None, treat as hub-style load
359
+ return super().from_pretrained(**kwargs)
360
+
361
+ # This is the original SignalJEPA transfer learning case
362
+ if not isinstance(model, SignalJEPA):
363
+ raise TypeError(
364
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
365
+ )
366
+ if n_outputs is None:
367
+ raise ValueError(
368
+ "n_outputs must be provided when loading from a SignalJEPA model"
369
+ )
370
+
371
+ feature_encoder = model.feature_encoder
372
+ pos_encoder = model.pos_encoder
373
+ transformer = model.transformer
374
+ assert feature_encoder is not None
375
+ assert pos_encoder is not None
376
+ assert transformer is not None
377
+
378
+ new_model = cls(
379
+ n_outputs=n_outputs,
380
+ n_chans=model.n_chans,
381
+ n_times=model.n_times,
382
+ chs_info=chs_info,
383
+ n_spat_filters=n_spat_filters,
384
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
385
+ _init_feature_encoder=False,
386
+ _init_transformer=False,
387
+ )
388
+ new_model.feature_encoder = deepcopy(feature_encoder)
389
+ new_model.pos_encoder = deepcopy(pos_encoder)
390
+ new_model.transformer = deepcopy(transformer)
391
+
392
+ if chs_info is not None:
393
+ ch_names = [ch["ch_name"] for ch in chs_info]
394
+ new_model.pos_encoder.set_fixed_ch_names(ch_names)
395
+
396
+ return new_model
397
+
398
+ def forward(self, X, ch_idxs: torch.Tensor | None = None): # type: ignore
399
+ local_features = self.feature_encoder(X) # type: ignore
400
+ pos_encoding = self.pos_encoder(local_features, ch_idxs=ch_idxs) # type: ignore
401
+ local_features += pos_encoding # type: ignore
402
+ contextual_features = self.transformer.encoder(local_features) # type: ignore
403
+ y = self.final_layer(contextual_features) # type: ignore
404
+ return y # type: ignore
405
+
406
+
407
+ class SignalJEPA_PostLocal(_BaseSignalJEPA):
408
+ r"""Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
409
+
410
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
411
+
412
+ This architecture is one of the variants of :class:`SignalJEPA`
413
+ that can be used for classification purposes.
414
+
415
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_post-local.jpg
416
+ :align: center
417
+ :alt: sJEPA Pre-Local.
418
+
419
+ .. versionadded:: 0.9
420
+
421
+ Parameters
422
+ ----------
423
+ n_spat_filters : int
424
+ Number of spatial filters.
425
+
426
+ References
427
+ ----------
428
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
429
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
430
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ n_outputs=None,
436
+ n_chans=None,
437
+ chs_info=None,
438
+ n_times=None,
439
+ input_window_seconds=None,
440
+ sfreq=None,
441
+ *,
442
+ n_spat_filters: int = 4,
443
+ # feature_encoder
444
+ feature_encoder__conv_layers_spec: Sequence[
445
+ tuple[int, int, int]
446
+ ] = _DEFAULT_CONV_LAYER_SPEC,
447
+ drop_prob: float = 0.0,
448
+ feature_encoder__mode: str = "default",
449
+ feature_encoder__conv_bias: bool = False,
450
+ activation: type[nn.Module] = nn.GELU,
451
+ # pos_encoder
452
+ pos_encoder__spat_dim: int = 30,
453
+ pos_encoder__time_dim: int = 34,
454
+ pos_encoder__sfreq_features: float = 1.0,
455
+ pos_encoder__spat_kwargs: dict | None = None,
456
+ # transformer
457
+ transformer__d_model: int = 64,
458
+ transformer__num_encoder_layers: int = 8,
459
+ transformer__num_decoder_layers: int = 4,
460
+ transformer__nhead: int = 8,
461
+ # other
462
+ _init_feature_encoder: bool = True,
463
+ ):
464
+ super().__init__(
465
+ n_outputs=n_outputs,
466
+ n_chans=n_chans,
467
+ chs_info=chs_info,
468
+ n_times=n_times,
469
+ input_window_seconds=input_window_seconds,
470
+ sfreq=sfreq,
471
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
472
+ drop_prob=drop_prob,
473
+ feature_encoder__mode=feature_encoder__mode,
474
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
475
+ activation=activation,
476
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
477
+ pos_encoder__time_dim=pos_encoder__time_dim,
478
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
479
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
480
+ transformer__d_model=transformer__d_model,
481
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
482
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
483
+ transformer__nhead=transformer__nhead,
484
+ _init_feature_encoder=_init_feature_encoder,
485
+ _init_transformer=False,
486
+ )
487
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
488
+ self.final_layer = _get_separable_clf_layer(
489
+ conv_layers_spec=feature_encoder__conv_layers_spec,
490
+ n_chans=self.n_chans,
491
+ n_times=self.n_times,
492
+ n_classes=self.n_outputs,
493
+ n_spat_filters=n_spat_filters,
494
+ )
495
+
496
+ @classmethod
497
+ def from_pretrained(
498
+ cls,
499
+ model: SignalJEPA | str | Path = None, # type: ignore
500
+ n_outputs: int = None, # type: ignore
501
+ n_spat_filters: int = 4,
502
+ **kwargs,
503
+ ):
504
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
505
+
506
+ Parameters
507
+ ----------
508
+ model: SignalJEPA, str, Path, or None
509
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
510
+ (for Hub-style loading), or None (for Hub loading via kwargs).
511
+ n_outputs: int or None
512
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
513
+ optional when loading from Hub (will be read from config).
514
+ n_spat_filters: int
515
+ Number of spatial filters.
516
+ **kwargs
517
+ Additional keyword arguments passed to the parent class for Hub loading.
518
+ """
519
+ # Check if this is a Hub-style load (from a directory path)
520
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
521
+ # This is a Hub load, delegate to parent class
522
+ if isinstance(model, (str, Path)):
523
+ # model is actually the repo_id or directory path
524
+ return super().from_pretrained(model, **kwargs)
525
+ else:
526
+ # model is None, treat as hub-style load
527
+ return super().from_pretrained(**kwargs)
528
+
529
+ # This is the original SignalJEPA transfer learning case
530
+ if not isinstance(model, SignalJEPA):
531
+ raise TypeError(
532
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
533
+ )
534
+ if n_outputs is None:
535
+ raise ValueError(
536
+ "n_outputs must be provided when loading from a SignalJEPA model"
537
+ )
538
+
539
+ feature_encoder = model.feature_encoder
540
+ assert feature_encoder is not None
541
+ new_model = cls(
542
+ n_outputs=n_outputs,
543
+ n_chans=model.n_chans,
544
+ n_times=model.n_times,
545
+ n_spat_filters=n_spat_filters,
546
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
547
+ _init_feature_encoder=False,
548
+ )
549
+ new_model.feature_encoder = deepcopy(feature_encoder)
550
+ return new_model
551
+
552
+ def forward(self, X):
553
+ local_features = self.feature_encoder(X)
554
+ y = self.final_layer(local_features)
555
+ return y
556
+
557
+
558
+ class SignalJEPA_PreLocal(_BaseSignalJEPA):
559
+ r"""Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
560
+
561
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
562
+
563
+ This architecture is one of the variants of :class:`SignalJEPA`
564
+ that can be used for classification purposes.
565
+
566
+ .. figure:: https://braindecode.org/dev/_static/model/sjepa_pre-local.jpg
567
+ :align: center
568
+ :alt: sJEPA Pre-Local.
569
+
570
+ .. versionadded:: 0.9
571
+
572
+ .. important::
573
+ **Pre-trained Weights Available**
574
+
575
+ This model has pre-trained weights available on the Hugging Face Hub.
576
+ You can load them using:
577
+
578
+ .. code-block:: python
579
+
580
+ from braindecode.models import SignalJEPA_PreLocal
581
+
582
+ # Load pre-trained model from Hugging Face Hub
583
+ model = SignalJEPA_PreLocal.from_pretrained(
584
+ "braindecode/SignalJEPA-PreLocal-pretrained"
585
+ )
586
+
587
+ To push your own trained model to the Hub:
588
+
589
+ .. code-block:: python
590
+
591
+ # After training your model
592
+ model.push_to_hub(
593
+ repo_id="username/my-sjepa-model",
594
+ commit_message="Upload trained SignalJEPA model",
595
+ )
596
+
597
+ Requires installing ``braindecode[hug]`` for Hub integration.
598
+
599
+ Parameters
600
+ ----------
601
+ n_spat_filters : int
602
+ Number of spatial filters.
603
+
604
+ References
605
+ ----------
606
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
607
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
608
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
609
+ """
610
+
611
+ _feature_encoder_channels: str = "n_spat_filters"
612
+
613
+ def __init__(
614
+ self,
615
+ n_outputs=None,
616
+ n_chans=None,
617
+ chs_info=None,
618
+ n_times=None,
619
+ input_window_seconds=None,
620
+ sfreq=None,
621
+ *,
622
+ n_spat_filters: int = 4,
623
+ # feature_encoder
624
+ feature_encoder__conv_layers_spec: Sequence[
625
+ tuple[int, int, int]
626
+ ] = _DEFAULT_CONV_LAYER_SPEC,
627
+ drop_prob: float = 0.0,
628
+ feature_encoder__mode: str = "default",
629
+ feature_encoder__conv_bias: bool = False,
630
+ activation: type[nn.Module] = nn.GELU,
631
+ # pos_encoder
632
+ pos_encoder__spat_dim: int = 30,
633
+ pos_encoder__time_dim: int = 34,
634
+ pos_encoder__sfreq_features: float = 1.0,
635
+ pos_encoder__spat_kwargs: dict | None = None,
636
+ # transformer
637
+ transformer__d_model: int = 64,
638
+ transformer__num_encoder_layers: int = 8,
639
+ transformer__num_decoder_layers: int = 4,
640
+ transformer__nhead: int = 8,
641
+ # other
642
+ _init_feature_encoder: bool = True,
643
+ ):
644
+ self.n_spat_filters = n_spat_filters
645
+ super().__init__(
646
+ n_outputs=n_outputs,
647
+ n_chans=n_chans,
648
+ chs_info=chs_info,
649
+ n_times=n_times,
650
+ input_window_seconds=input_window_seconds,
651
+ sfreq=sfreq,
652
+ feature_encoder__conv_layers_spec=feature_encoder__conv_layers_spec,
653
+ drop_prob=drop_prob,
654
+ feature_encoder__mode=feature_encoder__mode,
655
+ feature_encoder__conv_bias=feature_encoder__conv_bias,
656
+ activation=activation,
657
+ pos_encoder__spat_dim=pos_encoder__spat_dim,
658
+ pos_encoder__time_dim=pos_encoder__time_dim,
659
+ pos_encoder__sfreq_features=pos_encoder__sfreq_features,
660
+ pos_encoder__spat_kwargs=pos_encoder__spat_kwargs,
661
+ transformer__d_model=transformer__d_model,
662
+ transformer__num_encoder_layers=transformer__num_encoder_layers,
663
+ transformer__num_decoder_layers=transformer__num_decoder_layers,
664
+ transformer__nhead=transformer__nhead,
665
+ _init_feature_encoder=_init_feature_encoder,
666
+ _init_transformer=False,
667
+ )
668
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
669
+ self.spatial_conv = nn.Sequential(
670
+ Rearrange("b channels time -> b 1 channels time"),
671
+ nn.Conv2d(1, n_spat_filters, (self.n_chans, 1)),
672
+ Rearrange("b spat_filters 1 time -> b spat_filters time"),
673
+ )
674
+ out_emb_dim = _get_out_emb_dim(
675
+ conv_layers_spec=feature_encoder__conv_layers_spec,
676
+ n_times=self.n_times,
677
+ n_spat_filters=n_spat_filters,
678
+ )
679
+ self.final_layer = nn.Sequential(
680
+ nn.Flatten(start_dim=1),
681
+ nn.Linear(out_emb_dim, self.n_outputs),
682
+ )
683
+
684
+ @classmethod
685
+ def from_pretrained(
686
+ cls,
687
+ model: SignalJEPA | str | Path = None, # type: ignore
688
+ n_outputs: int = None, # type: ignore
689
+ n_spat_filters: int = 4,
690
+ **kwargs,
691
+ ):
692
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
693
+
694
+ Parameters
695
+ ----------
696
+ model: SignalJEPA, str, Path, or None
697
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
698
+ (for Hub-style loading), or None (for Hub loading via kwargs).
699
+ n_outputs: int or None
700
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
701
+ optional when loading from Hub (will be read from config).
702
+ n_spat_filters: int
703
+ Number of spatial filters.
704
+ **kwargs
705
+ Additional keyword arguments passed to the parent class for Hub loading.
706
+ """
707
+ # Check if this is a Hub-style load (from a directory path)
708
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
709
+ # This is a Hub load, delegate to parent class
710
+ if isinstance(model, (str, Path)):
711
+ # model is actually the repo_id or directory path
712
+ return super().from_pretrained(model, **kwargs)
713
+ else:
714
+ # model is None, treat as hub-style load
715
+ return super().from_pretrained(**kwargs)
716
+
717
+ # This is the original SignalJEPA transfer learning case
718
+ if not isinstance(model, SignalJEPA):
719
+ raise TypeError(
720
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
721
+ )
722
+ if n_outputs is None:
723
+ raise ValueError(
724
+ "n_outputs must be provided when loading from a SignalJEPA model"
725
+ )
726
+
727
+ feature_encoder = model.feature_encoder
728
+ assert feature_encoder is not None
729
+ new_model = cls(
730
+ n_outputs=n_outputs,
731
+ n_chans=model.n_chans,
732
+ n_times=model.n_times,
733
+ n_spat_filters=n_spat_filters,
734
+ feature_encoder__conv_layers_spec=feature_encoder.conv_layers_spec,
735
+ _init_feature_encoder=False,
736
+ )
737
+ new_model.feature_encoder = deepcopy(feature_encoder)
738
+ return new_model
739
+
740
+ def forward(self, X):
741
+ X = self.spatial_conv(X)
742
+ local_features = self.feature_encoder(X)
743
+ y = self.final_layer(local_features)
744
+ return y
745
+
746
+
747
+ class _ConvFeatureEncoder(nn.Sequential):
748
+ r"""Convolutional feature encoder for EEG data.
749
+
750
+ Computes successive 1D convolutions (with activations) over the time
751
+ dimension of the input EEG signal.
752
+
753
+ Inspiration from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
754
+ and https://github.com/SPOClab-ca/BENDR/blob/main/dn3_ext.py
755
+
756
+ Parameters
757
+ ----------
758
+ conv_layers_spec: list of tuple
759
+ tuples have shape ``(dim, k, stride)`` where:
760
+
761
+ * ``dim`` : number of output channels of the layer (unrelated to EEG channels);
762
+ * ``k`` : temporal length of the layer's kernel;
763
+ * ``stride`` : temporal stride of the layer's kernel.
764
+
765
+ channels: int
766
+ drop_prob: float
767
+ mode: str
768
+ Normalisation mode. Either ``default`` or ``layer_norm``.
769
+ conv_bias: bool
770
+ activation: nn.Module
771
+ """
772
+
773
+ def __init__(
774
+ self,
775
+ conv_layers_spec: Sequence[tuple[int, int, int]],
776
+ channels: int,
777
+ drop_prob: float = 0.0,
778
+ mode: str = "default",
779
+ conv_bias: bool = False,
780
+ activation: type[nn.Module] = nn.GELU,
781
+ ):
782
+ assert mode in {"default", "layer_norm"}
783
+
784
+ input_channels = 1
785
+ conv_layers = []
786
+ for i, layer_spec in enumerate(conv_layers_spec):
787
+ # Each layer_spec should be a tuple: (output_channels, kernel_size, stride)
788
+ assert len(layer_spec) == 3, "Invalid conv definition: " + str(layer_spec)
789
+ output_channels, kernel_size, stride = layer_spec
790
+ conv_layers.append(
791
+ self._get_block(
792
+ input_channels,
793
+ output_channels,
794
+ kernel_size,
795
+ stride,
796
+ drop_prob,
797
+ activation,
798
+ is_layer_norm=(mode == "layer_norm"),
799
+ is_group_norm=(mode == "default" and i == 0),
800
+ conv_bias=conv_bias,
801
+ )
802
+ )
803
+ input_channels = output_channels
804
+ all_layers = [
805
+ Rearrange("b channels time -> (b channels) 1 time", channels=channels),
806
+ *conv_layers,
807
+ Rearrange(
808
+ "(b channels) emb_dim time_out -> b (channels time_out) emb_dim",
809
+ channels=channels,
810
+ ),
811
+ ]
812
+ super().__init__(*all_layers)
813
+ self.emb_dim = (
814
+ output_channels # last output dimension becomes the embedding dimension
815
+ )
816
+ self.conv_layers_spec = conv_layers_spec
817
+
818
+ @staticmethod
819
+ def _get_block(
820
+ input_channels,
821
+ output_channels,
822
+ kernel_size,
823
+ stride,
824
+ drop_prob,
825
+ activation,
826
+ is_layer_norm=False,
827
+ is_group_norm=False,
828
+ conv_bias=False,
829
+ ):
830
+ assert not (is_layer_norm and is_group_norm), (
831
+ "layer norm and group norm are exclusive"
832
+ )
833
+
834
+ conv = nn.Conv1d(
835
+ input_channels,
836
+ output_channels,
837
+ kernel_size,
838
+ stride=stride,
839
+ bias=conv_bias,
840
+ )
841
+ nn.init.kaiming_normal_(conv.weight)
842
+ if is_layer_norm:
843
+ return nn.Sequential(
844
+ conv,
845
+ nn.Dropout(p=drop_prob),
846
+ nn.Sequential(
847
+ Rearrange("... channels time -> ... time channels"),
848
+ nn.LayerNorm(output_channels, elementwise_affine=True),
849
+ Rearrange("... time channels -> ... channels time"),
850
+ ),
851
+ activation(),
852
+ )
853
+ elif is_group_norm:
854
+ return nn.Sequential(
855
+ conv,
856
+ nn.Dropout(p=drop_prob),
857
+ nn.GroupNorm(output_channels, output_channels, affine=True),
858
+ activation(),
859
+ )
860
+ else:
861
+ return nn.Sequential(conv, nn.Dropout(p=drop_prob), activation())
862
+
863
+ def n_times_out(self, n_times):
864
+ return _n_times_out(self.conv_layers_spec, n_times)
865
+
866
+
867
+ class _ChannelEmbedding(nn.Embedding):
868
+ r"""Embedding layer for EEG channels.
869
+
870
+ The difference with a regular :class:`nn.Embedding` is that the embedding
871
+ vectors are initialized with a positional encodding of the channel locations.
872
+
873
+ Parameters
874
+ ----------
875
+ channel_locations: list of (list of float or None)
876
+ List of the n-dimensions locations of the EEG channels.
877
+ embedding_dim: int
878
+ Dimensionality of the embedding vectors. Must be a multiple of the number
879
+ of dimensions of the channel locations.
880
+ """
881
+
882
+ def __init__(
883
+ self, channel_locations: list[list[float] | None], embedding_dim: int, **kwargs
884
+ ):
885
+ self.coordinate_ranges = [
886
+ (min(coords), max(coords))
887
+ for coords in zip(
888
+ *[
889
+ loc[3:6] if len(loc) == 12 else loc
890
+ for loc in channel_locations
891
+ if loc is not None
892
+ ]
893
+ )
894
+ ]
895
+ channel_mins, channel_maxs = zip(*self.coordinate_ranges)
896
+ global_min = min(channel_mins)
897
+ global_max = max(channel_maxs)
898
+ self.max_abs_coordinate = max(abs(global_min), abs(global_max))
899
+ self.embedding_dim_per_coordinate = embedding_dim // len(self.coordinate_ranges)
900
+ self.channel_locations = list(channel_locations)
901
+
902
+ assert embedding_dim % len(self.coordinate_ranges) == 0
903
+
904
+ super().__init__(len(channel_locations), embedding_dim, **kwargs)
905
+
906
+ def reset_parameters(self):
907
+ for i, loc in enumerate(self.channel_locations):
908
+ if loc is None:
909
+ nn.init.zeros_(self.weight[i])
910
+ else:
911
+ for j, (x, (x0, x1)) in enumerate(zip(loc, self.coordinate_ranges)):
912
+ with torch.no_grad():
913
+ self.weight[
914
+ i,
915
+ j * self.embedding_dim_per_coordinate : (j + 1)
916
+ * self.embedding_dim_per_coordinate,
917
+ ].copy_(
918
+ _pos_encode_contineous(
919
+ x,
920
+ 0,
921
+ 10 * self.max_abs_coordinate,
922
+ self.embedding_dim_per_coordinate,
923
+ device=self.weight.device,
924
+ ),
925
+ )
926
+
927
+
928
+ class _PosEncoder(nn.Module):
929
+ r"""Positional encoder for EEG data.
930
+
931
+ Parameters
932
+ ----------
933
+ spat_dim: int
934
+ Number of dimensions to use to encode the spatial position of the patch,
935
+ i.e. the EEG channel.
936
+ time_dim: int
937
+ Number of dimensions to use to encode the temporal position of the patch.
938
+ ch_locs: list of list of float or 2d array
939
+ List of the n-dimensions locations of the EEG channels.
940
+ sfreq_features: float
941
+ The "downsampled" sampling frequency returned by the feature encoder.
942
+ spat_kwargs: dict
943
+ Additional keyword arguments to pass to the :class:`nn.Embedding` layer used to
944
+ embed the channel names.
945
+ max_seconds: float
946
+ Maximum number of seconds to consider for the temporal encoding.
947
+ """
948
+
949
+ def __init__(
950
+ self,
951
+ spat_dim: int,
952
+ time_dim: int,
953
+ ch_locs,
954
+ sfreq_features: float,
955
+ spat_kwargs: dict | None = None,
956
+ max_seconds: float = 600.0, # 10 minutes
957
+ ):
958
+ super().__init__()
959
+ spat_kwargs = spat_kwargs or {}
960
+ self.spat_dim = spat_dim
961
+ self.time_dim = time_dim
962
+ self.max_n_times = int(max_seconds * sfreq_features)
963
+
964
+ # Positional encoder for the spatial dimension:
965
+ self.pos_encoder_spat = _ChannelEmbedding(
966
+ ch_locs, spat_dim, **spat_kwargs
967
+ ) # (batch_size, n_channels, spat_dim)
968
+
969
+ # Pre-computed tensor for positional encoding on the time dimension:
970
+ self.encoding_time = torch.zeros(0, dtype=torch.float32, requires_grad=False)
971
+
972
+ def _check_encoding_time(self, n_times: int):
973
+ if self.encoding_time.size(0) < n_times:
974
+ self.encoding_time = self.encoding_time.new_empty((n_times, self.time_dim))
975
+ self.encoding_time[:] = _pos_encode_time(
976
+ n_times=n_times,
977
+ n_dim=self.time_dim,
978
+ max_n_times=self.max_n_times,
979
+ device=self.encoding_time.device,
980
+ )
981
+
982
+ def forward(self, local_features, ch_idxs: torch.Tensor | None = None):
983
+ """
984
+ Parameters
985
+ ----------
986
+ * local_features: (batch_size, n_chans * n_times_out, emb_dim)
987
+ * ch_idxs: (batch_size, n_chans) | None
988
+ Indices of the channels to use in the ``ch_names`` list passed
989
+ as argument plus one. Index 0 is reserved for an unknown channel.
990
+
991
+ Returns
992
+ -------
993
+ pos_encoding: (batch_size, n_chans * n_times_out, emb_dim)
994
+ The first ``spat_dim`` dimensions encode the channels positional encoding
995
+ and the following ``time_dim`` dimensions encode the temporal positional encoding.
996
+ """
997
+ batch_size, n_chans_times, emb_dim = local_features.shape
998
+ if ch_idxs is None:
999
+ ch_idxs = torch.arange(
1000
+ 0,
1001
+ self.pos_encoder_spat.num_embeddings,
1002
+ device=local_features.device,
1003
+ ).repeat(batch_size, 1)
1004
+
1005
+ batch_size_chs, n_chans = ch_idxs.shape
1006
+ assert emb_dim >= self.spat_dim + self.time_dim
1007
+ assert n_chans_times % n_chans == 0
1008
+ n_times = n_chans_times // n_chans
1009
+
1010
+ pos_encoding = local_features.new_empty(
1011
+ (batch_size_chs, n_chans, n_times, emb_dim)
1012
+ )
1013
+ # Channel pos. encoding
1014
+ pos_encoding[:, :, :, : self.spat_dim] = self.pos_encoder_spat(ch_idxs)[
1015
+ :, :, None, :
1016
+ ]
1017
+ # Temporal pos. encoding
1018
+ self._check_encoding_time(n_times)
1019
+ _ = pos_encoding[:, :, :, self.spat_dim : self.spat_dim + self.time_dim].copy_(
1020
+ self.encoding_time[None, None, :n_times, :],
1021
+ )
1022
+
1023
+ return pos_encoding.view(batch_size, n_chans_times, emb_dim)
1024
+
1025
+
1026
+ def _n_times_out(conv_layers_spec, n_times):
1027
+ # it would be equal to n_times//ds_factor without edge effects:
1028
+ n_times_out_ = n_times
1029
+ for _, width, stride in conv_layers_spec:
1030
+ n_times_out_ = int((n_times_out_ - width) / stride) + 1
1031
+ return n_times_out_
1032
+
1033
+
1034
+ def _get_out_emb_dim(conv_layers_spec, n_times, n_spat_filters=4):
1035
+ n_time_out = _n_times_out(conv_layers_spec, n_times)
1036
+ emb_dim = conv_layers_spec[-1][0]
1037
+ return n_spat_filters * n_time_out * emb_dim
1038
+
1039
+
1040
+ def _get_separable_clf_layer(
1041
+ conv_layers_spec, n_chans, n_times, n_classes, n_spat_filters=4
1042
+ ):
1043
+ out_emb_dim = _get_out_emb_dim(
1044
+ conv_layers_spec=conv_layers_spec,
1045
+ n_times=n_times,
1046
+ n_spat_filters=n_spat_filters,
1047
+ )
1048
+ clf_layer = nn.Sequential()
1049
+ clf_layer.add_module(
1050
+ "unflatten_tokens",
1051
+ Rearrange("b (n_chans tokens) d -> b 1 n_chans tokens d", n_chans=n_chans),
1052
+ )
1053
+ clf_layer.add_module("spat_conv", nn.Conv3d(1, n_spat_filters, (n_chans, 1, 1)))
1054
+ clf_layer.add_module("flatten", nn.Flatten(start_dim=1))
1055
+ clf_layer.add_module("linear", nn.Linear(out_emb_dim, n_classes))
1056
+ return clf_layer
1057
+
1058
+
1059
+ def _pos_encode_time(
1060
+ n_times: int,
1061
+ n_dim: int,
1062
+ max_n_times: int,
1063
+ device: torch.device = torch.device("cpu"),
1064
+ ):
1065
+ """1-dimensional positional encoding.
1066
+
1067
+ Parameters
1068
+ ----------
1069
+ n_times: int
1070
+ Number of time samples to encode.
1071
+ n_dim: int
1072
+ Number of dimensions of the positional encoding. Must be even.
1073
+ max_n_times: int
1074
+ The largest possible number of time samples to encode.
1075
+ Used to scale the positional encoding.
1076
+ device: torch.device
1077
+ Device to put the output on.
1078
+ Returns
1079
+ -------
1080
+ pos_encoding: (n_times, n_dim)
1081
+ """
1082
+ assert n_dim % 2 == 0
1083
+ position = torch.arange(n_times, device=device).unsqueeze(1)
1084
+ div_term = torch.exp(
1085
+ torch.arange(0, n_dim, 2, device=device) * (-math.log(max_n_times) / n_dim)
1086
+ )
1087
+ pos_encoding = torch.empty((n_times, n_dim), dtype=torch.float32, device=device)
1088
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
1089
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
1090
+ return pos_encoding
1091
+
1092
+
1093
+ def _pos_encode_contineous(
1094
+ x, x_min, x_max, n_dim, device: torch.device = torch.device("cpu")
1095
+ ):
1096
+ """1-dimensional positional encoding.
1097
+
1098
+ Parameters
1099
+ ----------
1100
+ x: float
1101
+ The position to encode.
1102
+ x_min: float
1103
+ The minimum possible value of x.
1104
+ x_max: float
1105
+ The maximum possible value of x.
1106
+ n_dim: int
1107
+ Number of dimensions of the positional encoding. Must be even.
1108
+ device: torch.device
1109
+ Device to put the output on.
1110
+ Returns
1111
+ -------
1112
+ pos_encoding: (n_dim,)
1113
+ """
1114
+ assert n_dim % 2 == 0
1115
+ div_term = torch.exp(
1116
+ (1 - torch.arange(0, n_dim, 2, device=device) / n_dim) * 2 * math.pi
1117
+ )
1118
+ pos_encoding = torch.empty((n_dim,), dtype=torch.float32, device=device)
1119
+ xx = (x - x_min) / (x_max - x_min)
1120
+ pos_encoding[0::2] = torch.sin(xx * div_term)
1121
+ pos_encoding[1::2] = torch.cos(xx * div_term)
1122
+ return pos_encoding