braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1096 @@
1
+ # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
+ #
5
+ # License: BSD (3-clause)
6
+
7
+ from numbers import Real
8
+
9
+ import numpy as np
10
+ import torch
11
+ from mne.filter import notch_filter
12
+ from scipy.interpolate import Rbf
13
+ from sklearn.utils import check_random_state
14
+ from torch.fft import fft, ifft
15
+ from torch.nn.functional import one_hot, pad
16
+
17
+
18
+ def identity(X, y):
19
+ """Identity operation.
20
+
21
+ Parameters
22
+ ----------
23
+ X : torch.Tensor
24
+ EEG input example or batch.
25
+ y : torch.Tensor
26
+ EEG labels for the example or batch.
27
+
28
+ Returns
29
+ -------
30
+ torch.Tensor
31
+ Transformed inputs.
32
+ torch.Tensor
33
+ Transformed labels.
34
+ """
35
+ return X, y
36
+
37
+
38
+ def time_reverse(X, y):
39
+ """Flip the time axis of each input.
40
+
41
+ Parameters
42
+ ----------
43
+ X : torch.Tensor
44
+ EEG input example or batch.
45
+ y : torch.Tensor
46
+ EEG labels for the example or batch.
47
+
48
+ Returns
49
+ -------
50
+ torch.Tensor
51
+ Transformed inputs.
52
+ torch.Tensor
53
+ Transformed labels.
54
+ """
55
+ return torch.flip(X, [-1]), y
56
+
57
+
58
+ def sign_flip(X, y):
59
+ """Flip the sign axis of each input.
60
+
61
+ Parameters
62
+ ----------
63
+ X : torch.Tensor
64
+ EEG input example or batch.
65
+ y : torch.Tensor
66
+ EEG labels for the example or batch.
67
+
68
+ Returns
69
+ -------
70
+ torch.Tensor
71
+ Transformed inputs.
72
+ torch.Tensor
73
+ Transformed labels.
74
+ """
75
+ return -X, y
76
+
77
+
78
+ def _new_random_fft_phase_odd(batch_size, c, n, device, random_state):
79
+ rng = check_random_state(random_state)
80
+ random_phase = torch.from_numpy(
81
+ 2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
82
+ ).to(device)
83
+ return torch.cat(
84
+ [
85
+ torch.zeros((batch_size, c, 1), device=device),
86
+ random_phase,
87
+ -torch.flip(random_phase, [-1]),
88
+ ],
89
+ dim=-1,
90
+ )
91
+
92
+
93
+ def _new_random_fft_phase_even(batch_size, c, n, device, random_state):
94
+ rng = check_random_state(random_state)
95
+ random_phase = torch.from_numpy(
96
+ 2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
97
+ ).to(device)
98
+ return torch.cat(
99
+ [
100
+ torch.zeros((batch_size, c, 1), device=device),
101
+ random_phase,
102
+ torch.zeros((batch_size, c, 1), device=device),
103
+ -torch.flip(random_phase, [-1]),
104
+ ],
105
+ dim=-1,
106
+ )
107
+
108
+
109
+ _new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
110
+
111
+
112
+ def ft_surrogate(X, y, phase_noise_magnitude, channel_indep, random_state=None):
113
+ """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
114
+
115
+ Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
116
+ and modified.
117
+
118
+ Parameters
119
+ ----------
120
+ X : torch.Tensor
121
+ EEG input example or batch.
122
+ y : torch.Tensor
123
+ EEG labels for the example or batch.
124
+ phase_noise_magnitude: float
125
+ Float between 0 and 1 setting the range over which the phase
126
+ perturbation is uniformly sampled:
127
+ [0, `phase_noise_magnitude` * 2 * `pi`].
128
+ channel_indep : bool
129
+ Whether to sample phase perturbations independently for each channel or
130
+ not. It is advised to set it to False when spatial information is
131
+ important for the task, like in BCI.
132
+ random_state: int | numpy.random.Generator, optional
133
+ Used to draw the phase perturbation. Defaults to None.
134
+
135
+ Returns
136
+ -------
137
+ torch.Tensor
138
+ Transformed inputs.
139
+ torch.Tensor
140
+ Transformed labels.
141
+
142
+ References
143
+ ----------
144
+ .. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
145
+ Clifford, G. D. (2018). Addressing Class Imbalance in Classification
146
+ Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
147
+ preprint arXiv:1806.08675.
148
+ """
149
+ assert (
150
+ isinstance(
151
+ phase_noise_magnitude, (Real, torch.FloatTensor, torch.cuda.FloatTensor)
152
+ )
153
+ and 0 <= phase_noise_magnitude <= 1
154
+ ), f"eps must be a float between 0 and 1. Got {phase_noise_magnitude}."
155
+
156
+ f = fft(X.double(), dim=-1)
157
+ device = X.device
158
+
159
+ n = f.shape[-1]
160
+ random_phase = _new_random_fft_phase[n % 2](
161
+ f.shape[0],
162
+ f.shape[-2] if channel_indep else 1,
163
+ n,
164
+ device=device,
165
+ random_state=random_state,
166
+ )
167
+ if not channel_indep:
168
+ random_phase = torch.tile(random_phase, (1, f.shape[-2], 1))
169
+ if isinstance(phase_noise_magnitude, torch.Tensor):
170
+ phase_noise_magnitude = phase_noise_magnitude.to(device)
171
+ f_shifted = f * torch.exp(phase_noise_magnitude * random_phase)
172
+ shifted = ifft(f_shifted, dim=-1)
173
+ transformed_X = shifted.real.float()
174
+
175
+ return transformed_X, y
176
+
177
+
178
+ def _pick_channels_randomly(X, p_pick, random_state):
179
+ rng = check_random_state(random_state)
180
+ batch_size, n_channels, _ = X.shape
181
+ # allows to use the same RNG
182
+ unif_samples = torch.as_tensor(
183
+ rng.uniform(0, 1, size=(batch_size, n_channels)),
184
+ dtype=torch.float,
185
+ device=X.device,
186
+ )
187
+ # equivalent to a 0s and 1s mask
188
+ return torch.sigmoid(1000 * (unif_samples - p_pick))
189
+
190
+
191
+ def channels_dropout(X, y, p_drop, random_state=None):
192
+ """Randomly set channels to flat signal.
193
+
194
+ Part of the CMSAugment policy proposed in [1]_
195
+
196
+ Parameters
197
+ ----------
198
+ X : torch.Tensor
199
+ EEG input example or batch.
200
+ y : torch.Tensor
201
+ EEG labels for the example or batch.
202
+ p_drop : float
203
+ Float between 0 and 1 setting the probability of dropping each channel.
204
+ random_state : int | numpy.random.Generator, optional
205
+ Seed to be used to instantiate numpy random number generator instance.
206
+ Defaults to None.
207
+
208
+ Returns
209
+ -------
210
+ torch.Tensor
211
+ Transformed inputs.
212
+ torch.Tensor
213
+ Transformed labels.
214
+
215
+ References
216
+ ----------
217
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
218
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
219
+ Reordering. arXiv preprint arXiv:2010.13694.
220
+ """
221
+ mask = _pick_channels_randomly(X, p_drop, random_state=random_state)
222
+ return X * mask.unsqueeze(-1), y
223
+
224
+
225
+ def _make_permutation_matrix(X, mask, random_state):
226
+ rng = check_random_state(random_state)
227
+ batch_size, n_channels, _ = X.shape
228
+ hard_mask = mask.round()
229
+ batch_permutations = torch.empty(
230
+ batch_size, n_channels, n_channels, device=X.device
231
+ )
232
+ for b, mask in enumerate(hard_mask):
233
+ channels_to_shuffle = torch.arange(n_channels, device=X.device)
234
+ channels_to_shuffle = channels_to_shuffle[mask.bool()]
235
+ reordered_channels = torch.tensor(
236
+ rng.permutation(channels_to_shuffle.cpu()), device=X.device
237
+ )
238
+ channels_permutation = torch.arange(n_channels, device=X.device)
239
+ channels_permutation[channels_to_shuffle] = reordered_channels
240
+ batch_permutations[b, ...] = one_hot(channels_permutation)
241
+ return batch_permutations
242
+
243
+
244
+ def channels_shuffle(X, y, p_shuffle, random_state=None):
245
+ """Randomly shuffle channels in EEG data matrix.
246
+
247
+ Part of the CMSAugment policy proposed in [1]_
248
+
249
+ Parameters
250
+ ----------
251
+ X : torch.Tensor
252
+ EEG input example or batch.
253
+ y : torch.Tensor
254
+ EEG labels for the example or batch.
255
+ p_shuffle: float | None
256
+ Float between 0 and 1 setting the probability of including the channel
257
+ in the set of permutted channels.
258
+ random_state: int | numpy.random.Generator, optional
259
+ Seed to be used to instantiate numpy random number generator instance.
260
+ Used to sample which channels to shuffle and to carry the shuffle.
261
+ Defaults to None.
262
+
263
+ Returns
264
+ -------
265
+ torch.Tensor
266
+ Transformed inputs.
267
+ torch.Tensor
268
+ Transformed labels.
269
+
270
+ References
271
+ ----------
272
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
273
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
274
+ Reordering. arXiv preprint arXiv:2010.13694.
275
+ """
276
+ if p_shuffle == 0:
277
+ return X, y
278
+ mask = _pick_channels_randomly(X, 1 - p_shuffle, random_state)
279
+ batch_permutations = _make_permutation_matrix(X, mask, random_state)
280
+ return torch.matmul(batch_permutations, X), y
281
+
282
+
283
+ def gaussian_noise(X, y, std, random_state=None):
284
+ """Randomly add white Gaussian noise to all channels.
285
+
286
+ Suggested e.g. in [1]_, [2]_ and [3]_
287
+
288
+ Parameters
289
+ ----------
290
+ X : torch.Tensor
291
+ EEG input example or batch.
292
+ y : torch.Tensor
293
+ EEG labels for the example or batch.
294
+ std : float
295
+ Standard deviation to use for the additive noise.
296
+ random_state: int | numpy.random.Generator, optional
297
+ Seed to be used to instantiate numpy random number generator instance.
298
+ Defaults to None.
299
+
300
+ Returns
301
+ -------
302
+ torch.Tensor
303
+ Transformed inputs.
304
+ torch.Tensor
305
+ Transformed labels.
306
+
307
+ References
308
+ ----------
309
+ .. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
310
+ augmentation for eeg-based emotion recognition with deep convolutional
311
+ neural networks. In International Conference on Multimedia Modeling
312
+ (pp. 82-93).
313
+ .. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
314
+ Subject-aware contrastive learning for biosignals. arXiv preprint
315
+ arXiv:2007.04871.
316
+ .. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
317
+ Representation Learning for Electroencephalogram Classification. In
318
+ Machine Learning for Health (pp. 238-253). PMLR.
319
+ """
320
+ rng = check_random_state(random_state)
321
+ if isinstance(std, torch.Tensor):
322
+ std = std.to(X.device)
323
+ noise = (
324
+ torch.from_numpy(
325
+ rng.normal(loc=np.zeros(X.shape), scale=1),
326
+ )
327
+ .float()
328
+ .to(X.device)
329
+ * std
330
+ )
331
+ transformed_X = X + noise
332
+ return transformed_X, y
333
+
334
+
335
+ def channels_permute(X, y, permutation):
336
+ """Permute EEG channels according to fixed permutation matrix.
337
+
338
+ Suggested e.g. in [1]_
339
+
340
+ Parameters
341
+ ----------
342
+ X : torch.Tensor
343
+ EEG input example or batch.
344
+ y : torch.Tensor
345
+ EEG labels for the example or batch.
346
+ permutation : list
347
+ List of integers defining the new channels order.
348
+
349
+ Returns
350
+ -------
351
+ torch.Tensor
352
+ Transformed inputs.
353
+ torch.Tensor
354
+ Transformed labels.
355
+
356
+ References
357
+ ----------
358
+ .. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
359
+ (2018). HAMLET: interpretable human and machine co-learning technique.
360
+ arXiv preprint arXiv:1803.09702.
361
+ """
362
+ return X[..., permutation, :], y
363
+
364
+
365
+ def smooth_time_mask(X, y, mask_start_per_sample, mask_len_samples):
366
+ """Smoothly replace a contiguous part of all channels by zeros.
367
+
368
+ Originally proposed in [1]_ and [2]_
369
+
370
+ Parameters
371
+ ----------
372
+ X : torch.Tensor
373
+ EEG input example or batch.
374
+ y : torch.Tensor
375
+ EEG labels for the example or batch.
376
+ mask_start_per_sample : torch.tensor
377
+ Tensor of integers containing the position (in last dimension) where to
378
+ start masking the signal. Should have the same size as the first
379
+ dimension of X (i.e. one start position per example in the batch).
380
+ mask_len_samples : int
381
+ Number of consecutive samples to zero out.
382
+
383
+ Returns
384
+ -------
385
+ torch.Tensor
386
+ Transformed inputs.
387
+ torch.Tensor
388
+ Transformed labels.
389
+
390
+ References
391
+ ----------
392
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
393
+ Subject-aware contrastive learning for biosignals. arXiv preprint
394
+ arXiv:2007.04871.
395
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
396
+ Representation Learning for Electroencephalogram Classification. In
397
+ Machine Learning for Health (pp. 238-253). PMLR.
398
+ """
399
+ batch_size, n_channels, seq_len = X.shape
400
+ t = torch.arange(seq_len, device=X.device).float()
401
+ t = t.repeat(batch_size, n_channels, 1)
402
+ mask_start_per_sample = mask_start_per_sample.view(-1, 1, 1)
403
+ s = 1000 / seq_len
404
+ mask = (
405
+ (
406
+ torch.sigmoid(s * -(t - mask_start_per_sample))
407
+ + torch.sigmoid(s * (t - mask_start_per_sample - mask_len_samples))
408
+ )
409
+ .float()
410
+ .to(X.device)
411
+ )
412
+ return X * mask, y
413
+
414
+
415
+ def bandstop_filter(X, y, sfreq, bandwidth, freqs_to_notch):
416
+ """Apply a band-stop filter with desired bandwidth at the desired frequency
417
+ position.
418
+
419
+ Suggested e.g. in [1]_ and [2]_
420
+
421
+ Parameters
422
+ ----------
423
+ X : torch.Tensor
424
+ EEG input example or batch.
425
+ y : torch.Tensor
426
+ EEG labels for the example or batch.
427
+ sfreq : float
428
+ Sampling frequency of the signals to be filtered.
429
+ bandwidth : float
430
+ Bandwidth of the filter, i.e. distance between the low and high cut
431
+ frequencies.
432
+ freqs_to_notch : array-like | None
433
+ Array of floats of size ``(batch_size,)`` containing the center of the
434
+ frequency band to filter out for each sample in the batch. Frequencies
435
+ should be greater than ``bandwidth/2 + transition`` and lower than
436
+ ``sfreq/2 - bandwidth/2 - transition`` (where ``transition = 1 Hz``).
437
+
438
+ Returns
439
+ -------
440
+ torch.Tensor
441
+ Transformed inputs.
442
+ torch.Tensor
443
+ Transformed labels.
444
+
445
+ References
446
+ ----------
447
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
448
+ Subject-aware contrastive learning for biosignals. arXiv preprint
449
+ arXiv:2007.04871.
450
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
451
+ Representation Learning for Electroencephalogram Classification. In
452
+ Machine Learning for Health (pp. 238-253). PMLR.
453
+ """
454
+ if bandwidth == 0:
455
+ return X, y
456
+ transformed_X = X.clone()
457
+ for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
458
+ sample = sample.cpu().numpy().astype(np.float64)
459
+ transformed_X[c] = torch.as_tensor(
460
+ notch_filter(
461
+ sample,
462
+ Fs=sfreq,
463
+ freqs=notched_freq,
464
+ method="fir",
465
+ notch_widths=bandwidth,
466
+ verbose=False,
467
+ )
468
+ )
469
+ return transformed_X, y
470
+
471
+
472
+ def _analytic_transform(x):
473
+ if torch.is_complex(x):
474
+ raise ValueError("x must be real.")
475
+
476
+ N = x.shape[-1]
477
+ f = fft(x, N, dim=-1)
478
+ h = torch.zeros_like(f)
479
+ if N % 2 == 0:
480
+ h[..., 0] = h[..., N // 2] = 1
481
+ h[..., 1 : N // 2] = 2
482
+ else:
483
+ h[..., 0] = 1
484
+ h[..., 1 : (N + 1) // 2] = 2
485
+
486
+ return ifft(f * h, dim=-1)
487
+
488
+
489
+ def _nextpow2(n):
490
+ """Return the first integer N such that 2**N >= abs(n)."""
491
+ return int(np.ceil(np.log2(np.abs(n))))
492
+
493
+
494
+ def _frequency_shift(X, fs, f_shift):
495
+ """
496
+ Shift the specified signal by the specified frequency.
497
+
498
+ See https://gist.github.com/lebedov/4428122
499
+ """
500
+ # Pad the signal with zeros to prevent the FFT invoked by the transform
501
+ # from slowing down the computation:
502
+ n_channels, N_orig = X.shape[-2:]
503
+ N_padded = 2 ** _nextpow2(N_orig)
504
+ t = torch.arange(N_padded, device=X.device) / fs
505
+ padded = pad(X, (0, N_padded - N_orig))
506
+ analytical = _analytic_transform(padded)
507
+ if isinstance(f_shift, (float, int, np.ndarray, list)):
508
+ f_shift = torch.as_tensor(f_shift).float()
509
+ f_shift_stack = f_shift.repeat(N_padded, n_channels, 1)
510
+ reshaped_f_shift = f_shift_stack.permute(
511
+ *torch.arange(f_shift_stack.ndim - 1, -1, -1)
512
+ )
513
+ shifted = analytical * torch.exp(2j * np.pi * reshaped_f_shift * t)
514
+ return shifted[..., :N_orig].real.float()
515
+
516
+
517
+ def frequency_shift(X, y, delta_freq, sfreq):
518
+ """Adds a shift in the frequency domain to all channels.
519
+
520
+ Note that here, the shift is the same for all channels of a single example.
521
+
522
+ Parameters
523
+ ----------
524
+ X : torch.Tensor
525
+ EEG input example or batch.
526
+ y : torch.Tensor
527
+ EEG labels for the example or batch.
528
+ delta_freq : float
529
+ The amplitude of the frequency shift (in Hz).
530
+ sfreq : float
531
+ Sampling frequency of the signals to be transformed.
532
+
533
+ Returns
534
+ -------
535
+ torch.Tensor
536
+ Transformed inputs.
537
+ torch.Tensor
538
+ Transformed labels.
539
+ """
540
+ transformed_X = _frequency_shift(
541
+ X=X,
542
+ fs=sfreq,
543
+ f_shift=delta_freq,
544
+ )
545
+ return transformed_X, y
546
+
547
+
548
+ def _torch_normalize_vectors(rr):
549
+ """Normalize surface vertices."""
550
+ norm = torch.linalg.norm(rr, axis=1, keepdim=True)
551
+ mask = norm > 0
552
+ norm[~mask] = 1 # in case norm is zero, divide by 1
553
+ new_rr = rr / norm
554
+ return new_rr
555
+
556
+
557
+ def _torch_legval(x, c, tensor=True):
558
+ """
559
+ Evaluate a Legendre series at points x.
560
+ If `c` is of length `n + 1`, this function returns the value:
561
+ .. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
562
+ The parameter `x` is converted to an array only if it is a tuple or a
563
+ list, otherwise it is treated as a scalar. In either case, either `x`
564
+ or its elements must support multiplication and addition both with
565
+ themselves and with the elements of `c`.
566
+ If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If
567
+ `c` is multidimensional, then the shape of the result depends on the
568
+ value of `tensor`. If `tensor` is true the shape will be c.shape[1:] +
569
+ x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that
570
+ scalars have shape (,).
571
+ Trailing zeros in the coefficients will be used in the evaluation, so
572
+ they should be avoided if efficiency is a concern.
573
+
574
+ Parameters
575
+ ----------
576
+ x : array_like, compatible object
577
+ If `x` is a list or tuple, it is converted to an ndarray, otherwise
578
+ it is left unchanged and treated as a scalar. In either case, `x`
579
+ or its elements must support addition and multiplication with
580
+ with themselves and with the elements of `c`.
581
+ c : array_like
582
+ Array of coefficients ordered so that the coefficients for terms of
583
+ degree n are contained in c[n]. If `c` is multidimensional the
584
+ remaining indices enumerate multiple polynomials. In the two
585
+ dimensional case the coefficients may be thought of as stored in
586
+ the columns of `c`.
587
+ tensor : boolean, optional
588
+ If True, the shape of the coefficient array is extended with ones
589
+ on the right, one for each dimension of `x`. Scalars have dimension 0
590
+ for this action. The result is that every column of coefficients in
591
+ `c` is evaluated for every element of `x`. If False, `x` is broadcast
592
+ over the columns of `c` for the evaluation. This keyword is useful
593
+ when `c` is multidimensional. The default value is True.
594
+ .. versionadded:: 1.7.0
595
+
596
+ Returns
597
+ -------
598
+ values : ndarray, algebra_like
599
+ The shape of the return value is described above.
600
+
601
+ See Also
602
+ --------
603
+ legval2d, leggrid2d, legval3d, leggrid3d
604
+
605
+ Notes
606
+ -----
607
+ Code copied and modified from Numpy:
608
+ https://github.com/numpy/numpy/blob/v1.20.0/numpy/polynomial/legendre.py#L835-L920
609
+
610
+ Copyright (c) 2005-2021, NumPy Developers.
611
+ All rights reserved.
612
+
613
+ Redistribution and use in source and binary forms, with or without
614
+ modification, are permitted provided that the following conditions are
615
+ met:
616
+ * Redistributions of source code must retain the above copyright
617
+ notice, this list of conditions and the following disclaimer.
618
+ * Redistributions in binary form must reproduce the above
619
+ copyright notice, this list of conditions and the following
620
+ disclaimer in the documentation and/or other materials provided
621
+ with the distribution.
622
+ * Neither the name of the NumPy Developers nor the names of any
623
+ contributors may be used to endorse or promote products derived
624
+ from this software without specific prior written permission.
625
+
626
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
627
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
628
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
629
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
630
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
631
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
632
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
633
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
634
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
635
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
636
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
637
+ """
638
+ c = torch.as_tensor(c)
639
+ c = c.double()
640
+ if isinstance(x, (tuple, list)):
641
+ x = torch.as_tensor(x)
642
+ if isinstance(x, torch.Tensor) and tensor:
643
+ c = c.view(c.shape + (1,) * x.ndim)
644
+
645
+ c = c.to(x.device)
646
+
647
+ if len(c) == 1:
648
+ c0 = c[0]
649
+ c1 = 0
650
+ elif len(c) == 2:
651
+ c0 = c[0]
652
+ c1 = c[1]
653
+ else:
654
+ nd = len(c)
655
+ c0 = c[-2]
656
+ c1 = c[-1]
657
+ for i in range(3, len(c) + 1):
658
+ tmp = c0
659
+ nd = nd - 1
660
+ c0 = c[-i] - (c1 * (nd - 1)) / nd
661
+ c1 = tmp + (c1 * x * (2 * nd - 1)) / nd
662
+ return c0 + c1 * x
663
+
664
+
665
+ def _torch_calc_g(cosang, stiffness=4, n_legendre_terms=50):
666
+ """Calculate spherical spline g function between points on a sphere.
667
+
668
+ Parameters
669
+ ----------
670
+ cosang : array-like of float, shape(n_channels, n_channels)
671
+ cosine of angles between pairs of points on a spherical surface. This
672
+ is equivalent to the dot product of unit vectors.
673
+ stiffness : float
674
+ stiffness of the spline.
675
+ n_legendre_terms : int
676
+ number of Legendre terms to evaluate.
677
+
678
+ Returns
679
+ -------
680
+ G : np.ndrarray of float, shape(n_channels, n_channels)
681
+ The G matrix.
682
+
683
+ Notes
684
+ -----
685
+ Code copied and modified from MNE-Python:
686
+ https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L35
687
+
688
+ Copyright © 2011-2019, authors of MNE-Python
689
+ All rights reserved.
690
+
691
+ Redistribution and use in source and binary forms, with or without
692
+ modification, are permitted provided that the following conditions are met:
693
+ * Redistributions of source code must retain the above copyright
694
+ notice, this list of conditions and the following disclaimer.
695
+ * Redistributions in binary form must reproduce the above copyright
696
+ notice, this list of conditions and the following disclaimer in the
697
+ documentation and/or other materials provided with the distribution.
698
+ * Neither the name of the copyright holder nor the names of its
699
+ contributors may be used to endorse or promote products derived from
700
+ this software without specific prior written permission.
701
+
702
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
703
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
704
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
705
+ ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
706
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
707
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
708
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
709
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
710
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
711
+ OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
712
+ DAMAGE.
713
+ """
714
+ factors = [
715
+ (2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
716
+ for n in range(1, n_legendre_terms + 1)
717
+ ]
718
+ return _torch_legval(cosang, [0] + factors)
719
+
720
+
721
+ def _torch_make_interpolation_matrix(pos_from, pos_to, alpha=1e-5):
722
+ """Compute interpolation matrix based on spherical splines.
723
+
724
+ Implementation based on [1]_
725
+
726
+ Parameters
727
+ ----------
728
+ pos_from : np.ndarray of float, shape(n_good_sensors, 3)
729
+ The positions to interpolate from.
730
+ pos_to : np.ndarray of float, shape(n_bad_sensors, 3)
731
+ The positions to interpolate.
732
+ alpha : float
733
+ Regularization parameter. Defaults to 1e-5.
734
+
735
+ Returns
736
+ -------
737
+ interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to))
738
+ The interpolation matrix that maps good signals to the location
739
+ of bad signals.
740
+
741
+ References
742
+ ----------
743
+ [1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
744
+ Spherical splines for scalp potential and current density mapping.
745
+ Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
746
+
747
+ Notes
748
+ -----
749
+ Code copied and modified from MNE-Python:
750
+ https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L59
751
+
752
+ Copyright © 2011-2019, authors of MNE-Python
753
+ All rights reserved.
754
+
755
+ Redistribution and use in source and binary forms, with or without
756
+ modification, are permitted provided that the following conditions are met:
757
+ * Redistributions of source code must retain the above copyright
758
+ notice, this list of conditions and the following disclaimer.
759
+ * Redistributions in binary form must reproduce the above copyright
760
+ notice, this list of conditions and the following disclaimer in the
761
+ documentation and/or other materials provided with the distribution.
762
+ * Neither the name of the copyright holder nor the names of its
763
+ contributors may be used to endorse or promote products derived from
764
+ this software without specific prior written permission.
765
+
766
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
767
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
768
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
769
+ ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
770
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
771
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
772
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
773
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
774
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
775
+ OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
776
+ DAMAGE.
777
+ """
778
+ pos_from = pos_from.clone()
779
+ pos_to = pos_to.clone()
780
+ n_from = pos_from.shape[0]
781
+ n_to = pos_to.shape[0]
782
+
783
+ # normalize sensor positions to sphere
784
+ pos_from = _torch_normalize_vectors(pos_from)
785
+ pos_to = _torch_normalize_vectors(pos_to)
786
+
787
+ # cosine angles between source positions
788
+ cosang_from = torch.matmul(pos_from, pos_from.T)
789
+ cosang_to_from = torch.matmul(pos_to, pos_from.T)
790
+ G_from = _torch_calc_g(cosang_from)
791
+ G_to_from = _torch_calc_g(cosang_to_from)
792
+ assert G_from.shape == (n_from, n_from)
793
+ assert G_to_from.shape == (n_to, n_from)
794
+
795
+ if alpha is not None:
796
+ G_from.flatten()[:: len(G_from) + 1] += alpha
797
+
798
+ device = G_from.device
799
+ C = torch.vstack(
800
+ [
801
+ torch.hstack([G_from, torch.ones((n_from, 1), device=device)]),
802
+ torch.hstack(
803
+ [
804
+ torch.ones((1, n_from), device=device),
805
+ torch.as_tensor([[0]], device=device),
806
+ ]
807
+ ),
808
+ ]
809
+ )
810
+
811
+ try:
812
+ C_inv = torch.linalg.inv(C)
813
+ except torch._C._LinAlgError:
814
+ # There is a stability issue with pinv since torch v1.8.0
815
+ # see https://github.com/pytorch/pytorch/issues/75494
816
+ C_inv = torch.linalg.pinv(C.cpu()).to(device)
817
+
818
+ interpolation = torch.hstack(
819
+ [G_to_from, torch.ones((n_to, 1), device=device)]
820
+ ).matmul(C_inv[:, :-1])
821
+ assert interpolation.shape == (n_to, n_from)
822
+ return interpolation
823
+
824
+
825
+ def _rotate_signals(X, rotations, sensors_positions_matrix, spherical=True):
826
+ sensors_positions_matrix = sensors_positions_matrix.to(X.device)
827
+ rot_sensors_matrices = [
828
+ rotation.matmul(sensors_positions_matrix) for rotation in rotations
829
+ ]
830
+ if spherical:
831
+ interpolation_matrix = torch.stack(
832
+ [
833
+ torch.as_tensor(
834
+ _torch_make_interpolation_matrix(
835
+ sensors_positions_matrix.T, rot_sensors_matrix.T
836
+ ),
837
+ device=X.device,
838
+ ).float()
839
+ for rot_sensors_matrix in rot_sensors_matrices
840
+ ]
841
+ )
842
+ return torch.matmul(interpolation_matrix, X)
843
+ else:
844
+ transformed_X = X.clone()
845
+ sensors_positions = list(sensors_positions_matrix)
846
+ for s, rot_sensors_matrix in enumerate(rot_sensors_matrices):
847
+ rot_sensors_positions = list(rot_sensors_matrix.T)
848
+ for time in range(X.shape[-1]):
849
+ interpolator_t = Rbf(*sensors_positions, X[s, :, time])
850
+ transformed_X[s, :, time] = torch.from_numpy(
851
+ interpolator_t(*rot_sensors_positions)
852
+ )
853
+ return transformed_X
854
+
855
+
856
+ def _make_rotation_matrix(axis, angle, degrees=True):
857
+ assert axis in ["x", "y", "z"], "axis should be either x, y or z."
858
+
859
+ if isinstance(angle, (float, int, np.ndarray, list)):
860
+ angle = torch.as_tensor(angle)
861
+
862
+ if degrees:
863
+ angle = angle * np.pi / 180
864
+
865
+ device = angle.device
866
+ zero = torch.zeros(1, device=device)
867
+ rot = torch.stack(
868
+ [
869
+ torch.as_tensor([1, 0, 0], device=device),
870
+ torch.hstack([zero, torch.cos(angle), -torch.sin(angle)]),
871
+ torch.hstack([zero, torch.sin(angle), torch.cos(angle)]),
872
+ ]
873
+ )
874
+ if axis == "x":
875
+ return rot
876
+ elif axis == "y":
877
+ rot = rot[[2, 0, 1], :]
878
+ return rot[:, [2, 0, 1]]
879
+ else:
880
+ rot = rot[[1, 2, 0], :]
881
+ return rot[:, [1, 2, 0]]
882
+
883
+
884
+ def sensors_rotation(X, y, sensors_positions_matrix, axis, angles, spherical_splines):
885
+ """Interpolates EEG signals over sensors rotated around the desired axis
886
+ with the desired angle.
887
+
888
+ Suggested in [1]_
889
+
890
+ Parameters
891
+ ----------
892
+ X : torch.Tensor
893
+ EEG input example or batch.
894
+ y : torch.Tensor
895
+ EEG labels for the example or batch.
896
+ sensors_positions_matrix : numpy.ndarray
897
+ Matrix giving the positions of each sensor in a 3D cartesian coordinate
898
+ system. Should have shape (3, n_channels), where n_channels is the
899
+ number of channels. Standard 10-20 positions can be obtained from
900
+ ``mne`` through::
901
+
902
+ >>> ten_twenty_montage = mne.channels.make_standard_montage(
903
+ ... 'standard_1020'
904
+ ... ).get_positions()['ch_pos']
905
+ axis : 'x' | 'y' | 'z'
906
+ Axis around which to rotate.
907
+ angles : array-like
908
+ Array of float of shape ``(batch_size,)`` containing the rotation
909
+ angles (in degrees) for each element of the input batch.
910
+ spherical_splines : bool
911
+ Whether to use spherical splines for the interpolation or not. When
912
+ `False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
913
+ used (as in the original paper).
914
+
915
+ References
916
+ ----------
917
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
918
+ electroencephalographic data. In 2017 39th Annual International
919
+ Conference of the IEEE Engineering in Medicine and Biology Society
920
+ (EMBC) (pp. 471-474).
921
+ """
922
+ rots = [_make_rotation_matrix(axis, angle, degrees=True) for angle in angles]
923
+ rotated_X = _rotate_signals(X, rots, sensors_positions_matrix, spherical_splines)
924
+ return rotated_X, y
925
+
926
+
927
+ def mixup(X, y, lam, idx_perm):
928
+ """Mixes two channels of EEG data.
929
+
930
+ See [1]_ for details.
931
+ Implementation based on [2]_.
932
+
933
+ Parameters
934
+ ----------
935
+ X : torch.Tensor
936
+ EEG data in form ``batch_size, n_channels, n_times``
937
+ y : torch.Tensor
938
+ Target of length ``batch_size``
939
+ lam : torch.Tensor
940
+ Values between 0 and 1 setting the linear interpolation between
941
+ examples.
942
+ idx_perm: torch.Tensor
943
+ Permuted indices of example that are mixed into original examples.
944
+
945
+ Returns
946
+ -------
947
+ tuple
948
+ ``X``, ``y``. Where ``X`` is augmented and ``y`` is a tuple of length
949
+ 3 containing the labels of the two mixed channels and the mixing
950
+ coefficient.
951
+
952
+ References
953
+ ----------
954
+ .. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
955
+ (2018). mixup: Beyond Empirical Risk Minimization. In 2018
956
+ International Conference on Learning Representations (ICLR)
957
+ Online: https://arxiv.org/abs/1710.09412
958
+ .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
959
+ """
960
+ device = X.device
961
+ batch_size, n_channels, n_times = X.shape
962
+
963
+ X_mix = torch.zeros((batch_size, n_channels, n_times)).to(device)
964
+ y_a = torch.arange(batch_size).to(device)
965
+ y_b = torch.arange(batch_size).to(device)
966
+
967
+ for idx in range(batch_size):
968
+ X_mix[idx] = lam[idx] * X[idx] + (1 - lam[idx]) * X[idx_perm[idx]]
969
+ y_a[idx] = y[idx]
970
+ y_b[idx] = y[idx_perm[idx]]
971
+
972
+ return X_mix, (y_a, y_b, lam)
973
+
974
+
975
+ def segmentation_reconstruction(
976
+ X, y, n_segments, data_classes, rand_indices, idx_shuffle
977
+ ):
978
+ """Segment and reconstruct EEG data from [1]_.
979
+
980
+ See [1]_ for details.
981
+
982
+ Parameters
983
+ ----------
984
+ X : torch.Tensor
985
+ EEG input example or batch.
986
+ y : torch.Tensor
987
+ EEG labels for the example or batch.
988
+ n_segments : int
989
+ Number of segments to use in the batch.
990
+ rand_indices: array-like
991
+ Array of indices that indicates which trial to use in each segment.
992
+ idx_shuffle: array-like
993
+ Array of indices to shuffle the new generated trials.
994
+ Returns
995
+ -------
996
+ torch.Tensor
997
+ Transformed inputs.
998
+ torch.Tensor
999
+ Transformed labels.
1000
+ References
1001
+ ----------
1002
+ .. [1] Lotte, F. (2015). Signal processing approaches to minimize or
1003
+ suppress calibration time in oscillatory activity-based brain–computer
1004
+ interfaces. Proceedings of the IEEE, 103(6), 871-890.
1005
+ """
1006
+
1007
+ # Initialize lists to store augmented data and corresponding labels
1008
+ aug_data = []
1009
+ aug_label = []
1010
+
1011
+ # Iterate through each class to separate and augment data
1012
+ for class_index, X_class in data_classes:
1013
+ # Determine class-specific dimensions
1014
+ # Store the augmented data and the corresponding class labels
1015
+ n_trials, n_channels, window_size = X_class.shape
1016
+ # Segment Size
1017
+ segment_size = window_size // n_segments
1018
+ # Initialize an empty tensor for augmented data
1019
+ X_aug = torch.zeros_like(X_class)
1020
+ # Generate random indices within the class-specific dataset
1021
+ rand_idx = rand_indices[class_index]
1022
+ for idx_segment in range(n_segments):
1023
+ start = idx_segment * segment_size
1024
+ end = (idx_segment + 1) * segment_size
1025
+
1026
+ # Perform the data augmentation
1027
+ X_aug[np.arange(n_trials), :, start:end] = X_class[
1028
+ rand_idx[:, idx_segment], :, start:end
1029
+ ]
1030
+ aug_data.append(X_aug)
1031
+ aug_label.append(torch.full((n_trials,), class_index))
1032
+ # Concatenate the augmented data and labels
1033
+ aug_data = torch.cat(aug_data, dim=0)
1034
+ aug_data = aug_data.to(dtype=X.dtype, device=X.device)
1035
+ aug_data = aug_data[idx_shuffle]
1036
+
1037
+ if y is not None:
1038
+ aug_label = torch.cat(aug_label, dim=0)
1039
+ aug_label = aug_label.to(dtype=y.dtype, device=y.device)
1040
+ aug_label = aug_label[idx_shuffle]
1041
+ return aug_data, aug_label
1042
+
1043
+ return aug_data, y
1044
+
1045
+
1046
+ def mask_encoding(X, y, time_start, segment_length, n_segments):
1047
+ """Mark encoding from Ding et al. (2024) from [ding2024]_.
1048
+
1049
+ Replaces a contiguous part (or parts) of all channels by zeros
1050
+ (if more than one segment, it may overlap).
1051
+
1052
+ Implementation based on [ding2024]_
1053
+
1054
+ Parameters
1055
+ ----------
1056
+ X : torch.Tensor
1057
+ EEG input example or batch.
1058
+ y : torch.Tensor
1059
+ EEG labels for the example or batch.
1060
+ time_start : torch.Tensor
1061
+ Tensor of integers containing the position (in last dimension) where to
1062
+ start masking the signal. Should have "n_segments" times the size of the first
1063
+ dimension of X (i.e. "n_segments" start positions per example in the batch).
1064
+ segment_length : int
1065
+ Length of each segment to zero out.
1066
+ n_segments : int
1067
+ Number of segments to zero out in each example.
1068
+
1069
+ Returns
1070
+ -------
1071
+ torch.Tensor
1072
+ Transformed inputs.
1073
+ torch.Tensor
1074
+ Transformed labels.
1075
+
1076
+ References
1077
+ ----------
1078
+ .. [ding2024] Ding, Wenlong, et al. A Novel Data Augmentation Approach
1079
+ Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI.
1080
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering
1081
+ 32 (2024): 875-886.
1082
+ """
1083
+
1084
+ batch_indices = torch.arange(X.shape[0]).repeat_interleave(n_segments)
1085
+ start_indices = time_start.flatten()
1086
+ mask_indices = start_indices[:, None] + torch.arange(segment_length)
1087
+
1088
+ # Create a boolean mask with the same shape as X
1089
+ mask = torch.zeros_like(X, dtype=torch.bool)
1090
+ for batch_index, grouped_mask_indices in zip(batch_indices, mask_indices):
1091
+ mask[batch_index, :, grouped_mask_indices] = True
1092
+
1093
+ # Apply the mask to set the values to 0
1094
+ X[mask] = 0
1095
+
1096
+ return X, y # Return the masked tensor and labels