braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1274 @@
1
+ # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
+ #
5
+ # License: BSD (3-clause)
6
+
7
+ import warnings
8
+ from numbers import Real
9
+ from typing import Callable
10
+
11
+ import numpy as np
12
+ import torch
13
+ from mne.channels import make_standard_montage
14
+
15
+ from .base import Transform
16
+ from .functional import (
17
+ bandstop_filter,
18
+ channels_dropout,
19
+ channels_permute,
20
+ channels_shuffle,
21
+ frequency_shift,
22
+ ft_surrogate,
23
+ gaussian_noise,
24
+ mask_encoding,
25
+ mixup,
26
+ segmentation_reconstruction,
27
+ sensors_rotation,
28
+ sign_flip,
29
+ smooth_time_mask,
30
+ time_reverse,
31
+ )
32
+
33
+
34
+ class TimeReverse(Transform):
35
+ """Flip the time axis of each input with a given probability.
36
+
37
+ Parameters
38
+ ----------
39
+ probability : float
40
+ Float setting the probability of applying the operation.
41
+ random_state: int | numpy.random.Generator, optional
42
+ Seed to be used to instantiate numpy random number generator instance.
43
+ Used to decide whether or not to transform given the probability
44
+ argument. Defaults to None.
45
+ """
46
+
47
+ operation = staticmethod(time_reverse) # type: ignore[assignment]
48
+
49
+ def __init__(
50
+ self,
51
+ probability,
52
+ random_state=None,
53
+ ):
54
+ super().__init__(
55
+ probability=probability,
56
+ random_state=random_state,
57
+ )
58
+
59
+
60
+ class SignFlip(Transform):
61
+ """Flip the sign axis of each input with a given probability.
62
+
63
+ Parameters
64
+ ----------
65
+ probability : float
66
+ Float setting the probability of applying the operation.
67
+ random_state: int | numpy.random.Generator, optional
68
+ Seed to be used to instantiate numpy random number generator instance.
69
+ Used to decide whether or not to transform given the probability
70
+ argument. Defaults to None.
71
+ """
72
+
73
+ operation = staticmethod(sign_flip) # type: ignore[assignment]
74
+
75
+ def __init__(self, probability, random_state=None):
76
+ super().__init__(probability=probability, random_state=random_state)
77
+
78
+
79
+ class FTSurrogate(Transform):
80
+ """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
81
+
82
+ Parameters
83
+ ----------
84
+ probability: float
85
+ Float setting the probability of applying the operation.
86
+ phase_noise_magnitude : float | torch.Tensor, optional
87
+ Float between 0 and 1 setting the range over which the phase
88
+ perturbation is uniformly sampled:
89
+ ``[0, phase_noise_magnitude * 2 * pi]``. Defaults to 1.
90
+ channel_indep : bool, optional
91
+ Whether to sample phase perturbations independently for each channel or
92
+ not. It is advised to set it to False when spatial information is
93
+ important for the task, like in BCI. Default False.
94
+ random_state: int | numpy.random.Generator, optional
95
+ Seed to be used to instantiate numpy random number generator instance.
96
+ Used to decide whether or not to transform given the probability
97
+ argument. Defaults to None.
98
+
99
+ References
100
+ ----------
101
+ .. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
102
+ Clifford, G. D. (2018). Addressing Class Imbalance in Classification
103
+ Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
104
+ preprint arXiv:1806.08675.
105
+ """
106
+
107
+ operation = staticmethod(ft_surrogate) # type: ignore[assignment]
108
+
109
+ def __init__(
110
+ self,
111
+ probability,
112
+ phase_noise_magnitude=1,
113
+ channel_indep=False,
114
+ random_state=None,
115
+ ):
116
+ super().__init__(probability=probability, random_state=random_state)
117
+ assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), (
118
+ "phase_noise_magnitude should be a float."
119
+ )
120
+ assert 0 <= phase_noise_magnitude <= 1, (
121
+ "phase_noise_magnitude should be between 0 and 1."
122
+ )
123
+ assert isinstance(channel_indep, bool), (
124
+ "channel_indep is expected to be a boolean"
125
+ )
126
+ self.phase_noise_magnitude = phase_noise_magnitude
127
+ self.channel_indep = channel_indep
128
+
129
+ def get_augmentation_params(self, *batch):
130
+ """Return transform parameters.
131
+
132
+ Parameters
133
+ ----------
134
+ X : tensor.Tensor
135
+ The data.
136
+ y : tensor.Tensor
137
+ The labels.
138
+
139
+ Returns
140
+ -------
141
+ params : dict
142
+ Contains:
143
+
144
+ * phase_noise_magnitude : float
145
+ The magnitude of the transformation.
146
+ * random_state : numpy.random.Generator
147
+ The generator to use.
148
+ """
149
+ return {
150
+ "phase_noise_magnitude": self.phase_noise_magnitude,
151
+ "channel_indep": self.channel_indep,
152
+ "random_state": self.rng,
153
+ }
154
+
155
+
156
+ class ChannelsDropout(Transform):
157
+ """Randomly set channels to flat signal.
158
+
159
+ Part of the CMSAugment policy proposed in [1]_
160
+
161
+ Parameters
162
+ ----------
163
+ probability: float
164
+ Float setting the probability of applying the operation.
165
+ proba_drop: float | None, optional
166
+ Float between 0 and 1 setting the probability of dropping each channel.
167
+ Defaults to 0.2.
168
+ random_state: int | numpy.random.Generator, optional
169
+ Seed to be used to instantiate numpy random number generator instance.
170
+ Used to decide whether or not to transform given the probability
171
+ argument and to sample channels to erase. Defaults to None.
172
+
173
+ References
174
+ ----------
175
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
176
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
177
+ Reordering. arXiv preprint arXiv:2010.13694.
178
+ """
179
+
180
+ operation = staticmethod(channels_dropout) # type: ignore[assignment]
181
+
182
+ def __init__(self, probability, p_drop=0.2, random_state=None):
183
+ super().__init__(probability=probability, random_state=random_state)
184
+ self.p_drop = p_drop
185
+
186
+ def get_augmentation_params(self, *batch):
187
+ """Return transform parameters.
188
+
189
+ Parameters
190
+ ----------
191
+ X : tensor.Tensor
192
+ The data.
193
+ y : tensor.Tensor
194
+ The labels.
195
+
196
+ Returns
197
+ -------
198
+ params : dict
199
+ Contains
200
+
201
+ * p_drop : float
202
+ Float between 0 and 1 setting the probability of dropping each
203
+ channel.
204
+ * random_state : numpy.random.Generator
205
+ The generator to use.
206
+ """
207
+ return {
208
+ "p_drop": self.p_drop,
209
+ "random_state": self.rng,
210
+ }
211
+
212
+
213
+ class ChannelsShuffle(Transform):
214
+ """Randomly shuffle channels in EEG data matrix.
215
+
216
+ Part of the CMSAugment policy proposed in [1]_
217
+
218
+ Parameters
219
+ ----------
220
+ probability: float
221
+ Float setting the probability of applying the operation.
222
+ p_shuffle: float | None, optional
223
+ Float between 0 and 1 setting the probability of including the channel
224
+ in the set of permuted channels. Defaults to 0.2.
225
+ random_state: int | numpy.random.Generator, optional
226
+ Seed to be used to instantiate numpy random number generator instance.
227
+ Used to decide whether or not to transform given the probability
228
+ argument, to sample which channels to shuffle and to carry the shuffle.
229
+ Defaults to None.
230
+
231
+ References
232
+ ----------
233
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
234
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
235
+ Reordering. arXiv preprint arXiv:2010.13694.
236
+ """
237
+
238
+ operation = staticmethod(channels_shuffle) # type: ignore[assignment]
239
+
240
+ def __init__(self, probability, p_shuffle=0.2, random_state=None):
241
+ super().__init__(probability=probability, random_state=random_state)
242
+ self.p_shuffle = p_shuffle
243
+
244
+ def get_augmentation_params(self, *batch):
245
+ """Return transform parameters.
246
+
247
+ Parameters
248
+ ----------
249
+ X : tensor.Tensor
250
+ The data.
251
+ y : tensor.Tensor
252
+ The labels.
253
+
254
+ Returns
255
+ -------
256
+ params : dict
257
+ Contains
258
+
259
+ * p_shuffle : float
260
+ Float between 0 and 1 setting the probability of including the
261
+ channel in the set of permuted channels.
262
+ * random_state : numpy.random.Generator
263
+ The generator to use.
264
+ """
265
+ return {
266
+ "p_shuffle": self.p_shuffle,
267
+ "random_state": self.rng,
268
+ }
269
+
270
+
271
+ class GaussianNoise(Transform):
272
+ """Randomly add white noise to all channels.
273
+
274
+ Suggested e.g. in [1]_, [2]_ and [3]_
275
+
276
+ Parameters
277
+ ----------
278
+ probability : float
279
+ Float setting the probability of applying the operation.
280
+ std : float, optional
281
+ Standard deviation to use for the additive noise. Defaults to 0.1.
282
+ random_state: int | numpy.random.Generator, optional
283
+ Seed to be used to instantiate numpy random number generator instance.
284
+ Defaults to None.
285
+
286
+ References
287
+ ----------
288
+ .. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
289
+ augmentation for eeg-based emotion recognition with deep convolutional
290
+ neural networks. In International Conference on Multimedia Modeling
291
+ (pp. 82-93).
292
+ .. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
293
+ Subject-aware contrastive learning for biosignals. arXiv preprint
294
+ arXiv:2007.04871.
295
+ .. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
296
+ Representation Learning for Electroencephalogram Classification. In
297
+ Machine Learning for Health (pp. 238-253). PMLR.
298
+ """
299
+
300
+ operation = staticmethod(gaussian_noise) # type: ignore[assignment]
301
+
302
+ def __init__(self, probability, std=0.1, random_state=None):
303
+ super().__init__(
304
+ probability=probability,
305
+ random_state=random_state,
306
+ )
307
+ self.std = std
308
+
309
+ def get_augmentation_params(self, *batch):
310
+ """Return transform parameters.
311
+
312
+ Parameters
313
+ ----------
314
+ X : tensor.Tensor
315
+ The data.
316
+ y : tensor.Tensor
317
+ The labels.
318
+
319
+ Returns
320
+ -------
321
+ params : dict
322
+ Contains
323
+
324
+ * std : float
325
+ Standard deviation to use for the additive noise.
326
+ * random_state : numpy.random.Generator
327
+ The generator to use.
328
+ """
329
+ return {
330
+ "std": self.std,
331
+ "random_state": self.rng,
332
+ }
333
+
334
+
335
+ class ChannelsSymmetry(Transform):
336
+ """Permute EEG channels inverting left and right-side sensors.
337
+
338
+ Suggested e.g. in [1]_
339
+
340
+ Parameters
341
+ ----------
342
+ probability : float
343
+ Float setting the probability of applying the operation.
344
+ ordered_ch_names : list
345
+ Ordered list of strings containing the names (in 10-20
346
+ nomenclature) of the EEG channels that will be transformed. The
347
+ first name should correspond the data in the first row of X, the
348
+ second name in the second row and so on.
349
+ random_state: int | numpy.random.Generator, optional
350
+ Seed to be used to instantiate numpy random number generator instance.
351
+ Used to decide whether or not to transform given the probability
352
+ argument. Defaults to None.
353
+
354
+ References
355
+ ----------
356
+ .. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
357
+ (2018). HAMLET: interpretable human and machine co-learning technique.
358
+ arXiv preprint arXiv:1803.09702.
359
+ """
360
+
361
+ operation = staticmethod(channels_permute) # type: ignore[assignment]
362
+
363
+ def __init__(self, probability, ordered_ch_names, random_state=None):
364
+ super().__init__(
365
+ probability=probability,
366
+ random_state=random_state,
367
+ )
368
+ assert isinstance(ordered_ch_names, list) and all(
369
+ isinstance(ch, str) for ch in ordered_ch_names
370
+ ), "ordered_ch_names should be a list of str."
371
+
372
+ permutation = list()
373
+ for idx, ch_name in enumerate(ordered_ch_names):
374
+ new_position = idx
375
+ # Find digits in channel name (assuming 10-20 system)
376
+ d = "".join(list(filter(str.isdigit, ch_name)))
377
+ if len(d) > 0:
378
+ d = int(d)
379
+ if d % 2 == 0: # pair/right electrodes
380
+ sym = d - 1
381
+ else: # odd/left electrodes
382
+ sym = d + 1
383
+ new_channel = ch_name.replace(str(d), str(sym))
384
+ if new_channel in ordered_ch_names:
385
+ new_position = ordered_ch_names.index(new_channel)
386
+ permutation.append(new_position)
387
+ self.permutation = permutation
388
+
389
+ def get_augmentation_params(self, *batch):
390
+ """Return transform parameters.
391
+
392
+ Parameters
393
+ ----------
394
+ X : tensor.Tensor
395
+ The data.
396
+ y : tensor.Tensor
397
+ The labels.
398
+
399
+ Returns
400
+ -------
401
+ params : dict
402
+ Contains
403
+
404
+ * permutation : float
405
+ List of integers defining the new channels order.
406
+ """
407
+ return {"permutation": self.permutation}
408
+
409
+
410
+ class SmoothTimeMask(Transform):
411
+ """Smoothly replace a randomly chosen contiguous part of all channels by
412
+ zeros.
413
+
414
+ Suggested e.g. in [1]_ and [2]_
415
+
416
+ Parameters
417
+ ----------
418
+ probability : float
419
+ Float setting the probability of applying the operation.
420
+ mask_len_samples : int | torch.Tensor, optional
421
+ Number of consecutive samples to zero out. Will be ignored if
422
+ magnitude is not set to None. Defaults to 100.
423
+ random_state: int | numpy.random.Generator, optional
424
+ Seed to be used to instantiate numpy random number generator instance.
425
+ Defaults to None.
426
+
427
+ References
428
+ ----------
429
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
430
+ Subject-aware contrastive learning for biosignals. arXiv preprint
431
+ arXiv:2007.04871.
432
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
433
+ Representation Learning for Electroencephalogram Classification. In
434
+ Machine Learning for Health (pp. 238-253). PMLR.
435
+ """
436
+
437
+ operation = staticmethod(smooth_time_mask) # type: ignore[assignment]
438
+
439
+ def __init__(self, probability, mask_len_samples=100, random_state=None):
440
+ super().__init__(
441
+ probability=probability,
442
+ random_state=random_state,
443
+ )
444
+
445
+ assert (
446
+ isinstance(mask_len_samples, (int, torch.Tensor)) and mask_len_samples > 0
447
+ ), "mask_len_samples has to be a positive integer"
448
+ self.mask_len_samples = mask_len_samples
449
+
450
+ def get_augmentation_params(self, *batch):
451
+ """Return transform parameters.
452
+
453
+ Parameters
454
+ ----------
455
+ X : tensor.Tensor
456
+ The data.
457
+ y : tensor.Tensor
458
+ The labels.
459
+
460
+ Returns
461
+ -------
462
+ params : dict
463
+ Contains two elements:
464
+
465
+ * mask_start_per_sample : torch.tensor
466
+ Tensor of integers containing the position (in last dimension)
467
+ where to start masking the signal. Should have the same size as
468
+ the first dimension of X (i.e. one start position per example
469
+ in the batch).
470
+ * mask_len_samples : int
471
+ Number of consecutive samples to zero out.
472
+ """
473
+ if len(batch) == 0:
474
+ return super().get_augmentation_params(*batch)
475
+ X = batch[0]
476
+
477
+ seq_length = torch.as_tensor(X.shape[-1], device=X.device)
478
+ mask_len_samples = self.mask_len_samples
479
+ if isinstance(mask_len_samples, torch.Tensor):
480
+ mask_len_samples = mask_len_samples.to(X.device)
481
+ mask_start = torch.as_tensor(
482
+ self.rng.uniform(
483
+ low=0,
484
+ high=1,
485
+ size=X.shape[0],
486
+ ),
487
+ device=X.device,
488
+ ) * (seq_length - mask_len_samples)
489
+ return {
490
+ "mask_start_per_sample": mask_start,
491
+ "mask_len_samples": mask_len_samples,
492
+ }
493
+
494
+
495
+ class BandstopFilter(Transform):
496
+ """Apply a band-stop filter with desired bandwidth at a randomly selected
497
+ frequency position between 0 and ``max_freq``.
498
+
499
+ Suggested e.g. in [1]_ and [2]_
500
+
501
+ Parameters
502
+ ----------
503
+ probability : float
504
+ Float setting the probability of applying the operation.
505
+ bandwidth : float
506
+ Bandwidth of the filter, i.e. distance between the low and high cut
507
+ frequencies.
508
+ sfreq : float, optional
509
+ Sampling frequency of the signals to be filtered. Defaults to 100 Hz.
510
+ max_freq : float | None, optional
511
+ Maximal admissible frequency. The low cut frequency will be sampled so
512
+ that the corresponding high cut frequency + transition (=1Hz) are below
513
+ ``max_freq``. If omitted or `None`, will default to the Nyquist
514
+ frequency (``sfreq / 2``).
515
+ random_state: int | numpy.random.Generator, optional
516
+ Seed to be used to instantiate numpy random number generator instance.
517
+ Defaults to None.
518
+
519
+ References
520
+ ----------
521
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
522
+ Subject-aware contrastive learning for biosignals. arXiv preprint
523
+ arXiv:2007.04871.
524
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
525
+ Representation Learning for Electroencephalogram Classification. In
526
+ Machine Learning for Health (pp. 238-253). PMLR.
527
+ """
528
+
529
+ operation = staticmethod(bandstop_filter) # type: ignore[assignment]
530
+
531
+ def __init__(
532
+ self, probability, sfreq, bandwidth=1, max_freq=None, random_state=None
533
+ ):
534
+ super().__init__(
535
+ probability=probability,
536
+ random_state=random_state,
537
+ )
538
+ assert isinstance(bandwidth, Real) and bandwidth >= 0, (
539
+ "bandwidth should be a non-negative float."
540
+ )
541
+ assert isinstance(sfreq, Real) and sfreq > 0, (
542
+ "sfreq should be a positive float."
543
+ )
544
+ if max_freq is not None:
545
+ assert isinstance(max_freq, Real) and max_freq > 0, (
546
+ "max_freq should be a positive float."
547
+ )
548
+ nyq = sfreq / 2
549
+ if max_freq is None or max_freq > nyq:
550
+ max_freq = nyq
551
+ warnings.warn(
552
+ "You either passed None or a frequency greater than the"
553
+ f" Nyquist frequency ({nyq} Hz)."
554
+ f" Falling back to max_freq = {nyq}."
555
+ )
556
+ assert bandwidth < max_freq, (
557
+ f"`bandwidth` needs to be smaller than max_freq={max_freq}"
558
+ )
559
+
560
+ # override bandwidth value when a magnitude is passed
561
+ self.sfreq = sfreq
562
+ self.max_freq = max_freq
563
+ self.bandwidth = bandwidth
564
+
565
+ def get_augmentation_params(self, *batch):
566
+ """Return transform parameters.
567
+
568
+ Parameters
569
+ ----------
570
+ X : tensor.Tensor
571
+ The data.
572
+ y : tensor.Tensor
573
+ The labels.
574
+
575
+ Returns
576
+ -------
577
+ params : dict
578
+ Contains
579
+
580
+ * sfreq : float
581
+ Sampling frequency of the signals to be filtered.
582
+ * bandwidth : float
583
+ Bandwidth of the filter, i.e. distance between the low and high
584
+ cut frequencies.
585
+ * freqs_to_notch : array-like | None
586
+ Array of floats of size ``(batch_size,)`` containing the center
587
+ of the frequency band to filter out for each sample in the
588
+ batch. Frequencies should be greater than
589
+ ``bandwidth/2 + transition`` and lower than
590
+ ``sfreq/2 - bandwidth/2 - transition`` (where
591
+ ``transition = 1 Hz``).
592
+ """
593
+ if len(batch) == 0:
594
+ return super().get_augmentation_params(*batch)
595
+ X = batch[0]
596
+
597
+ # Prevents transitions from going below 0 and above max_freq
598
+ notched_freqs = self.rng.uniform(
599
+ low=1 + 2 * self.bandwidth,
600
+ high=self.max_freq - 1 - 2 * self.bandwidth,
601
+ size=X.shape[0],
602
+ )
603
+ return {
604
+ "sfreq": self.sfreq,
605
+ "bandwidth": self.bandwidth,
606
+ "freqs_to_notch": notched_freqs,
607
+ }
608
+
609
+
610
+ class FrequencyShift(Transform):
611
+ """Add a random shift in the frequency domain to all channels.
612
+
613
+ Note that here, the shift is the same for all channels of a single example.
614
+
615
+ Parameters
616
+ ----------
617
+ probability : float
618
+ Float setting the probability of applying the operation.
619
+ sfreq : float
620
+ Sampling frequency of the signals to be transformed.
621
+ max_delta_freq : float | torch.Tensor, optional
622
+ Maximum shift in Hz that can be sampled (in absolute value).
623
+ Defaults to 2 (shift sampled between -2 and 2 Hz).
624
+ random_state: int | numpy.random.Generator, optional
625
+ Seed to be used to instantiate numpy random number generator instance.
626
+ Defaults to None.
627
+ """
628
+
629
+ operation = staticmethod(frequency_shift) # type: ignore[assignment]
630
+
631
+ def __init__(self, probability, sfreq, max_delta_freq=2, random_state=None):
632
+ super().__init__(
633
+ probability=probability,
634
+ random_state=random_state,
635
+ )
636
+ assert isinstance(sfreq, Real) and sfreq > 0, (
637
+ "sfreq should be a positive float."
638
+ )
639
+ self.sfreq = sfreq
640
+
641
+ self.max_delta_freq = max_delta_freq
642
+
643
+ def get_augmentation_params(self, *batch):
644
+ """Return transform parameters.
645
+
646
+ Parameters
647
+ ----------
648
+ X : tensor.Tensor
649
+ The data.
650
+ y : tensor.Tensor
651
+ The labels.
652
+
653
+ Returns
654
+ -------
655
+ params : dict
656
+ Contains
657
+
658
+ * delta_freq : float
659
+ The amplitude of the frequency shift (in Hz).
660
+ * sfreq : float
661
+ Sampling frequency of the signals to be transformed.
662
+ """
663
+ if len(batch) == 0:
664
+ return super().get_augmentation_params(*batch)
665
+ X = batch[0]
666
+
667
+ u = torch.as_tensor(self.rng.uniform(size=X.shape[0]), device=X.device)
668
+ max_delta_freq = self.max_delta_freq
669
+ if isinstance(max_delta_freq, torch.Tensor):
670
+ max_delta_freq = max_delta_freq.to(X.device)
671
+ delta_freq = u * 2 * max_delta_freq - max_delta_freq
672
+ return {
673
+ "delta_freq": delta_freq,
674
+ "sfreq": self.sfreq,
675
+ }
676
+
677
+
678
+ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
679
+ """Returns standard 10-20 sensors position matrix (for instantiating
680
+ SensorsRotation for example).
681
+
682
+ Parameters
683
+ ----------
684
+ raw_or_epoch : mne.io.Raw | mne.Epoch, optional
685
+ Example of raw or epoch to retrieve ordered channels list from. Need to
686
+ be named as in 10-20. By default None.
687
+ ordered_ch_names : list, optional
688
+ List of strings representing the channels of the montage considered.
689
+ The order has to be consistent with the order of channels in the input
690
+ matrices that will be fed to `SensorsRotation` transform. By
691
+ default None.
692
+ """
693
+ assert raw_or_epoch is not None or ordered_ch_names is not None, (
694
+ "At least one of raw_or_epoch and ordered_ch_names is needed."
695
+ )
696
+ if ordered_ch_names is None:
697
+ ordered_ch_names = raw_or_epoch.info["ch_names"]
698
+ ten_twenty_montage = make_standard_montage("standard_1020")
699
+ positions_dict = ten_twenty_montage.get_positions()["ch_pos"]
700
+ positions_subdict = {
701
+ k: positions_dict[k] for k in ordered_ch_names if k in positions_dict
702
+ }
703
+ return np.stack(list(positions_subdict.values())).T
704
+
705
+
706
+ class SensorsRotation(Transform):
707
+ """Interpolates EEG signals over sensors rotated around the desired axis
708
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
709
+
710
+ Suggested in [1]_
711
+
712
+ Parameters
713
+ ----------
714
+ probability : float
715
+ Float setting the probability of applying the operation.
716
+ sensors_positions_matrix : numpy.ndarray
717
+ Matrix giving the positions of each sensor in a 3D cartesian coordinate
718
+ system. Should have shape (3, n_channels), where n_channels is the
719
+ number of channels. Standard 10-20 positions can be obtained from
720
+ `mne` through::
721
+
722
+ >>> ten_twenty_montage = mne.channels.make_standard_montage(
723
+ ... 'standard_1020'
724
+ ... ).get_positions()['ch_pos']
725
+
726
+ axis : 'x' | 'y' | 'z', optional
727
+ Axis around which to rotate. Defaults to 'z'.
728
+ max_degree : float, optional
729
+ Maximum rotation. Rotation angles will be sampled between
730
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
731
+ spherical_splines : bool, optional
732
+ Whether to use spherical splines for the interpolation or not. When
733
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
734
+ be used (as in the original paper). Defaults to True.
735
+ random_state: int | numpy.random.Generator, optional
736
+ Seed to be used to instantiate numpy random number generator instance.
737
+ Defaults to None.
738
+
739
+ References
740
+ ----------
741
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
742
+ electroencephalographic data. In 2017 39th Annual International
743
+ Conference of the IEEE Engineering in Medicine and Biology Society
744
+ (EMBC) (pp. 471-474).
745
+ """
746
+
747
+ operation = staticmethod(sensors_rotation) # type: ignore[assignment]
748
+
749
+ def __init__(
750
+ self,
751
+ probability,
752
+ sensors_positions_matrix,
753
+ axis="z",
754
+ max_degrees=15,
755
+ spherical_splines=True,
756
+ random_state=None,
757
+ ):
758
+ super().__init__(probability=probability, random_state=random_state)
759
+ if isinstance(sensors_positions_matrix, (np.ndarray, list)):
760
+ sensors_positions_matrix = torch.as_tensor(sensors_positions_matrix)
761
+ assert isinstance(sensors_positions_matrix, torch.Tensor), (
762
+ "sensors_positions should be an Tensor"
763
+ )
764
+ assert isinstance(max_degrees, (Real, torch.Tensor)) and max_degrees >= 0, (
765
+ "max_degrees should be non-negative float."
766
+ )
767
+ assert isinstance(axis, str) and axis in [
768
+ "x",
769
+ "y",
770
+ "z",
771
+ ], "axis can be either x, y or z."
772
+ assert sensors_positions_matrix.shape[0] == 3, (
773
+ "sensors_positions_matrix shape should be 3 x n_channels."
774
+ )
775
+ assert isinstance(spherical_splines, bool), (
776
+ "spherical_splines should be a boolean"
777
+ )
778
+ self.sensors_positions_matrix = sensors_positions_matrix
779
+ self.axis = axis
780
+ self.spherical_splines = spherical_splines
781
+ self.max_degrees = max_degrees
782
+
783
+ def get_augmentation_params(self, *batch):
784
+ """Return transform parameters.
785
+
786
+ Parameters
787
+ ----------
788
+ X : tensor.Tensor
789
+ The data.
790
+ y : tensor.Tensor
791
+ The labels.
792
+
793
+ Returns
794
+ -------
795
+ params : dict
796
+ Contains four elements:
797
+
798
+ * sensors_positions_matrix : numpy.ndarray
799
+ Matrix giving the positions of each sensor in a 3D cartesian
800
+ coordinate system. Should have shape (3, n_channels), where
801
+ n_channels is the number of channels.
802
+ * axis : 'x' | 'y' | 'z'
803
+ Axis around which to rotate.
804
+ * angles : array-like
805
+ Array of float of shape ``(batch_size,)`` containing the
806
+ rotation angles (in degrees) for each element of the input
807
+ batch, sampled uniformly between ``-max_degrees``and
808
+ ``max_degrees``.
809
+ * spherical_splines : bool
810
+ Whether to use spherical splines for the interpolation or not.
811
+ When ``False``, standard scipy.interpolate.Rbf (with quadratic
812
+ kernel) will be used (as in the original paper).
813
+ """
814
+ if len(batch) == 0:
815
+ return super().get_augmentation_params(*batch)
816
+ X = batch[0]
817
+
818
+ u = self.rng.uniform(low=0, high=1, size=X.shape[0])
819
+ max_degrees = self.max_degrees
820
+ if isinstance(max_degrees, torch.Tensor):
821
+ max_degrees = max_degrees.to(X.device)
822
+ random_angles = (
823
+ torch.as_tensor(u, device=X.device) * 2 * max_degrees - max_degrees
824
+ )
825
+ return {
826
+ "sensors_positions_matrix": self.sensors_positions_matrix,
827
+ "axis": self.axis,
828
+ "angles": random_angles,
829
+ "spherical_splines": self.spherical_splines,
830
+ }
831
+
832
+
833
+ class SensorsZRotation(SensorsRotation):
834
+ """Interpolates EEG signals over sensors rotated around the Z axis
835
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
836
+
837
+ Suggested in [1]_
838
+
839
+ Parameters
840
+ ----------
841
+ probability : float
842
+ Float setting the probability of applying the operation.
843
+ ordered_ch_names : list
844
+ List of strings representing the channels of the montage considered.
845
+ Has to be in standard 10-20 style. The order has to be consistent with
846
+ the order of channels in the input matrices that will be fed to the
847
+ transform. This channel will be used to compute approximate sensors
848
+ positions from a standard 10-20 montage.
849
+ max_degree : float, optional
850
+ Maximum rotation. Rotation angles will be sampled between
851
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
852
+ spherical_splines : bool, optional
853
+ Whether to use spherical splines for the interpolation or not. When
854
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
855
+ be used (as in the original paper). Defaults to True.
856
+ random_state: int | numpy.random.Generator, optional
857
+ Seed to be used to instantiate numpy random number generator instance.
858
+ Defaults to None.
859
+
860
+ References
861
+ ----------
862
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
863
+ electroencephalographic data. In 2017 39th Annual International
864
+ Conference of the IEEE Engineering in Medicine and Biology Society
865
+ (EMBC) (pp. 471-474).
866
+ """
867
+
868
+ def __init__(
869
+ self,
870
+ probability,
871
+ ordered_ch_names,
872
+ max_degrees=15,
873
+ spherical_splines=True,
874
+ random_state=None,
875
+ ):
876
+ sensors_positions_matrix = torch.as_tensor(
877
+ _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
878
+ )
879
+ super().__init__(
880
+ probability=probability,
881
+ sensors_positions_matrix=sensors_positions_matrix,
882
+ axis="z",
883
+ max_degrees=max_degrees,
884
+ spherical_splines=spherical_splines,
885
+ random_state=random_state,
886
+ )
887
+
888
+
889
+ class SensorsYRotation(SensorsRotation):
890
+ """Interpolates EEG signals over sensors rotated around the Y axis
891
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
892
+
893
+ Suggested in [1]_
894
+
895
+ Parameters
896
+ ----------
897
+ probability : float
898
+ Float setting the probability of applying the operation.
899
+ ordered_ch_names : list
900
+ List of strings representing the channels of the montage considered.
901
+ Has to be in standard 10-20 style. The order has to be consistent with
902
+ the order of channels in the input matrices that will be fed to the
903
+ transform. This channel will be used to compute approximate sensors
904
+ positions from a standard 10-20 montage.
905
+ max_degree : float, optional
906
+ Maximum rotation. Rotation angles will be sampled between
907
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
908
+ spherical_splines : bool, optional
909
+ Whether to use spherical splines for the interpolation or not. When
910
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
911
+ be used (as in the original paper). Defaults to True.
912
+ random_state: int | numpy.random.Generator, optional
913
+ Seed to be used to instantiate numpy random number generator instance.
914
+ Defaults to None.
915
+
916
+ References
917
+ ----------
918
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
919
+ electroencephalographic data. In 2017 39th Annual International
920
+ Conference of the IEEE Engineering in Medicine and Biology Society
921
+ (EMBC) (pp. 471-474).
922
+ """
923
+
924
+ def __init__(
925
+ self,
926
+ probability,
927
+ ordered_ch_names,
928
+ max_degrees=15,
929
+ spherical_splines=True,
930
+ random_state=None,
931
+ ):
932
+ sensors_positions_matrix = torch.as_tensor(
933
+ _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
934
+ )
935
+ super().__init__(
936
+ probability=probability,
937
+ sensors_positions_matrix=sensors_positions_matrix,
938
+ axis="y",
939
+ max_degrees=max_degrees,
940
+ spherical_splines=spherical_splines,
941
+ random_state=random_state,
942
+ )
943
+
944
+
945
+ class SensorsXRotation(SensorsRotation):
946
+ """Interpolates EEG signals over sensors rotated around the X axis
947
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
948
+
949
+ Suggested in [1]_
950
+
951
+ Parameters
952
+ ----------
953
+ probability : float
954
+ Float setting the probability of applying the operation.
955
+ ordered_ch_names : list
956
+ List of strings representing the channels of the montage considered.
957
+ Has to be in standard 10-20 style. The order has to be consistent with
958
+ the order of channels in the input matrices that will be fed to the
959
+ transform. This channel will be used to compute approximate sensors
960
+ positions from a standard 10-20 montage.
961
+ max_degree : float, optional
962
+ Maximum rotation. Rotation angles will be sampled between
963
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
964
+ spherical_splines : bool, optional
965
+ Whether to use spherical splines for the interpolation or not. When
966
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
967
+ be used (as in the original paper). Defaults to True.
968
+ random_state: int | numpy.random.Generator, optional
969
+ Seed to be used to instantiate numpy random number generator instance.
970
+ Defaults to None.
971
+
972
+ References
973
+ ----------
974
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
975
+ electroencephalographic data. In 2017 39th Annual International
976
+ Conference of the IEEE Engineering in Medicine and Biology Society
977
+ (EMBC) (pp. 471-474).
978
+ """
979
+
980
+ def __init__(
981
+ self,
982
+ probability,
983
+ ordered_ch_names,
984
+ max_degrees=15,
985
+ spherical_splines=True,
986
+ random_state=None,
987
+ ):
988
+ sensors_positions_matrix = torch.as_tensor(
989
+ _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
990
+ )
991
+ super().__init__(
992
+ probability=probability,
993
+ sensors_positions_matrix=sensors_positions_matrix,
994
+ axis="x",
995
+ max_degrees=max_degrees,
996
+ spherical_splines=spherical_splines,
997
+ random_state=random_state,
998
+ )
999
+
1000
+
1001
+ class Mixup(Transform):
1002
+ """Implements Iterator for Mixup for EEG data. See [1]_.
1003
+ Implementation based on [2]_.
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ alpha: float
1008
+ Mixup hyperparameter.
1009
+ beta_per_sample: bool (default=False)
1010
+ By default, one mixing coefficient per batch is drawn from a beta
1011
+ distribution. If True, one mixing coefficient per sample is drawn.
1012
+ random_state: int | numpy.random.Generator, optional
1013
+ Seed to be used to instantiate numpy random number generator instance.
1014
+ Defaults to None.
1015
+
1016
+ References
1017
+ ----------
1018
+ .. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
1019
+ (2018). mixup: Beyond Empirical Risk Minimization. In 2018
1020
+ International Conference on Learning Representations (ICLR)
1021
+ Online: https://arxiv.org/abs/1710.09412
1022
+ .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
1023
+ """
1024
+
1025
+ operation = staticmethod(mixup) # type: ignore[assignment]
1026
+
1027
+ def __init__(self, alpha, beta_per_sample=False, random_state=None):
1028
+ super().__init__(
1029
+ probability=1.0, # Mixup has to be applied to whole batches
1030
+ random_state=random_state,
1031
+ )
1032
+ self.alpha = alpha
1033
+ self.beta_per_sample = beta_per_sample
1034
+
1035
+ def get_augmentation_params(self, *batch):
1036
+ """Return transform parameters.
1037
+
1038
+ Parameters
1039
+ ----------
1040
+ X : tensor.Tensor
1041
+ The data.
1042
+ y : tensor.Tensor
1043
+ The labels.
1044
+
1045
+ Returns
1046
+ -------
1047
+ params: dict
1048
+ Contains the values sampled uniformly between 0 and 1 setting the
1049
+ linear interpolation between examples (lam) and the shuffled
1050
+ indices of examples that are mixed into original examples
1051
+ (idx_perm).
1052
+ """
1053
+ X = batch[0]
1054
+ device = X.device
1055
+ batch_size, _, _ = X.shape
1056
+
1057
+ if self.alpha > 0:
1058
+ if self.beta_per_sample:
1059
+ lam = torch.as_tensor(
1060
+ self.rng.beta(self.alpha, self.alpha, batch_size)
1061
+ ).to(device)
1062
+ else:
1063
+ lam = torch.ones(batch_size).to(device)
1064
+ lam *= self.rng.beta(self.alpha, self.alpha)
1065
+ else:
1066
+ lam = torch.ones(batch_size).to(device)
1067
+
1068
+ idx_perm = torch.as_tensor(
1069
+ self.rng.permutation(
1070
+ batch_size,
1071
+ )
1072
+ )
1073
+
1074
+ return {
1075
+ "lam": lam,
1076
+ "idx_perm": idx_perm,
1077
+ }
1078
+
1079
+
1080
+ class SegmentationReconstruction(Transform):
1081
+ """Segmentation Reconstruction from Lotte (2015) [Lotte2015]_.
1082
+
1083
+ Applies a segmentation-reconstruction transform to the input data, as
1084
+ proposed in [Lotte2015]_. It segments each trial in the batch and randomly mix
1085
+ it to generate new synthetic trials by label, preserving the original
1086
+ order of the segments in time domain.
1087
+
1088
+ Parameters
1089
+ ----------
1090
+ probability : float
1091
+ Float setting the probability of applying the operation.
1092
+ random_state: int | numpy.random.Generator, optional
1093
+ Seed to be used to instantiate numpy random number generator instance.
1094
+ Used to decide whether to transform given the probability
1095
+ argument and to sample the segments mixing. Defaults to None.
1096
+ n_segments : int, optional
1097
+ Number of segments to use in the batch. If None, X will be
1098
+ automatically segmented, getting the last element in a list
1099
+ of factors of the number of samples's square root. Defaults to None.
1100
+
1101
+ References
1102
+ ----------
1103
+ .. [Lotte2015] Lotte, F. (2015). Signal processing approaches to minimize
1104
+ or suppress calibration time in oscillatory activity-based brain–computer
1105
+ interfaces. Proceedings of the IEEE, 103(6), 871-890.
1106
+ """
1107
+
1108
+ operation = staticmethod(segmentation_reconstruction) # type: ignore[assignment]
1109
+
1110
+ def __init__(
1111
+ self,
1112
+ probability,
1113
+ n_segments=None,
1114
+ random_state=None,
1115
+ ):
1116
+ super().__init__(
1117
+ probability=probability,
1118
+ random_state=random_state,
1119
+ )
1120
+ self.n_segments = n_segments
1121
+
1122
+ def get_augmentation_params(self, *batch):
1123
+ """Return transform parameters.
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ X : tensor.Tensor
1128
+ The data.
1129
+ y : tensor.Tensor
1130
+ The labels.
1131
+ Returns
1132
+ -------
1133
+ params : dict
1134
+ Contains the number of segments to split the signal into.
1135
+ """
1136
+ X, y = batch[0], batch[1]
1137
+
1138
+ if y is not None:
1139
+ if not isinstance(X, torch.Tensor) or not isinstance(y, torch.Tensor):
1140
+ raise ValueError("X and y must be torch tensors.")
1141
+
1142
+ if X.shape[0] != y.shape[0]:
1143
+ raise ValueError("Number of samples in X and y must be the same.")
1144
+
1145
+ if self.n_segments is None:
1146
+ self.n_segments = int(X.shape[2])
1147
+ n_segments_list = []
1148
+ for i in range(1, int(self.n_segments**0.5) + 1):
1149
+ if self.n_segments % i == 0:
1150
+ n_segments_list.append(i)
1151
+ self.n_segments = n_segments_list[-1]
1152
+
1153
+ elif not (
1154
+ isinstance(self.n_segments, (int, float))
1155
+ and 1 <= self.n_segments <= X.shape[2]
1156
+ ):
1157
+ raise ValueError(
1158
+ f"Number of segments must be a positive integer less than "
1159
+ f"(or equal) the window size. Got {self.n_segments}"
1160
+ )
1161
+
1162
+ if y is None:
1163
+ data_classes = [(np.nan, X)]
1164
+
1165
+ else:
1166
+ classes = torch.unique(y)
1167
+
1168
+ data_classes = [(i, X[y == i]) for i in classes]
1169
+
1170
+ rand_indices = dict()
1171
+ for label, X_class in data_classes:
1172
+ n_trials = X_class.shape[0]
1173
+ rand_indices[label] = self.rng.randint(
1174
+ 0, n_trials, (n_trials, self.n_segments)
1175
+ )
1176
+
1177
+ idx_shuffle = self.rng.permutation(X.shape[0])
1178
+
1179
+ return {
1180
+ "n_segments": self.n_segments,
1181
+ "data_classes": data_classes,
1182
+ "rand_indices": rand_indices,
1183
+ "idx_shuffle": idx_shuffle,
1184
+ }
1185
+
1186
+
1187
+ class MaskEncoding(Transform):
1188
+ """MaskEncoding from [1]_.
1189
+
1190
+ Replaces randomly chosen contiguous part (or parts) of all channels by
1191
+ zeros (if more than one segment, it may overlap).
1192
+
1193
+ Implementation based on [1]_
1194
+
1195
+ Parameters
1196
+ ----------
1197
+ probability : float
1198
+ Float setting the probability of applying the operation.
1199
+ max_mask_ratio: float, optional
1200
+ Signal ratio to zero out. Defaults to 0.1.
1201
+ n_segments : int, optional
1202
+ Number of segments to zero out in each example.
1203
+ Defaults to 1.
1204
+ random_state: int | numpy.random.Generator, optional
1205
+ Seed to be used to instantiate numpy random number generator instance.
1206
+ Defaults to None.
1207
+
1208
+ References
1209
+ ----------
1210
+ .. [1] Ding, Wenlong, et al. "A Novel Data Augmentation Approach
1211
+ Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI."
1212
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering
1213
+ 32 (2024): 875-886.
1214
+ """
1215
+
1216
+ operation = staticmethod(mask_encoding) # type: ignore[assignment]
1217
+
1218
+ def __init__(
1219
+ self,
1220
+ probability,
1221
+ max_mask_ratio=0.1,
1222
+ n_segments=1,
1223
+ random_state=None,
1224
+ ):
1225
+ super().__init__(
1226
+ probability=probability,
1227
+ random_state=random_state,
1228
+ )
1229
+ assert isinstance(n_segments, int) and n_segments > 0, (
1230
+ "n_segments should be a positive integer."
1231
+ )
1232
+ assert isinstance(max_mask_ratio, (int, float)) and 0 <= max_mask_ratio <= 1, (
1233
+ "mask_ratio should be a float between 0 and 1."
1234
+ )
1235
+
1236
+ self.mask_ratio = max_mask_ratio
1237
+ self.n_segments = n_segments
1238
+
1239
+ def get_augmentation_params(self, *batch):
1240
+ """Return transform parameters.
1241
+
1242
+ Parameters
1243
+ ----------
1244
+ X : tensor.Tensor
1245
+ The data.
1246
+ y : tensor.Tensor
1247
+ The labels.
1248
+ Returns
1249
+ -------
1250
+ params : dict
1251
+ Contains ...
1252
+ """
1253
+ if len(batch) == 0:
1254
+ return super().get_augmentation_params(*batch)
1255
+ X = batch[0]
1256
+
1257
+ batch_size, _, n_times = X.shape
1258
+
1259
+ segment_length = int((n_times * self.mask_ratio) / self.n_segments)
1260
+
1261
+ assert segment_length >= 1, (
1262
+ "n_segments should be a positive integer not higher than (max_mask_ratio * window size)."
1263
+ )
1264
+
1265
+ time_start = self.rng.randint(
1266
+ 0, n_times - segment_length, (batch_size, self.n_segments)
1267
+ )
1268
+ time_start = torch.from_numpy(time_start)
1269
+
1270
+ return {
1271
+ "time_start": time_start,
1272
+ "segment_length": segment_length,
1273
+ "n_segments": self.n_segments,
1274
+ }