braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
braindecode/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Utilities for data augmentation."""
|
|
2
|
+
|
|
3
|
+
from . import functional
|
|
4
|
+
from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
|
|
5
|
+
from .transforms import (
|
|
6
|
+
AmplitudeScale,
|
|
7
|
+
BandstopFilter,
|
|
8
|
+
ChannelsDropout,
|
|
9
|
+
ChannelsReref,
|
|
10
|
+
ChannelsShuffle,
|
|
11
|
+
ChannelsSymmetry,
|
|
12
|
+
FrequencyShift,
|
|
13
|
+
FTSurrogate,
|
|
14
|
+
GaussianNoise,
|
|
15
|
+
MaskEncoding,
|
|
16
|
+
Mixup,
|
|
17
|
+
SegmentationReconstruction,
|
|
18
|
+
SensorsRotation,
|
|
19
|
+
SensorsXRotation,
|
|
20
|
+
SensorsYRotation,
|
|
21
|
+
SensorsZRotation,
|
|
22
|
+
SignFlip,
|
|
23
|
+
SmoothTimeMask,
|
|
24
|
+
TimeReverse,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"Transform",
|
|
29
|
+
"IdentityTransform",
|
|
30
|
+
"Compose",
|
|
31
|
+
"AugmentedDataLoader",
|
|
32
|
+
"TimeReverse",
|
|
33
|
+
"SignFlip",
|
|
34
|
+
"FTSurrogate",
|
|
35
|
+
"ChannelsShuffle",
|
|
36
|
+
"ChannelsDropout",
|
|
37
|
+
"GaussianNoise",
|
|
38
|
+
"ChannelsSymmetry",
|
|
39
|
+
"SmoothTimeMask",
|
|
40
|
+
"BandstopFilter",
|
|
41
|
+
"FrequencyShift",
|
|
42
|
+
"SensorsRotation",
|
|
43
|
+
"SensorsZRotation",
|
|
44
|
+
"SensorsYRotation",
|
|
45
|
+
"SensorsXRotation",
|
|
46
|
+
"Mixup",
|
|
47
|
+
"SegmentationReconstruction",
|
|
48
|
+
"MaskEncoding",
|
|
49
|
+
"AmplitudeScale",
|
|
50
|
+
"ChannelsReref",
|
|
51
|
+
"functional",
|
|
52
|
+
]
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
# Authors: Cédric Rommel <cedric.rommel@inria.fr>
|
|
2
|
+
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
4
|
+
# Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>
|
|
5
|
+
# Valentin Iovene <val@too.gy>
|
|
6
|
+
# License: BSD (3-clause)
|
|
7
|
+
|
|
8
|
+
from numbers import Real
|
|
9
|
+
from typing import Any, Callable, Optional, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from sklearn.utils import check_random_state
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
from torch.utils.data import DataLoader
|
|
15
|
+
from torch.utils.data._utils.collate import default_collate
|
|
16
|
+
|
|
17
|
+
from .functional import identity
|
|
18
|
+
|
|
19
|
+
Batch = list[tuple[torch.Tensor, int, Any]]
|
|
20
|
+
Output = Union[
|
|
21
|
+
# just outputting X
|
|
22
|
+
torch.Tensor,
|
|
23
|
+
# outputting (X, y) where y can be a tensor or tuple of tensors
|
|
24
|
+
tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
|
|
25
|
+
]
|
|
26
|
+
# (X, y) -> (X', y') where y' can be a tensor or a tuple of tensors
|
|
27
|
+
Operation = Callable[
|
|
28
|
+
[torch.Tensor, torch.Tensor],
|
|
29
|
+
tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Transform(torch.nn.Module):
|
|
34
|
+
"""Basic transform class used for implementing data augmentation.
|
|
35
|
+
|
|
36
|
+
operations.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
operation : callable
|
|
41
|
+
A function taking arrays X, y (inputs and targets resp.) and
|
|
42
|
+
other required arguments, and returning the transformed X and y.
|
|
43
|
+
probability : float, optional
|
|
44
|
+
Float between 0 and 1 defining the uniform probability of applying the
|
|
45
|
+
operation. Set to 1.0 by default (e.g always apply the operation).
|
|
46
|
+
random_state : int, optional
|
|
47
|
+
Seed to be used to instantiate numpy random number generator instance.
|
|
48
|
+
Used to decide whether or not to transform given the probability
|
|
49
|
+
argument. Defaults to None.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
operation: Operation
|
|
53
|
+
|
|
54
|
+
def __init__(self, probability=1.0, random_state=None):
|
|
55
|
+
super().__init__()
|
|
56
|
+
if self.forward.__func__ is Transform.forward:
|
|
57
|
+
assert callable(self.operation), "operation should be a ``callable``. "
|
|
58
|
+
|
|
59
|
+
assert isinstance(probability, Real), (
|
|
60
|
+
f"probability should be a ``real``. Got {type(probability)}."
|
|
61
|
+
)
|
|
62
|
+
assert probability <= 1.0 and probability >= 0.0, (
|
|
63
|
+
"probability should be between 0 and 1."
|
|
64
|
+
)
|
|
65
|
+
self._probability = probability
|
|
66
|
+
self.rng = check_random_state(random_state)
|
|
67
|
+
|
|
68
|
+
def get_augmentation_params(self, *batch):
|
|
69
|
+
return dict()
|
|
70
|
+
|
|
71
|
+
def forward(self, X: Tensor, y: Optional[Tensor] = None) -> Output:
|
|
72
|
+
"""General forward pass for an augmentation transform.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
X : torch.Tensor
|
|
77
|
+
EEG input example or batch.
|
|
78
|
+
y : torch.Tensor | None
|
|
79
|
+
EEG labels for the example or batch. Defaults to None.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
torch.Tensor
|
|
84
|
+
Transformed inputs.
|
|
85
|
+
torch.Tensor, optional
|
|
86
|
+
Transformed labels. Only returned when y is not None.
|
|
87
|
+
"""
|
|
88
|
+
X = torch.as_tensor(X).float()
|
|
89
|
+
|
|
90
|
+
out_X = X.clone()
|
|
91
|
+
# check if input has a batch dimension
|
|
92
|
+
if len(out_X.shape) < 3:
|
|
93
|
+
out_X = out_X[None, ...]
|
|
94
|
+
|
|
95
|
+
if y is not None:
|
|
96
|
+
y = torch.as_tensor(y).to(out_X.device)
|
|
97
|
+
out_y = y.clone()
|
|
98
|
+
if len(out_y.shape) == 0:
|
|
99
|
+
out_y = out_y.reshape(1)
|
|
100
|
+
else:
|
|
101
|
+
out_y = torch.zeros(out_X.shape[0], device=out_X.device)
|
|
102
|
+
|
|
103
|
+
# Samples a mask setting for each example whether they should stay
|
|
104
|
+
# unchanged or not
|
|
105
|
+
mask = self._get_mask(out_X.shape[0], out_X.device)
|
|
106
|
+
num_valid = mask.sum().long()
|
|
107
|
+
|
|
108
|
+
if num_valid > 0:
|
|
109
|
+
# Uses the mask to define the output
|
|
110
|
+
out_X[mask, ...], tr_y = self.operation(
|
|
111
|
+
out_X[mask, ...],
|
|
112
|
+
out_y[mask],
|
|
113
|
+
**self.get_augmentation_params(out_X[mask, ...], out_y[mask]),
|
|
114
|
+
)
|
|
115
|
+
# Apply the operation defining the Transform to the whole batch
|
|
116
|
+
if isinstance(tr_y, tuple):
|
|
117
|
+
out_y = tuple(tmp_y[mask] for tmp_y in tr_y)
|
|
118
|
+
else:
|
|
119
|
+
out_y[mask] = tr_y
|
|
120
|
+
|
|
121
|
+
# potentially remove empty batch dimension again
|
|
122
|
+
out_X = out_X.reshape_as(X)
|
|
123
|
+
if y is not None:
|
|
124
|
+
return out_X, out_y
|
|
125
|
+
else:
|
|
126
|
+
return out_X
|
|
127
|
+
|
|
128
|
+
def _get_mask(self, batch_size, device) -> torch.Tensor:
|
|
129
|
+
"""Samples whether to apply operation or not over the whole batch."""
|
|
130
|
+
return torch.as_tensor(self.probability > self.rng.uniform(size=batch_size)).to(
|
|
131
|
+
device
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def probability(self):
|
|
136
|
+
return self._probability
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class IdentityTransform(Transform):
|
|
140
|
+
"""Identity transform.
|
|
141
|
+
|
|
142
|
+
Transform that does not change the input.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
operation = staticmethod(identity) # type: ignore[assignment]
|
|
146
|
+
# https://github.com/python/mypy/issues/4574
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class Compose(Transform):
|
|
150
|
+
"""Transform composition.
|
|
151
|
+
|
|
152
|
+
Callable class allowing to cast a sequence of Transform objects into a
|
|
153
|
+
single one.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
transforms : list
|
|
158
|
+
Sequence of Transforms to be composed.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
def __init__(self, transforms):
|
|
162
|
+
self.transforms = transforms
|
|
163
|
+
super().__init__()
|
|
164
|
+
|
|
165
|
+
def forward(self, X, y):
|
|
166
|
+
for transform in self.transforms:
|
|
167
|
+
X, y = transform(X, y)
|
|
168
|
+
return X, y
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _make_collateable(transform, device=None):
|
|
172
|
+
"""Wraps a transform to make it collateable.
|
|
173
|
+
|
|
174
|
+
with device control.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def _collate_fn(batch):
|
|
178
|
+
collated_batch = default_collate(batch)
|
|
179
|
+
X, y = collated_batch[:2]
|
|
180
|
+
|
|
181
|
+
if device is not None:
|
|
182
|
+
X = X.to(device)
|
|
183
|
+
y = y.to(device)
|
|
184
|
+
|
|
185
|
+
return (*transform(X, y), *collated_batch[2:])
|
|
186
|
+
|
|
187
|
+
return _collate_fn
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class AugmentedDataLoader(DataLoader):
|
|
191
|
+
"""A base dataloader class customized to applying augmentation Transforms.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
dataset : RecordDataset
|
|
196
|
+
The dataset containing the signals.
|
|
197
|
+
transforms : list | Transform, optional
|
|
198
|
+
Transform or sequence of Transform to be applied to each batch.
|
|
199
|
+
device : str | torch.device | None, optional
|
|
200
|
+
Device on which to transform the data. Defaults to None.
|
|
201
|
+
**kwargs : dict, optional
|
|
202
|
+
keyword arguments to pass to standard DataLoader class.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(self, dataset, transforms=None, device=None, **kwargs):
|
|
206
|
+
if "collate_fn" in kwargs:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"collate_fn cannot be used in this context because it is used "
|
|
209
|
+
"to pass transform"
|
|
210
|
+
)
|
|
211
|
+
if transforms is None or (
|
|
212
|
+
isinstance(transforms, list) and len(transforms) == 0
|
|
213
|
+
):
|
|
214
|
+
self.collated_tr = _make_collateable(IdentityTransform(), device=device)
|
|
215
|
+
elif isinstance(transforms, (Transform, nn.Module)):
|
|
216
|
+
self.collated_tr = _make_collateable(transforms, device=device)
|
|
217
|
+
elif isinstance(transforms, list):
|
|
218
|
+
self.collated_tr = _make_collateable(Compose(transforms), device=device)
|
|
219
|
+
else:
|
|
220
|
+
raise TypeError(
|
|
221
|
+
"transforms can be either a Transform object "
|
|
222
|
+
"or a list of Transform objects."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
super().__init__(dataset, collate_fn=self.collated_tr, **kwargs)
|