braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1356 @@
1
+ # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Gustavo Rodrigues <gustavenrique01@gmail.com>
4
+ # Bruna Lopes <brunajaflopes@gmail.com>
5
+ #
6
+ # License: BSD (3-clause)
7
+
8
+ import warnings
9
+ from numbers import Real
10
+
11
+ import numpy as np
12
+ import torch
13
+ from mne.channels import make_standard_montage
14
+
15
+ from .base import Transform
16
+ from .functional import (
17
+ amplitude_scale,
18
+ bandstop_filter,
19
+ channels_dropout,
20
+ channels_permute,
21
+ channels_rereference,
22
+ channels_shuffle,
23
+ frequency_shift,
24
+ ft_surrogate,
25
+ gaussian_noise,
26
+ mask_encoding,
27
+ mixup,
28
+ segmentation_reconstruction,
29
+ sensors_rotation,
30
+ sign_flip,
31
+ smooth_time_mask,
32
+ time_reverse,
33
+ )
34
+
35
+
36
+ class TimeReverse(Transform):
37
+ """Flip the time axis of each input with a given probability.
38
+
39
+ Parameters
40
+ ----------
41
+ probability : float
42
+ Float setting the probability of applying the operation.
43
+ random_state : int | numpy.random.Generator, optional
44
+ Seed to be used to instantiate numpy random number generator instance.
45
+ Used to decide whether or not to transform given the probability
46
+ argument. Defaults to None.
47
+ """
48
+
49
+ operation = staticmethod(time_reverse) # type: ignore[assignment]
50
+
51
+ def __init__(
52
+ self,
53
+ probability,
54
+ random_state=None,
55
+ ):
56
+ super().__init__(
57
+ probability=probability,
58
+ random_state=random_state,
59
+ )
60
+
61
+
62
+ class SignFlip(Transform):
63
+ """Flip the sign axis of each input with a given probability.
64
+
65
+ Parameters
66
+ ----------
67
+ probability : float
68
+ Float setting the probability of applying the operation.
69
+ random_state : int | numpy.random.Generator, optional
70
+ Seed to be used to instantiate numpy random number generator instance.
71
+ Used to decide whether or not to transform given the probability
72
+ argument. Defaults to None.
73
+ """
74
+
75
+ operation = staticmethod(sign_flip) # type: ignore[assignment]
76
+
77
+ def __init__(self, probability, random_state=None):
78
+ super().__init__(probability=probability, random_state=random_state)
79
+
80
+
81
+ class FTSurrogate(Transform):
82
+ """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
83
+
84
+ Parameters
85
+ ----------
86
+ probability : float
87
+ Float setting the probability of applying the operation.
88
+ phase_noise_magnitude : float | torch.Tensor, optional
89
+ Float between 0 and 1 setting the range over which the phase
90
+ perturbation is uniformly sampled:
91
+ ``[0, phase_noise_magnitude * 2 * pi]``. Defaults to 1.
92
+ channel_indep : bool, optional
93
+ Whether to sample phase perturbations independently for each channel or
94
+ not. It is advised to set it to False when spatial information is
95
+ important for the task, like in BCI. Default False.
96
+ random_state : int | numpy.random.Generator, optional
97
+ Seed to be used to instantiate numpy random number generator instance.
98
+ Used to decide whether or not to transform given the probability
99
+ argument. Defaults to None.
100
+
101
+ References
102
+ ----------
103
+ .. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
104
+ Clifford, G. D. (2018). Addressing Class Imbalance in Classification
105
+ Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
106
+ preprint arXiv:1806.08675.
107
+ """
108
+
109
+ operation = staticmethod(ft_surrogate) # type: ignore[assignment]
110
+
111
+ def __init__(
112
+ self,
113
+ probability,
114
+ phase_noise_magnitude=1,
115
+ channel_indep=False,
116
+ random_state=None,
117
+ ):
118
+ super().__init__(probability=probability, random_state=random_state)
119
+ assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), (
120
+ "phase_noise_magnitude should be a float."
121
+ )
122
+ assert 0 <= phase_noise_magnitude <= 1, (
123
+ "phase_noise_magnitude should be between 0 and 1."
124
+ )
125
+ assert isinstance(channel_indep, bool), (
126
+ "channel_indep is expected to be a boolean"
127
+ )
128
+ self.phase_noise_magnitude = phase_noise_magnitude
129
+ self.channel_indep = channel_indep
130
+
131
+ def get_augmentation_params(self, *batch):
132
+ """Return transform parameters.
133
+
134
+ Parameters
135
+ ----------
136
+ X : tensor.Tensor
137
+ The data.
138
+ y : tensor.Tensor
139
+ The labels.
140
+
141
+ Returns
142
+ -------
143
+ params : dict
144
+ Contains:
145
+
146
+ * phase_noise_magnitude : float
147
+ The magnitude of the transformation.
148
+ * random_state : numpy.random.Generator
149
+ The generator to use.
150
+ """
151
+ return {
152
+ "phase_noise_magnitude": self.phase_noise_magnitude,
153
+ "channel_indep": self.channel_indep,
154
+ "random_state": self.rng,
155
+ }
156
+
157
+
158
+ class ChannelsDropout(Transform):
159
+ """Randomly set channels to flat signal.
160
+
161
+ Part of the CMSAugment policy proposed in [1]_
162
+
163
+ Parameters
164
+ ----------
165
+ probability : float
166
+ Float setting the probability of applying the operation.
167
+ p_drop : float | None, optional
168
+ Float between 0 and 1 setting the probability of dropping each channel.
169
+ Defaults to 0.2.
170
+ random_state : int | numpy.random.RandomState, optional
171
+ Seed to be used to instantiate numpy random number generator instance.
172
+ Used to decide whether or not to transform given the probability
173
+ argument and to sample channels to erase. Defaults to None.
174
+
175
+ References
176
+ ----------
177
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
178
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
179
+ Reordering. arXiv preprint arXiv:2010.13694.
180
+ """
181
+
182
+ operation = staticmethod(channels_dropout) # type: ignore[assignment]
183
+
184
+ def __init__(self, probability, p_drop=0.2, random_state=None):
185
+ super().__init__(probability=probability, random_state=random_state)
186
+ self.p_drop = p_drop
187
+
188
+ def get_augmentation_params(self, *batch):
189
+ """Return transform parameters.
190
+
191
+ Parameters
192
+ ----------
193
+ X : tensor.Tensor
194
+ The data.
195
+ y : tensor.Tensor
196
+ The labels.
197
+
198
+ Returns
199
+ -------
200
+ params : dict
201
+ Contains
202
+
203
+ * p_drop : float
204
+ Float between 0 and 1 setting the probability of dropping each
205
+ channel.
206
+ * random_state : numpy.random.Generator
207
+ The generator to use.
208
+ """
209
+ return {
210
+ "p_drop": self.p_drop,
211
+ "random_state": self.rng,
212
+ }
213
+
214
+
215
+ class ChannelsShuffle(Transform):
216
+ """Randomly shuffle channels in EEG data matrix.
217
+
218
+ Part of the CMSAugment policy proposed in [1]_
219
+
220
+ Parameters
221
+ ----------
222
+ probability : float
223
+ Float setting the probability of applying the operation.
224
+ p_shuffle : float | None, optional
225
+ Float between 0 and 1 setting the probability of including the channel
226
+ in the set of permuted channels. Defaults to 0.2.
227
+ random_state : int | numpy.random.Generator, optional
228
+ Seed to be used to instantiate numpy random number generator instance.
229
+ Used to decide whether or not to transform given the probability
230
+ argument, to sample which channels to shuffle and to carry the shuffle.
231
+ Defaults to None.
232
+
233
+ References
234
+ ----------
235
+ .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
236
+ Learning from Heterogeneous EEG Signals with Differentiable Channel
237
+ Reordering. arXiv preprint arXiv:2010.13694.
238
+ """
239
+
240
+ operation = staticmethod(channels_shuffle) # type: ignore[assignment]
241
+
242
+ def __init__(self, probability, p_shuffle=0.2, random_state=None):
243
+ super().__init__(probability=probability, random_state=random_state)
244
+ self.p_shuffle = p_shuffle
245
+
246
+ def get_augmentation_params(self, *batch):
247
+ """Return transform parameters.
248
+
249
+ Parameters
250
+ ----------
251
+ X : tensor.Tensor
252
+ The data.
253
+ y : tensor.Tensor
254
+ The labels.
255
+
256
+ Returns
257
+ -------
258
+ params : dict
259
+ Contains
260
+
261
+ * p_shuffle : float
262
+ Float between 0 and 1 setting the probability of including the
263
+ channel in the set of permuted channels.
264
+ * random_state : numpy.random.Generator
265
+ The generator to use.
266
+ """
267
+ return {
268
+ "p_shuffle": self.p_shuffle,
269
+ "random_state": self.rng,
270
+ }
271
+
272
+
273
+ class GaussianNoise(Transform):
274
+ """Randomly add white noise to all channels.
275
+
276
+ Suggested e.g. in [1]_, [2]_ and [3]_
277
+
278
+ Parameters
279
+ ----------
280
+ probability : float
281
+ Float setting the probability of applying the operation.
282
+ std : float, optional
283
+ Standard deviation to use for the additive noise. Defaults to 0.1.
284
+ random_state : int | numpy.random.Generator, optional
285
+ Seed to be used to instantiate numpy random number generator instance.
286
+ Defaults to None.
287
+
288
+ References
289
+ ----------
290
+ .. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
291
+ augmentation for eeg-based emotion recognition with deep convolutional
292
+ neural networks. In International Conference on Multimedia Modeling
293
+ (pp. 82-93).
294
+ .. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
295
+ Subject-aware contrastive learning for biosignals. arXiv preprint
296
+ arXiv:2007.04871.
297
+ .. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
298
+ Representation Learning for Electroencephalogram Classification. In
299
+ Machine Learning for Health (pp. 238-253). PMLR.
300
+ """
301
+
302
+ operation = staticmethod(gaussian_noise) # type: ignore[assignment]
303
+
304
+ def __init__(self, probability, std=0.1, random_state=None):
305
+ super().__init__(
306
+ probability=probability,
307
+ random_state=random_state,
308
+ )
309
+ self.std = std
310
+
311
+ def get_augmentation_params(self, *batch):
312
+ """Return transform parameters.
313
+
314
+ Parameters
315
+ ----------
316
+ X : tensor.Tensor
317
+ The data.
318
+ y : tensor.Tensor
319
+ The labels.
320
+
321
+ Returns
322
+ -------
323
+ params : dict
324
+ Contains
325
+
326
+ * std : float
327
+ Standard deviation to use for the additive noise.
328
+ * random_state : numpy.random.Generator
329
+ The generator to use.
330
+ """
331
+ return {
332
+ "std": self.std,
333
+ "random_state": self.rng,
334
+ }
335
+
336
+
337
+ class ChannelsSymmetry(Transform):
338
+ """Permute EEG channels inverting left and right-side sensors.
339
+
340
+ Suggested e.g. in [1]_
341
+
342
+ Parameters
343
+ ----------
344
+ probability : float
345
+ Float setting the probability of applying the operation.
346
+ ordered_ch_names : list
347
+ Ordered list of strings containing the names (in 10-20
348
+ nomenclature) of the EEG channels that will be transformed. The
349
+ first name should correspond the data in the first row of X, the
350
+ second name in the second row and so on.
351
+ random_state : int | numpy.random.Generator, optional
352
+ Seed to be used to instantiate numpy random number generator instance.
353
+ Used to decide whether or not to transform given the probability
354
+ argument. Defaults to None.
355
+
356
+ References
357
+ ----------
358
+ .. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
359
+ (2018). HAMLET: interpretable human and machine co-learning technique.
360
+ arXiv preprint arXiv:1803.09702.
361
+ """
362
+
363
+ operation = staticmethod(channels_permute) # type: ignore[assignment]
364
+
365
+ def __init__(self, probability, ordered_ch_names, random_state=None):
366
+ super().__init__(
367
+ probability=probability,
368
+ random_state=random_state,
369
+ )
370
+ assert isinstance(ordered_ch_names, list) and all(
371
+ isinstance(ch, str) for ch in ordered_ch_names
372
+ ), "ordered_ch_names should be a list of str."
373
+
374
+ permutation = list()
375
+ for idx, ch_name in enumerate(ordered_ch_names):
376
+ new_position = idx
377
+ # Find digits in channel name (assuming 10-20 system)
378
+ d = "".join(list(filter(str.isdigit, ch_name)))
379
+ if len(d) > 0:
380
+ d = int(d)
381
+ if d % 2 == 0: # pair/right electrodes
382
+ sym = d - 1
383
+ else: # odd/left electrodes
384
+ sym = d + 1
385
+ new_channel = ch_name.replace(str(d), str(sym))
386
+ if new_channel in ordered_ch_names:
387
+ new_position = ordered_ch_names.index(new_channel)
388
+ permutation.append(new_position)
389
+ self.permutation = permutation
390
+
391
+ def get_augmentation_params(self, *batch):
392
+ """Return transform parameters.
393
+
394
+ Parameters
395
+ ----------
396
+ X : tensor.Tensor
397
+ The data.
398
+ y : tensor.Tensor
399
+ The labels.
400
+
401
+ Returns
402
+ -------
403
+ params : dict
404
+ Contains
405
+
406
+ * permutation : float
407
+ List of integers defining the new channels order.
408
+ """
409
+ return {"permutation": self.permutation}
410
+
411
+
412
+ class SmoothTimeMask(Transform):
413
+ """Smoothly replace a randomly chosen contiguous part of all channels by.
414
+
415
+ zeros.
416
+
417
+ Suggested e.g. in [1]_ and [2]_
418
+
419
+ Parameters
420
+ ----------
421
+ probability : float
422
+ Float setting the probability of applying the operation.
423
+ mask_len_samples : int | torch.Tensor, optional
424
+ Number of consecutive samples to zero out. Will be ignored if
425
+ magnitude is not set to None. Defaults to 100.
426
+ random_state : int | numpy.random.Generator, optional
427
+ Seed to be used to instantiate numpy random number generator instance.
428
+ Defaults to None.
429
+
430
+ References
431
+ ----------
432
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
433
+ Subject-aware contrastive learning for biosignals. arXiv preprint
434
+ arXiv:2007.04871.
435
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
436
+ Representation Learning for Electroencephalogram Classification. In
437
+ Machine Learning for Health (pp. 238-253). PMLR.
438
+ """
439
+
440
+ operation = staticmethod(smooth_time_mask) # type: ignore[assignment]
441
+
442
+ def __init__(self, probability, mask_len_samples=100, random_state=None):
443
+ super().__init__(
444
+ probability=probability,
445
+ random_state=random_state,
446
+ )
447
+
448
+ assert (
449
+ isinstance(mask_len_samples, (int, torch.Tensor)) and mask_len_samples > 0
450
+ ), "mask_len_samples has to be a positive integer"
451
+ self.mask_len_samples = mask_len_samples
452
+
453
+ def get_augmentation_params(self, *batch):
454
+ """Return transform parameters.
455
+
456
+ Parameters
457
+ ----------
458
+ X : tensor.Tensor
459
+ The data.
460
+ y : tensor.Tensor
461
+ The labels.
462
+
463
+ Returns
464
+ -------
465
+ params : dict
466
+ Contains two elements:
467
+
468
+ * mask_start_per_sample : torch.tensor
469
+ Tensor of integers containing the position (in last dimension)
470
+ where to start masking the signal. Should have the same size as
471
+ the first dimension of X (i.e. one start position per example
472
+ in the batch).
473
+ * mask_len_samples : int
474
+ Number of consecutive samples to zero out.
475
+ """
476
+ if len(batch) == 0:
477
+ return super().get_augmentation_params(*batch)
478
+ X = batch[0]
479
+
480
+ seq_length = torch.as_tensor(X.shape[-1], device=X.device)
481
+ mask_len_samples = self.mask_len_samples
482
+ if isinstance(mask_len_samples, torch.Tensor):
483
+ mask_len_samples = mask_len_samples.to(X.device)
484
+ mask_start = torch.as_tensor(
485
+ self.rng.uniform(
486
+ low=0,
487
+ high=1,
488
+ size=X.shape[0],
489
+ ),
490
+ device=X.device,
491
+ ) * (seq_length - mask_len_samples)
492
+ return {
493
+ "mask_start_per_sample": mask_start,
494
+ "mask_len_samples": mask_len_samples,
495
+ }
496
+
497
+
498
+ class BandstopFilter(Transform):
499
+ """Apply a band-stop filter with desired bandwidth at a randomly selected.
500
+
501
+ frequency position between 0 and ``max_freq``.
502
+
503
+ Suggested e.g. in [1]_ and [2]_
504
+
505
+ Parameters
506
+ ----------
507
+ probability : float
508
+ Float setting the probability of applying the operation.
509
+ bandwidth : float
510
+ Bandwidth of the filter, i.e. distance between the low and high cut
511
+ frequencies.
512
+ sfreq : float, optional
513
+ Sampling frequency of the signals to be filtered. Defaults to 100 Hz.
514
+ max_freq : float | None, optional
515
+ Maximal admissible frequency. The low cut frequency will be sampled so
516
+ that the corresponding high cut frequency + transition (=1Hz) are below
517
+ ``max_freq``. If omitted or `None`, will default to the Nyquist
518
+ frequency (``sfreq / 2``).
519
+ random_state : int | numpy.random.Generator, optional
520
+ Seed to be used to instantiate numpy random number generator instance.
521
+ Defaults to None.
522
+
523
+ References
524
+ ----------
525
+ .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
526
+ Subject-aware contrastive learning for biosignals. arXiv preprint
527
+ arXiv:2007.04871.
528
+ .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
529
+ Representation Learning for Electroencephalogram Classification. In
530
+ Machine Learning for Health (pp. 238-253). PMLR.
531
+ """
532
+
533
+ operation = staticmethod(bandstop_filter) # type: ignore[assignment]
534
+
535
+ def __init__(
536
+ self, probability, sfreq, bandwidth=1, max_freq=None, random_state=None
537
+ ):
538
+ super().__init__(
539
+ probability=probability,
540
+ random_state=random_state,
541
+ )
542
+ assert isinstance(bandwidth, Real) and bandwidth >= 0, (
543
+ "bandwidth should be a non-negative float."
544
+ )
545
+ assert isinstance(sfreq, Real) and sfreq > 0, (
546
+ "sfreq should be a positive float."
547
+ )
548
+ if max_freq is not None:
549
+ assert isinstance(max_freq, Real) and max_freq > 0, (
550
+ "max_freq should be a positive float."
551
+ )
552
+ nyq = sfreq / 2
553
+ if max_freq is None or max_freq > nyq:
554
+ max_freq = nyq
555
+ warnings.warn(
556
+ "You either passed None or a frequency greater than the"
557
+ f" Nyquist frequency ({nyq} Hz)."
558
+ f" Falling back to max_freq = {nyq}."
559
+ )
560
+ assert bandwidth < max_freq, (
561
+ f"`bandwidth` needs to be smaller than max_freq={max_freq}"
562
+ )
563
+
564
+ # override bandwidth value when a magnitude is passed
565
+ self.sfreq = sfreq
566
+ self.max_freq = max_freq
567
+ self.bandwidth = bandwidth
568
+
569
+ def get_augmentation_params(self, *batch):
570
+ """Return transform parameters.
571
+
572
+ Parameters
573
+ ----------
574
+ X : tensor.Tensor
575
+ The data.
576
+ y : tensor.Tensor
577
+ The labels.
578
+
579
+ Returns
580
+ -------
581
+ params : dict
582
+ Contains
583
+
584
+ * sfreq : float
585
+ Sampling frequency of the signals to be filtered.
586
+ * bandwidth : float
587
+ Bandwidth of the filter, i.e. distance between the low and high
588
+ cut frequencies.
589
+ * freqs_to_notch : array-like | None
590
+ Array of floats of size ``(batch_size,)`` containing the center
591
+ of the frequency band to filter out for each sample in the
592
+ batch. Frequencies should be greater than
593
+ ``bandwidth/2 + transition`` and lower than
594
+ ``sfreq/2 - bandwidth/2 - transition`` (where
595
+ ``transition = 1 Hz``).
596
+ """
597
+ if len(batch) == 0:
598
+ return super().get_augmentation_params(*batch)
599
+ X = batch[0]
600
+
601
+ # Prevents transitions from going below 0 and above max_freq
602
+ notched_freqs = self.rng.uniform(
603
+ low=1 + 2 * self.bandwidth,
604
+ high=self.max_freq - 1 - 2 * self.bandwidth,
605
+ size=X.shape[0],
606
+ )
607
+ return {
608
+ "sfreq": self.sfreq,
609
+ "bandwidth": self.bandwidth,
610
+ "freqs_to_notch": notched_freqs,
611
+ }
612
+
613
+
614
+ class FrequencyShift(Transform):
615
+ """Add a random shift in the frequency domain to all channels.
616
+
617
+ Note that here, the shift is the same for all channels of a single example.
618
+
619
+ Parameters
620
+ ----------
621
+ probability : float
622
+ Float setting the probability of applying the operation.
623
+ sfreq : float
624
+ Sampling frequency of the signals to be transformed.
625
+ max_delta_freq : float | torch.Tensor, optional
626
+ Maximum shift in Hz that can be sampled (in absolute value).
627
+ Defaults to 2 (shift sampled between -2 and 2 Hz).
628
+ random_state : int | numpy.random.Generator, optional
629
+ Seed to be used to instantiate numpy random number generator instance.
630
+ Defaults to None.
631
+ """
632
+
633
+ operation = staticmethod(frequency_shift) # type: ignore[assignment]
634
+
635
+ def __init__(self, probability, sfreq, max_delta_freq=2, random_state=None):
636
+ super().__init__(
637
+ probability=probability,
638
+ random_state=random_state,
639
+ )
640
+ assert isinstance(sfreq, Real) and sfreq > 0, (
641
+ "sfreq should be a positive float."
642
+ )
643
+ self.sfreq = sfreq
644
+
645
+ self.max_delta_freq = max_delta_freq
646
+
647
+ def get_augmentation_params(self, *batch):
648
+ """Return transform parameters.
649
+
650
+ Parameters
651
+ ----------
652
+ X : tensor.Tensor
653
+ The data.
654
+ y : tensor.Tensor
655
+ The labels.
656
+
657
+ Returns
658
+ -------
659
+ params : dict
660
+ Contains
661
+
662
+ * delta_freq : float
663
+ The amplitude of the frequency shift (in Hz).
664
+ * sfreq : float
665
+ Sampling frequency of the signals to be transformed.
666
+ """
667
+ if len(batch) == 0:
668
+ return super().get_augmentation_params(*batch)
669
+ X = batch[0]
670
+
671
+ u = torch.as_tensor(self.rng.uniform(size=X.shape[0]), device=X.device)
672
+ max_delta_freq = self.max_delta_freq
673
+ if isinstance(max_delta_freq, torch.Tensor):
674
+ max_delta_freq = max_delta_freq.to(X.device)
675
+ delta_freq = u * 2 * max_delta_freq - max_delta_freq
676
+ return {
677
+ "delta_freq": delta_freq,
678
+ "sfreq": self.sfreq,
679
+ }
680
+
681
+
682
+ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
683
+ """Returns standard 10-20 sensors position matrix (for instantiating.
684
+
685
+ SensorsRotation for example).
686
+
687
+ Parameters
688
+ ----------
689
+ raw_or_epoch : mne.io.Raw | mne.Epoch, optional
690
+ Example of raw or epoch to retrieve ordered channels list from. Need to
691
+ be named as in 10-20. By default None.
692
+ ordered_ch_names : list, optional
693
+ List of strings representing the channels of the montage considered.
694
+ The order has to be consistent with the order of channels in the input
695
+ matrices that will be fed to `SensorsRotation` transform. By
696
+ default None.
697
+ """
698
+ assert raw_or_epoch is not None or ordered_ch_names is not None, (
699
+ "At least one of raw_or_epoch and ordered_ch_names is needed."
700
+ )
701
+ if ordered_ch_names is None:
702
+ ordered_ch_names = raw_or_epoch.info["ch_names"]
703
+ ten_twenty_montage = make_standard_montage("standard_1020")
704
+ positions_dict = ten_twenty_montage.get_positions()["ch_pos"]
705
+ positions_subdict = {
706
+ k: positions_dict[k] for k in ordered_ch_names if k in positions_dict
707
+ }
708
+ return np.stack(list(positions_subdict.values())).T
709
+
710
+
711
+ class SensorsRotation(Transform):
712
+ """Interpolates EEG signals over sensors rotated around the desired axis.
713
+
714
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
715
+
716
+ Suggested in [1]_
717
+
718
+ Parameters
719
+ ----------
720
+ probability : float
721
+ Float setting the probability of applying the operation.
722
+ sensors_positions_matrix : numpy.ndarray
723
+ Matrix giving the positions of each sensor in a 3D cartesian coordinate
724
+ system. Should have shape (3, n_channels), where n_channels is the
725
+ number of channels. Standard 10-20 positions can be obtained from
726
+ `mne` through::
727
+
728
+ >>> ten_twenty_montage = mne.channels.make_standard_montage(
729
+ ... 'standard_1020'
730
+ ... ).get_positions()['ch_pos']
731
+
732
+ axis : 'x' | 'y' | 'z', optional
733
+ Axis around which to rotate. Defaults to 'z'.
734
+ max_degree : float, optional
735
+ Maximum rotation. Rotation angles will be sampled between
736
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
737
+ spherical_splines : bool, optional
738
+ Whether to use spherical splines for the interpolation or not. When
739
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
740
+ be used (as in the original paper). Defaults to True.
741
+ random_state : int | numpy.random.Generator, optional
742
+ Seed to be used to instantiate numpy random number generator instance.
743
+ Defaults to None.
744
+
745
+ References
746
+ ----------
747
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
748
+ electroencephalographic data. In 2017 39th Annual International
749
+ Conference of the IEEE Engineering in Medicine and Biology Society
750
+ (EMBC) (pp. 471-474).
751
+ """
752
+
753
+ operation = staticmethod(sensors_rotation) # type: ignore[assignment]
754
+
755
+ def __init__(
756
+ self,
757
+ probability,
758
+ sensors_positions_matrix,
759
+ axis="z",
760
+ max_degrees=15,
761
+ spherical_splines=True,
762
+ random_state=None,
763
+ ):
764
+ super().__init__(probability=probability, random_state=random_state)
765
+ if isinstance(sensors_positions_matrix, (np.ndarray, list)):
766
+ sensors_positions_matrix = torch.as_tensor(sensors_positions_matrix)
767
+ assert isinstance(sensors_positions_matrix, torch.Tensor), (
768
+ "sensors_positions should be an Tensor"
769
+ )
770
+ assert isinstance(max_degrees, (Real, torch.Tensor)) and max_degrees >= 0, (
771
+ "max_degrees should be non-negative float."
772
+ )
773
+ assert isinstance(axis, str) and axis in [
774
+ "x",
775
+ "y",
776
+ "z",
777
+ ], "axis can be either x, y or z."
778
+ assert sensors_positions_matrix.shape[0] == 3, (
779
+ "sensors_positions_matrix shape should be 3 x n_channels."
780
+ )
781
+ assert isinstance(spherical_splines, bool), (
782
+ "spherical_splines should be a boolean"
783
+ )
784
+ self.sensors_positions_matrix = sensors_positions_matrix
785
+ self.axis = axis
786
+ self.spherical_splines = spherical_splines
787
+ self.max_degrees = max_degrees
788
+
789
+ def get_augmentation_params(self, *batch):
790
+ """Return transform parameters.
791
+
792
+ Parameters
793
+ ----------
794
+ X : tensor.Tensor
795
+ The data.
796
+ y : tensor.Tensor
797
+ The labels.
798
+
799
+ Returns
800
+ -------
801
+ params : dict
802
+ Contains four elements:
803
+
804
+ * sensors_positions_matrix : numpy.ndarray
805
+ Matrix giving the positions of each sensor in a 3D cartesian
806
+ coordinate system. Should have shape (3, n_channels), where
807
+ n_channels is the number of channels.
808
+ * axis : 'x' | 'y' | 'z'
809
+ Axis around which to rotate.
810
+ * angles : array-like
811
+ Array of float of shape ``(batch_size,)`` containing the
812
+ rotation angles (in degrees) for each element of the input
813
+ batch, sampled uniformly between ``-max_degrees``and
814
+ ``max_degrees``.
815
+ * spherical_splines : bool
816
+ Whether to use spherical splines for the interpolation or not.
817
+ When ``False``, standard scipy.interpolate.Rbf (with quadratic
818
+ kernel) will be used (as in the original paper).
819
+ """
820
+ if len(batch) == 0:
821
+ return super().get_augmentation_params(*batch)
822
+ X = batch[0]
823
+
824
+ u = self.rng.uniform(low=0, high=1, size=X.shape[0])
825
+ max_degrees = self.max_degrees
826
+ if isinstance(max_degrees, torch.Tensor):
827
+ max_degrees = max_degrees.to(X.device)
828
+ random_angles = (
829
+ torch.as_tensor(u, device=X.device) * 2 * max_degrees - max_degrees
830
+ )
831
+ return {
832
+ "sensors_positions_matrix": self.sensors_positions_matrix,
833
+ "axis": self.axis,
834
+ "angles": random_angles,
835
+ "spherical_splines": self.spherical_splines,
836
+ }
837
+
838
+
839
+ class SensorsZRotation(SensorsRotation):
840
+ """Interpolates EEG signals over sensors rotated around the Z axis.
841
+
842
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
843
+
844
+ Suggested in [1]_
845
+
846
+ Parameters
847
+ ----------
848
+ probability : float
849
+ Float setting the probability of applying the operation.
850
+ ordered_ch_names : list
851
+ List of strings representing the channels of the montage considered.
852
+ Has to be in standard 10-20 style. The order has to be consistent with
853
+ the order of channels in the input matrices that will be fed to the
854
+ transform. This channel will be used to compute approximate sensors
855
+ positions from a standard 10-20 montage.
856
+ max_degree : float, optional
857
+ Maximum rotation. Rotation angles will be sampled between
858
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
859
+ spherical_splines : bool, optional
860
+ Whether to use spherical splines for the interpolation or not. When
861
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
862
+ be used (as in the original paper). Defaults to True.
863
+ random_state : int | numpy.random.Generator, optional
864
+ Seed to be used to instantiate numpy random number generator instance.
865
+ Defaults to None.
866
+
867
+ References
868
+ ----------
869
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
870
+ electroencephalographic data. In 2017 39th Annual International
871
+ Conference of the IEEE Engineering in Medicine and Biology Society
872
+ (EMBC) (pp. 471-474).
873
+ """
874
+
875
+ def __init__(
876
+ self,
877
+ probability,
878
+ ordered_ch_names,
879
+ max_degrees=15,
880
+ spherical_splines=True,
881
+ random_state=None,
882
+ ):
883
+ sensors_positions_matrix = torch.as_tensor(
884
+ _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
885
+ )
886
+ super().__init__(
887
+ probability=probability,
888
+ sensors_positions_matrix=sensors_positions_matrix,
889
+ axis="z",
890
+ max_degrees=max_degrees,
891
+ spherical_splines=spherical_splines,
892
+ random_state=random_state,
893
+ )
894
+
895
+
896
+ class SensorsYRotation(SensorsRotation):
897
+ """Interpolates EEG signals over sensors rotated around the Y axis.
898
+
899
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
900
+
901
+ Suggested in [1]_
902
+
903
+ Parameters
904
+ ----------
905
+ probability : float
906
+ Float setting the probability of applying the operation.
907
+ ordered_ch_names : list
908
+ List of strings representing the channels of the montage considered.
909
+ Has to be in standard 10-20 style. The order has to be consistent with
910
+ the order of channels in the input matrices that will be fed to the
911
+ transform. This channel will be used to compute approximate sensors
912
+ positions from a standard 10-20 montage.
913
+ max_degree : float, optional
914
+ Maximum rotation. Rotation angles will be sampled between
915
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
916
+ spherical_splines : bool, optional
917
+ Whether to use spherical splines for the interpolation or not. When
918
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
919
+ be used (as in the original paper). Defaults to True.
920
+ random_state : int | numpy.random.Generator, optional
921
+ Seed to be used to instantiate numpy random number generator instance.
922
+ Defaults to None.
923
+
924
+ References
925
+ ----------
926
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
927
+ electroencephalographic data. In 2017 39th Annual International
928
+ Conference of the IEEE Engineering in Medicine and Biology Society
929
+ (EMBC) (pp. 471-474).
930
+ """
931
+
932
+ def __init__(
933
+ self,
934
+ probability,
935
+ ordered_ch_names,
936
+ max_degrees=15,
937
+ spherical_splines=True,
938
+ random_state=None,
939
+ ):
940
+ sensors_positions_matrix = torch.as_tensor(
941
+ _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
942
+ )
943
+ super().__init__(
944
+ probability=probability,
945
+ sensors_positions_matrix=sensors_positions_matrix,
946
+ axis="y",
947
+ max_degrees=max_degrees,
948
+ spherical_splines=spherical_splines,
949
+ random_state=random_state,
950
+ )
951
+
952
+
953
+ class SensorsXRotation(SensorsRotation):
954
+ """Interpolates EEG signals over sensors rotated around the X axis.
955
+
956
+ with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
957
+
958
+ Suggested in [1]_
959
+
960
+ Parameters
961
+ ----------
962
+ probability : float
963
+ Float setting the probability of applying the operation.
964
+ ordered_ch_names : list
965
+ List of strings representing the channels of the montage considered.
966
+ Has to be in standard 10-20 style. The order has to be consistent with
967
+ the order of channels in the input matrices that will be fed to the
968
+ transform. This channel will be used to compute approximate sensors
969
+ positions from a standard 10-20 montage.
970
+ max_degree : float, optional
971
+ Maximum rotation. Rotation angles will be sampled between
972
+ ``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
973
+ spherical_splines : bool, optional
974
+ Whether to use spherical splines for the interpolation or not. When
975
+ ``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
976
+ be used (as in the original paper). Defaults to True.
977
+ random_state : int | numpy.random.Generator, optional
978
+ Seed to be used to instantiate numpy random number generator instance.
979
+ Defaults to None.
980
+
981
+ References
982
+ ----------
983
+ .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
984
+ electroencephalographic data. In 2017 39th Annual International
985
+ Conference of the IEEE Engineering in Medicine and Biology Society
986
+ (EMBC) (pp. 471-474).
987
+ """
988
+
989
+ def __init__(
990
+ self,
991
+ probability,
992
+ ordered_ch_names,
993
+ max_degrees=15,
994
+ spherical_splines=True,
995
+ random_state=None,
996
+ ):
997
+ sensors_positions_matrix = torch.as_tensor(
998
+ _get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
999
+ )
1000
+ super().__init__(
1001
+ probability=probability,
1002
+ sensors_positions_matrix=sensors_positions_matrix,
1003
+ axis="x",
1004
+ max_degrees=max_degrees,
1005
+ spherical_splines=spherical_splines,
1006
+ random_state=random_state,
1007
+ )
1008
+
1009
+
1010
+ class Mixup(Transform):
1011
+ """Implements Iterator for Mixup for EEG data.
1012
+
1013
+ See [1]_.
1014
+ Implementation based on [2]_.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ alpha : float
1019
+ Mixup hyperparameter.
1020
+ beta_per_sample : bool (default=False)
1021
+ By default, one mixing coefficient per batch is drawn from a beta
1022
+ distribution. If True, one mixing coefficient per sample is drawn.
1023
+ random_state : int | numpy.random.Generator, optional
1024
+ Seed to be used to instantiate numpy random number generator instance.
1025
+ Defaults to None.
1026
+
1027
+ References
1028
+ ----------
1029
+ .. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
1030
+ (2018). mixup: Beyond Empirical Risk Minimization. In 2018
1031
+ International Conference on Learning Representations (ICLR)
1032
+ Online: https://arxiv.org/abs/1710.09412
1033
+ .. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
1034
+ """
1035
+
1036
+ operation = staticmethod(mixup) # type: ignore[assignment]
1037
+
1038
+ def __init__(self, alpha, beta_per_sample=False, random_state=None):
1039
+ super().__init__(
1040
+ probability=1.0, # Mixup has to be applied to whole batches
1041
+ random_state=random_state,
1042
+ )
1043
+ self.alpha = alpha
1044
+ self.beta_per_sample = beta_per_sample
1045
+
1046
+ def get_augmentation_params(self, *batch):
1047
+ """Return transform parameters.
1048
+
1049
+ Parameters
1050
+ ----------
1051
+ X : tensor.Tensor
1052
+ The data.
1053
+ y : tensor.Tensor
1054
+ The labels.
1055
+
1056
+ Returns
1057
+ -------
1058
+ params : dict
1059
+ Contains the values sampled uniformly between 0 and 1 setting the
1060
+ linear interpolation between examples (lam) and the shuffled
1061
+ indices of examples that are mixed into original examples
1062
+ (idx_perm).
1063
+ """
1064
+ X = batch[0]
1065
+ device = X.device
1066
+ batch_size, _, _ = X.shape
1067
+
1068
+ if self.alpha > 0:
1069
+ if self.beta_per_sample:
1070
+ lam = torch.as_tensor(
1071
+ self.rng.beta(self.alpha, self.alpha, batch_size)
1072
+ ).to(device)
1073
+ else:
1074
+ lam = torch.ones(batch_size).to(device)
1075
+ lam *= self.rng.beta(self.alpha, self.alpha)
1076
+ else:
1077
+ lam = torch.ones(batch_size).to(device)
1078
+
1079
+ idx_perm = torch.as_tensor(
1080
+ self.rng.permutation(
1081
+ batch_size,
1082
+ )
1083
+ )
1084
+
1085
+ return {
1086
+ "lam": lam,
1087
+ "idx_perm": idx_perm,
1088
+ }
1089
+
1090
+
1091
+ class SegmentationReconstruction(Transform):
1092
+ """Segmentation Reconstruction from Lotte (2015) [Lotte2015]_.
1093
+
1094
+ Applies a segmentation-reconstruction transform to the input data, as
1095
+ proposed in [Lotte2015]_. It segments each trial in the batch and randomly mix
1096
+ it to generate new synthetic trials by label, preserving the original
1097
+ order of the segments in time domain.
1098
+
1099
+ Parameters
1100
+ ----------
1101
+ probability : float
1102
+ Float setting the probability of applying the operation.
1103
+ random_state : int | numpy.random.Generator, optional
1104
+ Seed to be used to instantiate numpy random number generator instance.
1105
+ Used to decide whether to transform given the probability
1106
+ argument and to sample the segments mixing. Defaults to None.
1107
+ n_segments : int, optional
1108
+ Number of segments to use in the batch. If None, X will be
1109
+ automatically segmented, getting the last element in a list
1110
+ of factors of the number of samples's square root. Defaults to None.
1111
+
1112
+ References
1113
+ ----------
1114
+ .. [Lotte2015] Lotte, F. (2015). Signal processing approaches to minimize
1115
+ or suppress calibration time in oscillatory activity-based brain–computer
1116
+ interfaces. Proceedings of the IEEE, 103(6), 871-890.
1117
+ """
1118
+
1119
+ operation = staticmethod(segmentation_reconstruction) # type: ignore[assignment]
1120
+
1121
+ def __init__(
1122
+ self,
1123
+ probability,
1124
+ n_segments=None,
1125
+ random_state=None,
1126
+ ):
1127
+ super().__init__(
1128
+ probability=probability,
1129
+ random_state=random_state,
1130
+ )
1131
+ self.n_segments = n_segments
1132
+
1133
+ def get_augmentation_params(self, *batch):
1134
+ """Return transform parameters.
1135
+
1136
+ Parameters
1137
+ ----------
1138
+ X : tensor.Tensor
1139
+ The data.
1140
+ y : tensor.Tensor
1141
+ The labels.
1142
+
1143
+ Returns
1144
+ -------
1145
+ params : dict
1146
+ Contains the number of segments to split the signal into.
1147
+ """
1148
+ X, y = batch[0], batch[1]
1149
+
1150
+ if y is not None:
1151
+ if not isinstance(X, torch.Tensor) or not isinstance(y, torch.Tensor):
1152
+ raise ValueError("X and y must be torch tensors.")
1153
+
1154
+ if X.shape[0] != y.shape[0]:
1155
+ raise ValueError("Number of samples in X and y must be the same.")
1156
+
1157
+ if self.n_segments is None:
1158
+ self.n_segments = int(X.shape[2])
1159
+ n_segments_list = []
1160
+ for i in range(1, int(self.n_segments**0.5) + 1):
1161
+ if self.n_segments % i == 0:
1162
+ n_segments_list.append(i)
1163
+ self.n_segments = n_segments_list[-1]
1164
+
1165
+ elif not (
1166
+ isinstance(self.n_segments, (int, float))
1167
+ and 1 <= self.n_segments <= X.shape[2]
1168
+ ):
1169
+ raise ValueError(
1170
+ f"Number of segments must be a positive integer less than "
1171
+ f"(or equal) the window size. Got {self.n_segments}"
1172
+ )
1173
+
1174
+ if y is None:
1175
+ data_classes = [(np.nan, X)]
1176
+
1177
+ else:
1178
+ classes = torch.unique(y)
1179
+
1180
+ data_classes = [(i, X[y == i]) for i in classes]
1181
+
1182
+ rand_indices = dict()
1183
+ for label, X_class in data_classes:
1184
+ n_trials = X_class.shape[0]
1185
+ rand_indices[label] = self.rng.randint(
1186
+ 0, n_trials, (n_trials, self.n_segments)
1187
+ )
1188
+
1189
+ idx_shuffle = self.rng.permutation(X.shape[0])
1190
+
1191
+ return {
1192
+ "n_segments": self.n_segments,
1193
+ "data_classes": data_classes,
1194
+ "rand_indices": rand_indices,
1195
+ "idx_shuffle": idx_shuffle,
1196
+ }
1197
+
1198
+
1199
+ class MaskEncoding(Transform):
1200
+ """MaskEncoding from [1]_.
1201
+
1202
+ Replaces randomly chosen contiguous part (or parts) of all channels by
1203
+ zeros (if more than one segment, it may overlap).
1204
+
1205
+ Implementation based on [1]_
1206
+
1207
+ Parameters
1208
+ ----------
1209
+ probability : float
1210
+ Float setting the probability of applying the operation.
1211
+ max_mask_ratio : float, optional
1212
+ Signal ratio to zero out. Defaults to 0.1.
1213
+ n_segments : int, optional
1214
+ Number of segments to zero out in each example.
1215
+ Defaults to 1.
1216
+ random_state : int | numpy.random.Generator, optional
1217
+ Seed to be used to instantiate numpy random number generator instance.
1218
+ Defaults to None.
1219
+
1220
+ References
1221
+ ----------
1222
+ .. [1] Ding, Wenlong, et al. "A Novel Data Augmentation Approach
1223
+ Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI."
1224
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering
1225
+ 32 (2024): 875-886.
1226
+ """
1227
+
1228
+ operation = staticmethod(mask_encoding) # type: ignore[assignment]
1229
+
1230
+ def __init__(
1231
+ self,
1232
+ probability,
1233
+ max_mask_ratio=0.1,
1234
+ n_segments=1,
1235
+ random_state=None,
1236
+ ):
1237
+ super().__init__(
1238
+ probability=probability,
1239
+ random_state=random_state,
1240
+ )
1241
+ assert isinstance(n_segments, int) and n_segments > 0, (
1242
+ "n_segments should be a positive integer."
1243
+ )
1244
+ assert isinstance(max_mask_ratio, (int, float)) and 0 <= max_mask_ratio <= 1, (
1245
+ "mask_ratio should be a float between 0 and 1."
1246
+ )
1247
+
1248
+ self.mask_ratio = max_mask_ratio
1249
+ self.n_segments = n_segments
1250
+
1251
+ def get_augmentation_params(self, *batch):
1252
+ """Return transform parameters.
1253
+
1254
+ Parameters
1255
+ ----------
1256
+ X : tensor.Tensor
1257
+ The data.
1258
+ y : tensor.Tensor
1259
+ The labels.
1260
+
1261
+ Returns
1262
+ -------
1263
+ params : dict
1264
+ Contains ...
1265
+ """
1266
+ if len(batch) == 0:
1267
+ return super().get_augmentation_params(*batch)
1268
+ X = batch[0]
1269
+
1270
+ batch_size, _, n_times = X.shape
1271
+
1272
+ segment_length = int((n_times * self.mask_ratio) / self.n_segments)
1273
+
1274
+ assert segment_length >= 1, (
1275
+ "n_segments should be a positive integer not higher than (max_mask_ratio * window size)."
1276
+ )
1277
+
1278
+ time_start = self.rng.randint(
1279
+ 0, n_times - segment_length, (batch_size, self.n_segments)
1280
+ )
1281
+ time_start = torch.from_numpy(time_start)
1282
+
1283
+ return {
1284
+ "time_start": time_start,
1285
+ "segment_length": segment_length,
1286
+ "n_segments": self.n_segments,
1287
+ }
1288
+
1289
+
1290
+ class ChannelsReref(Transform):
1291
+ """Randomly re-reference channels in EEG data matrix.
1292
+
1293
+ Part of the augmentations proposed in [1]_
1294
+
1295
+ Parameters
1296
+ ----------
1297
+ probability : float
1298
+ Float setting the probability of applying the operation.
1299
+ random_state : int | numpy.random.Generator, optional
1300
+ Seed to be used to instantiate numpy random number generator instance.
1301
+ Used to decide whether or not to transform given the probability
1302
+ argument, to sample which channels to shuffle and to carry the shuffle.
1303
+ Defaults to None.
1304
+
1305
+ References
1306
+ ----------
1307
+ .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1308
+ Representation Learning for Electroencephalogram Classification. Proceedings
1309
+ of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1310
+ Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
1311
+ """
1312
+
1313
+ operation = staticmethod(channels_rereference) # type: ignore[assignment]
1314
+
1315
+ def __init__(self, probability, random_state=None):
1316
+ super().__init__(probability=probability, random_state=random_state)
1317
+
1318
+ def get_augmentation_params(self, *batch):
1319
+ """Return transform parameters."""
1320
+ return {
1321
+ "random_state": self.rng,
1322
+ }
1323
+
1324
+
1325
+ class AmplitudeScale(Transform):
1326
+ """Rescale amplitude based on a random sampled scaling value.
1327
+
1328
+ Part of the augmentations proposed in [1]_
1329
+
1330
+ Parameters
1331
+ ----------
1332
+ probability : float
1333
+ Float setting the probability of applying the operation.
1334
+ random_state : int | numpy.random.Generator, optional
1335
+ Seed to be used to instantiate numpy random number generator instance.
1336
+ Used to decide whether or not to transform given the probability
1337
+ argument, to sample which channels to shuffle and to carry the shuffle.
1338
+ Defaults to None.
1339
+
1340
+ References
1341
+ ----------
1342
+ .. [1] Mohsenvand, M.N., Izadi, M.R. &amp; Maes, P.. (2020). Contrastive
1343
+ Representation Learning for Electroencephalogram Classification. Proceedings
1344
+ of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
1345
+ Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
1346
+ """
1347
+
1348
+ operation = staticmethod(amplitude_scale) # type: ignore[assignment]
1349
+
1350
+ def __init__(self, probability, interval=(0.5, 2), random_state=None):
1351
+ super().__init__(probability=probability, random_state=random_state)
1352
+ self.scale = interval
1353
+
1354
+ def get_augmentation_params(self, *batch):
1355
+ """Return transform parameters."""
1356
+ return {"random_state": self.rng, "scale": self.scale}