braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,483 @@
1
+ # Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
2
+ # Robin Tibor Schirrmeister <robintibor@gmail.com>
3
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Mohammed Fattouh <mo.fattouh@gmail.com>
6
+ #
7
+ # License: BSD-3
8
+
9
+ import warnings
10
+ from contextlib import contextmanager
11
+
12
+ import numpy as np
13
+ import torch
14
+ from mne.utils.check import check_version
15
+ from skorch.callbacks.scoring import EpochScoring
16
+ from skorch.dataset import unpack_data
17
+ from skorch.utils import to_numpy
18
+ from torch.utils.data import DataLoader
19
+
20
+
21
+ def trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials):
22
+ """
23
+ Assigning window predictions to trials while removing duplicate
24
+ predictions.
25
+
26
+ Parameters
27
+ ----------
28
+ preds: list of ndarrays (at least 2darrays)
29
+ List of window predictions, in each window prediction
30
+ time is in axis=1
31
+ i_window_in_trials: list
32
+ Index/number of window in trial
33
+ i_stop_in_trials: list
34
+ stop position of window in trial
35
+
36
+ Returns
37
+ -------
38
+ preds_per_trial: list of ndarrays
39
+ Predictions in each trial, duplicates removed
40
+
41
+ """
42
+ assert len(preds) == len(i_window_in_trials) == len(i_stop_in_trials), (
43
+ f"{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}"
44
+ )
45
+
46
+ # Algorithm for assigning window predictions to trials
47
+ # while removing duplicate predictions:
48
+ # Loop through windows:
49
+ # In each iteration you have predictions (assumed: #classes x #timesteps,
50
+ # or at least #timesteps must be in axis=1)
51
+ # and you have i_window_in_trial, i_stop_in_trial
52
+ # (i_trial removed from variable names for brevity)
53
+ # You first check if the i_window_in_trial is 1 larger
54
+ # than in last iteration, then you are still in the same trial
55
+ # Otherwise you are in a new trial
56
+ # If you are in the same trial, you check for duplicate predictions
57
+ # Only take predictions that are after (inclusive)
58
+ # the stop of the last iteration (i.e., the index of final prediction
59
+ # in the last iteration)
60
+ # Then add the duplicate-removed predictions from this window
61
+ # to predictions for current trial
62
+ preds_per_trial = []
63
+ cur_trial_preds = []
64
+ i_last_stop = None
65
+ i_last_window = -1
66
+ for window_preds, i_window, i_stop in zip(
67
+ preds, i_window_in_trials, i_stop_in_trials
68
+ ):
69
+ window_preds = np.array(window_preds)
70
+ if i_window != (i_last_window + 1):
71
+ assert i_window == 0, "window numbers in new trial should start from 0"
72
+ preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
73
+ cur_trial_preds = []
74
+ i_last_stop = None
75
+
76
+ if i_last_stop is not None:
77
+ # Remove duplicates
78
+ n_needed_preds = i_stop - i_last_stop
79
+ window_preds = window_preds[:, -n_needed_preds:]
80
+ cur_trial_preds.append(window_preds)
81
+ i_last_window = i_window
82
+ i_last_stop = i_stop
83
+ # add last trial preds
84
+ preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
85
+ return preds_per_trial
86
+
87
+
88
+ @contextmanager
89
+ def _cache_net_forward_iter(net, use_caching, y_preds):
90
+ """Caching context for ``skorch.NeuralNet`` instance.
91
+ Returns a modified version of the net whose ``forward_iter``
92
+ method will subsequently return cached predictions. Leaving the
93
+ context will undo the overwrite of the ``forward_iter`` method.
94
+ """
95
+ if not use_caching:
96
+ yield net
97
+ return
98
+ y_preds = iter(y_preds)
99
+
100
+ # pylint: disable=unused-argument
101
+ def cached_forward_iter(*args, device=net.device, **kwargs):
102
+ for yp in y_preds:
103
+ yield yp.to(device=device)
104
+
105
+ net.forward_iter = cached_forward_iter
106
+ try:
107
+ yield net
108
+ finally:
109
+ # By setting net.forward_iter we define an attribute
110
+ # `forward_iter` that precedes the bound method
111
+ # `forward_iter`. By deleting the entry from the attribute
112
+ # dict we undo this.
113
+ del net.__dict__["forward_iter"]
114
+
115
+
116
+ class CroppedTrialEpochScoring(EpochScoring):
117
+ """
118
+ Class to compute scores for trials from a model that predicts (super)crops.
119
+ """
120
+
121
+ # XXX needs a docstring !!!
122
+
123
+ def __init__(
124
+ self,
125
+ scoring,
126
+ lower_is_better=True,
127
+ on_train=False,
128
+ name=None,
129
+ target_extractor=to_numpy,
130
+ use_caching=True,
131
+ ):
132
+ super().__init__(
133
+ scoring=scoring,
134
+ lower_is_better=lower_is_better,
135
+ on_train=on_train,
136
+ name=name,
137
+ target_extractor=target_extractor,
138
+ use_caching=use_caching,
139
+ )
140
+ if not self.on_train:
141
+ self.window_inds_ = []
142
+
143
+ def _initialize_cache(self):
144
+ super()._initialize_cache()
145
+ self.crops_to_trials_computed = False
146
+ self.y_trues_ = []
147
+ self.y_preds_ = []
148
+ if not self.on_train:
149
+ self.window_inds_ = []
150
+
151
+ def on_batch_end(self, net, batch, y_pred, training, **kwargs):
152
+ # Skorch saves the predictions without moving them from GPU
153
+ # https://github.com/skorch-dev/skorch/blob/fe71e3d55a4ae5f5f94ef7bdfc00fca3b3fd267f/skorch/callbacks/scoring.py#L385
154
+ # This can cause memory issues in case of a large number of predictions
155
+ # Therefore here we move them to CPU already
156
+ super().on_batch_end(net, batch, y_pred, training, **kwargs)
157
+ if self.use_caching and training == self.on_train:
158
+ self.y_preds_[-1] = self.y_preds_[-1].cpu()
159
+
160
+ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
161
+ assert self.use_caching
162
+ if not self.crops_to_trials_computed:
163
+ if self.on_train:
164
+ # Prevent that rng state of torch is changed by
165
+ # creation+usage of iterator
166
+ rng_state = torch.random.get_rng_state()
167
+ pred_results = net.predict_with_window_inds_and_ys(dataset_train)
168
+ torch.random.set_rng_state(rng_state)
169
+ else:
170
+ pred_results = {}
171
+ pred_results["i_window_in_trials"] = np.concatenate(
172
+ [i[0].cpu().numpy() for i in self.window_inds_]
173
+ )
174
+ pred_results["i_window_stops"] = np.concatenate(
175
+ [i[2].cpu().numpy() for i in self.window_inds_]
176
+ )
177
+ pred_results["preds"] = np.concatenate(
178
+ [y_pred.cpu().numpy() for y_pred in self.y_preds_]
179
+ )
180
+ pred_results["window_ys"] = np.concatenate(
181
+ [y.cpu().numpy() for y in self.y_trues_]
182
+ )
183
+
184
+ # A new trial starts
185
+ # when the index of the window in trials
186
+ # does not increment by 1
187
+ # Add dummy infinity at start
188
+ window_0_per_trial_mask = (
189
+ np.diff(pred_results["i_window_in_trials"], prepend=[np.inf]) != 1
190
+ )
191
+ trial_ys = pred_results["window_ys"][window_0_per_trial_mask]
192
+ trial_preds = trial_preds_from_window_preds(
193
+ pred_results["preds"],
194
+ pred_results["i_window_in_trials"],
195
+ pred_results["i_window_stops"],
196
+ )
197
+
198
+ # Average across the timesteps of each trial so we have per-trial
199
+ # predictions already, these will be just passed through the forward
200
+ # method of the classifier/regressor to the skorch scoring function.
201
+ # trial_preds is a list, each item is a 2d array classes x time
202
+ y_preds_per_trial = np.array([np.mean(p, axis=1) for p in trial_preds])
203
+ # Move into format expected by skorch (list of torch tensors)
204
+ y_preds_per_trial = [torch.tensor(y_preds_per_trial)]
205
+
206
+ # Store the computed trial preds for all Cropped Callbacks
207
+ # that are also on same set
208
+ cbs = net.callbacks_
209
+ epoch_cbs = [
210
+ cb
211
+ for name, cb in cbs
212
+ if isinstance(cb, CroppedTrialEpochScoring)
213
+ and (cb.on_train == self.on_train)
214
+ ]
215
+ for cb in epoch_cbs:
216
+ cb.y_preds_ = y_preds_per_trial
217
+ cb.y_trues_ = trial_ys
218
+ cb.crops_to_trials_computed = True
219
+
220
+ dataset = dataset_train if self.on_train else dataset_valid
221
+
222
+ with _cache_net_forward_iter(
223
+ net, self.use_caching, self.y_preds_
224
+ ) as cached_net:
225
+ current_score = self._scoring(cached_net, dataset, self.y_trues_)
226
+ self._record_score(net.history, current_score)
227
+
228
+ return
229
+
230
+
231
+ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
232
+ """
233
+ Class to compute scores for trials from a model that predicts (super)crops with
234
+ time series target.
235
+ """
236
+
237
+ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
238
+ assert self.use_caching
239
+ if not self.crops_to_trials_computed:
240
+ if self.on_train:
241
+ # Prevent that rng state of torch is changed by
242
+ # creation+usage of iterator
243
+ rng_state = torch.random.get_rng_state()
244
+ pred_results = net.predict_with_window_inds_and_ys(dataset_train)
245
+ torch.random.set_rng_state(rng_state)
246
+ else:
247
+ pred_results = {}
248
+ pred_results["i_window_in_trials"] = np.concatenate(
249
+ [i[0].cpu().numpy() for i in self.window_inds_]
250
+ )
251
+ pred_results["i_window_stops"] = np.concatenate(
252
+ [i[2].cpu().numpy() for i in self.window_inds_]
253
+ )
254
+ pred_results["preds"] = np.concatenate(
255
+ [y_pred.cpu().numpy() for y_pred in self.y_preds_]
256
+ )
257
+ pred_results["window_ys"] = np.concatenate(
258
+ [y.cpu().numpy() for y in self.y_trues_]
259
+ )
260
+
261
+ num_preds = pred_results["preds"][-1].shape[-1]
262
+ # slice the targets to fit preds shape
263
+ pred_results["window_ys"] = [
264
+ targets[:, -num_preds:] for targets in pred_results["window_ys"]
265
+ ]
266
+
267
+ trial_preds = trial_preds_from_window_preds(
268
+ pred_results["preds"],
269
+ pred_results["i_window_in_trials"],
270
+ pred_results["i_window_stops"],
271
+ )
272
+
273
+ trial_ys = trial_preds_from_window_preds(
274
+ pred_results["window_ys"],
275
+ pred_results["i_window_in_trials"],
276
+ pred_results["i_window_stops"],
277
+ )
278
+
279
+ # the output is a list of predictions/targets per trial where each item is a
280
+ # timeseries of predictions/targets of shape (n_classes x timesteps)
281
+
282
+ # mask NaNs form targets
283
+ preds = np.hstack(trial_preds) # n_classes x timesteps in all trials
284
+ targets = np.hstack(trial_ys)
285
+ # create valid targets mask
286
+ mask = ~np.isnan(targets)
287
+ # select valid targets that have a matching predictions
288
+ masked_targets = targets[mask]
289
+ # For classification there is only one row in targets and n_classes rows in preds
290
+ if mask.shape[0] != preds.shape[0]:
291
+ masked_preds = preds[:, mask[0, :]]
292
+ else:
293
+ masked_preds = preds[mask]
294
+
295
+ # Store the computed trial preds for all Cropped Callbacks
296
+ # that are also on same set
297
+ cbs = net.callbacks_
298
+ epoch_cbs = [
299
+ cb
300
+ for name, cb in cbs
301
+ if isinstance(cb, CroppedTimeSeriesEpochScoring)
302
+ and (cb.on_train == self.on_train)
303
+ ]
304
+ masked_preds = [torch.tensor(masked_preds.T)]
305
+ for cb in epoch_cbs:
306
+ cb.y_preds_ = masked_preds
307
+ cb.y_trues_ = masked_targets.T
308
+ cb.crops_to_trials_computed = True
309
+
310
+ dataset = dataset_train if self.on_train else dataset_valid
311
+
312
+ with _cache_net_forward_iter(
313
+ net, self.use_caching, self.y_preds_
314
+ ) as cached_net:
315
+ current_score = self._scoring(cached_net, dataset, self.y_trues_)
316
+ self._record_score(net.history, current_score)
317
+
318
+
319
+ class PostEpochTrainScoring(EpochScoring):
320
+ """
321
+ Epoch Scoring class that recomputes predictions after the epoch
322
+ on the training in validation mode.
323
+
324
+ Note: For unknown reasons, this affects global random generator and
325
+ therefore all results may change slightly if you add this scoring callback.
326
+
327
+ Parameters
328
+ ----------
329
+ scoring : None, str, or callable (default=None)
330
+ If None, use the ``score`` method of the model. If str, it
331
+ should be a valid sklearn scorer (e.g. "f1", "accuracy"). If a
332
+ callable, it should have the signature (model, X, y), and it
333
+ should return a scalar. This works analogously to the
334
+ ``scoring`` parameter in sklearn's ``GridSearchCV`` et al.
335
+ lower_is_better : bool (default=True)
336
+ Whether lower scores should be considered better or worse.
337
+ name : str or None (default=None)
338
+ If not an explicit string, tries to infer the name from the
339
+ ``scoring`` argument.
340
+ target_extractor : callable (default=to_numpy)
341
+ This is called on y before it is passed to scoring.
342
+ """
343
+
344
+ def __init__(
345
+ self,
346
+ scoring,
347
+ lower_is_better=True,
348
+ name=None,
349
+ target_extractor=to_numpy,
350
+ ):
351
+ super().__init__(
352
+ scoring=scoring,
353
+ lower_is_better=lower_is_better,
354
+ on_train=True,
355
+ name=name,
356
+ target_extractor=target_extractor,
357
+ use_caching=False,
358
+ )
359
+
360
+ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
361
+ if len(self.y_preds_) == 0:
362
+ dataset = net.get_dataset(dataset_train)
363
+ # Prevent that rng state of torch is changed by
364
+ # creation+usage of iterator
365
+ # Unfortunatenly calling __iter__() of a pytorch
366
+ # DataLoader will change the random state
367
+ # Note line below setting rng state back
368
+ rng_state = torch.random.get_rng_state()
369
+ iterator = net.get_iterator(dataset, training=False)
370
+ y_preds = []
371
+ y_test = []
372
+ for batch in iterator:
373
+ batch_X, batch_y = unpack_data(batch)
374
+ # TODO: remove after skorch 0.10 release
375
+ if not check_version("skorch", min_version="0.10.1"):
376
+ yp = net.evaluation_step(batch_X, training=False)
377
+ # X, y unpacking has been pushed downstream in skorch 0.10
378
+ else:
379
+ yp = net.evaluation_step(batch, training=False)
380
+ yp = yp.to(device="cpu")
381
+ y_test.append(self.target_extractor(batch_y))
382
+ y_preds.append(yp)
383
+ y_test = np.concatenate(y_test)
384
+ torch.random.set_rng_state(rng_state)
385
+
386
+ # Adding the recomputed preds to all other
387
+ # instances of PostEpochTrainScoring of this
388
+ # Skorch-Net (NeuralNet, BraindecodeClassifier etc.)
389
+ # (They will be reinitialized to empty lists by skorch
390
+ # each epoch)
391
+ cbs = net.callbacks_
392
+ epoch_cbs = [
393
+ cb for name, cb in cbs if isinstance(cb, PostEpochTrainScoring)
394
+ ]
395
+ for cb in epoch_cbs:
396
+ cb.y_preds_ = y_preds
397
+ cb.y_trues_ = y_test
398
+ # y pred should be same as self.y_preds_
399
+ # Unclear if this also leads to any
400
+ # random generator call?
401
+ with _cache_net_forward_iter(
402
+ net, use_caching=True, y_preds=self.y_preds_
403
+ ) as cached_net:
404
+ current_score = self._scoring(cached_net, dataset_train, self.y_trues_)
405
+ self._record_score(net.history, current_score)
406
+
407
+
408
+ def predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0):
409
+ """Create trialwise predictions and optionally also return trialwise
410
+ labels from cropped dataset given module.
411
+
412
+ Parameters
413
+ ----------
414
+ module: torch.nn.Module
415
+ A pytorch model implementing forward.
416
+ dataset: braindecode.datasets.BaseConcatDataset
417
+ A braindecode dataset to be predicted.
418
+ return_targets: bool
419
+ If True, additionally returns the trial targets.
420
+ batch_size: int
421
+ The batch size used to iterate the dataset.
422
+ num_workers: int
423
+ Number of workers used in DataLoader to iterate the dataset.
424
+
425
+ Returns
426
+ -------
427
+ trial_predictions: np.ndarray
428
+ 3-dimensional array (n_trials x n_classes x n_predictions), where
429
+ the number of predictions depend on the chosen window size and the
430
+ receptive field of the network.
431
+ trial_labels: np.ndarray
432
+ 2-dimensional array (n_trials x n_targets) where the number of
433
+ targets depends on the decoding paradigm and can be either a single
434
+ value, multiple values, or a sequence.
435
+ """
436
+ # Ensure the model is in evaluation mode
437
+ module.eval()
438
+ # we have a cropped dataset if there exists at least one trial with more
439
+ # than one compute window
440
+ more_than_one_window = sum(dataset.get_metadata()["i_window_in_trial"] != 0) > 0
441
+ if not more_than_one_window:
442
+ warnings.warn(
443
+ "This function was designed to predict trials from "
444
+ "cropped datasets, which typically have multiple compute "
445
+ "windows per trial. The given dataset has exactly one "
446
+ "window per trial."
447
+ )
448
+ loader = DataLoader(
449
+ dataset=dataset,
450
+ batch_size=batch_size,
451
+ shuffle=False,
452
+ num_workers=num_workers,
453
+ )
454
+ device = next(module.parameters()).device
455
+ all_preds, all_ys, all_inds = [], [], []
456
+ with torch.no_grad():
457
+ for X, y, ind in loader:
458
+ X = X.to(device)
459
+ preds = module(X)
460
+ all_preds.extend(preds.cpu().numpy().astype(np.float32))
461
+ all_ys.extend(y.cpu().numpy().astype(np.float32))
462
+ all_inds.extend(ind)
463
+ preds_per_trial = trial_preds_from_window_preds(
464
+ preds=all_preds,
465
+ i_window_in_trials=torch.cat(all_inds[0::3]),
466
+ i_stop_in_trials=torch.cat(all_inds[2::3]),
467
+ )
468
+ preds_per_trial = np.array(preds_per_trial)
469
+ if return_targets:
470
+ if all_ys[0].shape == ():
471
+ all_ys = np.array(all_ys)
472
+ ys_per_trial = all_ys[
473
+ np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1
474
+ ]
475
+ else:
476
+ ys_per_trial = trial_preds_from_window_preds(
477
+ preds=all_ys,
478
+ i_window_in_trials=torch.cat(all_inds[0::3]),
479
+ i_stop_in_trials=torch.cat(all_inds[2::3]),
480
+ )
481
+ ys_per_trial = np.array(ys_per_trial)
482
+ return preds_per_trial, ys_per_trial
483
+ return preds_per_trial
braindecode/util.py CHANGED
@@ -12,6 +12,7 @@ import mne
12
12
  import numpy as np
13
13
  import torch
14
14
  from sklearn.utils import check_random_state
15
+ from torch import Tensor
15
16
 
16
17
 
17
18
  def set_random_seeds(seed, cuda, cudnn_benchmark=None):
@@ -51,7 +52,9 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
51
52
  warn(
52
53
  "torch.backends.cudnn.benchmark was set to True which may results in lack of "
53
54
  "reproducibility. In some cases to ensure reproducibility you may need to "
54
- "set torch.backends.cudnn.benchmark to False.", UserWarning)
55
+ "set torch.backends.cudnn.benchmark to False.",
56
+ UserWarning,
57
+ )
55
58
  else:
56
59
  raise ValueError(
57
60
  f"cudnn_benchmark expected to be bool or None, got '{cudnn_benchmark}'"
@@ -60,19 +63,7 @@ def set_random_seeds(seed, cuda, cudnn_benchmark=None):
60
63
  np.random.seed(seed)
61
64
 
62
65
 
63
- def np_to_var(
64
- X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
65
- ):
66
- warn("np_to_var has been renamed np_to_th, please use np_to_th instead")
67
- return np_to_th(
68
- X, requires_grad=requires_grad, dtype=dtype, pin_memory=pin_memory,
69
- **tensor_kwargs
70
- )
71
-
72
-
73
- def np_to_th(
74
- X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs
75
- ):
66
+ def np_to_th(X, requires_grad=False, dtype=None, pin_memory=False, **tensor_kwargs):
76
67
  """
77
68
  Convenience function to transform numpy array to `torch.Tensor`.
78
69
 
@@ -103,12 +94,7 @@ def np_to_th(
103
94
  return X_tensor
104
95
 
105
96
 
106
- def var_to_np(var):
107
- warn("var_to_np has been renamed th_to_np, please use th_to_np instead")
108
- return th_to_np(var)
109
-
110
-
111
- def th_to_np(var):
97
+ def th_to_np(var: Tensor):
112
98
  """Convenience function to transform `torch.Tensor` to numpy
113
99
  array.
114
100
 
@@ -209,15 +195,11 @@ def wrap_reshape_apply_fn(stat_fn, a, b, axis_a, axis_b):
209
195
  )
210
196
  assert np.array_equal(n_stat_axis_a, n_stat_axis_b)
211
197
  stat_result = stat_fn(flat_topo_a, flat_topo_b)
212
- topo_result = stat_result.reshape(
213
- tuple(n_other_axis_a) + tuple(n_other_axis_b)
214
- )
198
+ topo_result = stat_result.reshape(tuple(n_other_axis_a) + tuple(n_other_axis_b))
215
199
  return topo_result
216
200
 
217
201
 
218
- def get_balanced_batches(
219
- n_trials, rng, shuffle, n_batches=None, batch_size=None
220
- ):
202
+ def get_balanced_batches(n_trials, rng, shuffle, n_batches=None, batch_size=None):
221
203
  """Create indices for batches balanced in size
222
204
  (batches will have maximum size difference of 1).
223
205
  Supply either batch size or number of batches. Resulting batches
@@ -268,9 +250,17 @@ def get_balanced_batches(
268
250
  return batches
269
251
 
270
252
 
271
- def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
272
- description=None, savedir=None, save_format='fif',
273
- overwrite=True, random_state=None):
253
+ def create_mne_dummy_raw(
254
+ n_channels,
255
+ n_times,
256
+ sfreq,
257
+ include_anns=True,
258
+ description=None,
259
+ savedir=None,
260
+ save_format="fif",
261
+ overwrite=True,
262
+ random_state=None,
263
+ ):
274
264
  """Create an mne.io.RawArray with fake data, and optionally save it.
275
265
 
276
266
  This will overwrite already existing files.
@@ -305,20 +295,21 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
305
295
  """
306
296
  random_state = check_random_state(random_state)
307
297
  data = random_state.rand(n_channels, n_times)
308
- ch_names = [f'ch{i}' for i in range(n_channels)]
309
- ch_types = ['eeg'] * n_channels
298
+ ch_names = [f"ch{i}" for i in range(n_channels)]
299
+ ch_types = ["eeg"] * n_channels
310
300
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
311
301
 
312
302
  raw = mne.io.RawArray(data, info)
313
303
 
314
304
  if include_anns:
315
305
  n_anns = 10
316
- inds = np.linspace(
317
- int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(int)
306
+ inds = np.linspace(int(sfreq * 2), int(n_times - sfreq * 2), num=n_anns).astype(
307
+ int
308
+ )
318
309
  onset = raw.times[inds]
319
310
  duration = [1] * n_anns
320
311
  if description is None:
321
- description = ['test'] * n_anns
312
+ description = ["test"] * n_anns
322
313
  anns = mne.Annotations(onset, duration, description)
323
314
  raw = raw.set_annotations(anns)
324
315
 
@@ -326,18 +317,17 @@ def create_mne_dummy_raw(n_channels, n_times, sfreq, include_anns=True,
326
317
  if savedir is not None:
327
318
  if not isinstance(save_format, list):
328
319
  save_format = [save_format]
329
- fname = os.path.join(savedir, 'fake_eeg_raw')
320
+ fname = os.path.join(savedir, "fake_eeg_raw")
330
321
 
331
- if 'fif' in save_format:
332
- fif_fname = fname + '.fif'
322
+ if "fif" in save_format:
323
+ fif_fname = fname + ".fif"
333
324
  raw.save(fif_fname, overwrite=overwrite)
334
- save_fname['fif'] = fif_fname
335
- if 'hdf5' in save_format:
336
- h5_fname = fname + '.h5'
337
- with h5py.File(h5_fname, 'w') as f:
338
- f.create_dataset(
339
- 'fake_raw', dtype='f8', data=raw.get_data())
340
- save_fname['hdf5'] = h5_fname
325
+ save_fname["fif"] = fif_fname
326
+ if "hdf5" in save_format:
327
+ h5_fname = fname + ".h5"
328
+ with h5py.File(h5_fname, "w") as f:
329
+ f.create_dataset("fake_raw", dtype="f8", data=raw.get_data())
330
+ save_fname["hdf5"] = h5_fname
341
331
 
342
332
  return raw, save_fname
343
333
 
@@ -349,7 +339,9 @@ class ThrowAwayIndexLoader(object):
349
339
  self.last_i = None
350
340
  self.is_regression = is_regression
351
341
 
352
- def __iter__(self, ):
342
+ def __iter__(
343
+ self,
344
+ ):
353
345
  normal_iter = self.loader.__iter__()
354
346
  for batch in normal_iter:
355
347
  if len(batch) == 3:
@@ -360,7 +352,7 @@ class ThrowAwayIndexLoader(object):
360
352
  x, y = batch
361
353
 
362
354
  # TODO: should be on dataset side
363
- if hasattr(x, 'type'):
355
+ if hasattr(x, "type"):
364
356
  x = x.type(torch.float32)
365
357
  if self.is_regression:
366
358
  y = y.type(torch.float32)
@@ -370,23 +362,26 @@ class ThrowAwayIndexLoader(object):
370
362
 
371
363
 
372
364
  def update_estimator_docstring(base_class, docstring):
373
- base_doc = base_class.__doc__.replace(' : ', ': ')
374
- idx = base_doc.find('callbacks:')
375
- idx_end = idx + base_doc[idx:].find('\n\n')
365
+ base_doc = base_class.__doc__.replace(" : ", ": ")
366
+ idx = base_doc.find("callbacks:")
367
+ idx_end = idx + base_doc[idx:].find("\n\n")
376
368
  # remove callback descripiton already included in braindecode docstring
377
- filtered_doc = base_doc[:idx] + base_doc[idx_end + 6:]
378
- splitted = docstring.split('Parameters\n ----------\n ')
369
+ filtered_doc = base_doc[:idx] + base_doc[idx_end + 6 :]
370
+ splitted = docstring.split("Parameters\n ----------\n ")
379
371
  out_docstring = (
380
- splitted[0] +
381
- filtered_doc[filtered_doc.find('Parameters'):filtered_doc.find('Attributes')] +
382
- splitted[1] +
383
- filtered_doc[filtered_doc.find('Attributes'):])
372
+ splitted[0]
373
+ + filtered_doc[
374
+ filtered_doc.find("Parameters") : filtered_doc.find("Attributes")
375
+ ]
376
+ + splitted[1]
377
+ + filtered_doc[filtered_doc.find("Attributes") :]
378
+ )
384
379
  return out_docstring
385
380
 
386
381
 
387
382
  def _update_moabb_docstring(base_class, docstring):
388
383
  base_doc = base_class.__doc__
389
- out_docstring = base_doc + f'\n\n{docstring}'
384
+ out_docstring = base_doc + f"\n\n{docstring}"
390
385
  return out_docstring
391
386
 
392
387
 
@@ -406,8 +401,9 @@ def read_all_file_names(directory, extension):
406
401
  file_paths: list(str)
407
402
  List of all files found in (sub)directories of path.
408
403
  """
409
- assert extension.startswith('.')
410
- file_paths = glob.glob(directory + '**/*' + extension, recursive=True)
404
+ assert extension.startswith(".")
405
+ file_paths = glob.glob(directory + "**/*" + extension, recursive=True)
411
406
  assert len(file_paths) > 0, (
412
- f'something went wrong. Found no {extension} files in {directory}')
407
+ f"something went wrong. Found no {extension} files in {directory}"
408
+ )
413
409
  return file_paths