braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/training/scoring.py
CHANGED
|
@@ -6,20 +6,19 @@
|
|
|
6
6
|
#
|
|
7
7
|
# License: BSD-3
|
|
8
8
|
|
|
9
|
-
from contextlib import contextmanager
|
|
10
9
|
import warnings
|
|
10
|
+
from contextlib import contextmanager
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import torch
|
|
14
14
|
from mne.utils.check import check_version
|
|
15
15
|
from skorch.callbacks.scoring import EpochScoring
|
|
16
|
-
from skorch.utils import to_numpy
|
|
17
16
|
from skorch.dataset import unpack_data
|
|
17
|
+
from skorch.utils import to_numpy
|
|
18
18
|
from torch.utils.data import DataLoader
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def trial_preds_from_window_preds(
|
|
22
|
-
preds, i_window_in_trials, i_stop_in_trials):
|
|
21
|
+
def trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials):
|
|
23
22
|
"""
|
|
24
23
|
Assigning window predictions to trials while removing duplicate
|
|
25
24
|
predictions.
|
|
@@ -41,7 +40,8 @@ def trial_preds_from_window_preds(
|
|
|
41
40
|
|
|
42
41
|
"""
|
|
43
42
|
assert len(preds) == len(i_window_in_trials) == len(i_stop_in_trials), (
|
|
44
|
-
f
|
|
43
|
+
f"{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}"
|
|
44
|
+
)
|
|
45
45
|
|
|
46
46
|
# Algorithm for assigning window predictions to trials
|
|
47
47
|
# while removing duplicate predictions:
|
|
@@ -64,11 +64,11 @@ def trial_preds_from_window_preds(
|
|
|
64
64
|
i_last_stop = None
|
|
65
65
|
i_last_window = -1
|
|
66
66
|
for window_preds, i_window, i_stop in zip(
|
|
67
|
-
|
|
67
|
+
preds, i_window_in_trials, i_stop_in_trials
|
|
68
|
+
):
|
|
68
69
|
window_preds = np.array(window_preds)
|
|
69
70
|
if i_window != (i_last_window + 1):
|
|
70
|
-
assert i_window == 0,
|
|
71
|
-
"window numbers in new trial should start from 0")
|
|
71
|
+
assert i_window == 0, "window numbers in new trial should start from 0"
|
|
72
72
|
preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
|
|
73
73
|
cur_trial_preds = []
|
|
74
74
|
i_last_stop = None
|
|
@@ -117,16 +117,17 @@ class CroppedTrialEpochScoring(EpochScoring):
|
|
|
117
117
|
"""
|
|
118
118
|
Class to compute scores for trials from a model that predicts (super)crops.
|
|
119
119
|
"""
|
|
120
|
+
|
|
120
121
|
# XXX needs a docstring !!!
|
|
121
122
|
|
|
122
123
|
def __init__(
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
124
|
+
self,
|
|
125
|
+
scoring,
|
|
126
|
+
lower_is_better=True,
|
|
127
|
+
on_train=False,
|
|
128
|
+
name=None,
|
|
129
|
+
target_extractor=to_numpy,
|
|
130
|
+
use_caching=True,
|
|
130
131
|
):
|
|
131
132
|
super().__init__(
|
|
132
133
|
scoring=scoring,
|
|
@@ -147,8 +148,7 @@ class CroppedTrialEpochScoring(EpochScoring):
|
|
|
147
148
|
if not self.on_train:
|
|
148
149
|
self.window_inds_ = []
|
|
149
150
|
|
|
150
|
-
def on_batch_end(
|
|
151
|
-
self, net, batch, y_pred, training, **kwargs):
|
|
151
|
+
def on_batch_end(self, net, batch, y_pred, training, **kwargs):
|
|
152
152
|
# Skorch saves the predictions without moving them from GPU
|
|
153
153
|
# https://github.com/skorch-dev/skorch/blob/fe71e3d55a4ae5f5f94ef7bdfc00fca3b3fd267f/skorch/callbacks/scoring.py#L385
|
|
154
154
|
# This can cause memory issues in case of a large number of predictions
|
|
@@ -164,41 +164,42 @@ class CroppedTrialEpochScoring(EpochScoring):
|
|
|
164
164
|
# Prevent that rng state of torch is changed by
|
|
165
165
|
# creation+usage of iterator
|
|
166
166
|
rng_state = torch.random.get_rng_state()
|
|
167
|
-
pred_results = net.predict_with_window_inds_and_ys(
|
|
168
|
-
dataset_train)
|
|
167
|
+
pred_results = net.predict_with_window_inds_and_ys(dataset_train)
|
|
169
168
|
torch.random.set_rng_state(rng_state)
|
|
170
169
|
else:
|
|
171
170
|
pred_results = {}
|
|
172
|
-
pred_results[
|
|
171
|
+
pred_results["i_window_in_trials"] = np.concatenate(
|
|
173
172
|
[i[0].cpu().numpy() for i in self.window_inds_]
|
|
174
173
|
)
|
|
175
|
-
pred_results[
|
|
174
|
+
pred_results["i_window_stops"] = np.concatenate(
|
|
176
175
|
[i[2].cpu().numpy() for i in self.window_inds_]
|
|
177
176
|
)
|
|
178
|
-
pred_results[
|
|
179
|
-
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
180
|
-
|
|
181
|
-
|
|
177
|
+
pred_results["preds"] = np.concatenate(
|
|
178
|
+
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
179
|
+
)
|
|
180
|
+
pred_results["window_ys"] = np.concatenate(
|
|
181
|
+
[y.cpu().numpy() for y in self.y_trues_]
|
|
182
|
+
)
|
|
182
183
|
|
|
183
184
|
# A new trial starts
|
|
184
185
|
# when the index of the window in trials
|
|
185
186
|
# does not increment by 1
|
|
186
187
|
# Add dummy infinity at start
|
|
187
|
-
window_0_per_trial_mask =
|
|
188
|
-
pred_results[
|
|
189
|
-
|
|
188
|
+
window_0_per_trial_mask = (
|
|
189
|
+
np.diff(pred_results["i_window_in_trials"], prepend=[np.inf]) != 1
|
|
190
|
+
)
|
|
191
|
+
trial_ys = pred_results["window_ys"][window_0_per_trial_mask]
|
|
190
192
|
trial_preds = trial_preds_from_window_preds(
|
|
191
|
-
pred_results[
|
|
192
|
-
pred_results[
|
|
193
|
-
pred_results[
|
|
193
|
+
pred_results["preds"],
|
|
194
|
+
pred_results["i_window_in_trials"],
|
|
195
|
+
pred_results["i_window_stops"],
|
|
196
|
+
)
|
|
194
197
|
|
|
195
198
|
# Average across the timesteps of each trial so we have per-trial
|
|
196
199
|
# predictions already, these will be just passed through the forward
|
|
197
200
|
# method of the classifier/regressor to the skorch scoring function.
|
|
198
201
|
# trial_preds is a list, each item is a 2d array classes x time
|
|
199
|
-
y_preds_per_trial = np.array(
|
|
200
|
-
[np.mean(p, axis=1) for p in trial_preds]
|
|
201
|
-
)
|
|
202
|
+
y_preds_per_trial = np.array([np.mean(p, axis=1) for p in trial_preds])
|
|
202
203
|
# Move into format expected by skorch (list of torch tensors)
|
|
203
204
|
y_preds_per_trial = [torch.tensor(y_preds_per_trial)]
|
|
204
205
|
|
|
@@ -206,9 +207,10 @@ class CroppedTrialEpochScoring(EpochScoring):
|
|
|
206
207
|
# that are also on same set
|
|
207
208
|
cbs = net.callbacks_
|
|
208
209
|
epoch_cbs = [
|
|
209
|
-
cb
|
|
210
|
-
|
|
211
|
-
|
|
210
|
+
cb
|
|
211
|
+
for name, cb in cbs
|
|
212
|
+
if isinstance(cb, CroppedTrialEpochScoring)
|
|
213
|
+
and (cb.on_train == self.on_train)
|
|
212
214
|
]
|
|
213
215
|
for cb in epoch_cbs:
|
|
214
216
|
cb.y_preds_ = y_preds_per_trial
|
|
@@ -218,7 +220,7 @@ class CroppedTrialEpochScoring(EpochScoring):
|
|
|
218
220
|
dataset = dataset_train if self.on_train else dataset_valid
|
|
219
221
|
|
|
220
222
|
with _cache_net_forward_iter(
|
|
221
|
-
|
|
223
|
+
net, self.use_caching, self.y_preds_
|
|
222
224
|
) as cached_net:
|
|
223
225
|
current_score = self._scoring(cached_net, dataset, self.y_trues_)
|
|
224
226
|
self._record_score(net.history, current_score)
|
|
@@ -231,6 +233,7 @@ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
|
|
|
231
233
|
Class to compute scores for trials from a model that predicts (super)crops with
|
|
232
234
|
time series target.
|
|
233
235
|
"""
|
|
236
|
+
|
|
234
237
|
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
|
|
235
238
|
assert self.use_caching
|
|
236
239
|
if not self.crops_to_trials_computed:
|
|
@@ -238,37 +241,40 @@ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
|
|
|
238
241
|
# Prevent that rng state of torch is changed by
|
|
239
242
|
# creation+usage of iterator
|
|
240
243
|
rng_state = torch.random.get_rng_state()
|
|
241
|
-
pred_results = net.predict_with_window_inds_and_ys(
|
|
242
|
-
dataset_train)
|
|
244
|
+
pred_results = net.predict_with_window_inds_and_ys(dataset_train)
|
|
243
245
|
torch.random.set_rng_state(rng_state)
|
|
244
246
|
else:
|
|
245
247
|
pred_results = {}
|
|
246
|
-
pred_results[
|
|
248
|
+
pred_results["i_window_in_trials"] = np.concatenate(
|
|
247
249
|
[i[0].cpu().numpy() for i in self.window_inds_]
|
|
248
250
|
)
|
|
249
|
-
pred_results[
|
|
251
|
+
pred_results["i_window_stops"] = np.concatenate(
|
|
250
252
|
[i[2].cpu().numpy() for i in self.window_inds_]
|
|
251
253
|
)
|
|
252
|
-
pred_results[
|
|
253
|
-
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
254
|
-
|
|
255
|
-
|
|
254
|
+
pred_results["preds"] = np.concatenate(
|
|
255
|
+
[y_pred.cpu().numpy() for y_pred in self.y_preds_]
|
|
256
|
+
)
|
|
257
|
+
pred_results["window_ys"] = np.concatenate(
|
|
258
|
+
[y.cpu().numpy() for y in self.y_trues_]
|
|
259
|
+
)
|
|
256
260
|
|
|
257
|
-
num_preds = pred_results[
|
|
261
|
+
num_preds = pred_results["preds"][-1].shape[-1]
|
|
258
262
|
# slice the targets to fit preds shape
|
|
259
|
-
pred_results[
|
|
260
|
-
targets[:, -num_preds:] for targets in pred_results[
|
|
263
|
+
pred_results["window_ys"] = [
|
|
264
|
+
targets[:, -num_preds:] for targets in pred_results["window_ys"]
|
|
261
265
|
]
|
|
262
266
|
|
|
263
267
|
trial_preds = trial_preds_from_window_preds(
|
|
264
|
-
pred_results[
|
|
265
|
-
pred_results[
|
|
266
|
-
pred_results[
|
|
268
|
+
pred_results["preds"],
|
|
269
|
+
pred_results["i_window_in_trials"],
|
|
270
|
+
pred_results["i_window_stops"],
|
|
271
|
+
)
|
|
267
272
|
|
|
268
273
|
trial_ys = trial_preds_from_window_preds(
|
|
269
|
-
pred_results[
|
|
270
|
-
pred_results[
|
|
271
|
-
pred_results[
|
|
274
|
+
pred_results["window_ys"],
|
|
275
|
+
pred_results["i_window_in_trials"],
|
|
276
|
+
pred_results["i_window_stops"],
|
|
277
|
+
)
|
|
272
278
|
|
|
273
279
|
# the output is a list of predictions/targets per trial where each item is a
|
|
274
280
|
# timeseries of predictions/targets of shape (n_classes x timesteps)
|
|
@@ -290,9 +296,10 @@ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
|
|
|
290
296
|
# that are also on same set
|
|
291
297
|
cbs = net.callbacks_
|
|
292
298
|
epoch_cbs = [
|
|
293
|
-
cb
|
|
294
|
-
|
|
295
|
-
|
|
299
|
+
cb
|
|
300
|
+
for name, cb in cbs
|
|
301
|
+
if isinstance(cb, CroppedTimeSeriesEpochScoring)
|
|
302
|
+
and (cb.on_train == self.on_train)
|
|
296
303
|
]
|
|
297
304
|
masked_preds = [torch.tensor(masked_preds.T)]
|
|
298
305
|
for cb in epoch_cbs:
|
|
@@ -365,7 +372,7 @@ class PostEpochTrainScoring(EpochScoring):
|
|
|
365
372
|
for batch in iterator:
|
|
366
373
|
batch_X, batch_y = unpack_data(batch)
|
|
367
374
|
# TODO: remove after skorch 0.10 release
|
|
368
|
-
if not check_version(
|
|
375
|
+
if not check_version("skorch", min_version="0.10.1"):
|
|
369
376
|
yp = net.evaluation_step(batch_X, training=False)
|
|
370
377
|
# X, y unpacking has been pushed downstream in skorch 0.10
|
|
371
378
|
else:
|
|
@@ -394,9 +401,7 @@ class PostEpochTrainScoring(EpochScoring):
|
|
|
394
401
|
with _cache_net_forward_iter(
|
|
395
402
|
net, use_caching=True, y_preds=self.y_preds_
|
|
396
403
|
) as cached_net:
|
|
397
|
-
current_score = self._scoring(
|
|
398
|
-
cached_net, dataset_train, self.y_trues_
|
|
399
|
-
)
|
|
404
|
+
current_score = self._scoring(cached_net, dataset_train, self.y_trues_)
|
|
400
405
|
self._record_score(net.history, current_score)
|
|
401
406
|
|
|
402
407
|
|
|
@@ -432,12 +437,14 @@ def predict_trials(module, dataset, return_targets=True, batch_size=1, num_worke
|
|
|
432
437
|
module.eval()
|
|
433
438
|
# we have a cropped dataset if there exists at least one trial with more
|
|
434
439
|
# than one compute window
|
|
435
|
-
more_than_one_window = sum(dataset.get_metadata()[
|
|
440
|
+
more_than_one_window = sum(dataset.get_metadata()["i_window_in_trial"] != 0) > 0
|
|
436
441
|
if not more_than_one_window:
|
|
437
|
-
warnings.warn(
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
442
|
+
warnings.warn(
|
|
443
|
+
"This function was designed to predict trials from "
|
|
444
|
+
"cropped datasets, which typically have multiple compute "
|
|
445
|
+
"windows per trial. The given dataset has exactly one "
|
|
446
|
+
"window per trial."
|
|
447
|
+
)
|
|
441
448
|
loader = DataLoader(
|
|
442
449
|
dataset=dataset,
|
|
443
450
|
batch_size=batch_size,
|
|
@@ -463,7 +470,8 @@ def predict_trials(module, dataset, return_targets=True, batch_size=1, num_worke
|
|
|
463
470
|
if all_ys[0].shape == ():
|
|
464
471
|
all_ys = np.array(all_ys)
|
|
465
472
|
ys_per_trial = all_ys[
|
|
466
|
-
np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1
|
|
473
|
+
np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1
|
|
474
|
+
]
|
|
467
475
|
else:
|
|
468
476
|
ys_per_trial = trial_preds_from_window_preds(
|
|
469
477
|
preds=all_ys,
|
braindecode/util.py
CHANGED
|
@@ -12,6 +12,7 @@ import mne
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import torch
|
|
14
14
|
from sklearn.utils import check_random_state
|
|
15
|
+
from torch import Tensor
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def set_random_seeds(seed, cuda, cudnn_benchmark=None):
|
|
@@ -51,7 +52,9 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
|
|
|
51
52
|
warn(
|
|
52
53
|
"torch.backends.cudnn.benchmark was set to True which may results in lack of "
|
|
53
54
|
"reproducibility. In some cases to ensure reproducibility you may need to "
|
|
54
|
-
"set torch.backends.cudnn.benchmark to False.",
|
|
55
|
+
"set torch.backends.cudnn.benchmark to False.",
|
|
56
|
+
UserWarning,
|
|
57
|
+
)
|
|
55
58
|
else:
|
|
56
59
|
raise ValueError(
|
|
57
60
|
f"cudnn_benchmark expected to be bool or None, got '{cudnn_benchmark}'"
|
|
@@ -60,19 +63,7 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
|
|
|
60
63
|
np.random.seed(seed)
|
|
61
64
|
|
|
62
65
|
|
|
63
|
-
def
|
|
64
|
-
X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
|
|
65
|
-
):
|
|
66
|
-
warn("np_to_var has been renamed np_to_th, please use np_to_th instead")
|
|
67
|
-
return np_to_th(
|
|
68
|
-
X, requires_grad=requires_grad, dtype=dtype, pin_memory=pin_memory,
|
|
69
|
-
**tensor_kwargs
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def np_to_th(
|
|
74
|
-
X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
|
|
75
|
-
):
|
|
66
|
+
def np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs):
|
|
76
67
|
"""
|
|
77
68
|
Convenience function to transform numpy array to `torch.Tensor`.
|
|
78
69
|
|
|
@@ -103,12 +94,7 @@ def np_to_th(
|
|
|
103
94
|
return X_tensor
|
|
104
95
|
|
|
105
96
|
|
|
106
|
-
def
|
|
107
|
-
warn("var_to_np has been renamed th_to_np, please use th_to_np instead")
|
|
108
|
-
return th_to_np(var)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def th_to_np(var):
|
|
97
|
+
def th_to_np(var: Tensor):
|
|
112
98
|
"""Convenience function to transform `torch.Tensor` to numpy
|
|
113
99
|
array.
|
|
114
100
|
|
|
@@ -209,15 +195,11 @@ def wrap_reshape_apply_fn(stat_fn, a, b, axis_a, axis_b):
|
|
|
209
195
|
)
|
|
210
196
|
assert np.array_equal(n_stat_axis_a, n_stat_axis_b)
|
|
211
197
|
stat_result = stat_fn(flat_topo_a, flat_topo_b)
|
|
212
|
-
topo_result = stat_result.reshape(
|
|
213
|
-
tuple(n_other_axis_a) + tuple(n_other_axis_b)
|
|
214
|
-
)
|
|
198
|
+
topo_result = stat_result.reshape(tuple(n_other_axis_a) + tuple(n_other_axis_b))
|
|
215
199
|
return topo_result
|
|
216
200
|
|
|
217
201
|
|
|
218
|
-
def get_balanced_batches(
|
|
219
|
-
n_trials, rng, shuffle, n_batches=None, batch_size=None
|
|
220
|
-
):
|
|
202
|
+
def get_balanced_batches(n_trials, rng, shuffle, n_batches=None, batch_size=None):
|
|
221
203
|
"""Create indices for batches balanced in size
|
|
222
204
|
(batches will have maximum size difference of 1).
|
|
223
205
|
Supply either batch size or number of batches. Resulting batches
|
|
@@ -268,9 +250,17 @@ def get_balanced_batches(
|
|
|
268
250
|
return batches
|
|
269
251
|
|
|
270
252
|
|
|
271
|
-
def create_mne_dummy_raw(
|
|
272
|
-
|
|
273
|
-
|
|
253
|
+
def create_mne_dummy_raw(
|
|
254
|
+
n_channels,
|
|
255
|
+
n_times,
|
|
256
|
+
sfreq,
|
|
257
|
+
include_anns=True,
|
|
258
|
+
description=None,
|
|
259
|
+
savedir=None,
|
|
260
|
+
save_format="fif",
|
|
261
|
+
overwrite=True,
|
|
262
|
+
random_state=None,
|
|
263
|
+
):
|
|
274
264
|
"""Create an mne.io.RawArray with fake data, and optionally save it.
|
|
275
265
|
|
|
276
266
|
This will overwrite already existing files.
|
|
@@ -305,20 +295,21 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
|
|
|
305
295
|
"""
|
|
306
296
|
random_state = check_random_state(random_state)
|
|
307
297
|
data = random_state.rand(n_channels, n_times)
|
|
308
|
-
ch_names = [f
|
|
309
|
-
ch_types = [
|
|
298
|
+
ch_names = [f"ch{i}" for i in range(n_channels)]
|
|
299
|
+
ch_types = ["eeg"] * n_channels
|
|
310
300
|
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
|
311
301
|
|
|
312
302
|
raw = mne.io.RawArray(data, info)
|
|
313
303
|
|
|
314
304
|
if include_anns:
|
|
315
305
|
n_anns = 10
|
|
316
|
-
inds = np.linspace(
|
|
317
|
-
int
|
|
306
|
+
inds = np.linspace(int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(
|
|
307
|
+
int
|
|
308
|
+
)
|
|
318
309
|
onset = raw.times[inds]
|
|
319
310
|
duration = [1] * n_anns
|
|
320
311
|
if description is None:
|
|
321
|
-
description = [
|
|
312
|
+
description = ["test"] * n_anns
|
|
322
313
|
anns = mne.Annotations(onset, duration, description)
|
|
323
314
|
raw = raw.set_annotations(anns)
|
|
324
315
|
|
|
@@ -326,18 +317,17 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
|
|
|
326
317
|
if savedir is not None:
|
|
327
318
|
if not isinstance(save_format, list):
|
|
328
319
|
save_format = [save_format]
|
|
329
|
-
fname = os.path.join(savedir,
|
|
320
|
+
fname = os.path.join(savedir, "fake_eeg_raw")
|
|
330
321
|
|
|
331
|
-
if
|
|
332
|
-
fif_fname = fname +
|
|
322
|
+
if "fif" in save_format:
|
|
323
|
+
fif_fname = fname + ".fif"
|
|
333
324
|
raw.save(fif_fname, overwrite=overwrite)
|
|
334
|
-
save_fname[
|
|
335
|
-
if
|
|
336
|
-
h5_fname = fname +
|
|
337
|
-
with h5py.File(h5_fname,
|
|
338
|
-
f.create_dataset(
|
|
339
|
-
|
|
340
|
-
save_fname['hdf5'] = h5_fname
|
|
325
|
+
save_fname["fif"] = fif_fname
|
|
326
|
+
if "hdf5" in save_format:
|
|
327
|
+
h5_fname = fname + ".h5"
|
|
328
|
+
with h5py.File(h5_fname, "w") as f:
|
|
329
|
+
f.create_dataset("fake_raw", dtype="f8", data=raw.get_data())
|
|
330
|
+
save_fname["hdf5"] = h5_fname
|
|
341
331
|
|
|
342
332
|
return raw, save_fname
|
|
343
333
|
|
|
@@ -349,7 +339,9 @@ class ThrowAwayIndexLoader(object):
|
|
|
349
339
|
self.last_i = None
|
|
350
340
|
self.is_regression = is_regression
|
|
351
341
|
|
|
352
|
-
def __iter__(
|
|
342
|
+
def __iter__(
|
|
343
|
+
self,
|
|
344
|
+
):
|
|
353
345
|
normal_iter = self.loader.__iter__()
|
|
354
346
|
for batch in normal_iter:
|
|
355
347
|
if len(batch) == 3:
|
|
@@ -360,7 +352,7 @@ class ThrowAwayIndexLoader(object):
|
|
|
360
352
|
x, y = batch
|
|
361
353
|
|
|
362
354
|
# TODO: should be on dataset side
|
|
363
|
-
if hasattr(x,
|
|
355
|
+
if hasattr(x, "type"):
|
|
364
356
|
x = x.type(torch.float32)
|
|
365
357
|
if self.is_regression:
|
|
366
358
|
y = y.type(torch.float32)
|
|
@@ -370,23 +362,26 @@ class ThrowAwayIndexLoader(object):
|
|
|
370
362
|
|
|
371
363
|
|
|
372
364
|
def update_estimator_docstring(base_class, docstring):
|
|
373
|
-
base_doc = base_class.__doc__.replace(
|
|
374
|
-
idx = base_doc.find(
|
|
375
|
-
idx_end = idx + base_doc[idx:].find(
|
|
365
|
+
base_doc = base_class.__doc__.replace(" : ", ": ")
|
|
366
|
+
idx = base_doc.find("callbacks:")
|
|
367
|
+
idx_end = idx + base_doc[idx:].find("\n\n")
|
|
376
368
|
# remove callback descripiton already included in braindecode docstring
|
|
377
|
-
filtered_doc = base_doc[:idx] + base_doc[idx_end + 6:]
|
|
378
|
-
splitted = docstring.split(
|
|
369
|
+
filtered_doc = base_doc[:idx] + base_doc[idx_end + 6 :]
|
|
370
|
+
splitted = docstring.split("Parameters\n ----------\n ")
|
|
379
371
|
out_docstring = (
|
|
380
|
-
splitted[0]
|
|
381
|
-
filtered_doc[
|
|
382
|
-
|
|
383
|
-
|
|
372
|
+
splitted[0]
|
|
373
|
+
+ filtered_doc[
|
|
374
|
+
filtered_doc.find("Parameters") : filtered_doc.find("Attributes")
|
|
375
|
+
]
|
|
376
|
+
+ splitted[1]
|
|
377
|
+
+ filtered_doc[filtered_doc.find("Attributes") :]
|
|
378
|
+
)
|
|
384
379
|
return out_docstring
|
|
385
380
|
|
|
386
381
|
|
|
387
382
|
def _update_moabb_docstring(base_class, docstring):
|
|
388
383
|
base_doc = base_class.__doc__
|
|
389
|
-
out_docstring = base_doc + f
|
|
384
|
+
out_docstring = base_doc + f"\n\n{docstring}"
|
|
390
385
|
return out_docstring
|
|
391
386
|
|
|
392
387
|
|
|
@@ -406,8 +401,9 @@ def read_all_file_names(directory, extension):
|
|
|
406
401
|
file_paths: list(str)
|
|
407
402
|
List of all files found in (sub)directories of path.
|
|
408
403
|
"""
|
|
409
|
-
assert extension.startswith(
|
|
410
|
-
file_paths = glob.glob(directory +
|
|
404
|
+
assert extension.startswith(".")
|
|
405
|
+
file_paths = glob.glob(directory + "**/*" + extension, recursive=True)
|
|
411
406
|
assert len(file_paths) > 0, (
|
|
412
|
-
f
|
|
407
|
+
f"something went wrong. Found no {extension} files in {directory}"
|
|
408
|
+
)
|
|
413
409
|
return file_paths
|
braindecode/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0
|
|
1
|
+
__version__ = "1.1.0"
|
|
@@ -2,8 +2,7 @@
|
|
|
2
2
|
Functions for visualisations, especially of the ConvNets.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from .gradients import compute_amplitude_gradients
|
|
6
5
|
from .confusion_matrices import plot_confusion_matrix
|
|
6
|
+
from .gradients import compute_amplitude_gradients
|
|
7
7
|
|
|
8
|
-
__all__ = ["compute_amplitude_gradients",
|
|
9
|
-
"plot_confusion_matrix"]
|
|
8
|
+
__all__ = ["compute_amplitude_gradients", "plot_confusion_matrix"]
|