braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/eegneuralnet.py
CHANGED
|
@@ -5,32 +5,36 @@
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
import abc
|
|
8
|
-
import logging
|
|
9
8
|
import inspect
|
|
9
|
+
import logging
|
|
10
10
|
|
|
11
11
|
import mne
|
|
12
12
|
import numpy as np
|
|
13
13
|
import torch
|
|
14
|
-
from skorch import NeuralNet
|
|
15
14
|
from sklearn.metrics import get_scorer
|
|
15
|
+
from skorch import NeuralNet
|
|
16
16
|
from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
|
|
17
|
-
from skorch.
|
|
17
|
+
from skorch.helper import SliceDataset
|
|
18
|
+
from skorch.utils import is_dataset, noop, to_numpy, train_loss_score, valid_loss_score
|
|
18
19
|
|
|
19
|
-
from .training.scoring import (CroppedTimeSeriesEpochScoring,
|
|
20
|
-
CroppedTrialEpochScoring, PostEpochTrainScoring)
|
|
21
|
-
from .models.util import models_dict
|
|
22
20
|
from .datasets.base import BaseConcatDataset, WindowsDataset
|
|
21
|
+
from .models.util import models_dict
|
|
22
|
+
from .training.scoring import (
|
|
23
|
+
CroppedTimeSeriesEpochScoring,
|
|
24
|
+
CroppedTrialEpochScoring,
|
|
25
|
+
PostEpochTrainScoring,
|
|
26
|
+
)
|
|
23
27
|
|
|
24
28
|
log = logging.getLogger(__name__)
|
|
25
29
|
|
|
26
30
|
|
|
27
|
-
def _get_model(model):
|
|
28
|
-
|
|
31
|
+
def _get_model(model: str):
|
|
32
|
+
"""Returns the corresponding class in case the model passed is a string."""
|
|
29
33
|
if isinstance(model, str):
|
|
30
34
|
if model in models_dict:
|
|
31
35
|
model = models_dict[model]
|
|
32
36
|
else:
|
|
33
|
-
raise ValueError(f
|
|
37
|
+
raise ValueError(f"Unknown model name {model!r}.")
|
|
34
38
|
return model
|
|
35
39
|
|
|
36
40
|
|
|
@@ -50,7 +54,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
50
54
|
will be left as is.
|
|
51
55
|
|
|
52
56
|
"""
|
|
53
|
-
kwargs = self.get_params_for(
|
|
57
|
+
kwargs = self.get_params_for("module")
|
|
54
58
|
module = _get_model(self.module)
|
|
55
59
|
module = self.initialized_instance(module, kwargs)
|
|
56
60
|
# pylint: disable=attribute-defined-outside-init
|
|
@@ -61,7 +65,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
61
65
|
# Here we parse the callbacks supplied as strings,
|
|
62
66
|
# e.g. 'accuracy', to the callbacks skorch expects
|
|
63
67
|
for name, cb, named_by_user in super()._yield_callbacks():
|
|
64
|
-
if name ==
|
|
68
|
+
if name == "str":
|
|
65
69
|
train_cb, valid_cb = self._parse_str_callback(cb)
|
|
66
70
|
yield train_cb
|
|
67
71
|
if self.train_split is not None:
|
|
@@ -72,15 +76,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
72
76
|
def _parse_str_callback(self, cb_supplied_name):
|
|
73
77
|
scoring = get_scorer(cb_supplied_name)
|
|
74
78
|
scoring_name = scoring._score_func.__name__
|
|
75
|
-
assert scoring_name.endswith(
|
|
76
|
-
|
|
77
|
-
if (scoring_name.endswith('_score') or
|
|
78
|
-
cb_supplied_name.startswith('neg_')):
|
|
79
|
+
assert scoring_name.endswith(("_score", "_error", "_deviance", "_loss"))
|
|
80
|
+
if scoring_name.endswith("_score") or cb_supplied_name.startswith("neg_"):
|
|
79
81
|
lower_is_better = False
|
|
80
82
|
else:
|
|
81
83
|
lower_is_better = True
|
|
82
|
-
train_name = f
|
|
83
|
-
valid_name = f
|
|
84
|
+
train_name = f"train_{cb_supplied_name}"
|
|
85
|
+
valid_name = f"valid_{cb_supplied_name}"
|
|
84
86
|
if self.cropped:
|
|
85
87
|
train_scoring = CroppedTrialEpochScoring(
|
|
86
88
|
cb_supplied_name, lower_is_better, on_train=True, name=train_name
|
|
@@ -98,7 +100,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
98
100
|
named_by_user = True
|
|
99
101
|
train_valid_callbacks = [
|
|
100
102
|
(train_name, train_scoring, named_by_user),
|
|
101
|
-
(valid_name, valid_scoring, named_by_user)
|
|
103
|
+
(valid_name, valid_scoring, named_by_user),
|
|
102
104
|
]
|
|
103
105
|
return train_valid_callbacks
|
|
104
106
|
|
|
@@ -108,8 +110,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
108
110
|
if not training:
|
|
109
111
|
epoch_cbs = []
|
|
110
112
|
for name, cb in self.callbacks_:
|
|
111
|
-
if
|
|
112
|
-
|
|
113
|
+
if (
|
|
114
|
+
isinstance(
|
|
115
|
+
cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)
|
|
116
|
+
)
|
|
117
|
+
and (hasattr(cb, "window_inds_"))
|
|
118
|
+
and (not cb.on_train)
|
|
119
|
+
):
|
|
113
120
|
epoch_cbs.append(cb)
|
|
114
121
|
# for trialwise decoding stuffs it might also be we don't have
|
|
115
122
|
# cropped loader, so no indices there
|
|
@@ -136,8 +143,11 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
136
143
|
i_window_stops = np.concatenate(i_window_stops)
|
|
137
144
|
window_ys = np.concatenate(window_ys)
|
|
138
145
|
return dict(
|
|
139
|
-
preds=preds,
|
|
140
|
-
|
|
146
|
+
preds=preds,
|
|
147
|
+
i_window_in_trials=i_window_in_trials,
|
|
148
|
+
i_window_stops=i_window_stops,
|
|
149
|
+
window_ys=window_ys,
|
|
150
|
+
)
|
|
141
151
|
|
|
142
152
|
# Changes the default target extractor to noop
|
|
143
153
|
@property
|
|
@@ -156,7 +166,9 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
156
166
|
(
|
|
157
167
|
"valid_loss",
|
|
158
168
|
BatchScoring(
|
|
159
|
-
valid_loss_score,
|
|
169
|
+
valid_loss_score,
|
|
170
|
+
name="valid_loss",
|
|
171
|
+
target_extractor=noop,
|
|
160
172
|
),
|
|
161
173
|
),
|
|
162
174
|
("print_log", PrintLog()),
|
|
@@ -179,17 +191,27 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
179
191
|
return
|
|
180
192
|
# get kwargs from signal:
|
|
181
193
|
signal_kwargs = dict()
|
|
182
|
-
|
|
194
|
+
# Using shape to work both with torch.tensor and numpy.array:
|
|
195
|
+
if (
|
|
196
|
+
isinstance(X, mne.BaseEpochs)
|
|
197
|
+
or (hasattr(X, "shape") and len(X.shape) >= 2)
|
|
198
|
+
or isinstance(X, SliceDataset)
|
|
199
|
+
):
|
|
183
200
|
if y is None:
|
|
184
|
-
raise ValueError("y must be specified if X is
|
|
185
|
-
signal_kwargs[
|
|
201
|
+
raise ValueError("y must be specified if X is array-like.")
|
|
202
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y, classes)
|
|
186
203
|
if isinstance(X, mne.BaseEpochs):
|
|
187
204
|
self.log.info("Using mne.Epochs to find signal-related parameters.")
|
|
188
205
|
signal_kwargs["n_times"] = len(X.times)
|
|
189
|
-
signal_kwargs["sfreq"] = X.info[
|
|
190
|
-
signal_kwargs["chs_info"] = X.info[
|
|
206
|
+
signal_kwargs["sfreq"] = X.info["sfreq"]
|
|
207
|
+
signal_kwargs["chs_info"] = X.info["chs"]
|
|
208
|
+
elif isinstance(X, SliceDataset):
|
|
209
|
+
self.log.info("Using SliceDataset to find signal-related parameters.")
|
|
210
|
+
Xshape = X[0].shape
|
|
211
|
+
signal_kwargs["n_times"] = Xshape[-1]
|
|
212
|
+
signal_kwargs["n_chans"] = Xshape[-2]
|
|
191
213
|
else:
|
|
192
|
-
self.log.info("Using
|
|
214
|
+
self.log.info("Using array-like to find signal-related parameters.")
|
|
193
215
|
signal_kwargs["n_times"] = X.shape[-1]
|
|
194
216
|
signal_kwargs["n_chans"] = X.shape[-2]
|
|
195
217
|
elif is_dataset(X):
|
|
@@ -198,21 +220,17 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
198
220
|
Xshape = X0.shape
|
|
199
221
|
signal_kwargs["n_times"] = Xshape[-1]
|
|
200
222
|
signal_kwargs["n_chans"] = Xshape[-2]
|
|
201
|
-
if (
|
|
202
|
-
|
|
203
|
-
all(ds.targets_from == 'metadata' for ds in X.datasets)
|
|
223
|
+
if isinstance(X, BaseConcatDataset) and all(
|
|
224
|
+
ds.targets_from == "metadata" for ds in X.datasets
|
|
204
225
|
):
|
|
205
226
|
y_target = X.get_metadata().target
|
|
206
|
-
signal_kwargs[
|
|
207
|
-
elif (
|
|
208
|
-
isinstance(X, WindowsDataset) and
|
|
209
|
-
X.targets_from == "metadata"
|
|
210
|
-
):
|
|
227
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
|
|
228
|
+
elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
|
|
211
229
|
y_target = X.windows.metadata.target
|
|
212
|
-
signal_kwargs[
|
|
230
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
|
|
213
231
|
else:
|
|
214
232
|
self.log.warning(
|
|
215
|
-
"Can only infer signal shape of
|
|
233
|
+
"Can only infer signal shape of array-like and Datasets, "
|
|
216
234
|
f"got {type(X)!r}."
|
|
217
235
|
)
|
|
218
236
|
return
|
|
@@ -227,15 +245,13 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
227
245
|
if k in all_module_kwargs:
|
|
228
246
|
module_kwargs[k] = v
|
|
229
247
|
else:
|
|
230
|
-
self.log.warning(
|
|
231
|
-
f"Module {self.module!r} "
|
|
232
|
-
f"is missing parameter {k!r}."
|
|
233
|
-
)
|
|
248
|
+
self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
|
|
234
249
|
|
|
235
250
|
# save kwargs to self:
|
|
236
251
|
self.log.info(
|
|
237
252
|
f"Passing additional parameters {module_kwargs!r} "
|
|
238
|
-
f"to module {self.module!r}."
|
|
253
|
+
f"to module {self.module!r}."
|
|
254
|
+
)
|
|
239
255
|
module_kwargs = {f"module__{k}": v for k, v in module_kwargs.items()}
|
|
240
256
|
self.set_params(**module_kwargs)
|
|
241
257
|
|
|
@@ -275,7 +291,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
275
291
|
|
|
276
292
|
"""
|
|
277
293
|
if isinstance(X, mne.BaseEpochs):
|
|
278
|
-
X = X.get_data(units=
|
|
294
|
+
X = X.get_data(units="uV")
|
|
279
295
|
return super().get_dataset(X, y)
|
|
280
296
|
|
|
281
297
|
def partial_fit(self, X, y=None, classes=None, **fit_params):
|
|
@@ -291,7 +307,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
291
307
|
|
|
292
308
|
* mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
|
|
293
309
|
``sfreq``, ``input_window_seconds``
|
|
294
|
-
*
|
|
310
|
+
* array-like: ``n_times``, ``n_chans``, ``n_outputs``
|
|
295
311
|
* WindowsDataset with ``targets_from='metadata'``
|
|
296
312
|
(or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
|
|
297
313
|
* other Dataset: ``n_times``, ``n_chans``
|
|
@@ -345,7 +361,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
345
361
|
|
|
346
362
|
* mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
|
|
347
363
|
``sfreq``, ``input_window_seconds``
|
|
348
|
-
*
|
|
364
|
+
* array-like: ``n_times``, ``n_chans``, ``n_outputs``
|
|
349
365
|
* WindowsDataset with ``targets_from='metadata'``
|
|
350
366
|
(or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
|
|
351
367
|
* other Dataset: ``n_times``, ``n_chans``
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .functions import (
|
|
2
|
+
_get_gaussian_kernel1d,
|
|
3
|
+
drop_path,
|
|
4
|
+
hilbert_freq,
|
|
5
|
+
identity,
|
|
6
|
+
plv_time,
|
|
7
|
+
safe_log,
|
|
8
|
+
square,
|
|
9
|
+
)
|
|
10
|
+
from .initialization import glorot_weight_zero_bias, rescale_parameter
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"_get_gaussian_kernel1d",
|
|
14
|
+
"drop_path",
|
|
15
|
+
"hilbert_freq",
|
|
16
|
+
"identity",
|
|
17
|
+
"plv_time",
|
|
18
|
+
"safe_log",
|
|
19
|
+
"square",
|
|
20
|
+
"glorot_weight_zero_bias",
|
|
21
|
+
"rescale_parameter",
|
|
22
|
+
]
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def square(x):
|
|
10
|
+
return x * x
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def safe_log(x, eps: float = 1e-6) -> torch.Tensor:
|
|
14
|
+
"""Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
|
|
15
|
+
return torch.log(torch.clamp(x, min=eps))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def identity(x):
|
|
19
|
+
return x
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def drop_path(
|
|
23
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
24
|
+
):
|
|
25
|
+
"""Drop paths (Stochastic Depth) per sample.
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
Notes: This implementation is taken from timm library.
|
|
29
|
+
|
|
30
|
+
All credit goes to Ross Wightman.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
x: torch.Tensor
|
|
35
|
+
input tensor
|
|
36
|
+
drop_prob : float, optional
|
|
37
|
+
survival rate (i.e. probability of being kept), by default 0.0
|
|
38
|
+
training : bool, optional
|
|
39
|
+
whether the model is in training mode, by default False
|
|
40
|
+
scale_by_keep : bool, optional
|
|
41
|
+
whether to scale output by (1/keep_prob) during training, by default True
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
torch.Tensor
|
|
46
|
+
output tensor
|
|
47
|
+
|
|
48
|
+
Notes from Ross Wightman:
|
|
49
|
+
(when applied in main path of residual blocks)
|
|
50
|
+
This is the same as the DropConnect impl I created for EfficientNet,
|
|
51
|
+
etc. networks, however,
|
|
52
|
+
the original name is misleading as 'Drop Connect' is a different form
|
|
53
|
+
of dropout in a separate paper...
|
|
54
|
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
|
|
55
|
+
... I've opted for changing the layer and argument names to 'drop path'
|
|
56
|
+
rather than mix DropConnect as a layer name and use
|
|
57
|
+
'survival rate' as the argument.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
if drop_prob == 0.0 or not training:
|
|
61
|
+
return x
|
|
62
|
+
keep_prob = 1 - drop_prob
|
|
63
|
+
shape = (x.shape[0],) + (1,) * (
|
|
64
|
+
x.ndim - 1
|
|
65
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
|
66
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
67
|
+
if keep_prob > 0.0 and scale_by_keep:
|
|
68
|
+
random_tensor.div_(keep_prob)
|
|
69
|
+
return x * random_tensor
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
|
|
73
|
+
"""
|
|
74
|
+
Generates a 1-dimensional Gaussian kernel based on the specified kernel
|
|
75
|
+
size and standard deviation (sigma).
|
|
76
|
+
This kernel is useful for Gaussian smoothing or filtering operations in
|
|
77
|
+
image processing. The function calculates a range limit to ensure the kernel
|
|
78
|
+
effectively covers the Gaussian distribution. It generates a tensor of
|
|
79
|
+
specified size and type, filled with values distributed according to a
|
|
80
|
+
Gaussian curve, normalized using a softmax function
|
|
81
|
+
to ensure all weights sum to 1.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
kernel_size: int
|
|
87
|
+
sigma: float
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
kernel1d: torch.Tensor
|
|
92
|
+
|
|
93
|
+
Notes
|
|
94
|
+
-----
|
|
95
|
+
Code copied and modified from TorchVision:
|
|
96
|
+
https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py#L725-L732
|
|
97
|
+
All rights reserved.
|
|
98
|
+
|
|
99
|
+
LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
ksize_half = (kernel_size - 1) * 0.5
|
|
103
|
+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
104
|
+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
105
|
+
kernel1d = pdf / pdf.sum()
|
|
106
|
+
return kernel1d
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def hilbert_freq(x, forward_fourier=True):
|
|
110
|
+
r"""
|
|
111
|
+
Compute the Hilbert transform using PyTorch, separating the real and
|
|
112
|
+
imaginary parts.
|
|
113
|
+
|
|
114
|
+
The analytic signal :math:`x_a(t)` of a real-valued signal :math:`x(t)`
|
|
115
|
+
is defined as:
|
|
116
|
+
|
|
117
|
+
.. math::
|
|
118
|
+
|
|
119
|
+
x_a(t) = x(t) + i y(t) = \mathcal{F}^{-1} \{ U(f) \mathcal{F}\{x(t)\} \}
|
|
120
|
+
|
|
121
|
+
where:
|
|
122
|
+
- :math:`\mathcal{F}` is the Fourier transform,
|
|
123
|
+
- :math:`U(f)` is the unit step function,
|
|
124
|
+
- :math:`y(t)` is the Hilbert transform of :math:`x(t)`.
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
input : torch.Tensor
|
|
130
|
+
Input tensor. The expected shape depends on the `forward_fourier` parameter:
|
|
131
|
+
|
|
132
|
+
- If `forward_fourier` is True:
|
|
133
|
+
(..., seq_len)
|
|
134
|
+
- If `forward_fourier` is False:
|
|
135
|
+
(..., seq_len / 2 + 1, 2)
|
|
136
|
+
|
|
137
|
+
forward_fourier : bool, optional
|
|
138
|
+
Determines the format of the input tensor.
|
|
139
|
+
- If True, the input is in the forward Fourier domain.
|
|
140
|
+
- If False, the input contains separate real and imaginary parts.
|
|
141
|
+
Default is True.
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
torch.Tensor
|
|
146
|
+
Output tensor with shape (..., seq_len, 2), where the last dimension represents
|
|
147
|
+
the real and imaginary parts of the Hilbert transform.
|
|
148
|
+
|
|
149
|
+
Examples
|
|
150
|
+
--------
|
|
151
|
+
>>> import torch
|
|
152
|
+
>>> input = torch.randn(10, 100) # Example input tensor
|
|
153
|
+
>>> output = hilbert_transform(input)
|
|
154
|
+
>>> print(output.shape)
|
|
155
|
+
torch.Size([10, 100, 2])
|
|
156
|
+
|
|
157
|
+
Notes
|
|
158
|
+
-----
|
|
159
|
+
The implementation is matching scipy implementation, but using torch.
|
|
160
|
+
https://github.com/scipy/scipy/blob/v1.14.1/scipy/signal/_signaltools.py#L2287-L2394
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
if forward_fourier:
|
|
164
|
+
x = torch.fft.rfft(x, norm=None, dim=-1)
|
|
165
|
+
x = torch.view_as_real(x)
|
|
166
|
+
x = x * 2.0
|
|
167
|
+
x[..., 0, :] = x[..., 0, :] / 2.0 # Don't multiply the DC-term by 2
|
|
168
|
+
x = F.pad(
|
|
169
|
+
x, [0, 0, 0, x.shape[-2] - 2]
|
|
170
|
+
) # Fill Fourier coefficients to retain shape
|
|
171
|
+
x = torch.view_as_complex(x)
|
|
172
|
+
x = torch.fft.ifft(x, norm=None, dim=-1) # returns complex signal
|
|
173
|
+
x = torch.view_as_real(x)
|
|
174
|
+
|
|
175
|
+
return x
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
|
|
179
|
+
"""Compute the Phase Locking Value (PLV) metric in the time domain.
|
|
180
|
+
|
|
181
|
+
The Phase Locking Value (PLV) is a measure of the synchronization between
|
|
182
|
+
different channels by evaluating the consistency of phase differences
|
|
183
|
+
over time. It ranges from 0 (no synchronization) to 1 (perfect
|
|
184
|
+
synchronization) [1]_.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
x : torch.Tensor
|
|
189
|
+
Input tensor containing the signal data.
|
|
190
|
+
- If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
|
|
191
|
+
- If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
|
|
192
|
+
where the last dimension represents the real and imaginary parts.
|
|
193
|
+
forward_fourier : bool, optional
|
|
194
|
+
Specifies the format of the input tensor `x`.
|
|
195
|
+
- If `True`, `x` is assumed to be in the time domain.
|
|
196
|
+
- If `False`, `x` is assumed to be in the Fourier domain with separate real and
|
|
197
|
+
imaginary components.
|
|
198
|
+
Default is `True`.
|
|
199
|
+
epsilon : float, default 1e-6
|
|
200
|
+
Small numerical value to ensure positivity constraint on the complex part
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
plv : torch.Tensor
|
|
205
|
+
The Phase Locking Value matrix with shape `(..., channels, channels)`. Each
|
|
206
|
+
element `[i, j]` represents the PLV between channel `i` and channel `j`.
|
|
207
|
+
|
|
208
|
+
References
|
|
209
|
+
----------
|
|
210
|
+
[1] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
|
|
211
|
+
Measuring phase synchrony in brain signals. Human brain mapping,
|
|
212
|
+
8(4), 194-208.
|
|
213
|
+
"""
|
|
214
|
+
# Compute the analytic signal using the Hilbert transform.
|
|
215
|
+
# x_a has separate real and imaginary parts.
|
|
216
|
+
analytic_signal = hilbert_freq(x, forward_fourier)
|
|
217
|
+
# Calculate the amplitude (magnitude) of the analytic signal.
|
|
218
|
+
# Adding a small epsilon (1e-6) to avoid division by zero.
|
|
219
|
+
amplitude = torch.sqrt(
|
|
220
|
+
analytic_signal[..., 0] ** 2 + analytic_signal[..., 1] ** 2 + 1e-6
|
|
221
|
+
)
|
|
222
|
+
# Normalize the analytic signal to obtain unit vectors (phasors).
|
|
223
|
+
unit_phasor = analytic_signal / amplitude.unsqueeze(-1)
|
|
224
|
+
|
|
225
|
+
# Compute the real part of the outer product between phasors of
|
|
226
|
+
# different channels.
|
|
227
|
+
real_real = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 0].transpose(-2, -1))
|
|
228
|
+
|
|
229
|
+
# Compute the imaginary part of the outer product between phasors of
|
|
230
|
+
# different channels.
|
|
231
|
+
imag_imag = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 1].transpose(-2, -1))
|
|
232
|
+
|
|
233
|
+
# Compute the cross-terms for the real and imaginary parts.
|
|
234
|
+
real_imag = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 1].transpose(-2, -1))
|
|
235
|
+
imag_real = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 0].transpose(-2, -1))
|
|
236
|
+
|
|
237
|
+
# Combine the real and imaginary parts to form the complex correlation.
|
|
238
|
+
correlation_real = real_real + imag_imag
|
|
239
|
+
correlation_imag = real_imag - imag_real
|
|
240
|
+
|
|
241
|
+
# Determine the number of time points (or frequency bins if in Fourier domain).
|
|
242
|
+
time = amplitude.shape[-1]
|
|
243
|
+
|
|
244
|
+
# Calculate the PLV by averaging the magnitude of the complex correlation over time.
|
|
245
|
+
# epsilon is small numerical value to ensure positivity constraint on the complex part
|
|
246
|
+
plv_matrix = (
|
|
247
|
+
1 / time * torch.sqrt(correlation_real**2 + correlation_imag**2 + epsilon)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
return plv_matrix
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def glorot_weight_zero_bias(model):
|
|
7
|
+
"""Initialize parameters of all modules by initializing weights with
|
|
8
|
+
glorot
|
|
9
|
+
uniform/xavier initialization, and setting biases to zero. Weights from
|
|
10
|
+
batch norm layers are set to 1.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
model: Module
|
|
15
|
+
"""
|
|
16
|
+
for module in model.modules():
|
|
17
|
+
if hasattr(module, "weight"):
|
|
18
|
+
if "BatchNorm" in module.__class__.__name__:
|
|
19
|
+
nn.init.constant_(module.weight, 1)
|
|
20
|
+
if hasattr(module, "bias"):
|
|
21
|
+
if module.bias is not None:
|
|
22
|
+
nn.init.constant_(module.bias, 0)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def rescale_parameter(param, layer_id):
|
|
26
|
+
r"""Recaling the l-th transformer layer.
|
|
27
|
+
|
|
28
|
+
Rescales the parameter tensor by the inverse square root of the layer id.
|
|
29
|
+
Made inplace. :math:`\frac{1}{\sqrt{2 \cdot \text{layer\_id}}}` [Beit2022]
|
|
30
|
+
|
|
31
|
+
In the labram, this is used to rescale the output matrices
|
|
32
|
+
(i.e., the last linear projection within each sub-layer) of the
|
|
33
|
+
self-attention module.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
param: :class:`torch.Tensor`
|
|
38
|
+
tensor to be rescaled
|
|
39
|
+
layer_id: int
|
|
40
|
+
layer id in the neural network
|
|
41
|
+
|
|
42
|
+
References
|
|
43
|
+
----------
|
|
44
|
+
[Beit2022] Hangbo Bao, Li Dong, Songhao Piao, Furu We (2022). BEIT: BERT
|
|
45
|
+
Pre-Training of Image Transformers.
|
|
46
|
+
"""
|
|
47
|
+
param.div_(math.sqrt(2.0 * layer_id))
|
braindecode/models/__init__.py
CHANGED
|
@@ -1,30 +1,100 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Some predefined network architectures for EEG decoding.
|
|
3
3
|
"""
|
|
4
|
+
|
|
5
|
+
from .atcnet import ATCNet
|
|
6
|
+
from .attentionbasenet import AttentionBaseNet
|
|
4
7
|
from .base import EEGModuleMixin
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
8
|
+
from .biot import BIOT
|
|
9
|
+
from .contrawr import ContraWR
|
|
10
|
+
from .ctnet import CTNet
|
|
7
11
|
from .deep4 import Deep4Net
|
|
8
12
|
from .deepsleepnet import DeepSleepNet
|
|
9
|
-
from .
|
|
10
|
-
from .hybrid import HybridNet
|
|
11
|
-
from .shallow_fbcsp import ShallowFBCSPNet
|
|
12
|
-
from .eegresnet import EEGResNet
|
|
13
|
-
from .eeginception import EEGInception
|
|
13
|
+
from .eegconformer import EEGConformer
|
|
14
14
|
from .eeginception_erp import EEGInceptionERP
|
|
15
15
|
from .eeginception_mi import EEGInceptionMI
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
18
|
-
from .
|
|
16
|
+
from .eegitnet import EEGITNet
|
|
17
|
+
from .eegminer import EEGMiner
|
|
18
|
+
from .eegnet import EEGNetv1, EEGNetv4
|
|
19
|
+
from .eegnex import EEGNeX
|
|
20
|
+
from .eegresnet import EEGResNet
|
|
21
|
+
from .eegsimpleconv import EEGSimpleConv
|
|
22
|
+
from .eegtcnet import EEGTCNet
|
|
23
|
+
from .fbcnet import FBCNet
|
|
24
|
+
from .fblightconvnet import FBLightConvNet
|
|
25
|
+
from .fbmsnet import FBMSNet
|
|
26
|
+
from .hybrid import HybridNet
|
|
27
|
+
from .ifnet import IFNet
|
|
28
|
+
from .labram import Labram
|
|
29
|
+
from .msvtnet import MSVTNet
|
|
30
|
+
from .sccnet import SCCNet
|
|
31
|
+
from .shallow_fbcsp import ShallowFBCSPNet
|
|
32
|
+
from .signal_jepa import (
|
|
33
|
+
SignalJEPA,
|
|
34
|
+
SignalJEPA_Contextual,
|
|
35
|
+
SignalJEPA_PostLocal,
|
|
36
|
+
SignalJEPA_PreLocal,
|
|
37
|
+
)
|
|
38
|
+
from .sinc_shallow import SincShallowNet
|
|
19
39
|
from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
|
|
40
|
+
from .sleep_stager_chambon_2018 import SleepStagerChambon2018
|
|
20
41
|
from .sleep_stager_eldele_2021 import SleepStagerEldele2021
|
|
42
|
+
from .sparcnet import SPARCNet
|
|
43
|
+
from .syncnet import SyncNet
|
|
44
|
+
from .tcn import BDTCN, TCN
|
|
21
45
|
from .tidnet import TIDNet
|
|
46
|
+
from .tsinception import TSceptionV1
|
|
22
47
|
from .usleep import USleep
|
|
23
|
-
from .util import
|
|
24
|
-
from .modules import TimeDistributed
|
|
25
|
-
|
|
26
|
-
from .util import _init_models_dict
|
|
48
|
+
from .util import _init_models_dict, models_mandatory_parameters
|
|
27
49
|
|
|
28
50
|
# Call this last in order to make sure the dataset list is populated with
|
|
29
51
|
# the models imported in this file.
|
|
30
52
|
_init_models_dict()
|
|
53
|
+
|
|
54
|
+
__all__ = [
|
|
55
|
+
"ATCNet",
|
|
56
|
+
"AttentionBaseNet",
|
|
57
|
+
"EEGModuleMixin",
|
|
58
|
+
"BIOT",
|
|
59
|
+
"ContraWR",
|
|
60
|
+
"CTNet",
|
|
61
|
+
"Deep4Net",
|
|
62
|
+
"DeepSleepNet",
|
|
63
|
+
"EEGConformer",
|
|
64
|
+
"EEGInceptionERP",
|
|
65
|
+
"EEGInceptionMI",
|
|
66
|
+
"EEGITNet",
|
|
67
|
+
"EEGMiner",
|
|
68
|
+
"EEGNetv1",
|
|
69
|
+
"EEGNetv4",
|
|
70
|
+
"EEGNeX",
|
|
71
|
+
"EEGResNet",
|
|
72
|
+
"EEGSimpleConv",
|
|
73
|
+
"EEGTCNet",
|
|
74
|
+
"FBCNet",
|
|
75
|
+
"FBLightConvNet",
|
|
76
|
+
"FBMSNet",
|
|
77
|
+
"HybridNet",
|
|
78
|
+
"IFNet",
|
|
79
|
+
"Labram",
|
|
80
|
+
"MSVTNet",
|
|
81
|
+
"SCCNet",
|
|
82
|
+
"ShallowFBCSPNet",
|
|
83
|
+
"SignalJEPA",
|
|
84
|
+
"SignalJEPA_Contextual",
|
|
85
|
+
"SignalJEPA_PostLocal",
|
|
86
|
+
"SignalJEPA_PreLocal",
|
|
87
|
+
"SincShallowNet",
|
|
88
|
+
"SleepStagerBlanco2020",
|
|
89
|
+
"SleepStagerChambon2018",
|
|
90
|
+
"SleepStagerEldele2021",
|
|
91
|
+
"SPARCNet",
|
|
92
|
+
"SyncNet",
|
|
93
|
+
"BDTCN",
|
|
94
|
+
"TCN",
|
|
95
|
+
"TIDNet",
|
|
96
|
+
"TSceptionV1",
|
|
97
|
+
"USleep",
|
|
98
|
+
"_init_models_dict",
|
|
99
|
+
"models_mandatory_parameters",
|
|
100
|
+
]
|