braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,477 @@
1
+ # Authors: Maciej Sliwowski <maciek.sliwowski@gmail.com>
2
+ # Robin Tibor Schirrmeister <robintibor@gmail.com>
3
+ # Alexandre Gramfort <alexandre.gramfort@inria.fr>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Mohammed Fattouh <mo.fattouh@gmail.com>
6
+ #
7
+ # License: BSD-3
8
+
9
+ import warnings
10
+ from contextlib import contextmanager
11
+
12
+ import numpy as np
13
+ import torch
14
+ from skorch.callbacks.scoring import EpochScoring
15
+ from skorch.dataset import unpack_data
16
+ from skorch.utils import to_numpy
17
+ from torch.utils.data import DataLoader
18
+
19
+
20
+ def trial_preds_from_window_preds(preds, i_window_in_trials, i_stop_in_trials):
21
+ """
22
+ Assigning window predictions to trials while removing duplicate
23
+ predictions.
24
+
25
+ Parameters
26
+ ----------
27
+ preds: list of ndarrays (at least 2darrays)
28
+ List of window predictions, in each window prediction
29
+ time is in axis=1
30
+ i_window_in_trials: list
31
+ Index/number of window in trial
32
+ i_stop_in_trials: list
33
+ stop position of window in trial
34
+
35
+ Returns
36
+ -------
37
+ preds_per_trial: list of ndarrays
38
+ Predictions in each trial, duplicates removed
39
+
40
+ """
41
+ assert len(preds) == len(i_window_in_trials) == len(i_stop_in_trials), (
42
+ f"{len(preds)}, {len(i_window_in_trials)}, {len(i_stop_in_trials)}"
43
+ )
44
+
45
+ # Algorithm for assigning window predictions to trials
46
+ # while removing duplicate predictions:
47
+ # Loop through windows:
48
+ # In each iteration you have predictions (assumed: #classes x #timesteps,
49
+ # or at least #timesteps must be in axis=1)
50
+ # and you have i_window_in_trial, i_stop_in_trial
51
+ # (i_trial removed from variable names for brevity)
52
+ # You first check if the i_window_in_trial is 1 larger
53
+ # than in last iteration, then you are still in the same trial
54
+ # Otherwise you are in a new trial
55
+ # If you are in the same trial, you check for duplicate predictions
56
+ # Only take predictions that are after (inclusive)
57
+ # the stop of the last iteration (i.e., the index of final prediction
58
+ # in the last iteration)
59
+ # Then add the duplicate-removed predictions from this window
60
+ # to predictions for current trial
61
+ preds_per_trial = []
62
+ cur_trial_preds = []
63
+ i_last_stop = None
64
+ i_last_window = -1
65
+ for window_preds, i_window, i_stop in zip(
66
+ preds, i_window_in_trials, i_stop_in_trials
67
+ ):
68
+ window_preds = np.array(window_preds)
69
+ if i_window != (i_last_window + 1):
70
+ assert i_window == 0, "window numbers in new trial should start from 0"
71
+ preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
72
+ cur_trial_preds = []
73
+ i_last_stop = None
74
+
75
+ if i_last_stop is not None:
76
+ # Remove duplicates
77
+ n_needed_preds = i_stop - i_last_stop
78
+ window_preds = window_preds[:, -n_needed_preds:]
79
+ cur_trial_preds.append(window_preds)
80
+ i_last_window = i_window
81
+ i_last_stop = i_stop
82
+ # add last trial preds
83
+ preds_per_trial.append(np.concatenate(cur_trial_preds, axis=1))
84
+ return preds_per_trial
85
+
86
+
87
+ @contextmanager
88
+ def _cache_net_forward_iter(net, use_caching, y_preds):
89
+ """Caching context for ``skorch.NeuralNet`` instance.
90
+ Returns a modified version of the net whose ``forward_iter``
91
+ method will subsequently return cached predictions. Leaving the
92
+ context will undo the overwrite of the ``forward_iter`` method.
93
+ """
94
+ if not use_caching:
95
+ yield net
96
+ return
97
+ y_preds = iter(y_preds)
98
+
99
+ # pylint: disable=unused-argument
100
+ def cached_forward_iter(*args, device=net.device, **kwargs):
101
+ for yp in y_preds:
102
+ yield yp.to(device=device)
103
+
104
+ net.forward_iter = cached_forward_iter
105
+ try:
106
+ yield net
107
+ finally:
108
+ # By setting net.forward_iter we define an attribute
109
+ # `forward_iter` that precedes the bound method
110
+ # `forward_iter`. By deleting the entry from the attribute
111
+ # dict we undo this.
112
+ del net.__dict__["forward_iter"]
113
+
114
+
115
+ class CroppedTrialEpochScoring(EpochScoring):
116
+ """
117
+ Class to compute scores for trials from a model that predicts (super)crops.
118
+ """
119
+
120
+ # XXX needs a docstring !!!
121
+
122
+ def __init__(
123
+ self,
124
+ scoring,
125
+ lower_is_better=True,
126
+ on_train=False,
127
+ name=None,
128
+ target_extractor=to_numpy,
129
+ use_caching=True,
130
+ ):
131
+ super().__init__(
132
+ scoring=scoring,
133
+ lower_is_better=lower_is_better,
134
+ on_train=on_train,
135
+ name=name,
136
+ target_extractor=target_extractor,
137
+ use_caching=use_caching,
138
+ )
139
+ if not self.on_train:
140
+ self.window_inds_ = []
141
+
142
+ def _initialize_cache(self):
143
+ super()._initialize_cache()
144
+ self.crops_to_trials_computed = False
145
+ self.y_trues_ = []
146
+ self.y_preds_ = []
147
+ if not self.on_train:
148
+ self.window_inds_ = []
149
+
150
+ def on_batch_end(self, net, batch, y_pred, training, **kwargs):
151
+ # Skorch saves the predictions without moving them from GPU
152
+ # https://github.com/skorch-dev/skorch/blob/fe71e3d55a4ae5f5f94ef7bdfc00fca3b3fd267f/skorch/callbacks/scoring.py#L385
153
+ # This can cause memory issues in case of a large number of predictions
154
+ # Therefore here we move them to CPU already
155
+ super().on_batch_end(net, batch, y_pred, training, **kwargs)
156
+ if self.use_caching and training == self.on_train:
157
+ self.y_preds_[-1] = self.y_preds_[-1].cpu()
158
+
159
+ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
160
+ assert self.use_caching
161
+ if not self.crops_to_trials_computed:
162
+ if self.on_train:
163
+ # Prevent that rng state of torch is changed by
164
+ # creation+usage of iterator
165
+ rng_state = torch.random.get_rng_state()
166
+ pred_results = net.predict_with_window_inds_and_ys(dataset_train)
167
+ torch.random.set_rng_state(rng_state)
168
+ else:
169
+ pred_results = {}
170
+ pred_results["i_window_in_trials"] = np.concatenate(
171
+ [i[0].cpu().numpy() for i in self.window_inds_]
172
+ )
173
+ pred_results["i_window_stops"] = np.concatenate(
174
+ [i[2].cpu().numpy() for i in self.window_inds_]
175
+ )
176
+ pred_results["preds"] = np.concatenate(
177
+ [y_pred.cpu().numpy() for y_pred in self.y_preds_]
178
+ )
179
+ pred_results["window_ys"] = np.concatenate(
180
+ [y.cpu().numpy() for y in self.y_trues_]
181
+ )
182
+
183
+ # A new trial starts
184
+ # when the index of the window in trials
185
+ # does not increment by 1
186
+ # Add dummy infinity at start
187
+ window_0_per_trial_mask = (
188
+ np.diff(pred_results["i_window_in_trials"], prepend=[np.inf]) != 1
189
+ )
190
+ trial_ys = pred_results["window_ys"][window_0_per_trial_mask]
191
+ trial_preds = trial_preds_from_window_preds(
192
+ pred_results["preds"],
193
+ pred_results["i_window_in_trials"],
194
+ pred_results["i_window_stops"],
195
+ )
196
+
197
+ # Average across the timesteps of each trial so we have per-trial
198
+ # predictions already, these will be just passed through the forward
199
+ # method of the classifier/regressor to the skorch scoring function.
200
+ # trial_preds is a list, each item is a 2d array classes x time
201
+ y_preds_per_trial = np.array([np.mean(p, axis=1) for p in trial_preds])
202
+ # Move into format expected by skorch (list of torch tensors)
203
+ y_preds_per_trial = [torch.tensor(y_preds_per_trial)]
204
+
205
+ # Store the computed trial preds for all Cropped Callbacks
206
+ # that are also on same set
207
+ cbs = net.callbacks_
208
+ epoch_cbs = [
209
+ cb
210
+ for name, cb in cbs
211
+ if isinstance(cb, CroppedTrialEpochScoring)
212
+ and (cb.on_train == self.on_train)
213
+ ]
214
+ for cb in epoch_cbs:
215
+ cb.y_preds_ = y_preds_per_trial
216
+ cb.y_trues_ = trial_ys
217
+ cb.crops_to_trials_computed = True
218
+
219
+ dataset = dataset_train if self.on_train else dataset_valid
220
+
221
+ with _cache_net_forward_iter(
222
+ net, self.use_caching, self.y_preds_
223
+ ) as cached_net:
224
+ current_score = self._scoring(cached_net, dataset, self.y_trues_)
225
+ self._record_score(net.history, current_score)
226
+
227
+ return
228
+
229
+
230
+ class CroppedTimeSeriesEpochScoring(CroppedTrialEpochScoring):
231
+ """
232
+ Class to compute scores for trials from a model that predicts (super)crops with
233
+ time series target.
234
+ """
235
+
236
+ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
237
+ assert self.use_caching
238
+ if not self.crops_to_trials_computed:
239
+ if self.on_train:
240
+ # Prevent that rng state of torch is changed by
241
+ # creation+usage of iterator
242
+ rng_state = torch.random.get_rng_state()
243
+ pred_results = net.predict_with_window_inds_and_ys(dataset_train)
244
+ torch.random.set_rng_state(rng_state)
245
+ else:
246
+ pred_results = {}
247
+ pred_results["i_window_in_trials"] = np.concatenate(
248
+ [i[0].cpu().numpy() for i in self.window_inds_]
249
+ )
250
+ pred_results["i_window_stops"] = np.concatenate(
251
+ [i[2].cpu().numpy() for i in self.window_inds_]
252
+ )
253
+ pred_results["preds"] = np.concatenate(
254
+ [y_pred.cpu().numpy() for y_pred in self.y_preds_]
255
+ )
256
+ pred_results["window_ys"] = np.concatenate(
257
+ [y.cpu().numpy() for y in self.y_trues_]
258
+ )
259
+
260
+ num_preds = pred_results["preds"][-1].shape[-1]
261
+ # slice the targets to fit preds shape
262
+ pred_results["window_ys"] = [
263
+ targets[:, -num_preds:] for targets in pred_results["window_ys"]
264
+ ]
265
+
266
+ trial_preds = trial_preds_from_window_preds(
267
+ pred_results["preds"],
268
+ pred_results["i_window_in_trials"],
269
+ pred_results["i_window_stops"],
270
+ )
271
+
272
+ trial_ys = trial_preds_from_window_preds(
273
+ pred_results["window_ys"],
274
+ pred_results["i_window_in_trials"],
275
+ pred_results["i_window_stops"],
276
+ )
277
+
278
+ # the output is a list of predictions/targets per trial where each item is a
279
+ # timeseries of predictions/targets of shape (n_classes x timesteps)
280
+
281
+ # mask NaNs form targets
282
+ preds = np.hstack(trial_preds) # n_classes x timesteps in all trials
283
+ targets = np.hstack(trial_ys)
284
+ # create valid targets mask
285
+ mask = ~np.isnan(targets)
286
+ # select valid targets that have a matching predictions
287
+ masked_targets = targets[mask]
288
+ # For classification there is only one row in targets and n_classes rows in preds
289
+ if mask.shape[0] != preds.shape[0]:
290
+ masked_preds = preds[:, mask[0, :]]
291
+ else:
292
+ masked_preds = preds[mask]
293
+
294
+ # Store the computed trial preds for all Cropped Callbacks
295
+ # that are also on same set
296
+ cbs = net.callbacks_
297
+ epoch_cbs = [
298
+ cb
299
+ for name, cb in cbs
300
+ if isinstance(cb, CroppedTimeSeriesEpochScoring)
301
+ and (cb.on_train == self.on_train)
302
+ ]
303
+ masked_preds = [torch.tensor(masked_preds.T)]
304
+ for cb in epoch_cbs:
305
+ cb.y_preds_ = masked_preds
306
+ cb.y_trues_ = masked_targets.T
307
+ cb.crops_to_trials_computed = True
308
+
309
+ dataset = dataset_train if self.on_train else dataset_valid
310
+
311
+ with _cache_net_forward_iter(
312
+ net, self.use_caching, self.y_preds_
313
+ ) as cached_net:
314
+ current_score = self._scoring(cached_net, dataset, self.y_trues_)
315
+ self._record_score(net.history, current_score)
316
+
317
+
318
+ class PostEpochTrainScoring(EpochScoring):
319
+ """
320
+ Epoch Scoring class that recomputes predictions after the epoch
321
+ on the training in validation mode.
322
+
323
+ Note: For unknown reasons, this affects global random generator and
324
+ therefore all results may change slightly if you add this scoring callback.
325
+
326
+ Parameters
327
+ ----------
328
+ scoring : None, str, or callable (default=None)
329
+ If None, use the ``score`` method of the model. If str, it
330
+ should be a valid sklearn scorer (e.g. "f1", "accuracy"). If a
331
+ callable, it should have the signature (model, X, y), and it
332
+ should return a scalar. This works analogously to the
333
+ ``scoring`` parameter in sklearn's ``GridSearchCV`` et al.
334
+ lower_is_better : bool (default=True)
335
+ Whether lower scores should be considered better or worse.
336
+ name : str or None (default=None)
337
+ If not an explicit string, tries to infer the name from the
338
+ ``scoring`` argument.
339
+ target_extractor : callable (default=to_numpy)
340
+ This is called on y before it is passed to scoring.
341
+ """
342
+
343
+ def __init__(
344
+ self,
345
+ scoring,
346
+ lower_is_better=True,
347
+ name=None,
348
+ target_extractor=to_numpy,
349
+ ):
350
+ super().__init__(
351
+ scoring=scoring,
352
+ lower_is_better=lower_is_better,
353
+ on_train=True,
354
+ name=name,
355
+ target_extractor=target_extractor,
356
+ use_caching=False,
357
+ )
358
+
359
+ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
360
+ if len(self.y_preds_) == 0:
361
+ dataset = net.get_dataset(dataset_train)
362
+ # Prevent that rng state of torch is changed by
363
+ # creation+usage of iterator
364
+ # Unfortunatenly calling __iter__() of a pytorch
365
+ # DataLoader will change the random state
366
+ # Note line below setting rng state back
367
+ rng_state = torch.random.get_rng_state()
368
+ iterator = net.get_iterator(dataset, training=False)
369
+ y_preds = []
370
+ y_test = []
371
+ for batch in iterator:
372
+ _, batch_y = unpack_data(batch)
373
+ yp = net.evaluation_step(batch, training=False)
374
+ yp = yp.to(device="cpu")
375
+ y_test.append(self.target_extractor(batch_y))
376
+ y_preds.append(yp)
377
+ y_test = np.concatenate(y_test)
378
+ torch.random.set_rng_state(rng_state)
379
+
380
+ # Adding the recomputed preds to all other
381
+ # instances of PostEpochTrainScoring of this
382
+ # Skorch-Net (NeuralNet, BraindecodeClassifier etc.)
383
+ # (They will be reinitialized to empty lists by skorch
384
+ # each epoch)
385
+ cbs = net.callbacks_
386
+ epoch_cbs = [
387
+ cb for name, cb in cbs if isinstance(cb, PostEpochTrainScoring)
388
+ ]
389
+ for cb in epoch_cbs:
390
+ cb.y_preds_ = y_preds
391
+ cb.y_trues_ = y_test
392
+ # y pred should be same as self.y_preds_
393
+ # Unclear if this also leads to any
394
+ # random generator call?
395
+ with _cache_net_forward_iter(
396
+ net, use_caching=True, y_preds=self.y_preds_
397
+ ) as cached_net:
398
+ current_score = self._scoring(cached_net, dataset_train, self.y_trues_)
399
+ self._record_score(net.history, current_score)
400
+
401
+
402
+ def predict_trials(module, dataset, return_targets=True, batch_size=1, num_workers=0):
403
+ """Create trialwise predictions and optionally also return trialwise
404
+ labels from cropped dataset given module.
405
+
406
+ Parameters
407
+ ----------
408
+ module: torch.nn.Module
409
+ A pytorch model implementing forward.
410
+ dataset: braindecode.datasets.BaseConcatDataset
411
+ A braindecode dataset to be predicted.
412
+ return_targets: bool
413
+ If True, additionally returns the trial targets.
414
+ batch_size: int
415
+ The batch size used to iterate the dataset.
416
+ num_workers: int
417
+ Number of workers used in DataLoader to iterate the dataset.
418
+
419
+ Returns
420
+ -------
421
+ trial_predictions: np.ndarray
422
+ 3-dimensional array (n_trials x n_classes x n_predictions), where
423
+ the number of predictions depend on the chosen window size and the
424
+ receptive field of the network.
425
+ trial_labels: np.ndarray
426
+ 2-dimensional array (n_trials x n_targets) where the number of
427
+ targets depends on the decoding paradigm and can be either a single
428
+ value, multiple values, or a sequence.
429
+ """
430
+ # Ensure the model is in evaluation mode
431
+ module.eval()
432
+ # we have a cropped dataset if there exists at least one trial with more
433
+ # than one compute window
434
+ more_than_one_window = sum(dataset.get_metadata()["i_window_in_trial"] != 0) > 0
435
+ if not more_than_one_window:
436
+ warnings.warn(
437
+ "This function was designed to predict trials from "
438
+ "cropped datasets, which typically have multiple compute "
439
+ "windows per trial. The given dataset has exactly one "
440
+ "window per trial."
441
+ )
442
+ loader = DataLoader(
443
+ dataset=dataset,
444
+ batch_size=batch_size,
445
+ shuffle=False,
446
+ num_workers=num_workers,
447
+ )
448
+ device = next(module.parameters()).device
449
+ all_preds, all_ys, all_inds = [], [], []
450
+ with torch.no_grad():
451
+ for X, y, ind in loader:
452
+ X = X.to(device)
453
+ preds = module(X)
454
+ all_preds.extend(preds.cpu().numpy().astype(np.float32))
455
+ all_ys.extend(y.cpu().numpy().astype(np.float32))
456
+ all_inds.extend(ind)
457
+ preds_per_trial = trial_preds_from_window_preds(
458
+ preds=all_preds,
459
+ i_window_in_trials=torch.cat(all_inds[0::3]),
460
+ i_stop_in_trials=torch.cat(all_inds[2::3]),
461
+ )
462
+ preds_per_trial = np.array(preds_per_trial)
463
+ if return_targets:
464
+ if all_ys[0].shape == ():
465
+ all_ys = np.array(all_ys)
466
+ ys_per_trial = all_ys[
467
+ np.diff(torch.cat(all_inds[0::3]), prepend=[np.inf]) != 1
468
+ ]
469
+ else:
470
+ ys_per_trial = trial_preds_from_window_preds(
471
+ preds=all_ys,
472
+ i_window_in_trials=torch.cat(all_inds[0::3]),
473
+ i_stop_in_trials=torch.cat(all_inds[2::3]),
474
+ )
475
+ ys_per_trial = np.array(ys_per_trial)
476
+ return preds_per_trial, ys_per_trial
477
+ return preds_per_trial