braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
2
|
+
# Robin Schirrmeister <robintibor@gmail.com>
|
|
3
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
4
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
5
|
+
# Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
6
|
+
#
|
|
7
|
+
# License: BSD (3-clause)
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
from skorch import NeuralNet
|
|
12
|
+
from skorch.callbacks import EpochScoring
|
|
13
|
+
from skorch.classifier import NeuralNetClassifier
|
|
14
|
+
from torch.nn import CrossEntropyLoss
|
|
15
|
+
|
|
16
|
+
from .eegneuralnet import _EEGNeuralNet
|
|
17
|
+
from .training.scoring import predict_trials
|
|
18
|
+
from .util import ThrowAwayIndexLoader, update_estimator_docstring
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
22
|
+
doc = """Classifier that does not assume softmax activation.
|
|
23
|
+
Calls loss function directly without applying log or anything.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
module: str or torch Module (class or instance)
|
|
28
|
+
Either the name of one of the braindecode models (see
|
|
29
|
+
:obj:`braindecode.models.util.models_dict`) or directly a PyTorch module.
|
|
30
|
+
When passing directly a torch module, uninstantiated class should be preferred,
|
|
31
|
+
although instantiated modules will also work.
|
|
32
|
+
|
|
33
|
+
cropped: bool (default=False)
|
|
34
|
+
Defines whether torch model passed to this class is cropped or not.
|
|
35
|
+
Currently used for callbacks definition.
|
|
36
|
+
|
|
37
|
+
callbacks: None or list of strings or list of Callback instances (default=None)
|
|
38
|
+
More callbacks, in addition to those returned by
|
|
39
|
+
``get_default_callbacks``. Each callback should inherit from
|
|
40
|
+
:class:`skorch.callbacks.Callback`. If not ``None``, callbacks can be a
|
|
41
|
+
list of strings specifying `sklearn` scoring functions (for scoring
|
|
42
|
+
functions names see: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter)
|
|
43
|
+
or a list of callbacks where the callback names are inferred from the
|
|
44
|
+
class name. Name conflicts are resolved by appending a count suffix
|
|
45
|
+
starting with 1, e.g. ``EpochScoring_1``. Alternatively,
|
|
46
|
+
a tuple ``(name, callback)`` can be passed, where ``name``
|
|
47
|
+
should be unique. Callbacks may or may not be instantiated.
|
|
48
|
+
The callback name can be used to set parameters on specific
|
|
49
|
+
callbacks (e.g., for the callback with name ``'print_log'``, use
|
|
50
|
+
``net.set_params(callbacks__print_log__keys_ignored=['epoch',
|
|
51
|
+
'train_loss'])``).
|
|
52
|
+
|
|
53
|
+
iterator_train__shuffle: bool (default=True)
|
|
54
|
+
Defines whether train dataset will be shuffled. As skorch does not
|
|
55
|
+
shuffle the train dataset by default this one overwrites this option.
|
|
56
|
+
|
|
57
|
+
aggregate_predictions: bool (default=True)
|
|
58
|
+
Whether to average cropped predictions to obtain window predictions. Used only in the
|
|
59
|
+
cropped mode.
|
|
60
|
+
|
|
61
|
+
""" # noqa: E501
|
|
62
|
+
__doc__ = update_estimator_docstring(NeuralNetClassifier, doc)
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
module,
|
|
67
|
+
*args,
|
|
68
|
+
criterion=CrossEntropyLoss,
|
|
69
|
+
cropped=False,
|
|
70
|
+
callbacks=None,
|
|
71
|
+
iterator_train__shuffle=True,
|
|
72
|
+
iterator_train__drop_last=True,
|
|
73
|
+
aggregate_predictions=True,
|
|
74
|
+
**kwargs,
|
|
75
|
+
):
|
|
76
|
+
self.cropped = cropped
|
|
77
|
+
self.aggregate_predictions = aggregate_predictions
|
|
78
|
+
self._last_window_inds_ = None
|
|
79
|
+
super().__init__(
|
|
80
|
+
module,
|
|
81
|
+
*args,
|
|
82
|
+
criterion=criterion,
|
|
83
|
+
callbacks=callbacks,
|
|
84
|
+
iterator_train__shuffle=iterator_train__shuffle,
|
|
85
|
+
iterator_train__drop_last=iterator_train__drop_last,
|
|
86
|
+
**kwargs,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def get_iterator(self, dataset, training=False, drop_index=True):
|
|
90
|
+
iterator = super().get_iterator(dataset, training=training)
|
|
91
|
+
if drop_index:
|
|
92
|
+
return ThrowAwayIndexLoader(self, iterator, is_regression=False)
|
|
93
|
+
else:
|
|
94
|
+
return iterator
|
|
95
|
+
|
|
96
|
+
def predict_proba(self, X):
|
|
97
|
+
"""Return the output of the module's forward method as a numpy.
|
|
98
|
+
|
|
99
|
+
array. In case of cropped decoding returns averaged values for
|
|
100
|
+
each trial.
|
|
101
|
+
|
|
102
|
+
If the module's forward method returns multiple outputs as a
|
|
103
|
+
tuple, it is assumed that the first output contains the
|
|
104
|
+
relevant information and the other values are ignored.
|
|
105
|
+
If all values are relevant or module's output for each crop
|
|
106
|
+
is needed, consider using :func:`~skorch.NeuralNet.forward`
|
|
107
|
+
instead.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
112
|
+
By default, you should be able to pass:
|
|
113
|
+
|
|
114
|
+
* numpy arrays
|
|
115
|
+
* torch tensors
|
|
116
|
+
* pandas DataFrame or Series
|
|
117
|
+
* scipy sparse CSR matrices
|
|
118
|
+
* a dictionary of the former three
|
|
119
|
+
* a list/tuple of the former three
|
|
120
|
+
* a Dataset
|
|
121
|
+
|
|
122
|
+
If this doesn't work with your data, you have to pass a
|
|
123
|
+
``Dataset`` that can deal with the data.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
y_proba : numpy ndarray
|
|
128
|
+
"""
|
|
129
|
+
y_pred = super().predict_proba(X)
|
|
130
|
+
# Normally, we have to average the predictions across crops/timesteps
|
|
131
|
+
# to get one prediction per window/trial
|
|
132
|
+
# Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
|
|
133
|
+
# However, when predictions are computed outside of CroppedTrialEpochScoring
|
|
134
|
+
# we have to average predictions, hence the check if len(y_pred.shape) == 3
|
|
135
|
+
if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
|
|
136
|
+
return y_pred.mean(axis=-1)
|
|
137
|
+
else:
|
|
138
|
+
return y_pred
|
|
139
|
+
|
|
140
|
+
def get_loss(self, y_pred, y_true, *args, **kwargs):
|
|
141
|
+
"""Return the loss for this batch by calling NeuralNet get_loss.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
y_pred : torch tensor
|
|
146
|
+
Predicted target values
|
|
147
|
+
y_true : torch tensor
|
|
148
|
+
True target values.
|
|
149
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
150
|
+
By default, you should be able to pass:
|
|
151
|
+
|
|
152
|
+
* numpy arrays
|
|
153
|
+
* torch tensors
|
|
154
|
+
* pandas DataFrame or Series
|
|
155
|
+
* scipy sparse CSR matrices
|
|
156
|
+
* a dictionary of the former three
|
|
157
|
+
* a list/tuple of the former three
|
|
158
|
+
* a Dataset
|
|
159
|
+
|
|
160
|
+
If this doesn't work with your data, you have to pass a
|
|
161
|
+
``Dataset`` that can deal with the data.
|
|
162
|
+
training : bool (default=False)
|
|
163
|
+
Whether train mode should be used or not.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
loss : float
|
|
168
|
+
The loss value.
|
|
169
|
+
"""
|
|
170
|
+
return NeuralNet.get_loss(self, y_pred, y_true, *args, **kwargs)
|
|
171
|
+
|
|
172
|
+
def predict(self, X):
|
|
173
|
+
"""Return class labels for samples in X.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
178
|
+
By default, you should be able to pass:
|
|
179
|
+
|
|
180
|
+
* numpy arrays
|
|
181
|
+
* torch tensors
|
|
182
|
+
* pandas DataFrame or Series
|
|
183
|
+
* scipy sparse CSR matrices
|
|
184
|
+
* a dictionary of the former three
|
|
185
|
+
* a list/tuple of the former three
|
|
186
|
+
* a Dataset
|
|
187
|
+
|
|
188
|
+
If this doesn't work with your data, you have to pass a
|
|
189
|
+
``Dataset`` that can deal with the data.
|
|
190
|
+
|
|
191
|
+
Returns
|
|
192
|
+
-------
|
|
193
|
+
y_pred : numpy ndarray
|
|
194
|
+
"""
|
|
195
|
+
return self.predict_proba(X).argmax(1)
|
|
196
|
+
|
|
197
|
+
def predict_trials(self, X, return_targets=True):
|
|
198
|
+
"""Create trialwise predictions and optionally also return trialwise.
|
|
199
|
+
|
|
200
|
+
labels from cropped dataset.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
X : braindecode.datasets.BaseConcatDataset
|
|
205
|
+
A braindecode dataset to be predicted.
|
|
206
|
+
return_targets : bool
|
|
207
|
+
If True, additionally returns the trial targets.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
trial_predictions : np.ndarray
|
|
212
|
+
3-dimensional array (n_trials x n_classes x n_predictions), where
|
|
213
|
+
the number of predictions depend on the chosen window size and the
|
|
214
|
+
receptive field of the network.
|
|
215
|
+
trial_labels : np.ndarray
|
|
216
|
+
2-dimensional array (n_trials x n_targets) where the number of
|
|
217
|
+
targets depends on the decoding paradigm and can be either a single
|
|
218
|
+
value, multiple values, or a sequence.
|
|
219
|
+
"""
|
|
220
|
+
if not self.cropped:
|
|
221
|
+
warnings.warn(
|
|
222
|
+
"This method was designed to predict trials in cropped mode. "
|
|
223
|
+
"Calling it when cropped is False will give the same result as "
|
|
224
|
+
"'.predict'.",
|
|
225
|
+
UserWarning,
|
|
226
|
+
)
|
|
227
|
+
preds = self.predict(X)
|
|
228
|
+
if return_targets:
|
|
229
|
+
return preds, X.get_metadata()["target"].to_numpy()
|
|
230
|
+
return preds
|
|
231
|
+
return predict_trials(
|
|
232
|
+
module=self.module,
|
|
233
|
+
dataset=X,
|
|
234
|
+
return_targets=return_targets,
|
|
235
|
+
batch_size=self.batch_size,
|
|
236
|
+
num_workers=self.get_iterator(X, training=False).loader.num_workers,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def mode(self):
|
|
241
|
+
return "classification"
|
|
242
|
+
|
|
243
|
+
# Only add the 'accuracy' callback if we are not in cropped mode.
|
|
244
|
+
@property
|
|
245
|
+
def _default_callbacks(self):
|
|
246
|
+
callbacks = list(super()._default_callbacks)
|
|
247
|
+
if not self.cropped:
|
|
248
|
+
callbacks.append(
|
|
249
|
+
(
|
|
250
|
+
"valid_acc",
|
|
251
|
+
EpochScoring(
|
|
252
|
+
"accuracy",
|
|
253
|
+
name="valid_acc",
|
|
254
|
+
lower_is_better=False,
|
|
255
|
+
),
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
return callbacks
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Loader code for some datasets."""
|
|
2
|
+
|
|
3
|
+
from .base import (
|
|
4
|
+
BaseConcatDataset,
|
|
5
|
+
EEGWindowsDataset,
|
|
6
|
+
RawDataset,
|
|
7
|
+
RecordDataset,
|
|
8
|
+
WindowsDataset,
|
|
9
|
+
)
|
|
10
|
+
from .bcicomp import BCICompetitionIVDataset4
|
|
11
|
+
from .bids import BIDSDataset, BIDSEpochsDataset
|
|
12
|
+
from .chb_mit import CHBMIT
|
|
13
|
+
from .mne import create_from_mne_epochs, create_from_mne_raw
|
|
14
|
+
from .moabb import BNCI2014_001, HGD, MOABBDataset
|
|
15
|
+
from .nmt import NMT
|
|
16
|
+
from .siena import SIENA
|
|
17
|
+
from .sleep_physio_challe_18 import SleepPhysionetChallenge2018
|
|
18
|
+
from .sleep_physionet import SleepPhysionet
|
|
19
|
+
from .tuh import TUH, TUHAbnormal
|
|
20
|
+
from .xy import create_from_X_y
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"WindowsDataset",
|
|
24
|
+
"EEGWindowsDataset",
|
|
25
|
+
"RecordDataset",
|
|
26
|
+
"RawDataset",
|
|
27
|
+
"BaseConcatDataset",
|
|
28
|
+
"BIDSDataset",
|
|
29
|
+
"BIDSEpochsDataset",
|
|
30
|
+
"MOABBDataset",
|
|
31
|
+
"HGD",
|
|
32
|
+
"BNCI2014_001",
|
|
33
|
+
"create_from_mne_raw",
|
|
34
|
+
"create_from_mne_epochs",
|
|
35
|
+
"TUH",
|
|
36
|
+
"TUHAbnormal",
|
|
37
|
+
"SIENA",
|
|
38
|
+
"NMT",
|
|
39
|
+
"CHBMIT",
|
|
40
|
+
"SleepPhysionet",
|
|
41
|
+
"SleepPhysionetChallenge2018",
|
|
42
|
+
"create_from_X_y",
|
|
43
|
+
"BCICompetitionIVDataset4",
|
|
44
|
+
]
|