braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
2
|
+
# Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
3
|
+
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
4
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
5
|
+
# Mohammed Fattouh <mo.fattouh@gmail.com>
|
|
6
|
+
#
|
|
7
|
+
# License: BSD-3
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from mne.utils.check import check_version
|
|
15
|
+
from skorch.callbacks.scoring import EpochScoring
|
|
16
|
+
from skorch.dataset import unpack_data
|
|
17
|
+
from skorch.utils import to_numpy
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials):
|
|
22
|
+
"""
|
|
23
|
+
Assigning window predictions to trials while removing duplicate
|
|
24
|
+
predictions.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
preds: list of ndarrays (at least 2darrays)
|
|
29
|
+
List of window predictions, in each window prediction
|
|
30
|
+
time is in axis=1
|
|
31
|
+
i_window_in_trials: list
|
|
32
|
+
Index/number of window in trial
|
|
33
|
+
i_stop_in_trials: list
|
|
34
|
+
stop position of window in trial
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
preds_per_trial: list of ndarrays
|
|
39
|
+
Predictions in each trial, duplicates removed
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
assert len(preds) == len(i_window_in_trials) == len(i_stop_in_trials), (
|
|
43
|
+
f"{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Algorithm for assigning window predictions to trials
|
|
47
|
+
# while removing duplicate predictions:
|
|
48
|
+
# Loop through windows:
|
|
49
|
+
# In each iteration you have predictions (assumed: #classes x #timesteps,
|
|
50
|
+
# or at least #timesteps must be in axis=1)
|
|
51
|
+
# and you have i_window_in_trial, i_stop_in_trial
|
|
52
|
+
# (i_trial removed from variable names for brevity)
|
|
53
|
+
# You first check if the i_window_in_trial is 1 larger
|
|
54
|
+
# than in last iteration, then you are still in the same trial
|
|
55
|
+
# Otherwise you are in a new trial
|
|
56
|
+
# If you are in the same trial, you check for duplicate predictions
|
|
57
|
+
# Only take predictions that are after (inclusive)
|
|
58
|
+
# the stop of the last iteration (i.e., the index of final prediction
|
|
59
|
+
# in the last iteration)
|
|
60
|
+
# Then add the duplicate-removed predictions from this window
|
|
61
|
+
# to predictions for current trial
|
|
62
|
+
preds_per_trial = []
|
|
63
|
+
cur_trial_preds = []
|
|
64
|
+
i_last_stop = None
|
|
65
|
+
i_last_window = -1
|
|
66
|
+
for window_preds, i_window, i_stop in zip(
|
|
67
|
+
preds, i_window_in_trials, i_stop_in_trials
|
|
68
|
+
):
|
|
69
|
+
window_preds = np.array(window_preds)
|
|
70
|
+
if i_window != (i_last_window + 1):
|
|
71
|
+
assert i_window == 0, "window numbers in new trial should start from 0"
|
|
72
|
+
preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
|
|
73
|
+
cur_trial_preds = []
|
|
74
|
+
i_last_stop = None
|
|
75
|
+
|
|
76
|
+
if i_last_stop is not None:
|
|
77
|
+
# Remove duplicates
|
|
78
|
+
n_needed_preds = i_stop - i_last_stop
|
|
79
|
+
window_preds = window_preds[:, -n_needed_preds:]
|
|
80
|
+
cur_trial_preds.append(window_preds)
|
|
81
|
+
i_last_window = i_window
|
|
82
|
+
i_last_stop = i_stop
|
|
83
|
+
# add last trial preds
|
|
84
|
+
preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
|
|
85
|
+
return preds_per_trial
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@contextmanager
|
|
89
|
+
def _cache_net_forward_iter(net, use_caching, y_preds):
|
|
90
|
+
"""Caching context for ``skorch.NeuralNet`` instance.
|
|
91
|
+
Returns a modified version of the net whose ``forward_iter``
|
|
92
|
+
method will subsequently return cached predictions. Leaving the
|
|
93
|
+
context will undo the overwrite of the ``forward_iter`` method.
|
|
94
|
+
"""
|
|
95
|
+
if not use_caching:
|
|
96
|
+
yield net
|
|
97
|
+
return
|
|
98
|
+
y_preds = iter(y_preds)
|
|
99
|
+
|
|
100
|
+
# pylint: disable=unused-argument
|
|
101
|
+
def cached_forward_iter(*args, device=net.device, **kwargs):
|
|
102
|
+
for yp in y_preds:
|
|
103
|
+
yield yp.to(device=device)
|
|
104
|
+
|
|
105
|
+
net.forward_iter = cached_forward_iter
|
|
106
|
+
try:
|
|
107
|
+
yield net
|
|
108
|
+
finally:
|
|
109
|
+
# By setting net.forward_iter we define an attribute
|
|
110
|
+
# `forward_iter` that precedes the bound method
|
|
111
|
+
# `forward_iter`. By deleting the entry from the attribute
|
|
112
|
+
# dict we undo this.
|
|
113
|
+
del net.__dict__["forward_iter"]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class CroppedTrialEpochScoring(EpochScoring):
|
|
117
|
+
"""
|
|
118
|
+
Class to compute scores for trials from a model that predicts (super)crops.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
# XXX needs a docstring !!!
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
scoring,
|
|
126
|
+
lower_is_better=True,
|
|
127
|
+
on_train=False,
|
|
128
|
+
name=None,
|
|
129
|
+
target_extractor=to_numpy,
|
|
130
|
+
use_caching=True,
|
|
131
|
+
):
|
|
132
|
+
super().__init__(
|
|
133
|
+
scoring=scoring,
|
|
134
|
+
lower_is_better=lower_is_better,
|
|
135
|
+
on_train=on_train,
|
|
136
|
+
name=name,
|
|
137
|
+
target_extractor=target_extractor,
|
|
138
|
+
use_caching=use_caching,
|
|
139
|
+
)
|
|
140
|
+
if not self.on_train:
|
|
141
|
+
self.window_inds_ = []
|
|
142
|
+
|
|
143
|
+
def _initialize_cache(self):
|
|
144
|
+
super()._initialize_cache()
|
|
145
|
+
self.crops_to_trials_computed = False
|
|
146
|
+
self.y_trues_ = []
|
|
147
|
+
self.y_preds_ = []
|
|
148
|
+
if not self.on_train:
|
|
149
|
+
self.window_inds_ = []
|
|
150
|
+
|
|
151
|
+
def on_batch_end(self, net, batch, y_pred, training, **kwargs):
|
|
152
|
+
# Skorch saves the predictions without moving them from GPU
|
|
153
|
+
# https://github.com/skorch-dev/skorch/blob/fe71e3d55a4ae5f5f94ef7bdfc00fca3b3fd267f/skorch/callbacks/scoring.py#L385
|
|
154
|
+
# This can cause memory issues in case of a large number of predictions
|
|
155
|
+
# Therefore here we move them to CPU already
|
|
156
|
+
super().on_batch_end(net, batch, y_pred, training, **kwargs)
|
|
157
|
+
if self.use_caching and training == self.on_train:
|
|
158
|
+
self.y_preds_[-1] = self.y_preds_[-1].cpu()
|
|
159
|
+
|
|
160
|
+
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
|
161
|
+
assert self.use_caching
|
|
162
|
+
if not self.crops_to_trials_computed:
|
|
163
|
+
if self.on_train:
|
|
164
|
+
# Prevent that rng state of torch is changed by
|
|
165
|
+
# creation+usage of iterator
|
|
166
|
+
rng_state = torch.random.get_rng_state()
|
|
167
|
+
pred_results = net.predict_with_window_inds_and_ys(dataset_train)
|
|
168
|
+
torch.random.set_rng_state(rng_state)
|
|
169
|
+
else:
|
|
170
|
+
pred_results = {}
|
|
171
|
+
pred_results["i_window_in_trials"] = np.concatenate(
|
|
172
|
+
[i[0].cpu().numpy() for i in self.window_inds_]
|
|
173
|
+
)
|
|
174
|
+
pred_results["i_window_stops"] = np.concatenate(
|
|
175
|
+
[i[2].cpu().numpy() for i in self.window_inds_]
|
|
176
|
+
)
|
|
177
|
+
pred_results["preds"] = np.concatenate(
|
|
178
|
+
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
179
|
+
)
|
|
180
|
+
pred_results["window_ys"] = np.concatenate(
|
|
181
|
+
[y.cpu().numpy() for y in self.y_trues_]
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# A new trial starts
|
|
185
|
+
# when the index of the window in trials
|
|
186
|
+
# does not increment by 1
|
|
187
|
+
# Add dummy infinity at start
|
|
188
|
+
window_0_per_trial_mask = (
|
|
189
|
+
np.diff(pred_results["i_window_in_trials"], prepend=[np.inf]) != 1
|
|
190
|
+
)
|
|
191
|
+
trial_ys = pred_results["window_ys"][window_0_per_trial_mask]
|
|
192
|
+
trial_preds = trial_preds_from_window_preds(
|
|
193
|
+
pred_results["preds"],
|
|
194
|
+
pred_results["i_window_in_trials"],
|
|
195
|
+
pred_results["i_window_stops"],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Average across the timesteps of each trial so we have per-trial
|
|
199
|
+
# predictions already, these will be just passed through the forward
|
|
200
|
+
# method of the classifier/regressor to the skorch scoring function.
|
|
201
|
+
# trial_preds is a list, each item is a 2d array classes x time
|
|
202
|
+
y_preds_per_trial = np.array([np.mean(p, axis=1) for p in trial_preds])
|
|
203
|
+
# Move into format expected by skorch (list of torch tensors)
|
|
204
|
+
y_preds_per_trial = [torch.tensor(y_preds_per_trial)]
|
|
205
|
+
|
|
206
|
+
# Store the computed trial preds for all Cropped Callbacks
|
|
207
|
+
# that are also on same set
|
|
208
|
+
cbs = net.callbacks_
|
|
209
|
+
epoch_cbs = [
|
|
210
|
+
cb
|
|
211
|
+
for name, cb in cbs
|
|
212
|
+
if isinstance(cb, CroppedTrialEpochScoring)
|
|
213
|
+
and (cb.on_train == self.on_train)
|
|
214
|
+
]
|
|
215
|
+
for cb in epoch_cbs:
|
|
216
|
+
cb.y_preds_ = y_preds_per_trial
|
|
217
|
+
cb.y_trues_ = trial_ys
|
|
218
|
+
cb.crops_to_trials_computed = True
|
|
219
|
+
|
|
220
|
+
dataset = dataset_train if self.on_train else dataset_valid
|
|
221
|
+
|
|
222
|
+
with _cache_net_forward_iter(
|
|
223
|
+
net, self.use_caching, self.y_preds_
|
|
224
|
+
) as cached_net:
|
|
225
|
+
current_score = self._scoring(cached_net, dataset, self.y_trues_)
|
|
226
|
+
self._record_score(net.history, current_score)
|
|
227
|
+
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
|
|
232
|
+
"""
|
|
233
|
+
Class to compute scores for trials from a model that predicts (super)crops with
|
|
234
|
+
time series target.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
|
238
|
+
assert self.use_caching
|
|
239
|
+
if not self.crops_to_trials_computed:
|
|
240
|
+
if self.on_train:
|
|
241
|
+
# Prevent that rng state of torch is changed by
|
|
242
|
+
# creation+usage of iterator
|
|
243
|
+
rng_state = torch.random.get_rng_state()
|
|
244
|
+
pred_results = net.predict_with_window_inds_and_ys(dataset_train)
|
|
245
|
+
torch.random.set_rng_state(rng_state)
|
|
246
|
+
else:
|
|
247
|
+
pred_results = {}
|
|
248
|
+
pred_results["i_window_in_trials"] = np.concatenate(
|
|
249
|
+
[i[0].cpu().numpy() for i in self.window_inds_]
|
|
250
|
+
)
|
|
251
|
+
pred_results["i_window_stops"] = np.concatenate(
|
|
252
|
+
[i[2].cpu().numpy() for i in self.window_inds_]
|
|
253
|
+
)
|
|
254
|
+
pred_results["preds"] = np.concatenate(
|
|
255
|
+
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
256
|
+
)
|
|
257
|
+
pred_results["window_ys"] = np.concatenate(
|
|
258
|
+
[y.cpu().numpy() for y in self.y_trues_]
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
num_preds = pred_results["preds"][-1].shape[-1]
|
|
262
|
+
# slice the targets to fit preds shape
|
|
263
|
+
pred_results["window_ys"] = [
|
|
264
|
+
targets[:, -num_preds:] for targets in pred_results["window_ys"]
|
|
265
|
+
]
|
|
266
|
+
|
|
267
|
+
trial_preds = trial_preds_from_window_preds(
|
|
268
|
+
pred_results["preds"],
|
|
269
|
+
pred_results["i_window_in_trials"],
|
|
270
|
+
pred_results["i_window_stops"],
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
trial_ys = trial_preds_from_window_preds(
|
|
274
|
+
pred_results["window_ys"],
|
|
275
|
+
pred_results["i_window_in_trials"],
|
|
276
|
+
pred_results["i_window_stops"],
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# the output is a list of predictions/targets per trial where each item is a
|
|
280
|
+
# timeseries of predictions/targets of shape (n_classes x timesteps)
|
|
281
|
+
|
|
282
|
+
# mask NaNs form targets
|
|
283
|
+
preds = np.hstack(trial_preds) # n_classes x timesteps in all trials
|
|
284
|
+
targets = np.hstack(trial_ys)
|
|
285
|
+
# create valid targets mask
|
|
286
|
+
mask = ~np.isnan(targets)
|
|
287
|
+
# select valid targets that have a matching predictions
|
|
288
|
+
masked_targets = targets[mask]
|
|
289
|
+
# For classification there is only one row in targets and n_classes rows in preds
|
|
290
|
+
if mask.shape[0] != preds.shape[0]:
|
|
291
|
+
masked_preds = preds[:, mask[0, :]]
|
|
292
|
+
else:
|
|
293
|
+
masked_preds = preds[mask]
|
|
294
|
+
|
|
295
|
+
# Store the computed trial preds for all Cropped Callbacks
|
|
296
|
+
# that are also on same set
|
|
297
|
+
cbs = net.callbacks_
|
|
298
|
+
epoch_cbs = [
|
|
299
|
+
cb
|
|
300
|
+
for name, cb in cbs
|
|
301
|
+
if isinstance(cb, CroppedTimeSeriesEpochScoring)
|
|
302
|
+
and (cb.on_train == self.on_train)
|
|
303
|
+
]
|
|
304
|
+
masked_preds = [torch.tensor(masked_preds.T)]
|
|
305
|
+
for cb in epoch_cbs:
|
|
306
|
+
cb.y_preds_ = masked_preds
|
|
307
|
+
cb.y_trues_ = masked_targets.T
|
|
308
|
+
cb.crops_to_trials_computed = True
|
|
309
|
+
|
|
310
|
+
dataset = dataset_train if self.on_train else dataset_valid
|
|
311
|
+
|
|
312
|
+
with _cache_net_forward_iter(
|
|
313
|
+
net, self.use_caching, self.y_preds_
|
|
314
|
+
) as cached_net:
|
|
315
|
+
current_score = self._scoring(cached_net, dataset, self.y_trues_)
|
|
316
|
+
self._record_score(net.history, current_score)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class PostEpochTrainScoring(EpochScoring):
|
|
320
|
+
"""
|
|
321
|
+
Epoch Scoring class that recomputes predictions after the epoch
|
|
322
|
+
on the training in validation mode.
|
|
323
|
+
|
|
324
|
+
Note: For unknown reasons, this affects global random generator and
|
|
325
|
+
therefore all results may change slightly if you add this scoring callback.
|
|
326
|
+
|
|
327
|
+
Parameters
|
|
328
|
+
----------
|
|
329
|
+
scoring : None, str, or callable (default=None)
|
|
330
|
+
If None, use the ``score`` method of the model. If str, it
|
|
331
|
+
should be a valid sklearn scorer (e.g. "f1", "accuracy"). If a
|
|
332
|
+
callable, it should have the signature (model, X, y), and it
|
|
333
|
+
should return a scalar. This works analogously to the
|
|
334
|
+
``scoring`` parameter in sklearn's ``GridSearchCV`` et al.
|
|
335
|
+
lower_is_better : bool (default=True)
|
|
336
|
+
Whether lower scores should be considered better or worse.
|
|
337
|
+
name : str or None (default=None)
|
|
338
|
+
If not an explicit string, tries to infer the name from the
|
|
339
|
+
``scoring`` argument.
|
|
340
|
+
target_extractor : callable (default=to_numpy)
|
|
341
|
+
This is called on y before it is passed to scoring.
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
def __init__(
|
|
345
|
+
self,
|
|
346
|
+
scoring,
|
|
347
|
+
lower_is_better=True,
|
|
348
|
+
name=None,
|
|
349
|
+
target_extractor=to_numpy,
|
|
350
|
+
):
|
|
351
|
+
super().__init__(
|
|
352
|
+
scoring=scoring,
|
|
353
|
+
lower_is_better=lower_is_better,
|
|
354
|
+
on_train=True,
|
|
355
|
+
name=name,
|
|
356
|
+
target_extractor=target_extractor,
|
|
357
|
+
use_caching=False,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
|
361
|
+
if len(self.y_preds_) == 0:
|
|
362
|
+
dataset = net.get_dataset(dataset_train)
|
|
363
|
+
# Prevent that rng state of torch is changed by
|
|
364
|
+
# creation+usage of iterator
|
|
365
|
+
# Unfortunatenly calling __iter__() of a pytorch
|
|
366
|
+
# DataLoader will change the random state
|
|
367
|
+
# Note line below setting rng state back
|
|
368
|
+
rng_state = torch.random.get_rng_state()
|
|
369
|
+
iterator = net.get_iterator(dataset, training=False)
|
|
370
|
+
y_preds = []
|
|
371
|
+
y_test = []
|
|
372
|
+
for batch in iterator:
|
|
373
|
+
batch_X, batch_y = unpack_data(batch)
|
|
374
|
+
# TODO: remove after skorch 0.10 release
|
|
375
|
+
if not check_version("skorch", min_version="0.10.1"):
|
|
376
|
+
yp = net.evaluation_step(batch_X, training=False)
|
|
377
|
+
# X, y unpacking has been pushed downstream in skorch 0.10
|
|
378
|
+
else:
|
|
379
|
+
yp = net.evaluation_step(batch, training=False)
|
|
380
|
+
yp = yp.to(device="cpu")
|
|
381
|
+
y_test.append(self.target_extractor(batch_y))
|
|
382
|
+
y_preds.append(yp)
|
|
383
|
+
y_test = np.concatenate(y_test)
|
|
384
|
+
torch.random.set_rng_state(rng_state)
|
|
385
|
+
|
|
386
|
+
# Adding the recomputed preds to all other
|
|
387
|
+
# instances of PostEpochTrainScoring of this
|
|
388
|
+
# Skorch-Net (NeuralNet, BraindecodeClassifier etc.)
|
|
389
|
+
# (They will be reinitialized to empty lists by skorch
|
|
390
|
+
# each epoch)
|
|
391
|
+
cbs = net.callbacks_
|
|
392
|
+
epoch_cbs = [
|
|
393
|
+
cb for name, cb in cbs if isinstance(cb, PostEpochTrainScoring)
|
|
394
|
+
]
|
|
395
|
+
for cb in epoch_cbs:
|
|
396
|
+
cb.y_preds_ = y_preds
|
|
397
|
+
cb.y_trues_ = y_test
|
|
398
|
+
# y pred should be same as self.y_preds_
|
|
399
|
+
# Unclear if this also leads to any
|
|
400
|
+
# random generator call?
|
|
401
|
+
with _cache_net_forward_iter(
|
|
402
|
+
net, use_caching=True, y_preds=self.y_preds_
|
|
403
|
+
) as cached_net:
|
|
404
|
+
current_score = self._scoring(cached_net, dataset_train, self.y_trues_)
|
|
405
|
+
self._record_score(net.history, current_score)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0):
|
|
409
|
+
"""Create trialwise predictions and optionally also return trialwise
|
|
410
|
+
labels from cropped dataset given module.
|
|
411
|
+
|
|
412
|
+
Parameters
|
|
413
|
+
----------
|
|
414
|
+
module: torch.nn.Module
|
|
415
|
+
A pytorch model implementing forward.
|
|
416
|
+
dataset: braindecode.datasets.BaseConcatDataset
|
|
417
|
+
A braindecode dataset to be predicted.
|
|
418
|
+
return_targets: bool
|
|
419
|
+
If True, additionally returns the trial targets.
|
|
420
|
+
batch_size: int
|
|
421
|
+
The batch size used to iterate the dataset.
|
|
422
|
+
num_workers: int
|
|
423
|
+
Number of workers used in DataLoader to iterate the dataset.
|
|
424
|
+
|
|
425
|
+
Returns
|
|
426
|
+
-------
|
|
427
|
+
trial_predictions: np.ndarray
|
|
428
|
+
3-dimensional array (n_trials x n_classes x n_predictions), where
|
|
429
|
+
the number of predictions depend on the chosen window size and the
|
|
430
|
+
receptive field of the network.
|
|
431
|
+
trial_labels: np.ndarray
|
|
432
|
+
2-dimensional array (n_trials x n_targets) where the number of
|
|
433
|
+
targets depends on the decoding paradigm and can be either a single
|
|
434
|
+
value, multiple values, or a sequence.
|
|
435
|
+
"""
|
|
436
|
+
# Ensure the model is in evaluation mode
|
|
437
|
+
module.eval()
|
|
438
|
+
# we have a cropped dataset if there exists at least one trial with more
|
|
439
|
+
# than one compute window
|
|
440
|
+
more_than_one_window = sum(dataset.get_metadata()["i_window_in_trial"] != 0) > 0
|
|
441
|
+
if not more_than_one_window:
|
|
442
|
+
warnings.warn(
|
|
443
|
+
"This function was designed to predict trials from "
|
|
444
|
+
"cropped datasets, which typically have multiple compute "
|
|
445
|
+
"windows per trial. The given dataset has exactly one "
|
|
446
|
+
"window per trial."
|
|
447
|
+
)
|
|
448
|
+
loader = DataLoader(
|
|
449
|
+
dataset=dataset,
|
|
450
|
+
batch_size=batch_size,
|
|
451
|
+
shuffle=False,
|
|
452
|
+
num_workers=num_workers,
|
|
453
|
+
)
|
|
454
|
+
device = next(module.parameters()).device
|
|
455
|
+
all_preds, all_ys, all_inds = [], [], []
|
|
456
|
+
with torch.no_grad():
|
|
457
|
+
for X, y, ind in loader:
|
|
458
|
+
X = X.to(device)
|
|
459
|
+
preds = module(X)
|
|
460
|
+
all_preds.extend(preds.cpu().numpy().astype(np.float32))
|
|
461
|
+
all_ys.extend(y.cpu().numpy().astype(np.float32))
|
|
462
|
+
all_inds.extend(ind)
|
|
463
|
+
preds_per_trial = trial_preds_from_window_preds(
|
|
464
|
+
preds=all_preds,
|
|
465
|
+
i_window_in_trials=torch.cat(all_inds[0::3]),
|
|
466
|
+
i_stop_in_trials=torch.cat(all_inds[2::3]),
|
|
467
|
+
)
|
|
468
|
+
preds_per_trial = np.array(preds_per_trial)
|
|
469
|
+
if return_targets:
|
|
470
|
+
if all_ys[0].shape == ():
|
|
471
|
+
all_ys = np.array(all_ys)
|
|
472
|
+
ys_per_trial = all_ys[
|
|
473
|
+
np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1
|
|
474
|
+
]
|
|
475
|
+
else:
|
|
476
|
+
ys_per_trial = trial_preds_from_window_preds(
|
|
477
|
+
preds=all_ys,
|
|
478
|
+
i_window_in_trials=torch.cat(all_inds[0::3]),
|
|
479
|
+
i_stop_in_trials=torch.cat(all_inds[2::3]),
|
|
480
|
+
)
|
|
481
|
+
ys_per_trial = np.array(ys_per_trial)
|
|
482
|
+
return preds_per_trial, ys_per_trial
|
|
483
|
+
return preds_per_trial
|
braindecode/util.py
CHANGED
|
@@ -12,6 +12,7 @@ import mne
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import torch
|
|
14
14
|
from sklearn.utils import check_random_state
|
|
15
|
+
from torch import Tensor
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def set_random_seeds(seed, cuda, cudnn_benchmark=None):
|
|
@@ -51,7 +52,9 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
|
|
|
51
52
|
warn(
|
|
52
53
|
"torch.backends.cudnn.benchmark was set to True which may results in lack of "
|
|
53
54
|
"reproducibility. In some cases to ensure reproducibility you may need to "
|
|
54
|
-
"set torch.backends.cudnn.benchmark to False.",
|
|
55
|
+
"set torch.backends.cudnn.benchmark to False.",
|
|
56
|
+
UserWarning,
|
|
57
|
+
)
|
|
55
58
|
else:
|
|
56
59
|
raise ValueError(
|
|
57
60
|
f"cudnn_benchmark expected to be bool or None, got '{cudnn_benchmark}'"
|
|
@@ -60,19 +63,7 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
|
|
|
60
63
|
np.random.seed(seed)
|
|
61
64
|
|
|
62
65
|
|
|
63
|
-
def
|
|
64
|
-
X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
|
|
65
|
-
):
|
|
66
|
-
warn("np_to_var has been renamed np_to_th, please use np_to_th instead")
|
|
67
|
-
return np_to_th(
|
|
68
|
-
X, requires_grad=requires_grad, dtype=dtype, pin_memory=pin_memory,
|
|
69
|
-
**tensor_kwargs
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def np_to_th(
|
|
74
|
-
X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
|
|
75
|
-
):
|
|
66
|
+
def np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs):
|
|
76
67
|
"""
|
|
77
68
|
Convenience function to transform numpy array to `torch.Tensor`.
|
|
78
69
|
|
|
@@ -103,12 +94,7 @@ def np_to_th(
|
|
|
103
94
|
return X_tensor
|
|
104
95
|
|
|
105
96
|
|
|
106
|
-
def
|
|
107
|
-
warn("var_to_np has been renamed th_to_np, please use th_to_np instead")
|
|
108
|
-
return th_to_np(var)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def th_to_np(var):
|
|
97
|
+
def th_to_np(var: Tensor):
|
|
112
98
|
"""Convenience function to transform `torch.Tensor` to numpy
|
|
113
99
|
array.
|
|
114
100
|
|
|
@@ -209,15 +195,11 @@ def wrap_reshape_apply_fn(stat_fn, a, b, axis_a, axis_b):
|
|
|
209
195
|
)
|
|
210
196
|
assert np.array_equal(n_stat_axis_a, n_stat_axis_b)
|
|
211
197
|
stat_result = stat_fn(flat_topo_a, flat_topo_b)
|
|
212
|
-
topo_result = stat_result.reshape(
|
|
213
|
-
tuple(n_other_axis_a) + tuple(n_other_axis_b)
|
|
214
|
-
)
|
|
198
|
+
topo_result = stat_result.reshape(tuple(n_other_axis_a) + tuple(n_other_axis_b))
|
|
215
199
|
return topo_result
|
|
216
200
|
|
|
217
201
|
|
|
218
|
-
def get_balanced_batches(
|
|
219
|
-
n_trials, rng, shuffle, n_batches=None, batch_size=None
|
|
220
|
-
):
|
|
202
|
+
def get_balanced_batches(n_trials, rng, shuffle, n_batches=None, batch_size=None):
|
|
221
203
|
"""Create indices for batches balanced in size
|
|
222
204
|
(batches will have maximum size difference of 1).
|
|
223
205
|
Supply either batch size or number of batches. Resulting batches
|
|
@@ -268,9 +250,17 @@ def get_balanced_batches(
|
|
|
268
250
|
return batches
|
|
269
251
|
|
|
270
252
|
|
|
271
|
-
def create_mne_dummy_raw(
|
|
272
|
-
|
|
273
|
-
|
|
253
|
+
def create_mne_dummy_raw(
|
|
254
|
+
n_channels,
|
|
255
|
+
n_times,
|
|
256
|
+
sfreq,
|
|
257
|
+
include_anns=True,
|
|
258
|
+
description=None,
|
|
259
|
+
savedir=None,
|
|
260
|
+
save_format="fif",
|
|
261
|
+
overwrite=True,
|
|
262
|
+
random_state=None,
|
|
263
|
+
):
|
|
274
264
|
"""Create an mne.io.RawArray with fake data, and optionally save it.
|
|
275
265
|
|
|
276
266
|
This will overwrite already existing files.
|
|
@@ -305,20 +295,21 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
|
|
|
305
295
|
"""
|
|
306
296
|
random_state = check_random_state(random_state)
|
|
307
297
|
data = random_state.rand(n_channels, n_times)
|
|
308
|
-
ch_names = [f
|
|
309
|
-
ch_types = [
|
|
298
|
+
ch_names = [f"ch{i}" for i in range(n_channels)]
|
|
299
|
+
ch_types = ["eeg"] * n_channels
|
|
310
300
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
|
311
301
|
|
|
312
302
|
raw = mne.io.RawArray(data, info)
|
|
313
303
|
|
|
314
304
|
if include_anns:
|
|
315
305
|
n_anns = 10
|
|
316
|
-
inds = np.linspace(
|
|
317
|
-
int
|
|
306
|
+
inds = np.linspace(int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(
|
|
307
|
+
int
|
|
308
|
+
)
|
|
318
309
|
onset = raw.times[inds]
|
|
319
310
|
duration = [1] * n_anns
|
|
320
311
|
if description is None:
|
|
321
|
-
description = [
|
|
312
|
+
description = ["test"] * n_anns
|
|
322
313
|
anns = mne.Annotations(onset, duration, description)
|
|
323
314
|
raw = raw.set_annotations(anns)
|
|
324
315
|
|
|
@@ -326,18 +317,17 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
|
|
|
326
317
|
if savedir is not None:
|
|
327
318
|
if not isinstance(save_format, list):
|
|
328
319
|
save_format = [save_format]
|
|
329
|
-
fname = os.path.join(savedir,
|
|
320
|
+
fname = os.path.join(savedir, "fake_eeg_raw")
|
|
330
321
|
|
|
331
|
-
if
|
|
332
|
-
fif_fname = fname +
|
|
322
|
+
if "fif" in save_format:
|
|
323
|
+
fif_fname = fname + ".fif"
|
|
333
324
|
raw.save(fif_fname, overwrite=overwrite)
|
|
334
|
-
save_fname[
|
|
335
|
-
if
|
|
336
|
-
h5_fname = fname +
|
|
337
|
-
with h5py.File(h5_fname,
|
|
338
|
-
f.create_dataset(
|
|
339
|
-
|
|
340
|
-
save_fname['hdf5'] = h5_fname
|
|
325
|
+
save_fname["fif"] = fif_fname
|
|
326
|
+
if "hdf5" in save_format:
|
|
327
|
+
h5_fname = fname + ".h5"
|
|
328
|
+
with h5py.File(h5_fname, "w") as f:
|
|
329
|
+
f.create_dataset("fake_raw", dtype="f8", data=raw.get_data())
|
|
330
|
+
save_fname["hdf5"] = h5_fname
|
|
341
331
|
|
|
342
332
|
return raw, save_fname
|
|
343
333
|
|
|
@@ -349,7 +339,9 @@ class ThrowAwayIndexLoader(object):
|
|
|
349
339
|
self.last_i = None
|
|
350
340
|
self.is_regression = is_regression
|
|
351
341
|
|
|
352
|
-
def __iter__(
|
|
342
|
+
def __iter__(
|
|
343
|
+
self,
|
|
344
|
+
):
|
|
353
345
|
normal_iter = self.loader.__iter__()
|
|
354
346
|
for batch in normal_iter:
|
|
355
347
|
if len(batch) == 3:
|
|
@@ -360,7 +352,7 @@ class ThrowAwayIndexLoader(object):
|
|
|
360
352
|
x, y = batch
|
|
361
353
|
|
|
362
354
|
# TODO: should be on dataset side
|
|
363
|
-
if hasattr(x,
|
|
355
|
+
if hasattr(x, "type"):
|
|
364
356
|
x = x.type(torch.float32)
|
|
365
357
|
if self.is_regression:
|
|
366
358
|
y = y.type(torch.float32)
|
|
@@ -370,23 +362,26 @@ class ThrowAwayIndexLoader(object):
|
|
|
370
362
|
|
|
371
363
|
|
|
372
364
|
def update_estimator_docstring(base_class, docstring):
|
|
373
|
-
base_doc = base_class.__doc__.replace(
|
|
374
|
-
idx = base_doc.find(
|
|
375
|
-
idx_end = idx + base_doc[idx:].find(
|
|
365
|
+
base_doc = base_class.__doc__.replace(" : ", ": ")
|
|
366
|
+
idx = base_doc.find("callbacks:")
|
|
367
|
+
idx_end = idx + base_doc[idx:].find("\n\n")
|
|
376
368
|
# remove callback descripiton already included in braindecode docstring
|
|
377
|
-
filtered_doc = base_doc[:idx] + base_doc[idx_end + 6:]
|
|
378
|
-
splitted = docstring.split(
|
|
369
|
+
filtered_doc = base_doc[:idx] + base_doc[idx_end + 6 :]
|
|
370
|
+
splitted = docstring.split("Parameters\n ----------\n ")
|
|
379
371
|
out_docstring = (
|
|
380
|
-
splitted[0]
|
|
381
|
-
filtered_doc[
|
|
382
|
-
|
|
383
|
-
|
|
372
|
+
splitted[0]
|
|
373
|
+
+ filtered_doc[
|
|
374
|
+
filtered_doc.find("Parameters") : filtered_doc.find("Attributes")
|
|
375
|
+
]
|
|
376
|
+
+ splitted[1]
|
|
377
|
+
+ filtered_doc[filtered_doc.find("Attributes") :]
|
|
378
|
+
)
|
|
384
379
|
return out_docstring
|
|
385
380
|
|
|
386
381
|
|
|
387
382
|
def _update_moabb_docstring(base_class, docstring):
|
|
388
383
|
base_doc = base_class.__doc__
|
|
389
|
-
out_docstring = base_doc + f
|
|
384
|
+
out_docstring = base_doc + f"\n\n{docstring}"
|
|
390
385
|
return out_docstring
|
|
391
386
|
|
|
392
387
|
|
|
@@ -406,8 +401,9 @@ def read_all_file_names(directory, extension):
|
|
|
406
401
|
file_paths: list(str)
|
|
407
402
|
List of all files found in (sub)directories of path.
|
|
408
403
|
"""
|
|
409
|
-
assert extension.startswith(
|
|
410
|
-
file_paths = glob.glob(directory +
|
|
404
|
+
assert extension.startswith(".")
|
|
405
|
+
file_paths = glob.glob(directory + "**/*" + extension, recursive=True)
|
|
411
406
|
assert len(file_paths) > 0, (
|
|
412
|
-
f
|
|
407
|
+
f"something went wrong. Found no {extension} files in {directory}"
|
|
408
|
+
)
|
|
413
409
|
return file_paths
|