braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
# Pierre Guetschel <pierre.guetschel@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
import inspect
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Literal
|
|
11
|
+
|
|
12
|
+
import mne
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
from sklearn.metrics import get_scorer
|
|
16
|
+
from skorch import NeuralNet
|
|
17
|
+
from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
|
|
18
|
+
from skorch.utils import noop, to_numpy, train_loss_score, valid_loss_score
|
|
19
|
+
|
|
20
|
+
from braindecode.datautil import infer_signal_properties
|
|
21
|
+
|
|
22
|
+
from .models.util import models_dict
|
|
23
|
+
from .training.scoring import (
|
|
24
|
+
CroppedTimeSeriesEpochScoring,
|
|
25
|
+
CroppedTrialEpochScoring,
|
|
26
|
+
PostEpochTrainScoring,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
log = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_model(model: str):
|
|
33
|
+
"""Returns the corresponding class in case the model passed is a string."""
|
|
34
|
+
if isinstance(model, str):
|
|
35
|
+
if model in models_dict:
|
|
36
|
+
model = models_dict[model]
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(f"Unknown model name {model!r}.")
|
|
39
|
+
return model
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class _EEGNeuralNet(NeuralNet, abc.ABC):
|
|
43
|
+
signal_args_set_ = False
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def log(self):
|
|
47
|
+
return log.getChild(self.__class__.__name__)
|
|
48
|
+
|
|
49
|
+
def initialize_module(self):
|
|
50
|
+
"""Initializes the module.
|
|
51
|
+
|
|
52
|
+
A Braindecode model name can also be passed as module argument.
|
|
53
|
+
|
|
54
|
+
If the module is already initialized and no parameter was changed, it
|
|
55
|
+
will be left as is.
|
|
56
|
+
"""
|
|
57
|
+
kwargs = self.get_params_for("module")
|
|
58
|
+
module = _get_model(self.module)
|
|
59
|
+
module = self.initialized_instance(module, kwargs)
|
|
60
|
+
# pylint: disable=attribute-defined-outside-init
|
|
61
|
+
self.module_ = module
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
def _yield_callbacks(self):
|
|
65
|
+
# Here we parse the callbacks supplied as strings,
|
|
66
|
+
# e.g. 'accuracy', to the callbacks skorch expects
|
|
67
|
+
for name, cb, named_by_user in super()._yield_callbacks():
|
|
68
|
+
if name == "str":
|
|
69
|
+
train_cb, valid_cb = self._parse_str_callback(cb)
|
|
70
|
+
yield train_cb
|
|
71
|
+
if self.train_split is not None:
|
|
72
|
+
yield valid_cb
|
|
73
|
+
else:
|
|
74
|
+
yield name, cb, named_by_user
|
|
75
|
+
|
|
76
|
+
def _parse_str_callback(self, cb_supplied_name):
|
|
77
|
+
scoring = get_scorer(cb_supplied_name)
|
|
78
|
+
scoring_name = scoring._score_func.__name__
|
|
79
|
+
assert scoring_name.endswith(("_score", "_error", "_deviance", "_loss"))
|
|
80
|
+
if scoring_name.endswith("_score") or cb_supplied_name.startswith("neg_"):
|
|
81
|
+
lower_is_better = False
|
|
82
|
+
else:
|
|
83
|
+
lower_is_better = True
|
|
84
|
+
train_name = f"train_{cb_supplied_name}"
|
|
85
|
+
valid_name = f"valid_{cb_supplied_name}"
|
|
86
|
+
if self.cropped:
|
|
87
|
+
train_scoring = CroppedTrialEpochScoring(
|
|
88
|
+
cb_supplied_name, lower_is_better, on_train=True, name=train_name
|
|
89
|
+
)
|
|
90
|
+
valid_scoring = CroppedTrialEpochScoring(
|
|
91
|
+
cb_supplied_name, lower_is_better, on_train=False, name=valid_name
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
train_scoring = PostEpochTrainScoring(
|
|
95
|
+
cb_supplied_name, lower_is_better, name=train_name
|
|
96
|
+
)
|
|
97
|
+
valid_scoring = EpochScoring(
|
|
98
|
+
cb_supplied_name, lower_is_better, on_train=False, name=valid_name
|
|
99
|
+
)
|
|
100
|
+
named_by_user = True
|
|
101
|
+
train_valid_callbacks = [
|
|
102
|
+
(train_name, train_scoring, named_by_user),
|
|
103
|
+
(valid_name, valid_scoring, named_by_user),
|
|
104
|
+
]
|
|
105
|
+
return train_valid_callbacks
|
|
106
|
+
|
|
107
|
+
def on_batch_end(self, net, *batch, training=False, **kwargs):
|
|
108
|
+
# If training is false, assume that our loader has indices for this
|
|
109
|
+
# batch
|
|
110
|
+
if not training:
|
|
111
|
+
epoch_cbs = []
|
|
112
|
+
for name, cb in self.callbacks_:
|
|
113
|
+
if (
|
|
114
|
+
isinstance(
|
|
115
|
+
cb, (CroppedTrialEpochScoring, CroppedTimeSeriesEpochScoring)
|
|
116
|
+
)
|
|
117
|
+
and (hasattr(cb, "window_inds_"))
|
|
118
|
+
and (not cb.on_train)
|
|
119
|
+
):
|
|
120
|
+
epoch_cbs.append(cb)
|
|
121
|
+
# for trialwise decoding stuffs it might also be we don't have
|
|
122
|
+
# cropped loader, so no indices there
|
|
123
|
+
if len(epoch_cbs) > 0:
|
|
124
|
+
assert self._last_window_inds_ is not None
|
|
125
|
+
for cb in epoch_cbs:
|
|
126
|
+
cb.window_inds_.append(self._last_window_inds_)
|
|
127
|
+
self._last_window_inds_ = None
|
|
128
|
+
|
|
129
|
+
def predict_with_window_inds_and_ys(self, dataset):
|
|
130
|
+
self.module.eval()
|
|
131
|
+
preds = []
|
|
132
|
+
i_window_in_trials = []
|
|
133
|
+
i_window_stops = []
|
|
134
|
+
window_ys = []
|
|
135
|
+
for X, y, i in self.get_iterator(dataset, drop_index=False):
|
|
136
|
+
i_window_in_trials.append(i[0].cpu().numpy())
|
|
137
|
+
i_window_stops.append(i[2].cpu().numpy())
|
|
138
|
+
with torch.no_grad():
|
|
139
|
+
preds.append(to_numpy(self.module.forward(X.to(self.device))))
|
|
140
|
+
window_ys.append(y.cpu().numpy())
|
|
141
|
+
preds = np.concatenate(preds)
|
|
142
|
+
i_window_in_trials = np.concatenate(i_window_in_trials)
|
|
143
|
+
i_window_stops = np.concatenate(i_window_stops)
|
|
144
|
+
window_ys = np.concatenate(window_ys)
|
|
145
|
+
return dict(
|
|
146
|
+
preds=preds,
|
|
147
|
+
i_window_in_trials=i_window_in_trials,
|
|
148
|
+
i_window_stops=i_window_stops,
|
|
149
|
+
window_ys=window_ys,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Changes the default target extractor to noop
|
|
153
|
+
@property
|
|
154
|
+
def _default_callbacks(self):
|
|
155
|
+
return [
|
|
156
|
+
("epoch_timer", EpochTimer()),
|
|
157
|
+
(
|
|
158
|
+
"train_loss",
|
|
159
|
+
BatchScoring(
|
|
160
|
+
train_loss_score,
|
|
161
|
+
name="train_loss",
|
|
162
|
+
on_train=True,
|
|
163
|
+
target_extractor=noop,
|
|
164
|
+
),
|
|
165
|
+
),
|
|
166
|
+
(
|
|
167
|
+
"valid_loss",
|
|
168
|
+
BatchScoring(
|
|
169
|
+
valid_loss_score,
|
|
170
|
+
name="valid_loss",
|
|
171
|
+
target_extractor=noop,
|
|
172
|
+
),
|
|
173
|
+
),
|
|
174
|
+
("print_log", PrintLog()),
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
@abc.abstractmethod
|
|
179
|
+
def mode(self) -> Literal["classification", "regression"]:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
def _set_signal_args(self, X, y, classes):
|
|
183
|
+
is_init = isinstance(self.module, torch.nn.Module)
|
|
184
|
+
if is_init:
|
|
185
|
+
self.log.info(
|
|
186
|
+
"The module passed is already initialized which is not recommended. "
|
|
187
|
+
"Instead, you can pass the module class and its parameters separately.\n"
|
|
188
|
+
"For more details, see "
|
|
189
|
+
"https://skorch.readthedocs.io/en/stable/user/neuralnet.html#module \n"
|
|
190
|
+
"Skipping setting signal-related parameters from data."
|
|
191
|
+
)
|
|
192
|
+
return
|
|
193
|
+
if classes is None:
|
|
194
|
+
classes = getattr(self, "classes", None)
|
|
195
|
+
signal_kwargs = infer_signal_properties(X, y, mode=self.mode, classes=classes)
|
|
196
|
+
if not signal_kwargs:
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
# kick out missing kwargs:
|
|
200
|
+
module_kwargs = dict()
|
|
201
|
+
module = _get_model(self.module)
|
|
202
|
+
all_module_kwargs = inspect.signature(module.__init__).parameters.keys()
|
|
203
|
+
for k, v in signal_kwargs.items():
|
|
204
|
+
if v is None:
|
|
205
|
+
continue
|
|
206
|
+
if k in all_module_kwargs:
|
|
207
|
+
module_kwargs[k] = v
|
|
208
|
+
else:
|
|
209
|
+
self.log.warning(f"Module {self.module!r} is missing parameter {k!r}.")
|
|
210
|
+
|
|
211
|
+
# kick out inferred signal kwargs if user specifies kwargs:
|
|
212
|
+
user_specified_kwargs = self.get_params_for("module").items()
|
|
213
|
+
if len(user_specified_kwargs) > 0:
|
|
214
|
+
self.log.info(
|
|
215
|
+
f"Overriding inferred parameters with user "
|
|
216
|
+
f"specified parameters{user_specified_kwargs!r}."
|
|
217
|
+
)
|
|
218
|
+
for k, v in self.get_params_for("module").items():
|
|
219
|
+
if k in module_kwargs:
|
|
220
|
+
module_kwargs.pop(k)
|
|
221
|
+
module_kwargs[k] = v
|
|
222
|
+
|
|
223
|
+
# save kwargs to self:
|
|
224
|
+
self.log.info(
|
|
225
|
+
f"Passing additional parameters {module_kwargs!r} "
|
|
226
|
+
f"to module {self.module!r}."
|
|
227
|
+
)
|
|
228
|
+
module_kwargs = {f"module__{k}": v for k, v in module_kwargs.items()}
|
|
229
|
+
self.set_params(**module_kwargs)
|
|
230
|
+
|
|
231
|
+
def get_dataset(self, X, y=None):
|
|
232
|
+
"""Get a dataset that contains the input data and is passed to.
|
|
233
|
+
|
|
234
|
+
the iterator.
|
|
235
|
+
|
|
236
|
+
Override this if you want to initialize your dataset
|
|
237
|
+
differently.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
242
|
+
By default, you should be able to pass:
|
|
243
|
+
|
|
244
|
+
* mne.Epochs
|
|
245
|
+
* numpy arrays
|
|
246
|
+
* torch tensors
|
|
247
|
+
* pandas DataFrame or Series
|
|
248
|
+
* scipy sparse CSR matrices
|
|
249
|
+
* a dictionary of the former three
|
|
250
|
+
* a list/tuple of the former three
|
|
251
|
+
* a Dataset
|
|
252
|
+
|
|
253
|
+
If this doesn't work with your data, you have to pass a
|
|
254
|
+
``Dataset`` that can deal with the data.
|
|
255
|
+
|
|
256
|
+
y : target data, compatible with skorch.dataset.Dataset
|
|
257
|
+
The same data types as for ``X`` are supported. If your X is
|
|
258
|
+
a Dataset that contains the target, ``y`` may be set to
|
|
259
|
+
None.
|
|
260
|
+
|
|
261
|
+
Returns
|
|
262
|
+
-------
|
|
263
|
+
dataset
|
|
264
|
+
The initialized dataset.
|
|
265
|
+
"""
|
|
266
|
+
if isinstance(X, mne.BaseEpochs):
|
|
267
|
+
X = X.get_data(units="uV")
|
|
268
|
+
return super().get_dataset(X, y)
|
|
269
|
+
|
|
270
|
+
def partial_fit(self, X, y=None, classes=None, **fit_params):
|
|
271
|
+
"""Fit the module.
|
|
272
|
+
|
|
273
|
+
If the module is initialized, it is not re-initialized, which
|
|
274
|
+
means that this method should be used if you want to continue
|
|
275
|
+
training a model (warm start).
|
|
276
|
+
If possible, signal-related parameters are inferred from the
|
|
277
|
+
data and passed to the module at initialisation.
|
|
278
|
+
Depending on the type of input passed, the following parameters
|
|
279
|
+
are inferred:
|
|
280
|
+
|
|
281
|
+
* mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
|
|
282
|
+
``sfreq``, ``input_window_seconds``
|
|
283
|
+
* array-like: ``n_times``, ``n_chans``, ``n_outputs``
|
|
284
|
+
* WindowsDataset with ``targets_from='metadata'``
|
|
285
|
+
(or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
|
|
286
|
+
* other Dataset: ``n_times``, ``n_chans``
|
|
287
|
+
* other types: no parameters are inferred.
|
|
288
|
+
|
|
289
|
+
Parameters
|
|
290
|
+
----------
|
|
291
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
292
|
+
By default, you should be able to pass:
|
|
293
|
+
|
|
294
|
+
* mne.Epochs
|
|
295
|
+
* numpy arrays
|
|
296
|
+
* torch tensors
|
|
297
|
+
* pandas DataFrame or Series
|
|
298
|
+
* scipy sparse CSR matrices
|
|
299
|
+
* a dictionary of the former three
|
|
300
|
+
* a list/tuple of the former three
|
|
301
|
+
* a Dataset
|
|
302
|
+
|
|
303
|
+
If this doesn't work with your data, you have to pass a
|
|
304
|
+
``Dataset`` that can deal with the data.
|
|
305
|
+
|
|
306
|
+
y : target data, compatible with skorch.dataset.Dataset
|
|
307
|
+
The same data types as for ``X`` are supported. If your X is
|
|
308
|
+
a Dataset that contains the target, ``y`` may be set to
|
|
309
|
+
None.
|
|
310
|
+
|
|
311
|
+
classes : array, sahpe (n_classes,)
|
|
312
|
+
Solely for sklearn compatibility, currently unused.
|
|
313
|
+
|
|
314
|
+
**fit_params : dict
|
|
315
|
+
Additional parameters passed to the ``forward`` method of
|
|
316
|
+
the module and to the ``self.train_split`` call.
|
|
317
|
+
"""
|
|
318
|
+
# this needs to be executed before the net is initialized:
|
|
319
|
+
if not self.signal_args_set_:
|
|
320
|
+
self._set_signal_args(X, y, classes)
|
|
321
|
+
self.signal_args_set_ = True
|
|
322
|
+
return super().partial_fit(X=X, y=y, classes=classes, **fit_params)
|
|
323
|
+
|
|
324
|
+
def fit(self, X, y=None, **fit_params):
|
|
325
|
+
"""Initialize and fit the module.
|
|
326
|
+
|
|
327
|
+
If the module was already initialized, by calling fit, the
|
|
328
|
+
module will be re-initialized (unless ``warm_start`` is True).
|
|
329
|
+
If possible, signal-related parameters are inferred from the
|
|
330
|
+
data and passed to the module at initialisation.
|
|
331
|
+
Depending on the type of input passed, the following parameters
|
|
332
|
+
are inferred:
|
|
333
|
+
|
|
334
|
+
* mne.Epochs: ``n_times``, ``n_chans``, ``n_outputs``, ``chs_info``,
|
|
335
|
+
``sfreq``, ``input_window_seconds``
|
|
336
|
+
* array-like: ``n_times``, ``n_chans``, ``n_outputs``
|
|
337
|
+
* WindowsDataset with ``targets_from='metadata'``
|
|
338
|
+
(or BaseConcatDataset of such datasets): ``n_times``, ``n_chans``, ``n_outputs``
|
|
339
|
+
* other Dataset: ``n_times``, ``n_chans``
|
|
340
|
+
* other types: no parameters are inferred.
|
|
341
|
+
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
X : input data, compatible with skorch.dataset.Dataset
|
|
345
|
+
By default, you should be able to pass:
|
|
346
|
+
|
|
347
|
+
* mne.Epochs
|
|
348
|
+
* numpy arrays
|
|
349
|
+
* torch tensors
|
|
350
|
+
* pandas DataFrame or Series
|
|
351
|
+
* scipy sparse CSR matrices
|
|
352
|
+
* a dictionary of the former three
|
|
353
|
+
* a list/tuple of the former three
|
|
354
|
+
* a Dataset
|
|
355
|
+
|
|
356
|
+
If this doesn't work with your data, you have to pass a
|
|
357
|
+
``Dataset`` that can deal with the data.
|
|
358
|
+
|
|
359
|
+
y : target data, compatible with skorch.dataset.Dataset
|
|
360
|
+
The same data types as for ``X`` are supported. If your X is
|
|
361
|
+
a Dataset that contains the target, ``y`` may be set to
|
|
362
|
+
None.
|
|
363
|
+
|
|
364
|
+
**fit_params : dict
|
|
365
|
+
Additional parameters passed to the ``forward`` method of
|
|
366
|
+
the module and to the ``self.train_split`` call.
|
|
367
|
+
"""
|
|
368
|
+
# this needs to be executed before the net is initialized:
|
|
369
|
+
if not self.signal_args_set_:
|
|
370
|
+
self._set_signal_args(X, y, classes=None)
|
|
371
|
+
self.signal_args_set_ = True
|
|
372
|
+
return super().fit(X=X, y=y, **fit_params)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .functions import (
|
|
2
|
+
_get_gaussian_kernel1d,
|
|
3
|
+
drop_path,
|
|
4
|
+
hilbert_freq,
|
|
5
|
+
identity,
|
|
6
|
+
plv_time,
|
|
7
|
+
safe_log,
|
|
8
|
+
square,
|
|
9
|
+
)
|
|
10
|
+
from .initialization import glorot_weight_zero_bias, rescale_parameter
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"_get_gaussian_kernel1d",
|
|
14
|
+
"drop_path",
|
|
15
|
+
"hilbert_freq",
|
|
16
|
+
"identity",
|
|
17
|
+
"plv_time",
|
|
18
|
+
"safe_log",
|
|
19
|
+
"square",
|
|
20
|
+
"glorot_weight_zero_bias",
|
|
21
|
+
"rescale_parameter",
|
|
22
|
+
]
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def square(x):
|
|
10
|
+
return x * x
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def safe_log(x, eps: float = 1e-6) -> torch.Tensor:
|
|
14
|
+
"""Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
|
|
15
|
+
return torch.log(torch.clamp(x, min=eps))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def identity(x):
|
|
19
|
+
return x
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def drop_path(
|
|
23
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
24
|
+
):
|
|
25
|
+
"""Drop paths (Stochastic Depth) per sample.
|
|
26
|
+
|
|
27
|
+
Notes: This implementation is taken from timm library.
|
|
28
|
+
|
|
29
|
+
All credit goes to Ross Wightman.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
x : torch.Tensor
|
|
34
|
+
input tensor
|
|
35
|
+
drop_prob : float, optional
|
|
36
|
+
survival rate (i.e. probability of being kept), by default 0.0
|
|
37
|
+
training : bool, optional
|
|
38
|
+
whether the model is in training mode, by default False
|
|
39
|
+
scale_by_keep : bool, optional
|
|
40
|
+
whether to scale output by (1/keep_prob) during training, by default True
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
torch.Tensor
|
|
45
|
+
output tensor
|
|
46
|
+
|
|
47
|
+
Notes from Ross Wightman:
|
|
48
|
+
(when applied in main path of residual blocks)
|
|
49
|
+
This is the same as the DropConnect impl I created for EfficientNet,
|
|
50
|
+
etc. networks, however,
|
|
51
|
+
the original name is misleading as 'Drop Connect' is a different form
|
|
52
|
+
of dropout in a separate paper...
|
|
53
|
+
See discussion : https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
|
|
54
|
+
... I've opted for changing the layer and argument names to 'drop path'
|
|
55
|
+
rather than mix DropConnect as a layer name and use
|
|
56
|
+
'survival rate' as the argument.
|
|
57
|
+
"""
|
|
58
|
+
if drop_prob == 0.0 or not training:
|
|
59
|
+
return x
|
|
60
|
+
keep_prob = 1 - drop_prob
|
|
61
|
+
shape = (x.shape[0],) + (1,) * (
|
|
62
|
+
x.ndim - 1
|
|
63
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
|
64
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
65
|
+
if keep_prob > 0.0 and scale_by_keep:
|
|
66
|
+
random_tensor.div_(keep_prob)
|
|
67
|
+
return x * random_tensor
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> torch.Tensor:
|
|
71
|
+
"""
|
|
72
|
+
Generates a 1-dimensional Gaussian kernel based on the specified kernel.
|
|
73
|
+
|
|
74
|
+
size and standard deviation (sigma).
|
|
75
|
+
This kernel is useful for Gaussian smoothing or filtering operations in
|
|
76
|
+
image processing. The function calculates a range limit to ensure the kernel
|
|
77
|
+
effectively covers the Gaussian distribution. It generates a tensor of
|
|
78
|
+
specified size and type, filled with values distributed according to a
|
|
79
|
+
Gaussian curve, normalized using a softmax function
|
|
80
|
+
to ensure all weights sum to 1.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
kernel_size : int
|
|
85
|
+
sigma : float
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
kernel1d : torch.Tensor
|
|
90
|
+
|
|
91
|
+
Notes
|
|
92
|
+
-----
|
|
93
|
+
Code copied and modified from TorchVision:
|
|
94
|
+
https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py#L725-L732
|
|
95
|
+
All rights reserved.
|
|
96
|
+
|
|
97
|
+
LICENSE in https://github.com/pytorch/vision/blob/main/LICENSE
|
|
98
|
+
"""
|
|
99
|
+
ksize_half = (kernel_size - 1) * 0.5
|
|
100
|
+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
101
|
+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
102
|
+
kernel1d = pdf / pdf.sum()
|
|
103
|
+
return kernel1d
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def hilbert_freq(x, forward_fourier=True):
|
|
107
|
+
r"""
|
|
108
|
+
Compute the Hilbert transform using PyTorch, separating the real and
|
|
109
|
+
imaginary parts.
|
|
110
|
+
|
|
111
|
+
The analytic signal :math:`x_a(t)` of a real-valued signal :math:`x(t)`
|
|
112
|
+
is defined as:
|
|
113
|
+
|
|
114
|
+
.. math::
|
|
115
|
+
|
|
116
|
+
x_a(t) = x(t) + i y(t) = \mathcal{F}^{-1} \{ U(f) \mathcal{F}\{x(t)\} \}
|
|
117
|
+
|
|
118
|
+
where:
|
|
119
|
+
- :math:`\mathcal{F}` is the Fourier transform,
|
|
120
|
+
- :math:`U(f)` is the unit step function,
|
|
121
|
+
- :math:`y(t)` is the Hilbert transform of :math:`x(t)`.
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
input : torch.Tensor
|
|
127
|
+
Input tensor. The expected shape depends on the `forward_fourier` parameter:
|
|
128
|
+
|
|
129
|
+
- If `forward_fourier` is True:
|
|
130
|
+
(..., seq_len)
|
|
131
|
+
- If `forward_fourier` is False:
|
|
132
|
+
(..., seq_len / 2 + 1, 2)
|
|
133
|
+
|
|
134
|
+
forward_fourier : bool, optional
|
|
135
|
+
Determines the format of the input tensor.
|
|
136
|
+
- If True, the input is in the forward Fourier domain.
|
|
137
|
+
- If False, the input contains separate real and imaginary parts.
|
|
138
|
+
Default is True.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
torch.Tensor
|
|
143
|
+
Output tensor with shape (..., seq_len, 2), where the last dimension represents
|
|
144
|
+
the real and imaginary parts of the Hilbert transform.
|
|
145
|
+
|
|
146
|
+
Examples
|
|
147
|
+
--------
|
|
148
|
+
>>> import torch
|
|
149
|
+
>>> input = torch.randn(10, 100) # Example input tensor
|
|
150
|
+
>>> output = hilbert_transform(input)
|
|
151
|
+
>>> print(output.shape)
|
|
152
|
+
torch.Size([10, 100, 2])
|
|
153
|
+
|
|
154
|
+
Notes
|
|
155
|
+
-----
|
|
156
|
+
The implementation is matching scipy implementation, but using torch.
|
|
157
|
+
https://github.com/scipy/scipy/blob/v1.14.1/scipy/signal/_signaltools.py#L2287-L2394
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
if forward_fourier:
|
|
161
|
+
x = torch.fft.rfft(x, norm=None, dim=-1)
|
|
162
|
+
x = torch.view_as_real(x)
|
|
163
|
+
x = x * 2.0
|
|
164
|
+
x[..., 0, :] = x[..., 0, :] / 2.0 # Don't multiply the DC-term by 2
|
|
165
|
+
x = F.pad(
|
|
166
|
+
x, [0, 0, 0, x.shape[-2] - 2]
|
|
167
|
+
) # Fill Fourier coefficients to retain shape
|
|
168
|
+
x = torch.view_as_complex(x)
|
|
169
|
+
x = torch.fft.ifft(x, norm=None, dim=-1) # returns complex signal
|
|
170
|
+
x = torch.view_as_real(x)
|
|
171
|
+
|
|
172
|
+
return x
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def plv_time(x, forward_fourier=True, epsilon: float = 1e-6):
|
|
176
|
+
"""Compute the Phase Locking Value (PLV) metric in the time domain.
|
|
177
|
+
|
|
178
|
+
The Phase Locking Value (PLV) is a measure of the synchronization between
|
|
179
|
+
different channels by evaluating the consistency of phase differences
|
|
180
|
+
over time. It ranges from 0 (no synchronization) to 1 (perfect
|
|
181
|
+
synchronization) [Lachaux1999]_.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
x : torch.Tensor
|
|
186
|
+
Input tensor containing the signal data.
|
|
187
|
+
|
|
188
|
+
- If `forward_fourier` is `True`, the shape should be `(..., channels, time)`.
|
|
189
|
+
- If `forward_fourier` is `False`, the shape should be `(..., channels, freqs, 2)`,
|
|
190
|
+
where the last dimension represents the real and imaginary parts.
|
|
191
|
+
|
|
192
|
+
forward_fourier : bool, optional
|
|
193
|
+
Specifies the format of the input tensor `x`.
|
|
194
|
+
|
|
195
|
+
- If `True`, `x` is assumed to be in the time domain.
|
|
196
|
+
- If `False`, `x` is assumed to be in the Fourier domain with separate real and
|
|
197
|
+
imaginary components.
|
|
198
|
+
|
|
199
|
+
Default is `True`.
|
|
200
|
+
epsilon : float, default 1e-6
|
|
201
|
+
Small numerical value to ensure positivity constraint on the complex part
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
plv : torch.Tensor
|
|
206
|
+
The Phase Locking Value matrix with shape `(..., channels, channels)`. Each
|
|
207
|
+
element `[i, j]` represents the PLV between channel `i` and channel `j`.
|
|
208
|
+
|
|
209
|
+
References
|
|
210
|
+
----------
|
|
211
|
+
.. [Lachaux1999] Lachaux, J. P., Rodriguez, E., Martinerie, J., & Varela, F. J. (1999).
|
|
212
|
+
Measuring phase synchrony in brain signals. Human brain mapping,
|
|
213
|
+
8(4), 194-208.
|
|
214
|
+
"""
|
|
215
|
+
# Compute the analytic signal using the Hilbert transform.
|
|
216
|
+
# x_a has separate real and imaginary parts.
|
|
217
|
+
analytic_signal = hilbert_freq(x, forward_fourier)
|
|
218
|
+
# Calculate the amplitude (magnitude) of the analytic signal.
|
|
219
|
+
# Adding a small epsilon (1e-6) to avoid division by zero.
|
|
220
|
+
amplitude = torch.sqrt(
|
|
221
|
+
analytic_signal[..., 0] ** 2 + analytic_signal[..., 1] ** 2 + 1e-6
|
|
222
|
+
)
|
|
223
|
+
# Normalize the analytic signal to obtain unit vectors (phasors).
|
|
224
|
+
unit_phasor = analytic_signal / amplitude.unsqueeze(-1)
|
|
225
|
+
|
|
226
|
+
# Compute the real part of the outer product between phasors of
|
|
227
|
+
# different channels.
|
|
228
|
+
real_real = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 0].transpose(-2, -1))
|
|
229
|
+
|
|
230
|
+
# Compute the imaginary part of the outer product between phasors of
|
|
231
|
+
# different channels.
|
|
232
|
+
imag_imag = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 1].transpose(-2, -1))
|
|
233
|
+
|
|
234
|
+
# Compute the cross-terms for the real and imaginary parts.
|
|
235
|
+
real_imag = torch.matmul(unit_phasor[..., 0], unit_phasor[..., 1].transpose(-2, -1))
|
|
236
|
+
imag_real = torch.matmul(unit_phasor[..., 1], unit_phasor[..., 0].transpose(-2, -1))
|
|
237
|
+
|
|
238
|
+
# Combine the real and imaginary parts to form the complex correlation.
|
|
239
|
+
correlation_real = real_real + imag_imag
|
|
240
|
+
correlation_imag = real_imag - imag_real
|
|
241
|
+
|
|
242
|
+
# Determine the number of time points (or frequency bins if in Fourier domain).
|
|
243
|
+
time = amplitude.shape[-1]
|
|
244
|
+
|
|
245
|
+
# Calculate the PLV by averaging the magnitude of the complex correlation over time.
|
|
246
|
+
# epsilon is small numerical value to ensure positivity constraint on the complex part
|
|
247
|
+
plv_matrix = (
|
|
248
|
+
1 / time * torch.sqrt(correlation_real**2 + correlation_imag**2 + epsilon)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return plv_matrix
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def glorot_weight_zero_bias(model):
|
|
7
|
+
"""Initialize parameters of all modules by initializing weights with.
|
|
8
|
+
|
|
9
|
+
glorot uniform/xavier initialization, and setting biases to zero. Weights from
|
|
10
|
+
batch norm layers are set to 1.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
model : Module
|
|
15
|
+
"""
|
|
16
|
+
for module in model.modules():
|
|
17
|
+
if hasattr(module, "weight"):
|
|
18
|
+
if "BatchNorm" in module.__class__.__name__:
|
|
19
|
+
nn.init.constant_(module.weight, 1)
|
|
20
|
+
if hasattr(module, "bias"):
|
|
21
|
+
if module.bias is not None:
|
|
22
|
+
nn.init.constant_(module.bias, 0)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def rescale_parameter(param, layer_id):
|
|
26
|
+
r"""Recaling the l-th transformer layer.
|
|
27
|
+
|
|
28
|
+
Rescales the parameter tensor by the inverse square root of the layer id.
|
|
29
|
+
Made inplace. :math:`\frac{1}{\sqrt{2 \cdot \text{layer\_id}}}` [Beit2022]
|
|
30
|
+
|
|
31
|
+
In the labram, this is used to rescale the output matrices
|
|
32
|
+
(i.e., the last linear projection within each sub-layer) of the
|
|
33
|
+
self-attention module.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
param: :class:`torch.Tensor`
|
|
38
|
+
tensor to be rescaled
|
|
39
|
+
layer_id: int
|
|
40
|
+
layer id in the neural network
|
|
41
|
+
|
|
42
|
+
References
|
|
43
|
+
----------
|
|
44
|
+
[Beit2022] Hangbo Bao, Li Dong, Songhao Piao, Furu We (2022). BEIT: BERT
|
|
45
|
+
Pre-Training of Image Transformers.
|
|
46
|
+
"""
|
|
47
|
+
param.div_(math.sqrt(2.0 * layer_id))
|