braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/classifier.py CHANGED
@@ -10,8 +10,8 @@ import warnings
10
10
 
11
11
  import numpy as np
12
12
  from skorch import NeuralNet
13
- from skorch.classifier import NeuralNetClassifier
14
13
  from skorch.callbacks import EpochScoring
14
+ from skorch.classifier import NeuralNetClassifier
15
15
  from torch.nn import CrossEntropyLoss
16
16
 
17
17
  from .eegneuralnet import _EEGNeuralNet
@@ -63,16 +63,16 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
63
63
  __doc__ = update_estimator_docstring(NeuralNetClassifier, doc)
64
64
 
65
65
  def __init__(
66
- self,
67
- module,
68
- *args,
69
- criterion=CrossEntropyLoss,
70
- cropped=False,
71
- callbacks=None,
72
- iterator_train__shuffle=True,
73
- iterator_train__drop_last=True,
74
- aggregate_predictions=True,
75
- **kwargs
66
+ self,
67
+ module,
68
+ *args,
69
+ criterion=CrossEntropyLoss,
70
+ cropped=False,
71
+ callbacks=None,
72
+ iterator_train__shuffle=True,
73
+ iterator_train__drop_last=True,
74
+ aggregate_predictions=True,
75
+ **kwargs,
76
76
  ):
77
77
  self.cropped = cropped
78
78
  self.aggregate_predictions = aggregate_predictions
@@ -133,8 +133,7 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
133
133
  # Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
134
134
  # However, when predictions are computed outside of CroppedTrialEpochScoring
135
135
  # we have to average predictions, hence the check if len(y_pred.shape) == 3
136
- if self.cropped and self.aggregate_predictions and len(
137
- y_pred.shape) == 3:
136
+ if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
138
137
  return y_pred.mean(axis=-1)
139
138
  else:
140
139
  return y_pred
@@ -223,18 +222,19 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
223
222
  warnings.warn(
224
223
  "This method was designed to predict trials in cropped mode. "
225
224
  "Calling it when cropped is False will give the same result as "
226
- "'.predict'.", UserWarning)
225
+ "'.predict'.",
226
+ UserWarning,
227
+ )
227
228
  preds = self.predict(X)
228
229
  if return_targets:
229
- return preds, X.get_metadata()['target'].to_numpy()
230
+ return preds, X.get_metadata()["target"].to_numpy()
230
231
  return preds
231
232
  return predict_trials(
232
233
  module=self.module,
233
234
  dataset=X,
234
235
  return_targets=return_targets,
235
236
  batch_size=self.batch_size,
236
- num_workers=self.get_iterator(X,
237
- training=False).loader.num_workers,
237
+ num_workers=self.get_iterator(X, training=False).loader.num_workers,
238
238
  )
239
239
 
240
240
  def _get_n_outputs(self, y, classes):
@@ -250,12 +250,14 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
250
250
  def _default_callbacks(self):
251
251
  callbacks = list(super()._default_callbacks)
252
252
  if not self.cropped:
253
- callbacks.append((
254
- 'valid_acc',
255
- EpochScoring(
256
- 'accuracy',
257
- name='valid_acc',
258
- lower_is_better=False,
253
+ callbacks.append(
254
+ (
255
+ "valid_acc",
256
+ EpochScoring(
257
+ "accuracy",
258
+ name="valid_acc",
259
+ lower_is_better=False,
260
+ ),
259
261
  )
260
- ))
262
+ )
261
263
  return callbacks
@@ -1,16 +1,34 @@
1
1
  """
2
2
  Loader code for some datasets.
3
3
  """
4
- from .base import WindowsDataset, BaseDataset, BaseConcatDataset
5
- from .moabb import MOABBDataset, HGD, BNCI2014001
6
- from .mne import create_from_mne_raw, create_from_mne_epochs
7
- from .tuh import TUH, TUHAbnormal
4
+
5
+ from .base import BaseConcatDataset, BaseDataset, WindowsDataset
6
+ from .bcicomp import BCICompetitionIVDataset4
7
+ from .bids import BIDSDataset, BIDSEpochsDataset
8
+ from .mne import create_from_mne_epochs, create_from_mne_raw
9
+ from .moabb import BNCI2014001, HGD, MOABBDataset
10
+ from .nmt import NMT
11
+ from .sleep_physio_challe_18 import SleepPhysionetChallenge2018
8
12
  from .sleep_physionet import SleepPhysionet
13
+ from .tuh import TUH, TUHAbnormal
9
14
  from .xy import create_from_X_y
10
- from .bcicomp import BCICompetitionIVDataset4
11
15
 
12
- __all__ = ["WindowsDataset", "BaseDataset", "BaseConcatDataset",
13
- "MOABBDataset", "HGD", "BNCI2014001",
14
- "create_from_mne_raw", "create_from_mne_epochs",
15
- "TUH", "TUHAbnormal", "SleepPhysionet", "create_from_X_y",
16
- "BCICompetitionIVDataset4"]
16
+ __all__ = [
17
+ "WindowsDataset",
18
+ "BaseDataset",
19
+ "BaseConcatDataset",
20
+ "BIDSDataset",
21
+ "BIDSEpochsDataset",
22
+ "MOABBDataset",
23
+ "HGD",
24
+ "BNCI2014001",
25
+ "create_from_mne_raw",
26
+ "create_from_mne_epochs",
27
+ "TUH",
28
+ "TUHAbnormal",
29
+ "NMT",
30
+ "SleepPhysionet",
31
+ "SleepPhysionetChallenge2018",
32
+ "create_from_X_y",
33
+ "BCICompetitionIVDataset4",
34
+ ]