braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/__init__.py CHANGED
@@ -1,7 +1,6 @@
1
- from .version import __version__
2
-
3
1
  from .classifier import EEGClassifier
4
2
  from .regressor import EEGRegressor
3
+ from .version import __version__
5
4
 
6
5
  __all__ = [
7
6
  "__version__",
@@ -0,0 +1,50 @@
1
+ """
2
+ Utilities for data augmentation.
3
+ """
4
+
5
+ from . import functional
6
+ from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
7
+ from .transforms import (
8
+ BandstopFilter,
9
+ ChannelsDropout,
10
+ ChannelsShuffle,
11
+ ChannelsSymmetry,
12
+ FrequencyShift,
13
+ FTSurrogate,
14
+ GaussianNoise,
15
+ MaskEncoding,
16
+ Mixup,
17
+ SegmentationReconstruction,
18
+ SensorsRotation,
19
+ SensorsXRotation,
20
+ SensorsYRotation,
21
+ SensorsZRotation,
22
+ SignFlip,
23
+ SmoothTimeMask,
24
+ TimeReverse,
25
+ )
26
+
27
+ __all__ = [
28
+ "Transform",
29
+ "IdentityTransform",
30
+ "Compose",
31
+ "AugmentedDataLoader",
32
+ "TimeReverse",
33
+ "SignFlip",
34
+ "FTSurrogate",
35
+ "ChannelsShuffle",
36
+ "ChannelsDropout",
37
+ "GaussianNoise",
38
+ "ChannelsSymmetry",
39
+ "SmoothTimeMask",
40
+ "BandstopFilter",
41
+ "FrequencyShift",
42
+ "SensorsRotation",
43
+ "SensorsZRotation",
44
+ "SensorsYRotation",
45
+ "SensorsXRotation",
46
+ "Mixup",
47
+ "SegmentationReconstruction",
48
+ "MaskEncoding",
49
+ "functional",
50
+ ]
@@ -0,0 +1,222 @@
1
+ # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
4
+ # Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>
5
+ # Valentin Iovene <val@too.gy>
6
+ # License: BSD (3-clause)
7
+
8
+ from numbers import Real
9
+ from typing import Any, Callable, Optional, Union
10
+
11
+ import torch
12
+ from sklearn.utils import check_random_state
13
+ from torch import Tensor, nn
14
+ from torch.utils.data import DataLoader
15
+ from torch.utils.data._utils.collate import default_collate
16
+
17
+ from .functional import identity
18
+
19
+ Batch = list[tuple[torch.Tensor, int, Any]]
20
+ Output = Union[
21
+ # just outputting X
22
+ torch.Tensor,
23
+ # outputting (X, y) where y can be a tensor or tuple of tensors
24
+ tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
25
+ ]
26
+ # (X, y) -> (X', y') where y' can be a tensor or a tuple of tensors
27
+ Operation = Callable[
28
+ [torch.Tensor, torch.Tensor],
29
+ tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
30
+ ]
31
+
32
+
33
+ class Transform(torch.nn.Module):
34
+ """Basic transform class used for implementing data augmentation
35
+ operations.
36
+
37
+ Parameters
38
+ ----------
39
+ operation : callable
40
+ A function taking arrays X, y (inputs and targets resp.) and
41
+ other required arguments, and returning the transformed X and y.
42
+ probability : float, optional
43
+ Float between 0 and 1 defining the uniform probability of applying the
44
+ operation. Set to 1.0 by default (e.g always apply the operation).
45
+ random_state: int, optional
46
+ Seed to be used to instantiate numpy random number generator instance.
47
+ Used to decide whether or not to transform given the probability
48
+ argument. Defaults to None.
49
+ """
50
+
51
+ operation: Operation
52
+
53
+ def __init__(self, probability=1.0, random_state=None):
54
+ super().__init__()
55
+ if self.forward.__func__ is Transform.forward:
56
+ assert callable(self.operation), "operation should be a ``callable``. "
57
+
58
+ assert isinstance(probability, Real), (
59
+ f"probability should be a ``real``. Got {type(probability)}."
60
+ )
61
+ assert probability <= 1.0 and probability >= 0.0, (
62
+ "probability should be between 0 and 1."
63
+ )
64
+ self._probability = probability
65
+ self.rng = check_random_state(random_state)
66
+
67
+ def get_augmentation_params(self, *batch):
68
+ return dict()
69
+
70
+ def forward(self, X: Tensor, y: Optional[Tensor] = None) -> Output:
71
+ """General forward pass for an augmentation transform.
72
+
73
+ Parameters
74
+ ----------
75
+ X : torch.Tensor
76
+ EEG input example or batch.
77
+ y : torch.Tensor | None
78
+ EEG labels for the example or batch. Defaults to None.
79
+
80
+ Returns
81
+ -------
82
+ torch.Tensor
83
+ Transformed inputs.
84
+ torch.Tensor, optional
85
+ Transformed labels. Only returned when y is not None.
86
+ """
87
+ X = torch.as_tensor(X).float()
88
+
89
+ out_X = X.clone()
90
+ # check if input has a batch dimension
91
+ if len(out_X.shape) < 3:
92
+ out_X = out_X[None, ...]
93
+
94
+ if y is not None:
95
+ y = torch.as_tensor(y).to(out_X.device)
96
+ out_y = y.clone()
97
+ if len(out_y.shape) == 0:
98
+ out_y = out_y.reshape(1)
99
+ else:
100
+ out_y = torch.zeros(out_X.shape[0], device=out_X.device)
101
+
102
+ # Samples a mask setting for each example whether they should stay
103
+ # unchanged or not
104
+ mask = self._get_mask(out_X.shape[0], out_X.device)
105
+ num_valid = mask.sum().long()
106
+
107
+ if num_valid > 0:
108
+ # Uses the mask to define the output
109
+ out_X[mask, ...], tr_y = self.operation(
110
+ out_X[mask, ...],
111
+ out_y[mask],
112
+ **self.get_augmentation_params(out_X[mask, ...], out_y[mask]),
113
+ )
114
+ # Apply the operation defining the Transform to the whole batch
115
+ if isinstance(tr_y, tuple):
116
+ out_y = tuple(tmp_y[mask] for tmp_y in tr_y)
117
+ else:
118
+ out_y[mask] = tr_y
119
+
120
+ # potentially remove empty batch dimension again
121
+ out_X = out_X.reshape_as(X)
122
+ if y is not None:
123
+ return out_X, out_y
124
+ else:
125
+ return out_X
126
+
127
+ def _get_mask(self, batch_size, device) -> torch.Tensor:
128
+ """Samples whether to apply operation or not over the whole batch"""
129
+ return torch.as_tensor(self.probability > self.rng.uniform(size=batch_size)).to(
130
+ device
131
+ )
132
+
133
+ @property
134
+ def probability(self):
135
+ return self._probability
136
+
137
+
138
+ class IdentityTransform(Transform):
139
+ """Identity transform.
140
+
141
+ Transform that does not change the input.
142
+ """
143
+
144
+ operation = staticmethod(identity) # type: ignore[assignment]
145
+ # https://github.com/python/mypy/issues/4574
146
+
147
+
148
+ class Compose(Transform):
149
+ """Transform composition.
150
+
151
+ Callable class allowing to cast a sequence of Transform objects into a
152
+ single one.
153
+
154
+ Parameters
155
+ ----------
156
+ transforms: list
157
+ Sequence of Transforms to be composed.
158
+ """
159
+
160
+ def __init__(self, transforms):
161
+ self.transforms = transforms
162
+ super().__init__()
163
+
164
+ def forward(self, X, y):
165
+ for transform in self.transforms:
166
+ X, y = transform(X, y)
167
+ return X, y
168
+
169
+
170
+ def _make_collateable(transform, device=None):
171
+ """Wraps a transform to make it collateable.
172
+ with device control."""
173
+
174
+ def _collate_fn(batch):
175
+ collated_batch = default_collate(batch)
176
+ X, y = collated_batch[:2]
177
+
178
+ if device is not None:
179
+ X = X.to(device)
180
+ y = y.to(device)
181
+
182
+ return (*transform(X, y), *collated_batch[2:])
183
+
184
+ return _collate_fn
185
+
186
+
187
+ class AugmentedDataLoader(DataLoader):
188
+ """A base dataloader class customized to applying augmentation Transforms.
189
+
190
+ Parameters
191
+ ----------
192
+ dataset : BaseDataset
193
+ The dataset containing the signals.
194
+ transforms : list | Transform, optional
195
+ Transform or sequence of Transform to be applied to each batch.
196
+ device : str | torch.device | None, optional
197
+ Device on which to transform the data. Defaults to None.
198
+ **kwargs : dict, optional
199
+ keyword arguments to pass to standard DataLoader class.
200
+ """
201
+
202
+ def __init__(self, dataset, transforms=None, device=None, **kwargs):
203
+ if "collate_fn" in kwargs:
204
+ raise ValueError(
205
+ "collate_fn cannot be used in this context because it is used "
206
+ "to pass transform"
207
+ )
208
+ if transforms is None or (
209
+ isinstance(transforms, list) and len(transforms) == 0
210
+ ):
211
+ self.collated_tr = _make_collateable(IdentityTransform(), device=device)
212
+ elif isinstance(transforms, (Transform, nn.Module)):
213
+ self.collated_tr = _make_collateable(transforms, device=device)
214
+ elif isinstance(transforms, list):
215
+ self.collated_tr = _make_collateable(Compose(transforms), device=device)
216
+ else:
217
+ raise TypeError(
218
+ "transforms can be either a Transform object "
219
+ "or a list of Transform objects."
220
+ )
221
+
222
+ super().__init__(dataset, collate_fn=self.collated_tr, **kwargs)