braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1356 @@
|
|
|
1
|
+
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
|
+
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
|
+
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
4
|
+
# Bruna Lopes <brunajaflopes@gmail.com>
|
|
5
|
+
#
|
|
6
|
+
# License: BSD (3-clause)
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
from numbers import Real
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from mne.channels import make_standard_montage
|
|
14
|
+
|
|
15
|
+
from .base import Transform
|
|
16
|
+
from .functional import (
|
|
17
|
+
amplitude_scale,
|
|
18
|
+
bandstop_filter,
|
|
19
|
+
channels_dropout,
|
|
20
|
+
channels_permute,
|
|
21
|
+
channels_rereference,
|
|
22
|
+
channels_shuffle,
|
|
23
|
+
frequency_shift,
|
|
24
|
+
ft_surrogate,
|
|
25
|
+
gaussian_noise,
|
|
26
|
+
mask_encoding,
|
|
27
|
+
mixup,
|
|
28
|
+
segmentation_reconstruction,
|
|
29
|
+
sensors_rotation,
|
|
30
|
+
sign_flip,
|
|
31
|
+
smooth_time_mask,
|
|
32
|
+
time_reverse,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TimeReverse(Transform):
|
|
37
|
+
"""Flip the time axis of each input with a given probability.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
probability : float
|
|
42
|
+
Float setting the probability of applying the operation.
|
|
43
|
+
random_state : int | numpy.random.Generator, optional
|
|
44
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
45
|
+
Used to decide whether or not to transform given the probability
|
|
46
|
+
argument. Defaults to None.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
operation = staticmethod(time_reverse) # type: ignore[assignment]
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
probability,
|
|
54
|
+
random_state=None,
|
|
55
|
+
):
|
|
56
|
+
super().__init__(
|
|
57
|
+
probability=probability,
|
|
58
|
+
random_state=random_state,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class SignFlip(Transform):
|
|
63
|
+
"""Flip the sign axis of each input with a given probability.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
probability : float
|
|
68
|
+
Float setting the probability of applying the operation.
|
|
69
|
+
random_state : int | numpy.random.Generator, optional
|
|
70
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
71
|
+
Used to decide whether or not to transform given the probability
|
|
72
|
+
argument. Defaults to None.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
operation = staticmethod(sign_flip) # type: ignore[assignment]
|
|
76
|
+
|
|
77
|
+
def __init__(self, probability, random_state=None):
|
|
78
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class FTSurrogate(Transform):
|
|
82
|
+
"""FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
probability : float
|
|
87
|
+
Float setting the probability of applying the operation.
|
|
88
|
+
phase_noise_magnitude : float | torch.Tensor, optional
|
|
89
|
+
Float between 0 and 1 setting the range over which the phase
|
|
90
|
+
perturbation is uniformly sampled:
|
|
91
|
+
``[0, phase_noise_magnitude * 2 * pi]``. Defaults to 1.
|
|
92
|
+
channel_indep : bool, optional
|
|
93
|
+
Whether to sample phase perturbations independently for each channel or
|
|
94
|
+
not. It is advised to set it to False when spatial information is
|
|
95
|
+
important for the task, like in BCI. Default False.
|
|
96
|
+
random_state : int | numpy.random.Generator, optional
|
|
97
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
98
|
+
Used to decide whether or not to transform given the probability
|
|
99
|
+
argument. Defaults to None.
|
|
100
|
+
|
|
101
|
+
References
|
|
102
|
+
----------
|
|
103
|
+
.. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
|
|
104
|
+
Clifford, G. D. (2018). Addressing Class Imbalance in Classification
|
|
105
|
+
Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
|
|
106
|
+
preprint arXiv:1806.08675.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
operation = staticmethod(ft_surrogate) # type: ignore[assignment]
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
probability,
|
|
114
|
+
phase_noise_magnitude=1,
|
|
115
|
+
channel_indep=False,
|
|
116
|
+
random_state=None,
|
|
117
|
+
):
|
|
118
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
119
|
+
assert isinstance(phase_noise_magnitude, (float, int, torch.Tensor)), (
|
|
120
|
+
"phase_noise_magnitude should be a float."
|
|
121
|
+
)
|
|
122
|
+
assert 0 <= phase_noise_magnitude <= 1, (
|
|
123
|
+
"phase_noise_magnitude should be between 0 and 1."
|
|
124
|
+
)
|
|
125
|
+
assert isinstance(channel_indep, bool), (
|
|
126
|
+
"channel_indep is expected to be a boolean"
|
|
127
|
+
)
|
|
128
|
+
self.phase_noise_magnitude = phase_noise_magnitude
|
|
129
|
+
self.channel_indep = channel_indep
|
|
130
|
+
|
|
131
|
+
def get_augmentation_params(self, *batch):
|
|
132
|
+
"""Return transform parameters.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
X : tensor.Tensor
|
|
137
|
+
The data.
|
|
138
|
+
y : tensor.Tensor
|
|
139
|
+
The labels.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
params : dict
|
|
144
|
+
Contains:
|
|
145
|
+
|
|
146
|
+
* phase_noise_magnitude : float
|
|
147
|
+
The magnitude of the transformation.
|
|
148
|
+
* random_state : numpy.random.Generator
|
|
149
|
+
The generator to use.
|
|
150
|
+
"""
|
|
151
|
+
return {
|
|
152
|
+
"phase_noise_magnitude": self.phase_noise_magnitude,
|
|
153
|
+
"channel_indep": self.channel_indep,
|
|
154
|
+
"random_state": self.rng,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class ChannelsDropout(Transform):
|
|
159
|
+
"""Randomly set channels to flat signal.
|
|
160
|
+
|
|
161
|
+
Part of the CMSAugment policy proposed in [1]_
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
probability : float
|
|
166
|
+
Float setting the probability of applying the operation.
|
|
167
|
+
p_drop : float | None, optional
|
|
168
|
+
Float between 0 and 1 setting the probability of dropping each channel.
|
|
169
|
+
Defaults to 0.2.
|
|
170
|
+
random_state : int | numpy.random.RandomState, optional
|
|
171
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
172
|
+
Used to decide whether or not to transform given the probability
|
|
173
|
+
argument and to sample channels to erase. Defaults to None.
|
|
174
|
+
|
|
175
|
+
References
|
|
176
|
+
----------
|
|
177
|
+
.. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
|
|
178
|
+
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
179
|
+
Reordering. arXiv preprint arXiv:2010.13694.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
operation = staticmethod(channels_dropout) # type: ignore[assignment]
|
|
183
|
+
|
|
184
|
+
def __init__(self, probability, p_drop=0.2, random_state=None):
|
|
185
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
186
|
+
self.p_drop = p_drop
|
|
187
|
+
|
|
188
|
+
def get_augmentation_params(self, *batch):
|
|
189
|
+
"""Return transform parameters.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
X : tensor.Tensor
|
|
194
|
+
The data.
|
|
195
|
+
y : tensor.Tensor
|
|
196
|
+
The labels.
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
params : dict
|
|
201
|
+
Contains
|
|
202
|
+
|
|
203
|
+
* p_drop : float
|
|
204
|
+
Float between 0 and 1 setting the probability of dropping each
|
|
205
|
+
channel.
|
|
206
|
+
* random_state : numpy.random.Generator
|
|
207
|
+
The generator to use.
|
|
208
|
+
"""
|
|
209
|
+
return {
|
|
210
|
+
"p_drop": self.p_drop,
|
|
211
|
+
"random_state": self.rng,
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ChannelsShuffle(Transform):
|
|
216
|
+
"""Randomly shuffle channels in EEG data matrix.
|
|
217
|
+
|
|
218
|
+
Part of the CMSAugment policy proposed in [1]_
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
probability : float
|
|
223
|
+
Float setting the probability of applying the operation.
|
|
224
|
+
p_shuffle : float | None, optional
|
|
225
|
+
Float between 0 and 1 setting the probability of including the channel
|
|
226
|
+
in the set of permuted channels. Defaults to 0.2.
|
|
227
|
+
random_state : int | numpy.random.Generator, optional
|
|
228
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
229
|
+
Used to decide whether or not to transform given the probability
|
|
230
|
+
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
231
|
+
Defaults to None.
|
|
232
|
+
|
|
233
|
+
References
|
|
234
|
+
----------
|
|
235
|
+
.. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
|
|
236
|
+
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
237
|
+
Reordering. arXiv preprint arXiv:2010.13694.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
operation = staticmethod(channels_shuffle) # type: ignore[assignment]
|
|
241
|
+
|
|
242
|
+
def __init__(self, probability, p_shuffle=0.2, random_state=None):
|
|
243
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
244
|
+
self.p_shuffle = p_shuffle
|
|
245
|
+
|
|
246
|
+
def get_augmentation_params(self, *batch):
|
|
247
|
+
"""Return transform parameters.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
X : tensor.Tensor
|
|
252
|
+
The data.
|
|
253
|
+
y : tensor.Tensor
|
|
254
|
+
The labels.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
params : dict
|
|
259
|
+
Contains
|
|
260
|
+
|
|
261
|
+
* p_shuffle : float
|
|
262
|
+
Float between 0 and 1 setting the probability of including the
|
|
263
|
+
channel in the set of permuted channels.
|
|
264
|
+
* random_state : numpy.random.Generator
|
|
265
|
+
The generator to use.
|
|
266
|
+
"""
|
|
267
|
+
return {
|
|
268
|
+
"p_shuffle": self.p_shuffle,
|
|
269
|
+
"random_state": self.rng,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class GaussianNoise(Transform):
|
|
274
|
+
"""Randomly add white noise to all channels.
|
|
275
|
+
|
|
276
|
+
Suggested e.g. in [1]_, [2]_ and [3]_
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
probability : float
|
|
281
|
+
Float setting the probability of applying the operation.
|
|
282
|
+
std : float, optional
|
|
283
|
+
Standard deviation to use for the additive noise. Defaults to 0.1.
|
|
284
|
+
random_state : int | numpy.random.Generator, optional
|
|
285
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
286
|
+
Defaults to None.
|
|
287
|
+
|
|
288
|
+
References
|
|
289
|
+
----------
|
|
290
|
+
.. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
|
|
291
|
+
augmentation for eeg-based emotion recognition with deep convolutional
|
|
292
|
+
neural networks. In International Conference on Multimedia Modeling
|
|
293
|
+
(pp. 82-93).
|
|
294
|
+
.. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
|
|
295
|
+
Subject-aware contrastive learning for biosignals. arXiv preprint
|
|
296
|
+
arXiv:2007.04871.
|
|
297
|
+
.. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
|
|
298
|
+
Representation Learning for Electroencephalogram Classification. In
|
|
299
|
+
Machine Learning for Health (pp. 238-253). PMLR.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
operation = staticmethod(gaussian_noise) # type: ignore[assignment]
|
|
303
|
+
|
|
304
|
+
def __init__(self, probability, std=0.1, random_state=None):
|
|
305
|
+
super().__init__(
|
|
306
|
+
probability=probability,
|
|
307
|
+
random_state=random_state,
|
|
308
|
+
)
|
|
309
|
+
self.std = std
|
|
310
|
+
|
|
311
|
+
def get_augmentation_params(self, *batch):
|
|
312
|
+
"""Return transform parameters.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
X : tensor.Tensor
|
|
317
|
+
The data.
|
|
318
|
+
y : tensor.Tensor
|
|
319
|
+
The labels.
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
params : dict
|
|
324
|
+
Contains
|
|
325
|
+
|
|
326
|
+
* std : float
|
|
327
|
+
Standard deviation to use for the additive noise.
|
|
328
|
+
* random_state : numpy.random.Generator
|
|
329
|
+
The generator to use.
|
|
330
|
+
"""
|
|
331
|
+
return {
|
|
332
|
+
"std": self.std,
|
|
333
|
+
"random_state": self.rng,
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class ChannelsSymmetry(Transform):
|
|
338
|
+
"""Permute EEG channels inverting left and right-side sensors.
|
|
339
|
+
|
|
340
|
+
Suggested e.g. in [1]_
|
|
341
|
+
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
probability : float
|
|
345
|
+
Float setting the probability of applying the operation.
|
|
346
|
+
ordered_ch_names : list
|
|
347
|
+
Ordered list of strings containing the names (in 10-20
|
|
348
|
+
nomenclature) of the EEG channels that will be transformed. The
|
|
349
|
+
first name should correspond the data in the first row of X, the
|
|
350
|
+
second name in the second row and so on.
|
|
351
|
+
random_state : int | numpy.random.Generator, optional
|
|
352
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
353
|
+
Used to decide whether or not to transform given the probability
|
|
354
|
+
argument. Defaults to None.
|
|
355
|
+
|
|
356
|
+
References
|
|
357
|
+
----------
|
|
358
|
+
.. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
|
|
359
|
+
(2018). HAMLET: interpretable human and machine co-learning technique.
|
|
360
|
+
arXiv preprint arXiv:1803.09702.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
operation = staticmethod(channels_permute) # type: ignore[assignment]
|
|
364
|
+
|
|
365
|
+
def __init__(self, probability, ordered_ch_names, random_state=None):
|
|
366
|
+
super().__init__(
|
|
367
|
+
probability=probability,
|
|
368
|
+
random_state=random_state,
|
|
369
|
+
)
|
|
370
|
+
assert isinstance(ordered_ch_names, list) and all(
|
|
371
|
+
isinstance(ch, str) for ch in ordered_ch_names
|
|
372
|
+
), "ordered_ch_names should be a list of str."
|
|
373
|
+
|
|
374
|
+
permutation = list()
|
|
375
|
+
for idx, ch_name in enumerate(ordered_ch_names):
|
|
376
|
+
new_position = idx
|
|
377
|
+
# Find digits in channel name (assuming 10-20 system)
|
|
378
|
+
d = "".join(list(filter(str.isdigit, ch_name)))
|
|
379
|
+
if len(d) > 0:
|
|
380
|
+
d = int(d)
|
|
381
|
+
if d % 2 == 0: # pair/right electrodes
|
|
382
|
+
sym = d - 1
|
|
383
|
+
else: # odd/left electrodes
|
|
384
|
+
sym = d + 1
|
|
385
|
+
new_channel = ch_name.replace(str(d), str(sym))
|
|
386
|
+
if new_channel in ordered_ch_names:
|
|
387
|
+
new_position = ordered_ch_names.index(new_channel)
|
|
388
|
+
permutation.append(new_position)
|
|
389
|
+
self.permutation = permutation
|
|
390
|
+
|
|
391
|
+
def get_augmentation_params(self, *batch):
|
|
392
|
+
"""Return transform parameters.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
X : tensor.Tensor
|
|
397
|
+
The data.
|
|
398
|
+
y : tensor.Tensor
|
|
399
|
+
The labels.
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
params : dict
|
|
404
|
+
Contains
|
|
405
|
+
|
|
406
|
+
* permutation : float
|
|
407
|
+
List of integers defining the new channels order.
|
|
408
|
+
"""
|
|
409
|
+
return {"permutation": self.permutation}
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class SmoothTimeMask(Transform):
|
|
413
|
+
"""Smoothly replace a randomly chosen contiguous part of all channels by.
|
|
414
|
+
|
|
415
|
+
zeros.
|
|
416
|
+
|
|
417
|
+
Suggested e.g. in [1]_ and [2]_
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
probability : float
|
|
422
|
+
Float setting the probability of applying the operation.
|
|
423
|
+
mask_len_samples : int | torch.Tensor, optional
|
|
424
|
+
Number of consecutive samples to zero out. Will be ignored if
|
|
425
|
+
magnitude is not set to None. Defaults to 100.
|
|
426
|
+
random_state : int | numpy.random.Generator, optional
|
|
427
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
428
|
+
Defaults to None.
|
|
429
|
+
|
|
430
|
+
References
|
|
431
|
+
----------
|
|
432
|
+
.. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
|
|
433
|
+
Subject-aware contrastive learning for biosignals. arXiv preprint
|
|
434
|
+
arXiv:2007.04871.
|
|
435
|
+
.. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
|
|
436
|
+
Representation Learning for Electroencephalogram Classification. In
|
|
437
|
+
Machine Learning for Health (pp. 238-253). PMLR.
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
operation = staticmethod(smooth_time_mask) # type: ignore[assignment]
|
|
441
|
+
|
|
442
|
+
def __init__(self, probability, mask_len_samples=100, random_state=None):
|
|
443
|
+
super().__init__(
|
|
444
|
+
probability=probability,
|
|
445
|
+
random_state=random_state,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
assert (
|
|
449
|
+
isinstance(mask_len_samples, (int, torch.Tensor)) and mask_len_samples > 0
|
|
450
|
+
), "mask_len_samples has to be a positive integer"
|
|
451
|
+
self.mask_len_samples = mask_len_samples
|
|
452
|
+
|
|
453
|
+
def get_augmentation_params(self, *batch):
|
|
454
|
+
"""Return transform parameters.
|
|
455
|
+
|
|
456
|
+
Parameters
|
|
457
|
+
----------
|
|
458
|
+
X : tensor.Tensor
|
|
459
|
+
The data.
|
|
460
|
+
y : tensor.Tensor
|
|
461
|
+
The labels.
|
|
462
|
+
|
|
463
|
+
Returns
|
|
464
|
+
-------
|
|
465
|
+
params : dict
|
|
466
|
+
Contains two elements:
|
|
467
|
+
|
|
468
|
+
* mask_start_per_sample : torch.tensor
|
|
469
|
+
Tensor of integers containing the position (in last dimension)
|
|
470
|
+
where to start masking the signal. Should have the same size as
|
|
471
|
+
the first dimension of X (i.e. one start position per example
|
|
472
|
+
in the batch).
|
|
473
|
+
* mask_len_samples : int
|
|
474
|
+
Number of consecutive samples to zero out.
|
|
475
|
+
"""
|
|
476
|
+
if len(batch) == 0:
|
|
477
|
+
return super().get_augmentation_params(*batch)
|
|
478
|
+
X = batch[0]
|
|
479
|
+
|
|
480
|
+
seq_length = torch.as_tensor(X.shape[-1], device=X.device)
|
|
481
|
+
mask_len_samples = self.mask_len_samples
|
|
482
|
+
if isinstance(mask_len_samples, torch.Tensor):
|
|
483
|
+
mask_len_samples = mask_len_samples.to(X.device)
|
|
484
|
+
mask_start = torch.as_tensor(
|
|
485
|
+
self.rng.uniform(
|
|
486
|
+
low=0,
|
|
487
|
+
high=1,
|
|
488
|
+
size=X.shape[0],
|
|
489
|
+
),
|
|
490
|
+
device=X.device,
|
|
491
|
+
) * (seq_length - mask_len_samples)
|
|
492
|
+
return {
|
|
493
|
+
"mask_start_per_sample": mask_start,
|
|
494
|
+
"mask_len_samples": mask_len_samples,
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class BandstopFilter(Transform):
|
|
499
|
+
"""Apply a band-stop filter with desired bandwidth at a randomly selected.
|
|
500
|
+
|
|
501
|
+
frequency position between 0 and ``max_freq``.
|
|
502
|
+
|
|
503
|
+
Suggested e.g. in [1]_ and [2]_
|
|
504
|
+
|
|
505
|
+
Parameters
|
|
506
|
+
----------
|
|
507
|
+
probability : float
|
|
508
|
+
Float setting the probability of applying the operation.
|
|
509
|
+
bandwidth : float
|
|
510
|
+
Bandwidth of the filter, i.e. distance between the low and high cut
|
|
511
|
+
frequencies.
|
|
512
|
+
sfreq : float, optional
|
|
513
|
+
Sampling frequency of the signals to be filtered. Defaults to 100 Hz.
|
|
514
|
+
max_freq : float | None, optional
|
|
515
|
+
Maximal admissible frequency. The low cut frequency will be sampled so
|
|
516
|
+
that the corresponding high cut frequency + transition (=1Hz) are below
|
|
517
|
+
``max_freq``. If omitted or `None`, will default to the Nyquist
|
|
518
|
+
frequency (``sfreq / 2``).
|
|
519
|
+
random_state : int | numpy.random.Generator, optional
|
|
520
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
521
|
+
Defaults to None.
|
|
522
|
+
|
|
523
|
+
References
|
|
524
|
+
----------
|
|
525
|
+
.. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
|
|
526
|
+
Subject-aware contrastive learning for biosignals. arXiv preprint
|
|
527
|
+
arXiv:2007.04871.
|
|
528
|
+
.. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
|
|
529
|
+
Representation Learning for Electroencephalogram Classification. In
|
|
530
|
+
Machine Learning for Health (pp. 238-253). PMLR.
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
operation = staticmethod(bandstop_filter) # type: ignore[assignment]
|
|
534
|
+
|
|
535
|
+
def __init__(
|
|
536
|
+
self, probability, sfreq, bandwidth=1, max_freq=None, random_state=None
|
|
537
|
+
):
|
|
538
|
+
super().__init__(
|
|
539
|
+
probability=probability,
|
|
540
|
+
random_state=random_state,
|
|
541
|
+
)
|
|
542
|
+
assert isinstance(bandwidth, Real) and bandwidth >= 0, (
|
|
543
|
+
"bandwidth should be a non-negative float."
|
|
544
|
+
)
|
|
545
|
+
assert isinstance(sfreq, Real) and sfreq > 0, (
|
|
546
|
+
"sfreq should be a positive float."
|
|
547
|
+
)
|
|
548
|
+
if max_freq is not None:
|
|
549
|
+
assert isinstance(max_freq, Real) and max_freq > 0, (
|
|
550
|
+
"max_freq should be a positive float."
|
|
551
|
+
)
|
|
552
|
+
nyq = sfreq / 2
|
|
553
|
+
if max_freq is None or max_freq > nyq:
|
|
554
|
+
max_freq = nyq
|
|
555
|
+
warnings.warn(
|
|
556
|
+
"You either passed None or a frequency greater than the"
|
|
557
|
+
f" Nyquist frequency ({nyq} Hz)."
|
|
558
|
+
f" Falling back to max_freq = {nyq}."
|
|
559
|
+
)
|
|
560
|
+
assert bandwidth < max_freq, (
|
|
561
|
+
f"`bandwidth` needs to be smaller than max_freq={max_freq}"
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
# override bandwidth value when a magnitude is passed
|
|
565
|
+
self.sfreq = sfreq
|
|
566
|
+
self.max_freq = max_freq
|
|
567
|
+
self.bandwidth = bandwidth
|
|
568
|
+
|
|
569
|
+
def get_augmentation_params(self, *batch):
|
|
570
|
+
"""Return transform parameters.
|
|
571
|
+
|
|
572
|
+
Parameters
|
|
573
|
+
----------
|
|
574
|
+
X : tensor.Tensor
|
|
575
|
+
The data.
|
|
576
|
+
y : tensor.Tensor
|
|
577
|
+
The labels.
|
|
578
|
+
|
|
579
|
+
Returns
|
|
580
|
+
-------
|
|
581
|
+
params : dict
|
|
582
|
+
Contains
|
|
583
|
+
|
|
584
|
+
* sfreq : float
|
|
585
|
+
Sampling frequency of the signals to be filtered.
|
|
586
|
+
* bandwidth : float
|
|
587
|
+
Bandwidth of the filter, i.e. distance between the low and high
|
|
588
|
+
cut frequencies.
|
|
589
|
+
* freqs_to_notch : array-like | None
|
|
590
|
+
Array of floats of size ``(batch_size,)`` containing the center
|
|
591
|
+
of the frequency band to filter out for each sample in the
|
|
592
|
+
batch. Frequencies should be greater than
|
|
593
|
+
``bandwidth/2 + transition`` and lower than
|
|
594
|
+
``sfreq/2 - bandwidth/2 - transition`` (where
|
|
595
|
+
``transition = 1 Hz``).
|
|
596
|
+
"""
|
|
597
|
+
if len(batch) == 0:
|
|
598
|
+
return super().get_augmentation_params(*batch)
|
|
599
|
+
X = batch[0]
|
|
600
|
+
|
|
601
|
+
# Prevents transitions from going below 0 and above max_freq
|
|
602
|
+
notched_freqs = self.rng.uniform(
|
|
603
|
+
low=1 + 2 * self.bandwidth,
|
|
604
|
+
high=self.max_freq - 1 - 2 * self.bandwidth,
|
|
605
|
+
size=X.shape[0],
|
|
606
|
+
)
|
|
607
|
+
return {
|
|
608
|
+
"sfreq": self.sfreq,
|
|
609
|
+
"bandwidth": self.bandwidth,
|
|
610
|
+
"freqs_to_notch": notched_freqs,
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
class FrequencyShift(Transform):
|
|
615
|
+
"""Add a random shift in the frequency domain to all channels.
|
|
616
|
+
|
|
617
|
+
Note that here, the shift is the same for all channels of a single example.
|
|
618
|
+
|
|
619
|
+
Parameters
|
|
620
|
+
----------
|
|
621
|
+
probability : float
|
|
622
|
+
Float setting the probability of applying the operation.
|
|
623
|
+
sfreq : float
|
|
624
|
+
Sampling frequency of the signals to be transformed.
|
|
625
|
+
max_delta_freq : float | torch.Tensor, optional
|
|
626
|
+
Maximum shift in Hz that can be sampled (in absolute value).
|
|
627
|
+
Defaults to 2 (shift sampled between -2 and 2 Hz).
|
|
628
|
+
random_state : int | numpy.random.Generator, optional
|
|
629
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
630
|
+
Defaults to None.
|
|
631
|
+
"""
|
|
632
|
+
|
|
633
|
+
operation = staticmethod(frequency_shift) # type: ignore[assignment]
|
|
634
|
+
|
|
635
|
+
def __init__(self, probability, sfreq, max_delta_freq=2, random_state=None):
|
|
636
|
+
super().__init__(
|
|
637
|
+
probability=probability,
|
|
638
|
+
random_state=random_state,
|
|
639
|
+
)
|
|
640
|
+
assert isinstance(sfreq, Real) and sfreq > 0, (
|
|
641
|
+
"sfreq should be a positive float."
|
|
642
|
+
)
|
|
643
|
+
self.sfreq = sfreq
|
|
644
|
+
|
|
645
|
+
self.max_delta_freq = max_delta_freq
|
|
646
|
+
|
|
647
|
+
def get_augmentation_params(self, *batch):
|
|
648
|
+
"""Return transform parameters.
|
|
649
|
+
|
|
650
|
+
Parameters
|
|
651
|
+
----------
|
|
652
|
+
X : tensor.Tensor
|
|
653
|
+
The data.
|
|
654
|
+
y : tensor.Tensor
|
|
655
|
+
The labels.
|
|
656
|
+
|
|
657
|
+
Returns
|
|
658
|
+
-------
|
|
659
|
+
params : dict
|
|
660
|
+
Contains
|
|
661
|
+
|
|
662
|
+
* delta_freq : float
|
|
663
|
+
The amplitude of the frequency shift (in Hz).
|
|
664
|
+
* sfreq : float
|
|
665
|
+
Sampling frequency of the signals to be transformed.
|
|
666
|
+
"""
|
|
667
|
+
if len(batch) == 0:
|
|
668
|
+
return super().get_augmentation_params(*batch)
|
|
669
|
+
X = batch[0]
|
|
670
|
+
|
|
671
|
+
u = torch.as_tensor(self.rng.uniform(size=X.shape[0]), device=X.device)
|
|
672
|
+
max_delta_freq = self.max_delta_freq
|
|
673
|
+
if isinstance(max_delta_freq, torch.Tensor):
|
|
674
|
+
max_delta_freq = max_delta_freq.to(X.device)
|
|
675
|
+
delta_freq = u * 2 * max_delta_freq - max_delta_freq
|
|
676
|
+
return {
|
|
677
|
+
"delta_freq": delta_freq,
|
|
678
|
+
"sfreq": self.sfreq,
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
|
|
683
|
+
"""Returns standard 10-20 sensors position matrix (for instantiating.
|
|
684
|
+
|
|
685
|
+
SensorsRotation for example).
|
|
686
|
+
|
|
687
|
+
Parameters
|
|
688
|
+
----------
|
|
689
|
+
raw_or_epoch : mne.io.Raw | mne.Epoch, optional
|
|
690
|
+
Example of raw or epoch to retrieve ordered channels list from. Need to
|
|
691
|
+
be named as in 10-20. By default None.
|
|
692
|
+
ordered_ch_names : list, optional
|
|
693
|
+
List of strings representing the channels of the montage considered.
|
|
694
|
+
The order has to be consistent with the order of channels in the input
|
|
695
|
+
matrices that will be fed to `SensorsRotation` transform. By
|
|
696
|
+
default None.
|
|
697
|
+
"""
|
|
698
|
+
assert raw_or_epoch is not None or ordered_ch_names is not None, (
|
|
699
|
+
"At least one of raw_or_epoch and ordered_ch_names is needed."
|
|
700
|
+
)
|
|
701
|
+
if ordered_ch_names is None:
|
|
702
|
+
ordered_ch_names = raw_or_epoch.info["ch_names"]
|
|
703
|
+
ten_twenty_montage = make_standard_montage("standard_1020")
|
|
704
|
+
positions_dict = ten_twenty_montage.get_positions()["ch_pos"]
|
|
705
|
+
positions_subdict = {
|
|
706
|
+
k: positions_dict[k] for k in ordered_ch_names if k in positions_dict
|
|
707
|
+
}
|
|
708
|
+
return np.stack(list(positions_subdict.values())).T
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
class SensorsRotation(Transform):
|
|
712
|
+
"""Interpolates EEG signals over sensors rotated around the desired axis.
|
|
713
|
+
|
|
714
|
+
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
715
|
+
|
|
716
|
+
Suggested in [1]_
|
|
717
|
+
|
|
718
|
+
Parameters
|
|
719
|
+
----------
|
|
720
|
+
probability : float
|
|
721
|
+
Float setting the probability of applying the operation.
|
|
722
|
+
sensors_positions_matrix : numpy.ndarray
|
|
723
|
+
Matrix giving the positions of each sensor in a 3D cartesian coordinate
|
|
724
|
+
system. Should have shape (3, n_channels), where n_channels is the
|
|
725
|
+
number of channels. Standard 10-20 positions can be obtained from
|
|
726
|
+
`mne` through::
|
|
727
|
+
|
|
728
|
+
>>> ten_twenty_montage = mne.channels.make_standard_montage(
|
|
729
|
+
... 'standard_1020'
|
|
730
|
+
... ).get_positions()['ch_pos']
|
|
731
|
+
|
|
732
|
+
axis : 'x' | 'y' | 'z', optional
|
|
733
|
+
Axis around which to rotate. Defaults to 'z'.
|
|
734
|
+
max_degree : float, optional
|
|
735
|
+
Maximum rotation. Rotation angles will be sampled between
|
|
736
|
+
``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
|
|
737
|
+
spherical_splines : bool, optional
|
|
738
|
+
Whether to use spherical splines for the interpolation or not. When
|
|
739
|
+
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
740
|
+
be used (as in the original paper). Defaults to True.
|
|
741
|
+
random_state : int | numpy.random.Generator, optional
|
|
742
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
743
|
+
Defaults to None.
|
|
744
|
+
|
|
745
|
+
References
|
|
746
|
+
----------
|
|
747
|
+
.. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
|
|
748
|
+
electroencephalographic data. In 2017 39th Annual International
|
|
749
|
+
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
750
|
+
(EMBC) (pp. 471-474).
|
|
751
|
+
"""
|
|
752
|
+
|
|
753
|
+
operation = staticmethod(sensors_rotation) # type: ignore[assignment]
|
|
754
|
+
|
|
755
|
+
def __init__(
|
|
756
|
+
self,
|
|
757
|
+
probability,
|
|
758
|
+
sensors_positions_matrix,
|
|
759
|
+
axis="z",
|
|
760
|
+
max_degrees=15,
|
|
761
|
+
spherical_splines=True,
|
|
762
|
+
random_state=None,
|
|
763
|
+
):
|
|
764
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
765
|
+
if isinstance(sensors_positions_matrix, (np.ndarray, list)):
|
|
766
|
+
sensors_positions_matrix = torch.as_tensor(sensors_positions_matrix)
|
|
767
|
+
assert isinstance(sensors_positions_matrix, torch.Tensor), (
|
|
768
|
+
"sensors_positions should be an Tensor"
|
|
769
|
+
)
|
|
770
|
+
assert isinstance(max_degrees, (Real, torch.Tensor)) and max_degrees >= 0, (
|
|
771
|
+
"max_degrees should be non-negative float."
|
|
772
|
+
)
|
|
773
|
+
assert isinstance(axis, str) and axis in [
|
|
774
|
+
"x",
|
|
775
|
+
"y",
|
|
776
|
+
"z",
|
|
777
|
+
], "axis can be either x, y or z."
|
|
778
|
+
assert sensors_positions_matrix.shape[0] == 3, (
|
|
779
|
+
"sensors_positions_matrix shape should be 3 x n_channels."
|
|
780
|
+
)
|
|
781
|
+
assert isinstance(spherical_splines, bool), (
|
|
782
|
+
"spherical_splines should be a boolean"
|
|
783
|
+
)
|
|
784
|
+
self.sensors_positions_matrix = sensors_positions_matrix
|
|
785
|
+
self.axis = axis
|
|
786
|
+
self.spherical_splines = spherical_splines
|
|
787
|
+
self.max_degrees = max_degrees
|
|
788
|
+
|
|
789
|
+
def get_augmentation_params(self, *batch):
|
|
790
|
+
"""Return transform parameters.
|
|
791
|
+
|
|
792
|
+
Parameters
|
|
793
|
+
----------
|
|
794
|
+
X : tensor.Tensor
|
|
795
|
+
The data.
|
|
796
|
+
y : tensor.Tensor
|
|
797
|
+
The labels.
|
|
798
|
+
|
|
799
|
+
Returns
|
|
800
|
+
-------
|
|
801
|
+
params : dict
|
|
802
|
+
Contains four elements:
|
|
803
|
+
|
|
804
|
+
* sensors_positions_matrix : numpy.ndarray
|
|
805
|
+
Matrix giving the positions of each sensor in a 3D cartesian
|
|
806
|
+
coordinate system. Should have shape (3, n_channels), where
|
|
807
|
+
n_channels is the number of channels.
|
|
808
|
+
* axis : 'x' | 'y' | 'z'
|
|
809
|
+
Axis around which to rotate.
|
|
810
|
+
* angles : array-like
|
|
811
|
+
Array of float of shape ``(batch_size,)`` containing the
|
|
812
|
+
rotation angles (in degrees) for each element of the input
|
|
813
|
+
batch, sampled uniformly between ``-max_degrees``and
|
|
814
|
+
``max_degrees``.
|
|
815
|
+
* spherical_splines : bool
|
|
816
|
+
Whether to use spherical splines for the interpolation or not.
|
|
817
|
+
When ``False``, standard scipy.interpolate.Rbf (with quadratic
|
|
818
|
+
kernel) will be used (as in the original paper).
|
|
819
|
+
"""
|
|
820
|
+
if len(batch) == 0:
|
|
821
|
+
return super().get_augmentation_params(*batch)
|
|
822
|
+
X = batch[0]
|
|
823
|
+
|
|
824
|
+
u = self.rng.uniform(low=0, high=1, size=X.shape[0])
|
|
825
|
+
max_degrees = self.max_degrees
|
|
826
|
+
if isinstance(max_degrees, torch.Tensor):
|
|
827
|
+
max_degrees = max_degrees.to(X.device)
|
|
828
|
+
random_angles = (
|
|
829
|
+
torch.as_tensor(u, device=X.device) * 2 * max_degrees - max_degrees
|
|
830
|
+
)
|
|
831
|
+
return {
|
|
832
|
+
"sensors_positions_matrix": self.sensors_positions_matrix,
|
|
833
|
+
"axis": self.axis,
|
|
834
|
+
"angles": random_angles,
|
|
835
|
+
"spherical_splines": self.spherical_splines,
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
|
|
839
|
+
class SensorsZRotation(SensorsRotation):
|
|
840
|
+
"""Interpolates EEG signals over sensors rotated around the Z axis.
|
|
841
|
+
|
|
842
|
+
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
843
|
+
|
|
844
|
+
Suggested in [1]_
|
|
845
|
+
|
|
846
|
+
Parameters
|
|
847
|
+
----------
|
|
848
|
+
probability : float
|
|
849
|
+
Float setting the probability of applying the operation.
|
|
850
|
+
ordered_ch_names : list
|
|
851
|
+
List of strings representing the channels of the montage considered.
|
|
852
|
+
Has to be in standard 10-20 style. The order has to be consistent with
|
|
853
|
+
the order of channels in the input matrices that will be fed to the
|
|
854
|
+
transform. This channel will be used to compute approximate sensors
|
|
855
|
+
positions from a standard 10-20 montage.
|
|
856
|
+
max_degree : float, optional
|
|
857
|
+
Maximum rotation. Rotation angles will be sampled between
|
|
858
|
+
``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
|
|
859
|
+
spherical_splines : bool, optional
|
|
860
|
+
Whether to use spherical splines for the interpolation or not. When
|
|
861
|
+
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
862
|
+
be used (as in the original paper). Defaults to True.
|
|
863
|
+
random_state : int | numpy.random.Generator, optional
|
|
864
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
865
|
+
Defaults to None.
|
|
866
|
+
|
|
867
|
+
References
|
|
868
|
+
----------
|
|
869
|
+
.. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
|
|
870
|
+
electroencephalographic data. In 2017 39th Annual International
|
|
871
|
+
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
872
|
+
(EMBC) (pp. 471-474).
|
|
873
|
+
"""
|
|
874
|
+
|
|
875
|
+
def __init__(
|
|
876
|
+
self,
|
|
877
|
+
probability,
|
|
878
|
+
ordered_ch_names,
|
|
879
|
+
max_degrees=15,
|
|
880
|
+
spherical_splines=True,
|
|
881
|
+
random_state=None,
|
|
882
|
+
):
|
|
883
|
+
sensors_positions_matrix = torch.as_tensor(
|
|
884
|
+
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
885
|
+
)
|
|
886
|
+
super().__init__(
|
|
887
|
+
probability=probability,
|
|
888
|
+
sensors_positions_matrix=sensors_positions_matrix,
|
|
889
|
+
axis="z",
|
|
890
|
+
max_degrees=max_degrees,
|
|
891
|
+
spherical_splines=spherical_splines,
|
|
892
|
+
random_state=random_state,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
class SensorsYRotation(SensorsRotation):
|
|
897
|
+
"""Interpolates EEG signals over sensors rotated around the Y axis.
|
|
898
|
+
|
|
899
|
+
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
900
|
+
|
|
901
|
+
Suggested in [1]_
|
|
902
|
+
|
|
903
|
+
Parameters
|
|
904
|
+
----------
|
|
905
|
+
probability : float
|
|
906
|
+
Float setting the probability of applying the operation.
|
|
907
|
+
ordered_ch_names : list
|
|
908
|
+
List of strings representing the channels of the montage considered.
|
|
909
|
+
Has to be in standard 10-20 style. The order has to be consistent with
|
|
910
|
+
the order of channels in the input matrices that will be fed to the
|
|
911
|
+
transform. This channel will be used to compute approximate sensors
|
|
912
|
+
positions from a standard 10-20 montage.
|
|
913
|
+
max_degree : float, optional
|
|
914
|
+
Maximum rotation. Rotation angles will be sampled between
|
|
915
|
+
``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
|
|
916
|
+
spherical_splines : bool, optional
|
|
917
|
+
Whether to use spherical splines for the interpolation or not. When
|
|
918
|
+
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
919
|
+
be used (as in the original paper). Defaults to True.
|
|
920
|
+
random_state : int | numpy.random.Generator, optional
|
|
921
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
922
|
+
Defaults to None.
|
|
923
|
+
|
|
924
|
+
References
|
|
925
|
+
----------
|
|
926
|
+
.. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
|
|
927
|
+
electroencephalographic data. In 2017 39th Annual International
|
|
928
|
+
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
929
|
+
(EMBC) (pp. 471-474).
|
|
930
|
+
"""
|
|
931
|
+
|
|
932
|
+
def __init__(
|
|
933
|
+
self,
|
|
934
|
+
probability,
|
|
935
|
+
ordered_ch_names,
|
|
936
|
+
max_degrees=15,
|
|
937
|
+
spherical_splines=True,
|
|
938
|
+
random_state=None,
|
|
939
|
+
):
|
|
940
|
+
sensors_positions_matrix = torch.as_tensor(
|
|
941
|
+
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
942
|
+
)
|
|
943
|
+
super().__init__(
|
|
944
|
+
probability=probability,
|
|
945
|
+
sensors_positions_matrix=sensors_positions_matrix,
|
|
946
|
+
axis="y",
|
|
947
|
+
max_degrees=max_degrees,
|
|
948
|
+
spherical_splines=spherical_splines,
|
|
949
|
+
random_state=random_state,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
class SensorsXRotation(SensorsRotation):
|
|
954
|
+
"""Interpolates EEG signals over sensors rotated around the X axis.
|
|
955
|
+
|
|
956
|
+
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
957
|
+
|
|
958
|
+
Suggested in [1]_
|
|
959
|
+
|
|
960
|
+
Parameters
|
|
961
|
+
----------
|
|
962
|
+
probability : float
|
|
963
|
+
Float setting the probability of applying the operation.
|
|
964
|
+
ordered_ch_names : list
|
|
965
|
+
List of strings representing the channels of the montage considered.
|
|
966
|
+
Has to be in standard 10-20 style. The order has to be consistent with
|
|
967
|
+
the order of channels in the input matrices that will be fed to the
|
|
968
|
+
transform. This channel will be used to compute approximate sensors
|
|
969
|
+
positions from a standard 10-20 montage.
|
|
970
|
+
max_degree : float, optional
|
|
971
|
+
Maximum rotation. Rotation angles will be sampled between
|
|
972
|
+
``-max_degree`` and ``max_degree``. Defaults to 15 degrees.
|
|
973
|
+
spherical_splines : bool, optional
|
|
974
|
+
Whether to use spherical splines for the interpolation or not. When
|
|
975
|
+
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
976
|
+
be used (as in the original paper). Defaults to True.
|
|
977
|
+
random_state : int | numpy.random.Generator, optional
|
|
978
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
979
|
+
Defaults to None.
|
|
980
|
+
|
|
981
|
+
References
|
|
982
|
+
----------
|
|
983
|
+
.. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
|
|
984
|
+
electroencephalographic data. In 2017 39th Annual International
|
|
985
|
+
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
986
|
+
(EMBC) (pp. 471-474).
|
|
987
|
+
"""
|
|
988
|
+
|
|
989
|
+
def __init__(
|
|
990
|
+
self,
|
|
991
|
+
probability,
|
|
992
|
+
ordered_ch_names,
|
|
993
|
+
max_degrees=15,
|
|
994
|
+
spherical_splines=True,
|
|
995
|
+
random_state=None,
|
|
996
|
+
):
|
|
997
|
+
sensors_positions_matrix = torch.as_tensor(
|
|
998
|
+
_get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
|
|
999
|
+
)
|
|
1000
|
+
super().__init__(
|
|
1001
|
+
probability=probability,
|
|
1002
|
+
sensors_positions_matrix=sensors_positions_matrix,
|
|
1003
|
+
axis="x",
|
|
1004
|
+
max_degrees=max_degrees,
|
|
1005
|
+
spherical_splines=spherical_splines,
|
|
1006
|
+
random_state=random_state,
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
class Mixup(Transform):
|
|
1011
|
+
"""Implements Iterator for Mixup for EEG data.
|
|
1012
|
+
|
|
1013
|
+
See [1]_.
|
|
1014
|
+
Implementation based on [2]_.
|
|
1015
|
+
|
|
1016
|
+
Parameters
|
|
1017
|
+
----------
|
|
1018
|
+
alpha : float
|
|
1019
|
+
Mixup hyperparameter.
|
|
1020
|
+
beta_per_sample : bool (default=False)
|
|
1021
|
+
By default, one mixing coefficient per batch is drawn from a beta
|
|
1022
|
+
distribution. If True, one mixing coefficient per sample is drawn.
|
|
1023
|
+
random_state : int | numpy.random.Generator, optional
|
|
1024
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1025
|
+
Defaults to None.
|
|
1026
|
+
|
|
1027
|
+
References
|
|
1028
|
+
----------
|
|
1029
|
+
.. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
|
|
1030
|
+
(2018). mixup: Beyond Empirical Risk Minimization. In 2018
|
|
1031
|
+
International Conference on Learning Representations (ICLR)
|
|
1032
|
+
Online: https://arxiv.org/abs/1710.09412
|
|
1033
|
+
.. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
|
|
1034
|
+
"""
|
|
1035
|
+
|
|
1036
|
+
operation = staticmethod(mixup) # type: ignore[assignment]
|
|
1037
|
+
|
|
1038
|
+
def __init__(self, alpha, beta_per_sample=False, random_state=None):
|
|
1039
|
+
super().__init__(
|
|
1040
|
+
probability=1.0, # Mixup has to be applied to whole batches
|
|
1041
|
+
random_state=random_state,
|
|
1042
|
+
)
|
|
1043
|
+
self.alpha = alpha
|
|
1044
|
+
self.beta_per_sample = beta_per_sample
|
|
1045
|
+
|
|
1046
|
+
def get_augmentation_params(self, *batch):
|
|
1047
|
+
"""Return transform parameters.
|
|
1048
|
+
|
|
1049
|
+
Parameters
|
|
1050
|
+
----------
|
|
1051
|
+
X : tensor.Tensor
|
|
1052
|
+
The data.
|
|
1053
|
+
y : tensor.Tensor
|
|
1054
|
+
The labels.
|
|
1055
|
+
|
|
1056
|
+
Returns
|
|
1057
|
+
-------
|
|
1058
|
+
params : dict
|
|
1059
|
+
Contains the values sampled uniformly between 0 and 1 setting the
|
|
1060
|
+
linear interpolation between examples (lam) and the shuffled
|
|
1061
|
+
indices of examples that are mixed into original examples
|
|
1062
|
+
(idx_perm).
|
|
1063
|
+
"""
|
|
1064
|
+
X = batch[0]
|
|
1065
|
+
device = X.device
|
|
1066
|
+
batch_size, _, _ = X.shape
|
|
1067
|
+
|
|
1068
|
+
if self.alpha > 0:
|
|
1069
|
+
if self.beta_per_sample:
|
|
1070
|
+
lam = torch.as_tensor(
|
|
1071
|
+
self.rng.beta(self.alpha, self.alpha, batch_size)
|
|
1072
|
+
).to(device)
|
|
1073
|
+
else:
|
|
1074
|
+
lam = torch.ones(batch_size).to(device)
|
|
1075
|
+
lam *= self.rng.beta(self.alpha, self.alpha)
|
|
1076
|
+
else:
|
|
1077
|
+
lam = torch.ones(batch_size).to(device)
|
|
1078
|
+
|
|
1079
|
+
idx_perm = torch.as_tensor(
|
|
1080
|
+
self.rng.permutation(
|
|
1081
|
+
batch_size,
|
|
1082
|
+
)
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
return {
|
|
1086
|
+
"lam": lam,
|
|
1087
|
+
"idx_perm": idx_perm,
|
|
1088
|
+
}
|
|
1089
|
+
|
|
1090
|
+
|
|
1091
|
+
class SegmentationReconstruction(Transform):
|
|
1092
|
+
"""Segmentation Reconstruction from Lotte (2015) [Lotte2015]_.
|
|
1093
|
+
|
|
1094
|
+
Applies a segmentation-reconstruction transform to the input data, as
|
|
1095
|
+
proposed in [Lotte2015]_. It segments each trial in the batch and randomly mix
|
|
1096
|
+
it to generate new synthetic trials by label, preserving the original
|
|
1097
|
+
order of the segments in time domain.
|
|
1098
|
+
|
|
1099
|
+
Parameters
|
|
1100
|
+
----------
|
|
1101
|
+
probability : float
|
|
1102
|
+
Float setting the probability of applying the operation.
|
|
1103
|
+
random_state : int | numpy.random.Generator, optional
|
|
1104
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1105
|
+
Used to decide whether to transform given the probability
|
|
1106
|
+
argument and to sample the segments mixing. Defaults to None.
|
|
1107
|
+
n_segments : int, optional
|
|
1108
|
+
Number of segments to use in the batch. If None, X will be
|
|
1109
|
+
automatically segmented, getting the last element in a list
|
|
1110
|
+
of factors of the number of samples's square root. Defaults to None.
|
|
1111
|
+
|
|
1112
|
+
References
|
|
1113
|
+
----------
|
|
1114
|
+
.. [Lotte2015] Lotte, F. (2015). Signal processing approaches to minimize
|
|
1115
|
+
or suppress calibration time in oscillatory activity-based brain–computer
|
|
1116
|
+
interfaces. Proceedings of the IEEE, 103(6), 871-890.
|
|
1117
|
+
"""
|
|
1118
|
+
|
|
1119
|
+
operation = staticmethod(segmentation_reconstruction) # type: ignore[assignment]
|
|
1120
|
+
|
|
1121
|
+
def __init__(
|
|
1122
|
+
self,
|
|
1123
|
+
probability,
|
|
1124
|
+
n_segments=None,
|
|
1125
|
+
random_state=None,
|
|
1126
|
+
):
|
|
1127
|
+
super().__init__(
|
|
1128
|
+
probability=probability,
|
|
1129
|
+
random_state=random_state,
|
|
1130
|
+
)
|
|
1131
|
+
self.n_segments = n_segments
|
|
1132
|
+
|
|
1133
|
+
def get_augmentation_params(self, *batch):
|
|
1134
|
+
"""Return transform parameters.
|
|
1135
|
+
|
|
1136
|
+
Parameters
|
|
1137
|
+
----------
|
|
1138
|
+
X : tensor.Tensor
|
|
1139
|
+
The data.
|
|
1140
|
+
y : tensor.Tensor
|
|
1141
|
+
The labels.
|
|
1142
|
+
|
|
1143
|
+
Returns
|
|
1144
|
+
-------
|
|
1145
|
+
params : dict
|
|
1146
|
+
Contains the number of segments to split the signal into.
|
|
1147
|
+
"""
|
|
1148
|
+
X, y = batch[0], batch[1]
|
|
1149
|
+
|
|
1150
|
+
if y is not None:
|
|
1151
|
+
if not isinstance(X, torch.Tensor) or not isinstance(y, torch.Tensor):
|
|
1152
|
+
raise ValueError("X and y must be torch tensors.")
|
|
1153
|
+
|
|
1154
|
+
if X.shape[0] != y.shape[0]:
|
|
1155
|
+
raise ValueError("Number of samples in X and y must be the same.")
|
|
1156
|
+
|
|
1157
|
+
if self.n_segments is None:
|
|
1158
|
+
self.n_segments = int(X.shape[2])
|
|
1159
|
+
n_segments_list = []
|
|
1160
|
+
for i in range(1, int(self.n_segments**0.5) + 1):
|
|
1161
|
+
if self.n_segments % i == 0:
|
|
1162
|
+
n_segments_list.append(i)
|
|
1163
|
+
self.n_segments = n_segments_list[-1]
|
|
1164
|
+
|
|
1165
|
+
elif not (
|
|
1166
|
+
isinstance(self.n_segments, (int, float))
|
|
1167
|
+
and 1 <= self.n_segments <= X.shape[2]
|
|
1168
|
+
):
|
|
1169
|
+
raise ValueError(
|
|
1170
|
+
f"Number of segments must be a positive integer less than "
|
|
1171
|
+
f"(or equal) the window size. Got {self.n_segments}"
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
if y is None:
|
|
1175
|
+
data_classes = [(np.nan, X)]
|
|
1176
|
+
|
|
1177
|
+
else:
|
|
1178
|
+
classes = torch.unique(y)
|
|
1179
|
+
|
|
1180
|
+
data_classes = [(i, X[y == i]) for i in classes]
|
|
1181
|
+
|
|
1182
|
+
rand_indices = dict()
|
|
1183
|
+
for label, X_class in data_classes:
|
|
1184
|
+
n_trials = X_class.shape[0]
|
|
1185
|
+
rand_indices[label] = self.rng.randint(
|
|
1186
|
+
0, n_trials, (n_trials, self.n_segments)
|
|
1187
|
+
)
|
|
1188
|
+
|
|
1189
|
+
idx_shuffle = self.rng.permutation(X.shape[0])
|
|
1190
|
+
|
|
1191
|
+
return {
|
|
1192
|
+
"n_segments": self.n_segments,
|
|
1193
|
+
"data_classes": data_classes,
|
|
1194
|
+
"rand_indices": rand_indices,
|
|
1195
|
+
"idx_shuffle": idx_shuffle,
|
|
1196
|
+
}
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
class MaskEncoding(Transform):
|
|
1200
|
+
"""MaskEncoding from [1]_.
|
|
1201
|
+
|
|
1202
|
+
Replaces randomly chosen contiguous part (or parts) of all channels by
|
|
1203
|
+
zeros (if more than one segment, it may overlap).
|
|
1204
|
+
|
|
1205
|
+
Implementation based on [1]_
|
|
1206
|
+
|
|
1207
|
+
Parameters
|
|
1208
|
+
----------
|
|
1209
|
+
probability : float
|
|
1210
|
+
Float setting the probability of applying the operation.
|
|
1211
|
+
max_mask_ratio : float, optional
|
|
1212
|
+
Signal ratio to zero out. Defaults to 0.1.
|
|
1213
|
+
n_segments : int, optional
|
|
1214
|
+
Number of segments to zero out in each example.
|
|
1215
|
+
Defaults to 1.
|
|
1216
|
+
random_state : int | numpy.random.Generator, optional
|
|
1217
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1218
|
+
Defaults to None.
|
|
1219
|
+
|
|
1220
|
+
References
|
|
1221
|
+
----------
|
|
1222
|
+
.. [1] Ding, Wenlong, et al. "A Novel Data Augmentation Approach
|
|
1223
|
+
Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI."
|
|
1224
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering
|
|
1225
|
+
32 (2024): 875-886.
|
|
1226
|
+
"""
|
|
1227
|
+
|
|
1228
|
+
operation = staticmethod(mask_encoding) # type: ignore[assignment]
|
|
1229
|
+
|
|
1230
|
+
def __init__(
|
|
1231
|
+
self,
|
|
1232
|
+
probability,
|
|
1233
|
+
max_mask_ratio=0.1,
|
|
1234
|
+
n_segments=1,
|
|
1235
|
+
random_state=None,
|
|
1236
|
+
):
|
|
1237
|
+
super().__init__(
|
|
1238
|
+
probability=probability,
|
|
1239
|
+
random_state=random_state,
|
|
1240
|
+
)
|
|
1241
|
+
assert isinstance(n_segments, int) and n_segments > 0, (
|
|
1242
|
+
"n_segments should be a positive integer."
|
|
1243
|
+
)
|
|
1244
|
+
assert isinstance(max_mask_ratio, (int, float)) and 0 <= max_mask_ratio <= 1, (
|
|
1245
|
+
"mask_ratio should be a float between 0 and 1."
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
self.mask_ratio = max_mask_ratio
|
|
1249
|
+
self.n_segments = n_segments
|
|
1250
|
+
|
|
1251
|
+
def get_augmentation_params(self, *batch):
|
|
1252
|
+
"""Return transform parameters.
|
|
1253
|
+
|
|
1254
|
+
Parameters
|
|
1255
|
+
----------
|
|
1256
|
+
X : tensor.Tensor
|
|
1257
|
+
The data.
|
|
1258
|
+
y : tensor.Tensor
|
|
1259
|
+
The labels.
|
|
1260
|
+
|
|
1261
|
+
Returns
|
|
1262
|
+
-------
|
|
1263
|
+
params : dict
|
|
1264
|
+
Contains ...
|
|
1265
|
+
"""
|
|
1266
|
+
if len(batch) == 0:
|
|
1267
|
+
return super().get_augmentation_params(*batch)
|
|
1268
|
+
X = batch[0]
|
|
1269
|
+
|
|
1270
|
+
batch_size, _, n_times = X.shape
|
|
1271
|
+
|
|
1272
|
+
segment_length = int((n_times * self.mask_ratio) / self.n_segments)
|
|
1273
|
+
|
|
1274
|
+
assert segment_length >= 1, (
|
|
1275
|
+
"n_segments should be a positive integer not higher than (max_mask_ratio * window size)."
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
time_start = self.rng.randint(
|
|
1279
|
+
0, n_times - segment_length, (batch_size, self.n_segments)
|
|
1280
|
+
)
|
|
1281
|
+
time_start = torch.from_numpy(time_start)
|
|
1282
|
+
|
|
1283
|
+
return {
|
|
1284
|
+
"time_start": time_start,
|
|
1285
|
+
"segment_length": segment_length,
|
|
1286
|
+
"n_segments": self.n_segments,
|
|
1287
|
+
}
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
class ChannelsReref(Transform):
|
|
1291
|
+
"""Randomly re-reference channels in EEG data matrix.
|
|
1292
|
+
|
|
1293
|
+
Part of the augmentations proposed in [1]_
|
|
1294
|
+
|
|
1295
|
+
Parameters
|
|
1296
|
+
----------
|
|
1297
|
+
probability : float
|
|
1298
|
+
Float setting the probability of applying the operation.
|
|
1299
|
+
random_state : int | numpy.random.Generator, optional
|
|
1300
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1301
|
+
Used to decide whether or not to transform given the probability
|
|
1302
|
+
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
1303
|
+
Defaults to None.
|
|
1304
|
+
|
|
1305
|
+
References
|
|
1306
|
+
----------
|
|
1307
|
+
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1308
|
+
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1309
|
+
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1310
|
+
Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
|
|
1311
|
+
"""
|
|
1312
|
+
|
|
1313
|
+
operation = staticmethod(channels_rereference) # type: ignore[assignment]
|
|
1314
|
+
|
|
1315
|
+
def __init__(self, probability, random_state=None):
|
|
1316
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
1317
|
+
|
|
1318
|
+
def get_augmentation_params(self, *batch):
|
|
1319
|
+
"""Return transform parameters."""
|
|
1320
|
+
return {
|
|
1321
|
+
"random_state": self.rng,
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
|
|
1325
|
+
class AmplitudeScale(Transform):
|
|
1326
|
+
"""Rescale amplitude based on a random sampled scaling value.
|
|
1327
|
+
|
|
1328
|
+
Part of the augmentations proposed in [1]_
|
|
1329
|
+
|
|
1330
|
+
Parameters
|
|
1331
|
+
----------
|
|
1332
|
+
probability : float
|
|
1333
|
+
Float setting the probability of applying the operation.
|
|
1334
|
+
random_state : int | numpy.random.Generator, optional
|
|
1335
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1336
|
+
Used to decide whether or not to transform given the probability
|
|
1337
|
+
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
1338
|
+
Defaults to None.
|
|
1339
|
+
|
|
1340
|
+
References
|
|
1341
|
+
----------
|
|
1342
|
+
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1343
|
+
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1344
|
+
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1345
|
+
Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
|
|
1346
|
+
"""
|
|
1347
|
+
|
|
1348
|
+
operation = staticmethod(amplitude_scale) # type: ignore[assignment]
|
|
1349
|
+
|
|
1350
|
+
def __init__(self, probability, interval=(0.5, 2), random_state=None):
|
|
1351
|
+
super().__init__(probability=probability, random_state=random_state)
|
|
1352
|
+
self.scale = interval
|
|
1353
|
+
|
|
1354
|
+
def get_augmentation_params(self, *batch):
|
|
1355
|
+
"""Return transform parameters."""
|
|
1356
|
+
return {"random_state": self.rng, "scale": self.scale}
|