braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1300 @@
|
|
|
1
|
+
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
|
+
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
|
+
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
4
|
+
# Bruna Lopes <brunajaflopes@gmail.com>
|
|
5
|
+
#
|
|
6
|
+
# License: BSD (3-clause)
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from numbers import Real
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
15
|
+
import torch
|
|
16
|
+
from mne.filter import notch_filter
|
|
17
|
+
from scipy.interpolate import Rbf
|
|
18
|
+
from sklearn.utils import check_random_state
|
|
19
|
+
from torch.fft import fft, ifft
|
|
20
|
+
from torch.nn.functional import one_hot, pad
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def identity(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
24
|
+
"""Identity operation.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
X : torch.Tensor
|
|
29
|
+
EEG input example or batch.
|
|
30
|
+
y : torch.Tensor
|
|
31
|
+
EEG labels for the example or batch.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
torch.Tensor
|
|
36
|
+
Transformed inputs.
|
|
37
|
+
torch.Tensor
|
|
38
|
+
Transformed labels.
|
|
39
|
+
"""
|
|
40
|
+
return X, y
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def time_reverse(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
44
|
+
"""Flip the time axis of each input.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
X : torch.Tensor
|
|
49
|
+
EEG input example or batch.
|
|
50
|
+
y : torch.Tensor
|
|
51
|
+
EEG labels for the example or batch.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
torch.Tensor
|
|
56
|
+
Transformed inputs.
|
|
57
|
+
torch.Tensor
|
|
58
|
+
Transformed labels.
|
|
59
|
+
"""
|
|
60
|
+
return torch.flip(X, [-1]), y
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def sign_flip(X: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
64
|
+
"""Flip the sign axis of each input.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
X : torch.Tensor
|
|
69
|
+
EEG input example or batch.
|
|
70
|
+
y : torch.Tensor
|
|
71
|
+
EEG labels for the example or batch.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
torch.Tensor
|
|
76
|
+
Transformed inputs.
|
|
77
|
+
torch.Tensor
|
|
78
|
+
Transformed labels.
|
|
79
|
+
"""
|
|
80
|
+
return -X, y
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _new_random_fft_phase_odd(
|
|
84
|
+
batch_size: int,
|
|
85
|
+
c: int,
|
|
86
|
+
n: int,
|
|
87
|
+
device: torch.device,
|
|
88
|
+
random_state: int | np.random.RandomState | None,
|
|
89
|
+
) -> torch.Tensor:
|
|
90
|
+
rng = check_random_state(random_state)
|
|
91
|
+
random_phase = torch.from_numpy(
|
|
92
|
+
2j * np.pi * rng.random((batch_size, c, (n - 1) // 2))
|
|
93
|
+
).to(device)
|
|
94
|
+
return torch.cat(
|
|
95
|
+
[
|
|
96
|
+
torch.zeros((batch_size, c, 1), device=device),
|
|
97
|
+
random_phase,
|
|
98
|
+
-torch.flip(random_phase, [-1]),
|
|
99
|
+
],
|
|
100
|
+
dim=-1,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _new_random_fft_phase_even(
|
|
105
|
+
batch_size: int,
|
|
106
|
+
c: int,
|
|
107
|
+
n: int,
|
|
108
|
+
device: torch.device,
|
|
109
|
+
random_state: int | np.random.RandomState | None,
|
|
110
|
+
) -> torch.Tensor:
|
|
111
|
+
rng = check_random_state(random_state)
|
|
112
|
+
random_phase = torch.from_numpy(
|
|
113
|
+
2j * np.pi * rng.random((batch_size, c, n // 2 - 1))
|
|
114
|
+
).to(device)
|
|
115
|
+
return torch.cat(
|
|
116
|
+
[
|
|
117
|
+
torch.zeros((batch_size, c, 1), device=device),
|
|
118
|
+
random_phase,
|
|
119
|
+
torch.zeros((batch_size, c, 1), device=device),
|
|
120
|
+
-torch.flip(random_phase, [-1]),
|
|
121
|
+
],
|
|
122
|
+
dim=-1,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
_new_random_fft_phase = {0: _new_random_fft_phase_even, 1: _new_random_fft_phase_odd}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def ft_surrogate(
|
|
130
|
+
X: torch.Tensor,
|
|
131
|
+
y: torch.Tensor,
|
|
132
|
+
phase_noise_magnitude: float,
|
|
133
|
+
channel_indep: bool,
|
|
134
|
+
random_state: int | np.random.RandomState | None = None,
|
|
135
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
136
|
+
"""FT surrogate augmentation of a single EEG channel, as proposed in [1]_.
|
|
137
|
+
|
|
138
|
+
Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
|
|
139
|
+
and modified.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
X : torch.Tensor
|
|
144
|
+
EEG input example or batch.
|
|
145
|
+
y : torch.Tensor
|
|
146
|
+
EEG labels for the example or batch.
|
|
147
|
+
phase_noise_magnitude : float
|
|
148
|
+
Float between 0 and 1 setting the range over which the phase
|
|
149
|
+
perturbation is uniformly sampled:
|
|
150
|
+
[0, `phase_noise_magnitude` * 2 * `pi`].
|
|
151
|
+
channel_indep : bool
|
|
152
|
+
Whether to sample phase perturbations independently for each channel or
|
|
153
|
+
not. It is advised to set it to False when spatial information is
|
|
154
|
+
important for the task, like in BCI.
|
|
155
|
+
random_state : int | numpy.random.Generator, optional
|
|
156
|
+
Used to draw the phase perturbation. Defaults to None.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
torch.Tensor
|
|
161
|
+
Transformed inputs.
|
|
162
|
+
torch.Tensor
|
|
163
|
+
Transformed labels.
|
|
164
|
+
|
|
165
|
+
References
|
|
166
|
+
----------
|
|
167
|
+
.. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
|
|
168
|
+
Clifford, G. D. (2018). Addressing Class Imbalance in Classification
|
|
169
|
+
Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
|
|
170
|
+
preprint arXiv:1806.08675.
|
|
171
|
+
"""
|
|
172
|
+
assert (
|
|
173
|
+
isinstance(
|
|
174
|
+
phase_noise_magnitude, (Real, torch.FloatTensor, torch.cuda.FloatTensor)
|
|
175
|
+
)
|
|
176
|
+
and 0 <= phase_noise_magnitude <= 1
|
|
177
|
+
), f"eps must be a float between 0 and 1. Got {phase_noise_magnitude}."
|
|
178
|
+
|
|
179
|
+
f = fft(X.double(), dim=-1)
|
|
180
|
+
device = X.device
|
|
181
|
+
|
|
182
|
+
n = f.shape[-1]
|
|
183
|
+
random_phase = _new_random_fft_phase[n % 2](
|
|
184
|
+
f.shape[0],
|
|
185
|
+
f.shape[-2] if channel_indep else 1,
|
|
186
|
+
n,
|
|
187
|
+
device=device,
|
|
188
|
+
random_state=random_state,
|
|
189
|
+
)
|
|
190
|
+
if not channel_indep:
|
|
191
|
+
random_phase = torch.tile(random_phase, (1, f.shape[-2], 1))
|
|
192
|
+
if isinstance(phase_noise_magnitude, torch.Tensor):
|
|
193
|
+
phase_noise_magnitude = phase_noise_magnitude.to(device)
|
|
194
|
+
f_shifted = f * torch.exp(phase_noise_magnitude * random_phase)
|
|
195
|
+
shifted = ifft(f_shifted, dim=-1)
|
|
196
|
+
transformed_X = shifted.real.float()
|
|
197
|
+
|
|
198
|
+
return transformed_X, y
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _pick_channels_randomly(
|
|
202
|
+
X: torch.Tensor, p_pick: float, random_state: int | np.random.RandomState | None
|
|
203
|
+
) -> torch.Tensor:
|
|
204
|
+
rng = check_random_state(random_state)
|
|
205
|
+
batch_size, n_channels, _ = X.shape
|
|
206
|
+
# allows to use the same RNG
|
|
207
|
+
unif_samples = torch.as_tensor(
|
|
208
|
+
rng.uniform(0, 1, size=(batch_size, n_channels)),
|
|
209
|
+
dtype=torch.float,
|
|
210
|
+
device=X.device,
|
|
211
|
+
)
|
|
212
|
+
# equivalent to a 0s and 1s mask
|
|
213
|
+
return torch.sigmoid(1000 * (unif_samples - p_pick))
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def channels_dropout(
|
|
217
|
+
X: torch.Tensor,
|
|
218
|
+
y: torch.Tensor,
|
|
219
|
+
p_drop: float,
|
|
220
|
+
random_state: int | np.random.RandomState | None = None,
|
|
221
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
222
|
+
"""Randomly set channels to flat signal.
|
|
223
|
+
|
|
224
|
+
Part of the CMSAugment policy proposed in [1]_
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
X : torch.Tensor
|
|
229
|
+
EEG input example or batch.
|
|
230
|
+
y : torch.Tensor
|
|
231
|
+
EEG labels for the example or batch.
|
|
232
|
+
p_drop : float
|
|
233
|
+
Float between 0 and 1 setting the probability of dropping each channel.
|
|
234
|
+
random_state : int | numpy.random.Generator, optional
|
|
235
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
236
|
+
Defaults to None.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
torch.Tensor
|
|
241
|
+
Transformed inputs.
|
|
242
|
+
torch.Tensor
|
|
243
|
+
Transformed labels.
|
|
244
|
+
|
|
245
|
+
References
|
|
246
|
+
----------
|
|
247
|
+
.. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
|
|
248
|
+
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
249
|
+
Reordering. arXiv preprint arXiv:2010.13694.
|
|
250
|
+
"""
|
|
251
|
+
mask = _pick_channels_randomly(X, p_drop, random_state=random_state)
|
|
252
|
+
return X * mask.unsqueeze(-1), y
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _make_permutation_matrix(
|
|
256
|
+
X: torch.Tensor, mask: torch.Tensor, random_state: int | np.random.Generator | None
|
|
257
|
+
) -> torch.Tensor:
|
|
258
|
+
rng = check_random_state(random_state)
|
|
259
|
+
batch_size, n_channels, _ = X.shape
|
|
260
|
+
hard_mask = mask.round()
|
|
261
|
+
batch_permutations = torch.empty(
|
|
262
|
+
batch_size, n_channels, n_channels, device=X.device
|
|
263
|
+
)
|
|
264
|
+
for b, mask in enumerate(hard_mask):
|
|
265
|
+
channels_to_shuffle = torch.arange(n_channels, device=X.device)
|
|
266
|
+
channels_to_shuffle = channels_to_shuffle[mask.bool()]
|
|
267
|
+
reordered_channels = torch.tensor(
|
|
268
|
+
rng.permutation(channels_to_shuffle.cpu()), device=X.device
|
|
269
|
+
)
|
|
270
|
+
channels_permutation = torch.arange(n_channels, device=X.device)
|
|
271
|
+
channels_permutation[channels_to_shuffle] = reordered_channels
|
|
272
|
+
batch_permutations[b, ...] = one_hot(channels_permutation)
|
|
273
|
+
return batch_permutations
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def channels_shuffle(
|
|
277
|
+
X: torch.Tensor,
|
|
278
|
+
y: torch.Tensor,
|
|
279
|
+
p_shuffle: float,
|
|
280
|
+
random_state: int | np.random.RandomState | None = None,
|
|
281
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
282
|
+
"""Randomly shuffle channels in EEG data matrix.
|
|
283
|
+
|
|
284
|
+
Part of the CMSAugment policy proposed in [1]_
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
X : torch.Tensor
|
|
289
|
+
EEG input example or batch.
|
|
290
|
+
y : torch.Tensor
|
|
291
|
+
EEG labels for the example or batch.
|
|
292
|
+
p_shuffle : float | None
|
|
293
|
+
Float between 0 and 1 setting the probability of including the channel
|
|
294
|
+
in the set of permutted channels.
|
|
295
|
+
random_state : int | numpy.random.Generator, optional
|
|
296
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
297
|
+
Used to sample which channels to shuffle and to carry the shuffle.
|
|
298
|
+
Defaults to None.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
torch.Tensor
|
|
303
|
+
Transformed inputs.
|
|
304
|
+
torch.Tensor
|
|
305
|
+
Transformed labels.
|
|
306
|
+
|
|
307
|
+
References
|
|
308
|
+
----------
|
|
309
|
+
.. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
|
|
310
|
+
Learning from Heterogeneous EEG Signals with Differentiable Channel
|
|
311
|
+
Reordering. arXiv preprint arXiv:2010.13694.
|
|
312
|
+
"""
|
|
313
|
+
if p_shuffle == 0:
|
|
314
|
+
return X, y
|
|
315
|
+
mask = _pick_channels_randomly(X, 1 - p_shuffle, random_state)
|
|
316
|
+
batch_permutations = _make_permutation_matrix(X, mask, random_state)
|
|
317
|
+
return torch.matmul(batch_permutations, X), y
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def gaussian_noise(
|
|
321
|
+
X: torch.Tensor,
|
|
322
|
+
y: torch.Tensor,
|
|
323
|
+
std: float,
|
|
324
|
+
random_state: int | np.random.RandomState | None = None,
|
|
325
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
326
|
+
"""Randomly add white Gaussian noise to all channels.
|
|
327
|
+
|
|
328
|
+
Suggested e.g. in [1]_, [2]_ and [3]_
|
|
329
|
+
|
|
330
|
+
Parameters
|
|
331
|
+
----------
|
|
332
|
+
X : torch.Tensor
|
|
333
|
+
EEG input example or batch.
|
|
334
|
+
y : torch.Tensor
|
|
335
|
+
EEG labels for the example or batch.
|
|
336
|
+
std : float
|
|
337
|
+
Standard deviation to use for the additive noise.
|
|
338
|
+
random_state : int | numpy.random.Generator, optional
|
|
339
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
340
|
+
Defaults to None.
|
|
341
|
+
|
|
342
|
+
Returns
|
|
343
|
+
-------
|
|
344
|
+
torch.Tensor
|
|
345
|
+
Transformed inputs.
|
|
346
|
+
torch.Tensor
|
|
347
|
+
Transformed labels.
|
|
348
|
+
|
|
349
|
+
References
|
|
350
|
+
----------
|
|
351
|
+
.. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
|
|
352
|
+
augmentation for eeg-based emotion recognition with deep convolutional
|
|
353
|
+
neural networks. In International Conference on Multimedia Modeling
|
|
354
|
+
(pp. 82-93).
|
|
355
|
+
.. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
|
|
356
|
+
Subject-aware contrastive learning for biosignals. arXiv preprint
|
|
357
|
+
arXiv:2007.04871.
|
|
358
|
+
.. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
|
|
359
|
+
Representation Learning for Electroencephalogram Classification. In
|
|
360
|
+
Machine Learning for Health (pp. 238-253). PMLR.
|
|
361
|
+
"""
|
|
362
|
+
rng = check_random_state(random_state)
|
|
363
|
+
if isinstance(std, torch.Tensor):
|
|
364
|
+
std = std.to(X.device)
|
|
365
|
+
noise = (
|
|
366
|
+
torch.from_numpy(
|
|
367
|
+
rng.normal(loc=np.zeros(X.shape), scale=1),
|
|
368
|
+
)
|
|
369
|
+
.float()
|
|
370
|
+
.to(X.device)
|
|
371
|
+
* std
|
|
372
|
+
)
|
|
373
|
+
transformed_X = X + noise
|
|
374
|
+
return transformed_X, y
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def channels_permute(
|
|
378
|
+
X: torch.Tensor, y: torch.Tensor, permutation: list[int]
|
|
379
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
380
|
+
"""Permute EEG channels according to fixed permutation matrix.
|
|
381
|
+
|
|
382
|
+
Suggested e.g. in [1]_
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
X : torch.Tensor
|
|
387
|
+
EEG input example or batch.
|
|
388
|
+
y : torch.Tensor
|
|
389
|
+
EEG labels for the example or batch.
|
|
390
|
+
permutation : list
|
|
391
|
+
List of integers defining the new channels order.
|
|
392
|
+
|
|
393
|
+
Returns
|
|
394
|
+
-------
|
|
395
|
+
torch.Tensor
|
|
396
|
+
Transformed inputs.
|
|
397
|
+
torch.Tensor
|
|
398
|
+
Transformed labels.
|
|
399
|
+
|
|
400
|
+
References
|
|
401
|
+
----------
|
|
402
|
+
.. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
|
|
403
|
+
(2018). HAMLET: interpretable human and machine co-learning technique.
|
|
404
|
+
arXiv preprint arXiv:1803.09702.
|
|
405
|
+
"""
|
|
406
|
+
return X[..., permutation, :], y
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def smooth_time_mask(
|
|
410
|
+
X: torch.Tensor,
|
|
411
|
+
y: torch.Tensor,
|
|
412
|
+
mask_start_per_sample: torch.Tensor,
|
|
413
|
+
mask_len_samples: int,
|
|
414
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
415
|
+
"""Smoothly replace a contiguous part of all channels by zeros.
|
|
416
|
+
|
|
417
|
+
Originally proposed in [1]_ and [2]_
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
X : torch.Tensor
|
|
422
|
+
EEG input example or batch.
|
|
423
|
+
y : torch.Tensor
|
|
424
|
+
EEG labels for the example or batch.
|
|
425
|
+
mask_start_per_sample : torch.tensor
|
|
426
|
+
Tensor of integers containing the position (in last dimension) where to
|
|
427
|
+
start masking the signal. Should have the same size as the first
|
|
428
|
+
dimension of X (i.e. one start position per example in the batch).
|
|
429
|
+
mask_len_samples : int
|
|
430
|
+
Number of consecutive samples to zero out.
|
|
431
|
+
|
|
432
|
+
Returns
|
|
433
|
+
-------
|
|
434
|
+
torch.Tensor
|
|
435
|
+
Transformed inputs.
|
|
436
|
+
torch.Tensor
|
|
437
|
+
Transformed labels.
|
|
438
|
+
|
|
439
|
+
References
|
|
440
|
+
----------
|
|
441
|
+
.. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
|
|
442
|
+
Subject-aware contrastive learning for biosignals. arXiv preprint
|
|
443
|
+
arXiv:2007.04871.
|
|
444
|
+
.. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
|
|
445
|
+
Representation Learning for Electroencephalogram Classification. In
|
|
446
|
+
Machine Learning for Health (pp. 238-253). PMLR.
|
|
447
|
+
"""
|
|
448
|
+
batch_size, n_channels, seq_len = X.shape
|
|
449
|
+
t = torch.arange(seq_len, device=X.device).float()
|
|
450
|
+
t = t.repeat(batch_size, n_channels, 1)
|
|
451
|
+
mask_start_per_sample = mask_start_per_sample.view(-1, 1, 1)
|
|
452
|
+
s = 1000 / seq_len
|
|
453
|
+
mask = (
|
|
454
|
+
(
|
|
455
|
+
torch.sigmoid(s * -(t - mask_start_per_sample))
|
|
456
|
+
+ torch.sigmoid(s * (t - mask_start_per_sample - mask_len_samples))
|
|
457
|
+
)
|
|
458
|
+
.float()
|
|
459
|
+
.to(X.device)
|
|
460
|
+
)
|
|
461
|
+
return X * mask, y
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def bandstop_filter(
|
|
465
|
+
X: torch.Tensor,
|
|
466
|
+
y: torch.Tensor,
|
|
467
|
+
sfreq: float,
|
|
468
|
+
bandwidth: float,
|
|
469
|
+
freqs_to_notch: npt.ArrayLike | None,
|
|
470
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
471
|
+
"""Apply a band-stop filter with desired bandwidth at the desired frequency.
|
|
472
|
+
|
|
473
|
+
position.
|
|
474
|
+
|
|
475
|
+
Suggested e.g. in [1]_ and [2]_
|
|
476
|
+
|
|
477
|
+
Parameters
|
|
478
|
+
----------
|
|
479
|
+
X : torch.Tensor
|
|
480
|
+
EEG input example or batch.
|
|
481
|
+
y : torch.Tensor
|
|
482
|
+
EEG labels for the example or batch.
|
|
483
|
+
sfreq : float
|
|
484
|
+
Sampling frequency of the signals to be filtered.
|
|
485
|
+
bandwidth : float
|
|
486
|
+
Bandwidth of the filter, i.e. distance between the low and high cut
|
|
487
|
+
frequencies.
|
|
488
|
+
freqs_to_notch : array-like | None
|
|
489
|
+
Array of floats of size ``(batch_size,)`` containing the center of the
|
|
490
|
+
frequency band to filter out for each sample in the batch. Frequencies
|
|
491
|
+
should be greater than ``bandwidth/2 + transition`` and lower than
|
|
492
|
+
``sfreq/2 - bandwidth/2 - transition`` (where ``transition = 1 Hz``).
|
|
493
|
+
|
|
494
|
+
Returns
|
|
495
|
+
-------
|
|
496
|
+
torch.Tensor
|
|
497
|
+
Transformed inputs.
|
|
498
|
+
torch.Tensor
|
|
499
|
+
Transformed labels.
|
|
500
|
+
|
|
501
|
+
References
|
|
502
|
+
----------
|
|
503
|
+
.. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
|
|
504
|
+
Subject-aware contrastive learning for biosignals. arXiv preprint
|
|
505
|
+
arXiv:2007.04871.
|
|
506
|
+
.. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
|
|
507
|
+
Representation Learning for Electroencephalogram Classification. In
|
|
508
|
+
Machine Learning for Health (pp. 238-253). PMLR.
|
|
509
|
+
"""
|
|
510
|
+
if bandwidth == 0 or freqs_to_notch is None:
|
|
511
|
+
return X, y
|
|
512
|
+
transformed_X = X.clone()
|
|
513
|
+
for c, (sample, notched_freq) in enumerate(zip(transformed_X, freqs_to_notch)):
|
|
514
|
+
sample = sample.cpu().numpy().astype(np.float64)
|
|
515
|
+
transformed_X[c] = torch.as_tensor(
|
|
516
|
+
notch_filter(
|
|
517
|
+
sample,
|
|
518
|
+
Fs=sfreq,
|
|
519
|
+
freqs=notched_freq,
|
|
520
|
+
method="fir",
|
|
521
|
+
notch_widths=bandwidth,
|
|
522
|
+
verbose=False,
|
|
523
|
+
)
|
|
524
|
+
)
|
|
525
|
+
return transformed_X, y
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _analytic_transform(x: torch.Tensor) -> torch.Tensor:
|
|
529
|
+
if torch.is_complex(x):
|
|
530
|
+
raise ValueError("x must be real.")
|
|
531
|
+
|
|
532
|
+
N = x.shape[-1]
|
|
533
|
+
f = fft(x, N, dim=-1)
|
|
534
|
+
h = torch.zeros_like(f)
|
|
535
|
+
if N % 2 == 0:
|
|
536
|
+
h[..., 0] = h[..., N // 2] = 1
|
|
537
|
+
h[..., 1 : N // 2] = 2
|
|
538
|
+
else:
|
|
539
|
+
h[..., 0] = 1
|
|
540
|
+
h[..., 1 : (N + 1) // 2] = 2
|
|
541
|
+
|
|
542
|
+
return ifft(f * h, dim=-1)
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def _nextpow2(n: int) -> int:
|
|
546
|
+
"""Return the first integer N such that 2**N >= abs(n)."""
|
|
547
|
+
return int(np.ceil(np.log2(np.abs(n))))
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def _frequency_shift(X: torch.Tensor, fs: float, f_shift: float) -> torch.Tensor:
|
|
551
|
+
"""
|
|
552
|
+
Shift the specified signal by the specified frequency.
|
|
553
|
+
|
|
554
|
+
See https://gist.github.com/lebedov/4428122
|
|
555
|
+
"""
|
|
556
|
+
# Pad the signal with zeros to prevent the FFT invoked by the transform
|
|
557
|
+
# from slowing down the computation:
|
|
558
|
+
n_channels, N_orig = X.shape[-2:]
|
|
559
|
+
N_padded = 2 ** _nextpow2(N_orig)
|
|
560
|
+
t = torch.arange(N_padded, device=X.device) / fs
|
|
561
|
+
padded = pad(X, (0, N_padded - N_orig))
|
|
562
|
+
analytical = _analytic_transform(padded)
|
|
563
|
+
if isinstance(f_shift, torch.Tensor):
|
|
564
|
+
_f_shift = f_shift
|
|
565
|
+
elif isinstance(f_shift, (float, int, np.ndarray, list)):
|
|
566
|
+
_f_shift = torch.as_tensor(f_shift).float()
|
|
567
|
+
else:
|
|
568
|
+
raise ValueError(f"Invalid f_shift type: {type(f_shift)}")
|
|
569
|
+
f_shift_stack = _f_shift.repeat(N_padded, n_channels, 1)
|
|
570
|
+
reshaped_f_shift = f_shift_stack.permute(
|
|
571
|
+
*torch.arange(f_shift_stack.ndim - 1, -1, -1)
|
|
572
|
+
)
|
|
573
|
+
shifted = analytical * torch.exp(2j * np.pi * reshaped_f_shift * t)
|
|
574
|
+
return shifted[..., :N_orig].real.float()
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def frequency_shift(
|
|
578
|
+
X: torch.Tensor, y: torch.Tensor, delta_freq: float, sfreq: float
|
|
579
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
580
|
+
"""Adds a shift in the frequency domain to all channels.
|
|
581
|
+
|
|
582
|
+
Note that here, the shift is the same for all channels of a single example.
|
|
583
|
+
|
|
584
|
+
Parameters
|
|
585
|
+
----------
|
|
586
|
+
X : torch.Tensor
|
|
587
|
+
EEG input example or batch.
|
|
588
|
+
y : torch.Tensor
|
|
589
|
+
EEG labels for the example or batch.
|
|
590
|
+
delta_freq : float
|
|
591
|
+
The amplitude of the frequency shift (in Hz).
|
|
592
|
+
sfreq : float
|
|
593
|
+
Sampling frequency of the signals to be transformed.
|
|
594
|
+
|
|
595
|
+
Returns
|
|
596
|
+
-------
|
|
597
|
+
torch.Tensor
|
|
598
|
+
Transformed inputs.
|
|
599
|
+
torch.Tensor
|
|
600
|
+
Transformed labels.
|
|
601
|
+
"""
|
|
602
|
+
transformed_X = _frequency_shift(
|
|
603
|
+
X=X,
|
|
604
|
+
fs=sfreq,
|
|
605
|
+
f_shift=delta_freq,
|
|
606
|
+
)
|
|
607
|
+
return transformed_X, y
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _torch_normalize_vectors(rr: torch.Tensor) -> torch.Tensor:
|
|
611
|
+
"""Normalize surface vertices."""
|
|
612
|
+
norm = torch.linalg.norm(rr, axis=1, keepdim=True)
|
|
613
|
+
mask = norm > 0
|
|
614
|
+
norm[~mask] = 1 # in case norm is zero, divide by 1
|
|
615
|
+
new_rr = rr / norm
|
|
616
|
+
return new_rr
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _torch_legval(
|
|
620
|
+
x: torch.Tensor, c: torch.Tensor, tensor: bool = True
|
|
621
|
+
) -> torch.Tensor:
|
|
622
|
+
"""
|
|
623
|
+
Evaluate a Legendre series at points x.
|
|
624
|
+
|
|
625
|
+
If `c` is of length `n + 1`, this function returns the value:
|
|
626
|
+
.. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
|
|
627
|
+
The parameter `x` is converted to an array only if it is a tuple or a
|
|
628
|
+
list, otherwise it is treated as a scalar. In either case, either `x`
|
|
629
|
+
or its elements must support multiplication and addition both with
|
|
630
|
+
themselves and with the elements of `c`.
|
|
631
|
+
If `c` is a 1-D array, then `p(x)` will have the same shape as `x`. If
|
|
632
|
+
`c` is multidimensional, then the shape of the result depends on the
|
|
633
|
+
value of `tensor`. If `tensor` is true the shape will be c.shape[1:] +
|
|
634
|
+
x.shape. If `tensor` is false the shape will be c.shape[1:]. Note that
|
|
635
|
+
scalars have shape (,).
|
|
636
|
+
Trailing zeros in the coefficients will be used in the evaluation, so
|
|
637
|
+
they should be avoided if efficiency is a concern.
|
|
638
|
+
|
|
639
|
+
Parameters
|
|
640
|
+
----------
|
|
641
|
+
x : array_like, compatible object
|
|
642
|
+
If `x` is a list or tuple, it is converted to an ndarray, otherwise
|
|
643
|
+
it is left unchanged and treated as a scalar. In either case, `x`
|
|
644
|
+
or its elements must support addition and multiplication with
|
|
645
|
+
with themselves and with the elements of `c`.
|
|
646
|
+
c : array_like
|
|
647
|
+
Array of coefficients ordered so that the coefficients for terms of
|
|
648
|
+
degree n are contained in c[n]. If `c` is multidimensional the
|
|
649
|
+
remaining indices enumerate multiple polynomials. In the two
|
|
650
|
+
dimensional case the coefficients may be thought of as stored in
|
|
651
|
+
the columns of `c`.
|
|
652
|
+
tensor : boolean, optional
|
|
653
|
+
If True, the shape of the coefficient array is extended with ones
|
|
654
|
+
on the right, one for each dimension of `x`. Scalars have dimension 0
|
|
655
|
+
for this action. The result is that every column of coefficients in
|
|
656
|
+
`c` is evaluated for every element of `x`. If False, `x` is broadcast
|
|
657
|
+
over the columns of `c` for the evaluation. This keyword is useful
|
|
658
|
+
when `c` is multidimensional. The default value is True.
|
|
659
|
+
.. versionadded:: 1.7.0
|
|
660
|
+
|
|
661
|
+
Returns
|
|
662
|
+
-------
|
|
663
|
+
values : ndarray, algebra_like
|
|
664
|
+
The shape of the return value is described above.
|
|
665
|
+
|
|
666
|
+
See Also
|
|
667
|
+
--------
|
|
668
|
+
legval2d, leggrid2d, legval3d, leggrid3d
|
|
669
|
+
|
|
670
|
+
Notes
|
|
671
|
+
-----
|
|
672
|
+
Code copied and modified from Numpy:
|
|
673
|
+
https://github.com/numpy/numpy/blob/v1.20.0/numpy/polynomial/legendre.py#L835-L920
|
|
674
|
+
|
|
675
|
+
Copyright (c) 2005-2021, NumPy Developers.
|
|
676
|
+
All rights reserved.
|
|
677
|
+
|
|
678
|
+
Redistribution and use in source and binary forms, with or without
|
|
679
|
+
modification, are permitted provided that the following conditions are
|
|
680
|
+
met:
|
|
681
|
+
* Redistributions of source code must retain the above copyright
|
|
682
|
+
notice, this list of conditions and the following disclaimer.
|
|
683
|
+
* Redistributions in binary form must reproduce the above
|
|
684
|
+
copyright notice, this list of conditions and the following
|
|
685
|
+
disclaimer in the documentation and/or other materials provided
|
|
686
|
+
with the distribution.
|
|
687
|
+
* Neither the name of the NumPy Developers nor the names of any
|
|
688
|
+
contributors may be used to endorse or promote products derived
|
|
689
|
+
from this software without specific prior written permission.
|
|
690
|
+
|
|
691
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
692
|
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
693
|
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
694
|
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
695
|
+
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
696
|
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
697
|
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
698
|
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
699
|
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
700
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
701
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
702
|
+
"""
|
|
703
|
+
c = torch.as_tensor(c)
|
|
704
|
+
c = c.double()
|
|
705
|
+
if isinstance(x, (tuple, list)):
|
|
706
|
+
x = torch.as_tensor(x)
|
|
707
|
+
if isinstance(x, torch.Tensor) and tensor:
|
|
708
|
+
c = c.view(c.shape + (1,) * x.ndim)
|
|
709
|
+
|
|
710
|
+
c = c.to(x.device)
|
|
711
|
+
|
|
712
|
+
if len(c) == 1:
|
|
713
|
+
c0 = c[0]
|
|
714
|
+
c1 = 0
|
|
715
|
+
elif len(c) == 2:
|
|
716
|
+
c0 = c[0]
|
|
717
|
+
c1 = c[1]
|
|
718
|
+
else:
|
|
719
|
+
nd = len(c)
|
|
720
|
+
c0 = c[-2]
|
|
721
|
+
c1 = c[-1]
|
|
722
|
+
for i in range(3, len(c) + 1):
|
|
723
|
+
tmp = c0
|
|
724
|
+
nd = nd - 1
|
|
725
|
+
c0 = c[-i] - (c1 * (nd - 1)) / nd
|
|
726
|
+
c1 = tmp + (c1 * x * (2 * nd - 1)) / nd
|
|
727
|
+
return c0 + c1 * x
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def _torch_calc_g(
|
|
731
|
+
cosang: torch.Tensor, stiffness: float = 4, n_legendre_terms: int = 50
|
|
732
|
+
) -> torch.Tensor:
|
|
733
|
+
"""Calculate spherical spline g function between points on a sphere.
|
|
734
|
+
|
|
735
|
+
Parameters
|
|
736
|
+
----------
|
|
737
|
+
cosang : array-like of float, shape(n_channels, n_channels)
|
|
738
|
+
cosine of angles between pairs of points on a spherical surface. This
|
|
739
|
+
is equivalent to the dot product of unit vectors.
|
|
740
|
+
stiffness : float
|
|
741
|
+
stiffness of the spline.
|
|
742
|
+
n_legendre_terms : int
|
|
743
|
+
number of Legendre terms to evaluate.
|
|
744
|
+
|
|
745
|
+
Returns
|
|
746
|
+
-------
|
|
747
|
+
G : np.ndrarray of float, shape(n_channels, n_channels)
|
|
748
|
+
The G matrix.
|
|
749
|
+
|
|
750
|
+
Notes
|
|
751
|
+
-----
|
|
752
|
+
Code copied and modified from MNE-Python:
|
|
753
|
+
https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L35
|
|
754
|
+
|
|
755
|
+
Copyright © 2011-2019, authors of MNE-Python
|
|
756
|
+
All rights reserved.
|
|
757
|
+
|
|
758
|
+
Redistribution and use in source and binary forms, with or without
|
|
759
|
+
modification, are permitted provided that the following conditions are met:
|
|
760
|
+
* Redistributions of source code must retain the above copyright
|
|
761
|
+
notice, this list of conditions and the following disclaimer.
|
|
762
|
+
* Redistributions in binary form must reproduce the above copyright
|
|
763
|
+
notice, this list of conditions and the following disclaimer in the
|
|
764
|
+
documentation and/or other materials provided with the distribution.
|
|
765
|
+
* Neither the name of the copyright holder nor the names of its
|
|
766
|
+
contributors may be used to endorse or promote products derived from
|
|
767
|
+
this software without specific prior written permission.
|
|
768
|
+
|
|
769
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
770
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
771
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
772
|
+
ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
|
773
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
774
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
775
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
776
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
|
777
|
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
|
778
|
+
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
|
779
|
+
DAMAGE.
|
|
780
|
+
"""
|
|
781
|
+
factors = [
|
|
782
|
+
(2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi)
|
|
783
|
+
for n in range(1, n_legendre_terms + 1)
|
|
784
|
+
]
|
|
785
|
+
return _torch_legval(cosang, [0] + factors)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def _torch_make_interpolation_matrix(
|
|
789
|
+
pos_from: torch.Tensor, pos_to: torch.Tensor, alpha: float = 1e-5
|
|
790
|
+
) -> torch.Tensor:
|
|
791
|
+
"""Compute interpolation matrix based on spherical splines.
|
|
792
|
+
|
|
793
|
+
Implementation based on [1]_
|
|
794
|
+
|
|
795
|
+
Parameters
|
|
796
|
+
----------
|
|
797
|
+
pos_from : torch.Tensor of float, shape(n_good_sensors, 3)
|
|
798
|
+
The positions to interpolate from.
|
|
799
|
+
pos_to : torch.Tensor of float, shape(n_bad_sensors, 3)
|
|
800
|
+
The positions to interpolate.
|
|
801
|
+
alpha : float
|
|
802
|
+
Regularization parameter. Defaults to 1e-5.
|
|
803
|
+
|
|
804
|
+
Returns
|
|
805
|
+
-------
|
|
806
|
+
interpolation : torch.Tensor of float, shape(len(pos_from), len(pos_to))
|
|
807
|
+
The interpolation matrix that maps good signals to the location
|
|
808
|
+
of bad signals.
|
|
809
|
+
|
|
810
|
+
Notes
|
|
811
|
+
-----
|
|
812
|
+
Code copied and modified from MNE-Python:
|
|
813
|
+
https://github.com/mne-tools/mne-python/blob/bdaa1d460201a3bc3cec95b67fc2b8d31a933652/mne/channels/interpolation.py#L59
|
|
814
|
+
|
|
815
|
+
Copyright © 2011-2019, authors of MNE-Python
|
|
816
|
+
All rights reserved.
|
|
817
|
+
|
|
818
|
+
Redistribution and use in source and binary forms, with or without
|
|
819
|
+
modification, are permitted provided that the following conditions are met:
|
|
820
|
+
* Redistributions of source code must retain the above copyright
|
|
821
|
+
notice, this list of conditions and the following disclaimer.
|
|
822
|
+
* Redistributions in binary form must reproduce the above copyright
|
|
823
|
+
notice, this list of conditions and the following disclaimer in the
|
|
824
|
+
documentation and/or other materials provided with the distribution.
|
|
825
|
+
* Neither the name of the copyright holder nor the names of its
|
|
826
|
+
contributors may be used to endorse or promote products derived from
|
|
827
|
+
this software without specific prior written permission.
|
|
828
|
+
|
|
829
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
830
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
831
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
832
|
+
ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
|
833
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
834
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
835
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
836
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
|
837
|
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
|
838
|
+
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
|
839
|
+
DAMAGE.
|
|
840
|
+
|
|
841
|
+
References
|
|
842
|
+
----------
|
|
843
|
+
[1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
|
|
844
|
+
Spherical splines for scalp potential and current density mapping.
|
|
845
|
+
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
|
|
846
|
+
"""
|
|
847
|
+
pos_from = pos_from.clone()
|
|
848
|
+
pos_to = pos_to.clone()
|
|
849
|
+
n_from = pos_from.shape[0]
|
|
850
|
+
n_to = pos_to.shape[0]
|
|
851
|
+
|
|
852
|
+
# normalize sensor positions to sphere
|
|
853
|
+
pos_from = _torch_normalize_vectors(pos_from)
|
|
854
|
+
pos_to = _torch_normalize_vectors(pos_to)
|
|
855
|
+
|
|
856
|
+
# cosine angles between source positions
|
|
857
|
+
cosang_from = torch.matmul(pos_from, pos_from.T)
|
|
858
|
+
cosang_to_from = torch.matmul(pos_to, pos_from.T)
|
|
859
|
+
G_from = _torch_calc_g(cosang_from)
|
|
860
|
+
G_to_from = _torch_calc_g(cosang_to_from)
|
|
861
|
+
assert G_from.shape == (n_from, n_from)
|
|
862
|
+
assert G_to_from.shape == (n_to, n_from)
|
|
863
|
+
|
|
864
|
+
if alpha is not None:
|
|
865
|
+
G_from.flatten()[:: len(G_from) + 1] += alpha
|
|
866
|
+
|
|
867
|
+
device = G_from.device
|
|
868
|
+
C = torch.vstack(
|
|
869
|
+
[
|
|
870
|
+
torch.hstack([G_from, torch.ones((n_from, 1), device=device)]),
|
|
871
|
+
torch.hstack(
|
|
872
|
+
[
|
|
873
|
+
torch.ones((1, n_from), device=device),
|
|
874
|
+
torch.as_tensor([[0]], device=device),
|
|
875
|
+
]
|
|
876
|
+
),
|
|
877
|
+
]
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
try:
|
|
881
|
+
C_inv = torch.linalg.inv(C)
|
|
882
|
+
except torch._C._LinAlgError:
|
|
883
|
+
# There is a stability issue with pinv since torch v1.8.0
|
|
884
|
+
# see https://github.com/pytorch/pytorch/issues/75494
|
|
885
|
+
C_inv = torch.linalg.pinv(C.cpu()).to(device)
|
|
886
|
+
|
|
887
|
+
interpolation = torch.hstack(
|
|
888
|
+
[G_to_from, torch.ones((n_to, 1), device=device)]
|
|
889
|
+
).matmul(C_inv[:, :-1])
|
|
890
|
+
assert interpolation.shape == (n_to, n_from)
|
|
891
|
+
return interpolation
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
def _rotate_signals(
|
|
895
|
+
X: torch.Tensor,
|
|
896
|
+
rotations: list[torch.Tensor],
|
|
897
|
+
sensors_positions_matrix: torch.Tensor,
|
|
898
|
+
spherical: bool = True,
|
|
899
|
+
) -> torch.Tensor:
|
|
900
|
+
sensors_positions_matrix = sensors_positions_matrix.to(X.device)
|
|
901
|
+
rot_sensors_matrices = [
|
|
902
|
+
rotation.matmul(sensors_positions_matrix) for rotation in rotations
|
|
903
|
+
]
|
|
904
|
+
if spherical:
|
|
905
|
+
interpolation_matrix = torch.stack(
|
|
906
|
+
[
|
|
907
|
+
torch.as_tensor(
|
|
908
|
+
_torch_make_interpolation_matrix(
|
|
909
|
+
sensors_positions_matrix.T, rot_sensors_matrix.T
|
|
910
|
+
),
|
|
911
|
+
device=X.device,
|
|
912
|
+
).float()
|
|
913
|
+
for rot_sensors_matrix in rot_sensors_matrices
|
|
914
|
+
]
|
|
915
|
+
)
|
|
916
|
+
return torch.matmul(interpolation_matrix, X)
|
|
917
|
+
else:
|
|
918
|
+
transformed_X = X.clone()
|
|
919
|
+
sensors_positions = list(sensors_positions_matrix)
|
|
920
|
+
for s, rot_sensors_matrix in enumerate(rot_sensors_matrices):
|
|
921
|
+
rot_sensors_positions = list(rot_sensors_matrix.T)
|
|
922
|
+
for time in range(X.shape[-1]):
|
|
923
|
+
interpolator_t = Rbf(*sensors_positions, X[s, :, time])
|
|
924
|
+
transformed_X[s, :, time] = torch.from_numpy(
|
|
925
|
+
interpolator_t(*rot_sensors_positions)
|
|
926
|
+
)
|
|
927
|
+
return transformed_X
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
def _make_rotation_matrix(
|
|
931
|
+
axis: Literal["x", "y", "z"],
|
|
932
|
+
angle: float | int | np.ndarray | list | torch.Tensor,
|
|
933
|
+
degrees: bool = True,
|
|
934
|
+
) -> torch.Tensor:
|
|
935
|
+
assert axis in ["x", "y", "z"], "axis should be either x, y or z."
|
|
936
|
+
if isinstance(angle, torch.Tensor):
|
|
937
|
+
_angle = angle
|
|
938
|
+
elif isinstance(angle, (float, int, np.ndarray, list)):
|
|
939
|
+
_angle = torch.as_tensor(angle)
|
|
940
|
+
else:
|
|
941
|
+
raise ValueError(f"Invalid angle type: {type(angle)}")
|
|
942
|
+
|
|
943
|
+
if degrees:
|
|
944
|
+
_angle = _angle * np.pi / 180
|
|
945
|
+
|
|
946
|
+
device = _angle.device
|
|
947
|
+
zero = torch.zeros(1, device=device)
|
|
948
|
+
rot = torch.stack(
|
|
949
|
+
[
|
|
950
|
+
torch.as_tensor([1, 0, 0], device=device),
|
|
951
|
+
torch.hstack([zero, torch.cos(_angle), -torch.sin(_angle)]),
|
|
952
|
+
torch.hstack([zero, torch.sin(_angle), torch.cos(_angle)]),
|
|
953
|
+
]
|
|
954
|
+
)
|
|
955
|
+
if axis == "x":
|
|
956
|
+
return rot
|
|
957
|
+
elif axis == "y":
|
|
958
|
+
rot = rot[[2, 0, 1], :]
|
|
959
|
+
return rot[:, [2, 0, 1]]
|
|
960
|
+
else:
|
|
961
|
+
rot = rot[[1, 2, 0], :]
|
|
962
|
+
return rot[:, [1, 2, 0]]
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def sensors_rotation(
|
|
966
|
+
X: torch.Tensor,
|
|
967
|
+
y: torch.Tensor,
|
|
968
|
+
sensors_positions_matrix: torch.Tensor,
|
|
969
|
+
axis: Literal["x", "y", "z"],
|
|
970
|
+
angles: npt.ArrayLike,
|
|
971
|
+
spherical_splines: bool,
|
|
972
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
973
|
+
"""Interpolates EEG signals over sensors rotated around the desired axis.
|
|
974
|
+
|
|
975
|
+
with the desired angle.
|
|
976
|
+
|
|
977
|
+
Suggested in [1]_
|
|
978
|
+
|
|
979
|
+
Parameters
|
|
980
|
+
----------
|
|
981
|
+
X : torch.Tensor
|
|
982
|
+
EEG input example or batch.
|
|
983
|
+
y : torch.Tensor
|
|
984
|
+
EEG labels for the example or batch.
|
|
985
|
+
sensors_positions_matrix : torch.Tensor
|
|
986
|
+
Matrix giving the positions of each sensor in a 3D cartesian coordinate
|
|
987
|
+
system. Should have shape (3, n_channels), where n_channels is the
|
|
988
|
+
number of channels. Standard 10-20 positions can be obtained from
|
|
989
|
+
``mne`` through::
|
|
990
|
+
|
|
991
|
+
>>> ten_twenty_montage = mne.channels.make_standard_montage(
|
|
992
|
+
... 'standard_1020'
|
|
993
|
+
... ).get_positions()['ch_pos']
|
|
994
|
+
axis : 'x' | 'y' | 'z'
|
|
995
|
+
Axis around which to rotate.
|
|
996
|
+
angles : array-like
|
|
997
|
+
Array of float of shape ``(batch_size,)`` containing the rotation
|
|
998
|
+
angles (in degrees) for each element of the input batch.
|
|
999
|
+
spherical_splines : bool
|
|
1000
|
+
Whether to use spherical splines for the interpolation or not. When
|
|
1001
|
+
`False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
|
|
1002
|
+
used (as in the original paper).
|
|
1003
|
+
|
|
1004
|
+
References
|
|
1005
|
+
----------
|
|
1006
|
+
.. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
|
|
1007
|
+
electroencephalographic data. In 2017 39th Annual International
|
|
1008
|
+
Conference of the IEEE Engineering in Medicine and Biology Society
|
|
1009
|
+
(EMBC) (pp. 471-474).
|
|
1010
|
+
"""
|
|
1011
|
+
rots = [_make_rotation_matrix(axis, angle, degrees=True) for angle in angles]
|
|
1012
|
+
rotated_X = _rotate_signals(X, rots, sensors_positions_matrix, spherical_splines)
|
|
1013
|
+
return rotated_X, y
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def mixup(
|
|
1017
|
+
X: torch.Tensor, y: torch.Tensor, lam: torch.Tensor, idx_perm: torch.Tensor
|
|
1018
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1019
|
+
"""Mixes two channels of EEG data.
|
|
1020
|
+
|
|
1021
|
+
See [1]_ for details.
|
|
1022
|
+
Implementation based on [2]_.
|
|
1023
|
+
|
|
1024
|
+
Parameters
|
|
1025
|
+
----------
|
|
1026
|
+
X : torch.Tensor
|
|
1027
|
+
EEG data in form ``batch_size, n_channels, n_times``
|
|
1028
|
+
y : torch.Tensor
|
|
1029
|
+
Target of length ``batch_size``
|
|
1030
|
+
lam : torch.Tensor
|
|
1031
|
+
Values between 0 and 1 setting the linear interpolation between
|
|
1032
|
+
examples.
|
|
1033
|
+
idx_perm : torch.Tensor
|
|
1034
|
+
Permuted indices of example that are mixed into original examples.
|
|
1035
|
+
|
|
1036
|
+
Returns
|
|
1037
|
+
-------
|
|
1038
|
+
tuple
|
|
1039
|
+
``X``, ``y``. Where ``X`` is augmented and ``y`` is a tuple of length
|
|
1040
|
+
3 containing the labels of the two mixed channels and the mixing
|
|
1041
|
+
coefficient.
|
|
1042
|
+
|
|
1043
|
+
References
|
|
1044
|
+
----------
|
|
1045
|
+
.. [1] Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz
|
|
1046
|
+
(2018). mixup: Beyond Empirical Risk Minimization. In 2018
|
|
1047
|
+
International Conference on Learning Representations (ICLR)
|
|
1048
|
+
Online: https://arxiv.org/abs/1710.09412
|
|
1049
|
+
.. [2] https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
|
|
1050
|
+
"""
|
|
1051
|
+
device = X.device
|
|
1052
|
+
batch_size, n_channels, n_times = X.shape
|
|
1053
|
+
|
|
1054
|
+
X_mix = torch.zeros((batch_size, n_channels, n_times)).to(device)
|
|
1055
|
+
y_a = torch.arange(batch_size).to(device)
|
|
1056
|
+
y_b = torch.arange(batch_size).to(device)
|
|
1057
|
+
|
|
1058
|
+
for idx in range(batch_size):
|
|
1059
|
+
X_mix[idx] = lam[idx] * X[idx] + (1 - lam[idx]) * X[idx_perm[idx]]
|
|
1060
|
+
y_a[idx] = y[idx]
|
|
1061
|
+
y_b[idx] = y[idx_perm[idx]]
|
|
1062
|
+
|
|
1063
|
+
return X_mix, (y_a, y_b, lam)
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
def segmentation_reconstruction(
|
|
1067
|
+
X: torch.Tensor,
|
|
1068
|
+
y: torch.Tensor,
|
|
1069
|
+
n_segments: int,
|
|
1070
|
+
data_classes: list[tuple[int, torch.Tensor]],
|
|
1071
|
+
rand_indices: npt.ArrayLike,
|
|
1072
|
+
idx_shuffle: npt.ArrayLike,
|
|
1073
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1074
|
+
"""Segment and reconstruct EEG data from [1]_.
|
|
1075
|
+
|
|
1076
|
+
See [1]_ for details.
|
|
1077
|
+
|
|
1078
|
+
Parameters
|
|
1079
|
+
----------
|
|
1080
|
+
X : torch.Tensor
|
|
1081
|
+
EEG input example or batch.
|
|
1082
|
+
y : torch.Tensor
|
|
1083
|
+
EEG labels for the example or batch.
|
|
1084
|
+
n_segments : int
|
|
1085
|
+
Number of segments to use in the batch.
|
|
1086
|
+
data_classes : list[tuple[int, torch.Tensor]]
|
|
1087
|
+
List of tuples. Each tuple contains the class index and the corresponding EEG data.
|
|
1088
|
+
rand_indices : array-like
|
|
1089
|
+
Array of indices that indicates which trial to use in each segment.
|
|
1090
|
+
idx_shuffle : array-like
|
|
1091
|
+
Array of indices to shuffle the new generated trials.
|
|
1092
|
+
|
|
1093
|
+
Returns
|
|
1094
|
+
-------
|
|
1095
|
+
torch.Tensor
|
|
1096
|
+
Transformed inputs.
|
|
1097
|
+
torch.Tensor
|
|
1098
|
+
Transformed labels.
|
|
1099
|
+
|
|
1100
|
+
References
|
|
1101
|
+
----------
|
|
1102
|
+
.. [1] Lotte, F. (2015). Signal processing approaches to minimize or
|
|
1103
|
+
suppress calibration time in oscillatory activity-based brain–computer
|
|
1104
|
+
interfaces. Proceedings of the IEEE, 103(6), 871-890.
|
|
1105
|
+
"""
|
|
1106
|
+
|
|
1107
|
+
# Initialize lists to store augmented data and corresponding labels
|
|
1108
|
+
aug_data: list[torch.Tensor] = []
|
|
1109
|
+
aug_label: list[torch.Tensor] = []
|
|
1110
|
+
|
|
1111
|
+
# Iterate through each class to separate and augment data
|
|
1112
|
+
for class_index, X_class in data_classes:
|
|
1113
|
+
# Determine class-specific dimensions
|
|
1114
|
+
# Store the augmented data and the corresponding class labels
|
|
1115
|
+
n_trials, n_channels, window_size = X_class.shape
|
|
1116
|
+
# Segment Size
|
|
1117
|
+
segment_size = window_size // n_segments
|
|
1118
|
+
# Initialize an empty tensor for augmented data
|
|
1119
|
+
X_aug = torch.zeros_like(X_class)
|
|
1120
|
+
# Generate random indices within the class-specific dataset
|
|
1121
|
+
rand_idx = rand_indices[class_index]
|
|
1122
|
+
for idx_segment in range(n_segments):
|
|
1123
|
+
start = idx_segment * segment_size
|
|
1124
|
+
end = (idx_segment + 1) * segment_size
|
|
1125
|
+
|
|
1126
|
+
# Perform the data augmentation
|
|
1127
|
+
X_aug[np.arange(n_trials), :, start:end] = X_class[
|
|
1128
|
+
rand_idx[:, idx_segment], :, start:end
|
|
1129
|
+
]
|
|
1130
|
+
aug_data.append(X_aug)
|
|
1131
|
+
aug_label.append(torch.full((n_trials,), class_index))
|
|
1132
|
+
# Concatenate the augmented data and labels
|
|
1133
|
+
concat_aug_data = torch.cat(aug_data, dim=0)
|
|
1134
|
+
concat_aug_data = concat_aug_data.to(dtype=X.dtype, device=X.device)
|
|
1135
|
+
concat_aug_data = concat_aug_data[idx_shuffle]
|
|
1136
|
+
|
|
1137
|
+
if y is not None:
|
|
1138
|
+
concat_label = torch.cat(aug_label, dim=0)
|
|
1139
|
+
concat_label = concat_label.to(dtype=y.dtype, device=y.device)
|
|
1140
|
+
concat_label = concat_label[idx_shuffle]
|
|
1141
|
+
return concat_aug_data, concat_label
|
|
1142
|
+
|
|
1143
|
+
return concat_aug_data, None
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
def mask_encoding(
|
|
1147
|
+
X: torch.Tensor,
|
|
1148
|
+
y: torch.Tensor,
|
|
1149
|
+
time_start: torch.Tensor,
|
|
1150
|
+
segment_length: int,
|
|
1151
|
+
n_segments: int,
|
|
1152
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1153
|
+
"""Mark encoding from Ding et al (2024) from [ding2024]_.
|
|
1154
|
+
|
|
1155
|
+
Replaces a contiguous part (or parts) of all channels by zeros
|
|
1156
|
+
(if more than one segment, it may overlap).
|
|
1157
|
+
|
|
1158
|
+
Implementation based on [ding2024]_
|
|
1159
|
+
|
|
1160
|
+
Parameters
|
|
1161
|
+
----------
|
|
1162
|
+
X : torch.Tensor
|
|
1163
|
+
EEG input example or batch.
|
|
1164
|
+
y : torch.Tensor
|
|
1165
|
+
EEG labels for the example or batch.
|
|
1166
|
+
time_start : torch.Tensor
|
|
1167
|
+
Tensor of integers containing the position (in last dimension) where to
|
|
1168
|
+
start masking the signal. Should have "n_segments" times the size of the first
|
|
1169
|
+
dimension of X (i.e. "n_segments" start positions per example in the batch).
|
|
1170
|
+
segment_length : int
|
|
1171
|
+
Length of each segment to zero out.
|
|
1172
|
+
n_segments : int
|
|
1173
|
+
Number of segments to zero out in each example.
|
|
1174
|
+
|
|
1175
|
+
Returns
|
|
1176
|
+
-------
|
|
1177
|
+
torch.Tensor
|
|
1178
|
+
Transformed inputs.
|
|
1179
|
+
torch.Tensor
|
|
1180
|
+
Transformed labels.
|
|
1181
|
+
|
|
1182
|
+
References
|
|
1183
|
+
----------
|
|
1184
|
+
.. [ding2024] Ding, Wenlong, et al. A Novel Data Augmentation Approach
|
|
1185
|
+
Using Mask Encoding for Deep Learning-Based Asynchronous SSVEP-BCI.
|
|
1186
|
+
IEEE Transactions on Neural Systems and Rehabilitation Engineering
|
|
1187
|
+
32 (2024): 875-886.
|
|
1188
|
+
"""
|
|
1189
|
+
|
|
1190
|
+
batch_indices = torch.arange(X.shape[0]).repeat_interleave(n_segments)
|
|
1191
|
+
start_indices = time_start.flatten()
|
|
1192
|
+
mask_indices = start_indices[:, None] + torch.arange(segment_length)
|
|
1193
|
+
|
|
1194
|
+
# Create a boolean mask with the same shape as X
|
|
1195
|
+
mask = torch.zeros_like(X, dtype=torch.bool)
|
|
1196
|
+
for batch_index, grouped_mask_indices in zip(batch_indices, mask_indices):
|
|
1197
|
+
mask[batch_index, :, grouped_mask_indices] = True
|
|
1198
|
+
|
|
1199
|
+
# Apply the mask to set the values to 0
|
|
1200
|
+
X[mask] = 0
|
|
1201
|
+
|
|
1202
|
+
return X, y # Return the masked tensor and labels
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
def channels_rereference(
|
|
1206
|
+
X: torch.Tensor,
|
|
1207
|
+
y: torch.Tensor,
|
|
1208
|
+
random_state: int | np.random.RandomState | None = None,
|
|
1209
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1210
|
+
"""Randomly re-reference channels in EEG data matrix.
|
|
1211
|
+
|
|
1212
|
+
Part of the augmentations proposed in [1]_
|
|
1213
|
+
|
|
1214
|
+
Parameters
|
|
1215
|
+
----------
|
|
1216
|
+
X : torch.Tensor
|
|
1217
|
+
EEG input example or batch.
|
|
1218
|
+
y : torch.Tensor
|
|
1219
|
+
EEG labels for the example or batch.
|
|
1220
|
+
random_state : int | numpy.random.Generator, optional
|
|
1221
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1222
|
+
Defaults to None.
|
|
1223
|
+
|
|
1224
|
+
Returns
|
|
1225
|
+
-------
|
|
1226
|
+
torch.Tensor
|
|
1227
|
+
Transformed inputs.
|
|
1228
|
+
torch.Tensor
|
|
1229
|
+
Transformed labels.
|
|
1230
|
+
|
|
1231
|
+
References
|
|
1232
|
+
----------
|
|
1233
|
+
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1234
|
+
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1235
|
+
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1236
|
+
Learning Research 136:238-253
|
|
1237
|
+
"""
|
|
1238
|
+
|
|
1239
|
+
rng = check_random_state(random_state)
|
|
1240
|
+
batch_size, n_channels, _ = X.shape
|
|
1241
|
+
|
|
1242
|
+
ch = rng.randint(0, n_channels, size=batch_size)
|
|
1243
|
+
|
|
1244
|
+
X_ch = X[torch.arange(batch_size), ch, :]
|
|
1245
|
+
X = X - X_ch.unsqueeze(1)
|
|
1246
|
+
X[torch.arange(batch_size), ch, :] = -X_ch
|
|
1247
|
+
|
|
1248
|
+
return X, y
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
def amplitude_scale(
|
|
1252
|
+
X: torch.Tensor,
|
|
1253
|
+
y: torch.Tensor,
|
|
1254
|
+
scale: tuple,
|
|
1255
|
+
random_state: int | np.random.RandomState | None = None,
|
|
1256
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1257
|
+
"""Rescale amplitude of each channel based on a random sampled scaling value.
|
|
1258
|
+
|
|
1259
|
+
Part of the augmentations proposed in [1]_
|
|
1260
|
+
|
|
1261
|
+
Parameters
|
|
1262
|
+
----------
|
|
1263
|
+
X : torch.Tensor
|
|
1264
|
+
EEG input example or batch.
|
|
1265
|
+
y : torch.Tensor
|
|
1266
|
+
EEG labels for the example or batch.
|
|
1267
|
+
scale : tuple of floats
|
|
1268
|
+
Interval from which ypu sample the scaling value
|
|
1269
|
+
random_state : int | numpy.random.Generator, optional
|
|
1270
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
1271
|
+
Defaults to None.
|
|
1272
|
+
|
|
1273
|
+
Returns
|
|
1274
|
+
-------
|
|
1275
|
+
torch.Tensor
|
|
1276
|
+
Transformed inputs.
|
|
1277
|
+
torch.Tensor
|
|
1278
|
+
Transformed labels.
|
|
1279
|
+
|
|
1280
|
+
References
|
|
1281
|
+
----------
|
|
1282
|
+
.. [1] Mohsenvand, M.N., Izadi, M.R. & Maes, P.. (2020). Contrastive
|
|
1283
|
+
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1284
|
+
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1285
|
+
Learning Research 136:238-253
|
|
1286
|
+
"""
|
|
1287
|
+
|
|
1288
|
+
rng = torch.Generator()
|
|
1289
|
+
rng.manual_seed(random_state)
|
|
1290
|
+
batch_size, n_channels, _ = X.shape
|
|
1291
|
+
|
|
1292
|
+
# Parameter for scaling amplitude / channel / trial
|
|
1293
|
+
l, h = scale
|
|
1294
|
+
s = l + (h - l) * torch.rand(
|
|
1295
|
+
batch_size, n_channels, 1, generator=rng, device=X.device, dtype=X.dtype
|
|
1296
|
+
)
|
|
1297
|
+
|
|
1298
|
+
X = s * X
|
|
1299
|
+
|
|
1300
|
+
return X, y
|