braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1037 @@
1
+ """Get epochs from mne.Raw"""
2
+
3
+ # Authors: Hubert Banville <hubert.jbanville@gmail.com>
4
+ # Lukas Gemein <l.gemein@gmail.com>
5
+ # Simon Brandt <simonbrandt@protonmail.com>
6
+ # David Sabbagh <dav.sabbagh@gmail.com>
7
+ # Henrik Bonsmann <henrikbons@gmail.com>
8
+ # Ann-Kathrin Kiessner <ann-kathrin.kiessner@gmx.de>
9
+ # Vytautas Jankauskas <vytauto.jankausko@gmail.com>
10
+ # Dan Wilson <dan.c.wil@gmail.com>
11
+ # Maciej Sliwowski <maciek.sliwowski@gmail.com>
12
+ # Mohammed Fattouh <mo.fattouh@gmail.com>
13
+ # Robin Schirrmeister <robintibor@gmail.com>
14
+ # Matthew Chen <matt.chen42601@gmail.com>
15
+ #
16
+ # License: BSD (3-clause)
17
+
18
+ from __future__ import annotations
19
+
20
+ import warnings
21
+ from typing import Any, Callable
22
+
23
+ import mne
24
+ import numpy as np
25
+ import pandas as pd
26
+ from joblib import Parallel, delayed
27
+ from numpy.typing import ArrayLike
28
+
29
+ from ..datasets.base import (
30
+ BaseConcatDataset,
31
+ EEGWindowsDataset,
32
+ RawDataset,
33
+ WindowsDataset,
34
+ )
35
+
36
+
37
+ class _LazyDataFrame:
38
+ """
39
+ DataFrame-like object that lazily computes values (experimental).
40
+
41
+ This class emulates some features of a pandas DataFrame, but computes
42
+ the values on-the-fly when they are accessed. This is useful for
43
+ very long DataFrames with repetitive values.
44
+ Only the methods used by EEGWindowsDataset on its metadata are implemented.
45
+
46
+ Parameters:
47
+ -----------
48
+ length: int
49
+ The length of the dataframe.
50
+ functions: dict[str, Callable[[int], Any]]
51
+ A dictionary mapping column names to functions that take an index and
52
+ return the value of the column at that index.
53
+ columns: list[str]
54
+ The names of the columns in the dataframe.
55
+ series: bool
56
+ Whether the object should emulate a series or a dataframe.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ length: int,
62
+ functions: dict[str, Callable[[int], Any]],
63
+ columns: list[str],
64
+ series: bool = False,
65
+ ):
66
+ if not (isinstance(length, int) and length >= 0):
67
+ raise ValueError("Length must be a positive integer.")
68
+ if not all(c in functions for c in columns):
69
+ raise ValueError("All columns must have a corresponding function.")
70
+ if series and len(columns) != 1:
71
+ raise ValueError("Series must have exactly one column.")
72
+ self.length = length
73
+ self.functions = functions
74
+ self.columns = columns
75
+ self.series = series
76
+
77
+ @property
78
+ def loc(self):
79
+ return self
80
+
81
+ def __len__(self):
82
+ return self.length
83
+
84
+ def __getitem__(self, key):
85
+ if not isinstance(key, tuple):
86
+ key = (key, self.columns)
87
+ if len(key) == 1:
88
+ key = (key[0], self.columns)
89
+ if not len(key) == 2:
90
+ raise IndexError(
91
+ f"index must be either [row] or [row, column], got [{', '.join(map(str, key))}]."
92
+ )
93
+ row, col = key
94
+ if col == slice(None): # all columns (i.e., call to df[row, :])
95
+ col = self.columns
96
+ one_col = False
97
+ if isinstance(col, str): # one column
98
+ one_col = True
99
+ col = [col]
100
+ else: # multiple columns
101
+ col = list(col)
102
+ if not all(c in self.columns for c in col):
103
+ raise IndexError(
104
+ f"All columns must be present in the dataframe with columns {self.columns}. Got {col}."
105
+ )
106
+ if row == slice(None): # all rows (i.e., call to df[:] or df[:, col])
107
+ return _LazyDataFrame(self.length, self.functions, col)
108
+ if not isinstance(row, int):
109
+ raise NotImplementedError(
110
+ "Row indexing only supports either a single integer or a null slice (i.e., df[:])."
111
+ )
112
+ if not (0 <= row < self.length):
113
+ raise IndexError(f"Row index {row} is out of bounds.")
114
+ if self.series or one_col:
115
+ return self.functions[col[0]](row)
116
+ return pd.Series({c: self.functions[c](row) for c in col})
117
+
118
+ def to_numpy(self):
119
+ return _LazyDataFrame(
120
+ length=self.length,
121
+ functions=self.functions,
122
+ columns=self.columns,
123
+ series=len(self.columns) == 1,
124
+ )
125
+
126
+ def to_list(self):
127
+ return self.to_numpy()
128
+
129
+
130
+ class _FixedLengthWindowFunctions:
131
+ """Class defining functions for lazy metadata generation in fixed length windowing
132
+ to be used in combination with _LazyDataFrame (experimental)."""
133
+
134
+ def __init__(
135
+ self,
136
+ start_offset_samples: int,
137
+ last_potential_start: int,
138
+ window_stride_samples: int,
139
+ window_size_samples: int,
140
+ target: Any,
141
+ ):
142
+ self.start_offset_samples = start_offset_samples
143
+ self.last_potential_start = last_potential_start
144
+ self.window_stride_samples = window_stride_samples
145
+ self.window_size_samples = window_size_samples
146
+ self.target_val = target
147
+
148
+ @property
149
+ def length(self) -> int:
150
+ return int(
151
+ np.ceil(
152
+ (self.last_potential_start + 1 - self.start_offset_samples)
153
+ / self.window_stride_samples
154
+ )
155
+ )
156
+
157
+ def i_window_in_trial(self, i: int) -> int:
158
+ return i
159
+
160
+ def i_start_in_trial(self, i: int) -> int:
161
+ return self.start_offset_samples + i * self.window_stride_samples
162
+
163
+ def i_stop_in_trial(self, i: int) -> int:
164
+ return (
165
+ self.start_offset_samples
166
+ + i * self.window_stride_samples
167
+ + self.window_size_samples
168
+ )
169
+
170
+ def target(self, i: int) -> Any:
171
+ return self.target_val
172
+
173
+
174
+ def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
175
+ should_use_mne_epochs = (
176
+ (reject is not None)
177
+ or (picks is not None)
178
+ or (flat is not None)
179
+ or (drop_bad_windows is True)
180
+ )
181
+ if use_mne_epochs is None:
182
+ if should_use_mne_epochs:
183
+ warnings.warn(
184
+ "Using reject or picks or flat or dropping bad windows means "
185
+ "mne Epochs are created, "
186
+ "which will be substantially slower and may be deprecated in the future."
187
+ )
188
+ return should_use_mne_epochs
189
+ if not use_mne_epochs and should_use_mne_epochs:
190
+ raise ValueError(
191
+ "Cannot set use_mne_epochs=False when using reject, picks, flat, or dropping bad windows."
192
+ )
193
+ return use_mne_epochs
194
+
195
+
196
+ # XXX it's called concat_ds...
197
+ def create_windows_from_events(
198
+ concat_ds: BaseConcatDataset[RawDataset],
199
+ trial_start_offset_samples: int = 0,
200
+ trial_stop_offset_samples: int = 0,
201
+ window_size_samples: int | None = None,
202
+ window_stride_samples: int | None = None,
203
+ drop_last_window: bool = False,
204
+ mapping: dict[str, int] | None = None,
205
+ preload: bool = False,
206
+ drop_bad_windows: bool | None = None,
207
+ picks: str | ArrayLike | slice | None = None,
208
+ reject: dict[str, float] | None = None,
209
+ flat: dict[str, float] | None = None,
210
+ on_missing: str = "error",
211
+ accepted_bads_ratio: float = 0.0,
212
+ use_mne_epochs: bool | None = None,
213
+ n_jobs: int = 1,
214
+ verbose: bool | str | int | None = "error",
215
+ ) -> BaseConcatDataset[WindowsDataset | EEGWindowsDataset]:
216
+ """Create windows based on events in mne.Raw.
217
+
218
+ This function extracts windows of size window_size_samples in the interval
219
+ [trial onset + trial_start_offset_samples, trial onset + trial duration +
220
+ trial_stop_offset_samples] around each trial, with a separation of
221
+ window_stride_samples between consecutive windows. If the last window
222
+ around an event does not end at trial_stop_offset_samples and
223
+ drop_last_window is set to False, an additional overlapping window that
224
+ ends at trial_stop_offset_samples is created.
225
+
226
+ Windows are extracted from the interval defined by the following::
227
+
228
+ trial onset +
229
+ trial onset duration
230
+ |--------------------|------------------------|-----------------------|
231
+ trial onset - trial onset +
232
+ trial_start_offset_samples duration +
233
+ trial_stop_offset_samples
234
+
235
+ Parameters
236
+ ----------
237
+ concat_ds: BaseConcatDataset[RawDataset]
238
+ A concat of base datasets each holding raw and description.
239
+ trial_start_offset_samples: int
240
+ Start offset from original trial onsets, in samples. Defaults to zero.
241
+ trial_stop_offset_samples: int
242
+ Stop offset from original trial stop, in samples. Defaults to zero.
243
+ window_size_samples: int | None
244
+ Window size. If None, the window size is inferred from the original
245
+ trial size of the first trial and trial_start_offset_samples and
246
+ trial_stop_offset_samples.
247
+ window_stride_samples: int | None
248
+ Stride between windows, in samples. If None, the window stride is
249
+ inferred from the original trial size of the first trial and
250
+ trial_start_offset_samples and trial_stop_offset_samples.
251
+ drop_last_window: bool
252
+ If False, an additional overlapping window that ends at
253
+ trial_stop_offset_samples will be extracted around each event when the
254
+ last window does not end exactly at trial_stop_offset_samples.
255
+ mapping: dict(str: int)
256
+ Mapping from event description to numerical target value.
257
+ preload: bool
258
+ If True, preload the data of the Epochs objects. This is useful to
259
+ reduce disk reading overhead when returning windows in a training
260
+ scenario, however very large data might not fit into memory.
261
+ drop_bad_windows: bool
262
+ If True, call `.drop_bad()` on the resulting mne.Epochs object. This
263
+ step allows identifying e.g., windows that fall outside of the
264
+ continuous recording. It is suggested to run this step here as otherwise
265
+ the BaseConcatDataset has to be updated as well.
266
+ picks: str | list | slice | None
267
+ Channels to include. If None, all available channels are used. See
268
+ mne.Epochs.
269
+ reject: dict | None
270
+ Epoch rejection parameters based on peak-to-peak amplitude. If None, no
271
+ rejection is done based on peak-to-peak amplitude. See mne.Epochs.
272
+ flat: dict | None
273
+ Epoch rejection parameters based on flatness of signals. If None, no
274
+ rejection based on flatness is done. See mne.Epochs.
275
+ on_missing: str
276
+ What to do if one or several event ids are not found in the recording.
277
+ Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
278
+ accepted_bads_ratio: float, optional
279
+ Acceptable proportion of trials with inconsistent length in a raw. If
280
+ the number of trials whose length is exceeded by the window size is
281
+ smaller than this, then only the corresponding trials are dropped, but
282
+ the computation continues. Otherwise, an error is raised. Defaults to
283
+ 0.0 (raise an error).
284
+ use_mne_epochs: bool
285
+ If False, return EEGWindowsDataset objects.
286
+ If True, return mne.Epochs objects encapsulated in WindowsDataset objects,
287
+ which is substantially slower that EEGWindowsDataset.
288
+ n_jobs: int
289
+ Number of jobs to use to parallelize the windowing.
290
+ verbose: bool | str | int | None
291
+ Control verbosity of the logging output when calling mne.Epochs.
292
+
293
+ Returns
294
+ -------
295
+ windows_datasets: BaseConcatDataset[WindowsDataset | EEGWindowsDataset]
296
+ Concatenated datasets of WindowsDataset containing the extracted windows.
297
+ """
298
+ _check_windowing_arguments(
299
+ trial_start_offset_samples,
300
+ trial_stop_offset_samples,
301
+ window_size_samples,
302
+ window_stride_samples,
303
+ )
304
+
305
+ # If user did not specify mapping, we extract all events from all datasets
306
+ # and map them to increasing integers starting from 0
307
+ infer_mapping = mapping is None
308
+ mapping = dict() if infer_mapping else mapping
309
+ infer_window_size_stride = window_size_samples is None
310
+
311
+ if drop_bad_windows is not None:
312
+ warnings.warn(
313
+ "Drop bad windows only has an effect if mne epochs are created, "
314
+ "and this argument may be removed in the future."
315
+ )
316
+
317
+ use_mne_epochs = _get_use_mne_epochs(
318
+ use_mne_epochs, reject, picks, flat, drop_bad_windows
319
+ )
320
+ if use_mne_epochs and drop_bad_windows is None:
321
+ drop_bad_windows = True
322
+
323
+ list_of_windows_ds = Parallel(n_jobs=n_jobs)(
324
+ delayed(_create_windows_from_events)(
325
+ ds,
326
+ infer_mapping,
327
+ infer_window_size_stride,
328
+ trial_start_offset_samples,
329
+ trial_stop_offset_samples,
330
+ window_size_samples,
331
+ window_stride_samples,
332
+ drop_last_window,
333
+ mapping,
334
+ preload,
335
+ drop_bad_windows,
336
+ picks,
337
+ reject,
338
+ flat,
339
+ on_missing,
340
+ accepted_bads_ratio,
341
+ verbose,
342
+ use_mne_epochs,
343
+ )
344
+ for ds in concat_ds.datasets
345
+ )
346
+ return BaseConcatDataset(list_of_windows_ds)
347
+
348
+
349
+ def create_fixed_length_windows(
350
+ concat_ds: BaseConcatDataset[RawDataset],
351
+ start_offset_samples: int = 0,
352
+ stop_offset_samples: int | None = None,
353
+ window_size_samples: int | None = None,
354
+ window_stride_samples: int | None = None,
355
+ drop_last_window: bool | None = None,
356
+ mapping: dict[str, int] | None = None,
357
+ preload: bool = False,
358
+ picks: str | ArrayLike | slice | None = None,
359
+ reject: dict[str, float] | None = None,
360
+ flat: dict[str, float] | None = None,
361
+ targets_from: str = "metadata",
362
+ last_target_only: bool = True,
363
+ lazy_metadata: bool = False,
364
+ on_missing: str = "error",
365
+ n_jobs: int = 1,
366
+ verbose: bool | str | int | None = "error",
367
+ ) -> BaseConcatDataset[EEGWindowsDataset]:
368
+ """Windower that creates sliding windows.
369
+
370
+ Parameters
371
+ ----------
372
+ concat_ds: ConcatDataset[RawDataset]
373
+ A concat of base datasets each holding raw and description.
374
+ start_offset_samples: int
375
+ Start offset from beginning of recording in samples.
376
+ stop_offset_samples: int | None
377
+ Stop offset from beginning of recording in samples. If None, set to be
378
+ the end of the recording.
379
+ window_size_samples: int | None
380
+ Window size in samples. If None, set to be the maximum possible window size, ie length of
381
+ the recording, once offsets are accounted for.
382
+ window_stride_samples: int | None
383
+ Stride between windows in samples. If None, set to be equal to winddow_size_samples, so
384
+ windows will not overlap.
385
+ drop_last_window: bool | None
386
+ Whether or not have a last overlapping window, when windows do not
387
+ equally divide the continuous signal. Must be set to a bool if window size and stride are
388
+ not None.
389
+ mapping: dict(str: int)
390
+ Mapping from event description to target value.
391
+ preload: bool
392
+ If True, preload the data of the Epochs objects.
393
+ picks: str | list | slice | None
394
+ Channels to include. If None, all available channels are used. See
395
+ mne.Epochs.
396
+ reject: dict | None
397
+ Epoch rejection parameters based on peak-to-peak amplitude. If None, no
398
+ rejection is done based on peak-to-peak amplitude. See mne.Epochs.
399
+ flat: dict | None
400
+ Epoch rejection parameters based on flatness of signals. If None, no
401
+ rejection based on flatness is done. See mne.Epochs.
402
+ lazy_metadata: bool
403
+ If True, metadata is not computed immediately, but only when accessed
404
+ by using the _LazyDataFrame (experimental).
405
+ on_missing: str
406
+ What to do if one or several event ids are not found in the recording.
407
+ Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
408
+ n_jobs: int
409
+ Number of jobs to use to parallelize the windowing.
410
+ verbose: bool | str | int | None
411
+ Control verbosity of the logging output when calling mne.Epochs.
412
+
413
+ Returns
414
+ -------
415
+ windows_datasets: BaseConcatDataset[EEGWindowsDataset]
416
+ Concatenated datasets of WindowsDataset containing the extracted windows.
417
+ """
418
+ stop_offset_samples, drop_last_window = (
419
+ _check_and_set_fixed_length_window_arguments(
420
+ start_offset_samples,
421
+ stop_offset_samples,
422
+ window_size_samples,
423
+ window_stride_samples,
424
+ drop_last_window,
425
+ lazy_metadata,
426
+ )
427
+ )
428
+
429
+ # check if recordings are of different lengths
430
+ lengths = np.array([ds.raw.n_times for ds in concat_ds.datasets])
431
+ if (np.diff(lengths) != 0).any() and window_size_samples is None:
432
+ warnings.warn("Recordings have different lengths, they will not be batch-able!")
433
+ if (window_size_samples is not None) and any(window_size_samples > lengths):
434
+ raise ValueError(
435
+ f"Window size {window_size_samples} exceeds trial duration {lengths.min()}."
436
+ )
437
+
438
+ list_of_windows_ds = Parallel(n_jobs=n_jobs)(
439
+ delayed(_create_fixed_length_windows)(
440
+ ds,
441
+ start_offset_samples,
442
+ stop_offset_samples,
443
+ window_size_samples,
444
+ window_stride_samples,
445
+ drop_last_window,
446
+ mapping,
447
+ preload,
448
+ picks,
449
+ reject,
450
+ flat,
451
+ targets_from,
452
+ last_target_only,
453
+ lazy_metadata,
454
+ on_missing,
455
+ verbose,
456
+ )
457
+ for ds in concat_ds.datasets
458
+ )
459
+ return BaseConcatDataset(list_of_windows_ds)
460
+
461
+
462
+ def _create_windows_from_events(
463
+ ds,
464
+ infer_mapping,
465
+ infer_window_size_stride,
466
+ trial_start_offset_samples,
467
+ trial_stop_offset_samples,
468
+ window_size_samples=None,
469
+ window_stride_samples=None,
470
+ drop_last_window=False,
471
+ mapping=None,
472
+ preload=False,
473
+ drop_bad_windows=True,
474
+ picks=None,
475
+ reject=None,
476
+ flat=None,
477
+ on_missing="error",
478
+ accepted_bads_ratio=0.0,
479
+ verbose="error",
480
+ use_mne_epochs=False,
481
+ ):
482
+ """Create WindowsDataset from RawDataset based on events.
483
+
484
+ Parameters
485
+ ----------
486
+ ds : RawDataset
487
+ Dataset containing continuous data and description.
488
+ infer_mapping : bool
489
+ If True, extract all events from all datasets and map them to
490
+ increasing integers starting from 0.
491
+ infer_window_size_stride : bool
492
+ If True, infer the stride from the original trial size of the first
493
+ trial and trial_start_offset_samples and trial_stop_offset_samples.
494
+
495
+ See `create_windows_from_events` for description of other parameters.
496
+
497
+ Returns
498
+ -------
499
+ EEGWindowsDataset :
500
+ Windowed dataset.
501
+ """
502
+ # catch window_kwargs to store to dataset
503
+ window_kwargs = [
504
+ (create_windows_from_events.__name__, _get_windowing_kwargs(locals())),
505
+ ]
506
+ if infer_mapping:
507
+ unique_events = np.unique(ds.raw.annotations.description)
508
+ new_unique_events = [x for x in unique_events if x not in mapping]
509
+ # mapping event descriptions to integers from 0 on
510
+ max_id_existing_mapping = len(mapping)
511
+ mapping.update(
512
+ {
513
+ event_name: i_event_type + max_id_existing_mapping
514
+ for i_event_type, event_name in enumerate(new_unique_events)
515
+ }
516
+ )
517
+
518
+ events, events_id = mne.events_from_annotations(ds.raw, mapping, verbose=verbose)
519
+ onsets = events[:, 0]
520
+ # Onsets are relative to the beginning of the recording
521
+ filtered_durations = np.array(
522
+ [a["duration"] for a in ds.raw.annotations if a["description"] in events_id]
523
+ )
524
+
525
+ stops = onsets + (filtered_durations * ds.raw.info["sfreq"]).astype(int)
526
+ # XXX This could probably be simplified by using chunk_duration in
527
+ # `events_from_annotations`
528
+
529
+ last_samp = ds.raw.first_samp + ds.raw.n_times - 1
530
+ # `stops` is used exclusively (i.e. `start:stop`), so add back 1
531
+ if stops[-1] + trial_stop_offset_samples > last_samp + 1:
532
+ raise ValueError(
533
+ '"trial_stop_offset_samples" too large. Stop of last trial '
534
+ f'({stops[-1]}) + "trial_stop_offset_samples" '
535
+ f"({trial_stop_offset_samples}) must be smaller than length of"
536
+ f" recording ({len(ds)})."
537
+ )
538
+
539
+ if infer_window_size_stride:
540
+ # window size is trial size
541
+ if window_size_samples is None:
542
+ window_size_samples = (
543
+ stops[0]
544
+ + trial_stop_offset_samples
545
+ - (onsets[0] + trial_start_offset_samples)
546
+ )
547
+ window_stride_samples = window_size_samples
548
+ this_trial_sizes = (stops + trial_stop_offset_samples) - (
549
+ onsets + trial_start_offset_samples
550
+ )
551
+ # Maybe actually this is not necessary?
552
+ # We could also just say we just assume window size=trial size
553
+ # in case not given, without this condition...
554
+ # but then would have to change functions overall
555
+ checker_trials_size = this_trial_sizes == window_size_samples
556
+
557
+ if not np.all(checker_trials_size):
558
+ trials_drops = int(len(this_trial_sizes) - sum(checker_trials_size))
559
+ warnings.warn(
560
+ f"Dropping trials with different windows size {trials_drops}",
561
+ )
562
+ bads_size_trials = checker_trials_size
563
+ events = events[checker_trials_size]
564
+ onsets = onsets[checker_trials_size]
565
+ stops = stops[checker_trials_size]
566
+
567
+ description = events[:, -1]
568
+
569
+ if not use_mne_epochs:
570
+ onsets = onsets - ds.raw.first_samp
571
+ stops = stops - ds.raw.first_samp
572
+ i_trials, i_window_in_trials, starts, stops = _compute_window_inds(
573
+ onsets,
574
+ stops,
575
+ trial_start_offset_samples,
576
+ trial_stop_offset_samples,
577
+ window_size_samples,
578
+ window_stride_samples,
579
+ drop_last_window,
580
+ accepted_bads_ratio,
581
+ )
582
+
583
+ if any(np.diff(starts) <= 0):
584
+ raise NotImplementedError("Trial overlap not implemented.")
585
+
586
+ events = [
587
+ [start, window_size_samples, description[i_trials[i_start]]]
588
+ for i_start, start in enumerate(starts)
589
+ ]
590
+ events = np.array(events)
591
+
592
+ description = events[:, -1]
593
+
594
+ metadata = pd.DataFrame(
595
+ {
596
+ "i_window_in_trial": i_window_in_trials,
597
+ "i_start_in_trial": starts,
598
+ "i_stop_in_trial": stops,
599
+ "target": description,
600
+ }
601
+ )
602
+ if use_mne_epochs:
603
+ # window size - 1, since tmax is inclusive
604
+ mne_epochs = mne.Epochs(
605
+ ds.raw,
606
+ events,
607
+ events_id,
608
+ baseline=None,
609
+ tmin=0,
610
+ tmax=(window_size_samples - 1) / ds.raw.info["sfreq"],
611
+ metadata=metadata,
612
+ preload=preload,
613
+ picks=picks,
614
+ reject=reject,
615
+ flat=flat,
616
+ on_missing=on_missing,
617
+ verbose=verbose,
618
+ )
619
+ if drop_bad_windows:
620
+ mne_epochs.drop_bad()
621
+ windows_ds = WindowsDataset(
622
+ mne_epochs,
623
+ ds.description,
624
+ )
625
+ else:
626
+ windows_ds = EEGWindowsDataset(
627
+ ds.raw,
628
+ metadata=metadata,
629
+ description=ds.description,
630
+ )
631
+ # add window_kwargs and raw_preproc_kwargs to windows dataset
632
+ setattr(windows_ds, "window_kwargs", window_kwargs)
633
+ kwargs_name = "raw_preproc_kwargs"
634
+ if hasattr(ds, kwargs_name):
635
+ setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
636
+ return windows_ds
637
+
638
+
639
+ def _create_fixed_length_windows(
640
+ ds,
641
+ start_offset_samples,
642
+ stop_offset_samples,
643
+ window_size_samples,
644
+ window_stride_samples,
645
+ drop_last_window,
646
+ mapping=None,
647
+ preload=False,
648
+ picks=None,
649
+ reject=None,
650
+ flat=None,
651
+ targets_from="metadata",
652
+ last_target_only=True,
653
+ lazy_metadata=False,
654
+ on_missing="error",
655
+ verbose="error",
656
+ ):
657
+ """Create WindowsDataset from RawDataset with sliding windows.
658
+
659
+ Parameters
660
+ ----------
661
+ ds : RawDataset
662
+ Dataset containing continuous data and description.
663
+
664
+ See `create_fixed_length_windows` for description of other parameters.
665
+
666
+ Returns
667
+ -------
668
+ WindowsDataset :
669
+ Windowed dataset.
670
+ """
671
+ # catch window_kwargs to store to dataset
672
+ window_kwargs = [
673
+ (create_fixed_length_windows.__name__, _get_windowing_kwargs(locals())),
674
+ ]
675
+ stop = ds.raw.n_times if stop_offset_samples is None else stop_offset_samples
676
+
677
+ # assume window should be whole recording
678
+ if window_size_samples is None:
679
+ window_size_samples = stop - start_offset_samples
680
+ if window_stride_samples is None:
681
+ window_stride_samples = window_size_samples
682
+
683
+ last_potential_start = stop - window_size_samples
684
+
685
+ # get targets from dataset description if they exist
686
+ target = -1 if ds.target_name is None else ds.description[ds.target_name]
687
+ if mapping is not None:
688
+ # in case of multiple targets
689
+ if isinstance(target, pd.Series):
690
+ target = target.replace(mapping).to_list()
691
+ # in case of single value target
692
+ else:
693
+ target = mapping[target]
694
+
695
+ if lazy_metadata:
696
+ factory = _FixedLengthWindowFunctions(
697
+ start_offset_samples,
698
+ last_potential_start,
699
+ window_stride_samples,
700
+ window_size_samples,
701
+ target,
702
+ )
703
+ metadata = _LazyDataFrame(
704
+ length=factory.length,
705
+ functions={
706
+ "i_window_in_trial": factory.i_window_in_trial,
707
+ "i_start_in_trial": factory.i_start_in_trial,
708
+ "i_stop_in_trial": factory.i_stop_in_trial,
709
+ "target": factory.target,
710
+ },
711
+ columns=[
712
+ "i_window_in_trial",
713
+ "i_start_in_trial",
714
+ "i_stop_in_trial",
715
+ "target",
716
+ ],
717
+ )
718
+ else:
719
+ # already includes last incomplete window start
720
+ starts = np.arange(
721
+ start_offset_samples, last_potential_start + 1, window_stride_samples
722
+ )
723
+
724
+ if not drop_last_window and starts[-1] < last_potential_start:
725
+ # if last window does not end at trial stop, make it stop there
726
+ starts = np.append(starts, last_potential_start)
727
+
728
+ metadata = pd.DataFrame(
729
+ {
730
+ "i_window_in_trial": np.arange(len(starts)),
731
+ "i_start_in_trial": starts,
732
+ "i_stop_in_trial": starts + window_size_samples,
733
+ "target": len(starts) * [target],
734
+ }
735
+ )
736
+
737
+ window_kwargs.append(
738
+ (
739
+ EEGWindowsDataset.__name__,
740
+ {"targets_from": targets_from, "last_target_only": last_target_only},
741
+ )
742
+ )
743
+ windows_ds = EEGWindowsDataset(
744
+ ds.raw,
745
+ metadata=metadata,
746
+ description=ds.description,
747
+ targets_from=targets_from,
748
+ last_target_only=last_target_only,
749
+ )
750
+ # add window_kwargs and raw_preproc_kwargs to windows dataset
751
+ setattr(windows_ds, "window_kwargs", window_kwargs)
752
+ kwargs_name = "raw_preproc_kwargs"
753
+ if hasattr(ds, kwargs_name):
754
+ setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
755
+ return windows_ds
756
+
757
+
758
+ def create_windows_from_target_channels(
759
+ concat_ds: BaseConcatDataset[RawDataset],
760
+ window_size_samples=None,
761
+ preload=False,
762
+ picks=None,
763
+ reject=None,
764
+ flat=None,
765
+ n_jobs=1,
766
+ last_target_only=True,
767
+ verbose="error",
768
+ ) -> BaseConcatDataset[EEGWindowsDataset]:
769
+ list_of_windows_ds = Parallel(n_jobs=n_jobs)(
770
+ delayed(_create_windows_from_target_channels)(
771
+ ds,
772
+ window_size_samples,
773
+ preload,
774
+ picks,
775
+ reject,
776
+ flat,
777
+ last_target_only,
778
+ "error",
779
+ verbose,
780
+ )
781
+ for ds in concat_ds.datasets
782
+ )
783
+ return BaseConcatDataset(list_of_windows_ds)
784
+
785
+
786
+ def _create_windows_from_target_channels(
787
+ ds,
788
+ window_size_samples,
789
+ preload=False,
790
+ picks=None,
791
+ reject=None,
792
+ flat=None,
793
+ last_target_only=True,
794
+ on_missing="error",
795
+ verbose="error",
796
+ ):
797
+ """Create WindowsDataset from RawDataset using targets `misc` channels from mne.Raw.
798
+
799
+ Parameters
800
+ ----------
801
+ ds : RawDataset
802
+ Dataset containing continuous data and description.
803
+
804
+ See `create_fixed_length_windows` for description of other parameters.
805
+
806
+ Returns
807
+ -------
808
+ WindowsDataset :
809
+ Windowed dataset.
810
+ """
811
+ window_kwargs = [
812
+ (create_windows_from_target_channels.__name__, _get_windowing_kwargs(locals())),
813
+ ]
814
+ stop = ds.raw.n_times + ds.raw.first_samp
815
+
816
+ target = ds.raw.get_data(picks="misc")
817
+ # TODO: handle multi targets present only for some events
818
+ stops = np.nonzero((~np.isnan(target[0, :])))[0] + 1
819
+ stops = stops[(stops < stop) & (stops >= window_size_samples)]
820
+ stops = stops.astype(int)
821
+ metadata = pd.DataFrame(
822
+ {
823
+ "i_window_in_trial": np.arange(len(stops)),
824
+ "i_start_in_trial": stops - window_size_samples,
825
+ "i_stop_in_trial": stops,
826
+ "target": len(stops) * [target],
827
+ }
828
+ )
829
+
830
+ targets_from = "channels"
831
+ window_kwargs.append(
832
+ (
833
+ EEGWindowsDataset.__name__,
834
+ {"targets_from": targets_from, "last_target_only": last_target_only},
835
+ )
836
+ )
837
+ windows_ds = EEGWindowsDataset(
838
+ ds.raw,
839
+ metadata=metadata,
840
+ description=ds.description,
841
+ targets_from=targets_from,
842
+ last_target_only=last_target_only,
843
+ )
844
+ setattr(windows_ds, "window_kwargs", window_kwargs)
845
+ kwargs_name = "raw_preproc_kwargs"
846
+ if hasattr(ds, kwargs_name):
847
+ setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
848
+ return windows_ds
849
+
850
+
851
+ def _compute_window_inds(
852
+ starts,
853
+ stops,
854
+ start_offset,
855
+ stop_offset,
856
+ size,
857
+ stride,
858
+ drop_last_window,
859
+ accepted_bads_ratio,
860
+ ):
861
+ """Compute window start and stop indices.
862
+
863
+ Create window starts from trial onsets (shifted by start_offset) to trial
864
+ end (shifted by stop_offset) separated by stride, as long as window size
865
+ fits into trial.
866
+
867
+ Parameters
868
+ ----------
869
+ starts: array-like
870
+ Trial starts in samples.
871
+ stops: array-like
872
+ Trial stops in samples.
873
+ start_offset: int
874
+ Start offset from original trial onsets in samples.
875
+ stop_offset: int
876
+ Stop offset from original trial stop in samples.
877
+ size: int
878
+ Window size.
879
+ stride: int
880
+ Stride between windows.
881
+ drop_last_window: bool
882
+ Toggles of shifting last window within range or dropping last samples.
883
+ accepted_bads_ratio: float
884
+ Acceptable proportion of bad trials within a raw. If the number of
885
+ trials whose length is exceeded by the window size is smaller than
886
+ this, then only the corresponding trials are dropped, but the
887
+ computation continues. Otherwise, an error is raised.
888
+
889
+ Returns
890
+ -------
891
+ result_lists: (list, list, list, list)
892
+ Trial, i_window_in_trial, start sample and stop sample of windows.
893
+ """
894
+ starts = np.array([starts]) if isinstance(starts, int) else starts
895
+ stops = np.array([stops]) if isinstance(stops, int) else stops
896
+
897
+ starts += start_offset
898
+ stops += stop_offset
899
+ if any(size > (stops - starts)):
900
+ bads_mask = size > (stops - starts)
901
+ min_duration = (stops - starts).min()
902
+ if sum(bads_mask) <= accepted_bads_ratio * len(starts):
903
+ starts = starts[np.logical_not(bads_mask)]
904
+ stops = stops[np.logical_not(bads_mask)]
905
+ warnings.warn(
906
+ f"Trials {np.where(bads_mask)[0]} are being dropped as the "
907
+ f"window size ({size}) exceeds their duration {min_duration}."
908
+ )
909
+ else:
910
+ current_ratio = sum(bads_mask) / len(starts)
911
+ raise ValueError(
912
+ f"Window size {size} exceeds trial duration "
913
+ f"({min_duration}) for too many trials "
914
+ f"({current_ratio * 100}%). Set "
915
+ f"accepted_bads_ratio to at least {current_ratio}"
916
+ "and restart training to be able to continue."
917
+ )
918
+
919
+ i_window_in_trials, i_trials, window_starts = [], [], []
920
+ for start_i, (start, stop) in enumerate(zip(starts, stops)):
921
+ # Generate possible window starts, with given stride, between original
922
+ # trial onsets and stops (shifted by start_offset and stop_offset,
923
+ # respectively)
924
+ possible_starts = np.arange(start, stop, stride)
925
+
926
+ # Possible window start is actually a start, if window size fits in
927
+ # trial start and stop
928
+ for i_window, s in enumerate(possible_starts):
929
+ if (s + size) <= stop:
930
+ window_starts.append(s)
931
+ i_window_in_trials.append(i_window)
932
+ i_trials.append(start_i)
933
+
934
+ # If the last window start + window size is not the same as
935
+ # stop + stop_offset, create another window that overlaps and stops
936
+ # at onset + stop_offset
937
+ if not drop_last_window:
938
+ if window_starts[-1] + size != stop:
939
+ window_starts.append(stop - size)
940
+ i_window_in_trials.append(i_window_in_trials[-1] + 1)
941
+ i_trials.append(start_i)
942
+
943
+ # Set window stops to be event stops (rather than trial stops)
944
+ window_stops = np.array(window_starts) + size
945
+ if not (len(i_window_in_trials) == len(window_starts) == len(window_stops)):
946
+ raise ValueError(
947
+ f"{len(i_window_in_trials)} == {len(window_starts)} == {len(window_stops)}"
948
+ )
949
+
950
+ return i_trials, i_window_in_trials, window_starts, window_stops
951
+
952
+
953
+ def _check_windowing_arguments(
954
+ trial_start_offset_samples,
955
+ trial_stop_offset_samples,
956
+ window_size_samples,
957
+ window_stride_samples,
958
+ ):
959
+ assert isinstance(trial_start_offset_samples, (int, np.integer))
960
+ assert isinstance(trial_stop_offset_samples, (int, np.integer)) or (
961
+ trial_stop_offset_samples is None
962
+ )
963
+ assert isinstance(window_size_samples, (int, np.integer, type(None)))
964
+ assert isinstance(window_stride_samples, (int, np.integer, type(None)))
965
+ assert (window_size_samples is None) == (window_stride_samples is None)
966
+ if window_size_samples is not None:
967
+ assert window_size_samples > 0, "window size has to be larger than 0"
968
+ assert window_stride_samples > 0, "window stride has to be larger than 0"
969
+
970
+
971
+ def _check_and_set_fixed_length_window_arguments(
972
+ start_offset_samples,
973
+ stop_offset_samples,
974
+ window_size_samples,
975
+ window_stride_samples,
976
+ drop_last_window,
977
+ lazy_metadata,
978
+ ):
979
+ """Raises warnings for incorrect input arguments and will set correct default values for
980
+ stop_offset_samples & drop_last_window, if necessary.
981
+ """
982
+ _check_windowing_arguments(
983
+ start_offset_samples,
984
+ stop_offset_samples,
985
+ window_size_samples,
986
+ window_stride_samples,
987
+ )
988
+
989
+ if stop_offset_samples == 0:
990
+ warnings.warn(
991
+ "Meaning of `trial_stop_offset_samples`=0 has changed, use `None` "
992
+ "to indicate end of trial/recording. Using `None`."
993
+ )
994
+ stop_offset_samples = None
995
+
996
+ if start_offset_samples != 0 or stop_offset_samples is not None:
997
+ warnings.warn(
998
+ "Usage of offset_sample args in create_fixed_length_windows is deprecated and"
999
+ " will be removed in future versions. Please use "
1000
+ 'braindecode.preprocessing.preprocess.Preprocessor("crop", tmin, tmax)'
1001
+ " instead."
1002
+ )
1003
+
1004
+ if (
1005
+ window_size_samples is not None
1006
+ and window_stride_samples is not None
1007
+ and drop_last_window is None
1008
+ ):
1009
+ raise ValueError(
1010
+ "drop_last_window must be set if both window_size_samples &"
1011
+ " window_stride_samples have also been set"
1012
+ )
1013
+ elif (
1014
+ window_size_samples is None
1015
+ and window_stride_samples is None
1016
+ and drop_last_window is False
1017
+ ):
1018
+ # necessary for following assertion
1019
+ drop_last_window = None
1020
+
1021
+ assert (
1022
+ (window_size_samples is None)
1023
+ == (window_stride_samples is None)
1024
+ == (drop_last_window is None)
1025
+ )
1026
+ if not drop_last_window and lazy_metadata:
1027
+ raise ValueError(
1028
+ "Cannot have drop_last_window=False and lazy_metadata=True at the same time."
1029
+ )
1030
+ return stop_offset_samples, drop_last_window
1031
+
1032
+
1033
+ def _get_windowing_kwargs(windowing_func_locals):
1034
+ input_kwargs = windowing_func_locals
1035
+ input_kwargs.pop("ds")
1036
+ windowing_kwargs = {k: v for k, v in input_kwargs.items()}
1037
+ return windowing_kwargs