braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
- """Get epochs from mne.Raw
2
- """
1
+ """Get epochs from mne.Raw"""
3
2
 
4
3
  # Authors: Hubert Banville <hubert.jbanville@gmail.com>
5
4
  # Lukas Gemein <l.gemein@gmail.com>
@@ -15,23 +14,199 @@
15
14
  #
16
15
  # License: BSD (3-clause)
17
16
 
17
+ from __future__ import annotations
18
+
18
19
  import warnings
20
+ from typing import Any, Callable
19
21
 
20
- import numpy as np
21
22
  import mne
23
+ import numpy as np
22
24
  import pandas as pd
23
25
  from joblib import Parallel, delayed
26
+ from numpy.typing import ArrayLike
24
27
 
25
- from ..datasets.base import WindowsDataset, BaseConcatDataset, EEGWindowsDataset
28
+ from ..datasets.base import BaseConcatDataset, EEGWindowsDataset, WindowsDataset
29
+
30
+
31
+ class _LazyDataFrame:
32
+ """
33
+ DataFrame-like object that lazily computes values (experimental).
34
+
35
+ This class emulates some features of a pandas DataFrame, but computes
36
+ the values on-the-fly when they are accessed. This is useful for
37
+ very long DataFrames with repetitive values.
38
+ Only the methods used by EEGWindowsDataset on its metadata are implemented.
39
+
40
+ Parameters:
41
+ -----------
42
+ length: int
43
+ The length of the dataframe.
44
+ functions: dict[str, Callable[[int], Any]]
45
+ A dictionary mapping column names to functions that take an index and
46
+ return the value of the column at that index.
47
+ columns: list[str]
48
+ The names of the columns in the dataframe.
49
+ series: bool
50
+ Whether the object should emulate a series or a dataframe.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ length: int,
56
+ functions: dict[str, Callable[[int], Any]],
57
+ columns: list[str],
58
+ series: bool = False,
59
+ ):
60
+ if not (isinstance(length, int) and length >= 0):
61
+ raise ValueError("Length must be a positive integer.")
62
+ if not all(c in functions for c in columns):
63
+ raise ValueError("All columns must have a corresponding function.")
64
+ if series and len(columns) != 1:
65
+ raise ValueError("Series must have exactly one column.")
66
+ self.length = length
67
+ self.functions = functions
68
+ self.columns = columns
69
+ self.series = series
70
+
71
+ @property
72
+ def loc(self):
73
+ return self
74
+
75
+ def __len__(self):
76
+ return self.length
77
+
78
+ def __getitem__(self, key):
79
+ if not isinstance(key, tuple):
80
+ key = (key, self.columns)
81
+ if len(key) == 1:
82
+ key = (key[0], self.columns)
83
+ if not len(key) == 2:
84
+ raise IndexError(
85
+ f"index must be either [row] or [row, column], got [{', '.join(map(str, key))}]."
86
+ )
87
+ row, col = key
88
+ if col == slice(None): # all columns (i.e., call to df[row, :])
89
+ col = self.columns
90
+ one_col = False
91
+ if isinstance(col, str): # one column
92
+ one_col = True
93
+ col = [col]
94
+ else: # multiple columns
95
+ col = list(col)
96
+ if not all(c in self.columns for c in col):
97
+ raise IndexError(
98
+ f"All columns must be present in the dataframe with columns {self.columns}. Got {col}."
99
+ )
100
+ if row == slice(None): # all rows (i.e., call to df[:] or df[:, col])
101
+ return _LazyDataFrame(self.length, self.functions, col)
102
+ if not isinstance(row, int):
103
+ raise NotImplementedError(
104
+ "Row indexing only supports either a single integer or a null slice (i.e., df[:])."
105
+ )
106
+ if not (0 <= row < self.length):
107
+ raise IndexError(f"Row index {row} is out of bounds.")
108
+ if self.series or one_col:
109
+ return self.functions[col[0]](row)
110
+ return pd.Series({c: self.functions[c](row) for c in col})
111
+
112
+ def to_numpy(self):
113
+ return _LazyDataFrame(
114
+ length=self.length,
115
+ functions=self.functions,
116
+ columns=self.columns,
117
+ series=len(self.columns) == 1,
118
+ )
119
+
120
+ def to_list(self):
121
+ return self.to_numpy()
122
+
123
+
124
+ class _FixedLengthWindowFunctions:
125
+ """Class defining functions for lazy metadata generation in fixed length windowing
126
+ to be used in combination with _LazyDataFrame (experimental)."""
127
+
128
+ def __init__(
129
+ self,
130
+ start_offset_samples: int,
131
+ last_potential_start: int,
132
+ window_stride_samples: int,
133
+ window_size_samples: int,
134
+ target: Any,
135
+ ):
136
+ self.start_offset_samples = start_offset_samples
137
+ self.last_potential_start = last_potential_start
138
+ self.window_stride_samples = window_stride_samples
139
+ self.window_size_samples = window_size_samples
140
+ self.target_val = target
141
+
142
+ @property
143
+ def length(self) -> int:
144
+ return int(
145
+ np.ceil(
146
+ (self.last_potential_start + 1 - self.start_offset_samples)
147
+ / self.window_stride_samples
148
+ )
149
+ )
150
+
151
+ def i_window_in_trial(self, i: int) -> int:
152
+ return i
153
+
154
+ def i_start_in_trial(self, i: int) -> int:
155
+ return self.start_offset_samples + i * self.window_stride_samples
156
+
157
+ def i_stop_in_trial(self, i: int) -> int:
158
+ return (
159
+ self.start_offset_samples
160
+ + i * self.window_stride_samples
161
+ + self.window_size_samples
162
+ )
163
+
164
+ def target(self, i: int) -> Any:
165
+ return self.target_val
166
+
167
+
168
+ def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
169
+ should_use_mne_epochs = (
170
+ (reject is not None)
171
+ or (picks is not None)
172
+ or (flat is not None)
173
+ or (drop_bad_windows is True)
174
+ )
175
+ if use_mne_epochs is None:
176
+ if should_use_mne_epochs:
177
+ warnings.warn(
178
+ "Using reject or picks or flat or dropping bad windows means "
179
+ "mne Epochs are created, "
180
+ "which will be substantially slower and may be deprecated in the future."
181
+ )
182
+ return should_use_mne_epochs
183
+ if not use_mne_epochs and should_use_mne_epochs:
184
+ raise ValueError(
185
+ "Cannot set use_mne_epochs=False when using reject, picks, flat, or dropping bad windows."
186
+ )
187
+ return use_mne_epochs
26
188
 
27
189
 
28
190
  # XXX it's called concat_ds...
29
191
  def create_windows_from_events(
30
- concat_ds, trial_start_offset_samples=0, trial_stop_offset_samples=0,
31
- window_size_samples=None, window_stride_samples=None,
32
- drop_last_window=False, mapping=None, preload=False,
33
- drop_bad_windows=None, picks=None, reject=None, flat=None,
34
- on_missing='error', accepted_bads_ratio=0.0, n_jobs=1, verbose='error'):
192
+ concat_ds: BaseConcatDataset,
193
+ trial_start_offset_samples: int = 0,
194
+ trial_stop_offset_samples: int = 0,
195
+ window_size_samples: int | None = None,
196
+ window_stride_samples: int | None = None,
197
+ drop_last_window: bool = False,
198
+ mapping: dict[str, int] | None = None,
199
+ preload: bool = False,
200
+ drop_bad_windows: bool | None = None,
201
+ picks: str | ArrayLike | slice | None = None,
202
+ reject: dict[str, float] | None = None,
203
+ flat: dict[str, float] | None = None,
204
+ on_missing: str = "error",
205
+ accepted_bads_ratio: float = 0.0,
206
+ use_mne_epochs: bool | None = None,
207
+ n_jobs: int = 1,
208
+ verbose: bool | str | int | None = "error",
209
+ ):
35
210
  """Create windows based on events in mne.Raw.
36
211
 
37
212
  This function extracts windows of size window_size_samples in the interval
@@ -100,6 +275,10 @@ def create_windows_from_events(
100
275
  smaller than this, then only the corresponding trials are dropped, but
101
276
  the computation continues. Otherwise, an error is raised. Defaults to
102
277
  0.0 (raise an error).
278
+ use_mne_epochs: bool
279
+ If False, return EEGWindowsDataset objects.
280
+ If True, return mne.Epochs objects encapsulated in WindowsDataset objects,
281
+ which is substantially slower that EEGWindowsDataset.
103
282
  n_jobs: int
104
283
  Number of jobs to use to parallelize the windowing.
105
284
  verbose: bool | str | int | None
@@ -111,8 +290,11 @@ def create_windows_from_events(
111
290
  Concatenated datasets of WindowsDataset containing the extracted windows.
112
291
  """
113
292
  _check_windowing_arguments(
114
- trial_start_offset_samples, trial_stop_offset_samples,
115
- window_size_samples, window_stride_samples)
293
+ trial_start_offset_samples,
294
+ trial_stop_offset_samples,
295
+ window_size_samples,
296
+ window_stride_samples,
297
+ )
116
298
 
117
299
  # If user did not specify mapping, we extract all events from all datasets
118
300
  # and map them to increasing integers starting from 0
@@ -121,34 +303,62 @@ def create_windows_from_events(
121
303
  infer_window_size_stride = window_size_samples is None
122
304
 
123
305
  if drop_bad_windows is not None:
124
- warnings.warn('Drop bad windows only has an effect if mne epochs are created, '
125
- 'and this argument may be removed in the future.')
306
+ warnings.warn(
307
+ "Drop bad windows only has an effect if mne epochs are created, "
308
+ "and this argument may be removed in the future."
309
+ )
126
310
 
127
- use_mne_epochs = (reject is not None) or (picks is not None) or (flat is not None) or (
128
- drop_bad_windows is True)
129
- if use_mne_epochs:
130
- warnings.warn('Using reject or picks or flat or dropping bad windows means '
131
- 'mne Epochs are created, '
132
- 'which will be substantially slower and may be deprecated in the future.')
133
- if drop_bad_windows is None:
134
- drop_bad_windows = True
311
+ use_mne_epochs = _get_use_mne_epochs(
312
+ use_mne_epochs, reject, picks, flat, drop_bad_windows
313
+ )
314
+ if use_mne_epochs and drop_bad_windows is None:
315
+ drop_bad_windows = True
135
316
 
136
317
  list_of_windows_ds = Parallel(n_jobs=n_jobs)(
137
318
  delayed(_create_windows_from_events)(
138
- ds, infer_mapping, infer_window_size_stride,
139
- trial_start_offset_samples, trial_stop_offset_samples,
140
- window_size_samples, window_stride_samples, drop_last_window,
141
- mapping, preload, drop_bad_windows, picks, reject, flat,
142
- on_missing, accepted_bads_ratio, verbose, use_mne_epochs) for ds in concat_ds.datasets)
319
+ ds,
320
+ infer_mapping,
321
+ infer_window_size_stride,
322
+ trial_start_offset_samples,
323
+ trial_stop_offset_samples,
324
+ window_size_samples,
325
+ window_stride_samples,
326
+ drop_last_window,
327
+ mapping,
328
+ preload,
329
+ drop_bad_windows,
330
+ picks,
331
+ reject,
332
+ flat,
333
+ on_missing,
334
+ accepted_bads_ratio,
335
+ verbose,
336
+ use_mne_epochs,
337
+ )
338
+ for ds in concat_ds.datasets
339
+ )
143
340
  return BaseConcatDataset(list_of_windows_ds)
144
341
 
145
342
 
146
343
  def create_fixed_length_windows(
147
- concat_ds, start_offset_samples=0, stop_offset_samples=None,
148
- window_size_samples=None, window_stride_samples=None, drop_last_window=None,
149
- mapping=None, preload=False, picks=None,
150
- reject=None, flat=None, targets_from='metadata', last_target_only=True,
151
- on_missing='error', n_jobs=1, verbose='error'):
344
+ concat_ds: BaseConcatDataset,
345
+ start_offset_samples: int = 0,
346
+ stop_offset_samples: int | None = None,
347
+ window_size_samples: int | None = None,
348
+ window_stride_samples: int | None = None,
349
+ drop_last_window: bool | None = None,
350
+ mapping: dict[str, int] | None = None,
351
+ preload: bool = False,
352
+ picks: str | ArrayLike | slice | None = None,
353
+ reject: dict[str, float] | None = None,
354
+ flat: dict[str, float] | None = None,
355
+ targets_from: str = "metadata",
356
+ last_target_only: bool = True,
357
+ lazy_metadata: bool = False,
358
+ on_missing: str = "error",
359
+ n_jobs: int = 1,
360
+ verbose: bool | str | int | None = "error",
361
+ ):
152
362
  """Windower that creates sliding windows.
153
363
 
154
364
  Parameters
@@ -183,6 +393,9 @@ def create_fixed_length_windows(
183
393
  flat: dict | None
184
394
  Epoch rejection parameters based on flatness of signals. If None, no
185
395
  rejection based on flatness is done. See mne.Epochs.
396
+ lazy_metadata: bool
397
+ If True, metadata is not computed immediately, but only when accessed
398
+ by using the _LazyDataFrame (experimental).
186
399
  on_missing: str
187
400
  What to do if one or several event ids are not found in the recording.
188
401
  Valid keys are ‘error’ | ‘warning’ | ‘ignore’. See mne.Epochs.
@@ -196,35 +409,70 @@ def create_fixed_length_windows(
196
409
  windows_datasets: BaseConcatDataset
197
410
  Concatenated datasets of WindowsDataset containing the extracted windows.
198
411
  """
199
- stop_offset_samples, drop_last_window = _check_and_set_fixed_length_window_arguments(
200
- start_offset_samples, stop_offset_samples, window_size_samples, window_stride_samples,
201
- drop_last_window)
412
+ stop_offset_samples, drop_last_window = (
413
+ _check_and_set_fixed_length_window_arguments(
414
+ start_offset_samples,
415
+ stop_offset_samples,
416
+ window_size_samples,
417
+ window_stride_samples,
418
+ drop_last_window,
419
+ lazy_metadata,
420
+ )
421
+ )
202
422
 
203
423
  # check if recordings are of different lengths
204
424
  lengths = np.array([ds.raw.n_times for ds in concat_ds.datasets])
205
425
  if (np.diff(lengths) != 0).any() and window_size_samples is None:
206
- warnings.warn('Recordings have different lengths, they will not be batch-able!')
426
+ warnings.warn("Recordings have different lengths, they will not be batch-able!")
207
427
  if (window_size_samples is not None) and any(window_size_samples > lengths):
208
- raise ValueError(f'Window size {window_size_samples} exceeds trial '
209
- f'duration {lengths.min()}.')
428
+ raise ValueError(
429
+ f"Window size {window_size_samples} exceeds trial duration {lengths.min()}."
430
+ )
210
431
 
211
432
  list_of_windows_ds = Parallel(n_jobs=n_jobs)(
212
433
  delayed(_create_fixed_length_windows)(
213
- ds, start_offset_samples, stop_offset_samples, window_size_samples,
214
- window_stride_samples, drop_last_window, mapping, preload,
215
- picks, reject, flat, targets_from, last_target_only,
216
- on_missing, verbose) for ds in concat_ds.datasets)
434
+ ds,
435
+ start_offset_samples,
436
+ stop_offset_samples,
437
+ window_size_samples,
438
+ window_stride_samples,
439
+ drop_last_window,
440
+ mapping,
441
+ preload,
442
+ picks,
443
+ reject,
444
+ flat,
445
+ targets_from,
446
+ last_target_only,
447
+ lazy_metadata,
448
+ on_missing,
449
+ verbose,
450
+ )
451
+ for ds in concat_ds.datasets
452
+ )
217
453
  return BaseConcatDataset(list_of_windows_ds)
218
454
 
219
455
 
220
456
  def _create_windows_from_events(
221
- ds, infer_mapping, infer_window_size_stride,
222
- trial_start_offset_samples, trial_stop_offset_samples,
223
- window_size_samples=None, window_stride_samples=None,
224
- drop_last_window=False, mapping=None, preload=False,
225
- drop_bad_windows=True, picks=None, reject=None, flat=None,
226
- on_missing='error', accepted_bads_ratio=0.0, verbose='error',
227
- use_mne_epochs=False):
457
+ ds,
458
+ infer_mapping,
459
+ infer_window_size_stride,
460
+ trial_start_offset_samples,
461
+ trial_stop_offset_samples,
462
+ window_size_samples=None,
463
+ window_stride_samples=None,
464
+ drop_last_window=False,
465
+ mapping=None,
466
+ preload=False,
467
+ drop_bad_windows=True,
468
+ picks=None,
469
+ reject=None,
470
+ flat=None,
471
+ on_missing="error",
472
+ accepted_bads_ratio=0.0,
473
+ verbose="error",
474
+ use_mne_epochs=False,
475
+ ):
228
476
  """Create WindowsDataset from BaseDataset based on events.
229
477
 
230
478
  Parameters
@@ -254,47 +502,61 @@ def _create_windows_from_events(
254
502
  new_unique_events = [x for x in unique_events if x not in mapping]
255
503
  # mapping event descriptions to integers from 0 on
256
504
  max_id_existing_mapping = len(mapping)
257
- mapping.update({
505
+ mapping.update(
506
+ {
258
507
  event_name: i_event_type + max_id_existing_mapping
259
508
  for i_event_type, event_name in enumerate(new_unique_events)
260
- })
509
+ }
510
+ )
261
511
 
262
512
  events, events_id = mne.events_from_annotations(ds.raw, mapping)
263
513
  onsets = events[:, 0]
264
514
  # Onsets are relative to the beginning of the recording
265
- filtered_durations = np.array([
266
- a['duration'] for a in ds.raw.annotations
267
- if a['description'] in events_id
268
- ])
515
+ filtered_durations = np.array(
516
+ [a["duration"] for a in ds.raw.annotations if a["description"] in events_id]
517
+ )
269
518
 
270
- stops = onsets + (filtered_durations * ds.raw.info['sfreq']).astype(int)
519
+ stops = onsets + (filtered_durations * ds.raw.info["sfreq"]).astype(int)
271
520
  # XXX This could probably be simplified by using chunk_duration in
272
521
  # `events_from_annotations`
273
522
 
274
- last_samp = ds.raw.first_samp + ds.raw.n_times
275
- if stops[-1] + trial_stop_offset_samples > last_samp:
523
+ last_samp = ds.raw.first_samp + ds.raw.n_times - 1
524
+ # `stops` is used exclusively (i.e. `start:stop`), so add back 1
525
+ if stops[-1] + trial_stop_offset_samples > last_samp + 1:
276
526
  raise ValueError(
277
527
  '"trial_stop_offset_samples" too large. Stop of last trial '
278
528
  f'({stops[-1]}) + "trial_stop_offset_samples" '
279
- f'({trial_stop_offset_samples}) must be smaller than length of'
280
- f' recording ({len(ds)}).')
529
+ f"({trial_stop_offset_samples}) must be smaller than length of"
530
+ f" recording ({len(ds)})."
531
+ )
281
532
 
282
533
  if infer_window_size_stride:
283
534
  # window size is trial size
284
535
  if window_size_samples is None:
285
- window_size_samples = stops[0] + trial_stop_offset_samples - (
286
- onsets[0] + trial_start_offset_samples)
536
+ window_size_samples = (
537
+ stops[0]
538
+ + trial_stop_offset_samples
539
+ - (onsets[0] + trial_start_offset_samples)
540
+ )
287
541
  window_stride_samples = window_size_samples
288
542
  this_trial_sizes = (stops + trial_stop_offset_samples) - (
289
- onsets + trial_start_offset_samples)
543
+ onsets + trial_start_offset_samples
544
+ )
290
545
  # Maybe actually this is not necessary?
291
546
  # We could also just say we just assume window size=trial size
292
547
  # in case not given, without this condition...
293
548
  # but then would have to change functions overall
294
- # to deal with varying window sizes hmmhmh
295
- assert np.all(this_trial_sizes == window_size_samples), (
296
- 'All trial sizes should be the same if you do not supply a window '
297
- 'size.')
549
+ checker_trials_size = this_trial_sizes == window_size_samples
550
+
551
+ if not np.all(checker_trials_size):
552
+ trials_drops = int(len(this_trial_sizes) - sum(checker_trials_size))
553
+ warnings.warn(
554
+ f"Dropping trials with different windows size {trials_drops}",
555
+ )
556
+ bads_size_trials = checker_trials_size
557
+ events = events[checker_trials_size]
558
+ onsets = onsets[checker_trials_size]
559
+ stops = stops[checker_trials_size]
298
560
 
299
561
  description = events[:, -1]
300
562
 
@@ -302,51 +564,90 @@ def _create_windows_from_events(
302
564
  onsets = onsets - ds.raw.first_samp
303
565
  stops = stops - ds.raw.first_samp
304
566
  i_trials, i_window_in_trials, starts, stops = _compute_window_inds(
305
- onsets, stops, trial_start_offset_samples,
306
- trial_stop_offset_samples, window_size_samples,
307
- window_stride_samples, drop_last_window, accepted_bads_ratio)
567
+ onsets,
568
+ stops,
569
+ trial_start_offset_samples,
570
+ trial_stop_offset_samples,
571
+ window_size_samples,
572
+ window_stride_samples,
573
+ drop_last_window,
574
+ accepted_bads_ratio,
575
+ )
308
576
 
309
577
  if any(np.diff(starts) <= 0):
310
- raise NotImplementedError('Trial overlap not implemented.')
578
+ raise NotImplementedError("Trial overlap not implemented.")
311
579
 
312
- events = [[start, window_size_samples, description[i_trials[i_start]]]
313
- for i_start, start in enumerate(starts)]
580
+ events = [
581
+ [start, window_size_samples, description[i_trials[i_start]]]
582
+ for i_start, start in enumerate(starts)
583
+ ]
314
584
  events = np.array(events)
315
585
 
316
586
  description = events[:, -1]
317
587
 
318
- metadata = pd.DataFrame({
319
- 'i_window_in_trial': i_window_in_trials,
320
- 'i_start_in_trial': starts,
321
- 'i_stop_in_trial': stops,
322
- 'target': description
323
- })
588
+ metadata = pd.DataFrame(
589
+ {
590
+ "i_window_in_trial": i_window_in_trials,
591
+ "i_start_in_trial": starts,
592
+ "i_stop_in_trial": stops,
593
+ "target": description,
594
+ }
595
+ )
324
596
  if use_mne_epochs:
325
597
  # window size - 1, since tmax is inclusive
326
598
  mne_epochs = mne.Epochs(
327
- ds.raw, events, events_id, baseline=None, tmin=0,
328
- tmax=(window_size_samples - 1) / ds.raw.info['sfreq'],
329
- metadata=metadata, preload=preload, picks=picks, reject=reject,
330
- flat=flat, on_missing=on_missing, verbose=verbose)
599
+ ds.raw,
600
+ events,
601
+ events_id,
602
+ baseline=None,
603
+ tmin=0,
604
+ tmax=(window_size_samples - 1) / ds.raw.info["sfreq"],
605
+ metadata=metadata,
606
+ preload=preload,
607
+ picks=picks,
608
+ reject=reject,
609
+ flat=flat,
610
+ on_missing=on_missing,
611
+ verbose=verbose,
612
+ )
331
613
  if drop_bad_windows:
332
614
  mne_epochs.drop_bad()
333
- windows_ds = WindowsDataset(mne_epochs, ds.description,)
615
+ windows_ds = WindowsDataset(
616
+ mne_epochs,
617
+ ds.description,
618
+ )
334
619
  else:
335
620
  windows_ds = EEGWindowsDataset(
336
- ds.raw, metadata=metadata, description=ds.description,)
621
+ ds.raw,
622
+ metadata=metadata,
623
+ description=ds.description,
624
+ )
337
625
  # add window_kwargs and raw_preproc_kwargs to windows dataset
338
- setattr(windows_ds, 'window_kwargs', window_kwargs)
339
- kwargs_name = 'raw_preproc_kwargs'
626
+ setattr(windows_ds, "window_kwargs", window_kwargs)
627
+ kwargs_name = "raw_preproc_kwargs"
340
628
  if hasattr(ds, kwargs_name):
341
629
  setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
342
630
  return windows_ds
343
631
 
344
632
 
345
633
  def _create_fixed_length_windows(
346
- ds, start_offset_samples, stop_offset_samples, window_size_samples,
347
- window_stride_samples, drop_last_window, mapping=None, preload=False,
348
- picks=None, reject=None, flat=None, targets_from='metadata',
349
- last_target_only=True, on_missing='error', verbose='error'):
634
+ ds,
635
+ start_offset_samples,
636
+ stop_offset_samples,
637
+ window_size_samples,
638
+ window_stride_samples,
639
+ drop_last_window,
640
+ mapping=None,
641
+ preload=False,
642
+ picks=None,
643
+ reject=None,
644
+ flat=None,
645
+ targets_from="metadata",
646
+ last_target_only=True,
647
+ lazy_metadata=False,
648
+ on_missing="error",
649
+ verbose="error",
650
+ ):
350
651
  """Create WindowsDataset from BaseDataset with sliding windows.
351
652
 
352
653
  Parameters
@@ -365,8 +666,7 @@ def _create_fixed_length_windows(
365
666
  window_kwargs = [
366
667
  (create_fixed_length_windows.__name__, _get_windowing_kwargs(locals())),
367
668
  ]
368
- stop = ds.raw.n_times\
369
- if stop_offset_samples is None else stop_offset_samples
669
+ stop = ds.raw.n_times if stop_offset_samples is None else stop_offset_samples
370
670
 
371
671
  # assume window should be whole recording
372
672
  if window_size_samples is None:
@@ -375,15 +675,6 @@ def _create_fixed_length_windows(
375
675
  window_stride_samples = window_size_samples
376
676
 
377
677
  last_potential_start = stop - window_size_samples
378
- # already includes last incomplete window start
379
- starts = np.arange(
380
- start_offset_samples,
381
- last_potential_start + 1,
382
- window_stride_samples)
383
-
384
- if not drop_last_window and starts[-1] < last_potential_start:
385
- # if last window does not end at trial stop, make it stop there
386
- starts = np.append(starts, last_potential_start)
387
678
 
388
679
  # get targets from dataset description if they exist
389
680
  target = -1 if ds.target_name is None else ds.description[ds.target_name]
@@ -395,18 +686,53 @@ def _create_fixed_length_windows(
395
686
  else:
396
687
  target = mapping[target]
397
688
 
398
- metadata = pd.DataFrame({
399
- 'i_window_in_trial': np.arange(len(starts)),
400
- 'i_start_in_trial': starts,
401
- 'i_stop_in_trial': starts + window_size_samples,
402
- 'target': len(starts) * [target]
403
- })
689
+ if lazy_metadata:
690
+ factory = _FixedLengthWindowFunctions(
691
+ start_offset_samples,
692
+ last_potential_start,
693
+ window_stride_samples,
694
+ window_size_samples,
695
+ target,
696
+ )
697
+ metadata = _LazyDataFrame(
698
+ length=factory.length,
699
+ functions={
700
+ "i_window_in_trial": factory.i_window_in_trial,
701
+ "i_start_in_trial": factory.i_start_in_trial,
702
+ "i_stop_in_trial": factory.i_stop_in_trial,
703
+ "target": factory.target,
704
+ },
705
+ columns=[
706
+ "i_window_in_trial",
707
+ "i_start_in_trial",
708
+ "i_stop_in_trial",
709
+ "target",
710
+ ],
711
+ )
712
+ else:
713
+ # already includes last incomplete window start
714
+ starts = np.arange(
715
+ start_offset_samples, last_potential_start + 1, window_stride_samples
716
+ )
717
+
718
+ if not drop_last_window and starts[-1] < last_potential_start:
719
+ # if last window does not end at trial stop, make it stop there
720
+ starts = np.append(starts, last_potential_start)
721
+
722
+ metadata = pd.DataFrame(
723
+ {
724
+ "i_window_in_trial": np.arange(len(starts)),
725
+ "i_start_in_trial": starts,
726
+ "i_stop_in_trial": starts + window_size_samples,
727
+ "target": len(starts) * [target],
728
+ }
729
+ )
404
730
 
405
731
  window_kwargs.append(
406
- (EEGWindowsDataset.__name__, {
407
- 'targets_from': targets_from,
408
- 'last_target_only': last_target_only
409
- })
732
+ (
733
+ EEGWindowsDataset.__name__,
734
+ {"targets_from": targets_from, "last_target_only": last_target_only},
735
+ )
410
736
  )
411
737
  windows_ds = EEGWindowsDataset(
412
738
  ds.raw,
@@ -416,28 +742,52 @@ def _create_fixed_length_windows(
416
742
  last_target_only=last_target_only,
417
743
  )
418
744
  # add window_kwargs and raw_preproc_kwargs to windows dataset
419
- setattr(windows_ds, 'window_kwargs', window_kwargs)
420
- kwargs_name = 'raw_preproc_kwargs'
745
+ setattr(windows_ds, "window_kwargs", window_kwargs)
746
+ kwargs_name = "raw_preproc_kwargs"
421
747
  if hasattr(ds, kwargs_name):
422
748
  setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
423
749
  return windows_ds
424
750
 
425
751
 
426
752
  def create_windows_from_target_channels(
427
- concat_ds, window_size_samples=None, preload=False,
428
- picks=None, reject=None, flat=None, n_jobs=1, last_target_only=True,
429
- verbose='error'):
753
+ concat_ds,
754
+ window_size_samples=None,
755
+ preload=False,
756
+ picks=None,
757
+ reject=None,
758
+ flat=None,
759
+ n_jobs=1,
760
+ last_target_only=True,
761
+ verbose="error",
762
+ ):
430
763
  list_of_windows_ds = Parallel(n_jobs=n_jobs)(
431
764
  delayed(_create_windows_from_target_channels)(
432
- ds, window_size_samples, preload, picks, reject,
433
- flat, last_target_only, 'error', verbose) for ds in concat_ds.datasets)
765
+ ds,
766
+ window_size_samples,
767
+ preload,
768
+ picks,
769
+ reject,
770
+ flat,
771
+ last_target_only,
772
+ "error",
773
+ verbose,
774
+ )
775
+ for ds in concat_ds.datasets
776
+ )
434
777
  return BaseConcatDataset(list_of_windows_ds)
435
778
 
436
779
 
437
780
  def _create_windows_from_target_channels(
438
- ds, window_size_samples, preload=False, picks=None,
439
- reject=None, flat=None, last_target_only=True, on_missing='error',
440
- verbose='error'):
781
+ ds,
782
+ window_size_samples,
783
+ preload=False,
784
+ picks=None,
785
+ reject=None,
786
+ flat=None,
787
+ last_target_only=True,
788
+ on_missing="error",
789
+ verbose="error",
790
+ ):
441
791
  """Create WindowsDataset from BaseDataset using targets `misc` channels from mne.Raw.
442
792
 
443
793
  Parameters
@@ -457,24 +807,26 @@ def _create_windows_from_target_channels(
457
807
  ]
458
808
  stop = ds.raw.n_times + ds.raw.first_samp
459
809
 
460
- target = ds.raw.get_data(picks='misc')
810
+ target = ds.raw.get_data(picks="misc")
461
811
  # TODO: handle multi targets present only for some events
462
812
  stops = np.nonzero((~np.isnan(target[0, :])))[0] + 1
463
813
  stops = stops[(stops < stop) & (stops >= window_size_samples)]
464
814
  stops = stops.astype(int)
465
- metadata = pd.DataFrame({
466
- 'i_window_in_trial': np.arange(len(stops)),
467
- 'i_start_in_trial': stops - window_size_samples,
468
- 'i_stop_in_trial': stops,
469
- 'target': len(stops) * [target]
470
- })
471
-
472
- targets_from = 'channels'
815
+ metadata = pd.DataFrame(
816
+ {
817
+ "i_window_in_trial": np.arange(len(stops)),
818
+ "i_start_in_trial": stops - window_size_samples,
819
+ "i_stop_in_trial": stops,
820
+ "target": len(stops) * [target],
821
+ }
822
+ )
823
+
824
+ targets_from = "channels"
473
825
  window_kwargs.append(
474
- (EEGWindowsDataset.__name__, {
475
- 'targets_from': targets_from,
476
- 'last_target_only': last_target_only
477
- })
826
+ (
827
+ EEGWindowsDataset.__name__,
828
+ {"targets_from": targets_from, "last_target_only": last_target_only},
829
+ )
478
830
  )
479
831
  windows_ds = EEGWindowsDataset(
480
832
  ds.raw,
@@ -483,16 +835,23 @@ def _create_windows_from_target_channels(
483
835
  targets_from=targets_from,
484
836
  last_target_only=last_target_only,
485
837
  )
486
- setattr(windows_ds, 'window_kwargs', window_kwargs)
487
- kwargs_name = 'raw_preproc_kwargs'
838
+ setattr(windows_ds, "window_kwargs", window_kwargs)
839
+ kwargs_name = "raw_preproc_kwargs"
488
840
  if hasattr(ds, kwargs_name):
489
841
  setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
490
842
  return windows_ds
491
843
 
492
844
 
493
845
  def _compute_window_inds(
494
- starts, stops, start_offset, stop_offset, size, stride,
495
- drop_last_window, accepted_bads_ratio):
846
+ starts,
847
+ stops,
848
+ start_offset,
849
+ stop_offset,
850
+ size,
851
+ stride,
852
+ drop_last_window,
853
+ accepted_bads_ratio,
854
+ ):
496
855
  """Compute window start and stop indices.
497
856
 
498
857
  Create window starts from trial onsets (shifted by start_offset) to trial
@@ -531,27 +890,31 @@ def _compute_window_inds(
531
890
 
532
891
  starts += start_offset
533
892
  stops += stop_offset
534
- if any(size > (stops-starts)):
535
- bads_mask = size > (stops-starts)
536
- min_duration = (stops-starts).min()
893
+ if any(size > (stops - starts)):
894
+ bads_mask = size > (stops - starts)
895
+ min_duration = (stops - starts).min()
537
896
  if sum(bads_mask) <= accepted_bads_ratio * len(starts):
538
897
  starts = starts[np.logical_not(bads_mask)]
539
898
  stops = stops[np.logical_not(bads_mask)]
540
899
  warnings.warn(
541
- f'Trials {np.where(bads_mask)[0]} are being dropped as the '
542
- f'window size ({size}) exceeds their duration {min_duration}.')
900
+ f"Trials {np.where(bads_mask)[0]} are being dropped as the "
901
+ f"window size ({size}) exceeds their duration {min_duration}."
902
+ )
543
903
  else:
544
904
  current_ratio = sum(bads_mask) / len(starts)
545
- raise ValueError(f'Window size {size} exceeds trial duration '
546
- f'({min_duration}) for too many trials '
547
- f'({current_ratio * 100}%). Set '
548
- f'accepted_bads_ratio to at least {current_ratio}'
549
- 'and restart training to be able to continue.')
905
+ raise ValueError(
906
+ f"Window size {size} exceeds trial duration "
907
+ f"({min_duration}) for too many trials "
908
+ f"({current_ratio * 100}%). Set "
909
+ f"accepted_bads_ratio to at least {current_ratio}"
910
+ "and restart training to be able to continue."
911
+ )
550
912
 
551
913
  i_window_in_trials, i_trials, window_starts = [], [], []
552
914
  for start_i, (start, stop) in enumerate(zip(starts, stops)):
553
- # Generate possible window starts with given stride between original
554
- # trial onsets (shifted by start_offset) and stops
915
+ # Generate possible window starts, with given stride, between original
916
+ # trial onsets and stops (shifted by start_offset and stop_offset,
917
+ # respectively)
555
918
  possible_starts = np.arange(start, stop, stride)
556
919
 
557
920
  # Possible window start is actually a start, if window size fits in
@@ -571,72 +934,98 @@ def _compute_window_inds(
571
934
  i_window_in_trials.append(i_window_in_trials[-1] + 1)
572
935
  i_trials.append(start_i)
573
936
 
574
- # Update stops to now be event stops instead of trial stops
937
+ # Set window stops to be event stops (rather than trial stops)
575
938
  window_stops = np.array(window_starts) + size
576
939
  if not (len(i_window_in_trials) == len(window_starts) == len(window_stops)):
577
- raise ValueError(f'{len(i_window_in_trials)} == '
578
- f'{len(window_starts)} == {len(window_stops)}')
940
+ raise ValueError(
941
+ f"{len(i_window_in_trials)} == {len(window_starts)} == {len(window_stops)}"
942
+ )
579
943
 
580
944
  return i_trials, i_window_in_trials, window_starts, window_stops
581
945
 
582
946
 
583
947
  def _check_windowing_arguments(
584
- trial_start_offset_samples, trial_stop_offset_samples,
585
- window_size_samples, window_stride_samples):
948
+ trial_start_offset_samples,
949
+ trial_stop_offset_samples,
950
+ window_size_samples,
951
+ window_stride_samples,
952
+ ):
586
953
  assert isinstance(trial_start_offset_samples, (int, np.integer))
587
- assert (isinstance(trial_stop_offset_samples, (int, np.integer)) or
588
- (trial_stop_offset_samples is None))
954
+ assert isinstance(trial_stop_offset_samples, (int, np.integer)) or (
955
+ trial_stop_offset_samples is None
956
+ )
589
957
  assert isinstance(window_size_samples, (int, np.integer, type(None)))
590
958
  assert isinstance(window_stride_samples, (int, np.integer, type(None)))
591
959
  assert (window_size_samples is None) == (window_stride_samples is None)
592
960
  if window_size_samples is not None:
593
- assert window_size_samples > 0, (
594
- "window size has to be larger than 0")
595
- assert window_stride_samples > 0, (
596
- "window stride has to be larger than 0")
597
-
598
-
599
- def _check_and_set_fixed_length_window_arguments(start_offset_samples, stop_offset_samples,
600
- window_size_samples, window_stride_samples,
601
- drop_last_window):
961
+ assert window_size_samples > 0, "window size has to be larger than 0"
962
+ assert window_stride_samples > 0, "window stride has to be larger than 0"
963
+
964
+
965
+ def _check_and_set_fixed_length_window_arguments(
966
+ start_offset_samples,
967
+ stop_offset_samples,
968
+ window_size_samples,
969
+ window_stride_samples,
970
+ drop_last_window,
971
+ lazy_metadata,
972
+ ):
602
973
  """Raises warnings for incorrect input arguments and will set correct default values for
603
974
  stop_offset_samples & drop_last_window, if necessary.
604
975
  """
605
976
  _check_windowing_arguments(
606
- start_offset_samples, stop_offset_samples,
607
- window_size_samples, window_stride_samples)
977
+ start_offset_samples,
978
+ stop_offset_samples,
979
+ window_size_samples,
980
+ window_stride_samples,
981
+ )
608
982
 
609
983
  if stop_offset_samples == 0:
610
984
  warnings.warn(
611
- 'Meaning of `trial_stop_offset_samples`=0 has changed, use `None` '
612
- 'to indicate end of trial/recording. Using `None`.')
985
+ "Meaning of `trial_stop_offset_samples`=0 has changed, use `None` "
986
+ "to indicate end of trial/recording. Using `None`."
987
+ )
613
988
  stop_offset_samples = None
614
989
 
615
990
  if start_offset_samples != 0 or stop_offset_samples is not None:
616
- warnings.warn('Usage of offset_sample args in create_fixed_length_windows is deprecated and'
617
- ' will be removed in future versions. Please use '
618
- 'braindecode.preprocessing.preprocess.Preprocessor("crop", tmin, tmax)'
619
- ' instead.')
620
-
621
- if window_size_samples is not None and window_stride_samples is not None and \
622
- drop_last_window is None:
623
- raise ValueError('drop_last_window must be set if both window_size_samples &'
624
- ' window_stride_samples have also been set')
625
- elif window_size_samples is None and\
626
- window_stride_samples is None and\
627
- drop_last_window is False:
991
+ warnings.warn(
992
+ "Usage of offset_sample args in create_fixed_length_windows is deprecated and"
993
+ " will be removed in future versions. Please use "
994
+ 'braindecode.preprocessing.preprocess.Preprocessor("crop", tmin, tmax)'
995
+ " instead."
996
+ )
997
+
998
+ if (
999
+ window_size_samples is not None
1000
+ and window_stride_samples is not None
1001
+ and drop_last_window is None
1002
+ ):
1003
+ raise ValueError(
1004
+ "drop_last_window must be set if both window_size_samples &"
1005
+ " window_stride_samples have also been set"
1006
+ )
1007
+ elif (
1008
+ window_size_samples is None
1009
+ and window_stride_samples is None
1010
+ and drop_last_window is False
1011
+ ):
628
1012
  # necessary for following assertion
629
1013
  drop_last_window = None
630
1014
 
631
- assert (window_size_samples is None) == \
632
- (window_stride_samples is None) == \
633
- (drop_last_window is None)
634
-
1015
+ assert (
1016
+ (window_size_samples is None)
1017
+ == (window_stride_samples is None)
1018
+ == (drop_last_window is None)
1019
+ )
1020
+ if not drop_last_window and lazy_metadata:
1021
+ raise ValueError(
1022
+ "Cannot have drop_last_window=False and lazy_metadata=True at the same time."
1023
+ )
635
1024
  return stop_offset_samples, drop_last_window
636
1025
 
637
1026
 
638
1027
  def _get_windowing_kwargs(windowing_func_locals):
639
1028
  input_kwargs = windowing_func_locals
640
- input_kwargs.pop('ds')
1029
+ input_kwargs.pop("ds")
641
1030
  windowing_kwargs = {k: v for k, v in input_kwargs.items()}
642
1031
  return windowing_kwargs