braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/classifier.py
CHANGED
|
@@ -10,8 +10,8 @@ import warnings
|
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
from skorch import NeuralNet
|
|
13
|
-
from skorch.classifier import NeuralNetClassifier
|
|
14
13
|
from skorch.callbacks import EpochScoring
|
|
14
|
+
from skorch.classifier import NeuralNetClassifier
|
|
15
15
|
from torch.nn import CrossEntropyLoss
|
|
16
16
|
|
|
17
17
|
from .eegneuralnet import _EEGNeuralNet
|
|
@@ -63,16 +63,16 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
63
63
|
__doc__ = update_estimator_docstring(NeuralNetClassifier, doc)
|
|
64
64
|
|
|
65
65
|
def __init__(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
66
|
+
self,
|
|
67
|
+
module,
|
|
68
|
+
*args,
|
|
69
|
+
criterion=CrossEntropyLoss,
|
|
70
|
+
cropped=False,
|
|
71
|
+
callbacks=None,
|
|
72
|
+
iterator_train__shuffle=True,
|
|
73
|
+
iterator_train__drop_last=True,
|
|
74
|
+
aggregate_predictions=True,
|
|
75
|
+
**kwargs,
|
|
76
76
|
):
|
|
77
77
|
self.cropped = cropped
|
|
78
78
|
self.aggregate_predictions = aggregate_predictions
|
|
@@ -133,8 +133,7 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
133
133
|
# Predictions may be already averaged in CroppedTrialEpochScoring (y_pred.shape==2).
|
|
134
134
|
# However, when predictions are computed outside of CroppedTrialEpochScoring
|
|
135
135
|
# we have to average predictions, hence the check if len(y_pred.shape) == 3
|
|
136
|
-
if self.cropped and self.aggregate_predictions and len(
|
|
137
|
-
y_pred.shape) == 3:
|
|
136
|
+
if self.cropped and self.aggregate_predictions and len(y_pred.shape) == 3:
|
|
138
137
|
return y_pred.mean(axis=-1)
|
|
139
138
|
else:
|
|
140
139
|
return y_pred
|
|
@@ -223,18 +222,19 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
223
222
|
warnings.warn(
|
|
224
223
|
"This method was designed to predict trials in cropped mode. "
|
|
225
224
|
"Calling it when cropped is False will give the same result as "
|
|
226
|
-
"'.predict'.",
|
|
225
|
+
"'.predict'.",
|
|
226
|
+
UserWarning,
|
|
227
|
+
)
|
|
227
228
|
preds = self.predict(X)
|
|
228
229
|
if return_targets:
|
|
229
|
-
return preds, X.get_metadata()[
|
|
230
|
+
return preds, X.get_metadata()["target"].to_numpy()
|
|
230
231
|
return preds
|
|
231
232
|
return predict_trials(
|
|
232
233
|
module=self.module,
|
|
233
234
|
dataset=X,
|
|
234
235
|
return_targets=return_targets,
|
|
235
236
|
batch_size=self.batch_size,
|
|
236
|
-
num_workers=self.get_iterator(X,
|
|
237
|
-
training=False).loader.num_workers,
|
|
237
|
+
num_workers=self.get_iterator(X, training=False).loader.num_workers,
|
|
238
238
|
)
|
|
239
239
|
|
|
240
240
|
def _get_n_outputs(self, y, classes):
|
|
@@ -250,12 +250,14 @@ class EEGClassifier(_EEGNeuralNet, NeuralNetClassifier):
|
|
|
250
250
|
def _default_callbacks(self):
|
|
251
251
|
callbacks = list(super()._default_callbacks)
|
|
252
252
|
if not self.cropped:
|
|
253
|
-
callbacks.append(
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
253
|
+
callbacks.append(
|
|
254
|
+
(
|
|
255
|
+
"valid_acc",
|
|
256
|
+
EpochScoring(
|
|
257
|
+
"accuracy",
|
|
258
|
+
name="valid_acc",
|
|
259
|
+
lower_is_better=False,
|
|
260
|
+
),
|
|
259
261
|
)
|
|
260
|
-
)
|
|
262
|
+
)
|
|
261
263
|
return callbacks
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loader code for some datasets.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .base import BaseConcatDataset, BaseDataset, WindowsDataset
|
|
6
|
+
from .bcicomp import BCICompetitionIVDataset4
|
|
7
|
+
from .bids import BIDSDataset, BIDSEpochsDataset
|
|
8
|
+
from .mne import create_from_mne_epochs, create_from_mne_raw
|
|
9
|
+
from .moabb import BNCI2014001, HGD, MOABBDataset
|
|
10
|
+
from .nmt import NMT
|
|
11
|
+
from .sleep_physio_challe_18 import SleepPhysionetChallenge2018
|
|
12
|
+
from .sleep_physionet import SleepPhysionet
|
|
13
|
+
from .tuh import TUH, TUHAbnormal
|
|
14
|
+
from .xy import create_from_X_y
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"WindowsDataset",
|
|
18
|
+
"BaseDataset",
|
|
19
|
+
"BaseConcatDataset",
|
|
20
|
+
"BIDSDataset",
|
|
21
|
+
"BIDSEpochsDataset",
|
|
22
|
+
"MOABBDataset",
|
|
23
|
+
"HGD",
|
|
24
|
+
"BNCI2014001",
|
|
25
|
+
"create_from_mne_raw",
|
|
26
|
+
"create_from_mne_epochs",
|
|
27
|
+
"TUH",
|
|
28
|
+
"TUHAbnormal",
|
|
29
|
+
"NMT",
|
|
30
|
+
"SleepPhysionet",
|
|
31
|
+
"SleepPhysionetChallenge2018",
|
|
32
|
+
"create_from_X_y",
|
|
33
|
+
"BCICompetitionIVDataset4",
|
|
34
|
+
]
|