braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1031 @@
1
+ """Get epochs from mne.Raw"""
2
+
3
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Simon Brandt <simonbrandt@protonmail.com>
6
+ # David Sabbagh <dav.sabbagh@gmail.com>
7
+ # Henrik Bonsmann <henrikbons@gmail.com>
8
+ # Ann-Kathrin Kiessner <ann-kathrin.kiessner@gmx.de>
9
+ # Vytautas Jankauskas <vytauto.jankausko@gmail.com>
10
+ # Dan Wilson <dan.c.wil@gmail.com>
11
+ # Maciej Sliwowski <maciek.sliwowski@gmail.com>
12
+ # Mohammed Fattouh <mo.fattouh@gmail.com>
13
+ # Robin Schirrmeister <robintibor@gmail.com>
14
+ #
15
+ # License: BSD (3-clause)
16
+
17
+ from __future__ import annotations
18
+
19
+ import warnings
20
+ from typing import Any, Callable
21
+
22
+ import mne
23
+ import numpy as np
24
+ import pandas as pd
25
+ from joblib import Parallel, delayed
26
+ from numpy.typing import ArrayLike
27
+
28
+ from ..datasets.base import BaseConcatDataset, EEGWindowsDataset, WindowsDataset
29
+
30
+
31
+ class _LazyDataFrame:
32
+ """
33
+ DataFrame-like object that lazily computes values (experimental).
34
+
35
+ This class emulates some features of a pandas DataFrame, but computes
36
+ the values on-the-fly when they are accessed. This is useful for
37
+ very long DataFrames with repetitive values.
38
+ Only the methods used by EEGWindowsDataset on its metadata are implemented.
39
+
40
+ Parameters:
41
+ -----------
42
+ length: int
43
+ The length of the dataframe.
44
+ functions: dict[str, Callable[[int], Any]]
45
+ A dictionary mapping column names to functions that take an index and
46
+ return the value of the column at that index.
47
+ columns: list[str]
48
+ The names of the columns in the dataframe.
49
+ series: bool
50
+ Whether the object should emulate a series or a dataframe.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ length: int,
56
+ functions: dict[str, Callable[[int], Any]],
57
+ columns: list[str],
58
+ series: bool = False,
59
+ ):
60
+ if not (isinstance(length, int) and length >= 0):
61
+ raise ValueError("Length must be a positive integer.")
62
+ if not all(c in functions for c in columns):
63
+ raise ValueError("All columns must have a corresponding function.")
64
+ if series and len(columns) != 1:
65
+ raise ValueError("Series must have exactly one column.")
66
+ self.length = length
67
+ self.functions = functions
68
+ self.columns = columns
69
+ self.series = series
70
+
71
+ @property
72
+ def loc(self):
73
+ return self
74
+
75
+ def __len__(self):
76
+ return self.length
77
+
78
+ def __getitem__(self, key):
79
+ if not isinstance(key, tuple):
80
+ key = (key, self.columns)
81
+ if len(key) == 1:
82
+ key = (key[0], self.columns)
83
+ if not len(key) == 2:
84
+ raise IndexError(
85
+ f"index must be either [row] or [row, column], got [{', '.join(map(str, key))}]."
86
+ )
87
+ row, col = key
88
+ if col == slice(None): # all columns (i.e., call to df[row, :])
89
+ col = self.columns
90
+ one_col = False
91
+ if isinstance(col, str): # one column
92
+ one_col = True
93
+ col = [col]
94
+ else: # multiple columns
95
+ col = list(col)
96
+ if not all(c in self.columns for c in col):
97
+ raise IndexError(
98
+ f"All columns must be present in the dataframe with columns {self.columns}. Got {col}."
99
+ )
100
+ if row == slice(None): # all rows (i.e., call to df[:] or df[:, col])
101
+ return _LazyDataFrame(self.length, self.functions, col)
102
+ if not isinstance(row, int):
103
+ raise NotImplementedError(
104
+ "Row indexing only supports either a single integer or a null slice (i.e., df[:])."
105
+ )
106
+ if not (0 <= row < self.length):
107
+ raise IndexError(f"Row index {row} is out of bounds.")
108
+ if self.series or one_col:
109
+ return self.functions[col[0]](row)
110
+ return pd.Series({c: self.functions[c](row) for c in col})
111
+
112
+ def to_numpy(self):
113
+ return _LazyDataFrame(
114
+ length=self.length,
115
+ functions=self.functions,
116
+ columns=self.columns,
117
+ series=len(self.columns) == 1,
118
+ )
119
+
120
+ def to_list(self):
121
+ return self.to_numpy()
122
+
123
+
124
+ class _FixedLengthWindowFunctions:
125
+ """Class defining functions for lazy metadata generation in fixed length windowing
126
+ to be used in combination with _LazyDataFrame (experimental)."""
127
+
128
+ def __init__(
129
+ self,
130
+ start_offset_samples: int,
131
+ last_potential_start: int,
132
+ window_stride_samples: int,
133
+ window_size_samples: int,
134
+ target: Any,
135
+ ):
136
+ self.start_offset_samples = start_offset_samples
137
+ self.last_potential_start = last_potential_start
138
+ self.window_stride_samples = window_stride_samples
139
+ self.window_size_samples = window_size_samples
140
+ self.target_val = target
141
+
142
+ @property
143
+ def length(self) -> int:
144
+ return int(
145
+ np.ceil(
146
+ (self.last_potential_start + 1 - self.start_offset_samples)
147
+ / self.window_stride_samples
148
+ )
149
+ )
150
+
151
+ def i_window_in_trial(self, i: int) -> int:
152
+ return i
153
+
154
+ def i_start_in_trial(self, i: int) -> int:
155
+ return self.start_offset_samples + i * self.window_stride_samples
156
+
157
+ def i_stop_in_trial(self, i: int) -> int:
158
+ return (
159
+ self.start_offset_samples
160
+ + i * self.window_stride_samples
161
+ + self.window_size_samples
162
+ )
163
+
164
+ def target(self, i: int) -> Any:
165
+ return self.target_val
166
+
167
+
168
+ def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
169
+ should_use_mne_epochs = (
170
+ (reject is not None)
171
+ or (picks is not None)
172
+ or (flat is not None)
173
+ or (drop_bad_windows is True)
174
+ )
175
+ if use_mne_epochs is None:
176
+ if should_use_mne_epochs:
177
+ warnings.warn(
178
+ "Using reject or picks or flat or dropping bad windows means "
179
+ "mne Epochs are created, "
180
+ "which will be substantially slower and may be deprecated in the future."
181
+ )
182
+ return should_use_mne_epochs
183
+ if not use_mne_epochs and should_use_mne_epochs:
184
+ raise ValueError(
185
+ "Cannot set use_mne_epochs=False when using reject, picks, flat, or dropping bad windows."
186
+ )
187
+ return use_mne_epochs
188
+
189
+
190
+ # XXX it's called concat_ds...
191
+ def create_windows_from_events(
192
+ concat_ds: BaseConcatDataset,
193
+ trial_start_offset_samples: int = 0,
194
+ trial_stop_offset_samples: int = 0,
195
+ window_size_samples: int | None = None,
196
+ window_stride_samples: int | None = None,
197
+ drop_last_window: bool = False,
198
+ mapping: dict[str, int] | None = None,
199
+ preload: bool = False,
200
+ drop_bad_windows: bool | None = None,
201
+ picks: str | ArrayLike | slice | None = None,
202
+ reject: dict[str, float] | None = None,
203
+ flat: dict[str, float] | None = None,
204
+ on_missing: str = "error",
205
+ accepted_bads_ratio: float = 0.0,
206
+ use_mne_epochs: bool | None = None,
207
+ n_jobs: int = 1,
208
+ verbose: bool | str | int | None = "error",
209
+ ):
210
+ """Create windows based on events in mne.Raw.
211
+
212
+ This function extracts windows of size window_size_samples in the interval
213
+ [trial onset + trial_start_offset_samples, trial onset + trial duration +
214
+ trial_stop_offset_samples] around each trial, with a separation of
215
+ window_stride_samples between consecutive windows. If the last window
216
+ around an event does not end at trial_stop_offset_samples and
217
+ drop_last_window is set to False, an additional overlapping window that
218
+ ends at trial_stop_offset_samples is created.
219
+
220
+ Windows are extracted from the interval defined by the following::
221
+
222
+ trial onset +
223
+ trial onset duration
224
+ |--------------------|------------------------|-----------------------|
225
+ trial onset - trial onset +
226
+ trial_start_offset_samples duration +
227
+ trial_stop_offset_samples
228
+
229
+ Parameters
230
+ ----------
231
+ concat_ds: BaseConcatDataset
232
+ A concat of base datasets each holding raw and description.
233
+ trial_start_offset_samples: int
234
+ Start offset from original trial onsets, in samples. Defaults to zero.
235
+ trial_stop_offset_samples: int
236
+ Stop offset from original trial stop, in samples. Defaults to zero.
237
+ window_size_samples: int | None
238
+ Window size. If None, the window size is inferred from the original
239
+ trial size of the first trial and trial_start_offset_samples and
240
+ trial_stop_offset_samples.
241
+ window_stride_samples: int | None
242
+ Stride between windows, in samples. If None, the window stride is
243
+ inferred from the original trial size of the first trial and
244
+ trial_start_offset_samples and trial_stop_offset_samples.
245
+ drop_last_window: bool
246
+ If False, an additional overlapping window that ends at
247
+ trial_stop_offset_samples will be extracted around each event when the
248
+ last window does not end exactly at trial_stop_offset_samples.
249
+ mapping: dict(str: int)
250
+ Mapping from event description to numerical target value.
251
+ preload: bool
252
+ If True, preload the data of the Epochs objects. This is useful to
253
+ reduce disk reading overhead when returning windows in a training
254
+ scenario, however very large data might not fit into memory.
255
+ drop_bad_windows: bool
256
+ If True, call `.drop_bad()` on the resulting mne.Epochs object. This
257
+ step allows identifying e.g., windows that fall outside of the
258
+ continuous recording. It is suggested to run this step here as otherwise
259
+ the BaseConcatDataset has to be updated as well.
260
+ picks: str | list | slice | None
261
+ Channels to include. If None, all available channels are used. See
262
+ mne.Epochs.
263
+ reject: dict | None
264
+ Epoch rejection parameters based on peak-to-peak amplitude. If None, no
265
+ rejection is done based on peak-to-peak amplitude. See mne.Epochs.
266
+ flat: dict | None
267
+ Epoch rejection parameters based on flatness of signals. If None, no
268
+ rejection based on flatness is done. See mne.Epochs.
269
+ on_missing: str
270
+ What to do if one or several event ids are not found in the recording.
271
+ Valid keys are ‘error’ | ‘warning’ | ‘ignore’. See mne.Epochs.
272
+ accepted_bads_ratio: float, optional
273
+ Acceptable proportion of trials with inconsistent length in a raw. If
274
+ the number of trials whose length is exceeded by the window size is
275
+ smaller than this, then only the corresponding trials are dropped, but
276
+ the computation continues. Otherwise, an error is raised. Defaults to
277
+ 0.0 (raise an error).
278
+ use_mne_epochs: bool
279
+ If False, return EEGWindowsDataset objects.
280
+ If True, return mne.Epochs objects encapsulated in WindowsDataset objects,
281
+ which is substantially slower that EEGWindowsDataset.
282
+ n_jobs: int
283
+ Number of jobs to use to parallelize the windowing.
284
+ verbose: bool | str | int | None
285
+ Control verbosity of the logging output when calling mne.Epochs.
286
+
287
+ Returns
288
+ -------
289
+ windows_datasets: BaseConcatDataset
290
+ Concatenated datasets of WindowsDataset containing the extracted windows.
291
+ """
292
+ _check_windowing_arguments(
293
+ trial_start_offset_samples,
294
+ trial_stop_offset_samples,
295
+ window_size_samples,
296
+ window_stride_samples,
297
+ )
298
+
299
+ # If user did not specify mapping, we extract all events from all datasets
300
+ # and map them to increasing integers starting from 0
301
+ infer_mapping = mapping is None
302
+ mapping = dict() if infer_mapping else mapping
303
+ infer_window_size_stride = window_size_samples is None
304
+
305
+ if drop_bad_windows is not None:
306
+ warnings.warn(
307
+ "Drop bad windows only has an effect if mne epochs are created, "
308
+ "and this argument may be removed in the future."
309
+ )
310
+
311
+ use_mne_epochs = _get_use_mne_epochs(
312
+ use_mne_epochs, reject, picks, flat, drop_bad_windows
313
+ )
314
+ if use_mne_epochs and drop_bad_windows is None:
315
+ drop_bad_windows = True
316
+
317
+ list_of_windows_ds = Parallel(n_jobs=n_jobs)(
318
+ delayed(_create_windows_from_events)(
319
+ ds,
320
+ infer_mapping,
321
+ infer_window_size_stride,
322
+ trial_start_offset_samples,
323
+ trial_stop_offset_samples,
324
+ window_size_samples,
325
+ window_stride_samples,
326
+ drop_last_window,
327
+ mapping,
328
+ preload,
329
+ drop_bad_windows,
330
+ picks,
331
+ reject,
332
+ flat,
333
+ on_missing,
334
+ accepted_bads_ratio,
335
+ verbose,
336
+ use_mne_epochs,
337
+ )
338
+ for ds in concat_ds.datasets
339
+ )
340
+ return BaseConcatDataset(list_of_windows_ds)
341
+
342
+
343
+ def create_fixed_length_windows(
344
+ concat_ds: BaseConcatDataset,
345
+ start_offset_samples: int = 0,
346
+ stop_offset_samples: int | None = None,
347
+ window_size_samples: int | None = None,
348
+ window_stride_samples: int | None = None,
349
+ drop_last_window: bool | None = None,
350
+ mapping: dict[str, int] | None = None,
351
+ preload: bool = False,
352
+ picks: str | ArrayLike | slice | None = None,
353
+ reject: dict[str, float] | None = None,
354
+ flat: dict[str, float] | None = None,
355
+ targets_from: str = "metadata",
356
+ last_target_only: bool = True,
357
+ lazy_metadata: bool = False,
358
+ on_missing: str = "error",
359
+ n_jobs: int = 1,
360
+ verbose: bool | str | int | None = "error",
361
+ ):
362
+ """Windower that creates sliding windows.
363
+
364
+ Parameters
365
+ ----------
366
+ concat_ds: ConcatDataset
367
+ A concat of base datasets each holding raw and description.
368
+ start_offset_samples: int
369
+ Start offset from beginning of recording in samples.
370
+ stop_offset_samples: int | None
371
+ Stop offset from beginning of recording in samples. If None, set to be
372
+ the end of the recording.
373
+ window_size_samples: int | None
374
+ Window size in samples. If None, set to be the maximum possible window size, ie length of
375
+ the recording, once offsets are accounted for.
376
+ window_stride_samples: int | None
377
+ Stride between windows in samples. If None, set to be equal to winddow_size_samples, so
378
+ windows will not overlap.
379
+ drop_last_window: bool | None
380
+ Whether or not have a last overlapping window, when windows do not
381
+ equally divide the continuous signal. Must be set to a bool if window size and stride are
382
+ not None.
383
+ mapping: dict(str: int)
384
+ Mapping from event description to target value.
385
+ preload: bool
386
+ If True, preload the data of the Epochs objects.
387
+ picks: str | list | slice | None
388
+ Channels to include. If None, all available channels are used. See
389
+ mne.Epochs.
390
+ reject: dict | None
391
+ Epoch rejection parameters based on peak-to-peak amplitude. If None, no
392
+ rejection is done based on peak-to-peak amplitude. See mne.Epochs.
393
+ flat: dict | None
394
+ Epoch rejection parameters based on flatness of signals. If None, no
395
+ rejection based on flatness is done. See mne.Epochs.
396
+ lazy_metadata: bool
397
+ If True, metadata is not computed immediately, but only when accessed
398
+ by using the _LazyDataFrame (experimental).
399
+ on_missing: str
400
+ What to do if one or several event ids are not found in the recording.
401
+ Valid keys are ‘error’ | ‘warning’ | ‘ignore’. See mne.Epochs.
402
+ n_jobs: int
403
+ Number of jobs to use to parallelize the windowing.
404
+ verbose: bool | str | int | None
405
+ Control verbosity of the logging output when calling mne.Epochs.
406
+
407
+ Returns
408
+ -------
409
+ windows_datasets: BaseConcatDataset
410
+ Concatenated datasets of WindowsDataset containing the extracted windows.
411
+ """
412
+ stop_offset_samples, drop_last_window = (
413
+ _check_and_set_fixed_length_window_arguments(
414
+ start_offset_samples,
415
+ stop_offset_samples,
416
+ window_size_samples,
417
+ window_stride_samples,
418
+ drop_last_window,
419
+ lazy_metadata,
420
+ )
421
+ )
422
+
423
+ # check if recordings are of different lengths
424
+ lengths = np.array([ds.raw.n_times for ds in concat_ds.datasets])
425
+ if (np.diff(lengths) != 0).any() and window_size_samples is None:
426
+ warnings.warn("Recordings have different lengths, they will not be batch-able!")
427
+ if (window_size_samples is not None) and any(window_size_samples > lengths):
428
+ raise ValueError(
429
+ f"Window size {window_size_samples} exceeds trial duration {lengths.min()}."
430
+ )
431
+
432
+ list_of_windows_ds = Parallel(n_jobs=n_jobs)(
433
+ delayed(_create_fixed_length_windows)(
434
+ ds,
435
+ start_offset_samples,
436
+ stop_offset_samples,
437
+ window_size_samples,
438
+ window_stride_samples,
439
+ drop_last_window,
440
+ mapping,
441
+ preload,
442
+ picks,
443
+ reject,
444
+ flat,
445
+ targets_from,
446
+ last_target_only,
447
+ lazy_metadata,
448
+ on_missing,
449
+ verbose,
450
+ )
451
+ for ds in concat_ds.datasets
452
+ )
453
+ return BaseConcatDataset(list_of_windows_ds)
454
+
455
+
456
+ def _create_windows_from_events(
457
+ ds,
458
+ infer_mapping,
459
+ infer_window_size_stride,
460
+ trial_start_offset_samples,
461
+ trial_stop_offset_samples,
462
+ window_size_samples=None,
463
+ window_stride_samples=None,
464
+ drop_last_window=False,
465
+ mapping=None,
466
+ preload=False,
467
+ drop_bad_windows=True,
468
+ picks=None,
469
+ reject=None,
470
+ flat=None,
471
+ on_missing="error",
472
+ accepted_bads_ratio=0.0,
473
+ verbose="error",
474
+ use_mne_epochs=False,
475
+ ):
476
+ """Create WindowsDataset from BaseDataset based on events.
477
+
478
+ Parameters
479
+ ----------
480
+ ds : BaseDataset
481
+ Dataset containing continuous data and description.
482
+ infer_mapping : bool
483
+ If True, extract all events from all datasets and map them to
484
+ increasing integers starting from 0.
485
+ infer_window_size_stride : bool
486
+ If True, infer the stride from the original trial size of the first
487
+ trial and trial_start_offset_samples and trial_stop_offset_samples.
488
+
489
+ See `create_windows_from_events` for description of other parameters.
490
+
491
+ Returns
492
+ -------
493
+ EEGWindowsDataset :
494
+ Windowed dataset.
495
+ """
496
+ # catch window_kwargs to store to dataset
497
+ window_kwargs = [
498
+ (create_windows_from_events.__name__, _get_windowing_kwargs(locals())),
499
+ ]
500
+ if infer_mapping:
501
+ unique_events = np.unique(ds.raw.annotations.description)
502
+ new_unique_events = [x for x in unique_events if x not in mapping]
503
+ # mapping event descriptions to integers from 0 on
504
+ max_id_existing_mapping = len(mapping)
505
+ mapping.update(
506
+ {
507
+ event_name: i_event_type + max_id_existing_mapping
508
+ for i_event_type, event_name in enumerate(new_unique_events)
509
+ }
510
+ )
511
+
512
+ events, events_id = mne.events_from_annotations(ds.raw, mapping)
513
+ onsets = events[:, 0]
514
+ # Onsets are relative to the beginning of the recording
515
+ filtered_durations = np.array(
516
+ [a["duration"] for a in ds.raw.annotations if a["description"] in events_id]
517
+ )
518
+
519
+ stops = onsets + (filtered_durations * ds.raw.info["sfreq"]).astype(int)
520
+ # XXX This could probably be simplified by using chunk_duration in
521
+ # `events_from_annotations`
522
+
523
+ last_samp = ds.raw.first_samp + ds.raw.n_times - 1
524
+ # `stops` is used exclusively (i.e. `start:stop`), so add back 1
525
+ if stops[-1] + trial_stop_offset_samples > last_samp + 1:
526
+ raise ValueError(
527
+ '"trial_stop_offset_samples" too large. Stop of last trial '
528
+ f'({stops[-1]}) + "trial_stop_offset_samples" '
529
+ f"({trial_stop_offset_samples}) must be smaller than length of"
530
+ f" recording ({len(ds)})."
531
+ )
532
+
533
+ if infer_window_size_stride:
534
+ # window size is trial size
535
+ if window_size_samples is None:
536
+ window_size_samples = (
537
+ stops[0]
538
+ + trial_stop_offset_samples
539
+ - (onsets[0] + trial_start_offset_samples)
540
+ )
541
+ window_stride_samples = window_size_samples
542
+ this_trial_sizes = (stops + trial_stop_offset_samples) - (
543
+ onsets + trial_start_offset_samples
544
+ )
545
+ # Maybe actually this is not necessary?
546
+ # We could also just say we just assume window size=trial size
547
+ # in case not given, without this condition...
548
+ # but then would have to change functions overall
549
+ checker_trials_size = this_trial_sizes == window_size_samples
550
+
551
+ if not np.all(checker_trials_size):
552
+ trials_drops = int(len(this_trial_sizes) - sum(checker_trials_size))
553
+ warnings.warn(
554
+ f"Dropping trials with different windows size {trials_drops}",
555
+ )
556
+ bads_size_trials = checker_trials_size
557
+ events = events[checker_trials_size]
558
+ onsets = onsets[checker_trials_size]
559
+ stops = stops[checker_trials_size]
560
+
561
+ description = events[:, -1]
562
+
563
+ if not use_mne_epochs:
564
+ onsets = onsets - ds.raw.first_samp
565
+ stops = stops - ds.raw.first_samp
566
+ i_trials, i_window_in_trials, starts, stops = _compute_window_inds(
567
+ onsets,
568
+ stops,
569
+ trial_start_offset_samples,
570
+ trial_stop_offset_samples,
571
+ window_size_samples,
572
+ window_stride_samples,
573
+ drop_last_window,
574
+ accepted_bads_ratio,
575
+ )
576
+
577
+ if any(np.diff(starts) <= 0):
578
+ raise NotImplementedError("Trial overlap not implemented.")
579
+
580
+ events = [
581
+ [start, window_size_samples, description[i_trials[i_start]]]
582
+ for i_start, start in enumerate(starts)
583
+ ]
584
+ events = np.array(events)
585
+
586
+ description = events[:, -1]
587
+
588
+ metadata = pd.DataFrame(
589
+ {
590
+ "i_window_in_trial": i_window_in_trials,
591
+ "i_start_in_trial": starts,
592
+ "i_stop_in_trial": stops,
593
+ "target": description,
594
+ }
595
+ )
596
+ if use_mne_epochs:
597
+ # window size - 1, since tmax is inclusive
598
+ mne_epochs = mne.Epochs(
599
+ ds.raw,
600
+ events,
601
+ events_id,
602
+ baseline=None,
603
+ tmin=0,
604
+ tmax=(window_size_samples - 1) / ds.raw.info["sfreq"],
605
+ metadata=metadata,
606
+ preload=preload,
607
+ picks=picks,
608
+ reject=reject,
609
+ flat=flat,
610
+ on_missing=on_missing,
611
+ verbose=verbose,
612
+ )
613
+ if drop_bad_windows:
614
+ mne_epochs.drop_bad()
615
+ windows_ds = WindowsDataset(
616
+ mne_epochs,
617
+ ds.description,
618
+ )
619
+ else:
620
+ windows_ds = EEGWindowsDataset(
621
+ ds.raw,
622
+ metadata=metadata,
623
+ description=ds.description,
624
+ )
625
+ # add window_kwargs and raw_preproc_kwargs to windows dataset
626
+ setattr(windows_ds, "window_kwargs", window_kwargs)
627
+ kwargs_name = "raw_preproc_kwargs"
628
+ if hasattr(ds, kwargs_name):
629
+ setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
630
+ return windows_ds
631
+
632
+
633
+ def _create_fixed_length_windows(
634
+ ds,
635
+ start_offset_samples,
636
+ stop_offset_samples,
637
+ window_size_samples,
638
+ window_stride_samples,
639
+ drop_last_window,
640
+ mapping=None,
641
+ preload=False,
642
+ picks=None,
643
+ reject=None,
644
+ flat=None,
645
+ targets_from="metadata",
646
+ last_target_only=True,
647
+ lazy_metadata=False,
648
+ on_missing="error",
649
+ verbose="error",
650
+ ):
651
+ """Create WindowsDataset from BaseDataset with sliding windows.
652
+
653
+ Parameters
654
+ ----------
655
+ ds : BaseDataset
656
+ Dataset containing continuous data and description.
657
+
658
+ See `create_fixed_length_windows` for description of other parameters.
659
+
660
+ Returns
661
+ -------
662
+ WindowsDataset :
663
+ Windowed dataset.
664
+ """
665
+ # catch window_kwargs to store to dataset
666
+ window_kwargs = [
667
+ (create_fixed_length_windows.__name__, _get_windowing_kwargs(locals())),
668
+ ]
669
+ stop = ds.raw.n_times if stop_offset_samples is None else stop_offset_samples
670
+
671
+ # assume window should be whole recording
672
+ if window_size_samples is None:
673
+ window_size_samples = stop - start_offset_samples
674
+ if window_stride_samples is None:
675
+ window_stride_samples = window_size_samples
676
+
677
+ last_potential_start = stop - window_size_samples
678
+
679
+ # get targets from dataset description if they exist
680
+ target = -1 if ds.target_name is None else ds.description[ds.target_name]
681
+ if mapping is not None:
682
+ # in case of multiple targets
683
+ if isinstance(target, pd.Series):
684
+ target = target.replace(mapping).to_list()
685
+ # in case of single value target
686
+ else:
687
+ target = mapping[target]
688
+
689
+ if lazy_metadata:
690
+ factory = _FixedLengthWindowFunctions(
691
+ start_offset_samples,
692
+ last_potential_start,
693
+ window_stride_samples,
694
+ window_size_samples,
695
+ target,
696
+ )
697
+ metadata = _LazyDataFrame(
698
+ length=factory.length,
699
+ functions={
700
+ "i_window_in_trial": factory.i_window_in_trial,
701
+ "i_start_in_trial": factory.i_start_in_trial,
702
+ "i_stop_in_trial": factory.i_stop_in_trial,
703
+ "target": factory.target,
704
+ },
705
+ columns=[
706
+ "i_window_in_trial",
707
+ "i_start_in_trial",
708
+ "i_stop_in_trial",
709
+ "target",
710
+ ],
711
+ )
712
+ else:
713
+ # already includes last incomplete window start
714
+ starts = np.arange(
715
+ start_offset_samples, last_potential_start + 1, window_stride_samples
716
+ )
717
+
718
+ if not drop_last_window and starts[-1] < last_potential_start:
719
+ # if last window does not end at trial stop, make it stop there
720
+ starts = np.append(starts, last_potential_start)
721
+
722
+ metadata = pd.DataFrame(
723
+ {
724
+ "i_window_in_trial": np.arange(len(starts)),
725
+ "i_start_in_trial": starts,
726
+ "i_stop_in_trial": starts + window_size_samples,
727
+ "target": len(starts) * [target],
728
+ }
729
+ )
730
+
731
+ window_kwargs.append(
732
+ (
733
+ EEGWindowsDataset.__name__,
734
+ {"targets_from": targets_from, "last_target_only": last_target_only},
735
+ )
736
+ )
737
+ windows_ds = EEGWindowsDataset(
738
+ ds.raw,
739
+ metadata=metadata,
740
+ description=ds.description,
741
+ targets_from=targets_from,
742
+ last_target_only=last_target_only,
743
+ )
744
+ # add window_kwargs and raw_preproc_kwargs to windows dataset
745
+ setattr(windows_ds, "window_kwargs", window_kwargs)
746
+ kwargs_name = "raw_preproc_kwargs"
747
+ if hasattr(ds, kwargs_name):
748
+ setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
749
+ return windows_ds
750
+
751
+
752
+ def create_windows_from_target_channels(
753
+ concat_ds,
754
+ window_size_samples=None,
755
+ preload=False,
756
+ picks=None,
757
+ reject=None,
758
+ flat=None,
759
+ n_jobs=1,
760
+ last_target_only=True,
761
+ verbose="error",
762
+ ):
763
+ list_of_windows_ds = Parallel(n_jobs=n_jobs)(
764
+ delayed(_create_windows_from_target_channels)(
765
+ ds,
766
+ window_size_samples,
767
+ preload,
768
+ picks,
769
+ reject,
770
+ flat,
771
+ last_target_only,
772
+ "error",
773
+ verbose,
774
+ )
775
+ for ds in concat_ds.datasets
776
+ )
777
+ return BaseConcatDataset(list_of_windows_ds)
778
+
779
+
780
+ def _create_windows_from_target_channels(
781
+ ds,
782
+ window_size_samples,
783
+ preload=False,
784
+ picks=None,
785
+ reject=None,
786
+ flat=None,
787
+ last_target_only=True,
788
+ on_missing="error",
789
+ verbose="error",
790
+ ):
791
+ """Create WindowsDataset from BaseDataset using targets `misc` channels from mne.Raw.
792
+
793
+ Parameters
794
+ ----------
795
+ ds : BaseDataset
796
+ Dataset containing continuous data and description.
797
+
798
+ See `create_fixed_length_windows` for description of other parameters.
799
+
800
+ Returns
801
+ -------
802
+ WindowsDataset :
803
+ Windowed dataset.
804
+ """
805
+ window_kwargs = [
806
+ (create_windows_from_target_channels.__name__, _get_windowing_kwargs(locals())),
807
+ ]
808
+ stop = ds.raw.n_times + ds.raw.first_samp
809
+
810
+ target = ds.raw.get_data(picks="misc")
811
+ # TODO: handle multi targets present only for some events
812
+ stops = np.nonzero((~np.isnan(target[0, :])))[0] + 1
813
+ stops = stops[(stops < stop) & (stops >= window_size_samples)]
814
+ stops = stops.astype(int)
815
+ metadata = pd.DataFrame(
816
+ {
817
+ "i_window_in_trial": np.arange(len(stops)),
818
+ "i_start_in_trial": stops - window_size_samples,
819
+ "i_stop_in_trial": stops,
820
+ "target": len(stops) * [target],
821
+ }
822
+ )
823
+
824
+ targets_from = "channels"
825
+ window_kwargs.append(
826
+ (
827
+ EEGWindowsDataset.__name__,
828
+ {"targets_from": targets_from, "last_target_only": last_target_only},
829
+ )
830
+ )
831
+ windows_ds = EEGWindowsDataset(
832
+ ds.raw,
833
+ metadata=metadata,
834
+ description=ds.description,
835
+ targets_from=targets_from,
836
+ last_target_only=last_target_only,
837
+ )
838
+ setattr(windows_ds, "window_kwargs", window_kwargs)
839
+ kwargs_name = "raw_preproc_kwargs"
840
+ if hasattr(ds, kwargs_name):
841
+ setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
842
+ return windows_ds
843
+
844
+
845
+ def _compute_window_inds(
846
+ starts,
847
+ stops,
848
+ start_offset,
849
+ stop_offset,
850
+ size,
851
+ stride,
852
+ drop_last_window,
853
+ accepted_bads_ratio,
854
+ ):
855
+ """Compute window start and stop indices.
856
+
857
+ Create window starts from trial onsets (shifted by start_offset) to trial
858
+ end (shifted by stop_offset) separated by stride, as long as window size
859
+ fits into trial.
860
+
861
+ Parameters
862
+ ----------
863
+ starts: array-like
864
+ Trial starts in samples.
865
+ stops: array-like
866
+ Trial stops in samples.
867
+ start_offset: int
868
+ Start offset from original trial onsets in samples.
869
+ stop_offset: int
870
+ Stop offset from original trial stop in samples.
871
+ size: int
872
+ Window size.
873
+ stride: int
874
+ Stride between windows.
875
+ drop_last_window: bool
876
+ Toggles of shifting last window within range or dropping last samples.
877
+ accepted_bads_ratio: float
878
+ Acceptable proportion of bad trials within a raw. If the number of
879
+ trials whose length is exceeded by the window size is smaller than
880
+ this, then only the corresponding trials are dropped, but the
881
+ computation continues. Otherwise, an error is raised.
882
+
883
+ Returns
884
+ -------
885
+ result_lists: (list, list, list, list)
886
+ Trial, i_window_in_trial, start sample and stop sample of windows.
887
+ """
888
+ starts = np.array([starts]) if isinstance(starts, int) else starts
889
+ stops = np.array([stops]) if isinstance(stops, int) else stops
890
+
891
+ starts += start_offset
892
+ stops += stop_offset
893
+ if any(size > (stops - starts)):
894
+ bads_mask = size > (stops - starts)
895
+ min_duration = (stops - starts).min()
896
+ if sum(bads_mask) <= accepted_bads_ratio * len(starts):
897
+ starts = starts[np.logical_not(bads_mask)]
898
+ stops = stops[np.logical_not(bads_mask)]
899
+ warnings.warn(
900
+ f"Trials {np.where(bads_mask)[0]} are being dropped as the "
901
+ f"window size ({size}) exceeds their duration {min_duration}."
902
+ )
903
+ else:
904
+ current_ratio = sum(bads_mask) / len(starts)
905
+ raise ValueError(
906
+ f"Window size {size} exceeds trial duration "
907
+ f"({min_duration}) for too many trials "
908
+ f"({current_ratio * 100}%). Set "
909
+ f"accepted_bads_ratio to at least {current_ratio}"
910
+ "and restart training to be able to continue."
911
+ )
912
+
913
+ i_window_in_trials, i_trials, window_starts = [], [], []
914
+ for start_i, (start, stop) in enumerate(zip(starts, stops)):
915
+ # Generate possible window starts, with given stride, between original
916
+ # trial onsets and stops (shifted by start_offset and stop_offset,
917
+ # respectively)
918
+ possible_starts = np.arange(start, stop, stride)
919
+
920
+ # Possible window start is actually a start, if window size fits in
921
+ # trial start and stop
922
+ for i_window, s in enumerate(possible_starts):
923
+ if (s + size) <= stop:
924
+ window_starts.append(s)
925
+ i_window_in_trials.append(i_window)
926
+ i_trials.append(start_i)
927
+
928
+ # If the last window start + window size is not the same as
929
+ # stop + stop_offset, create another window that overlaps and stops
930
+ # at onset + stop_offset
931
+ if not drop_last_window:
932
+ if window_starts[-1] + size != stop:
933
+ window_starts.append(stop - size)
934
+ i_window_in_trials.append(i_window_in_trials[-1] + 1)
935
+ i_trials.append(start_i)
936
+
937
+ # Set window stops to be event stops (rather than trial stops)
938
+ window_stops = np.array(window_starts) + size
939
+ if not (len(i_window_in_trials) == len(window_starts) == len(window_stops)):
940
+ raise ValueError(
941
+ f"{len(i_window_in_trials)} == {len(window_starts)} == {len(window_stops)}"
942
+ )
943
+
944
+ return i_trials, i_window_in_trials, window_starts, window_stops
945
+
946
+
947
+ def _check_windowing_arguments(
948
+ trial_start_offset_samples,
949
+ trial_stop_offset_samples,
950
+ window_size_samples,
951
+ window_stride_samples,
952
+ ):
953
+ assert isinstance(trial_start_offset_samples, (int, np.integer))
954
+ assert isinstance(trial_stop_offset_samples, (int, np.integer)) or (
955
+ trial_stop_offset_samples is None
956
+ )
957
+ assert isinstance(window_size_samples, (int, np.integer, type(None)))
958
+ assert isinstance(window_stride_samples, (int, np.integer, type(None)))
959
+ assert (window_size_samples is None) == (window_stride_samples is None)
960
+ if window_size_samples is not None:
961
+ assert window_size_samples > 0, "window size has to be larger than 0"
962
+ assert window_stride_samples > 0, "window stride has to be larger than 0"
963
+
964
+
965
+ def _check_and_set_fixed_length_window_arguments(
966
+ start_offset_samples,
967
+ stop_offset_samples,
968
+ window_size_samples,
969
+ window_stride_samples,
970
+ drop_last_window,
971
+ lazy_metadata,
972
+ ):
973
+ """Raises warnings for incorrect input arguments and will set correct default values for
974
+ stop_offset_samples & drop_last_window, if necessary.
975
+ """
976
+ _check_windowing_arguments(
977
+ start_offset_samples,
978
+ stop_offset_samples,
979
+ window_size_samples,
980
+ window_stride_samples,
981
+ )
982
+
983
+ if stop_offset_samples == 0:
984
+ warnings.warn(
985
+ "Meaning of `trial_stop_offset_samples`=0 has changed, use `None` "
986
+ "to indicate end of trial/recording. Using `None`."
987
+ )
988
+ stop_offset_samples = None
989
+
990
+ if start_offset_samples != 0 or stop_offset_samples is not None:
991
+ warnings.warn(
992
+ "Usage of offset_sample args in create_fixed_length_windows is deprecated and"
993
+ " will be removed in future versions. Please use "
994
+ 'braindecode.preprocessing.preprocess.Preprocessor("crop", tmin, tmax)'
995
+ " instead."
996
+ )
997
+
998
+ if (
999
+ window_size_samples is not None
1000
+ and window_stride_samples is not None
1001
+ and drop_last_window is None
1002
+ ):
1003
+ raise ValueError(
1004
+ "drop_last_window must be set if both window_size_samples &"
1005
+ " window_stride_samples have also been set"
1006
+ )
1007
+ elif (
1008
+ window_size_samples is None
1009
+ and window_stride_samples is None
1010
+ and drop_last_window is False
1011
+ ):
1012
+ # necessary for following assertion
1013
+ drop_last_window = None
1014
+
1015
+ assert (
1016
+ (window_size_samples is None)
1017
+ == (window_stride_samples is None)
1018
+ == (drop_last_window is None)
1019
+ )
1020
+ if not drop_last_window and lazy_metadata:
1021
+ raise ValueError(
1022
+ "Cannot have drop_last_window=False and lazy_metadata=True at the same time."
1023
+ )
1024
+ return stop_offset_samples, drop_last_window
1025
+
1026
+
1027
+ def _get_windowing_kwargs(windowing_func_locals):
1028
+ input_kwargs = windowing_func_locals
1029
+ input_kwargs.pop("ds")
1030
+ windowing_kwargs = {k: v for k, v in input_kwargs.items()}
1031
+ return windowing_kwargs