braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,258 @@
1
+ # Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
2
+ # Robin Schirrmeister <robintibor@gmail.com>
3
+ # Lukas Gemein <l.gemein@gmail.com>
4
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ # Pierre Guetschel <pierre.guetschel@gmail.com>
6
+ #
7
+ # License: BSD (3-clause)
8
+
9
+ import warnings
10
+
11
+ from skorch import NeuralNet
12
+ from skorch.callbacks import EpochScoring
13
+ from skorch.classifier import NeuralNetClassifier
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from .eegneuralnet import _EEGNeuralNet
17
+ from .training.scoring import predict_trials
18
+ from .util import ThrowAwayIndexLoader, update_estimator_docstring
19
+
20
+
21
+ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
22
+ doc = """Classifier that does not assume softmax activation.
23
+ Calls loss function directly without applying log or anything.
24
+
25
+ Parameters
26
+ ----------
27
+ module: str or torch Module (class or instance)
28
+ Either the name of one of the braindecode models (see
29
+ :obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.
30
+ When passing directly a torch module, uninstantiated class should be preferred,
31
+ although instantiated modules will also work.
32
+
33
+ cropped: bool (default=False)
34
+ Defines whether torch model passed to this class is cropped or not.
35
+ Currently used for callbacks definition.
36
+
37
+ callbacks: None or list of strings or list of Callback instances (default=None)
38
+ More callbacks, in addition to those returned by
39
+ ``get_default_callbacks``. Each callback should inherit from
40
+ :class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a
41
+ list of strings specifying `sklearn` scoring functions (for scoring
42
+ functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)
43
+ or a list of callbacks where the callback names are inferred from the
44
+ class name. Name conflicts are resolved by appending a count suffix
45
+ starting with 1, e.g. ``EpochScoring_1``. Alternatively,
46
+ a tuple ``(name, callback)`` can be passed, where ``name``
47
+ should be unique. Callbacks may or may not be instantiated.
48
+ The callback name can be used to set parameters on specific
49
+ callbacks (e.g., for the callback with name ``'print_log'``, use
50
+ ``net.set_params(callbacks__print_log__keys_ignored=['epoch',
51
+ 'train_loss'])``).
52
+
53
+ iterator_train__shuffle: bool (default=True)
54
+ Defines whether train dataset will be shuffled. As skorch does not
55
+ shuffle the train dataset by default this one overwrites this option.
56
+
57
+ aggregate_predictions: bool (default=True)
58
+ Whether to average cropped predictions to obtain window predictions. Used only in the
59
+ cropped mode.
60
+
61
+ """ # noqa: E501
62
+ __doc__ = update_estimator_docstring(NeuralNetClassifier, doc)
63
+
64
+ def __init__(
65
+ self,
66
+ module,
67
+ *args,
68
+ criterion=CrossEntropyLoss,
69
+ cropped=False,
70
+ callbacks=None,
71
+ iterator_train__shuffle=True,
72
+ iterator_train__drop_last=True,
73
+ aggregate_predictions=True,
74
+ **kwargs,
75
+ ):
76
+ self.cropped = cropped
77
+ self.aggregate_predictions = aggregate_predictions
78
+ self._last_window_inds_ = None
79
+ super().__init__(
80
+ module,
81
+ *args,
82
+ criterion=criterion,
83
+ callbacks=callbacks,
84
+ iterator_train__shuffle=iterator_train__shuffle,
85
+ iterator_train__drop_last=iterator_train__drop_last,
86
+ **kwargs,
87
+ )
88
+
89
+ def get_iterator(self, dataset, training=False, drop_index=True):
90
+ iterator = super().get_iterator(dataset, training=training)
91
+ if drop_index:
92
+ return ThrowAwayIndexLoader(self, iterator, is_regression=False)
93
+ else:
94
+ return iterator
95
+
96
+ def predict_proba(self, X):
97
+ """Return the output of the module's forward method as a numpy.
98
+
99
+ array. In case of cropped decoding returns averaged values for
100
+ each trial.
101
+
102
+ If the module's forward method returns multiple outputs as a
103
+ tuple, it is assumed that the first output contains the
104
+ relevant information and the other values are ignored.
105
+ If all values are relevant or module's output for each crop
106
+ is needed, consider using :func:`~skorch.NeuralNet.forward`
107
+ instead.
108
+
109
+ Parameters
110
+ ----------
111
+ X : input data, compatible with skorch.dataset.Dataset
112
+ By default, you should be able to pass:
113
+
114
+ * numpy arrays
115
+ * torch tensors
116
+ * pandas DataFrame or Series
117
+ * scipy sparse CSR matrices
118
+ * a dictionary of the former three
119
+ * a list/tuple of the former three
120
+ * a Dataset
121
+
122
+ If this doesn't work with your data, you have to pass a
123
+ ``Dataset`` that can deal with the data.
124
+
125
+ Returns
126
+ -------
127
+ y_proba : numpy ndarray
128
+ """
129
+ y_pred = super().predict_proba(X)
130
+ # Normally, we have to average the predictions across crops/timesteps
131
+ # to get one prediction per window/trial
132
+ # Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
133
+ # However, when predictions are computed outside of CroppedTrialEpochScoring
134
+ # we have to average predictions, hence the check if len(y_pred.shape) == 3
135
+ if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
136
+ return y_pred.mean(axis=-1)
137
+ else:
138
+ return y_pred
139
+
140
+ def get_loss(self, y_pred, y_true, *args, **kwargs):
141
+ """Return the loss for this batch by calling NeuralNet get_loss.
142
+
143
+ Parameters
144
+ ----------
145
+ y_pred : torch tensor
146
+ Predicted target values
147
+ y_true : torch tensor
148
+ True target values.
149
+ X : input data, compatible with skorch.dataset.Dataset
150
+ By default, you should be able to pass:
151
+
152
+ * numpy arrays
153
+ * torch tensors
154
+ * pandas DataFrame or Series
155
+ * scipy sparse CSR matrices
156
+ * a dictionary of the former three
157
+ * a list/tuple of the former three
158
+ * a Dataset
159
+
160
+ If this doesn't work with your data, you have to pass a
161
+ ``Dataset`` that can deal with the data.
162
+ training : bool (default=False)
163
+ Whether train mode should be used or not.
164
+
165
+ Returns
166
+ -------
167
+ loss : float
168
+ The loss value.
169
+ """
170
+ return NeuralNet.get_loss(self, y_pred, y_true, *args, **kwargs)
171
+
172
+ def predict(self, X):
173
+ """Return class labels for samples in X.
174
+
175
+ Parameters
176
+ ----------
177
+ X : input data, compatible with skorch.dataset.Dataset
178
+ By default, you should be able to pass:
179
+
180
+ * numpy arrays
181
+ * torch tensors
182
+ * pandas DataFrame or Series
183
+ * scipy sparse CSR matrices
184
+ * a dictionary of the former three
185
+ * a list/tuple of the former three
186
+ * a Dataset
187
+
188
+ If this doesn't work with your data, you have to pass a
189
+ ``Dataset`` that can deal with the data.
190
+
191
+ Returns
192
+ -------
193
+ y_pred : numpy ndarray
194
+ """
195
+ return self.predict_proba(X).argmax(1)
196
+
197
+ def predict_trials(self, X, return_targets=True):
198
+ """Create trialwise predictions and optionally also return trialwise.
199
+
200
+ labels from cropped dataset.
201
+
202
+ Parameters
203
+ ----------
204
+ X : braindecode.datasets.BaseConcatDataset
205
+ A braindecode dataset to be predicted.
206
+ return_targets : bool
207
+ If True, additionally returns the trial targets.
208
+
209
+ Returns
210
+ -------
211
+ trial_predictions : np.ndarray
212
+ 3-dimensional array (n_trials x n_classes x n_predictions), where
213
+ the number of predictions depend on the chosen window size and the
214
+ receptive field of the network.
215
+ trial_labels : np.ndarray
216
+ 2-dimensional array (n_trials x n_targets) where the number of
217
+ targets depends on the decoding paradigm and can be either a single
218
+ value, multiple values, or a sequence.
219
+ """
220
+ if not self.cropped:
221
+ warnings.warn(
222
+ "This method was designed to predict trials in cropped mode. "
223
+ "Calling it when cropped is False will give the same result as "
224
+ "'.predict'.",
225
+ UserWarning,
226
+ )
227
+ preds = self.predict(X)
228
+ if return_targets:
229
+ return preds, X.get_metadata()["target"].to_numpy()
230
+ return preds
231
+ return predict_trials(
232
+ module=self.module,
233
+ dataset=X,
234
+ return_targets=return_targets,
235
+ batch_size=self.batch_size,
236
+ num_workers=self.get_iterator(X, training=False).loader.num_workers,
237
+ )
238
+
239
+ @property
240
+ def mode(self):
241
+ return "classification"
242
+
243
+ # Only add the 'accuracy' callback if we are not in cropped mode.
244
+ @property
245
+ def _default_callbacks(self):
246
+ callbacks = list(super()._default_callbacks)
247
+ if not self.cropped:
248
+ callbacks.append(
249
+ (
250
+ "valid_acc",
251
+ EpochScoring(
252
+ "accuracy",
253
+ name="valid_acc",
254
+ lower_is_better=False,
255
+ ),
256
+ )
257
+ )
258
+ return callbacks
@@ -0,0 +1,44 @@
1
+ """Loader code for some datasets."""
2
+
3
+ from .base import (
4
+ BaseConcatDataset,
5
+ EEGWindowsDataset,
6
+ RawDataset,
7
+ RecordDataset,
8
+ WindowsDataset,
9
+ )
10
+ from .bcicomp import BCICompetitionIVDataset4
11
+ from .bids import BIDSDataset, BIDSEpochsDataset
12
+ from .chb_mit import CHBMIT
13
+ from .mne import create_from_mne_epochs, create_from_mne_raw
14
+ from .moabb import BNCI2014_001, HGD, MOABBDataset
15
+ from .nmt import NMT
16
+ from .siena import SIENA
17
+ from .sleep_physio_challe_18 import SleepPhysionetChallenge2018
18
+ from .sleep_physionet import SleepPhysionet
19
+ from .tuh import TUH, TUHAbnormal
20
+ from .xy import create_from_X_y
21
+
22
+ __all__ = [
23
+ "WindowsDataset",
24
+ "EEGWindowsDataset",
25
+ "RecordDataset",
26
+ "RawDataset",
27
+ "BaseConcatDataset",
28
+ "BIDSDataset",
29
+ "BIDSEpochsDataset",
30
+ "MOABBDataset",
31
+ "HGD",
32
+ "BNCI2014_001",
33
+ "create_from_mne_raw",
34
+ "create_from_mne_epochs",
35
+ "TUH",
36
+ "TUHAbnormal",
37
+ "SIENA",
38
+ "NMT",
39
+ "CHBMIT",
40
+ "SleepPhysionet",
41
+ "SleepPhysionetChallenge2018",
42
+ "create_from_X_y",
43
+ "BCICompetitionIVDataset4",
44
+ ]