braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Utilities for data augmentation.
|
|
3
|
+
"""
|
|
2
4
|
|
|
3
5
|
from . import functional
|
|
4
6
|
from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
|
|
5
7
|
from .transforms import (
|
|
6
|
-
AmplitudeScale,
|
|
7
8
|
BandstopFilter,
|
|
8
9
|
ChannelsDropout,
|
|
9
|
-
ChannelsReref,
|
|
10
10
|
ChannelsShuffle,
|
|
11
11
|
ChannelsSymmetry,
|
|
12
12
|
FrequencyShift,
|
|
@@ -46,7 +46,5 @@ __all__ = [
|
|
|
46
46
|
"Mixup",
|
|
47
47
|
"SegmentationReconstruction",
|
|
48
48
|
"MaskEncoding",
|
|
49
|
-
"AmplitudeScale",
|
|
50
|
-
"ChannelsReref",
|
|
51
49
|
"functional",
|
|
52
50
|
]
|
braindecode/augmentation/base.py
CHANGED
|
@@ -31,8 +31,7 @@ Operation = Callable[
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class Transform(torch.nn.Module):
|
|
34
|
-
"""Basic transform class used for implementing data augmentation
|
|
35
|
-
|
|
34
|
+
"""Basic transform class used for implementing data augmentation
|
|
36
35
|
operations.
|
|
37
36
|
|
|
38
37
|
Parameters
|
|
@@ -43,7 +42,7 @@ class Transform(torch.nn.Module):
|
|
|
43
42
|
probability : float, optional
|
|
44
43
|
Float between 0 and 1 defining the uniform probability of applying the
|
|
45
44
|
operation. Set to 1.0 by default (e.g always apply the operation).
|
|
46
|
-
random_state
|
|
45
|
+
random_state: int, optional
|
|
47
46
|
Seed to be used to instantiate numpy random number generator instance.
|
|
48
47
|
Used to decide whether or not to transform given the probability
|
|
49
48
|
argument. Defaults to None.
|
|
@@ -126,7 +125,7 @@ class Transform(torch.nn.Module):
|
|
|
126
125
|
return out_X
|
|
127
126
|
|
|
128
127
|
def _get_mask(self, batch_size, device) -> torch.Tensor:
|
|
129
|
-
"""Samples whether to apply operation or not over the whole batch
|
|
128
|
+
"""Samples whether to apply operation or not over the whole batch"""
|
|
130
129
|
return torch.as_tensor(self.probability > self.rng.uniform(size=batch_size)).to(
|
|
131
130
|
device
|
|
132
131
|
)
|
|
@@ -154,7 +153,7 @@ class Compose(Transform):
|
|
|
154
153
|
|
|
155
154
|
Parameters
|
|
156
155
|
----------
|
|
157
|
-
transforms
|
|
156
|
+
transforms: list
|
|
158
157
|
Sequence of Transforms to be composed.
|
|
159
158
|
"""
|
|
160
159
|
|
|
@@ -170,9 +169,7 @@ class Compose(Transform):
|
|
|
170
169
|
|
|
171
170
|
def _make_collateable(transform, device=None):
|
|
172
171
|
"""Wraps a transform to make it collateable.
|
|
173
|
-
|
|
174
|
-
with device control.
|
|
175
|
-
"""
|
|
172
|
+
with device control."""
|
|
176
173
|
|
|
177
174
|
def _collate_fn(batch):
|
|
178
175
|
collated_batch = default_collate(batch)
|
|
@@ -144,7 +144,7 @@ def ft_surrogate(
|
|
|
144
144
|
EEG input example or batch.
|
|
145
145
|
y : torch.Tensor
|
|
146
146
|
EEG labels for the example or batch.
|
|
147
|
-
phase_noise_magnitude
|
|
147
|
+
phase_noise_magnitude: float
|
|
148
148
|
Float between 0 and 1 setting the range over which the phase
|
|
149
149
|
perturbation is uniformly sampled:
|
|
150
150
|
[0, `phase_noise_magnitude` * 2 * `pi`].
|
|
@@ -152,7 +152,7 @@ def ft_surrogate(
|
|
|
152
152
|
Whether to sample phase perturbations independently for each channel or
|
|
153
153
|
not. It is advised to set it to False when spatial information is
|
|
154
154
|
important for the task, like in BCI.
|
|
155
|
-
random_state
|
|
155
|
+
random_state: int | numpy.random.Generator, optional
|
|
156
156
|
Used to draw the phase perturbation. Defaults to None.
|
|
157
157
|
|
|
158
158
|
Returns
|
|
@@ -289,10 +289,10 @@ def channels_shuffle(
|
|
|
289
289
|
EEG input example or batch.
|
|
290
290
|
y : torch.Tensor
|
|
291
291
|
EEG labels for the example or batch.
|
|
292
|
-
p_shuffle
|
|
292
|
+
p_shuffle: float | None
|
|
293
293
|
Float between 0 and 1 setting the probability of including the channel
|
|
294
294
|
in the set of permutted channels.
|
|
295
|
-
random_state
|
|
295
|
+
random_state: int | numpy.random.Generator, optional
|
|
296
296
|
Seed to be used to instantiate numpy random number generator instance.
|
|
297
297
|
Used to sample which channels to shuffle and to carry the shuffle.
|
|
298
298
|
Defaults to None.
|
|
@@ -335,7 +335,7 @@ def gaussian_noise(
|
|
|
335
335
|
EEG labels for the example or batch.
|
|
336
336
|
std : float
|
|
337
337
|
Standard deviation to use for the additive noise.
|
|
338
|
-
random_state
|
|
338
|
+
random_state: int | numpy.random.Generator, optional
|
|
339
339
|
Seed to be used to instantiate numpy random number generator instance.
|
|
340
340
|
Defaults to None.
|
|
341
341
|
|
|
@@ -468,8 +468,7 @@ def bandstop_filter(
|
|
|
468
468
|
bandwidth: float,
|
|
469
469
|
freqs_to_notch: npt.ArrayLike | None,
|
|
470
470
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
471
|
-
"""Apply a band-stop filter with desired bandwidth at the desired frequency
|
|
472
|
-
|
|
471
|
+
"""Apply a band-stop filter with desired bandwidth at the desired frequency
|
|
473
472
|
position.
|
|
474
473
|
|
|
475
474
|
Suggested e.g. in [1]_ and [2]_
|
|
@@ -621,7 +620,6 @@ def _torch_legval(
|
|
|
621
620
|
) -> torch.Tensor:
|
|
622
621
|
"""
|
|
623
622
|
Evaluate a Legendre series at points x.
|
|
624
|
-
|
|
625
623
|
If `c` is of length `n + 1`, this function returns the value:
|
|
626
624
|
.. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
|
|
627
625
|
The parameter `x` is converted to an array only if it is a tuple or a
|
|
@@ -807,6 +805,12 @@ def _torch_make_interpolation_matrix(
|
|
|
807
805
|
The interpolation matrix that maps good signals to the location
|
|
808
806
|
of bad signals.
|
|
809
807
|
|
|
808
|
+
References
|
|
809
|
+
----------
|
|
810
|
+
[1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
|
|
811
|
+
Spherical splines for scalp potential and current density mapping.
|
|
812
|
+
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
|
|
813
|
+
|
|
810
814
|
Notes
|
|
811
815
|
-----
|
|
812
816
|
Code copied and modified from MNE-Python:
|
|
@@ -837,12 +841,6 @@ def _torch_make_interpolation_matrix(
|
|
|
837
841
|
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
|
838
842
|
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
|
839
843
|
DAMAGE.
|
|
840
|
-
|
|
841
|
-
References
|
|
842
|
-
----------
|
|
843
|
-
[1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989).
|
|
844
|
-
Spherical splines for scalp potential and current density mapping.
|
|
845
|
-
Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7.
|
|
846
844
|
"""
|
|
847
845
|
pos_from = pos_from.clone()
|
|
848
846
|
pos_to = pos_to.clone()
|
|
@@ -970,8 +968,7 @@ def sensors_rotation(
|
|
|
970
968
|
angles: npt.ArrayLike,
|
|
971
969
|
spherical_splines: bool,
|
|
972
970
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
973
|
-
"""Interpolates EEG signals over sensors rotated around the desired axis
|
|
974
|
-
|
|
971
|
+
"""Interpolates EEG signals over sensors rotated around the desired axis
|
|
975
972
|
with the desired angle.
|
|
976
973
|
|
|
977
974
|
Suggested in [1]_
|
|
@@ -1030,7 +1027,7 @@ def mixup(
|
|
|
1030
1027
|
lam : torch.Tensor
|
|
1031
1028
|
Values between 0 and 1 setting the linear interpolation between
|
|
1032
1029
|
examples.
|
|
1033
|
-
idx_perm
|
|
1030
|
+
idx_perm: torch.Tensor
|
|
1034
1031
|
Permuted indices of example that are mixed into original examples.
|
|
1035
1032
|
|
|
1036
1033
|
Returns
|
|
@@ -1083,20 +1080,18 @@ def segmentation_reconstruction(
|
|
|
1083
1080
|
EEG labels for the example or batch.
|
|
1084
1081
|
n_segments : int
|
|
1085
1082
|
Number of segments to use in the batch.
|
|
1086
|
-
data_classes
|
|
1083
|
+
data_classes: list[tuple[int, torch.Tensor]]
|
|
1087
1084
|
List of tuples. Each tuple contains the class index and the corresponding EEG data.
|
|
1088
|
-
rand_indices
|
|
1085
|
+
rand_indices: array-like
|
|
1089
1086
|
Array of indices that indicates which trial to use in each segment.
|
|
1090
|
-
idx_shuffle
|
|
1087
|
+
idx_shuffle: array-like
|
|
1091
1088
|
Array of indices to shuffle the new generated trials.
|
|
1092
|
-
|
|
1093
1089
|
Returns
|
|
1094
1090
|
-------
|
|
1095
1091
|
torch.Tensor
|
|
1096
1092
|
Transformed inputs.
|
|
1097
1093
|
torch.Tensor
|
|
1098
1094
|
Transformed labels.
|
|
1099
|
-
|
|
1100
1095
|
References
|
|
1101
1096
|
----------
|
|
1102
1097
|
.. [1] Lotte, F. (2015). Signal processing approaches to minimize or
|
|
@@ -1150,7 +1145,7 @@ def mask_encoding(
|
|
|
1150
1145
|
segment_length: int,
|
|
1151
1146
|
n_segments: int,
|
|
1152
1147
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1153
|
-
"""Mark encoding from Ding et al (2024) from [ding2024]_.
|
|
1148
|
+
"""Mark encoding from Ding et al. (2024) from [ding2024]_.
|
|
1154
1149
|
|
|
1155
1150
|
Replaces a contiguous part (or parts) of all channels by zeros
|
|
1156
1151
|
(if more than one segment, it may overlap).
|
|
@@ -1217,7 +1212,7 @@ def channels_rereference(
|
|
|
1217
1212
|
EEG input example or batch.
|
|
1218
1213
|
y : torch.Tensor
|
|
1219
1214
|
EEG labels for the example or batch.
|
|
1220
|
-
random_state
|
|
1215
|
+
random_state: int | numpy.random.Generator, optional
|
|
1221
1216
|
Seed to be used to instantiate numpy random number generator instance.
|
|
1222
1217
|
Defaults to None.
|
|
1223
1218
|
|
|
@@ -1234,6 +1229,7 @@ def channels_rereference(
|
|
|
1234
1229
|
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1235
1230
|
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1236
1231
|
Learning Research 136:238-253
|
|
1232
|
+
|
|
1237
1233
|
"""
|
|
1238
1234
|
|
|
1239
1235
|
rng = check_random_state(random_state)
|
|
@@ -1266,7 +1262,7 @@ def amplitude_scale(
|
|
|
1266
1262
|
EEG labels for the example or batch.
|
|
1267
1263
|
scale : tuple of floats
|
|
1268
1264
|
Interval from which ypu sample the scaling value
|
|
1269
|
-
random_state
|
|
1265
|
+
random_state: int | numpy.random.Generator, optional
|
|
1270
1266
|
Seed to be used to instantiate numpy random number generator instance.
|
|
1271
1267
|
Defaults to None.
|
|
1272
1268
|
|
|
@@ -1283,6 +1279,7 @@ def amplitude_scale(
|
|
|
1283
1279
|
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1284
1280
|
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1285
1281
|
Learning Research 136:238-253
|
|
1282
|
+
|
|
1286
1283
|
"""
|
|
1287
1284
|
|
|
1288
1285
|
rng = torch.Generator()
|
|
@@ -40,7 +40,7 @@ class TimeReverse(Transform):
|
|
|
40
40
|
----------
|
|
41
41
|
probability : float
|
|
42
42
|
Float setting the probability of applying the operation.
|
|
43
|
-
random_state
|
|
43
|
+
random_state: int | numpy.random.Generator, optional
|
|
44
44
|
Seed to be used to instantiate numpy random number generator instance.
|
|
45
45
|
Used to decide whether or not to transform given the probability
|
|
46
46
|
argument. Defaults to None.
|
|
@@ -66,7 +66,7 @@ class SignFlip(Transform):
|
|
|
66
66
|
----------
|
|
67
67
|
probability : float
|
|
68
68
|
Float setting the probability of applying the operation.
|
|
69
|
-
random_state
|
|
69
|
+
random_state: int | numpy.random.Generator, optional
|
|
70
70
|
Seed to be used to instantiate numpy random number generator instance.
|
|
71
71
|
Used to decide whether or not to transform given the probability
|
|
72
72
|
argument. Defaults to None.
|
|
@@ -83,7 +83,7 @@ class FTSurrogate(Transform):
|
|
|
83
83
|
|
|
84
84
|
Parameters
|
|
85
85
|
----------
|
|
86
|
-
probability
|
|
86
|
+
probability: float
|
|
87
87
|
Float setting the probability of applying the operation.
|
|
88
88
|
phase_noise_magnitude : float | torch.Tensor, optional
|
|
89
89
|
Float between 0 and 1 setting the range over which the phase
|
|
@@ -93,7 +93,7 @@ class FTSurrogate(Transform):
|
|
|
93
93
|
Whether to sample phase perturbations independently for each channel or
|
|
94
94
|
not. It is advised to set it to False when spatial information is
|
|
95
95
|
important for the task, like in BCI. Default False.
|
|
96
|
-
random_state
|
|
96
|
+
random_state: int | numpy.random.Generator, optional
|
|
97
97
|
Seed to be used to instantiate numpy random number generator instance.
|
|
98
98
|
Used to decide whether or not to transform given the probability
|
|
99
99
|
argument. Defaults to None.
|
|
@@ -162,12 +162,12 @@ class ChannelsDropout(Transform):
|
|
|
162
162
|
|
|
163
163
|
Parameters
|
|
164
164
|
----------
|
|
165
|
-
probability
|
|
165
|
+
probability: float
|
|
166
166
|
Float setting the probability of applying the operation.
|
|
167
|
-
p_drop
|
|
167
|
+
p_drop: float | None, optional
|
|
168
168
|
Float between 0 and 1 setting the probability of dropping each channel.
|
|
169
169
|
Defaults to 0.2.
|
|
170
|
-
random_state
|
|
170
|
+
random_state: int | numpy.random.RandomState, optional
|
|
171
171
|
Seed to be used to instantiate numpy random number generator instance.
|
|
172
172
|
Used to decide whether or not to transform given the probability
|
|
173
173
|
argument and to sample channels to erase. Defaults to None.
|
|
@@ -219,12 +219,12 @@ class ChannelsShuffle(Transform):
|
|
|
219
219
|
|
|
220
220
|
Parameters
|
|
221
221
|
----------
|
|
222
|
-
probability
|
|
222
|
+
probability: float
|
|
223
223
|
Float setting the probability of applying the operation.
|
|
224
|
-
p_shuffle
|
|
224
|
+
p_shuffle: float | None, optional
|
|
225
225
|
Float between 0 and 1 setting the probability of including the channel
|
|
226
226
|
in the set of permuted channels. Defaults to 0.2.
|
|
227
|
-
random_state
|
|
227
|
+
random_state: int | numpy.random.Generator, optional
|
|
228
228
|
Seed to be used to instantiate numpy random number generator instance.
|
|
229
229
|
Used to decide whether or not to transform given the probability
|
|
230
230
|
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
@@ -281,7 +281,7 @@ class GaussianNoise(Transform):
|
|
|
281
281
|
Float setting the probability of applying the operation.
|
|
282
282
|
std : float, optional
|
|
283
283
|
Standard deviation to use for the additive noise. Defaults to 0.1.
|
|
284
|
-
random_state
|
|
284
|
+
random_state: int | numpy.random.Generator, optional
|
|
285
285
|
Seed to be used to instantiate numpy random number generator instance.
|
|
286
286
|
Defaults to None.
|
|
287
287
|
|
|
@@ -348,7 +348,7 @@ class ChannelsSymmetry(Transform):
|
|
|
348
348
|
nomenclature) of the EEG channels that will be transformed. The
|
|
349
349
|
first name should correspond the data in the first row of X, the
|
|
350
350
|
second name in the second row and so on.
|
|
351
|
-
random_state
|
|
351
|
+
random_state: int | numpy.random.Generator, optional
|
|
352
352
|
Seed to be used to instantiate numpy random number generator instance.
|
|
353
353
|
Used to decide whether or not to transform given the probability
|
|
354
354
|
argument. Defaults to None.
|
|
@@ -410,8 +410,7 @@ class ChannelsSymmetry(Transform):
|
|
|
410
410
|
|
|
411
411
|
|
|
412
412
|
class SmoothTimeMask(Transform):
|
|
413
|
-
"""Smoothly replace a randomly chosen contiguous part of all channels by
|
|
414
|
-
|
|
413
|
+
"""Smoothly replace a randomly chosen contiguous part of all channels by
|
|
415
414
|
zeros.
|
|
416
415
|
|
|
417
416
|
Suggested e.g. in [1]_ and [2]_
|
|
@@ -423,7 +422,7 @@ class SmoothTimeMask(Transform):
|
|
|
423
422
|
mask_len_samples : int | torch.Tensor, optional
|
|
424
423
|
Number of consecutive samples to zero out. Will be ignored if
|
|
425
424
|
magnitude is not set to None. Defaults to 100.
|
|
426
|
-
random_state
|
|
425
|
+
random_state: int | numpy.random.Generator, optional
|
|
427
426
|
Seed to be used to instantiate numpy random number generator instance.
|
|
428
427
|
Defaults to None.
|
|
429
428
|
|
|
@@ -496,8 +495,7 @@ class SmoothTimeMask(Transform):
|
|
|
496
495
|
|
|
497
496
|
|
|
498
497
|
class BandstopFilter(Transform):
|
|
499
|
-
"""Apply a band-stop filter with desired bandwidth at a randomly selected
|
|
500
|
-
|
|
498
|
+
"""Apply a band-stop filter with desired bandwidth at a randomly selected
|
|
501
499
|
frequency position between 0 and ``max_freq``.
|
|
502
500
|
|
|
503
501
|
Suggested e.g. in [1]_ and [2]_
|
|
@@ -516,7 +514,7 @@ class BandstopFilter(Transform):
|
|
|
516
514
|
that the corresponding high cut frequency + transition (=1Hz) are below
|
|
517
515
|
``max_freq``. If omitted or `None`, will default to the Nyquist
|
|
518
516
|
frequency (``sfreq / 2``).
|
|
519
|
-
random_state
|
|
517
|
+
random_state: int | numpy.random.Generator, optional
|
|
520
518
|
Seed to be used to instantiate numpy random number generator instance.
|
|
521
519
|
Defaults to None.
|
|
522
520
|
|
|
@@ -625,7 +623,7 @@ class FrequencyShift(Transform):
|
|
|
625
623
|
max_delta_freq : float | torch.Tensor, optional
|
|
626
624
|
Maximum shift in Hz that can be sampled (in absolute value).
|
|
627
625
|
Defaults to 2 (shift sampled between -2 and 2 Hz).
|
|
628
|
-
random_state
|
|
626
|
+
random_state: int | numpy.random.Generator, optional
|
|
629
627
|
Seed to be used to instantiate numpy random number generator instance.
|
|
630
628
|
Defaults to None.
|
|
631
629
|
"""
|
|
@@ -680,8 +678,7 @@ class FrequencyShift(Transform):
|
|
|
680
678
|
|
|
681
679
|
|
|
682
680
|
def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
|
|
683
|
-
"""Returns standard 10-20 sensors position matrix (for instantiating
|
|
684
|
-
|
|
681
|
+
"""Returns standard 10-20 sensors position matrix (for instantiating
|
|
685
682
|
SensorsRotation for example).
|
|
686
683
|
|
|
687
684
|
Parameters
|
|
@@ -709,8 +706,7 @@ def _get_standard_10_20_positions(raw_or_epoch=None, ordered_ch_names=None):
|
|
|
709
706
|
|
|
710
707
|
|
|
711
708
|
class SensorsRotation(Transform):
|
|
712
|
-
"""Interpolates EEG signals over sensors rotated around the desired axis
|
|
713
|
-
|
|
709
|
+
"""Interpolates EEG signals over sensors rotated around the desired axis
|
|
714
710
|
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
715
711
|
|
|
716
712
|
Suggested in [1]_
|
|
@@ -738,7 +734,7 @@ class SensorsRotation(Transform):
|
|
|
738
734
|
Whether to use spherical splines for the interpolation or not. When
|
|
739
735
|
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
740
736
|
be used (as in the original paper). Defaults to True.
|
|
741
|
-
random_state
|
|
737
|
+
random_state: int | numpy.random.Generator, optional
|
|
742
738
|
Seed to be used to instantiate numpy random number generator instance.
|
|
743
739
|
Defaults to None.
|
|
744
740
|
|
|
@@ -837,8 +833,7 @@ class SensorsRotation(Transform):
|
|
|
837
833
|
|
|
838
834
|
|
|
839
835
|
class SensorsZRotation(SensorsRotation):
|
|
840
|
-
"""Interpolates EEG signals over sensors rotated around the Z axis
|
|
841
|
-
|
|
836
|
+
"""Interpolates EEG signals over sensors rotated around the Z axis
|
|
842
837
|
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
843
838
|
|
|
844
839
|
Suggested in [1]_
|
|
@@ -860,7 +855,7 @@ class SensorsZRotation(SensorsRotation):
|
|
|
860
855
|
Whether to use spherical splines for the interpolation or not. When
|
|
861
856
|
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
862
857
|
be used (as in the original paper). Defaults to True.
|
|
863
|
-
random_state
|
|
858
|
+
random_state: int | numpy.random.Generator, optional
|
|
864
859
|
Seed to be used to instantiate numpy random number generator instance.
|
|
865
860
|
Defaults to None.
|
|
866
861
|
|
|
@@ -894,8 +889,7 @@ class SensorsZRotation(SensorsRotation):
|
|
|
894
889
|
|
|
895
890
|
|
|
896
891
|
class SensorsYRotation(SensorsRotation):
|
|
897
|
-
"""Interpolates EEG signals over sensors rotated around the Y axis
|
|
898
|
-
|
|
892
|
+
"""Interpolates EEG signals over sensors rotated around the Y axis
|
|
899
893
|
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
900
894
|
|
|
901
895
|
Suggested in [1]_
|
|
@@ -917,7 +911,7 @@ class SensorsYRotation(SensorsRotation):
|
|
|
917
911
|
Whether to use spherical splines for the interpolation or not. When
|
|
918
912
|
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
919
913
|
be used (as in the original paper). Defaults to True.
|
|
920
|
-
random_state
|
|
914
|
+
random_state: int | numpy.random.Generator, optional
|
|
921
915
|
Seed to be used to instantiate numpy random number generator instance.
|
|
922
916
|
Defaults to None.
|
|
923
917
|
|
|
@@ -951,8 +945,7 @@ class SensorsYRotation(SensorsRotation):
|
|
|
951
945
|
|
|
952
946
|
|
|
953
947
|
class SensorsXRotation(SensorsRotation):
|
|
954
|
-
"""Interpolates EEG signals over sensors rotated around the X axis
|
|
955
|
-
|
|
948
|
+
"""Interpolates EEG signals over sensors rotated around the X axis
|
|
956
949
|
with an angle sampled uniformly between ``-max_degree`` and ``max_degree``.
|
|
957
950
|
|
|
958
951
|
Suggested in [1]_
|
|
@@ -974,7 +967,7 @@ class SensorsXRotation(SensorsRotation):
|
|
|
974
967
|
Whether to use spherical splines for the interpolation or not. When
|
|
975
968
|
``False``, standard scipy.interpolate.Rbf (with quadratic kernel) will
|
|
976
969
|
be used (as in the original paper). Defaults to True.
|
|
977
|
-
random_state
|
|
970
|
+
random_state: int | numpy.random.Generator, optional
|
|
978
971
|
Seed to be used to instantiate numpy random number generator instance.
|
|
979
972
|
Defaults to None.
|
|
980
973
|
|
|
@@ -1008,19 +1001,17 @@ class SensorsXRotation(SensorsRotation):
|
|
|
1008
1001
|
|
|
1009
1002
|
|
|
1010
1003
|
class Mixup(Transform):
|
|
1011
|
-
"""Implements Iterator for Mixup for EEG data.
|
|
1012
|
-
|
|
1013
|
-
See [1]_.
|
|
1004
|
+
"""Implements Iterator for Mixup for EEG data. See [1]_.
|
|
1014
1005
|
Implementation based on [2]_.
|
|
1015
1006
|
|
|
1016
1007
|
Parameters
|
|
1017
1008
|
----------
|
|
1018
|
-
alpha
|
|
1009
|
+
alpha: float
|
|
1019
1010
|
Mixup hyperparameter.
|
|
1020
|
-
beta_per_sample
|
|
1011
|
+
beta_per_sample: bool (default=False)
|
|
1021
1012
|
By default, one mixing coefficient per batch is drawn from a beta
|
|
1022
1013
|
distribution. If True, one mixing coefficient per sample is drawn.
|
|
1023
|
-
random_state
|
|
1014
|
+
random_state: int | numpy.random.Generator, optional
|
|
1024
1015
|
Seed to be used to instantiate numpy random number generator instance.
|
|
1025
1016
|
Defaults to None.
|
|
1026
1017
|
|
|
@@ -1055,7 +1046,7 @@ class Mixup(Transform):
|
|
|
1055
1046
|
|
|
1056
1047
|
Returns
|
|
1057
1048
|
-------
|
|
1058
|
-
params
|
|
1049
|
+
params: dict
|
|
1059
1050
|
Contains the values sampled uniformly between 0 and 1 setting the
|
|
1060
1051
|
linear interpolation between examples (lam) and the shuffled
|
|
1061
1052
|
indices of examples that are mixed into original examples
|
|
@@ -1100,7 +1091,7 @@ class SegmentationReconstruction(Transform):
|
|
|
1100
1091
|
----------
|
|
1101
1092
|
probability : float
|
|
1102
1093
|
Float setting the probability of applying the operation.
|
|
1103
|
-
random_state
|
|
1094
|
+
random_state: int | numpy.random.Generator, optional
|
|
1104
1095
|
Seed to be used to instantiate numpy random number generator instance.
|
|
1105
1096
|
Used to decide whether to transform given the probability
|
|
1106
1097
|
argument and to sample the segments mixing. Defaults to None.
|
|
@@ -1139,7 +1130,6 @@ class SegmentationReconstruction(Transform):
|
|
|
1139
1130
|
The data.
|
|
1140
1131
|
y : tensor.Tensor
|
|
1141
1132
|
The labels.
|
|
1142
|
-
|
|
1143
1133
|
Returns
|
|
1144
1134
|
-------
|
|
1145
1135
|
params : dict
|
|
@@ -1208,12 +1198,12 @@ class MaskEncoding(Transform):
|
|
|
1208
1198
|
----------
|
|
1209
1199
|
probability : float
|
|
1210
1200
|
Float setting the probability of applying the operation.
|
|
1211
|
-
max_mask_ratio
|
|
1201
|
+
max_mask_ratio: float, optional
|
|
1212
1202
|
Signal ratio to zero out. Defaults to 0.1.
|
|
1213
1203
|
n_segments : int, optional
|
|
1214
1204
|
Number of segments to zero out in each example.
|
|
1215
1205
|
Defaults to 1.
|
|
1216
|
-
random_state
|
|
1206
|
+
random_state: int | numpy.random.Generator, optional
|
|
1217
1207
|
Seed to be used to instantiate numpy random number generator instance.
|
|
1218
1208
|
Defaults to None.
|
|
1219
1209
|
|
|
@@ -1257,7 +1247,6 @@ class MaskEncoding(Transform):
|
|
|
1257
1247
|
The data.
|
|
1258
1248
|
y : tensor.Tensor
|
|
1259
1249
|
The labels.
|
|
1260
|
-
|
|
1261
1250
|
Returns
|
|
1262
1251
|
-------
|
|
1263
1252
|
params : dict
|
|
@@ -1294,9 +1283,9 @@ class ChannelsReref(Transform):
|
|
|
1294
1283
|
|
|
1295
1284
|
Parameters
|
|
1296
1285
|
----------
|
|
1297
|
-
probability
|
|
1286
|
+
probability: float
|
|
1298
1287
|
Float setting the probability of applying the operation.
|
|
1299
|
-
random_state
|
|
1288
|
+
random_state: int | numpy.random.Generator, optional
|
|
1300
1289
|
Seed to be used to instantiate numpy random number generator instance.
|
|
1301
1290
|
Used to decide whether or not to transform given the probability
|
|
1302
1291
|
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
@@ -1308,6 +1297,7 @@ class ChannelsReref(Transform):
|
|
|
1308
1297
|
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1309
1298
|
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1310
1299
|
Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
|
|
1300
|
+
|
|
1311
1301
|
"""
|
|
1312
1302
|
|
|
1313
1303
|
operation = staticmethod(channels_rereference) # type: ignore[assignment]
|
|
@@ -1316,7 +1306,7 @@ class ChannelsReref(Transform):
|
|
|
1316
1306
|
super().__init__(probability=probability, random_state=random_state)
|
|
1317
1307
|
|
|
1318
1308
|
def get_augmentation_params(self, *batch):
|
|
1319
|
-
"""Return transform parameters
|
|
1309
|
+
"""Return transform parameters"""
|
|
1320
1310
|
return {
|
|
1321
1311
|
"random_state": self.rng,
|
|
1322
1312
|
}
|
|
@@ -1329,9 +1319,9 @@ class AmplitudeScale(Transform):
|
|
|
1329
1319
|
|
|
1330
1320
|
Parameters
|
|
1331
1321
|
----------
|
|
1332
|
-
probability
|
|
1322
|
+
probability: float
|
|
1333
1323
|
Float setting the probability of applying the operation.
|
|
1334
|
-
random_state
|
|
1324
|
+
random_state: int | numpy.random.Generator, optional
|
|
1335
1325
|
Seed to be used to instantiate numpy random number generator instance.
|
|
1336
1326
|
Used to decide whether or not to transform given the probability
|
|
1337
1327
|
argument, to sample which channels to shuffle and to carry the shuffle.
|
|
@@ -1343,6 +1333,7 @@ class AmplitudeScale(Transform):
|
|
|
1343
1333
|
Representation Learning for Electroencephalogram Classification. Proceedings
|
|
1344
1334
|
of the Machine Learning for Health NeurIPS Workshop, in Proceedings of Machine
|
|
1345
1335
|
Learning Research 136:238-253 Available from https://proceedings.mlr.press/v136/mohsenvand20a.html.
|
|
1336
|
+
|
|
1346
1337
|
"""
|
|
1347
1338
|
|
|
1348
1339
|
operation = staticmethod(amplitude_scale) # type: ignore[assignment]
|
|
@@ -1352,5 +1343,5 @@ class AmplitudeScale(Transform):
|
|
|
1352
1343
|
self.scale = interval
|
|
1353
1344
|
|
|
1354
1345
|
def get_augmentation_params(self, *batch):
|
|
1355
|
-
"""Return transform parameters
|
|
1346
|
+
"""Return transform parameters"""
|
|
1356
1347
|
return {"random_state": self.rng, "scale": self.scale}
|
braindecode/classifier.py
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
import warnings
|
|
10
10
|
|
|
11
|
+
import numpy as np
|
|
11
12
|
from skorch import NeuralNet
|
|
12
13
|
from skorch.callbacks import EpochScoring
|
|
13
14
|
from skorch.classifier import NeuralNetClassifier
|
|
@@ -94,8 +95,7 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
94
95
|
return iterator
|
|
95
96
|
|
|
96
97
|
def predict_proba(self, X):
|
|
97
|
-
"""Return the output of the module's forward method as a numpy
|
|
98
|
-
|
|
98
|
+
"""Return the output of the module's forward method as a numpy
|
|
99
99
|
array. In case of cropped decoding returns averaged values for
|
|
100
100
|
each trial.
|
|
101
101
|
|
|
@@ -125,6 +125,7 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
125
125
|
Returns
|
|
126
126
|
-------
|
|
127
127
|
y_proba : numpy ndarray
|
|
128
|
+
|
|
128
129
|
"""
|
|
129
130
|
y_pred = super().predict_proba(X)
|
|
130
131
|
# Normally, we have to average the predictions across crops/timesteps
|
|
@@ -191,28 +192,28 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
191
192
|
Returns
|
|
192
193
|
-------
|
|
193
194
|
y_pred : numpy ndarray
|
|
195
|
+
|
|
194
196
|
"""
|
|
195
197
|
return self.predict_proba(X).argmax(1)
|
|
196
198
|
|
|
197
199
|
def predict_trials(self, X, return_targets=True):
|
|
198
|
-
"""Create trialwise predictions and optionally also return trialwise
|
|
199
|
-
|
|
200
|
+
"""Create trialwise predictions and optionally also return trialwise
|
|
200
201
|
labels from cropped dataset.
|
|
201
202
|
|
|
202
203
|
Parameters
|
|
203
204
|
----------
|
|
204
|
-
X
|
|
205
|
+
X: braindecode.datasets.BaseConcatDataset
|
|
205
206
|
A braindecode dataset to be predicted.
|
|
206
|
-
return_targets
|
|
207
|
+
return_targets: bool
|
|
207
208
|
If True, additionally returns the trial targets.
|
|
208
209
|
|
|
209
210
|
Returns
|
|
210
211
|
-------
|
|
211
|
-
trial_predictions
|
|
212
|
+
trial_predictions: np.ndarray
|
|
212
213
|
3-dimensional array (n_trials x n_classes x n_predictions), where
|
|
213
214
|
the number of predictions depend on the chosen window size and the
|
|
214
215
|
receptive field of the network.
|
|
215
|
-
trial_labels
|
|
216
|
+
trial_labels: np.ndarray
|
|
216
217
|
2-dimensional array (n_trials x n_targets) where the number of
|
|
217
218
|
targets depends on the decoding paradigm and can be either a single
|
|
218
219
|
value, multiple values, or a sequence.
|
|
@@ -236,9 +237,13 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
236
237
|
num_workers=self.get_iterator(X, training=False).loader.num_workers,
|
|
237
238
|
)
|
|
238
239
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
240
|
+
def _get_n_outputs(self, y, classes):
|
|
241
|
+
classes_y = np.unique(y)
|
|
242
|
+
if classes is not None:
|
|
243
|
+
assert set(classes_y) <= set(classes)
|
|
244
|
+
else:
|
|
245
|
+
classes = classes_y
|
|
246
|
+
return len(classes)
|
|
242
247
|
|
|
243
248
|
# Only add the 'accuracy' callback if we are not in cropped mode.
|
|
244
249
|
@property
|
braindecode/datasets/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Loader code for some datasets.
|
|
3
|
+
"""
|
|
2
4
|
|
|
3
5
|
from .base import (
|
|
4
6
|
BaseConcatDataset,
|
|
@@ -9,11 +11,9 @@ from .base import (
|
|
|
9
11
|
)
|
|
10
12
|
from .bcicomp import BCICompetitionIVDataset4
|
|
11
13
|
from .bids import BIDSDataset, BIDSEpochsDataset
|
|
12
|
-
from .chb_mit import CHBMIT
|
|
13
14
|
from .mne import create_from_mne_epochs, create_from_mne_raw
|
|
14
15
|
from .moabb import BNCI2014_001, HGD, MOABBDataset
|
|
15
16
|
from .nmt import NMT
|
|
16
|
-
from .siena import SIENA
|
|
17
17
|
from .sleep_physio_challe_18 import SleepPhysionetChallenge2018
|
|
18
18
|
from .sleep_physionet import SleepPhysionet
|
|
19
19
|
from .tuh import TUH, TUHAbnormal
|
|
@@ -34,9 +34,7 @@ __all__ = [
|
|
|
34
34
|
"create_from_mne_epochs",
|
|
35
35
|
"TUH",
|
|
36
36
|
"TUHAbnormal",
|
|
37
|
-
"SIENA",
|
|
38
37
|
"NMT",
|
|
39
|
-
"CHBMIT",
|
|
40
38
|
"SleepPhysionet",
|
|
41
39
|
"SleepPhysionetChallenge2018",
|
|
42
40
|
"create_from_X_y",
|