braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,9 @@
1
+ from .classifier import EEGClassifier
2
+ from .regressor import EEGRegressor
3
+ from .version import __version__
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "EEGClassifier",
8
+ "EEGRegressor",
9
+ ]
@@ -0,0 +1,52 @@
1
+ """Utilities for data augmentation."""
2
+
3
+ from . import functional
4
+ from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
5
+ from .transforms import (
6
+ AmplitudeScale,
7
+ BandstopFilter,
8
+ ChannelsDropout,
9
+ ChannelsReref,
10
+ ChannelsShuffle,
11
+ ChannelsSymmetry,
12
+ FrequencyShift,
13
+ FTSurrogate,
14
+ GaussianNoise,
15
+ MaskEncoding,
16
+ Mixup,
17
+ SegmentationReconstruction,
18
+ SensorsRotation,
19
+ SensorsXRotation,
20
+ SensorsYRotation,
21
+ SensorsZRotation,
22
+ SignFlip,
23
+ SmoothTimeMask,
24
+ TimeReverse,
25
+ )
26
+
27
+ __all__ = [
28
+ "Transform",
29
+ "IdentityTransform",
30
+ "Compose",
31
+ "AugmentedDataLoader",
32
+ "TimeReverse",
33
+ "SignFlip",
34
+ "FTSurrogate",
35
+ "ChannelsShuffle",
36
+ "ChannelsDropout",
37
+ "GaussianNoise",
38
+ "ChannelsSymmetry",
39
+ "SmoothTimeMask",
40
+ "BandstopFilter",
41
+ "FrequencyShift",
42
+ "SensorsRotation",
43
+ "SensorsZRotation",
44
+ "SensorsYRotation",
45
+ "SensorsXRotation",
46
+ "Mixup",
47
+ "SegmentationReconstruction",
48
+ "MaskEncoding",
49
+ "AmplitudeScale",
50
+ "ChannelsReref",
51
+ "functional",
52
+ ]
@@ -0,0 +1,225 @@
1
+ # Authors: Cédric Rommel <cedric.rommel@inria.fr>
2
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
3
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
4
+ # Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>
5
+ # Valentin Iovene <val@too.gy>
6
+ # License: BSD (3-clause)
7
+
8
+ from numbers import Real
9
+ from typing import Any, Callable, Optional, Union
10
+
11
+ import torch
12
+ from sklearn.utils import check_random_state
13
+ from torch import Tensor, nn
14
+ from torch.utils.data import DataLoader
15
+ from torch.utils.data._utils.collate import default_collate
16
+
17
+ from .functional import identity
18
+
19
+ Batch = list[tuple[torch.Tensor, int, Any]]
20
+ Output = Union[
21
+ # just outputting X
22
+ torch.Tensor,
23
+ # outputting (X, y) where y can be a tensor or tuple of tensors
24
+ tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
25
+ ]
26
+ # (X, y) -> (X', y') where y' can be a tensor or a tuple of tensors
27
+ Operation = Callable[
28
+ [torch.Tensor, torch.Tensor],
29
+ tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
30
+ ]
31
+
32
+
33
+ class Transform(torch.nn.Module):
34
+ """Basic transform class used for implementing data augmentation.
35
+
36
+ operations.
37
+
38
+ Parameters
39
+ ----------
40
+ operation : callable
41
+ A function taking arrays X, y (inputs and targets resp.) and
42
+ other required arguments, and returning the transformed X and y.
43
+ probability : float, optional
44
+ Float between 0 and 1 defining the uniform probability of applying the
45
+ operation. Set to 1.0 by default (e.g always apply the operation).
46
+ random_state : int, optional
47
+ Seed to be used to instantiate numpy random number generator instance.
48
+ Used to decide whether or not to transform given the probability
49
+ argument. Defaults to None.
50
+ """
51
+
52
+ operation: Operation
53
+
54
+ def __init__(self, probability=1.0, random_state=None):
55
+ super().__init__()
56
+ if self.forward.__func__ is Transform.forward:
57
+ assert callable(self.operation), "operation should be a ``callable``. "
58
+
59
+ assert isinstance(probability, Real), (
60
+ f"probability should be a ``real``. Got {type(probability)}."
61
+ )
62
+ assert probability <= 1.0 and probability >= 0.0, (
63
+ "probability should be between 0 and 1."
64
+ )
65
+ self._probability = probability
66
+ self.rng = check_random_state(random_state)
67
+
68
+ def get_augmentation_params(self, *batch):
69
+ return dict()
70
+
71
+ def forward(self, X: Tensor, y: Optional[Tensor] = None) -> Output:
72
+ """General forward pass for an augmentation transform.
73
+
74
+ Parameters
75
+ ----------
76
+ X : torch.Tensor
77
+ EEG input example or batch.
78
+ y : torch.Tensor | None
79
+ EEG labels for the example or batch. Defaults to None.
80
+
81
+ Returns
82
+ -------
83
+ torch.Tensor
84
+ Transformed inputs.
85
+ torch.Tensor, optional
86
+ Transformed labels. Only returned when y is not None.
87
+ """
88
+ X = torch.as_tensor(X).float()
89
+
90
+ out_X = X.clone()
91
+ # check if input has a batch dimension
92
+ if len(out_X.shape) < 3:
93
+ out_X = out_X[None, ...]
94
+
95
+ if y is not None:
96
+ y = torch.as_tensor(y).to(out_X.device)
97
+ out_y = y.clone()
98
+ if len(out_y.shape) == 0:
99
+ out_y = out_y.reshape(1)
100
+ else:
101
+ out_y = torch.zeros(out_X.shape[0], device=out_X.device)
102
+
103
+ # Samples a mask setting for each example whether they should stay
104
+ # unchanged or not
105
+ mask = self._get_mask(out_X.shape[0], out_X.device)
106
+ num_valid = mask.sum().long()
107
+
108
+ if num_valid > 0:
109
+ # Uses the mask to define the output
110
+ out_X[mask, ...], tr_y = self.operation(
111
+ out_X[mask, ...],
112
+ out_y[mask],
113
+ **self.get_augmentation_params(out_X[mask, ...], out_y[mask]),
114
+ )
115
+ # Apply the operation defining the Transform to the whole batch
116
+ if isinstance(tr_y, tuple):
117
+ out_y = tuple(tmp_y[mask] for tmp_y in tr_y)
118
+ else:
119
+ out_y[mask] = tr_y
120
+
121
+ # potentially remove empty batch dimension again
122
+ out_X = out_X.reshape_as(X)
123
+ if y is not None:
124
+ return out_X, out_y
125
+ else:
126
+ return out_X
127
+
128
+ def _get_mask(self, batch_size, device) -> torch.Tensor:
129
+ """Samples whether to apply operation or not over the whole batch."""
130
+ return torch.as_tensor(self.probability > self.rng.uniform(size=batch_size)).to(
131
+ device
132
+ )
133
+
134
+ @property
135
+ def probability(self):
136
+ return self._probability
137
+
138
+
139
+ class IdentityTransform(Transform):
140
+ """Identity transform.
141
+
142
+ Transform that does not change the input.
143
+ """
144
+
145
+ operation = staticmethod(identity) # type: ignore[assignment]
146
+ # https://github.com/python/mypy/issues/4574
147
+
148
+
149
+ class Compose(Transform):
150
+ """Transform composition.
151
+
152
+ Callable class allowing to cast a sequence of Transform objects into a
153
+ single one.
154
+
155
+ Parameters
156
+ ----------
157
+ transforms : list
158
+ Sequence of Transforms to be composed.
159
+ """
160
+
161
+ def __init__(self, transforms):
162
+ self.transforms = transforms
163
+ super().__init__()
164
+
165
+ def forward(self, X, y):
166
+ for transform in self.transforms:
167
+ X, y = transform(X, y)
168
+ return X, y
169
+
170
+
171
+ def _make_collateable(transform, device=None):
172
+ """Wraps a transform to make it collateable.
173
+
174
+ with device control.
175
+ """
176
+
177
+ def _collate_fn(batch):
178
+ collated_batch = default_collate(batch)
179
+ X, y = collated_batch[:2]
180
+
181
+ if device is not None:
182
+ X = X.to(device)
183
+ y = y.to(device)
184
+
185
+ return (*transform(X, y), *collated_batch[2:])
186
+
187
+ return _collate_fn
188
+
189
+
190
+ class AugmentedDataLoader(DataLoader):
191
+ """A base dataloader class customized to applying augmentation Transforms.
192
+
193
+ Parameters
194
+ ----------
195
+ dataset : RecordDataset
196
+ The dataset containing the signals.
197
+ transforms : list | Transform, optional
198
+ Transform or sequence of Transform to be applied to each batch.
199
+ device : str | torch.device | None, optional
200
+ Device on which to transform the data. Defaults to None.
201
+ **kwargs : dict, optional
202
+ keyword arguments to pass to standard DataLoader class.
203
+ """
204
+
205
+ def __init__(self, dataset, transforms=None, device=None, **kwargs):
206
+ if "collate_fn" in kwargs:
207
+ raise ValueError(
208
+ "collate_fn cannot be used in this context because it is used "
209
+ "to pass transform"
210
+ )
211
+ if transforms is None or (
212
+ isinstance(transforms, list) and len(transforms) == 0
213
+ ):
214
+ self.collated_tr = _make_collateable(IdentityTransform(), device=device)
215
+ elif isinstance(transforms, (Transform, nn.Module)):
216
+ self.collated_tr = _make_collateable(transforms, device=device)
217
+ elif isinstance(transforms, list):
218
+ self.collated_tr = _make_collateable(Compose(transforms), device=device)
219
+ else:
220
+ raise TypeError(
221
+ "transforms can be either a Transform object "
222
+ "or a list of Transform objects."
223
+ )
224
+
225
+ super().__init__(dataset, collate_fn=self.collated_tr, **kwargs)