braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
# Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
2
|
+
# Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
3
|
+
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
4
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
5
|
+
# Mohammed Fattouh <mo.fattouh@gmail.com>
|
|
6
|
+
#
|
|
7
|
+
# License: BSD-3
|
|
8
|
+
|
|
9
|
+
import warnings
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import torch
|
|
14
|
+
from skorch.callbacks.scoring import EpochScoring
|
|
15
|
+
from skorch.dataset import unpack_data
|
|
16
|
+
from skorch.utils import to_numpy
|
|
17
|
+
from torch.utils.data import DataLoader
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials):
|
|
21
|
+
"""
|
|
22
|
+
Assigning window predictions to trials while removing duplicate
|
|
23
|
+
predictions.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
preds: list of ndarrays (at least 2darrays)
|
|
28
|
+
List of window predictions, in each window prediction
|
|
29
|
+
time is in axis=1
|
|
30
|
+
i_window_in_trials: list
|
|
31
|
+
Index/number of window in trial
|
|
32
|
+
i_stop_in_trials: list
|
|
33
|
+
stop position of window in trial
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
preds_per_trial: list of ndarrays
|
|
38
|
+
Predictions in each trial, duplicates removed
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
assert len(preds) == len(i_window_in_trials) == len(i_stop_in_trials), (
|
|
42
|
+
f"{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Algorithm for assigning window predictions to trials
|
|
46
|
+
# while removing duplicate predictions:
|
|
47
|
+
# Loop through windows:
|
|
48
|
+
# In each iteration you have predictions (assumed: #classes x #timesteps,
|
|
49
|
+
# or at least #timesteps must be in axis=1)
|
|
50
|
+
# and you have i_window_in_trial, i_stop_in_trial
|
|
51
|
+
# (i_trial removed from variable names for brevity)
|
|
52
|
+
# You first check if the i_window_in_trial is 1 larger
|
|
53
|
+
# than in last iteration, then you are still in the same trial
|
|
54
|
+
# Otherwise you are in a new trial
|
|
55
|
+
# If you are in the same trial, you check for duplicate predictions
|
|
56
|
+
# Only take predictions that are after (inclusive)
|
|
57
|
+
# the stop of the last iteration (i.e., the index of final prediction
|
|
58
|
+
# in the last iteration)
|
|
59
|
+
# Then add the duplicate-removed predictions from this window
|
|
60
|
+
# to predictions for current trial
|
|
61
|
+
preds_per_trial = []
|
|
62
|
+
cur_trial_preds = []
|
|
63
|
+
i_last_stop = None
|
|
64
|
+
i_last_window = -1
|
|
65
|
+
for window_preds, i_window, i_stop in zip(
|
|
66
|
+
preds, i_window_in_trials, i_stop_in_trials
|
|
67
|
+
):
|
|
68
|
+
window_preds = np.array(window_preds)
|
|
69
|
+
if i_window != (i_last_window + 1):
|
|
70
|
+
assert i_window == 0, "window numbers in new trial should start from 0"
|
|
71
|
+
preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
|
|
72
|
+
cur_trial_preds = []
|
|
73
|
+
i_last_stop = None
|
|
74
|
+
|
|
75
|
+
if i_last_stop is not None:
|
|
76
|
+
# Remove duplicates
|
|
77
|
+
n_needed_preds = i_stop - i_last_stop
|
|
78
|
+
window_preds = window_preds[:, -n_needed_preds:]
|
|
79
|
+
cur_trial_preds.append(window_preds)
|
|
80
|
+
i_last_window = i_window
|
|
81
|
+
i_last_stop = i_stop
|
|
82
|
+
# add last trial preds
|
|
83
|
+
preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
|
|
84
|
+
return preds_per_trial
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@contextmanager
|
|
88
|
+
def _cache_net_forward_iter(net, use_caching, y_preds):
|
|
89
|
+
"""Caching context for ``skorch.NeuralNet`` instance.
|
|
90
|
+
Returns a modified version of the net whose ``forward_iter``
|
|
91
|
+
method will subsequently return cached predictions. Leaving the
|
|
92
|
+
context will undo the overwrite of the ``forward_iter`` method.
|
|
93
|
+
"""
|
|
94
|
+
if not use_caching:
|
|
95
|
+
yield net
|
|
96
|
+
return
|
|
97
|
+
y_preds = iter(y_preds)
|
|
98
|
+
|
|
99
|
+
# pylint: disable=unused-argument
|
|
100
|
+
def cached_forward_iter(*args, device=net.device, **kwargs):
|
|
101
|
+
for yp in y_preds:
|
|
102
|
+
yield yp.to(device=device)
|
|
103
|
+
|
|
104
|
+
net.forward_iter = cached_forward_iter
|
|
105
|
+
try:
|
|
106
|
+
yield net
|
|
107
|
+
finally:
|
|
108
|
+
# By setting net.forward_iter we define an attribute
|
|
109
|
+
# `forward_iter` that precedes the bound method
|
|
110
|
+
# `forward_iter`. By deleting the entry from the attribute
|
|
111
|
+
# dict we undo this.
|
|
112
|
+
del net.__dict__["forward_iter"]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class CroppedTrialEpochScoring(EpochScoring):
|
|
116
|
+
"""
|
|
117
|
+
Class to compute scores for trials from a model that predicts (super)crops.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
# XXX needs a docstring !!!
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
scoring,
|
|
125
|
+
lower_is_better=True,
|
|
126
|
+
on_train=False,
|
|
127
|
+
name=None,
|
|
128
|
+
target_extractor=to_numpy,
|
|
129
|
+
use_caching=True,
|
|
130
|
+
):
|
|
131
|
+
super().__init__(
|
|
132
|
+
scoring=scoring,
|
|
133
|
+
lower_is_better=lower_is_better,
|
|
134
|
+
on_train=on_train,
|
|
135
|
+
name=name,
|
|
136
|
+
target_extractor=target_extractor,
|
|
137
|
+
use_caching=use_caching,
|
|
138
|
+
)
|
|
139
|
+
if not self.on_train:
|
|
140
|
+
self.window_inds_ = []
|
|
141
|
+
|
|
142
|
+
def _initialize_cache(self):
|
|
143
|
+
super()._initialize_cache()
|
|
144
|
+
self.crops_to_trials_computed = False
|
|
145
|
+
self.y_trues_ = []
|
|
146
|
+
self.y_preds_ = []
|
|
147
|
+
if not self.on_train:
|
|
148
|
+
self.window_inds_ = []
|
|
149
|
+
|
|
150
|
+
def on_batch_end(self, net, batch, y_pred, training, **kwargs):
|
|
151
|
+
# Skorch saves the predictions without moving them from GPU
|
|
152
|
+
# https://github.com/skorch-dev/skorch/blob/fe71e3d55a4ae5f5f94ef7bdfc00fca3b3fd267f/skorch/callbacks/scoring.py#L385
|
|
153
|
+
# This can cause memory issues in case of a large number of predictions
|
|
154
|
+
# Therefore here we move them to CPU already
|
|
155
|
+
super().on_batch_end(net, batch, y_pred, training, **kwargs)
|
|
156
|
+
if self.use_caching and training == self.on_train:
|
|
157
|
+
self.y_preds_[-1] = self.y_preds_[-1].cpu()
|
|
158
|
+
|
|
159
|
+
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
|
160
|
+
assert self.use_caching
|
|
161
|
+
if not self.crops_to_trials_computed:
|
|
162
|
+
if self.on_train:
|
|
163
|
+
# Prevent that rng state of torch is changed by
|
|
164
|
+
# creation+usage of iterator
|
|
165
|
+
rng_state = torch.random.get_rng_state()
|
|
166
|
+
pred_results = net.predict_with_window_inds_and_ys(dataset_train)
|
|
167
|
+
torch.random.set_rng_state(rng_state)
|
|
168
|
+
else:
|
|
169
|
+
pred_results = {}
|
|
170
|
+
pred_results["i_window_in_trials"] = np.concatenate(
|
|
171
|
+
[i[0].cpu().numpy() for i in self.window_inds_]
|
|
172
|
+
)
|
|
173
|
+
pred_results["i_window_stops"] = np.concatenate(
|
|
174
|
+
[i[2].cpu().numpy() for i in self.window_inds_]
|
|
175
|
+
)
|
|
176
|
+
pred_results["preds"] = np.concatenate(
|
|
177
|
+
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
178
|
+
)
|
|
179
|
+
pred_results["window_ys"] = np.concatenate(
|
|
180
|
+
[y.cpu().numpy() for y in self.y_trues_]
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# A new trial starts
|
|
184
|
+
# when the index of the window in trials
|
|
185
|
+
# does not increment by 1
|
|
186
|
+
# Add dummy infinity at start
|
|
187
|
+
window_0_per_trial_mask = (
|
|
188
|
+
np.diff(pred_results["i_window_in_trials"], prepend=[np.inf]) != 1
|
|
189
|
+
)
|
|
190
|
+
trial_ys = pred_results["window_ys"][window_0_per_trial_mask]
|
|
191
|
+
trial_preds = trial_preds_from_window_preds(
|
|
192
|
+
pred_results["preds"],
|
|
193
|
+
pred_results["i_window_in_trials"],
|
|
194
|
+
pred_results["i_window_stops"],
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Average across the timesteps of each trial so we have per-trial
|
|
198
|
+
# predictions already, these will be just passed through the forward
|
|
199
|
+
# method of the classifier/regressor to the skorch scoring function.
|
|
200
|
+
# trial_preds is a list, each item is a 2d array classes x time
|
|
201
|
+
y_preds_per_trial = np.array([np.mean(p, axis=1) for p in trial_preds])
|
|
202
|
+
# Move into format expected by skorch (list of torch tensors)
|
|
203
|
+
y_preds_per_trial = [torch.tensor(y_preds_per_trial)]
|
|
204
|
+
|
|
205
|
+
# Store the computed trial preds for all Cropped Callbacks
|
|
206
|
+
# that are also on same set
|
|
207
|
+
cbs = net.callbacks_
|
|
208
|
+
epoch_cbs = [
|
|
209
|
+
cb
|
|
210
|
+
for name, cb in cbs
|
|
211
|
+
if isinstance(cb, CroppedTrialEpochScoring)
|
|
212
|
+
and (cb.on_train == self.on_train)
|
|
213
|
+
]
|
|
214
|
+
for cb in epoch_cbs:
|
|
215
|
+
cb.y_preds_ = y_preds_per_trial
|
|
216
|
+
cb.y_trues_ = trial_ys
|
|
217
|
+
cb.crops_to_trials_computed = True
|
|
218
|
+
|
|
219
|
+
dataset = dataset_train if self.on_train else dataset_valid
|
|
220
|
+
|
|
221
|
+
with _cache_net_forward_iter(
|
|
222
|
+
net, self.use_caching, self.y_preds_
|
|
223
|
+
) as cached_net:
|
|
224
|
+
current_score = self._scoring(cached_net, dataset, self.y_trues_)
|
|
225
|
+
self._record_score(net.history, current_score)
|
|
226
|
+
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
|
|
231
|
+
"""
|
|
232
|
+
Class to compute scores for trials from a model that predicts (super)crops with
|
|
233
|
+
time series target.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
|
237
|
+
assert self.use_caching
|
|
238
|
+
if not self.crops_to_trials_computed:
|
|
239
|
+
if self.on_train:
|
|
240
|
+
# Prevent that rng state of torch is changed by
|
|
241
|
+
# creation+usage of iterator
|
|
242
|
+
rng_state = torch.random.get_rng_state()
|
|
243
|
+
pred_results = net.predict_with_window_inds_and_ys(dataset_train)
|
|
244
|
+
torch.random.set_rng_state(rng_state)
|
|
245
|
+
else:
|
|
246
|
+
pred_results = {}
|
|
247
|
+
pred_results["i_window_in_trials"] = np.concatenate(
|
|
248
|
+
[i[0].cpu().numpy() for i in self.window_inds_]
|
|
249
|
+
)
|
|
250
|
+
pred_results["i_window_stops"] = np.concatenate(
|
|
251
|
+
[i[2].cpu().numpy() for i in self.window_inds_]
|
|
252
|
+
)
|
|
253
|
+
pred_results["preds"] = np.concatenate(
|
|
254
|
+
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
255
|
+
)
|
|
256
|
+
pred_results["window_ys"] = np.concatenate(
|
|
257
|
+
[y.cpu().numpy() for y in self.y_trues_]
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
num_preds = pred_results["preds"][-1].shape[-1]
|
|
261
|
+
# slice the targets to fit preds shape
|
|
262
|
+
pred_results["window_ys"] = [
|
|
263
|
+
targets[:, -num_preds:] for targets in pred_results["window_ys"]
|
|
264
|
+
]
|
|
265
|
+
|
|
266
|
+
trial_preds = trial_preds_from_window_preds(
|
|
267
|
+
pred_results["preds"],
|
|
268
|
+
pred_results["i_window_in_trials"],
|
|
269
|
+
pred_results["i_window_stops"],
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
trial_ys = trial_preds_from_window_preds(
|
|
273
|
+
pred_results["window_ys"],
|
|
274
|
+
pred_results["i_window_in_trials"],
|
|
275
|
+
pred_results["i_window_stops"],
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# the output is a list of predictions/targets per trial where each item is a
|
|
279
|
+
# timeseries of predictions/targets of shape (n_classes x timesteps)
|
|
280
|
+
|
|
281
|
+
# mask NaNs form targets
|
|
282
|
+
preds = np.hstack(trial_preds) # n_classes x timesteps in all trials
|
|
283
|
+
targets = np.hstack(trial_ys)
|
|
284
|
+
# create valid targets mask
|
|
285
|
+
mask = ~np.isnan(targets)
|
|
286
|
+
# select valid targets that have a matching predictions
|
|
287
|
+
masked_targets = targets[mask]
|
|
288
|
+
# For classification there is only one row in targets and n_classes rows in preds
|
|
289
|
+
if mask.shape[0] != preds.shape[0]:
|
|
290
|
+
masked_preds = preds[:, mask[0, :]]
|
|
291
|
+
else:
|
|
292
|
+
masked_preds = preds[mask]
|
|
293
|
+
|
|
294
|
+
# Store the computed trial preds for all Cropped Callbacks
|
|
295
|
+
# that are also on same set
|
|
296
|
+
cbs = net.callbacks_
|
|
297
|
+
epoch_cbs = [
|
|
298
|
+
cb
|
|
299
|
+
for name, cb in cbs
|
|
300
|
+
if isinstance(cb, CroppedTimeSeriesEpochScoring)
|
|
301
|
+
and (cb.on_train == self.on_train)
|
|
302
|
+
]
|
|
303
|
+
masked_preds = [torch.tensor(masked_preds.T)]
|
|
304
|
+
for cb in epoch_cbs:
|
|
305
|
+
cb.y_preds_ = masked_preds
|
|
306
|
+
cb.y_trues_ = masked_targets.T
|
|
307
|
+
cb.crops_to_trials_computed = True
|
|
308
|
+
|
|
309
|
+
dataset = dataset_train if self.on_train else dataset_valid
|
|
310
|
+
|
|
311
|
+
with _cache_net_forward_iter(
|
|
312
|
+
net, self.use_caching, self.y_preds_
|
|
313
|
+
) as cached_net:
|
|
314
|
+
current_score = self._scoring(cached_net, dataset, self.y_trues_)
|
|
315
|
+
self._record_score(net.history, current_score)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class PostEpochTrainScoring(EpochScoring):
|
|
319
|
+
"""
|
|
320
|
+
Epoch Scoring class that recomputes predictions after the epoch
|
|
321
|
+
on the training in validation mode.
|
|
322
|
+
|
|
323
|
+
Note: For unknown reasons, this affects global random generator and
|
|
324
|
+
therefore all results may change slightly if you add this scoring callback.
|
|
325
|
+
|
|
326
|
+
Parameters
|
|
327
|
+
----------
|
|
328
|
+
scoring : None, str, or callable (default=None)
|
|
329
|
+
If None, use the ``score`` method of the model. If str, it
|
|
330
|
+
should be a valid sklearn scorer (e.g. "f1", "accuracy"). If a
|
|
331
|
+
callable, it should have the signature (model, X, y), and it
|
|
332
|
+
should return a scalar. This works analogously to the
|
|
333
|
+
``scoring`` parameter in sklearn's ``GridSearchCV`` et al.
|
|
334
|
+
lower_is_better : bool (default=True)
|
|
335
|
+
Whether lower scores should be considered better or worse.
|
|
336
|
+
name : str or None (default=None)
|
|
337
|
+
If not an explicit string, tries to infer the name from the
|
|
338
|
+
``scoring`` argument.
|
|
339
|
+
target_extractor : callable (default=to_numpy)
|
|
340
|
+
This is called on y before it is passed to scoring.
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
def __init__(
|
|
344
|
+
self,
|
|
345
|
+
scoring,
|
|
346
|
+
lower_is_better=True,
|
|
347
|
+
name=None,
|
|
348
|
+
target_extractor=to_numpy,
|
|
349
|
+
):
|
|
350
|
+
super().__init__(
|
|
351
|
+
scoring=scoring,
|
|
352
|
+
lower_is_better=lower_is_better,
|
|
353
|
+
on_train=True,
|
|
354
|
+
name=name,
|
|
355
|
+
target_extractor=target_extractor,
|
|
356
|
+
use_caching=False,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
|
360
|
+
if len(self.y_preds_) == 0:
|
|
361
|
+
dataset = net.get_dataset(dataset_train)
|
|
362
|
+
# Prevent that rng state of torch is changed by
|
|
363
|
+
# creation+usage of iterator
|
|
364
|
+
# Unfortunatenly calling __iter__() of a pytorch
|
|
365
|
+
# DataLoader will change the random state
|
|
366
|
+
# Note line below setting rng state back
|
|
367
|
+
rng_state = torch.random.get_rng_state()
|
|
368
|
+
iterator = net.get_iterator(dataset, training=False)
|
|
369
|
+
y_preds = []
|
|
370
|
+
y_test = []
|
|
371
|
+
for batch in iterator:
|
|
372
|
+
_, batch_y = unpack_data(batch)
|
|
373
|
+
yp = net.evaluation_step(batch, training=False)
|
|
374
|
+
yp = yp.to(device="cpu")
|
|
375
|
+
y_test.append(self.target_extractor(batch_y))
|
|
376
|
+
y_preds.append(yp)
|
|
377
|
+
y_test = np.concatenate(y_test)
|
|
378
|
+
torch.random.set_rng_state(rng_state)
|
|
379
|
+
|
|
380
|
+
# Adding the recomputed preds to all other
|
|
381
|
+
# instances of PostEpochTrainScoring of this
|
|
382
|
+
# Skorch-Net (NeuralNet, BraindecodeClassifier etc.)
|
|
383
|
+
# (They will be reinitialized to empty lists by skorch
|
|
384
|
+
# each epoch)
|
|
385
|
+
cbs = net.callbacks_
|
|
386
|
+
epoch_cbs = [
|
|
387
|
+
cb for name, cb in cbs if isinstance(cb, PostEpochTrainScoring)
|
|
388
|
+
]
|
|
389
|
+
for cb in epoch_cbs:
|
|
390
|
+
cb.y_preds_ = y_preds
|
|
391
|
+
cb.y_trues_ = y_test
|
|
392
|
+
# y pred should be same as self.y_preds_
|
|
393
|
+
# Unclear if this also leads to any
|
|
394
|
+
# random generator call?
|
|
395
|
+
with _cache_net_forward_iter(
|
|
396
|
+
net, use_caching=True, y_preds=self.y_preds_
|
|
397
|
+
) as cached_net:
|
|
398
|
+
current_score = self._scoring(cached_net, dataset_train, self.y_trues_)
|
|
399
|
+
self._record_score(net.history, current_score)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0):
|
|
403
|
+
"""Create trialwise predictions and optionally also return trialwise
|
|
404
|
+
labels from cropped dataset given module.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
module: torch.nn.Module
|
|
409
|
+
A pytorch model implementing forward.
|
|
410
|
+
dataset: braindecode.datasets.BaseConcatDataset
|
|
411
|
+
A braindecode dataset to be predicted.
|
|
412
|
+
return_targets: bool
|
|
413
|
+
If True, additionally returns the trial targets.
|
|
414
|
+
batch_size: int
|
|
415
|
+
The batch size used to iterate the dataset.
|
|
416
|
+
num_workers: int
|
|
417
|
+
Number of workers used in DataLoader to iterate the dataset.
|
|
418
|
+
|
|
419
|
+
Returns
|
|
420
|
+
-------
|
|
421
|
+
trial_predictions: np.ndarray
|
|
422
|
+
3-dimensional array (n_trials x n_classes x n_predictions), where
|
|
423
|
+
the number of predictions depend on the chosen window size and the
|
|
424
|
+
receptive field of the network.
|
|
425
|
+
trial_labels: np.ndarray
|
|
426
|
+
2-dimensional array (n_trials x n_targets) where the number of
|
|
427
|
+
targets depends on the decoding paradigm and can be either a single
|
|
428
|
+
value, multiple values, or a sequence.
|
|
429
|
+
"""
|
|
430
|
+
# Ensure the model is in evaluation mode
|
|
431
|
+
module.eval()
|
|
432
|
+
# we have a cropped dataset if there exists at least one trial with more
|
|
433
|
+
# than one compute window
|
|
434
|
+
more_than_one_window = sum(dataset.get_metadata()["i_window_in_trial"] != 0) > 0
|
|
435
|
+
if not more_than_one_window:
|
|
436
|
+
warnings.warn(
|
|
437
|
+
"This function was designed to predict trials from "
|
|
438
|
+
"cropped datasets, which typically have multiple compute "
|
|
439
|
+
"windows per trial. The given dataset has exactly one "
|
|
440
|
+
"window per trial."
|
|
441
|
+
)
|
|
442
|
+
loader = DataLoader(
|
|
443
|
+
dataset=dataset,
|
|
444
|
+
batch_size=batch_size,
|
|
445
|
+
shuffle=False,
|
|
446
|
+
num_workers=num_workers,
|
|
447
|
+
)
|
|
448
|
+
device = next(module.parameters()).device
|
|
449
|
+
all_preds, all_ys, all_inds = [], [], []
|
|
450
|
+
with torch.no_grad():
|
|
451
|
+
for X, y, ind in loader:
|
|
452
|
+
X = X.to(device)
|
|
453
|
+
preds = module(X)
|
|
454
|
+
all_preds.extend(preds.cpu().numpy().astype(np.float32))
|
|
455
|
+
all_ys.extend(y.cpu().numpy().astype(np.float32))
|
|
456
|
+
all_inds.extend(ind)
|
|
457
|
+
preds_per_trial = trial_preds_from_window_preds(
|
|
458
|
+
preds=all_preds,
|
|
459
|
+
i_window_in_trials=torch.cat(all_inds[0::3]),
|
|
460
|
+
i_stop_in_trials=torch.cat(all_inds[2::3]),
|
|
461
|
+
)
|
|
462
|
+
preds_per_trial = np.array(preds_per_trial)
|
|
463
|
+
if return_targets:
|
|
464
|
+
if all_ys[0].shape == ():
|
|
465
|
+
all_ys = np.array(all_ys)
|
|
466
|
+
ys_per_trial = all_ys[
|
|
467
|
+
np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1
|
|
468
|
+
]
|
|
469
|
+
else:
|
|
470
|
+
ys_per_trial = trial_preds_from_window_preds(
|
|
471
|
+
preds=all_ys,
|
|
472
|
+
i_window_in_trials=torch.cat(all_inds[0::3]),
|
|
473
|
+
i_stop_in_trials=torch.cat(all_inds[2::3]),
|
|
474
|
+
)
|
|
475
|
+
ys_per_trial = np.array(ys_per_trial)
|
|
476
|
+
return preds_per_trial, ys_per_trial
|
|
477
|
+
return preds_per_trial
|