braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/classifier.py CHANGED
@@ -10,8 +10,8 @@ import warnings
10
10
 
11
11
  import numpy as np
12
12
  from skorch import NeuralNet
13
- from skorch.classifier import NeuralNetClassifier
14
13
  from skorch.callbacks import EpochScoring
14
+ from skorch.classifier import NeuralNetClassifier
15
15
  from torch.nn import CrossEntropyLoss
16
16
 
17
17
  from .eegneuralnet import _EEGNeuralNet
@@ -63,16 +63,16 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
63
63
  __doc__ = update_estimator_docstring(NeuralNetClassifier, doc)
64
64
 
65
65
  def __init__(
66
- self,
67
- module,
68
- *args,
69
- criterion=CrossEntropyLoss,
70
- cropped=False,
71
- callbacks=None,
72
- iterator_train__shuffle=True,
73
- iterator_train__drop_last=True,
74
- aggregate_predictions=True,
75
- **kwargs
66
+ self,
67
+ module,
68
+ *args,
69
+ criterion=CrossEntropyLoss,
70
+ cropped=False,
71
+ callbacks=None,
72
+ iterator_train__shuffle=True,
73
+ iterator_train__drop_last=True,
74
+ aggregate_predictions=True,
75
+ **kwargs,
76
76
  ):
77
77
  self.cropped = cropped
78
78
  self.aggregate_predictions = aggregate_predictions
@@ -133,8 +133,7 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
133
133
  # Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
134
134
  # However, when predictions are computed outside of CroppedTrialEpochScoring
135
135
  # we have to average predictions, hence the check if len(y_pred.shape) == 3
136
- if self.cropped and self.aggregate_predictions and len(
137
- y_pred.shape) == 3:
136
+ if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
138
137
  return y_pred.mean(axis=-1)
139
138
  else:
140
139
  return y_pred
@@ -223,18 +222,19 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
223
222
  warnings.warn(
224
223
  "This method was designed to predict trials in cropped mode. "
225
224
  "Calling it when cropped is False will give the same result as "
226
- "'.predict'.", UserWarning)
225
+ "'.predict'.",
226
+ UserWarning,
227
+ )
227
228
  preds = self.predict(X)
228
229
  if return_targets:
229
- return preds, X.get_metadata()['target'].to_numpy()
230
+ return preds, X.get_metadata()["target"].to_numpy()
230
231
  return preds
231
232
  return predict_trials(
232
233
  module=self.module,
233
234
  dataset=X,
234
235
  return_targets=return_targets,
235
236
  batch_size=self.batch_size,
236
- num_workers=self.get_iterator(X,
237
- training=False).loader.num_workers,
237
+ num_workers=self.get_iterator(X, training=False).loader.num_workers,
238
238
  )
239
239
 
240
240
  def _get_n_outputs(self, y, classes):
@@ -250,12 +250,14 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
250
250
  def _default_callbacks(self):
251
251
  callbacks = list(super()._default_callbacks)
252
252
  if not self.cropped:
253
- callbacks.append((
254
- 'valid_acc',
255
- EpochScoring(
256
- 'accuracy',
257
- name='valid_acc',
258
- lower_is_better=False,
253
+ callbacks.append(
254
+ (
255
+ "valid_acc",
256
+ EpochScoring(
257
+ "accuracy",
258
+ name="valid_acc",
259
+ lower_is_better=False,
260
+ ),
259
261
  )
260
- ))
262
+ )
261
263
  return callbacks
@@ -0,0 +1,34 @@
1
+ """
2
+ Loader code for some datasets.
3
+ """
4
+
5
+ from .base import BaseConcatDataset, BaseDataset, WindowsDataset
6
+ from .bcicomp import BCICompetitionIVDataset4
7
+ from .bids import BIDSDataset, BIDSEpochsDataset
8
+ from .mne import create_from_mne_epochs, create_from_mne_raw
9
+ from .moabb import BNCI2014001, HGD, MOABBDataset
10
+ from .nmt import NMT
11
+ from .sleep_physio_challe_18 import SleepPhysionetChallenge2018
12
+ from .sleep_physionet import SleepPhysionet
13
+ from .tuh import TUH, TUHAbnormal
14
+ from .xy import create_from_X_y
15
+
16
+ __all__ = [
17
+ "WindowsDataset",
18
+ "BaseDataset",
19
+ "BaseConcatDataset",
20
+ "BIDSDataset",
21
+ "BIDSEpochsDataset",
22
+ "MOABBDataset",
23
+ "HGD",
24
+ "BNCI2014001",
25
+ "create_from_mne_raw",
26
+ "create_from_mne_epochs",
27
+ "TUH",
28
+ "TUHAbnormal",
29
+ "NMT",
30
+ "SleepPhysionet",
31
+ "SleepPhysionetChallenge2018",
32
+ "create_from_X_y",
33
+ "BCICompetitionIVDataset4",
34
+ ]