braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/datautil/util.py
CHANGED
|
@@ -2,19 +2,6 @@
|
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
4
|
|
|
5
|
-
import logging
|
|
6
|
-
from typing import Any, Literal
|
|
7
|
-
|
|
8
|
-
import mne
|
|
9
|
-
import numpy as np
|
|
10
|
-
from skorch.helper import SliceDataset
|
|
11
|
-
from skorch.utils import is_dataset
|
|
12
|
-
|
|
13
|
-
from braindecode.datasets.base import BaseConcatDataset, WindowsDataset
|
|
14
|
-
from braindecode.models.util import SigArgName
|
|
15
|
-
|
|
16
|
-
log = logging.getLogger(__name__)
|
|
17
|
-
|
|
18
5
|
|
|
19
6
|
def ms_to_samples(ms, fs):
|
|
20
7
|
"""
|
|
@@ -22,15 +9,16 @@ def ms_to_samples(ms, fs):
|
|
|
22
9
|
|
|
23
10
|
Parameters
|
|
24
11
|
----------
|
|
25
|
-
ms
|
|
12
|
+
ms: number
|
|
26
13
|
Milliseconds
|
|
27
|
-
fs
|
|
14
|
+
fs: number
|
|
28
15
|
Sampling rate
|
|
29
16
|
|
|
30
17
|
Returns
|
|
31
18
|
-------
|
|
32
|
-
n_samples
|
|
19
|
+
n_samples: int
|
|
33
20
|
Number of samples
|
|
21
|
+
|
|
34
22
|
"""
|
|
35
23
|
return ms * fs / 1000.0
|
|
36
24
|
|
|
@@ -41,114 +29,13 @@ def samples_to_ms(n_samples, fs):
|
|
|
41
29
|
|
|
42
30
|
Parameters
|
|
43
31
|
----------
|
|
44
|
-
n_samples
|
|
32
|
+
n_samples: number
|
|
45
33
|
Number of samples
|
|
46
|
-
fs
|
|
34
|
+
fs: number
|
|
47
35
|
Sampling rate
|
|
48
36
|
|
|
49
37
|
Returns
|
|
50
38
|
-------
|
|
51
|
-
milliseconds
|
|
39
|
+
milliseconds: int
|
|
52
40
|
"""
|
|
53
41
|
return n_samples * 1000.0 / fs
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def _get_n_outputs(y, classes, mode):
|
|
57
|
-
if mode == "classification":
|
|
58
|
-
classes_y = np.unique(y)
|
|
59
|
-
if classes is not None:
|
|
60
|
-
assert set(classes_y) <= set(classes)
|
|
61
|
-
else:
|
|
62
|
-
classes = classes_y
|
|
63
|
-
return len(classes)
|
|
64
|
-
elif mode == "regression":
|
|
65
|
-
if y is None:
|
|
66
|
-
return None
|
|
67
|
-
if y.ndim == 1:
|
|
68
|
-
return 1
|
|
69
|
-
else:
|
|
70
|
-
return y.shape[-1]
|
|
71
|
-
else:
|
|
72
|
-
raise ValueError(f"Unknown mode {mode}")
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def infer_signal_properties(
|
|
76
|
-
X,
|
|
77
|
-
y=None,
|
|
78
|
-
mode: Literal["classification", "regression"] = "classification",
|
|
79
|
-
classes: list | None = None,
|
|
80
|
-
) -> dict[SigArgName, Any]:
|
|
81
|
-
"""Infers signal properties from the data.
|
|
82
|
-
|
|
83
|
-
The extracted signal properties are:
|
|
84
|
-
|
|
85
|
-
+ n_chans: number of channels
|
|
86
|
-
+ n_times: number of time points
|
|
87
|
-
+ n_outputs: number of outputs
|
|
88
|
-
+ chs_info: channel information
|
|
89
|
-
+ sfreq: sampling frequency
|
|
90
|
-
|
|
91
|
-
The returned dictionary can serve as kwargs for model initialization.
|
|
92
|
-
|
|
93
|
-
Depending on the type of input passed, not all properties can be inferred.
|
|
94
|
-
|
|
95
|
-
Parameters
|
|
96
|
-
----------
|
|
97
|
-
X : array-like or mne.BaseEpochs or Dataset
|
|
98
|
-
Input data
|
|
99
|
-
y : array-like or None
|
|
100
|
-
Targets
|
|
101
|
-
mode : "classification" or "regression"
|
|
102
|
-
Mode of the task
|
|
103
|
-
classes : list or None
|
|
104
|
-
List of classes for classification
|
|
105
|
-
|
|
106
|
-
Returns
|
|
107
|
-
-------
|
|
108
|
-
signal_kwargs : dict
|
|
109
|
-
Dictionary with signal-properties. Can serve as kwargs for model
|
|
110
|
-
initialization.
|
|
111
|
-
"""
|
|
112
|
-
signal_kwargs: dict[SigArgName, Any] = {}
|
|
113
|
-
# Using shape to work both with torch.tensor and numpy.array:
|
|
114
|
-
if (
|
|
115
|
-
isinstance(X, mne.BaseEpochs)
|
|
116
|
-
or (hasattr(X, "shape") and len(X.shape) >= 2)
|
|
117
|
-
or isinstance(X, SliceDataset)
|
|
118
|
-
):
|
|
119
|
-
if y is None:
|
|
120
|
-
raise ValueError("y must be specified if X is array-like.")
|
|
121
|
-
signal_kwargs["n_outputs"] = _get_n_outputs(y, classes, mode)
|
|
122
|
-
if isinstance(X, mne.BaseEpochs):
|
|
123
|
-
log.info("Using mne.Epochs to find signal-related parameters.")
|
|
124
|
-
signal_kwargs["n_times"] = len(X.times)
|
|
125
|
-
signal_kwargs["sfreq"] = X.info["sfreq"]
|
|
126
|
-
signal_kwargs["chs_info"] = X.info["chs"]
|
|
127
|
-
elif isinstance(X, SliceDataset):
|
|
128
|
-
log.info("Using SliceDataset to find signal-related parameters.")
|
|
129
|
-
Xshape = X[0].shape
|
|
130
|
-
signal_kwargs["n_times"] = Xshape[-1]
|
|
131
|
-
signal_kwargs["n_chans"] = Xshape[-2]
|
|
132
|
-
else:
|
|
133
|
-
log.info("Using array-like to find signal-related parameters.")
|
|
134
|
-
signal_kwargs["n_times"] = X.shape[-1]
|
|
135
|
-
signal_kwargs["n_chans"] = X.shape[-2]
|
|
136
|
-
elif is_dataset(X):
|
|
137
|
-
log.info(f"Using Dataset {X!r} to find signal-related parameters.")
|
|
138
|
-
X0 = X[0][0]
|
|
139
|
-
Xshape = X0.shape
|
|
140
|
-
signal_kwargs["n_times"] = Xshape[-1]
|
|
141
|
-
signal_kwargs["n_chans"] = Xshape[-2]
|
|
142
|
-
if isinstance(X, BaseConcatDataset) and all(
|
|
143
|
-
ds.targets_from == "metadata" for ds in X.datasets
|
|
144
|
-
):
|
|
145
|
-
y_target = X.get_metadata().target
|
|
146
|
-
signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
|
|
147
|
-
elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
|
|
148
|
-
y_target = X.windows.metadata.target
|
|
149
|
-
signal_kwargs["n_outputs"] = _get_n_outputs(y_target, classes, mode)
|
|
150
|
-
else:
|
|
151
|
-
log.warning(
|
|
152
|
-
f"Can only infer signal shape of array-like and Datasets, got {type(X)!r}."
|
|
153
|
-
)
|
|
154
|
-
return signal_kwargs
|
braindecode/eegneuralnet.py
CHANGED
|
@@ -7,7 +7,6 @@
|
|
|
7
7
|
import abc
|
|
8
8
|
import inspect
|
|
9
9
|
import logging
|
|
10
|
-
from typing import Literal
|
|
11
10
|
|
|
12
11
|
import mne
|
|
13
12
|
import numpy as np
|
|
@@ -15,10 +14,10 @@ import torch
|
|
|
15
14
|
from sklearn.metrics import get_scorer
|
|
16
15
|
from skorch import NeuralNet
|
|
17
16
|
from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
|
|
18
|
-
from skorch.
|
|
19
|
-
|
|
20
|
-
from braindecode.datautil import infer_signal_properties
|
|
17
|
+
from skorch.helper import SliceDataset
|
|
18
|
+
from skorch.utils import is_dataset, noop, to_numpy, train_loss_score, valid_loss_score
|
|
21
19
|
|
|
20
|
+
from .datasets.base import BaseConcatDataset, WindowsDataset
|
|
22
21
|
from .models.util import models_dict
|
|
23
22
|
from .training.scoring import (
|
|
24
23
|
CroppedTimeSeriesEpochScoring,
|
|
@@ -53,6 +52,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
53
52
|
|
|
54
53
|
If the module is already initialized and no parameter was changed, it
|
|
55
54
|
will be left as is.
|
|
55
|
+
|
|
56
56
|
"""
|
|
57
57
|
kwargs = self.get_params_for("module")
|
|
58
58
|
module = _get_model(self.module)
|
|
@@ -174,9 +174,8 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
174
174
|
("print_log", PrintLog()),
|
|
175
175
|
]
|
|
176
176
|
|
|
177
|
-
@property
|
|
178
177
|
@abc.abstractmethod
|
|
179
|
-
def
|
|
178
|
+
def _get_n_outputs(self, y, classes):
|
|
180
179
|
pass
|
|
181
180
|
|
|
182
181
|
def _set_signal_args(self, X, y, classes):
|
|
@@ -192,8 +191,50 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
192
191
|
return
|
|
193
192
|
if classes is None:
|
|
194
193
|
classes = getattr(self, "classes", None)
|
|
195
|
-
|
|
196
|
-
|
|
194
|
+
# get kwargs from signal:
|
|
195
|
+
signal_kwargs = dict()
|
|
196
|
+
# Using shape to work both with torch.tensor and numpy.array:
|
|
197
|
+
if (
|
|
198
|
+
isinstance(X, mne.BaseEpochs)
|
|
199
|
+
or (hasattr(X, "shape") and len(X.shape) >= 2)
|
|
200
|
+
or isinstance(X, SliceDataset)
|
|
201
|
+
):
|
|
202
|
+
if y is None:
|
|
203
|
+
raise ValueError("y must be specified if X is array-like.")
|
|
204
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y, classes)
|
|
205
|
+
if isinstance(X, mne.BaseEpochs):
|
|
206
|
+
self.log.info("Using mne.Epochs to find signal-related parameters.")
|
|
207
|
+
signal_kwargs["n_times"] = len(X.times)
|
|
208
|
+
signal_kwargs["sfreq"] = X.info["sfreq"]
|
|
209
|
+
signal_kwargs["chs_info"] = X.info["chs"]
|
|
210
|
+
elif isinstance(X, SliceDataset):
|
|
211
|
+
self.log.info("Using SliceDataset to find signal-related parameters.")
|
|
212
|
+
Xshape = X[0].shape
|
|
213
|
+
signal_kwargs["n_times"] = Xshape[-1]
|
|
214
|
+
signal_kwargs["n_chans"] = Xshape[-2]
|
|
215
|
+
else:
|
|
216
|
+
self.log.info("Using array-like to find signal-related parameters.")
|
|
217
|
+
signal_kwargs["n_times"] = X.shape[-1]
|
|
218
|
+
signal_kwargs["n_chans"] = X.shape[-2]
|
|
219
|
+
elif is_dataset(X):
|
|
220
|
+
self.log.info(f"Using Dataset {X!r} to find signal-related parameters.")
|
|
221
|
+
X0 = X[0][0]
|
|
222
|
+
Xshape = X0.shape
|
|
223
|
+
signal_kwargs["n_times"] = Xshape[-1]
|
|
224
|
+
signal_kwargs["n_chans"] = Xshape[-2]
|
|
225
|
+
if isinstance(X, BaseConcatDataset) and all(
|
|
226
|
+
ds.targets_from == "metadata" for ds in X.datasets
|
|
227
|
+
):
|
|
228
|
+
y_target = X.get_metadata().target
|
|
229
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
|
|
230
|
+
elif isinstance(X, WindowsDataset) and X.targets_from == "metadata":
|
|
231
|
+
y_target = X.windows.metadata.target
|
|
232
|
+
signal_kwargs["n_outputs"] = self._get_n_outputs(y_target, classes)
|
|
233
|
+
else:
|
|
234
|
+
self.log.warning(
|
|
235
|
+
"Can only infer signal shape of array-like and Datasets, "
|
|
236
|
+
f"got {type(X)!r}."
|
|
237
|
+
)
|
|
197
238
|
return
|
|
198
239
|
|
|
199
240
|
# kick out missing kwargs:
|
|
@@ -208,18 +249,6 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
208
249
|
else:
|
|
209
250
|
self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
|
|
210
251
|
|
|
211
|
-
# kick out inferred signal kwargs if user specifies kwargs:
|
|
212
|
-
user_specified_kwargs = self.get_params_for("module").items()
|
|
213
|
-
if len(user_specified_kwargs) > 0:
|
|
214
|
-
self.log.info(
|
|
215
|
-
f"Overriding inferred parameters with user "
|
|
216
|
-
f"specified parameters{user_specified_kwargs!r}."
|
|
217
|
-
)
|
|
218
|
-
for k, v in self.get_params_for("module").items():
|
|
219
|
-
if k in module_kwargs:
|
|
220
|
-
module_kwargs.pop(k)
|
|
221
|
-
module_kwargs[k] = v
|
|
222
|
-
|
|
223
252
|
# save kwargs to self:
|
|
224
253
|
self.log.info(
|
|
225
254
|
f"Passing additional parameters {module_kwargs!r} "
|
|
@@ -229,8 +258,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
229
258
|
self.set_params(**module_kwargs)
|
|
230
259
|
|
|
231
260
|
def get_dataset(self, X, y=None):
|
|
232
|
-
"""Get a dataset that contains the input data and is passed to
|
|
233
|
-
|
|
261
|
+
"""Get a dataset that contains the input data and is passed to
|
|
234
262
|
the iterator.
|
|
235
263
|
|
|
236
264
|
Override this if you want to initialize your dataset
|
|
@@ -262,6 +290,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
262
290
|
-------
|
|
263
291
|
dataset
|
|
264
292
|
The initialized dataset.
|
|
293
|
+
|
|
265
294
|
"""
|
|
266
295
|
if isinstance(X, mne.BaseEpochs):
|
|
267
296
|
X = X.get_data(units="uV")
|
|
@@ -314,6 +343,7 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
|
314
343
|
**fit_params : dict
|
|
315
344
|
Additional parameters passed to the ``forward`` method of
|
|
316
345
|
the module and to the ``self.train_split`` call.
|
|
346
|
+
|
|
317
347
|
"""
|
|
318
348
|
# this needs to be executed before the net is initialized:
|
|
319
349
|
if not self.signal_args_set_:
|
|
@@ -24,13 +24,14 @@ def drop_path(
|
|
|
24
24
|
):
|
|
25
25
|
"""Drop paths (Stochastic Depth) per sample.
|
|
26
26
|
|
|
27
|
+
|
|
27
28
|
Notes: This implementation is taken from timm library.
|
|
28
29
|
|
|
29
30
|
All credit goes to Ross Wightman.
|
|
30
31
|
|
|
31
32
|
Parameters
|
|
32
33
|
----------
|
|
33
|
-
x
|
|
34
|
+
x: torch.Tensor
|
|
34
35
|
input tensor
|
|
35
36
|
drop_prob : float, optional
|
|
36
37
|
survival rate (i.e. probability of being kept), by default 0.0
|
|
@@ -50,10 +51,11 @@ def drop_path(
|
|
|
50
51
|
etc. networks, however,
|
|
51
52
|
the original name is misleading as 'Drop Connect' is a different form
|
|
52
53
|
of dropout in a separate paper...
|
|
53
|
-
See discussion
|
|
54
|
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
|
|
54
55
|
... I've opted for changing the layer and argument names to 'drop path'
|
|
55
56
|
rather than mix DropConnect as a layer name and use
|
|
56
57
|
'survival rate' as the argument.
|
|
58
|
+
|
|
57
59
|
"""
|
|
58
60
|
if drop_prob == 0.0 or not training:
|
|
59
61
|
return x
|
|
@@ -69,8 +71,7 @@ def drop_path(
|
|
|
69
71
|
|
|
70
72
|
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
|
|
71
73
|
"""
|
|
72
|
-
Generates a 1-dimensional Gaussian kernel based on the specified kernel
|
|
73
|
-
|
|
74
|
+
Generates a 1-dimensional Gaussian kernel based on the specified kernel
|
|
74
75
|
size and standard deviation (sigma).
|
|
75
76
|
This kernel is useful for Gaussian smoothing or filtering operations in
|
|
76
77
|
image processing. The function calculates a range limit to ensure the kernel
|
|
@@ -79,14 +80,15 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
|
|
|
79
80
|
Gaussian curve, normalized using a softmax function
|
|
80
81
|
to ensure all weights sum to 1.
|
|
81
82
|
|
|
83
|
+
|
|
82
84
|
Parameters
|
|
83
85
|
----------
|
|
84
|
-
kernel_size
|
|
85
|
-
sigma
|
|
86
|
+
kernel_size: int
|
|
87
|
+
sigma: float
|
|
86
88
|
|
|
87
89
|
Returns
|
|
88
90
|
-------
|
|
89
|
-
kernel1d
|
|
91
|
+
kernel1d: torch.Tensor
|
|
90
92
|
|
|
91
93
|
Notes
|
|
92
94
|
-----
|
|
@@ -95,6 +97,7 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
|
|
|
95
97
|
All rights reserved.
|
|
96
98
|
|
|
97
99
|
LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
|
|
100
|
+
|
|
98
101
|
"""
|
|
99
102
|
ksize_half = (kernel_size - 1) * 0.5
|
|
100
103
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
@@ -4,14 +4,13 @@ from torch import nn
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def glorot_weight_zero_bias(model):
|
|
7
|
-
"""Initialize parameters of all modules by initializing weights with
|
|
8
|
-
|
|
7
|
+
"""Initialize parameters of all modules by initializing weights with
|
|
9
8
|
glorot uniform/xavier initialization, and setting biases to zero. Weights from
|
|
10
9
|
batch norm layers are set to 1.
|
|
11
10
|
|
|
12
11
|
Parameters
|
|
13
12
|
----------
|
|
14
|
-
model
|
|
13
|
+
model: Module
|
|
15
14
|
"""
|
|
16
15
|
for module in model.modules():
|
|
17
16
|
if hasattr(module, "weight"):
|
braindecode/models/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Some predefined network architectures for EEG decoding.
|
|
3
|
+
"""
|
|
2
4
|
|
|
3
5
|
from .atcnet import ATCNet
|
|
4
6
|
from .attentionbasenet import AttentionBaseNet
|
|
@@ -6,7 +8,6 @@ from .attn_sleep import AttnSleep
|
|
|
6
8
|
from .base import EEGModuleMixin
|
|
7
9
|
from .bendr import BENDR
|
|
8
10
|
from .biot import BIOT
|
|
9
|
-
from .brainmodule import BrainModule
|
|
10
11
|
from .contrawr import ContraWR
|
|
11
12
|
from .ctnet import CTNet
|
|
12
13
|
from .deep4 import Deep4Net
|
|
@@ -31,7 +32,6 @@ from .luna import LUNA
|
|
|
31
32
|
from .medformer import MEDFormer
|
|
32
33
|
from .msvtnet import MSVTNet
|
|
33
34
|
from .patchedtransformer import PBT
|
|
34
|
-
from .reve import REVE
|
|
35
35
|
from .sccnet import SCCNet
|
|
36
36
|
from .shallow_fbcsp import ShallowFBCSPNet
|
|
37
37
|
from .signal_jepa import (
|
|
@@ -71,7 +71,6 @@ __all__ = [
|
|
|
71
71
|
"CTNet",
|
|
72
72
|
"Deep4Net",
|
|
73
73
|
"DeepSleepNet",
|
|
74
|
-
"BrainModule",
|
|
75
74
|
"EEGConformer",
|
|
76
75
|
"EEGInceptionERP",
|
|
77
76
|
"EEGInceptionMI",
|
|
@@ -94,7 +93,6 @@ __all__ = [
|
|
|
94
93
|
"MEDFormer",
|
|
95
94
|
"MSVTNet",
|
|
96
95
|
"PBT",
|
|
97
|
-
"REVE",
|
|
98
96
|
"SCCNet",
|
|
99
97
|
"ShallowFBCSPNet",
|
|
100
98
|
"SignalJEPA",
|
braindecode/models/atcnet.py
CHANGED
|
@@ -13,9 +13,9 @@ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class ATCNet(EEGModuleMixin, nn.Module):
|
|
16
|
-
|
|
16
|
+
"""ATCNet from Altaheri et al. (2022) [1]_.
|
|
17
17
|
|
|
18
|
-
:bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention
|
|
18
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Small Attention`
|
|
19
19
|
|
|
20
20
|
.. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
|
|
21
21
|
:align: center
|
|
@@ -83,8 +83,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
83
83
|
|
|
84
84
|
- :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
|
|
85
85
|
|
|
86
|
-
*Operations
|
|
87
|
-
|
|
86
|
+
- *Operations.*
|
|
88
87
|
- Two :class:`braindecode.modules.CausalConv1d` layers per block with dilation ``1, 2, 4, …``
|
|
89
88
|
- Across blocks of `torch.nn.ELU` + `torch.nn.BatchNorm1d` + `torch.nn.Dropout`) +
|
|
90
89
|
a residual (identity or 1x1 mapping).
|
|
@@ -95,12 +94,10 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
95
94
|
|
|
96
95
|
- **Aggregation & Classifier**
|
|
97
96
|
|
|
98
|
-
*Operations
|
|
99
|
-
|
|
97
|
+
- *Operations.*
|
|
100
98
|
- Either (a) map each window feature ``(B, F2)`` to logits via :class:`braindecode.modules.MaxNormLinear`
|
|
101
|
-
|
|
99
|
+
and **average** across windows (default, matching official code), or
|
|
102
100
|
- (b) **concatenate** all window features ``(B, n·F2)`` and apply a single :class:`MaxNormLinear`.
|
|
103
|
-
|
|
104
101
|
The max-norm constraint regularizes the readout.
|
|
105
102
|
|
|
106
103
|
.. rubric:: Convolutional Details
|
|
@@ -117,6 +114,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
117
114
|
producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
|
|
118
115
|
This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
|
|
119
116
|
|
|
117
|
+
|
|
120
118
|
.. rubric:: Attention / Sequential Modules
|
|
121
119
|
|
|
122
120
|
- **Type.** Multi-head self-attention with ``H`` heads and per-head dim ``d_h`` implemented
|
|
@@ -143,13 +141,26 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
143
141
|
- Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
|
|
144
142
|
``T_c = T/(P1·P2)`` and thus window width ``T_w``.
|
|
145
143
|
- ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
|
|
146
|
-
- ``
|
|
144
|
+
- ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
|
|
147
145
|
- ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
|
|
148
146
|
longer inputs (see minimum length above). The implementation warns and *rescales*
|
|
149
147
|
kernels/pools/windows if inputs are too short.
|
|
150
148
|
- **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
|
|
151
149
|
the official code; ``concat=True`` mirrors the paper's concatenation variant.
|
|
152
150
|
|
|
151
|
+
|
|
152
|
+
Notes
|
|
153
|
+
-----
|
|
154
|
+
- Inputs substantially shorter than the implied minimum length trigger **automatic
|
|
155
|
+
downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
|
|
156
|
+
- The attention–TCN sequence operates **per window**; the last causal step is used as the
|
|
157
|
+
window feature, aligning the temporal semantics across windows.
|
|
158
|
+
|
|
159
|
+
.. versionadded:: 1.1
|
|
160
|
+
|
|
161
|
+
- More detailed documentation of the model.
|
|
162
|
+
|
|
163
|
+
|
|
153
164
|
Parameters
|
|
154
165
|
----------
|
|
155
166
|
input_window_seconds : float, optional
|
|
@@ -183,10 +194,10 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
183
194
|
table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
184
195
|
n_windows : int
|
|
185
196
|
Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
|
|
186
|
-
|
|
197
|
+
att_head_dim : int
|
|
187
198
|
Embedding dimension used in each self-attention head, denoted dh in
|
|
188
199
|
table 1 of the paper [1]_. Defaults to 8 as in [1]_.
|
|
189
|
-
|
|
200
|
+
att_num_heads : int
|
|
190
201
|
Number of attention heads, denoted H in table 1 of the paper [1]_.
|
|
191
202
|
Defaults to 2 as in [1]_.
|
|
192
203
|
att_dropout : float
|
|
@@ -214,17 +225,6 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
214
225
|
Maximum L2-norm constraint imposed on weights of the last
|
|
215
226
|
fully-connected layer. Defaults to 0.25.
|
|
216
227
|
|
|
217
|
-
Notes
|
|
218
|
-
-----
|
|
219
|
-
- Inputs substantially shorter than the implied minimum length trigger **automatic
|
|
220
|
-
downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
|
|
221
|
-
- The attention–TCN sequence operates **per window**; the last causal step is used as the
|
|
222
|
-
window feature, aligning the temporal semantics across windows.
|
|
223
|
-
|
|
224
|
-
.. versionadded:: 1.1
|
|
225
|
-
|
|
226
|
-
- More detailed documentation of the model.
|
|
227
|
-
|
|
228
228
|
References
|
|
229
229
|
----------
|
|
230
230
|
.. [1] H. Altaheri, G. Muhammad, M. Alsulaiman (2022).
|
|
@@ -248,13 +248,13 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
248
248
|
conv_block_depth_mult=2,
|
|
249
249
|
conv_block_dropout=0.3,
|
|
250
250
|
n_windows=5,
|
|
251
|
-
|
|
252
|
-
|
|
251
|
+
att_head_dim=8,
|
|
252
|
+
att_num_heads=2,
|
|
253
253
|
att_drop_prob=0.5,
|
|
254
254
|
tcn_depth=2,
|
|
255
255
|
tcn_kernel_size=4,
|
|
256
256
|
tcn_drop_prob=0.3,
|
|
257
|
-
tcn_activation:
|
|
257
|
+
tcn_activation: nn.Module = nn.ELU,
|
|
258
258
|
concat=False,
|
|
259
259
|
max_norm_const=0.25,
|
|
260
260
|
chs_info=None,
|
|
@@ -316,8 +316,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
316
316
|
self.conv_block_depth_mult = conv_block_depth_mult
|
|
317
317
|
self.conv_block_dropout = conv_block_dropout
|
|
318
318
|
self.n_windows = n_windows
|
|
319
|
-
self.
|
|
320
|
-
self.
|
|
319
|
+
self.att_head_dim = att_head_dim
|
|
320
|
+
self.att_num_heads = att_num_heads
|
|
321
321
|
self.att_dropout = att_drop_prob
|
|
322
322
|
self.tcn_depth = tcn_depth
|
|
323
323
|
self.tcn_kernel_size = tcn_kernel_size
|
|
@@ -356,8 +356,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
356
356
|
[
|
|
357
357
|
_AttentionBlock(
|
|
358
358
|
in_shape=self.F2,
|
|
359
|
-
head_dim=self.
|
|
360
|
-
num_heads=
|
|
359
|
+
head_dim=self.att_head_dim,
|
|
360
|
+
num_heads=att_num_heads,
|
|
361
361
|
dropout=att_drop_prob,
|
|
362
362
|
)
|
|
363
363
|
for _ in range(self.n_windows)
|
|
@@ -460,8 +460,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
460
460
|
|
|
461
461
|
|
|
462
462
|
class _ConvBlock(nn.Module):
|
|
463
|
-
|
|
464
|
-
|
|
463
|
+
"""Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
|
|
465
464
|
architecture [2]_.
|
|
466
465
|
|
|
467
466
|
References
|
|
@@ -563,8 +562,7 @@ class _ConvBlock(nn.Module):
|
|
|
563
562
|
|
|
564
563
|
|
|
565
564
|
class _AttentionBlock(nn.Module):
|
|
566
|
-
|
|
567
|
-
|
|
565
|
+
"""Multi Head self Attention (MHA) block used in ATCNet [1]_, inspired from
|
|
568
566
|
[2]_.
|
|
569
567
|
|
|
570
568
|
References
|
|
@@ -638,9 +636,7 @@ class _AttentionBlock(nn.Module):
|
|
|
638
636
|
|
|
639
637
|
|
|
640
638
|
class _TCNResidualBlock(nn.Module):
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
Inspired from
|
|
639
|
+
"""Modified TCN Residual block as proposed in [1]_. Inspired from
|
|
644
640
|
Temporal Convolutional Networks (TCN) [2]_.
|
|
645
641
|
|
|
646
642
|
References
|
|
@@ -660,7 +656,7 @@ class _TCNResidualBlock(nn.Module):
|
|
|
660
656
|
kernel_size=4,
|
|
661
657
|
n_filters=32,
|
|
662
658
|
dropout=0.3,
|
|
663
|
-
activation:
|
|
659
|
+
activation: nn.Module = nn.ELU,
|
|
664
660
|
dilation=1,
|
|
665
661
|
):
|
|
666
662
|
super().__init__()
|
|
@@ -736,7 +732,7 @@ class _MHA(nn.Module):
|
|
|
736
732
|
num_heads: int,
|
|
737
733
|
dropout: float = 0.0,
|
|
738
734
|
):
|
|
739
|
-
"""Multi-head Attention
|
|
735
|
+
"""Multi-head Attention
|
|
740
736
|
|
|
741
737
|
The difference between this module and torch.nn.MultiheadAttention is
|
|
742
738
|
that this module supports embedding dimensions different then input
|
|
@@ -779,20 +775,20 @@ class _MHA(nn.Module):
|
|
|
779
775
|
def forward(
|
|
780
776
|
self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
|
|
781
777
|
) -> torch.Tensor:
|
|
782
|
-
"""Compute MHA(Q, K, V)
|
|
778
|
+
"""Compute MHA(Q, K, V)
|
|
783
779
|
|
|
784
780
|
Parameters
|
|
785
781
|
----------
|
|
786
|
-
Q
|
|
782
|
+
Q: torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
787
783
|
Input query (Q) sequence.
|
|
788
|
-
K
|
|
784
|
+
K: torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
789
785
|
Input key (K) sequence.
|
|
790
|
-
V
|
|
786
|
+
V: torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
791
787
|
Input value (V) sequence.
|
|
792
788
|
|
|
793
789
|
Returns
|
|
794
790
|
-------
|
|
795
|
-
O
|
|
791
|
+
O: torch.Tensor of size (batch_size, seq_len, output_dim)
|
|
796
792
|
Output MHA(Q, K, V)
|
|
797
793
|
"""
|
|
798
794
|
assert Q.shape[-1] == K.shape[-1] == V.shape[-1] == self.input_dim
|