braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/__init__.py CHANGED
@@ -1,7 +1,6 @@
1
- from .version import __version__
2
-
3
1
  from .classifier import EEGClassifier
4
2
  from .regressor import EEGRegressor
3
+ from .version import __version__
5
4
 
6
5
  __all__ = [
7
6
  "__version__",
@@ -1,30 +1,50 @@
1
1
  """
2
2
  Utilities for data augmentation.
3
3
  """
4
- from .base import Transform, IdentityTransform, Compose, AugmentedDataLoader
4
+
5
+ from . import functional
6
+ from .base import AugmentedDataLoader, Compose, IdentityTransform, Transform
5
7
  from .transforms import (
6
- TimeReverse,
7
- SignFlip,
8
- FTSurrogate,
9
- ChannelsShuffle,
8
+ BandstopFilter,
10
9
  ChannelsDropout,
11
- GaussianNoise,
10
+ ChannelsShuffle,
12
11
  ChannelsSymmetry,
13
- SmoothTimeMask,
14
- BandstopFilter,
15
12
  FrequencyShift,
13
+ FTSurrogate,
14
+ GaussianNoise,
15
+ MaskEncoding,
16
+ Mixup,
17
+ SegmentationReconstruction,
16
18
  SensorsRotation,
17
- SensorsZRotation,
18
- SensorsYRotation,
19
19
  SensorsXRotation,
20
- Mixup,
20
+ SensorsYRotation,
21
+ SensorsZRotation,
22
+ SignFlip,
23
+ SmoothTimeMask,
24
+ TimeReverse,
21
25
  )
22
26
 
23
- from . import functional
24
-
25
- __all__ = ["Transform", "IdentityTransform", "Compose", "AugmentedDataLoader",
26
- "TimeReverse", "SignFlip", "FTSurrogate", "ChannelsShuffle",
27
- "ChannelsDropout", "GaussianNoise", "ChannelsSymmetry",
28
- "SmoothTimeMask", "BandstopFilter", "FrequencyShift",
29
- "SensorsRotation", "SensorsZRotation", "SensorsYRotation",
30
- "SensorsXRotation", "Mixup", "functional"]
27
+ __all__ = [
28
+ "Transform",
29
+ "IdentityTransform",
30
+ "Compose",
31
+ "AugmentedDataLoader",
32
+ "TimeReverse",
33
+ "SignFlip",
34
+ "FTSurrogate",
35
+ "ChannelsShuffle",
36
+ "ChannelsDropout",
37
+ "GaussianNoise",
38
+ "ChannelsSymmetry",
39
+ "SmoothTimeMask",
40
+ "BandstopFilter",
41
+ "FrequencyShift",
42
+ "SensorsRotation",
43
+ "SensorsZRotation",
44
+ "SensorsYRotation",
45
+ "SensorsXRotation",
46
+ "Mixup",
47
+ "SegmentationReconstruction",
48
+ "MaskEncoding",
49
+ "functional",
50
+ ]
@@ -5,31 +5,28 @@
5
5
  # Valentin Iovene <val@too.gy>
6
6
  # License: BSD (3-clause)
7
7
 
8
- from typing import List, Tuple, Any, Optional, Union, Callable
9
8
  from numbers import Real
9
+ from typing import Any, Callable, Optional, Union
10
10
 
11
- from sklearn.utils import check_random_state
12
11
  import torch
12
+ from sklearn.utils import check_random_state
13
13
  from torch import Tensor, nn
14
14
  from torch.utils.data import DataLoader
15
15
  from torch.utils.data._utils.collate import default_collate
16
16
 
17
17
  from .functional import identity
18
18
 
19
- Batch = List[Tuple[torch.Tensor, int, Any]]
19
+ Batch = list[tuple[torch.Tensor, int, Any]]
20
20
  Output = Union[
21
21
  # just outputting X
22
22
  torch.Tensor,
23
23
  # outputting (X, y) where y can be a tensor or tuple of tensors
24
- Tuple[torch.Tensor, Union[torch.Tensor, Tuple[torch.Tensor, ...]]]
24
+ tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
25
25
  ]
26
26
  # (X, y) -> (X', y') where y' can be a tensor or a tuple of tensors
27
27
  Operation = Callable[
28
28
  [torch.Tensor, torch.Tensor],
29
- Tuple[
30
- torch.Tensor,
31
- Union[torch.Tensor, Tuple[torch.Tensor, ...]]
32
- ]
29
+ tuple[torch.Tensor, Union[torch.Tensor, tuple[torch.Tensor, ...]]],
33
30
  ]
34
31
 
35
32
 
@@ -50,18 +47,20 @@ class Transform(torch.nn.Module):
50
47
  Used to decide whether or not to transform given the probability
51
48
  argument. Defaults to None.
52
49
  """
50
+
53
51
  operation: Operation
54
52
 
55
53
  def __init__(self, probability=1.0, random_state=None):
56
54
  super().__init__()
57
55
  if self.forward.__func__ is Transform.forward:
58
- assert callable(self.operation), "operation should be a " \
59
- "``callable``. "
56
+ assert callable(self.operation), "operation should be a ``callable``. "
60
57
 
61
58
  assert isinstance(probability, Real), (
62
- f"probability should be a ``real``. Got {type(probability)}.")
63
- assert probability <= 1. and probability >= 0., \
59
+ f"probability should be a ``real``. Got {type(probability)}."
60
+ )
61
+ assert probability <= 1.0 and probability >= 0.0, (
64
62
  "probability should be between 0 and 1."
63
+ )
65
64
  self._probability = probability
66
65
  self.rng = check_random_state(random_state)
67
66
 
@@ -108,8 +107,9 @@ class Transform(torch.nn.Module):
108
107
  if num_valid > 0:
109
108
  # Uses the mask to define the output
110
109
  out_X[mask, ...], tr_y = self.operation(
111
- out_X[mask, ...], out_y[mask],
112
- **self.get_augmentation_params(out_X[mask, ...], out_y[mask])
110
+ out_X[mask, ...],
111
+ out_y[mask],
112
+ **self.get_augmentation_params(out_X[mask, ...], out_y[mask]),
113
113
  )
114
114
  # Apply the operation defining the Transform to the whole batch
115
115
  if isinstance(tr_y, tuple):
@@ -125,11 +125,10 @@ class Transform(torch.nn.Module):
125
125
  return out_X
126
126
 
127
127
  def _get_mask(self, batch_size, device) -> torch.Tensor:
128
- """Samples whether to apply operation or not over the whole batch
129
- """
130
- return torch.as_tensor(
131
- self.probability > self.rng.uniform(size=batch_size)
132
- ).to(device)
128
+ """Samples whether to apply operation or not over the whole batch"""
129
+ return torch.as_tensor(self.probability > self.rng.uniform(size=batch_size)).to(
130
+ device
131
+ )
133
132
 
134
133
  @property
135
134
  def probability(self):
@@ -141,7 +140,9 @@ class IdentityTransform(Transform):
141
140
 
142
141
  Transform that does not change the input.
143
142
  """
144
- operation = staticmethod(identity)
143
+
144
+ operation = staticmethod(identity) # type: ignore[assignment]
145
+ # https://github.com/python/mypy/issues/4574
145
146
 
146
147
 
147
148
  class Compose(Transform):
@@ -167,8 +168,8 @@ class Compose(Transform):
167
168
 
168
169
 
169
170
  def _make_collateable(transform, device=None):
170
- """ Wraps a transform to make it collateable.
171
- with device control. """
171
+ """Wraps a transform to make it collateable.
172
+ with device control."""
172
173
 
173
174
  def _collate_fn(batch):
174
175
  collated_batch = default_collate(batch)
@@ -205,7 +206,7 @@ class AugmentedDataLoader(DataLoader):
205
206
  "to pass transform"
206
207
  )
207
208
  if transforms is None or (
208
- isinstance(transforms, list) and len(transforms) == 0
209
+ isinstance(transforms, list) and len(transforms) == 0
209
210
  ):
210
211
  self.collated_tr = _make_collateable(IdentityTransform(), device=device)
211
212
  elif isinstance(transforms, (Transform, nn.Module)):
@@ -218,8 +219,4 @@ class AugmentedDataLoader(DataLoader):
218
219
  "or a list of Transform objects."
219
220
  )
220
221
 
221
- super().__init__(
222
- dataset,
223
- collate_fn=self.collated_tr,
224
- **kwargs
225
- )
222
+ super().__init__(dataset, collate_fn=self.collated_tr, **kwargs)