braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1037 @@
|
|
|
1
|
+
"""Get epochs from mne.Raw"""
|
|
2
|
+
|
|
3
|
+
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
4
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
5
|
+
# Simon Brandt <simonbrandt@protonmail.com>
|
|
6
|
+
# David Sabbagh <dav.sabbagh@gmail.com>
|
|
7
|
+
# Henrik Bonsmann <henrikbons@gmail.com>
|
|
8
|
+
# Ann-Kathrin Kiessner <ann-kathrin.kiessner@gmx.de>
|
|
9
|
+
# Vytautas Jankauskas <vytauto.jankausko@gmail.com>
|
|
10
|
+
# Dan Wilson <dan.c.wil@gmail.com>
|
|
11
|
+
# Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
12
|
+
# Mohammed Fattouh <mo.fattouh@gmail.com>
|
|
13
|
+
# Robin Schirrmeister <robintibor@gmail.com>
|
|
14
|
+
# Matthew Chen <matt.chen42601@gmail.com>
|
|
15
|
+
#
|
|
16
|
+
# License: BSD (3-clause)
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import warnings
|
|
21
|
+
from typing import Any, Callable
|
|
22
|
+
|
|
23
|
+
import mne
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
from joblib import Parallel, delayed
|
|
27
|
+
from numpy.typing import ArrayLike
|
|
28
|
+
|
|
29
|
+
from ..datasets.base import (
|
|
30
|
+
BaseConcatDataset,
|
|
31
|
+
EEGWindowsDataset,
|
|
32
|
+
RawDataset,
|
|
33
|
+
WindowsDataset,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _LazyDataFrame:
|
|
38
|
+
"""
|
|
39
|
+
DataFrame-like object that lazily computes values (experimental).
|
|
40
|
+
|
|
41
|
+
This class emulates some features of a pandas DataFrame, but computes
|
|
42
|
+
the values on-the-fly when they are accessed. This is useful for
|
|
43
|
+
very long DataFrames with repetitive values.
|
|
44
|
+
Only the methods used by EEGWindowsDataset on its metadata are implemented.
|
|
45
|
+
|
|
46
|
+
Parameters:
|
|
47
|
+
-----------
|
|
48
|
+
length: int
|
|
49
|
+
The length of the dataframe.
|
|
50
|
+
functions: dict[str, Callable[[int], Any]]
|
|
51
|
+
A dictionary mapping column names to functions that take an index and
|
|
52
|
+
return the value of the column at that index.
|
|
53
|
+
columns: list[str]
|
|
54
|
+
The names of the columns in the dataframe.
|
|
55
|
+
series: bool
|
|
56
|
+
Whether the object should emulate a series or a dataframe.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
length: int,
|
|
62
|
+
functions: dict[str, Callable[[int], Any]],
|
|
63
|
+
columns: list[str],
|
|
64
|
+
series: bool = False,
|
|
65
|
+
):
|
|
66
|
+
if not (isinstance(length, int) and length >= 0):
|
|
67
|
+
raise ValueError("Length must be a positive integer.")
|
|
68
|
+
if not all(c in functions for c in columns):
|
|
69
|
+
raise ValueError("All columns must have a corresponding function.")
|
|
70
|
+
if series and len(columns) != 1:
|
|
71
|
+
raise ValueError("Series must have exactly one column.")
|
|
72
|
+
self.length = length
|
|
73
|
+
self.functions = functions
|
|
74
|
+
self.columns = columns
|
|
75
|
+
self.series = series
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def loc(self):
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
def __len__(self):
|
|
82
|
+
return self.length
|
|
83
|
+
|
|
84
|
+
def __getitem__(self, key):
|
|
85
|
+
if not isinstance(key, tuple):
|
|
86
|
+
key = (key, self.columns)
|
|
87
|
+
if len(key) == 1:
|
|
88
|
+
key = (key[0], self.columns)
|
|
89
|
+
if not len(key) == 2:
|
|
90
|
+
raise IndexError(
|
|
91
|
+
f"index must be either [row] or [row, column], got [{', '.join(map(str, key))}]."
|
|
92
|
+
)
|
|
93
|
+
row, col = key
|
|
94
|
+
if col == slice(None): # all columns (i.e., call to df[row, :])
|
|
95
|
+
col = self.columns
|
|
96
|
+
one_col = False
|
|
97
|
+
if isinstance(col, str): # one column
|
|
98
|
+
one_col = True
|
|
99
|
+
col = [col]
|
|
100
|
+
else: # multiple columns
|
|
101
|
+
col = list(col)
|
|
102
|
+
if not all(c in self.columns for c in col):
|
|
103
|
+
raise IndexError(
|
|
104
|
+
f"All columns must be present in the dataframe with columns {self.columns}. Got {col}."
|
|
105
|
+
)
|
|
106
|
+
if row == slice(None): # all rows (i.e., call to df[:] or df[:, col])
|
|
107
|
+
return _LazyDataFrame(self.length, self.functions, col)
|
|
108
|
+
if not isinstance(row, int):
|
|
109
|
+
raise NotImplementedError(
|
|
110
|
+
"Row indexing only supports either a single integer or a null slice (i.e., df[:])."
|
|
111
|
+
)
|
|
112
|
+
if not (0 <= row < self.length):
|
|
113
|
+
raise IndexError(f"Row index {row} is out of bounds.")
|
|
114
|
+
if self.series or one_col:
|
|
115
|
+
return self.functions[col[0]](row)
|
|
116
|
+
return pd.Series({c: self.functions[c](row) for c in col})
|
|
117
|
+
|
|
118
|
+
def to_numpy(self):
|
|
119
|
+
return _LazyDataFrame(
|
|
120
|
+
length=self.length,
|
|
121
|
+
functions=self.functions,
|
|
122
|
+
columns=self.columns,
|
|
123
|
+
series=len(self.columns) == 1,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def to_list(self):
|
|
127
|
+
return self.to_numpy()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class _FixedLengthWindowFunctions:
|
|
131
|
+
"""Class defining functions for lazy metadata generation in fixed length windowing
|
|
132
|
+
to be used in combination with _LazyDataFrame (experimental)."""
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
start_offset_samples: int,
|
|
137
|
+
last_potential_start: int,
|
|
138
|
+
window_stride_samples: int,
|
|
139
|
+
window_size_samples: int,
|
|
140
|
+
target: Any,
|
|
141
|
+
):
|
|
142
|
+
self.start_offset_samples = start_offset_samples
|
|
143
|
+
self.last_potential_start = last_potential_start
|
|
144
|
+
self.window_stride_samples = window_stride_samples
|
|
145
|
+
self.window_size_samples = window_size_samples
|
|
146
|
+
self.target_val = target
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def length(self) -> int:
|
|
150
|
+
return int(
|
|
151
|
+
np.ceil(
|
|
152
|
+
(self.last_potential_start + 1 - self.start_offset_samples)
|
|
153
|
+
/ self.window_stride_samples
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def i_window_in_trial(self, i: int) -> int:
|
|
158
|
+
return i
|
|
159
|
+
|
|
160
|
+
def i_start_in_trial(self, i: int) -> int:
|
|
161
|
+
return self.start_offset_samples + i * self.window_stride_samples
|
|
162
|
+
|
|
163
|
+
def i_stop_in_trial(self, i: int) -> int:
|
|
164
|
+
return (
|
|
165
|
+
self.start_offset_samples
|
|
166
|
+
+ i * self.window_stride_samples
|
|
167
|
+
+ self.window_size_samples
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def target(self, i: int) -> Any:
|
|
171
|
+
return self.target_val
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
|
|
175
|
+
should_use_mne_epochs = (
|
|
176
|
+
(reject is not None)
|
|
177
|
+
or (picks is not None)
|
|
178
|
+
or (flat is not None)
|
|
179
|
+
or (drop_bad_windows is True)
|
|
180
|
+
)
|
|
181
|
+
if use_mne_epochs is None:
|
|
182
|
+
if should_use_mne_epochs:
|
|
183
|
+
warnings.warn(
|
|
184
|
+
"Using reject or picks or flat or dropping bad windows means "
|
|
185
|
+
"mne Epochs are created, "
|
|
186
|
+
"which will be substantially slower and may be deprecated in the future."
|
|
187
|
+
)
|
|
188
|
+
return should_use_mne_epochs
|
|
189
|
+
if not use_mne_epochs and should_use_mne_epochs:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Cannot set use_mne_epochs=False when using reject, picks, flat, or dropping bad windows."
|
|
192
|
+
)
|
|
193
|
+
return use_mne_epochs
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# XXX it's called concat_ds...
|
|
197
|
+
def create_windows_from_events(
|
|
198
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
199
|
+
trial_start_offset_samples: int = 0,
|
|
200
|
+
trial_stop_offset_samples: int = 0,
|
|
201
|
+
window_size_samples: int | None = None,
|
|
202
|
+
window_stride_samples: int | None = None,
|
|
203
|
+
drop_last_window: bool = False,
|
|
204
|
+
mapping: dict[str, int] | None = None,
|
|
205
|
+
preload: bool = False,
|
|
206
|
+
drop_bad_windows: bool | None = None,
|
|
207
|
+
picks: str | ArrayLike | slice | None = None,
|
|
208
|
+
reject: dict[str, float] | None = None,
|
|
209
|
+
flat: dict[str, float] | None = None,
|
|
210
|
+
on_missing: str = "error",
|
|
211
|
+
accepted_bads_ratio: float = 0.0,
|
|
212
|
+
use_mne_epochs: bool | None = None,
|
|
213
|
+
n_jobs: int = 1,
|
|
214
|
+
verbose: bool | str | int | None = "error",
|
|
215
|
+
) -> BaseConcatDataset[WindowsDataset | EEGWindowsDataset]:
|
|
216
|
+
"""Create windows based on events in mne.Raw.
|
|
217
|
+
|
|
218
|
+
This function extracts windows of size window_size_samples in the interval
|
|
219
|
+
[trial onset + trial_start_offset_samples, trial onset + trial duration +
|
|
220
|
+
trial_stop_offset_samples] around each trial, with a separation of
|
|
221
|
+
window_stride_samples between consecutive windows. If the last window
|
|
222
|
+
around an event does not end at trial_stop_offset_samples and
|
|
223
|
+
drop_last_window is set to False, an additional overlapping window that
|
|
224
|
+
ends at trial_stop_offset_samples is created.
|
|
225
|
+
|
|
226
|
+
Windows are extracted from the interval defined by the following::
|
|
227
|
+
|
|
228
|
+
trial onset +
|
|
229
|
+
trial onset duration
|
|
230
|
+
|--------------------|------------------------|-----------------------|
|
|
231
|
+
trial onset - trial onset +
|
|
232
|
+
trial_start_offset_samples duration +
|
|
233
|
+
trial_stop_offset_samples
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
concat_ds: BaseConcatDataset[RawDataset]
|
|
238
|
+
A concat of base datasets each holding raw and description.
|
|
239
|
+
trial_start_offset_samples: int
|
|
240
|
+
Start offset from original trial onsets, in samples. Defaults to zero.
|
|
241
|
+
trial_stop_offset_samples: int
|
|
242
|
+
Stop offset from original trial stop, in samples. Defaults to zero.
|
|
243
|
+
window_size_samples: int | None
|
|
244
|
+
Window size. If None, the window size is inferred from the original
|
|
245
|
+
trial size of the first trial and trial_start_offset_samples and
|
|
246
|
+
trial_stop_offset_samples.
|
|
247
|
+
window_stride_samples: int | None
|
|
248
|
+
Stride between windows, in samples. If None, the window stride is
|
|
249
|
+
inferred from the original trial size of the first trial and
|
|
250
|
+
trial_start_offset_samples and trial_stop_offset_samples.
|
|
251
|
+
drop_last_window: bool
|
|
252
|
+
If False, an additional overlapping window that ends at
|
|
253
|
+
trial_stop_offset_samples will be extracted around each event when the
|
|
254
|
+
last window does not end exactly at trial_stop_offset_samples.
|
|
255
|
+
mapping: dict(str: int)
|
|
256
|
+
Mapping from event description to numerical target value.
|
|
257
|
+
preload: bool
|
|
258
|
+
If True, preload the data of the Epochs objects. This is useful to
|
|
259
|
+
reduce disk reading overhead when returning windows in a training
|
|
260
|
+
scenario, however very large data might not fit into memory.
|
|
261
|
+
drop_bad_windows: bool
|
|
262
|
+
If True, call `.drop_bad()` on the resulting mne.Epochs object. This
|
|
263
|
+
step allows identifying e.g., windows that fall outside of the
|
|
264
|
+
continuous recording. It is suggested to run this step here as otherwise
|
|
265
|
+
the BaseConcatDataset has to be updated as well.
|
|
266
|
+
picks: str | list | slice | None
|
|
267
|
+
Channels to include. If None, all available channels are used. See
|
|
268
|
+
mne.Epochs.
|
|
269
|
+
reject: dict | None
|
|
270
|
+
Epoch rejection parameters based on peak-to-peak amplitude. If None, no
|
|
271
|
+
rejection is done based on peak-to-peak amplitude. See mne.Epochs.
|
|
272
|
+
flat: dict | None
|
|
273
|
+
Epoch rejection parameters based on flatness of signals. If None, no
|
|
274
|
+
rejection based on flatness is done. See mne.Epochs.
|
|
275
|
+
on_missing: str
|
|
276
|
+
What to do if one or several event ids are not found in the recording.
|
|
277
|
+
Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
|
|
278
|
+
accepted_bads_ratio: float, optional
|
|
279
|
+
Acceptable proportion of trials with inconsistent length in a raw. If
|
|
280
|
+
the number of trials whose length is exceeded by the window size is
|
|
281
|
+
smaller than this, then only the corresponding trials are dropped, but
|
|
282
|
+
the computation continues. Otherwise, an error is raised. Defaults to
|
|
283
|
+
0.0 (raise an error).
|
|
284
|
+
use_mne_epochs: bool
|
|
285
|
+
If False, return EEGWindowsDataset objects.
|
|
286
|
+
If True, return mne.Epochs objects encapsulated in WindowsDataset objects,
|
|
287
|
+
which is substantially slower that EEGWindowsDataset.
|
|
288
|
+
n_jobs: int
|
|
289
|
+
Number of jobs to use to parallelize the windowing.
|
|
290
|
+
verbose: bool | str | int | None
|
|
291
|
+
Control verbosity of the logging output when calling mne.Epochs.
|
|
292
|
+
|
|
293
|
+
Returns
|
|
294
|
+
-------
|
|
295
|
+
windows_datasets: BaseConcatDataset[WindowsDataset | EEGWindowsDataset]
|
|
296
|
+
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
297
|
+
"""
|
|
298
|
+
_check_windowing_arguments(
|
|
299
|
+
trial_start_offset_samples,
|
|
300
|
+
trial_stop_offset_samples,
|
|
301
|
+
window_size_samples,
|
|
302
|
+
window_stride_samples,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# If user did not specify mapping, we extract all events from all datasets
|
|
306
|
+
# and map them to increasing integers starting from 0
|
|
307
|
+
infer_mapping = mapping is None
|
|
308
|
+
mapping = dict() if infer_mapping else mapping
|
|
309
|
+
infer_window_size_stride = window_size_samples is None
|
|
310
|
+
|
|
311
|
+
if drop_bad_windows is not None:
|
|
312
|
+
warnings.warn(
|
|
313
|
+
"Drop bad windows only has an effect if mne epochs are created, "
|
|
314
|
+
"and this argument may be removed in the future."
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
use_mne_epochs = _get_use_mne_epochs(
|
|
318
|
+
use_mne_epochs, reject, picks, flat, drop_bad_windows
|
|
319
|
+
)
|
|
320
|
+
if use_mne_epochs and drop_bad_windows is None:
|
|
321
|
+
drop_bad_windows = True
|
|
322
|
+
|
|
323
|
+
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
324
|
+
delayed(_create_windows_from_events)(
|
|
325
|
+
ds,
|
|
326
|
+
infer_mapping,
|
|
327
|
+
infer_window_size_stride,
|
|
328
|
+
trial_start_offset_samples,
|
|
329
|
+
trial_stop_offset_samples,
|
|
330
|
+
window_size_samples,
|
|
331
|
+
window_stride_samples,
|
|
332
|
+
drop_last_window,
|
|
333
|
+
mapping,
|
|
334
|
+
preload,
|
|
335
|
+
drop_bad_windows,
|
|
336
|
+
picks,
|
|
337
|
+
reject,
|
|
338
|
+
flat,
|
|
339
|
+
on_missing,
|
|
340
|
+
accepted_bads_ratio,
|
|
341
|
+
verbose,
|
|
342
|
+
use_mne_epochs,
|
|
343
|
+
)
|
|
344
|
+
for ds in concat_ds.datasets
|
|
345
|
+
)
|
|
346
|
+
return BaseConcatDataset(list_of_windows_ds)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def create_fixed_length_windows(
|
|
350
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
351
|
+
start_offset_samples: int = 0,
|
|
352
|
+
stop_offset_samples: int | None = None,
|
|
353
|
+
window_size_samples: int | None = None,
|
|
354
|
+
window_stride_samples: int | None = None,
|
|
355
|
+
drop_last_window: bool | None = None,
|
|
356
|
+
mapping: dict[str, int] | None = None,
|
|
357
|
+
preload: bool = False,
|
|
358
|
+
picks: str | ArrayLike | slice | None = None,
|
|
359
|
+
reject: dict[str, float] | None = None,
|
|
360
|
+
flat: dict[str, float] | None = None,
|
|
361
|
+
targets_from: str = "metadata",
|
|
362
|
+
last_target_only: bool = True,
|
|
363
|
+
lazy_metadata: bool = False,
|
|
364
|
+
on_missing: str = "error",
|
|
365
|
+
n_jobs: int = 1,
|
|
366
|
+
verbose: bool | str | int | None = "error",
|
|
367
|
+
) -> BaseConcatDataset[EEGWindowsDataset]:
|
|
368
|
+
"""Windower that creates sliding windows.
|
|
369
|
+
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
concat_ds: ConcatDataset[RawDataset]
|
|
373
|
+
A concat of base datasets each holding raw and description.
|
|
374
|
+
start_offset_samples: int
|
|
375
|
+
Start offset from beginning of recording in samples.
|
|
376
|
+
stop_offset_samples: int | None
|
|
377
|
+
Stop offset from beginning of recording in samples. If None, set to be
|
|
378
|
+
the end of the recording.
|
|
379
|
+
window_size_samples: int | None
|
|
380
|
+
Window size in samples. If None, set to be the maximum possible window size, ie length of
|
|
381
|
+
the recording, once offsets are accounted for.
|
|
382
|
+
window_stride_samples: int | None
|
|
383
|
+
Stride between windows in samples. If None, set to be equal to winddow_size_samples, so
|
|
384
|
+
windows will not overlap.
|
|
385
|
+
drop_last_window: bool | None
|
|
386
|
+
Whether or not have a last overlapping window, when windows do not
|
|
387
|
+
equally divide the continuous signal. Must be set to a bool if window size and stride are
|
|
388
|
+
not None.
|
|
389
|
+
mapping: dict(str: int)
|
|
390
|
+
Mapping from event description to target value.
|
|
391
|
+
preload: bool
|
|
392
|
+
If True, preload the data of the Epochs objects.
|
|
393
|
+
picks: str | list | slice | None
|
|
394
|
+
Channels to include. If None, all available channels are used. See
|
|
395
|
+
mne.Epochs.
|
|
396
|
+
reject: dict | None
|
|
397
|
+
Epoch rejection parameters based on peak-to-peak amplitude. If None, no
|
|
398
|
+
rejection is done based on peak-to-peak amplitude. See mne.Epochs.
|
|
399
|
+
flat: dict | None
|
|
400
|
+
Epoch rejection parameters based on flatness of signals. If None, no
|
|
401
|
+
rejection based on flatness is done. See mne.Epochs.
|
|
402
|
+
lazy_metadata: bool
|
|
403
|
+
If True, metadata is not computed immediately, but only when accessed
|
|
404
|
+
by using the _LazyDataFrame (experimental).
|
|
405
|
+
on_missing: str
|
|
406
|
+
What to do if one or several event ids are not found in the recording.
|
|
407
|
+
Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
|
|
408
|
+
n_jobs: int
|
|
409
|
+
Number of jobs to use to parallelize the windowing.
|
|
410
|
+
verbose: bool | str | int | None
|
|
411
|
+
Control verbosity of the logging output when calling mne.Epochs.
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
windows_datasets: BaseConcatDataset[EEGWindowsDataset]
|
|
416
|
+
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
417
|
+
"""
|
|
418
|
+
stop_offset_samples, drop_last_window = (
|
|
419
|
+
_check_and_set_fixed_length_window_arguments(
|
|
420
|
+
start_offset_samples,
|
|
421
|
+
stop_offset_samples,
|
|
422
|
+
window_size_samples,
|
|
423
|
+
window_stride_samples,
|
|
424
|
+
drop_last_window,
|
|
425
|
+
lazy_metadata,
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# check if recordings are of different lengths
|
|
430
|
+
lengths = np.array([ds.raw.n_times for ds in concat_ds.datasets])
|
|
431
|
+
if (np.diff(lengths) != 0).any() and window_size_samples is None:
|
|
432
|
+
warnings.warn("Recordings have different lengths, they will not be batch-able!")
|
|
433
|
+
if (window_size_samples is not None) and any(window_size_samples > lengths):
|
|
434
|
+
raise ValueError(
|
|
435
|
+
f"Window size {window_size_samples} exceeds trial duration {lengths.min()}."
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
439
|
+
delayed(_create_fixed_length_windows)(
|
|
440
|
+
ds,
|
|
441
|
+
start_offset_samples,
|
|
442
|
+
stop_offset_samples,
|
|
443
|
+
window_size_samples,
|
|
444
|
+
window_stride_samples,
|
|
445
|
+
drop_last_window,
|
|
446
|
+
mapping,
|
|
447
|
+
preload,
|
|
448
|
+
picks,
|
|
449
|
+
reject,
|
|
450
|
+
flat,
|
|
451
|
+
targets_from,
|
|
452
|
+
last_target_only,
|
|
453
|
+
lazy_metadata,
|
|
454
|
+
on_missing,
|
|
455
|
+
verbose,
|
|
456
|
+
)
|
|
457
|
+
for ds in concat_ds.datasets
|
|
458
|
+
)
|
|
459
|
+
return BaseConcatDataset(list_of_windows_ds)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def _create_windows_from_events(
|
|
463
|
+
ds,
|
|
464
|
+
infer_mapping,
|
|
465
|
+
infer_window_size_stride,
|
|
466
|
+
trial_start_offset_samples,
|
|
467
|
+
trial_stop_offset_samples,
|
|
468
|
+
window_size_samples=None,
|
|
469
|
+
window_stride_samples=None,
|
|
470
|
+
drop_last_window=False,
|
|
471
|
+
mapping=None,
|
|
472
|
+
preload=False,
|
|
473
|
+
drop_bad_windows=True,
|
|
474
|
+
picks=None,
|
|
475
|
+
reject=None,
|
|
476
|
+
flat=None,
|
|
477
|
+
on_missing="error",
|
|
478
|
+
accepted_bads_ratio=0.0,
|
|
479
|
+
verbose="error",
|
|
480
|
+
use_mne_epochs=False,
|
|
481
|
+
):
|
|
482
|
+
"""Create WindowsDataset from RawDataset based on events.
|
|
483
|
+
|
|
484
|
+
Parameters
|
|
485
|
+
----------
|
|
486
|
+
ds : RawDataset
|
|
487
|
+
Dataset containing continuous data and description.
|
|
488
|
+
infer_mapping : bool
|
|
489
|
+
If True, extract all events from all datasets and map them to
|
|
490
|
+
increasing integers starting from 0.
|
|
491
|
+
infer_window_size_stride : bool
|
|
492
|
+
If True, infer the stride from the original trial size of the first
|
|
493
|
+
trial and trial_start_offset_samples and trial_stop_offset_samples.
|
|
494
|
+
|
|
495
|
+
See `create_windows_from_events` for description of other parameters.
|
|
496
|
+
|
|
497
|
+
Returns
|
|
498
|
+
-------
|
|
499
|
+
EEGWindowsDataset :
|
|
500
|
+
Windowed dataset.
|
|
501
|
+
"""
|
|
502
|
+
# catch window_kwargs to store to dataset
|
|
503
|
+
window_kwargs = [
|
|
504
|
+
(create_windows_from_events.__name__, _get_windowing_kwargs(locals())),
|
|
505
|
+
]
|
|
506
|
+
if infer_mapping:
|
|
507
|
+
unique_events = np.unique(ds.raw.annotations.description)
|
|
508
|
+
new_unique_events = [x for x in unique_events if x not in mapping]
|
|
509
|
+
# mapping event descriptions to integers from 0 on
|
|
510
|
+
max_id_existing_mapping = len(mapping)
|
|
511
|
+
mapping.update(
|
|
512
|
+
{
|
|
513
|
+
event_name: i_event_type + max_id_existing_mapping
|
|
514
|
+
for i_event_type, event_name in enumerate(new_unique_events)
|
|
515
|
+
}
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
events, events_id = mne.events_from_annotations(ds.raw, mapping, verbose=verbose)
|
|
519
|
+
onsets = events[:, 0]
|
|
520
|
+
# Onsets are relative to the beginning of the recording
|
|
521
|
+
filtered_durations = np.array(
|
|
522
|
+
[a["duration"] for a in ds.raw.annotations if a["description"] in events_id]
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
stops = onsets + (filtered_durations * ds.raw.info["sfreq"]).astype(int)
|
|
526
|
+
# XXX This could probably be simplified by using chunk_duration in
|
|
527
|
+
# `events_from_annotations`
|
|
528
|
+
|
|
529
|
+
last_samp = ds.raw.first_samp + ds.raw.n_times - 1
|
|
530
|
+
# `stops` is used exclusively (i.e. `start:stop`), so add back 1
|
|
531
|
+
if stops[-1] + trial_stop_offset_samples > last_samp + 1:
|
|
532
|
+
raise ValueError(
|
|
533
|
+
'"trial_stop_offset_samples" too large. Stop of last trial '
|
|
534
|
+
f'({stops[-1]}) + "trial_stop_offset_samples" '
|
|
535
|
+
f"({trial_stop_offset_samples}) must be smaller than length of"
|
|
536
|
+
f" recording ({len(ds)})."
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
if infer_window_size_stride:
|
|
540
|
+
# window size is trial size
|
|
541
|
+
if window_size_samples is None:
|
|
542
|
+
window_size_samples = (
|
|
543
|
+
stops[0]
|
|
544
|
+
+ trial_stop_offset_samples
|
|
545
|
+
- (onsets[0] + trial_start_offset_samples)
|
|
546
|
+
)
|
|
547
|
+
window_stride_samples = window_size_samples
|
|
548
|
+
this_trial_sizes = (stops + trial_stop_offset_samples) - (
|
|
549
|
+
onsets + trial_start_offset_samples
|
|
550
|
+
)
|
|
551
|
+
# Maybe actually this is not necessary?
|
|
552
|
+
# We could also just say we just assume window size=trial size
|
|
553
|
+
# in case not given, without this condition...
|
|
554
|
+
# but then would have to change functions overall
|
|
555
|
+
checker_trials_size = this_trial_sizes == window_size_samples
|
|
556
|
+
|
|
557
|
+
if not np.all(checker_trials_size):
|
|
558
|
+
trials_drops = int(len(this_trial_sizes) - sum(checker_trials_size))
|
|
559
|
+
warnings.warn(
|
|
560
|
+
f"Dropping trials with different windows size {trials_drops}",
|
|
561
|
+
)
|
|
562
|
+
bads_size_trials = checker_trials_size
|
|
563
|
+
events = events[checker_trials_size]
|
|
564
|
+
onsets = onsets[checker_trials_size]
|
|
565
|
+
stops = stops[checker_trials_size]
|
|
566
|
+
|
|
567
|
+
description = events[:, -1]
|
|
568
|
+
|
|
569
|
+
if not use_mne_epochs:
|
|
570
|
+
onsets = onsets - ds.raw.first_samp
|
|
571
|
+
stops = stops - ds.raw.first_samp
|
|
572
|
+
i_trials, i_window_in_trials, starts, stops = _compute_window_inds(
|
|
573
|
+
onsets,
|
|
574
|
+
stops,
|
|
575
|
+
trial_start_offset_samples,
|
|
576
|
+
trial_stop_offset_samples,
|
|
577
|
+
window_size_samples,
|
|
578
|
+
window_stride_samples,
|
|
579
|
+
drop_last_window,
|
|
580
|
+
accepted_bads_ratio,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
if any(np.diff(starts) <= 0):
|
|
584
|
+
raise NotImplementedError("Trial overlap not implemented.")
|
|
585
|
+
|
|
586
|
+
events = [
|
|
587
|
+
[start, window_size_samples, description[i_trials[i_start]]]
|
|
588
|
+
for i_start, start in enumerate(starts)
|
|
589
|
+
]
|
|
590
|
+
events = np.array(events)
|
|
591
|
+
|
|
592
|
+
description = events[:, -1]
|
|
593
|
+
|
|
594
|
+
metadata = pd.DataFrame(
|
|
595
|
+
{
|
|
596
|
+
"i_window_in_trial": i_window_in_trials,
|
|
597
|
+
"i_start_in_trial": starts,
|
|
598
|
+
"i_stop_in_trial": stops,
|
|
599
|
+
"target": description,
|
|
600
|
+
}
|
|
601
|
+
)
|
|
602
|
+
if use_mne_epochs:
|
|
603
|
+
# window size - 1, since tmax is inclusive
|
|
604
|
+
mne_epochs = mne.Epochs(
|
|
605
|
+
ds.raw,
|
|
606
|
+
events,
|
|
607
|
+
events_id,
|
|
608
|
+
baseline=None,
|
|
609
|
+
tmin=0,
|
|
610
|
+
tmax=(window_size_samples - 1) / ds.raw.info["sfreq"],
|
|
611
|
+
metadata=metadata,
|
|
612
|
+
preload=preload,
|
|
613
|
+
picks=picks,
|
|
614
|
+
reject=reject,
|
|
615
|
+
flat=flat,
|
|
616
|
+
on_missing=on_missing,
|
|
617
|
+
verbose=verbose,
|
|
618
|
+
)
|
|
619
|
+
if drop_bad_windows:
|
|
620
|
+
mne_epochs.drop_bad()
|
|
621
|
+
windows_ds = WindowsDataset(
|
|
622
|
+
mne_epochs,
|
|
623
|
+
ds.description,
|
|
624
|
+
)
|
|
625
|
+
else:
|
|
626
|
+
windows_ds = EEGWindowsDataset(
|
|
627
|
+
ds.raw,
|
|
628
|
+
metadata=metadata,
|
|
629
|
+
description=ds.description,
|
|
630
|
+
)
|
|
631
|
+
# add window_kwargs and raw_preproc_kwargs to windows dataset
|
|
632
|
+
setattr(windows_ds, "window_kwargs", window_kwargs)
|
|
633
|
+
kwargs_name = "raw_preproc_kwargs"
|
|
634
|
+
if hasattr(ds, kwargs_name):
|
|
635
|
+
setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
|
|
636
|
+
return windows_ds
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def _create_fixed_length_windows(
|
|
640
|
+
ds,
|
|
641
|
+
start_offset_samples,
|
|
642
|
+
stop_offset_samples,
|
|
643
|
+
window_size_samples,
|
|
644
|
+
window_stride_samples,
|
|
645
|
+
drop_last_window,
|
|
646
|
+
mapping=None,
|
|
647
|
+
preload=False,
|
|
648
|
+
picks=None,
|
|
649
|
+
reject=None,
|
|
650
|
+
flat=None,
|
|
651
|
+
targets_from="metadata",
|
|
652
|
+
last_target_only=True,
|
|
653
|
+
lazy_metadata=False,
|
|
654
|
+
on_missing="error",
|
|
655
|
+
verbose="error",
|
|
656
|
+
):
|
|
657
|
+
"""Create WindowsDataset from RawDataset with sliding windows.
|
|
658
|
+
|
|
659
|
+
Parameters
|
|
660
|
+
----------
|
|
661
|
+
ds : RawDataset
|
|
662
|
+
Dataset containing continuous data and description.
|
|
663
|
+
|
|
664
|
+
See `create_fixed_length_windows` for description of other parameters.
|
|
665
|
+
|
|
666
|
+
Returns
|
|
667
|
+
-------
|
|
668
|
+
WindowsDataset :
|
|
669
|
+
Windowed dataset.
|
|
670
|
+
"""
|
|
671
|
+
# catch window_kwargs to store to dataset
|
|
672
|
+
window_kwargs = [
|
|
673
|
+
(create_fixed_length_windows.__name__, _get_windowing_kwargs(locals())),
|
|
674
|
+
]
|
|
675
|
+
stop = ds.raw.n_times if stop_offset_samples is None else stop_offset_samples
|
|
676
|
+
|
|
677
|
+
# assume window should be whole recording
|
|
678
|
+
if window_size_samples is None:
|
|
679
|
+
window_size_samples = stop - start_offset_samples
|
|
680
|
+
if window_stride_samples is None:
|
|
681
|
+
window_stride_samples = window_size_samples
|
|
682
|
+
|
|
683
|
+
last_potential_start = stop - window_size_samples
|
|
684
|
+
|
|
685
|
+
# get targets from dataset description if they exist
|
|
686
|
+
target = -1 if ds.target_name is None else ds.description[ds.target_name]
|
|
687
|
+
if mapping is not None:
|
|
688
|
+
# in case of multiple targets
|
|
689
|
+
if isinstance(target, pd.Series):
|
|
690
|
+
target = target.replace(mapping).to_list()
|
|
691
|
+
# in case of single value target
|
|
692
|
+
else:
|
|
693
|
+
target = mapping[target]
|
|
694
|
+
|
|
695
|
+
if lazy_metadata:
|
|
696
|
+
factory = _FixedLengthWindowFunctions(
|
|
697
|
+
start_offset_samples,
|
|
698
|
+
last_potential_start,
|
|
699
|
+
window_stride_samples,
|
|
700
|
+
window_size_samples,
|
|
701
|
+
target,
|
|
702
|
+
)
|
|
703
|
+
metadata = _LazyDataFrame(
|
|
704
|
+
length=factory.length,
|
|
705
|
+
functions={
|
|
706
|
+
"i_window_in_trial": factory.i_window_in_trial,
|
|
707
|
+
"i_start_in_trial": factory.i_start_in_trial,
|
|
708
|
+
"i_stop_in_trial": factory.i_stop_in_trial,
|
|
709
|
+
"target": factory.target,
|
|
710
|
+
},
|
|
711
|
+
columns=[
|
|
712
|
+
"i_window_in_trial",
|
|
713
|
+
"i_start_in_trial",
|
|
714
|
+
"i_stop_in_trial",
|
|
715
|
+
"target",
|
|
716
|
+
],
|
|
717
|
+
)
|
|
718
|
+
else:
|
|
719
|
+
# already includes last incomplete window start
|
|
720
|
+
starts = np.arange(
|
|
721
|
+
start_offset_samples, last_potential_start + 1, window_stride_samples
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
if not drop_last_window and starts[-1] < last_potential_start:
|
|
725
|
+
# if last window does not end at trial stop, make it stop there
|
|
726
|
+
starts = np.append(starts, last_potential_start)
|
|
727
|
+
|
|
728
|
+
metadata = pd.DataFrame(
|
|
729
|
+
{
|
|
730
|
+
"i_window_in_trial": np.arange(len(starts)),
|
|
731
|
+
"i_start_in_trial": starts,
|
|
732
|
+
"i_stop_in_trial": starts + window_size_samples,
|
|
733
|
+
"target": len(starts) * [target],
|
|
734
|
+
}
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
window_kwargs.append(
|
|
738
|
+
(
|
|
739
|
+
EEGWindowsDataset.__name__,
|
|
740
|
+
{"targets_from": targets_from, "last_target_only": last_target_only},
|
|
741
|
+
)
|
|
742
|
+
)
|
|
743
|
+
windows_ds = EEGWindowsDataset(
|
|
744
|
+
ds.raw,
|
|
745
|
+
metadata=metadata,
|
|
746
|
+
description=ds.description,
|
|
747
|
+
targets_from=targets_from,
|
|
748
|
+
last_target_only=last_target_only,
|
|
749
|
+
)
|
|
750
|
+
# add window_kwargs and raw_preproc_kwargs to windows dataset
|
|
751
|
+
setattr(windows_ds, "window_kwargs", window_kwargs)
|
|
752
|
+
kwargs_name = "raw_preproc_kwargs"
|
|
753
|
+
if hasattr(ds, kwargs_name):
|
|
754
|
+
setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
|
|
755
|
+
return windows_ds
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def create_windows_from_target_channels(
|
|
759
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
760
|
+
window_size_samples=None,
|
|
761
|
+
preload=False,
|
|
762
|
+
picks=None,
|
|
763
|
+
reject=None,
|
|
764
|
+
flat=None,
|
|
765
|
+
n_jobs=1,
|
|
766
|
+
last_target_only=True,
|
|
767
|
+
verbose="error",
|
|
768
|
+
) -> BaseConcatDataset[EEGWindowsDataset]:
|
|
769
|
+
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
770
|
+
delayed(_create_windows_from_target_channels)(
|
|
771
|
+
ds,
|
|
772
|
+
window_size_samples,
|
|
773
|
+
preload,
|
|
774
|
+
picks,
|
|
775
|
+
reject,
|
|
776
|
+
flat,
|
|
777
|
+
last_target_only,
|
|
778
|
+
"error",
|
|
779
|
+
verbose,
|
|
780
|
+
)
|
|
781
|
+
for ds in concat_ds.datasets
|
|
782
|
+
)
|
|
783
|
+
return BaseConcatDataset(list_of_windows_ds)
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
def _create_windows_from_target_channels(
|
|
787
|
+
ds,
|
|
788
|
+
window_size_samples,
|
|
789
|
+
preload=False,
|
|
790
|
+
picks=None,
|
|
791
|
+
reject=None,
|
|
792
|
+
flat=None,
|
|
793
|
+
last_target_only=True,
|
|
794
|
+
on_missing="error",
|
|
795
|
+
verbose="error",
|
|
796
|
+
):
|
|
797
|
+
"""Create WindowsDataset from RawDataset using targets `misc` channels from mne.Raw.
|
|
798
|
+
|
|
799
|
+
Parameters
|
|
800
|
+
----------
|
|
801
|
+
ds : RawDataset
|
|
802
|
+
Dataset containing continuous data and description.
|
|
803
|
+
|
|
804
|
+
See `create_fixed_length_windows` for description of other parameters.
|
|
805
|
+
|
|
806
|
+
Returns
|
|
807
|
+
-------
|
|
808
|
+
WindowsDataset :
|
|
809
|
+
Windowed dataset.
|
|
810
|
+
"""
|
|
811
|
+
window_kwargs = [
|
|
812
|
+
(create_windows_from_target_channels.__name__, _get_windowing_kwargs(locals())),
|
|
813
|
+
]
|
|
814
|
+
stop = ds.raw.n_times + ds.raw.first_samp
|
|
815
|
+
|
|
816
|
+
target = ds.raw.get_data(picks="misc")
|
|
817
|
+
# TODO: handle multi targets present only for some events
|
|
818
|
+
stops = np.nonzero((~np.isnan(target[0, :])))[0] + 1
|
|
819
|
+
stops = stops[(stops < stop) & (stops >= window_size_samples)]
|
|
820
|
+
stops = stops.astype(int)
|
|
821
|
+
metadata = pd.DataFrame(
|
|
822
|
+
{
|
|
823
|
+
"i_window_in_trial": np.arange(len(stops)),
|
|
824
|
+
"i_start_in_trial": stops - window_size_samples,
|
|
825
|
+
"i_stop_in_trial": stops,
|
|
826
|
+
"target": len(stops) * [target],
|
|
827
|
+
}
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
targets_from = "channels"
|
|
831
|
+
window_kwargs.append(
|
|
832
|
+
(
|
|
833
|
+
EEGWindowsDataset.__name__,
|
|
834
|
+
{"targets_from": targets_from, "last_target_only": last_target_only},
|
|
835
|
+
)
|
|
836
|
+
)
|
|
837
|
+
windows_ds = EEGWindowsDataset(
|
|
838
|
+
ds.raw,
|
|
839
|
+
metadata=metadata,
|
|
840
|
+
description=ds.description,
|
|
841
|
+
targets_from=targets_from,
|
|
842
|
+
last_target_only=last_target_only,
|
|
843
|
+
)
|
|
844
|
+
setattr(windows_ds, "window_kwargs", window_kwargs)
|
|
845
|
+
kwargs_name = "raw_preproc_kwargs"
|
|
846
|
+
if hasattr(ds, kwargs_name):
|
|
847
|
+
setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
|
|
848
|
+
return windows_ds
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def _compute_window_inds(
|
|
852
|
+
starts,
|
|
853
|
+
stops,
|
|
854
|
+
start_offset,
|
|
855
|
+
stop_offset,
|
|
856
|
+
size,
|
|
857
|
+
stride,
|
|
858
|
+
drop_last_window,
|
|
859
|
+
accepted_bads_ratio,
|
|
860
|
+
):
|
|
861
|
+
"""Compute window start and stop indices.
|
|
862
|
+
|
|
863
|
+
Create window starts from trial onsets (shifted by start_offset) to trial
|
|
864
|
+
end (shifted by stop_offset) separated by stride, as long as window size
|
|
865
|
+
fits into trial.
|
|
866
|
+
|
|
867
|
+
Parameters
|
|
868
|
+
----------
|
|
869
|
+
starts: array-like
|
|
870
|
+
Trial starts in samples.
|
|
871
|
+
stops: array-like
|
|
872
|
+
Trial stops in samples.
|
|
873
|
+
start_offset: int
|
|
874
|
+
Start offset from original trial onsets in samples.
|
|
875
|
+
stop_offset: int
|
|
876
|
+
Stop offset from original trial stop in samples.
|
|
877
|
+
size: int
|
|
878
|
+
Window size.
|
|
879
|
+
stride: int
|
|
880
|
+
Stride between windows.
|
|
881
|
+
drop_last_window: bool
|
|
882
|
+
Toggles of shifting last window within range or dropping last samples.
|
|
883
|
+
accepted_bads_ratio: float
|
|
884
|
+
Acceptable proportion of bad trials within a raw. If the number of
|
|
885
|
+
trials whose length is exceeded by the window size is smaller than
|
|
886
|
+
this, then only the corresponding trials are dropped, but the
|
|
887
|
+
computation continues. Otherwise, an error is raised.
|
|
888
|
+
|
|
889
|
+
Returns
|
|
890
|
+
-------
|
|
891
|
+
result_lists: (list, list, list, list)
|
|
892
|
+
Trial, i_window_in_trial, start sample and stop sample of windows.
|
|
893
|
+
"""
|
|
894
|
+
starts = np.array([starts]) if isinstance(starts, int) else starts
|
|
895
|
+
stops = np.array([stops]) if isinstance(stops, int) else stops
|
|
896
|
+
|
|
897
|
+
starts += start_offset
|
|
898
|
+
stops += stop_offset
|
|
899
|
+
if any(size > (stops - starts)):
|
|
900
|
+
bads_mask = size > (stops - starts)
|
|
901
|
+
min_duration = (stops - starts).min()
|
|
902
|
+
if sum(bads_mask) <= accepted_bads_ratio * len(starts):
|
|
903
|
+
starts = starts[np.logical_not(bads_mask)]
|
|
904
|
+
stops = stops[np.logical_not(bads_mask)]
|
|
905
|
+
warnings.warn(
|
|
906
|
+
f"Trials {np.where(bads_mask)[0]} are being dropped as the "
|
|
907
|
+
f"window size ({size}) exceeds their duration {min_duration}."
|
|
908
|
+
)
|
|
909
|
+
else:
|
|
910
|
+
current_ratio = sum(bads_mask) / len(starts)
|
|
911
|
+
raise ValueError(
|
|
912
|
+
f"Window size {size} exceeds trial duration "
|
|
913
|
+
f"({min_duration}) for too many trials "
|
|
914
|
+
f"({current_ratio * 100}%). Set "
|
|
915
|
+
f"accepted_bads_ratio to at least {current_ratio}"
|
|
916
|
+
"and restart training to be able to continue."
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
i_window_in_trials, i_trials, window_starts = [], [], []
|
|
920
|
+
for start_i, (start, stop) in enumerate(zip(starts, stops)):
|
|
921
|
+
# Generate possible window starts, with given stride, between original
|
|
922
|
+
# trial onsets and stops (shifted by start_offset and stop_offset,
|
|
923
|
+
# respectively)
|
|
924
|
+
possible_starts = np.arange(start, stop, stride)
|
|
925
|
+
|
|
926
|
+
# Possible window start is actually a start, if window size fits in
|
|
927
|
+
# trial start and stop
|
|
928
|
+
for i_window, s in enumerate(possible_starts):
|
|
929
|
+
if (s + size) <= stop:
|
|
930
|
+
window_starts.append(s)
|
|
931
|
+
i_window_in_trials.append(i_window)
|
|
932
|
+
i_trials.append(start_i)
|
|
933
|
+
|
|
934
|
+
# If the last window start + window size is not the same as
|
|
935
|
+
# stop + stop_offset, create another window that overlaps and stops
|
|
936
|
+
# at onset + stop_offset
|
|
937
|
+
if not drop_last_window:
|
|
938
|
+
if window_starts[-1] + size != stop:
|
|
939
|
+
window_starts.append(stop - size)
|
|
940
|
+
i_window_in_trials.append(i_window_in_trials[-1] + 1)
|
|
941
|
+
i_trials.append(start_i)
|
|
942
|
+
|
|
943
|
+
# Set window stops to be event stops (rather than trial stops)
|
|
944
|
+
window_stops = np.array(window_starts) + size
|
|
945
|
+
if not (len(i_window_in_trials) == len(window_starts) == len(window_stops)):
|
|
946
|
+
raise ValueError(
|
|
947
|
+
f"{len(i_window_in_trials)} == {len(window_starts)} == {len(window_stops)}"
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
return i_trials, i_window_in_trials, window_starts, window_stops
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def _check_windowing_arguments(
|
|
954
|
+
trial_start_offset_samples,
|
|
955
|
+
trial_stop_offset_samples,
|
|
956
|
+
window_size_samples,
|
|
957
|
+
window_stride_samples,
|
|
958
|
+
):
|
|
959
|
+
assert isinstance(trial_start_offset_samples, (int, np.integer))
|
|
960
|
+
assert isinstance(trial_stop_offset_samples, (int, np.integer)) or (
|
|
961
|
+
trial_stop_offset_samples is None
|
|
962
|
+
)
|
|
963
|
+
assert isinstance(window_size_samples, (int, np.integer, type(None)))
|
|
964
|
+
assert isinstance(window_stride_samples, (int, np.integer, type(None)))
|
|
965
|
+
assert (window_size_samples is None) == (window_stride_samples is None)
|
|
966
|
+
if window_size_samples is not None:
|
|
967
|
+
assert window_size_samples > 0, "window size has to be larger than 0"
|
|
968
|
+
assert window_stride_samples > 0, "window stride has to be larger than 0"
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
def _check_and_set_fixed_length_window_arguments(
|
|
972
|
+
start_offset_samples,
|
|
973
|
+
stop_offset_samples,
|
|
974
|
+
window_size_samples,
|
|
975
|
+
window_stride_samples,
|
|
976
|
+
drop_last_window,
|
|
977
|
+
lazy_metadata,
|
|
978
|
+
):
|
|
979
|
+
"""Raises warnings for incorrect input arguments and will set correct default values for
|
|
980
|
+
stop_offset_samples & drop_last_window, if necessary.
|
|
981
|
+
"""
|
|
982
|
+
_check_windowing_arguments(
|
|
983
|
+
start_offset_samples,
|
|
984
|
+
stop_offset_samples,
|
|
985
|
+
window_size_samples,
|
|
986
|
+
window_stride_samples,
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
if stop_offset_samples == 0:
|
|
990
|
+
warnings.warn(
|
|
991
|
+
"Meaning of `trial_stop_offset_samples`=0 has changed, use `None` "
|
|
992
|
+
"to indicate end of trial/recording. Using `None`."
|
|
993
|
+
)
|
|
994
|
+
stop_offset_samples = None
|
|
995
|
+
|
|
996
|
+
if start_offset_samples != 0 or stop_offset_samples is not None:
|
|
997
|
+
warnings.warn(
|
|
998
|
+
"Usage of offset_sample args in create_fixed_length_windows is deprecated and"
|
|
999
|
+
" will be removed in future versions. Please use "
|
|
1000
|
+
'braindecode.preprocessing.preprocess.Preprocessor("crop", tmin, tmax)'
|
|
1001
|
+
" instead."
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
if (
|
|
1005
|
+
window_size_samples is not None
|
|
1006
|
+
and window_stride_samples is not None
|
|
1007
|
+
and drop_last_window is None
|
|
1008
|
+
):
|
|
1009
|
+
raise ValueError(
|
|
1010
|
+
"drop_last_window must be set if both window_size_samples &"
|
|
1011
|
+
" window_stride_samples have also been set"
|
|
1012
|
+
)
|
|
1013
|
+
elif (
|
|
1014
|
+
window_size_samples is None
|
|
1015
|
+
and window_stride_samples is None
|
|
1016
|
+
and drop_last_window is False
|
|
1017
|
+
):
|
|
1018
|
+
# necessary for following assertion
|
|
1019
|
+
drop_last_window = None
|
|
1020
|
+
|
|
1021
|
+
assert (
|
|
1022
|
+
(window_size_samples is None)
|
|
1023
|
+
== (window_stride_samples is None)
|
|
1024
|
+
== (drop_last_window is None)
|
|
1025
|
+
)
|
|
1026
|
+
if not drop_last_window and lazy_metadata:
|
|
1027
|
+
raise ValueError(
|
|
1028
|
+
"Cannot have drop_last_window=False and lazy_metadata=True at the same time."
|
|
1029
|
+
)
|
|
1030
|
+
return stop_offset_samples, drop_last_window
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
def _get_windowing_kwargs(windowing_func_locals):
|
|
1034
|
+
input_kwargs = windowing_func_locals
|
|
1035
|
+
input_kwargs.pop("ds")
|
|
1036
|
+
windowing_kwargs = {k: v for k, v in input_kwargs.items()}
|
|
1037
|
+
return windowing_kwargs
|