braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1300 @@
1
+ # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
+ # Bruna Lopes <brunajaflopes@gmail.com>
5
+ #
6
+ # License: BSD (3-clause)
7
+
8
+ from __future__ import annotations
9
+
10
+ from numbers import Real
11
+ from typing import Literal
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import torch
16
+ from mne.filter import notch_filter
17
+ from scipy.interpolate import Rbf
18
+ from sklearn.utils import check_random_state
19
+ from torch.fft import fft, ifft
20
+ from torch.nn.functional import one_hot, pad
21
+
22
+
23
+ def identity(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
24
+ """Identity operation.
25
+
26
+ Parameters
27
+ ----------
28
+ X : torch.Tensor
29
+ EEG input example or batch.
30
+ y : torch.Tensor
31
+ EEG labels for the example or batch.
32
+
33
+ Returns
34
+ -------
35
+ torch.Tensor
36
+ Transformed inputs.
37
+ torch.Tensor
38
+ Transformed labels.
39
+ """
40
+ return X, y
41
+
42
+
43
+ def time_reverse(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
44
+ """Flip the time axis of each input.
45
+
46
+ Parameters
47
+ ----------
48
+ X : torch.Tensor
49
+ EEG input example or batch.
50
+ y : torch.Tensor
51
+ EEG labels for the example or batch.
52
+
53
+ Returns
54
+ -------
55
+ torch.Tensor
56
+ Transformed inputs.
57
+ torch.Tensor
58
+ Transformed labels.
59
+ """
60
+ return torch.flip(X, [-1]), y
61
+
62
+
63
+ def sign_flip(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
64
+ """Flip the sign axis of each input.
65
+
66
+ Parameters
67
+ ----------
68
+ X : torch.Tensor
69
+ EEG input example or batch.
70
+ y : torch.Tensor
71
+ EEG labels for the example or batch.
72
+
73
+ Returns
74
+ -------
75
+ torch.Tensor
76
+ Transformed inputs.
77
+ torch.Tensor
78
+ Transformed labels.
79
+ """
80
+ return -X, y
81
+
82
+
83
+ def _new_random_fft_phase_odd(
84
+ batch_size: int,
85
+ c: int,
86
+ n: int,
87
+ device: torch.device,
88
+ random_state: int | np.random.RandomState | None,
89
+ ) -> torch.Tensor:
90
+ rng = check_random_state(random_state)
91
+ random_phase = torch.from_numpy(
92
+ 2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
93
+ ).to(device)
94
+ return torch.cat(
95
+ [
96
+ torch.zeros((batch_size, c, 1), device=device),
97
+ random_phase,
98
+ -torch.flip(random_phase, [-1]),
99
+ ],
100
+ dim=-1,
101
+ )
102
+
103
+
104
+ def _new_random_fft_phase_even(
105
+ batch_size: int,
106
+ c: int,
107
+ n: int,
108
+ device: torch.device,
109
+ random_state: int | np.random.RandomState | None,
110
+ ) -> torch.Tensor:
111
+ rng = check_random_state(random_state)
112
+ random_phase = torch.from_numpy(
113
+ 2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
114
+ ).to(device)
115
+ return torch.cat(
116
+ [
117
+ torch.zeros((batch_size, c, 1), device=device),
118
+ random_phase,
119
+ torch.zeros((batch_size, c, 1), device=device),
120
+ -torch.flip(random_phase, [-1]),
121
+ ],
122
+ dim=-1,
123
+ )
124
+
125
+
126
+ _new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
127
+
128
+
129
+ def ft_surrogate(
130
+ X: torch.Tensor,
131
+ y: torch.Tensor,
132
+ phase_noise_magnitude: float,
133
+ channel_indep: bool,
134
+ random_state: int | np.random.RandomState | None = None,
135
+ ) -> tuple[torch.Tensor, torch.Tensor]:
136
+ """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
137
+
138
+ Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
139
+ and modified.
140
+
141
+ Parameters
142
+ ----------
143
+ X : torch.Tensor
144
+ EEG input example or batch.
145
+ y : torch.Tensor
146
+ EEG labels for the example or batch.
147
+ phase_noise_magnitude : float
148
+ Float between 0 and 1 setting the range over which the phase
149
+ perturbation is uniformly sampled:
150
+ [0, `phase_noise_magnitude` * 2 * `pi`].
151
+ channel_indep : bool
152
+ Whether to sample phase perturbations independently for each channel or
153
+ not. It is advised to set it to False when spatial information is
154
+ important for the task, like in BCI.
155
+ random_state : int | numpy.random.Generator, optional
156
+ Used to draw the phase perturbation. Defaults to None.
157
+
158
+ Returns
159
+ -------
160
+ torch.Tensor
161
+ Transformed inputs.
162
+ torch.Tensor
163
+ Transformed labels.
164
+
165
+ References
166
+ ----------
167
+ .. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
168
+ Clifford, G. D. (2018). Addressing Class Imbalance in Classification
169
+ Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
170
+ preprint arXiv:1806.08675.
171
+ """
172
+ assert (
173
+ isinstance(
174
+ phase_noise_magnitude, (Real, torch.FloatTensor, torch.cuda.FloatTensor)
175
+ )
176
+ and 0 <= phase_noise_magnitude <= 1
177
+ ), f"eps must be a float between 0 and 1. Got {phase_noise_magnitude}."
178
+
179
+ f = fft(X.double(), dim=-1)
180
+ device = X.device
181
+
182
+ n = f.shape[-1]
183
+ random_phase = _new_random_fft_phase[n % 2](
184
+ f.shape[0],
185
+ f.shape[-2] if channel_indep else 1,
186
+ n,
187
+ device=device,
188
+ random_state=random_state,
189
+ )
190
+ if not channel_indep:
191
+ random_phase = torch.tile(random_phase, (1, f.shape[-2], 1))
192
+ if isinstance(phase_noise_magnitude, torch.Tensor):
193
+ phase_noise_magnitude = phase_noise_magnitude.to(device)
194
+ f_shifted = f * torch.exp(phase_noise_magnitude * random_phase)
195
+ shifted = ifft(f_shifted, dim=-1)
196
+ transformed_X = shifted.real.float()
197
+
198
+ return transformed_X, y
199
+
200
+
201
+ def _pick_channels_randomly(
202
+ X: torch.Tensor, p_pick: float, random_state: int | np.random.RandomState | None
203
+ ) -> torch.Tensor:
204
+ rng = check_random_state(random_state)
205
+ batch_size, n_channels, _ = X.shape
206
+ # allows to use the same RNG
207
+ unif_samples = torch.as_tensor(
208
+ rng.uniform(0, 1, size=(batch_size, n_channels)),
209
+ dtype=torch.float,
210
+ device=X.device,
211
+ )
212
+ # equivalent to a 0s and 1s mask
213
+ return torch.sigmoid(1000 * (unif_samples - p_pick))
214
+
215
+
216
+ def channels_dropout(
217
+ X: torch.Tensor,
218
+ y: torch.Tensor,
219
+ p_drop: float,
220
+ random_state: int | np.random.RandomState | None = None,
221
+ ) -> tuple[torch.Tensor, torch.Tensor]:
222
+ """Randomly set channels to flat signal.
223
+
224
+ Part of the CMSAugment policy proposed in [1]_
225
+
226
+ Parameters
227
+ ----------
228
+ X : torch.Tensor
229
+ EEG input example or batch.
230
+ y : torch.Tensor
231
+ EEG labels for the example or batch.
232
+ p_drop : float
233
+ Float between 0 and 1 setting the probability of dropping each channel.
234
+ random_state : int | numpy.random.Generator, optional
235
+ Seed to be used to instantiate numpy random number generator instance.
236
+ Defaults to None.
237
+
238
+ Returns
239
+ -------
240
+ torch.Tensor
241
+ Transformed inputs.
242
+ torch.Tensor
243
+ Transformed labels.
244
+
245
+ References
246
+ ----------
247
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
248
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
249
+ Reordering. arXiv preprint arXiv:2010.13694.
250
+ """
251
+ mask = _pick_channels_randomly(X, p_drop, random_state=random_state)
252
+ return X * mask.unsqueeze(-1), y
253
+
254
+
255
+ def _make_permutation_matrix(
256
+ X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
257
+ ) -> torch.Tensor:
258
+ rng = check_random_state(random_state)
259
+ batch_size, n_channels, _ = X.shape
260
+ hard_mask = mask.round()
261
+ batch_permutations = torch.empty(
262
+ batch_size, n_channels, n_channels, device=X.device
263
+ )
264
+ for b, mask in enumerate(hard_mask):
265
+ channels_to_shuffle = torch.arange(n_channels, device=X.device)
266
+ channels_to_shuffle = channels_to_shuffle[mask.bool()]
267
+ reordered_channels = torch.tensor(
268
+ rng.permutation(channels_to_shuffle.cpu()), device=X.device
269
+ )
270
+ channels_permutation = torch.arange(n_channels, device=X.device)
271
+ channels_permutation[channels_to_shuffle] = reordered_channels
272
+ batch_permutations[b, ...] = one_hot(channels_permutation)
273
+ return batch_permutations
274
+
275
+
276
+ def channels_shuffle(
277
+ X: torch.Tensor,
278
+ y: torch.Tensor,
279
+ p_shuffle: float,
280
+ random_state: int | np.random.RandomState | None = None,
281
+ ) -> tuple[torch.Tensor, torch.Tensor]:
282
+ """Randomly shuffle channels in EEG data matrix.
283
+
284
+ Part of the CMSAugment policy proposed in [1]_
285
+
286
+ Parameters
287
+ ----------
288
+ X : torch.Tensor
289
+ EEG input example or batch.
290
+ y : torch.Tensor
291
+ EEG labels for the example or batch.
292
+ p_shuffle : float | None
293
+ Float between 0 and 1 setting the probability of including the channel
294
+ in the set of permutted channels.
295
+ random_state : int | numpy.random.Generator, optional
296
+ Seed to be used to instantiate numpy random number generator instance.
297
+ Used to sample which channels to shuffle and to carry the shuffle.
298
+ Defaults to None.
299
+
300
+ Returns
301
+ -------
302
+ torch.Tensor
303
+ Transformed inputs.
304
+ torch.Tensor
305
+ Transformed labels.
306
+
307
+ References
308
+ ----------
309
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
310
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
311
+ Reordering. arXiv preprint arXiv:2010.13694.
312
+ """
313
+ if p_shuffle == 0:
314
+ return X, y
315
+ mask = _pick_channels_randomly(X, 1 - p_shuffle, random_state)
316
+ batch_permutations = _make_permutation_matrix(X, mask, random_state)
317
+ return torch.matmul(batch_permutations, X), y
318
+
319
+
320
+ def gaussian_noise(
321
+ X: torch.Tensor,
322
+ y: torch.Tensor,
323
+ std: float,
324
+ random_state: int | np.random.RandomState | None = None,
325
+ ) -> tuple[torch.Tensor, torch.Tensor]:
326
+ """Randomly add white Gaussian noise to all channels.
327
+
328
+ Suggested e.g. in [1]_, [2]_ and [3]_
329
+
330
+ Parameters
331
+ ----------
332
+ X : torch.Tensor
333
+ EEG input example or batch.
334
+ y : torch.Tensor
335
+ EEG labels for the example or batch.
336
+ std : float
337
+ Standard deviation to use for the additive noise.
338
+ random_state : int | numpy.random.Generator, optional
339
+ Seed to be used to instantiate numpy random number generator instance.
340
+ Defaults to None.
341
+
342
+ Returns
343
+ -------
344
+ torch.Tensor
345
+ Transformed inputs.
346
+ torch.Tensor
347
+ Transformed labels.
348
+
349
+ References
350
+ ----------
351
+ .. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
352
+ augmentation for eeg-based emotion recognition with deep convolutional
353
+ neural networks. In International Conference on Multimedia Modeling
354
+ (pp. 82-93).
355
+ .. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
356
+ Subject-aware contrastive learning for biosignals. arXiv preprint
357
+ arXiv:2007.04871.
358
+ .. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
359
+ Representation Learning for Electroencephalogram Classification. In
360
+ Machine Learning for Health (pp. 238-253). PMLR.
361
+ """
362
+ rng = check_random_state(random_state)
363
+ if isinstance(std, torch.Tensor):
364
+ std = std.to(X.device)
365
+ noise = (
366
+ torch.from_numpy(
367
+ rng.normal(loc=np.zeros(X.shape), scale=1),
368
+ )
369
+ .float()
370
+ .to(X.device)
371
+ * std
372
+ )
373
+ transformed_X = X + noise
374
+ return transformed_X, y
375
+
376
+
377
+ def channels_permute(
378
+ X: torch.Tensor, y: torch.Tensor, permutation: list[int]
379
+ ) -> tuple[torch.Tensor, torch.Tensor]:
380
+ """Permute EEG channels according to fixed permutation matrix.
381
+
382
+ Suggested e.g. in [1]_
383
+
384
+ Parameters
385
+ ----------
386
+ X : torch.Tensor
387
+ EEG input example or batch.
388
+ y : torch.Tensor
389
+ EEG labels for the example or batch.
390
+ permutation : list
391
+ List of integers defining the new channels order.
392
+
393
+ Returns
394
+ -------
395
+ torch.Tensor
396
+ Transformed inputs.
397
+ torch.Tensor
398
+ Transformed labels.
399
+
400
+ References
401
+ ----------
402
+ .. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
403
+ (2018). HAMLET: interpretable human and machine co-learning technique.
404
+ arXiv preprint arXiv:1803.09702.
405
+ """
406
+ return X[..., permutation, :], y
407
+
408
+
409
+ def smooth_time_mask(
410
+ X: torch.Tensor,
411
+ y: torch.Tensor,
412
+ mask_start_per_sample: torch.Tensor,
413
+ mask_len_samples: int,
414
+ ) -> tuple[torch.Tensor, torch.Tensor]:
415
+ """Smoothly replace a contiguous part of all channels by zeros.
416
+
417
+ Originally proposed in [1]_ and [2]_
418
+
419
+ Parameters
420
+ ----------
421
+ X : torch.Tensor
422
+ EEG input example or batch.
423
+ y : torch.Tensor
424
+ EEG labels for the example or batch.
425
+ mask_start_per_sample : torch.tensor
426
+ Tensor of integers containing the position (in last dimension) where to
427
+ start masking the signal. Should have the same size as the first
428
+ dimension of X (i.e. one start position per example in the batch).
429
+ mask_len_samples : int
430
+ Number of consecutive samples to zero out.
431
+
432
+ Returns
433
+ -------
434
+ torch.Tensor
435
+ Transformed inputs.
436
+ torch.Tensor
437
+ Transformed labels.
438
+
439
+ References
440
+ ----------
441
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
442
+ Subject-aware contrastive learning for biosignals. arXiv preprint
443
+ arXiv:2007.04871.
444
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
445
+ Representation Learning for Electroencephalogram Classification. In
446
+ Machine Learning for Health (pp. 238-253). PMLR.
447
+ """
448
+ batch_size, n_channels, seq_len = X.shape
449
+ t = torch.arange(seq_len, device=X.device).float()
450
+ t = t.repeat(batch_size, n_channels, 1)
451
+ mask_start_per_sample = mask_start_per_sample.view(-1, 1, 1)
452
+ s = 1000 / seq_len
453
+ mask = (
454
+ (
455
+ torch.sigmoid(s * -(t - mask_start_per_sample))
456
+ + torch.sigmoid(s * (t - mask_start_per_sample - mask_len_samples))
457
+ )
458
+ .float()
459
+ .to(X.device)
460
+ )
461
+ return X * mask, y
462
+
463
+
464
+ def bandstop_filter(
465
+ X: torch.Tensor,
466
+ y: torch.Tensor,
467
+ sfreq: float,
468
+ bandwidth: float,
469
+ freqs_to_notch: npt.ArrayLike | None,
470
+ ) -> tuple[torch.Tensor, torch.Tensor]:
471
+ """Apply a band-stop filter with desired bandwidth at the desired frequency.
472
+
473
+ position.
474
+
475
+ Suggested e.g. in [1]_ and [2]_
476
+
477
+ Parameters
478
+ ----------
479
+ X : torch.Tensor
480
+ EEG input example or batch.
481
+ y : torch.Tensor
482
+ EEG labels for the example or batch.
483
+ sfreq : float
484
+ Sampling frequency of the signals to be filtered.
485
+ bandwidth : float
486
+ Bandwidth of the filter, i.e. distance between the low and high cut
487
+ frequencies.
488
+ freqs_to_notch : array-like | None
489
+ Array of floats of size ``(batch_size,)`` containing the center of the
490
+ frequency band to filter out for each sample in the batch. Frequencies
491
+ should be greater than ``bandwidth/2 + transition`` and lower than
492
+ ``sfreq/2 - bandwidth/2 - transition`` (where ``transition = 1 Hz``).
493
+
494
+ Returns
495
+ -------
496
+ torch.Tensor
497
+ Transformed inputs.
498
+ torch.Tensor
499
+ Transformed labels.
500
+
501
+ References
502
+ ----------
503
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
504
+ Subject-aware contrastive learning for biosignals. arXiv preprint
505
+ arXiv:2007.04871.
506
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
507
+ Representation Learning for Electroencephalogram Classification. In
508
+ Machine Learning for Health (pp. 238-253). PMLR.
509
+ """
510
+ if bandwidth == 0 or freqs_to_notch is None:
511
+ return X, y
512
+ transformed_X = X.clone()
513
+ for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
514
+ sample = sample.cpu().numpy().astype(np.float64)
515
+ transformed_X[c] = torch.as_tensor(
516
+ notch_filter(
517
+ sample,
518
+ Fs=sfreq,
519
+ freqs=notched_freq,
520
+ method="fir",
521
+ notch_widths=bandwidth,
522
+ verbose=False,
523
+ )
524
+ )
525
+ return transformed_X, y
526
+
527
+
528
+ def _analytic_transform(x: torch.Tensor) -> torch.Tensor:
529
+ if torch.is_complex(x):
530
+ raise ValueError("x must be real.")
531
+
532
+ N = x.shape[-1]
533
+ f = fft(x, N, dim=-1)
534
+ h = torch.zeros_like(f)
535
+ if N % 2 == 0:
536
+ h[..., 0] = h[..., N // 2] = 1
537
+ h[..., 1 : N // 2] = 2
538
+ else:
539
+ h[..., 0] = 1
540
+ h[..., 1 : (N + 1) // 2] = 2
541
+
542
+ return ifft(f * h, dim=-1)
543
+
544
+
545
+ def _nextpow2(n: int) -> int:
546
+ """Return the first integer N such that 2**N >= abs(n)."""
547
+ return int(np.ceil(np.log2(np.abs(n))))
548
+
549
+
550
+ def _frequency_shift(X: torch.Tensor, fs: float, f_shift: float) -> torch.Tensor:
551
+ """
552
+ Shift the specified signal by the specified frequency.
553
+
554
+ See https://gist.github.com/lebedov/4428122
555
+ """
556
+ # Pad the signal with zeros to prevent the FFT invoked by the transform
557
+ # from slowing down the computation:
558
+ n_channels, N_orig = X.shape[-2:]
559
+ N_padded = 2 ** _nextpow2(N_orig)
560
+ t = torch.arange(N_padded, device=X.device) / fs
561
+ padded = pad(X, (0, N_padded - N_orig))
562
+ analytical = _analytic_transform(padded)
563
+ if isinstance(f_shift, torch.Tensor):
564
+ _f_shift = f_shift
565
+ elif isinstance(f_shift, (float, int, np.ndarray, list)):
566
+ _f_shift = torch.as_tensor(f_shift).float()
567
+ else:
568
+ raise ValueError(f"Invalid f_shift type: {type(f_shift)}")
569
+ f_shift_stack = _f_shift.repeat(N_padded, n_channels, 1)
570
+ reshaped_f_shift = f_shift_stack.permute(
571
+ *torch.arange(f_shift_stack.ndim - 1, -1, -1)
572
+ )
573
+ shifted = analytical * torch.exp(2j * np.pi * reshaped_f_shift * t)
574
+ return shifted[..., :N_orig].real.float()
575
+
576
+
577
+ def frequency_shift(
578
+ X: torch.Tensor, y: torch.Tensor, delta_freq: float, sfreq: float
579
+ ) -> tuple[torch.Tensor, torch.Tensor]:
580
+ """Adds a shift in the frequency domain to all channels.
581
+
582
+ Note that here, the shift is the same for all channels of a single example.
583
+
584
+ Parameters
585
+ ----------
586
+ X : torch.Tensor
587
+ EEG input example or batch.
588
+ y : torch.Tensor
589
+ EEG labels for the example or batch.
590
+ delta_freq : float
591
+ The amplitude of the frequency shift (in Hz).
592
+ sfreq : float
593
+ Sampling frequency of the signals to be transformed.
594
+
595
+ Returns
596
+ -------
597
+ torch.Tensor
598
+ Transformed inputs.
599
+ torch.Tensor
600
+ Transformed labels.
601
+ """
602
+ transformed_X = _frequency_shift(
603
+ X=X,
604
+ fs=sfreq,
605
+ f_shift=delta_freq,
606
+ )
607
+ return transformed_X, y
608
+
609
+
610
+ def _torch_normalize_vectors(rr: torch.Tensor) -> torch.Tensor:
611
+ """Normalize surface vertices."""
612
+ norm = torch.linalg.norm(rr, axis=1, keepdim=True)
613
+ mask = norm > 0
614
+ norm[~mask] = 1 # in case norm is zero, divide by 1
615
+ new_rr = rr / norm
616
+ return new_rr
617
+
618
+
619
+ def _torch_legval(
620
+ x: torch.Tensor, c: torch.Tensor, tensor: bool = True
621
+ ) -> torch.Tensor:
622
+ """
623
+ Evaluate a Legendre series at points x.
624
+
625
+ If `c` is of length `n + 1`, this function returns the value:
626
+ .. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
627
+ The parameter `x` is converted to an array only if it is a tuple or a
628
+ list, otherwise it is treated as a scalar. In either case, either `x`
629
+ or its elements must support multiplication and addition both with
630
+ themselves and with the elements of `c`.
631
+ If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If
632
+ `c` is multidimensional, then the shape of the result depends on the
633
+ value of `tensor`. If `tensor` is true the shape will be c.shape[1:] +
634
+ x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that
635
+ scalars have shape (,).
636
+ Trailing zeros in the coefficients will be used in the evaluation, so
637
+ they should be avoided if efficiency is a concern.
638
+
639
+ Parameters
640
+ ----------
641
+ x : array_like, compatible object
642
+ If `x` is a list or tuple, it is converted to an ndarray, otherwise
643
+ it is left unchanged and treated as a scalar. In either case, `x`
644
+ or its elements must support addition and multiplication with
645
+ with themselves and with the elements of `c`.
646
+ c : array_like
647
+ Array of coefficients ordered so that the coefficients for terms of
648
+ degree n are contained in c[n]. If `c` is multidimensional the
649
+ remaining indices enumerate multiple polynomials. In the two
650
+ dimensional case the coefficients may be thought of as stored in
651
+ the columns of `c`.
652
+ tensor : boolean, optional
653
+ If True, the shape of the coefficient array is extended with ones
654
+ on the right, one for each dimension of `x`. Scalars have dimension 0
655
+ for this action. The result is that every column of coefficients in
656
+ `c` is evaluated for every element of `x`. If False, `x` is broadcast
657
+ over the columns of `c` for the evaluation. This keyword is useful
658
+ when `c` is multidimensional. The default value is True.
659
+ .. versionadded:: 1.7.0
660
+
661
+ Returns
662
+ -------
663
+ values : ndarray, algebra_like
664
+ The shape of the return value is described above.
665
+
666
+ See Also
667
+ --------
668
+ legval2d, leggrid2d, legval3d, leggrid3d
669
+
670
+ Notes
671
+ -----
672
+ Code copied and modified from Numpy:
673
+ https://github.com/numpy/numpy/blob/v1.20.0/numpy/polynomial/legendre.py#L835-L920
674
+
675
+ Copyright (c) 2005-2021, NumPy Developers.
676
+ All rights reserved.
677
+
678
+ Redistribution and use in source and binary forms, with or without
679
+ modification, are permitted provided that the following conditions are
680
+ met:
681
+ * Redistributions of source code must retain the above copyright
682
+ notice, this list of conditions and the following disclaimer.
683
+ * Redistributions in binary form must reproduce the above
684
+ copyright notice, this list of conditions and the following
685
+ disclaimer in the documentation and/or other materials provided
686
+ with the distribution.
687
+ * Neither the name of the NumPy Developers nor the names of any
688
+ contributors may be used to endorse or promote products derived
689
+ from this software without specific prior written permission.
690
+
691
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
692
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
693
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
694
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
695
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
696
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
697
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
698
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
699
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
700
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
701
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
702
+ """
703
+ c = torch.as_tensor(c)
704
+ c = c.double()
705
+ if isinstance(x, (tuple, list)):
706
+ x = torch.as_tensor(x)
707
+ if isinstance(x, torch.Tensor) and tensor:
708
+ c = c.view(c.shape + (1,) * x.ndim)
709
+
710
+ c = c.to(x.device)
711
+
712
+ if len(c) == 1:
713
+ c0 = c[0]
714
+ c1 = 0
715
+ elif len(c) == 2:
716
+ c0 = c[0]
717
+ c1 = c[1]
718
+ else:
719
+ nd = len(c)
720
+ c0 = c[-2]
721
+ c1 = c[-1]
722
+ for i in range(3, len(c) + 1):
723
+ tmp = c0
724
+ nd = nd - 1
725
+ c0 = c[-i] - (c1 * (nd - 1)) / nd
726
+ c1 = tmp + (c1 * x * (2 * nd - 1)) / nd
727
+ return c0 + c1 * x
728
+
729
+
730
+ def _torch_calc_g(
731
+ cosang: torch.Tensor, stiffness: float = 4, n_legendre_terms: int = 50
732
+ ) -> torch.Tensor:
733
+ """Calculate spherical spline g function between points on a sphere.
734
+
735
+ Parameters
736
+ ----------
737
+ cosang : array-like of float, shape(n_channels, n_channels)
738
+ cosine of angles between pairs of points on a spherical surface. This
739
+ is equivalent to the dot product of unit vectors.
740
+ stiffness : float
741
+ stiffness of the spline.
742
+ n_legendre_terms : int
743
+ number of Legendre terms to evaluate.
744
+
745
+ Returns
746
+ -------
747
+ G : np.ndrarray of float, shape(n_channels, n_channels)
748
+ The G matrix.
749
+
750
+ Notes
751
+ -----
752
+ Code copied and modified from MNE-Python:
753
+ https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L35
754
+
755
+ Copyright © 2011-2019, authors of MNE-Python
756
+ All rights reserved.
757
+
758
+ Redistribution and use in source and binary forms, with or without
759
+ modification, are permitted provided that the following conditions are met:
760
+ * Redistributions of source code must retain the above copyright
761
+ notice, this list of conditions and the following disclaimer.
762
+ * Redistributions in binary form must reproduce the above copyright
763
+ notice, this list of conditions and the following disclaimer in the
764
+ documentation and/or other materials provided with the distribution.
765
+ * Neither the name of the copyright holder nor the names of its
766
+ contributors may be used to endorse or promote products derived from
767
+ this software without specific prior written permission.
768
+
769
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
770
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
771
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
772
+ ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
773
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
774
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
775
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
776
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
777
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
778
+ OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
779
+ DAMAGE.
780
+ """
781
+ factors = [
782
+ (2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
783
+ for n in range(1, n_legendre_terms + 1)
784
+ ]
785
+ return _torch_legval(cosang, [0] + factors)
786
+
787
+
788
+ def _torch_make_interpolation_matrix(
789
+ pos_from: torch.Tensor, pos_to: torch.Tensor, alpha: float = 1e-5
790
+ ) -> torch.Tensor:
791
+ """Compute interpolation matrix based on spherical splines.
792
+
793
+ Implementation based on [1]_
794
+
795
+ Parameters
796
+ ----------
797
+ pos_from : torch.Tensor of float, shape(n_good_sensors, 3)
798
+ The positions to interpolate from.
799
+ pos_to : torch.Tensor of float, shape(n_bad_sensors, 3)
800
+ The positions to interpolate.
801
+ alpha : float
802
+ Regularization parameter. Defaults to 1e-5.
803
+
804
+ Returns
805
+ -------
806
+ interpolation : torch.Tensor of float, shape(len(pos_from), len(pos_to))
807
+ The interpolation matrix that maps good signals to the location
808
+ of bad signals.
809
+
810
+ Notes
811
+ -----
812
+ Code copied and modified from MNE-Python:
813
+ https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L59
814
+
815
+ Copyright © 2011-2019, authors of MNE-Python
816
+ All rights reserved.
817
+
818
+ Redistribution and use in source and binary forms, with or without
819
+ modification, are permitted provided that the following conditions are met:
820
+ * Redistributions of source code must retain the above copyright
821
+ notice, this list of conditions and the following disclaimer.
822
+ * Redistributions in binary form must reproduce the above copyright
823
+ notice, this list of conditions and the following disclaimer in the
824
+ documentation and/or other materials provided with the distribution.
825
+ * Neither the name of the copyright holder nor the names of its
826
+ contributors may be used to endorse or promote products derived from
827
+ this software without specific prior written permission.
828
+
829
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
830
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
831
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
832
+ ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
833
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
834
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
835
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
836
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
837
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
838
+ OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
839
+ DAMAGE.
840
+
841
+ References
842
+ ----------
843
+ [1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
844
+ Spherical splines for scalp potential and current density mapping.
845
+ Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
846
+ """
847
+ pos_from = pos_from.clone()
848
+ pos_to = pos_to.clone()
849
+ n_from = pos_from.shape[0]
850
+ n_to = pos_to.shape[0]
851
+
852
+ # normalize sensor positions to sphere
853
+ pos_from = _torch_normalize_vectors(pos_from)
854
+ pos_to = _torch_normalize_vectors(pos_to)
855
+
856
+ # cosine angles between source positions
857
+ cosang_from = torch.matmul(pos_from, pos_from.T)
858
+ cosang_to_from = torch.matmul(pos_to, pos_from.T)
859
+ G_from = _torch_calc_g(cosang_from)
860
+ G_to_from = _torch_calc_g(cosang_to_from)
861
+ assert G_from.shape == (n_from, n_from)
862
+ assert G_to_from.shape == (n_to, n_from)
863
+
864
+ if alpha is not None:
865
+ G_from.flatten()[:: len(G_from) + 1] += alpha
866
+
867
+ device = G_from.device
868
+ C = torch.vstack(
869
+ [
870
+ torch.hstack([G_from, torch.ones((n_from, 1), device=device)]),
871
+ torch.hstack(
872
+ [
873
+ torch.ones((1, n_from), device=device),
874
+ torch.as_tensor([[0]], device=device),
875
+ ]
876
+ ),
877
+ ]
878
+ )
879
+
880
+ try:
881
+ C_inv = torch.linalg.inv(C)
882
+ except torch._C._LinAlgError:
883
+ # There is a stability issue with pinv since torch v1.8.0
884
+ # see https://github.com/pytorch/pytorch/issues/75494
885
+ C_inv = torch.linalg.pinv(C.cpu()).to(device)
886
+
887
+ interpolation = torch.hstack(
888
+ [G_to_from, torch.ones((n_to, 1), device=device)]
889
+ ).matmul(C_inv[:, :-1])
890
+ assert interpolation.shape == (n_to, n_from)
891
+ return interpolation
892
+
893
+
894
+ def _rotate_signals(
895
+ X: torch.Tensor,
896
+ rotations: list[torch.Tensor],
897
+ sensors_positions_matrix: torch.Tensor,
898
+ spherical: bool = True,
899
+ ) -> torch.Tensor:
900
+ sensors_positions_matrix = sensors_positions_matrix.to(X.device)
901
+ rot_sensors_matrices = [
902
+ rotation.matmul(sensors_positions_matrix) for rotation in rotations
903
+ ]
904
+ if spherical:
905
+ interpolation_matrix = torch.stack(
906
+ [
907
+ torch.as_tensor(
908
+ _torch_make_interpolation_matrix(
909
+ sensors_positions_matrix.T, rot_sensors_matrix.T
910
+ ),
911
+ device=X.device,
912
+ ).float()
913
+ for rot_sensors_matrix in rot_sensors_matrices
914
+ ]
915
+ )
916
+ return torch.matmul(interpolation_matrix, X)
917
+ else:
918
+ transformed_X = X.clone()
919
+ sensors_positions = list(sensors_positions_matrix)
920
+ for s, rot_sensors_matrix in enumerate(rot_sensors_matrices):
921
+ rot_sensors_positions = list(rot_sensors_matrix.T)
922
+ for time in range(X.shape[-1]):
923
+ interpolator_t = Rbf(*sensors_positions, X[s, :, time])
924
+ transformed_X[s, :, time] = torch.from_numpy(
925
+ interpolator_t(*rot_sensors_positions)
926
+ )
927
+ return transformed_X
928
+
929
+
930
+ def _make_rotation_matrix(
931
+ axis: Literal["x", "y", "z"],
932
+ angle: float | int | np.ndarray | list | torch.Tensor,
933
+ degrees: bool = True,
934
+ ) -> torch.Tensor:
935
+ assert axis in ["x", "y", "z"], "axis should be either x, y or z."
936
+ if isinstance(angle, torch.Tensor):
937
+ _angle = angle
938
+ elif isinstance(angle, (float, int, np.ndarray, list)):
939
+ _angle = torch.as_tensor(angle)
940
+ else:
941
+ raise ValueError(f"Invalid angle type: {type(angle)}")
942
+
943
+ if degrees:
944
+ _angle = _angle * np.pi / 180
945
+
946
+ device = _angle.device
947
+ zero = torch.zeros(1, device=device)
948
+ rot = torch.stack(
949
+ [
950
+ torch.as_tensor([1, 0, 0], device=device),
951
+ torch.hstack([zero, torch.cos(_angle), -torch.sin(_angle)]),
952
+ torch.hstack([zero, torch.sin(_angle), torch.cos(_angle)]),
953
+ ]
954
+ )
955
+ if axis == "x":
956
+ return rot
957
+ elif axis == "y":
958
+ rot = rot[[2, 0, 1], :]
959
+ return rot[:, [2, 0, 1]]
960
+ else:
961
+ rot = rot[[1, 2, 0], :]
962
+ return rot[:, [1, 2, 0]]
963
+
964
+
965
+ def sensors_rotation(
966
+ X: torch.Tensor,
967
+ y: torch.Tensor,
968
+ sensors_positions_matrix: torch.Tensor,
969
+ axis: Literal["x", "y", "z"],
970
+ angles: npt.ArrayLike,
971
+ spherical_splines: bool,
972
+ ) -> tuple[torch.Tensor, torch.Tensor]:
973
+ """Interpolates EEG signals over sensors rotated around the desired axis.
974
+
975
+ with the desired angle.
976
+
977
+ Suggested in [1]_
978
+
979
+ Parameters
980
+ ----------
981
+ X : torch.Tensor
982
+ EEG input example or batch.
983
+ y : torch.Tensor
984
+ EEG labels for the example or batch.
985
+ sensors_positions_matrix : torch.Tensor
986
+ Matrix giving the positions of each sensor in a 3D cartesian coordinate
987
+ system. Should have shape (3, n_channels), where n_channels is the
988
+ number of channels. Standard 10-20 positions can be obtained from
989
+ ``mne`` through::
990
+
991
+ >>> ten_twenty_montage = mne.channels.make_standard_montage(
992
+ ... 'standard_1020'
993
+ ... ).get_positions()['ch_pos']
994
+ axis : 'x' | 'y' | 'z'
995
+ Axis around which to rotate.
996
+ angles : array-like
997
+ Array of float of shape ``(batch_size,)`` containing the rotation
998
+ angles (in degrees) for each element of the input batch.
999
+ spherical_splines : bool
1000
+ Whether to use spherical splines for the interpolation or not. When
1001
+ `False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
1002
+ used (as in the original paper).
1003
+
1004
+ References
1005
+ ----------
1006
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
1007
+ electroencephalographic data. In 2017 39th Annual International
1008
+ Conference of the IEEE Engineering in Medicine and Biology Society
1009
+ (EMBC) (pp. 471-474).
1010
+ """
1011
+ rots = [_make_rotation_matrix(axis, angle, degrees=True) for angle in angles]
1012
+ rotated_X = _rotate_signals(X, rots, sensors_positions_matrix, spherical_splines)
1013
+ return rotated_X, y
1014
+
1015
+
1016
+ def mixup(
1017
+ X: torch.Tensor, y: torch.Tensor, lam: torch.Tensor, idx_perm: torch.Tensor
1018
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1019
+ """Mixes two channels of EEG data.
1020
+
1021
+ See [1]_ for details.
1022
+ Implementation based on [2]_.
1023
+
1024
+ Parameters
1025
+ ----------
1026
+ X : torch.Tensor
1027
+ EEG data in form ``batch_size, n_channels, n_times``
1028
+ y : torch.Tensor
1029
+ Target of length ``batch_size``
1030
+ lam : torch.Tensor
1031
+ Values between 0 and 1 setting the linear interpolation between
1032
+ examples.
1033
+ idx_perm : torch.Tensor
1034
+ Permuted indices of example that are mixed into original examples.
1035
+
1036
+ Returns
1037
+ -------
1038
+ tuple
1039
+ ``X``, ``y``. Where ``X`` is augmented and ``y`` is a tuple of length
1040
+ 3 containing the labels of the two mixed channels and the mixing
1041
+ coefficient.
1042
+
1043
+ References
1044
+ ----------
1045
+ .. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
1046
+ (2018). mixup: Beyond Empirical Risk Minimization. In 2018
1047
+ International Conference on Learning Representations (ICLR)
1048
+ Online: https://arxiv.org/abs/1710.09412
1049
+ .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
1050
+ """
1051
+ device = X.device
1052
+ batch_size, n_channels, n_times = X.shape
1053
+
1054
+ X_mix = torch.zeros((batch_size, n_channels, n_times)).to(device)
1055
+ y_a = torch.arange(batch_size).to(device)
1056
+ y_b = torch.arange(batch_size).to(device)
1057
+
1058
+ for idx in range(batch_size):
1059
+ X_mix[idx] = lam[idx] * X[idx] + (1 - lam[idx]) * X[idx_perm[idx]]
1060
+ y_a[idx] = y[idx]
1061
+ y_b[idx] = y[idx_perm[idx]]
1062
+
1063
+ return X_mix, (y_a, y_b, lam)
1064
+
1065
+
1066
+ def segmentation_reconstruction(
1067
+ X: torch.Tensor,
1068
+ y: torch.Tensor,
1069
+ n_segments: int,
1070
+ data_classes: list[tuple[int, torch.Tensor]],
1071
+ rand_indices: npt.ArrayLike,
1072
+ idx_shuffle: npt.ArrayLike,
1073
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1074
+ """Segment and reconstruct EEG data from [1]_.
1075
+
1076
+ See [1]_ for details.
1077
+
1078
+ Parameters
1079
+ ----------
1080
+ X : torch.Tensor
1081
+ EEG input example or batch.
1082
+ y : torch.Tensor
1083
+ EEG labels for the example or batch.
1084
+ n_segments : int
1085
+ Number of segments to use in the batch.
1086
+ data_classes : list[tuple[int, torch.Tensor]]
1087
+ List of tuples. Each tuple contains the class index and the corresponding EEG data.
1088
+ rand_indices : array-like
1089
+ Array of indices that indicates which trial to use in each segment.
1090
+ idx_shuffle : array-like
1091
+ Array of indices to shuffle the new generated trials.
1092
+
1093
+ Returns
1094
+ -------
1095
+ torch.Tensor
1096
+ Transformed inputs.
1097
+ torch.Tensor
1098
+ Transformed labels.
1099
+
1100
+ References
1101
+ ----------
1102
+ .. [1] Lotte, F. (2015). Signal processing approaches to minimize or
1103
+ suppress calibration time in oscillatory activity-based brain–computer
1104
+ interfaces. Proceedings of the IEEE, 103(6), 871-890.
1105
+ """
1106
+
1107
+ # Initialize lists to store augmented data and corresponding labels
1108
+ aug_data: list[torch.Tensor] = []
1109
+ aug_label: list[torch.Tensor] = []
1110
+
1111
+ # Iterate through each class to separate and augment data
1112
+ for class_index, X_class in data_classes:
1113
+ # Determine class-specific dimensions
1114
+ # Store the augmented data and the corresponding class labels
1115
+ n_trials, n_channels, window_size = X_class.shape
1116
+ # Segment Size
1117
+ segment_size = window_size // n_segments
1118
+ # Initialize an empty tensor for augmented data
1119
+ X_aug = torch.zeros_like(X_class)
1120
+ # Generate random indices within the class-specific dataset
1121
+ rand_idx = rand_indices[class_index]
1122
+ for idx_segment in range(n_segments):
1123
+ start = idx_segment * segment_size
1124
+ end = (idx_segment + 1) * segment_size
1125
+
1126
+ # Perform the data augmentation
1127
+ X_aug[np.arange(n_trials), :, start:end] = X_class[
1128
+ rand_idx[:, idx_segment], :, start:end
1129
+ ]
1130
+ aug_data.append(X_aug)
1131
+ aug_label.append(torch.full((n_trials,), class_index))
1132
+ # Concatenate the augmented data and labels
1133
+ concat_aug_data = torch.cat(aug_data, dim=0)
1134
+ concat_aug_data = concat_aug_data.to(dtype=X.dtype, device=X.device)
1135
+ concat_aug_data = concat_aug_data[idx_shuffle]
1136
+
1137
+ if y is not None:
1138
+ concat_label = torch.cat(aug_label, dim=0)
1139
+ concat_label = concat_label.to(dtype=y.dtype, device=y.device)
1140
+ concat_label = concat_label[idx_shuffle]
1141
+ return concat_aug_data, concat_label
1142
+
1143
+ return concat_aug_data, None
1144
+
1145
+
1146
+ def mask_encoding(
1147
+ X: torch.Tensor,
1148
+ y: torch.Tensor,
1149
+ time_start: torch.Tensor,
1150
+ segment_length: int,
1151
+ n_segments: int,
1152
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1153
+ """Mark encoding from Ding et al (2024) from [ding2024]_.
1154
+
1155
+ Replaces a contiguous part (or parts) of all channels by zeros
1156
+ (if more than one segment, it may overlap).
1157
+
1158
+ Implementation based on [ding2024]_
1159
+
1160
+ Parameters
1161
+ ----------
1162
+ X : torch.Tensor
1163
+ EEG input example or batch.
1164
+ y : torch.Tensor
1165
+ EEG labels for the example or batch.
1166
+ time_start : torch.Tensor
1167
+ Tensor of integers containing the position (in last dimension) where to
1168
+ start masking the signal. Should have "n_segments" times the size of the first
1169
+ dimension of X (i.e. "n_segments" start positions per example in the batch).
1170
+ segment_length : int
1171
+ Length of each segment to zero out.
1172
+ n_segments : int
1173
+ Number of segments to zero out in each example.
1174
+
1175
+ Returns
1176
+ -------
1177
+ torch.Tensor
1178
+ Transformed inputs.
1179
+ torch.Tensor
1180
+ Transformed labels.
1181
+
1182
+ References
1183
+ ----------
1184
+ .. [ding2024] Ding, Wenlong, et al. A Novel Data Augmentation Approach
1185
+ Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI.
1186
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering
1187
+ 32 (2024): 875-886.
1188
+ """
1189
+
1190
+ batch_indices = torch.arange(X.shape[0]).repeat_interleave(n_segments)
1191
+ start_indices = time_start.flatten()
1192
+ mask_indices = start_indices[:, None] + torch.arange(segment_length)
1193
+
1194
+ # Create a boolean mask with the same shape as X
1195
+ mask = torch.zeros_like(X, dtype=torch.bool)
1196
+ for batch_index, grouped_mask_indices in zip(batch_indices, mask_indices):
1197
+ mask[batch_index, :, grouped_mask_indices] = True
1198
+
1199
+ # Apply the mask to set the values to 0
1200
+ X[mask] = 0
1201
+
1202
+ return X, y # Return the masked tensor and labels
1203
+
1204
+
1205
+ def channels_rereference(
1206
+ X: torch.Tensor,
1207
+ y: torch.Tensor,
1208
+ random_state: int | np.random.RandomState | None = None,
1209
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1210
+ """Randomly re-reference channels in EEG data matrix.
1211
+
1212
+ Part of the augmentations proposed in [1]_
1213
+
1214
+ Parameters
1215
+ ----------
1216
+ X : torch.Tensor
1217
+ EEG input example or batch.
1218
+ y : torch.Tensor
1219
+ EEG labels for the example or batch.
1220
+ random_state : int | numpy.random.Generator, optional
1221
+ Seed to be used to instantiate numpy random number generator instance.
1222
+ Defaults to None.
1223
+
1224
+ Returns
1225
+ -------
1226
+ torch.Tensor
1227
+ Transformed inputs.
1228
+ torch.Tensor
1229
+ Transformed labels.
1230
+
1231
+ References
1232
+ ----------
1233
+ .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1234
+ Representation Learning for Electroencephalogram Classification. Proceedings
1235
+ of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1236
+ Learning Research 136:238-253
1237
+ """
1238
+
1239
+ rng = check_random_state(random_state)
1240
+ batch_size, n_channels, _ = X.shape
1241
+
1242
+ ch = rng.randint(0, n_channels, size=batch_size)
1243
+
1244
+ X_ch = X[torch.arange(batch_size), ch, :]
1245
+ X = X - X_ch.unsqueeze(1)
1246
+ X[torch.arange(batch_size), ch, :] = -X_ch
1247
+
1248
+ return X, y
1249
+
1250
+
1251
+ def amplitude_scale(
1252
+ X: torch.Tensor,
1253
+ y: torch.Tensor,
1254
+ scale: tuple,
1255
+ random_state: int | np.random.RandomState | None = None,
1256
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1257
+ """Rescale amplitude of each channel based on a random sampled scaling value.
1258
+
1259
+ Part of the augmentations proposed in [1]_
1260
+
1261
+ Parameters
1262
+ ----------
1263
+ X : torch.Tensor
1264
+ EEG input example or batch.
1265
+ y : torch.Tensor
1266
+ EEG labels for the example or batch.
1267
+ scale : tuple of floats
1268
+ Interval from which ypu sample the scaling value
1269
+ random_state : int | numpy.random.Generator, optional
1270
+ Seed to be used to instantiate numpy random number generator instance.
1271
+ Defaults to None.
1272
+
1273
+ Returns
1274
+ -------
1275
+ torch.Tensor
1276
+ Transformed inputs.
1277
+ torch.Tensor
1278
+ Transformed labels.
1279
+
1280
+ References
1281
+ ----------
1282
+ .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1283
+ Representation Learning for Electroencephalogram Classification. Proceedings
1284
+ of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1285
+ Learning Research 136:238-253
1286
+ """
1287
+
1288
+ rng = torch.Generator()
1289
+ rng.manual_seed(random_state)
1290
+ batch_size, n_channels, _ = X.shape
1291
+
1292
+ # Parameter for scaling amplitude / channel / trial
1293
+ l, h = scale
1294
+ s = l + (h - l) * torch.rand(
1295
+ batch_size, n_channels, 1, generator=rng, device=X.device, dtype=X.dtype
1296
+ )
1297
+
1298
+ X = s * X
1299
+
1300
+ return X, y