braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/__init__.py
CHANGED
|
@@ -1,30 +1,50 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Utilities for data augmentation.
|
|
3
3
|
"""
|
|
4
|
-
|
|
4
|
+
|
|
5
|
+
from . import functional
|
|
6
|
+
from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
|
|
5
7
|
from .transforms import (
|
|
6
|
-
|
|
7
|
-
SignFlip,
|
|
8
|
-
FTSurrogate,
|
|
9
|
-
ChannelsShuffle,
|
|
8
|
+
BandstopFilter,
|
|
10
9
|
ChannelsDropout,
|
|
11
|
-
|
|
10
|
+
ChannelsShuffle,
|
|
12
11
|
ChannelsSymmetry,
|
|
13
|
-
SmoothTimeMask,
|
|
14
|
-
BandstopFilter,
|
|
15
12
|
FrequencyShift,
|
|
13
|
+
FTSurrogate,
|
|
14
|
+
GaussianNoise,
|
|
15
|
+
MaskEncoding,
|
|
16
|
+
Mixup,
|
|
17
|
+
SegmentationReconstruction,
|
|
16
18
|
SensorsRotation,
|
|
17
|
-
SensorsZRotation,
|
|
18
|
-
SensorsYRotation,
|
|
19
19
|
SensorsXRotation,
|
|
20
|
-
|
|
20
|
+
SensorsYRotation,
|
|
21
|
+
SensorsZRotation,
|
|
22
|
+
SignFlip,
|
|
23
|
+
SmoothTimeMask,
|
|
24
|
+
TimeReverse,
|
|
21
25
|
)
|
|
22
26
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
27
|
+
__all__ = [
|
|
28
|
+
"Transform",
|
|
29
|
+
"IdentityTransform",
|
|
30
|
+
"Compose",
|
|
31
|
+
"AugmentedDataLoader",
|
|
32
|
+
"TimeReverse",
|
|
33
|
+
"SignFlip",
|
|
34
|
+
"FTSurrogate",
|
|
35
|
+
"ChannelsShuffle",
|
|
36
|
+
"ChannelsDropout",
|
|
37
|
+
"GaussianNoise",
|
|
38
|
+
"ChannelsSymmetry",
|
|
39
|
+
"SmoothTimeMask",
|
|
40
|
+
"BandstopFilter",
|
|
41
|
+
"FrequencyShift",
|
|
42
|
+
"SensorsRotation",
|
|
43
|
+
"SensorsZRotation",
|
|
44
|
+
"SensorsYRotation",
|
|
45
|
+
"SensorsXRotation",
|
|
46
|
+
"Mixup",
|
|
47
|
+
"SegmentationReconstruction",
|
|
48
|
+
"MaskEncoding",
|
|
49
|
+
"functional",
|
|
50
|
+
]
|
braindecode/augmentation/base.py
CHANGED
|
@@ -5,31 +5,28 @@
|
|
|
5
5
|
# Valentin Iovene <val@too.gy>
|
|
6
6
|
# License: BSD (3-clause)
|
|
7
7
|
|
|
8
|
-
from typing import List, Tuple, Any, Optional, Union, Callable
|
|
9
8
|
from numbers import Real
|
|
9
|
+
from typing import Any, Callable, Optional, Union
|
|
10
10
|
|
|
11
|
-
from sklearn.utils import check_random_state
|
|
12
11
|
import torch
|
|
12
|
+
from sklearn.utils import check_random_state
|
|
13
13
|
from torch import Tensor, nn
|
|
14
14
|
from torch.utils.data import DataLoader
|
|
15
15
|
from torch.utils.data._utils.collate import default_collate
|
|
16
16
|
|
|
17
17
|
from .functional import identity
|
|
18
18
|
|
|
19
|
-
Batch =
|
|
19
|
+
Batch = list[tuple[torch.Tensor, int, Any]]
|
|
20
20
|
Output = Union[
|
|
21
21
|
# just outputting X
|
|
22
22
|
torch.Tensor,
|
|
23
23
|
# outputting (X, y) where y can be a tensor or tuple of tensors
|
|
24
|
-
|
|
24
|
+
tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
|
|
25
25
|
]
|
|
26
26
|
# (X, y) -> (X', y') where y' can be a tensor or a tuple of tensors
|
|
27
27
|
Operation = Callable[
|
|
28
28
|
[torch.Tensor, torch.Tensor],
|
|
29
|
-
|
|
30
|
-
torch.Tensor,
|
|
31
|
-
Union[torch.Tensor, Tuple[torch.Tensor, ...]]
|
|
32
|
-
]
|
|
29
|
+
tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
|
|
33
30
|
]
|
|
34
31
|
|
|
35
32
|
|
|
@@ -50,18 +47,20 @@ class Transform(torch.nn.Module):
|
|
|
50
47
|
Used to decide whether or not to transform given the probability
|
|
51
48
|
argument. Defaults to None.
|
|
52
49
|
"""
|
|
50
|
+
|
|
53
51
|
operation: Operation
|
|
54
52
|
|
|
55
53
|
def __init__(self, probability=1.0, random_state=None):
|
|
56
54
|
super().__init__()
|
|
57
55
|
if self.forward.__func__ is Transform.forward:
|
|
58
|
-
assert callable(self.operation), "operation should be a "
|
|
59
|
-
"``callable``. "
|
|
56
|
+
assert callable(self.operation), "operation should be a ``callable``. "
|
|
60
57
|
|
|
61
58
|
assert isinstance(probability, Real), (
|
|
62
|
-
f"probability should be a ``real``. Got {type(probability)}."
|
|
63
|
-
|
|
59
|
+
f"probability should be a ``real``. Got {type(probability)}."
|
|
60
|
+
)
|
|
61
|
+
assert probability <= 1.0 and probability >= 0.0, (
|
|
64
62
|
"probability should be between 0 and 1."
|
|
63
|
+
)
|
|
65
64
|
self._probability = probability
|
|
66
65
|
self.rng = check_random_state(random_state)
|
|
67
66
|
|
|
@@ -108,8 +107,9 @@ class Transform(torch.nn.Module):
|
|
|
108
107
|
if num_valid > 0:
|
|
109
108
|
# Uses the mask to define the output
|
|
110
109
|
out_X[mask, ...], tr_y = self.operation(
|
|
111
|
-
out_X[mask, ...],
|
|
112
|
-
|
|
110
|
+
out_X[mask, ...],
|
|
111
|
+
out_y[mask],
|
|
112
|
+
**self.get_augmentation_params(out_X[mask, ...], out_y[mask]),
|
|
113
113
|
)
|
|
114
114
|
# Apply the operation defining the Transform to the whole batch
|
|
115
115
|
if isinstance(tr_y, tuple):
|
|
@@ -125,11 +125,10 @@ class Transform(torch.nn.Module):
|
|
|
125
125
|
return out_X
|
|
126
126
|
|
|
127
127
|
def _get_mask(self, batch_size, device) -> torch.Tensor:
|
|
128
|
-
"""Samples whether to apply operation or not over the whole batch
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
).to(device)
|
|
128
|
+
"""Samples whether to apply operation or not over the whole batch"""
|
|
129
|
+
return torch.as_tensor(self.probability > self.rng.uniform(size=batch_size)).to(
|
|
130
|
+
device
|
|
131
|
+
)
|
|
133
132
|
|
|
134
133
|
@property
|
|
135
134
|
def probability(self):
|
|
@@ -141,7 +140,9 @@ class IdentityTransform(Transform):
|
|
|
141
140
|
|
|
142
141
|
Transform that does not change the input.
|
|
143
142
|
"""
|
|
144
|
-
|
|
143
|
+
|
|
144
|
+
operation = staticmethod(identity) # type: ignore[assignment]
|
|
145
|
+
# https://github.com/python/mypy/issues/4574
|
|
145
146
|
|
|
146
147
|
|
|
147
148
|
class Compose(Transform):
|
|
@@ -167,8 +168,8 @@ class Compose(Transform):
|
|
|
167
168
|
|
|
168
169
|
|
|
169
170
|
def _make_collateable(transform, device=None):
|
|
170
|
-
"""
|
|
171
|
-
|
|
171
|
+
"""Wraps a transform to make it collateable.
|
|
172
|
+
with device control."""
|
|
172
173
|
|
|
173
174
|
def _collate_fn(batch):
|
|
174
175
|
collated_batch = default_collate(batch)
|
|
@@ -205,7 +206,7 @@ class AugmentedDataLoader(DataLoader):
|
|
|
205
206
|
"to pass transform"
|
|
206
207
|
)
|
|
207
208
|
if transforms is None or (
|
|
208
|
-
|
|
209
|
+
isinstance(transforms, list) and len(transforms) == 0
|
|
209
210
|
):
|
|
210
211
|
self.collated_tr = _make_collateable(IdentityTransform(), device=device)
|
|
211
212
|
elif isinstance(transforms, (Transform, nn.Module)):
|
|
@@ -218,8 +219,4 @@ class AugmentedDataLoader(DataLoader):
|
|
|
218
219
|
"or a list of Transform objects."
|
|
219
220
|
)
|
|
220
221
|
|
|
221
|
-
super().__init__(
|
|
222
|
-
dataset,
|
|
223
|
-
collate_fn=self.collated_tr,
|
|
224
|
-
**kwargs
|
|
225
|
-
)
|
|
222
|
+
super().__init__(dataset, collate_fn=self.collated_tr, **kwargs)
|