braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
"""Get epochs from mne.Raw
|
|
2
|
-
"""
|
|
1
|
+
"""Get epochs from mne.Raw"""
|
|
3
2
|
|
|
4
3
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
5
4
|
# Lukas Gemein <l.gemein@gmail.com>
|
|
@@ -15,23 +14,199 @@
|
|
|
15
14
|
#
|
|
16
15
|
# License: BSD (3-clause)
|
|
17
16
|
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
18
19
|
import warnings
|
|
20
|
+
from typing import Any, Callable
|
|
19
21
|
|
|
20
|
-
import numpy as np
|
|
21
22
|
import mne
|
|
23
|
+
import numpy as np
|
|
22
24
|
import pandas as pd
|
|
23
25
|
from joblib import Parallel, delayed
|
|
26
|
+
from numpy.typing import ArrayLike
|
|
24
27
|
|
|
25
|
-
from ..datasets.base import
|
|
28
|
+
from ..datasets.base import BaseConcatDataset, EEGWindowsDataset, WindowsDataset
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _LazyDataFrame:
|
|
32
|
+
"""
|
|
33
|
+
DataFrame-like object that lazily computes values (experimental).
|
|
34
|
+
|
|
35
|
+
This class emulates some features of a pandas DataFrame, but computes
|
|
36
|
+
the values on-the-fly when they are accessed. This is useful for
|
|
37
|
+
very long DataFrames with repetitive values.
|
|
38
|
+
Only the methods used by EEGWindowsDataset on its metadata are implemented.
|
|
39
|
+
|
|
40
|
+
Parameters:
|
|
41
|
+
-----------
|
|
42
|
+
length: int
|
|
43
|
+
The length of the dataframe.
|
|
44
|
+
functions: dict[str, Callable[[int], Any]]
|
|
45
|
+
A dictionary mapping column names to functions that take an index and
|
|
46
|
+
return the value of the column at that index.
|
|
47
|
+
columns: list[str]
|
|
48
|
+
The names of the columns in the dataframe.
|
|
49
|
+
series: bool
|
|
50
|
+
Whether the object should emulate a series or a dataframe.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
length: int,
|
|
56
|
+
functions: dict[str, Callable[[int], Any]],
|
|
57
|
+
columns: list[str],
|
|
58
|
+
series: bool = False,
|
|
59
|
+
):
|
|
60
|
+
if not (isinstance(length, int) and length >= 0):
|
|
61
|
+
raise ValueError("Length must be a positive integer.")
|
|
62
|
+
if not all(c in functions for c in columns):
|
|
63
|
+
raise ValueError("All columns must have a corresponding function.")
|
|
64
|
+
if series and len(columns) != 1:
|
|
65
|
+
raise ValueError("Series must have exactly one column.")
|
|
66
|
+
self.length = length
|
|
67
|
+
self.functions = functions
|
|
68
|
+
self.columns = columns
|
|
69
|
+
self.series = series
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def loc(self):
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def __len__(self):
|
|
76
|
+
return self.length
|
|
77
|
+
|
|
78
|
+
def __getitem__(self, key):
|
|
79
|
+
if not isinstance(key, tuple):
|
|
80
|
+
key = (key, self.columns)
|
|
81
|
+
if len(key) == 1:
|
|
82
|
+
key = (key[0], self.columns)
|
|
83
|
+
if not len(key) == 2:
|
|
84
|
+
raise IndexError(
|
|
85
|
+
f"index must be either [row] or [row, column], got [{', '.join(map(str, key))}]."
|
|
86
|
+
)
|
|
87
|
+
row, col = key
|
|
88
|
+
if col == slice(None): # all columns (i.e., call to df[row, :])
|
|
89
|
+
col = self.columns
|
|
90
|
+
one_col = False
|
|
91
|
+
if isinstance(col, str): # one column
|
|
92
|
+
one_col = True
|
|
93
|
+
col = [col]
|
|
94
|
+
else: # multiple columns
|
|
95
|
+
col = list(col)
|
|
96
|
+
if not all(c in self.columns for c in col):
|
|
97
|
+
raise IndexError(
|
|
98
|
+
f"All columns must be present in the dataframe with columns {self.columns}. Got {col}."
|
|
99
|
+
)
|
|
100
|
+
if row == slice(None): # all rows (i.e., call to df[:] or df[:, col])
|
|
101
|
+
return _LazyDataFrame(self.length, self.functions, col)
|
|
102
|
+
if not isinstance(row, int):
|
|
103
|
+
raise NotImplementedError(
|
|
104
|
+
"Row indexing only supports either a single integer or a null slice (i.e., df[:])."
|
|
105
|
+
)
|
|
106
|
+
if not (0 <= row < self.length):
|
|
107
|
+
raise IndexError(f"Row index {row} is out of bounds.")
|
|
108
|
+
if self.series or one_col:
|
|
109
|
+
return self.functions[col[0]](row)
|
|
110
|
+
return pd.Series({c: self.functions[c](row) for c in col})
|
|
111
|
+
|
|
112
|
+
def to_numpy(self):
|
|
113
|
+
return _LazyDataFrame(
|
|
114
|
+
length=self.length,
|
|
115
|
+
functions=self.functions,
|
|
116
|
+
columns=self.columns,
|
|
117
|
+
series=len(self.columns) == 1,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def to_list(self):
|
|
121
|
+
return self.to_numpy()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class _FixedLengthWindowFunctions:
|
|
125
|
+
"""Class defining functions for lazy metadata generation in fixed length windowing
|
|
126
|
+
to be used in combination with _LazyDataFrame (experimental)."""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
start_offset_samples: int,
|
|
131
|
+
last_potential_start: int,
|
|
132
|
+
window_stride_samples: int,
|
|
133
|
+
window_size_samples: int,
|
|
134
|
+
target: Any,
|
|
135
|
+
):
|
|
136
|
+
self.start_offset_samples = start_offset_samples
|
|
137
|
+
self.last_potential_start = last_potential_start
|
|
138
|
+
self.window_stride_samples = window_stride_samples
|
|
139
|
+
self.window_size_samples = window_size_samples
|
|
140
|
+
self.target_val = target
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def length(self) -> int:
|
|
144
|
+
return int(
|
|
145
|
+
np.ceil(
|
|
146
|
+
(self.last_potential_start + 1 - self.start_offset_samples)
|
|
147
|
+
/ self.window_stride_samples
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def i_window_in_trial(self, i: int) -> int:
|
|
152
|
+
return i
|
|
153
|
+
|
|
154
|
+
def i_start_in_trial(self, i: int) -> int:
|
|
155
|
+
return self.start_offset_samples + i * self.window_stride_samples
|
|
156
|
+
|
|
157
|
+
def i_stop_in_trial(self, i: int) -> int:
|
|
158
|
+
return (
|
|
159
|
+
self.start_offset_samples
|
|
160
|
+
+ i * self.window_stride_samples
|
|
161
|
+
+ self.window_size_samples
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def target(self, i: int) -> Any:
|
|
165
|
+
return self.target_val
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
|
|
169
|
+
should_use_mne_epochs = (
|
|
170
|
+
(reject is not None)
|
|
171
|
+
or (picks is not None)
|
|
172
|
+
or (flat is not None)
|
|
173
|
+
or (drop_bad_windows is True)
|
|
174
|
+
)
|
|
175
|
+
if use_mne_epochs is None:
|
|
176
|
+
if should_use_mne_epochs:
|
|
177
|
+
warnings.warn(
|
|
178
|
+
"Using reject or picks or flat or dropping bad windows means "
|
|
179
|
+
"mne Epochs are created, "
|
|
180
|
+
"which will be substantially slower and may be deprecated in the future."
|
|
181
|
+
)
|
|
182
|
+
return should_use_mne_epochs
|
|
183
|
+
if not use_mne_epochs and should_use_mne_epochs:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
"Cannot set use_mne_epochs=False when using reject, picks, flat, or dropping bad windows."
|
|
186
|
+
)
|
|
187
|
+
return use_mne_epochs
|
|
26
188
|
|
|
27
189
|
|
|
28
190
|
# XXX it's called concat_ds...
|
|
29
191
|
def create_windows_from_events(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
192
|
+
concat_ds: BaseConcatDataset,
|
|
193
|
+
trial_start_offset_samples: int = 0,
|
|
194
|
+
trial_stop_offset_samples: int = 0,
|
|
195
|
+
window_size_samples: int | None = None,
|
|
196
|
+
window_stride_samples: int | None = None,
|
|
197
|
+
drop_last_window: bool = False,
|
|
198
|
+
mapping: dict[str, int] | None = None,
|
|
199
|
+
preload: bool = False,
|
|
200
|
+
drop_bad_windows: bool | None = None,
|
|
201
|
+
picks: str | ArrayLike | slice | None = None,
|
|
202
|
+
reject: dict[str, float] | None = None,
|
|
203
|
+
flat: dict[str, float] | None = None,
|
|
204
|
+
on_missing: str = "error",
|
|
205
|
+
accepted_bads_ratio: float = 0.0,
|
|
206
|
+
use_mne_epochs: bool | None = None,
|
|
207
|
+
n_jobs: int = 1,
|
|
208
|
+
verbose: bool | str | int | None = "error",
|
|
209
|
+
):
|
|
35
210
|
"""Create windows based on events in mne.Raw.
|
|
36
211
|
|
|
37
212
|
This function extracts windows of size window_size_samples in the interval
|
|
@@ -100,6 +275,10 @@ def create_windows_from_events(
|
|
|
100
275
|
smaller than this, then only the corresponding trials are dropped, but
|
|
101
276
|
the computation continues. Otherwise, an error is raised. Defaults to
|
|
102
277
|
0.0 (raise an error).
|
|
278
|
+
use_mne_epochs: bool
|
|
279
|
+
If False, return EEGWindowsDataset objects.
|
|
280
|
+
If True, return mne.Epochs objects encapsulated in WindowsDataset objects,
|
|
281
|
+
which is substantially slower that EEGWindowsDataset.
|
|
103
282
|
n_jobs: int
|
|
104
283
|
Number of jobs to use to parallelize the windowing.
|
|
105
284
|
verbose: bool | str | int | None
|
|
@@ -111,8 +290,11 @@ def create_windows_from_events(
|
|
|
111
290
|
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
112
291
|
"""
|
|
113
292
|
_check_windowing_arguments(
|
|
114
|
-
trial_start_offset_samples,
|
|
115
|
-
|
|
293
|
+
trial_start_offset_samples,
|
|
294
|
+
trial_stop_offset_samples,
|
|
295
|
+
window_size_samples,
|
|
296
|
+
window_stride_samples,
|
|
297
|
+
)
|
|
116
298
|
|
|
117
299
|
# If user did not specify mapping, we extract all events from all datasets
|
|
118
300
|
# and map them to increasing integers starting from 0
|
|
@@ -121,34 +303,62 @@ def create_windows_from_events(
|
|
|
121
303
|
infer_window_size_stride = window_size_samples is None
|
|
122
304
|
|
|
123
305
|
if drop_bad_windows is not None:
|
|
124
|
-
warnings.warn(
|
|
125
|
-
|
|
306
|
+
warnings.warn(
|
|
307
|
+
"Drop bad windows only has an effect if mne epochs are created, "
|
|
308
|
+
"and this argument may be removed in the future."
|
|
309
|
+
)
|
|
126
310
|
|
|
127
|
-
use_mne_epochs = (
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
'which will be substantially slower and may be deprecated in the future.')
|
|
133
|
-
if drop_bad_windows is None:
|
|
134
|
-
drop_bad_windows = True
|
|
311
|
+
use_mne_epochs = _get_use_mne_epochs(
|
|
312
|
+
use_mne_epochs, reject, picks, flat, drop_bad_windows
|
|
313
|
+
)
|
|
314
|
+
if use_mne_epochs and drop_bad_windows is None:
|
|
315
|
+
drop_bad_windows = True
|
|
135
316
|
|
|
136
317
|
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
137
318
|
delayed(_create_windows_from_events)(
|
|
138
|
-
ds,
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
319
|
+
ds,
|
|
320
|
+
infer_mapping,
|
|
321
|
+
infer_window_size_stride,
|
|
322
|
+
trial_start_offset_samples,
|
|
323
|
+
trial_stop_offset_samples,
|
|
324
|
+
window_size_samples,
|
|
325
|
+
window_stride_samples,
|
|
326
|
+
drop_last_window,
|
|
327
|
+
mapping,
|
|
328
|
+
preload,
|
|
329
|
+
drop_bad_windows,
|
|
330
|
+
picks,
|
|
331
|
+
reject,
|
|
332
|
+
flat,
|
|
333
|
+
on_missing,
|
|
334
|
+
accepted_bads_ratio,
|
|
335
|
+
verbose,
|
|
336
|
+
use_mne_epochs,
|
|
337
|
+
)
|
|
338
|
+
for ds in concat_ds.datasets
|
|
339
|
+
)
|
|
143
340
|
return BaseConcatDataset(list_of_windows_ds)
|
|
144
341
|
|
|
145
342
|
|
|
146
343
|
def create_fixed_length_windows(
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
344
|
+
concat_ds: BaseConcatDataset,
|
|
345
|
+
start_offset_samples: int = 0,
|
|
346
|
+
stop_offset_samples: int | None = None,
|
|
347
|
+
window_size_samples: int | None = None,
|
|
348
|
+
window_stride_samples: int | None = None,
|
|
349
|
+
drop_last_window: bool | None = None,
|
|
350
|
+
mapping: dict[str, int] | None = None,
|
|
351
|
+
preload: bool = False,
|
|
352
|
+
picks: str | ArrayLike | slice | None = None,
|
|
353
|
+
reject: dict[str, float] | None = None,
|
|
354
|
+
flat: dict[str, float] | None = None,
|
|
355
|
+
targets_from: str = "metadata",
|
|
356
|
+
last_target_only: bool = True,
|
|
357
|
+
lazy_metadata: bool = False,
|
|
358
|
+
on_missing: str = "error",
|
|
359
|
+
n_jobs: int = 1,
|
|
360
|
+
verbose: bool | str | int | None = "error",
|
|
361
|
+
):
|
|
152
362
|
"""Windower that creates sliding windows.
|
|
153
363
|
|
|
154
364
|
Parameters
|
|
@@ -183,6 +393,9 @@ def create_fixed_length_windows(
|
|
|
183
393
|
flat: dict | None
|
|
184
394
|
Epoch rejection parameters based on flatness of signals. If None, no
|
|
185
395
|
rejection based on flatness is done. See mne.Epochs.
|
|
396
|
+
lazy_metadata: bool
|
|
397
|
+
If True, metadata is not computed immediately, but only when accessed
|
|
398
|
+
by using the _LazyDataFrame (experimental).
|
|
186
399
|
on_missing: str
|
|
187
400
|
What to do if one or several event ids are not found in the recording.
|
|
188
401
|
Valid keys are ‘error’ | ‘warning’ | ‘ignore’. See mne.Epochs.
|
|
@@ -196,35 +409,70 @@ def create_fixed_length_windows(
|
|
|
196
409
|
windows_datasets: BaseConcatDataset
|
|
197
410
|
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
198
411
|
"""
|
|
199
|
-
stop_offset_samples, drop_last_window =
|
|
200
|
-
|
|
201
|
-
|
|
412
|
+
stop_offset_samples, drop_last_window = (
|
|
413
|
+
_check_and_set_fixed_length_window_arguments(
|
|
414
|
+
start_offset_samples,
|
|
415
|
+
stop_offset_samples,
|
|
416
|
+
window_size_samples,
|
|
417
|
+
window_stride_samples,
|
|
418
|
+
drop_last_window,
|
|
419
|
+
lazy_metadata,
|
|
420
|
+
)
|
|
421
|
+
)
|
|
202
422
|
|
|
203
423
|
# check if recordings are of different lengths
|
|
204
424
|
lengths = np.array([ds.raw.n_times for ds in concat_ds.datasets])
|
|
205
425
|
if (np.diff(lengths) != 0).any() and window_size_samples is None:
|
|
206
|
-
warnings.warn(
|
|
426
|
+
warnings.warn("Recordings have different lengths, they will not be batch-able!")
|
|
207
427
|
if (window_size_samples is not None) and any(window_size_samples > lengths):
|
|
208
|
-
raise ValueError(
|
|
209
|
-
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Window size {window_size_samples} exceeds trial duration {lengths.min()}."
|
|
430
|
+
)
|
|
210
431
|
|
|
211
432
|
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
212
433
|
delayed(_create_fixed_length_windows)(
|
|
213
|
-
ds,
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
434
|
+
ds,
|
|
435
|
+
start_offset_samples,
|
|
436
|
+
stop_offset_samples,
|
|
437
|
+
window_size_samples,
|
|
438
|
+
window_stride_samples,
|
|
439
|
+
drop_last_window,
|
|
440
|
+
mapping,
|
|
441
|
+
preload,
|
|
442
|
+
picks,
|
|
443
|
+
reject,
|
|
444
|
+
flat,
|
|
445
|
+
targets_from,
|
|
446
|
+
last_target_only,
|
|
447
|
+
lazy_metadata,
|
|
448
|
+
on_missing,
|
|
449
|
+
verbose,
|
|
450
|
+
)
|
|
451
|
+
for ds in concat_ds.datasets
|
|
452
|
+
)
|
|
217
453
|
return BaseConcatDataset(list_of_windows_ds)
|
|
218
454
|
|
|
219
455
|
|
|
220
456
|
def _create_windows_from_events(
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
457
|
+
ds,
|
|
458
|
+
infer_mapping,
|
|
459
|
+
infer_window_size_stride,
|
|
460
|
+
trial_start_offset_samples,
|
|
461
|
+
trial_stop_offset_samples,
|
|
462
|
+
window_size_samples=None,
|
|
463
|
+
window_stride_samples=None,
|
|
464
|
+
drop_last_window=False,
|
|
465
|
+
mapping=None,
|
|
466
|
+
preload=False,
|
|
467
|
+
drop_bad_windows=True,
|
|
468
|
+
picks=None,
|
|
469
|
+
reject=None,
|
|
470
|
+
flat=None,
|
|
471
|
+
on_missing="error",
|
|
472
|
+
accepted_bads_ratio=0.0,
|
|
473
|
+
verbose="error",
|
|
474
|
+
use_mne_epochs=False,
|
|
475
|
+
):
|
|
228
476
|
"""Create WindowsDataset from BaseDataset based on events.
|
|
229
477
|
|
|
230
478
|
Parameters
|
|
@@ -254,47 +502,61 @@ def _create_windows_from_events(
|
|
|
254
502
|
new_unique_events = [x for x in unique_events if x not in mapping]
|
|
255
503
|
# mapping event descriptions to integers from 0 on
|
|
256
504
|
max_id_existing_mapping = len(mapping)
|
|
257
|
-
mapping.update(
|
|
505
|
+
mapping.update(
|
|
506
|
+
{
|
|
258
507
|
event_name: i_event_type + max_id_existing_mapping
|
|
259
508
|
for i_event_type, event_name in enumerate(new_unique_events)
|
|
260
|
-
|
|
509
|
+
}
|
|
510
|
+
)
|
|
261
511
|
|
|
262
512
|
events, events_id = mne.events_from_annotations(ds.raw, mapping)
|
|
263
513
|
onsets = events[:, 0]
|
|
264
514
|
# Onsets are relative to the beginning of the recording
|
|
265
|
-
filtered_durations = np.array(
|
|
266
|
-
a[
|
|
267
|
-
|
|
268
|
-
])
|
|
515
|
+
filtered_durations = np.array(
|
|
516
|
+
[a["duration"] for a in ds.raw.annotations if a["description"] in events_id]
|
|
517
|
+
)
|
|
269
518
|
|
|
270
|
-
stops = onsets + (filtered_durations * ds.raw.info[
|
|
519
|
+
stops = onsets + (filtered_durations * ds.raw.info["sfreq"]).astype(int)
|
|
271
520
|
# XXX This could probably be simplified by using chunk_duration in
|
|
272
521
|
# `events_from_annotations`
|
|
273
522
|
|
|
274
|
-
last_samp = ds.raw.first_samp + ds.raw.n_times
|
|
275
|
-
|
|
523
|
+
last_samp = ds.raw.first_samp + ds.raw.n_times - 1
|
|
524
|
+
# `stops` is used exclusively (i.e. `start:stop`), so add back 1
|
|
525
|
+
if stops[-1] + trial_stop_offset_samples > last_samp + 1:
|
|
276
526
|
raise ValueError(
|
|
277
527
|
'"trial_stop_offset_samples" too large. Stop of last trial '
|
|
278
528
|
f'({stops[-1]}) + "trial_stop_offset_samples" '
|
|
279
|
-
f
|
|
280
|
-
f
|
|
529
|
+
f"({trial_stop_offset_samples}) must be smaller than length of"
|
|
530
|
+
f" recording ({len(ds)})."
|
|
531
|
+
)
|
|
281
532
|
|
|
282
533
|
if infer_window_size_stride:
|
|
283
534
|
# window size is trial size
|
|
284
535
|
if window_size_samples is None:
|
|
285
|
-
window_size_samples =
|
|
286
|
-
|
|
536
|
+
window_size_samples = (
|
|
537
|
+
stops[0]
|
|
538
|
+
+ trial_stop_offset_samples
|
|
539
|
+
- (onsets[0] + trial_start_offset_samples)
|
|
540
|
+
)
|
|
287
541
|
window_stride_samples = window_size_samples
|
|
288
542
|
this_trial_sizes = (stops + trial_stop_offset_samples) - (
|
|
289
|
-
onsets + trial_start_offset_samples
|
|
543
|
+
onsets + trial_start_offset_samples
|
|
544
|
+
)
|
|
290
545
|
# Maybe actually this is not necessary?
|
|
291
546
|
# We could also just say we just assume window size=trial size
|
|
292
547
|
# in case not given, without this condition...
|
|
293
548
|
# but then would have to change functions overall
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
549
|
+
checker_trials_size = this_trial_sizes == window_size_samples
|
|
550
|
+
|
|
551
|
+
if not np.all(checker_trials_size):
|
|
552
|
+
trials_drops = int(len(this_trial_sizes) - sum(checker_trials_size))
|
|
553
|
+
warnings.warn(
|
|
554
|
+
f"Dropping trials with different windows size {trials_drops}",
|
|
555
|
+
)
|
|
556
|
+
bads_size_trials = checker_trials_size
|
|
557
|
+
events = events[checker_trials_size]
|
|
558
|
+
onsets = onsets[checker_trials_size]
|
|
559
|
+
stops = stops[checker_trials_size]
|
|
298
560
|
|
|
299
561
|
description = events[:, -1]
|
|
300
562
|
|
|
@@ -302,51 +564,90 @@ def _create_windows_from_events(
|
|
|
302
564
|
onsets = onsets - ds.raw.first_samp
|
|
303
565
|
stops = stops - ds.raw.first_samp
|
|
304
566
|
i_trials, i_window_in_trials, starts, stops = _compute_window_inds(
|
|
305
|
-
onsets,
|
|
306
|
-
|
|
307
|
-
|
|
567
|
+
onsets,
|
|
568
|
+
stops,
|
|
569
|
+
trial_start_offset_samples,
|
|
570
|
+
trial_stop_offset_samples,
|
|
571
|
+
window_size_samples,
|
|
572
|
+
window_stride_samples,
|
|
573
|
+
drop_last_window,
|
|
574
|
+
accepted_bads_ratio,
|
|
575
|
+
)
|
|
308
576
|
|
|
309
577
|
if any(np.diff(starts) <= 0):
|
|
310
|
-
raise NotImplementedError(
|
|
578
|
+
raise NotImplementedError("Trial overlap not implemented.")
|
|
311
579
|
|
|
312
|
-
events = [
|
|
313
|
-
|
|
580
|
+
events = [
|
|
581
|
+
[start, window_size_samples, description[i_trials[i_start]]]
|
|
582
|
+
for i_start, start in enumerate(starts)
|
|
583
|
+
]
|
|
314
584
|
events = np.array(events)
|
|
315
585
|
|
|
316
586
|
description = events[:, -1]
|
|
317
587
|
|
|
318
|
-
metadata = pd.DataFrame(
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
588
|
+
metadata = pd.DataFrame(
|
|
589
|
+
{
|
|
590
|
+
"i_window_in_trial": i_window_in_trials,
|
|
591
|
+
"i_start_in_trial": starts,
|
|
592
|
+
"i_stop_in_trial": stops,
|
|
593
|
+
"target": description,
|
|
594
|
+
}
|
|
595
|
+
)
|
|
324
596
|
if use_mne_epochs:
|
|
325
597
|
# window size - 1, since tmax is inclusive
|
|
326
598
|
mne_epochs = mne.Epochs(
|
|
327
|
-
ds.raw,
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
599
|
+
ds.raw,
|
|
600
|
+
events,
|
|
601
|
+
events_id,
|
|
602
|
+
baseline=None,
|
|
603
|
+
tmin=0,
|
|
604
|
+
tmax=(window_size_samples - 1) / ds.raw.info["sfreq"],
|
|
605
|
+
metadata=metadata,
|
|
606
|
+
preload=preload,
|
|
607
|
+
picks=picks,
|
|
608
|
+
reject=reject,
|
|
609
|
+
flat=flat,
|
|
610
|
+
on_missing=on_missing,
|
|
611
|
+
verbose=verbose,
|
|
612
|
+
)
|
|
331
613
|
if drop_bad_windows:
|
|
332
614
|
mne_epochs.drop_bad()
|
|
333
|
-
windows_ds = WindowsDataset(
|
|
615
|
+
windows_ds = WindowsDataset(
|
|
616
|
+
mne_epochs,
|
|
617
|
+
ds.description,
|
|
618
|
+
)
|
|
334
619
|
else:
|
|
335
620
|
windows_ds = EEGWindowsDataset(
|
|
336
|
-
ds.raw,
|
|
621
|
+
ds.raw,
|
|
622
|
+
metadata=metadata,
|
|
623
|
+
description=ds.description,
|
|
624
|
+
)
|
|
337
625
|
# add window_kwargs and raw_preproc_kwargs to windows dataset
|
|
338
|
-
setattr(windows_ds,
|
|
339
|
-
kwargs_name =
|
|
626
|
+
setattr(windows_ds, "window_kwargs", window_kwargs)
|
|
627
|
+
kwargs_name = "raw_preproc_kwargs"
|
|
340
628
|
if hasattr(ds, kwargs_name):
|
|
341
629
|
setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
|
|
342
630
|
return windows_ds
|
|
343
631
|
|
|
344
632
|
|
|
345
633
|
def _create_fixed_length_windows(
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
634
|
+
ds,
|
|
635
|
+
start_offset_samples,
|
|
636
|
+
stop_offset_samples,
|
|
637
|
+
window_size_samples,
|
|
638
|
+
window_stride_samples,
|
|
639
|
+
drop_last_window,
|
|
640
|
+
mapping=None,
|
|
641
|
+
preload=False,
|
|
642
|
+
picks=None,
|
|
643
|
+
reject=None,
|
|
644
|
+
flat=None,
|
|
645
|
+
targets_from="metadata",
|
|
646
|
+
last_target_only=True,
|
|
647
|
+
lazy_metadata=False,
|
|
648
|
+
on_missing="error",
|
|
649
|
+
verbose="error",
|
|
650
|
+
):
|
|
350
651
|
"""Create WindowsDataset from BaseDataset with sliding windows.
|
|
351
652
|
|
|
352
653
|
Parameters
|
|
@@ -365,8 +666,7 @@ def _create_fixed_length_windows(
|
|
|
365
666
|
window_kwargs = [
|
|
366
667
|
(create_fixed_length_windows.__name__, _get_windowing_kwargs(locals())),
|
|
367
668
|
]
|
|
368
|
-
stop = ds.raw.n_times
|
|
369
|
-
if stop_offset_samples is None else stop_offset_samples
|
|
669
|
+
stop = ds.raw.n_times if stop_offset_samples is None else stop_offset_samples
|
|
370
670
|
|
|
371
671
|
# assume window should be whole recording
|
|
372
672
|
if window_size_samples is None:
|
|
@@ -375,15 +675,6 @@ def _create_fixed_length_windows(
|
|
|
375
675
|
window_stride_samples = window_size_samples
|
|
376
676
|
|
|
377
677
|
last_potential_start = stop - window_size_samples
|
|
378
|
-
# already includes last incomplete window start
|
|
379
|
-
starts = np.arange(
|
|
380
|
-
start_offset_samples,
|
|
381
|
-
last_potential_start + 1,
|
|
382
|
-
window_stride_samples)
|
|
383
|
-
|
|
384
|
-
if not drop_last_window and starts[-1] < last_potential_start:
|
|
385
|
-
# if last window does not end at trial stop, make it stop there
|
|
386
|
-
starts = np.append(starts, last_potential_start)
|
|
387
678
|
|
|
388
679
|
# get targets from dataset description if they exist
|
|
389
680
|
target = -1 if ds.target_name is None else ds.description[ds.target_name]
|
|
@@ -395,18 +686,53 @@ def _create_fixed_length_windows(
|
|
|
395
686
|
else:
|
|
396
687
|
target = mapping[target]
|
|
397
688
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
689
|
+
if lazy_metadata:
|
|
690
|
+
factory = _FixedLengthWindowFunctions(
|
|
691
|
+
start_offset_samples,
|
|
692
|
+
last_potential_start,
|
|
693
|
+
window_stride_samples,
|
|
694
|
+
window_size_samples,
|
|
695
|
+
target,
|
|
696
|
+
)
|
|
697
|
+
metadata = _LazyDataFrame(
|
|
698
|
+
length=factory.length,
|
|
699
|
+
functions={
|
|
700
|
+
"i_window_in_trial": factory.i_window_in_trial,
|
|
701
|
+
"i_start_in_trial": factory.i_start_in_trial,
|
|
702
|
+
"i_stop_in_trial": factory.i_stop_in_trial,
|
|
703
|
+
"target": factory.target,
|
|
704
|
+
},
|
|
705
|
+
columns=[
|
|
706
|
+
"i_window_in_trial",
|
|
707
|
+
"i_start_in_trial",
|
|
708
|
+
"i_stop_in_trial",
|
|
709
|
+
"target",
|
|
710
|
+
],
|
|
711
|
+
)
|
|
712
|
+
else:
|
|
713
|
+
# already includes last incomplete window start
|
|
714
|
+
starts = np.arange(
|
|
715
|
+
start_offset_samples, last_potential_start + 1, window_stride_samples
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
if not drop_last_window and starts[-1] < last_potential_start:
|
|
719
|
+
# if last window does not end at trial stop, make it stop there
|
|
720
|
+
starts = np.append(starts, last_potential_start)
|
|
721
|
+
|
|
722
|
+
metadata = pd.DataFrame(
|
|
723
|
+
{
|
|
724
|
+
"i_window_in_trial": np.arange(len(starts)),
|
|
725
|
+
"i_start_in_trial": starts,
|
|
726
|
+
"i_stop_in_trial": starts + window_size_samples,
|
|
727
|
+
"target": len(starts) * [target],
|
|
728
|
+
}
|
|
729
|
+
)
|
|
404
730
|
|
|
405
731
|
window_kwargs.append(
|
|
406
|
-
(
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
732
|
+
(
|
|
733
|
+
EEGWindowsDataset.__name__,
|
|
734
|
+
{"targets_from": targets_from, "last_target_only": last_target_only},
|
|
735
|
+
)
|
|
410
736
|
)
|
|
411
737
|
windows_ds = EEGWindowsDataset(
|
|
412
738
|
ds.raw,
|
|
@@ -416,28 +742,52 @@ def _create_fixed_length_windows(
|
|
|
416
742
|
last_target_only=last_target_only,
|
|
417
743
|
)
|
|
418
744
|
# add window_kwargs and raw_preproc_kwargs to windows dataset
|
|
419
|
-
setattr(windows_ds,
|
|
420
|
-
kwargs_name =
|
|
745
|
+
setattr(windows_ds, "window_kwargs", window_kwargs)
|
|
746
|
+
kwargs_name = "raw_preproc_kwargs"
|
|
421
747
|
if hasattr(ds, kwargs_name):
|
|
422
748
|
setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
|
|
423
749
|
return windows_ds
|
|
424
750
|
|
|
425
751
|
|
|
426
752
|
def create_windows_from_target_channels(
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
753
|
+
concat_ds,
|
|
754
|
+
window_size_samples=None,
|
|
755
|
+
preload=False,
|
|
756
|
+
picks=None,
|
|
757
|
+
reject=None,
|
|
758
|
+
flat=None,
|
|
759
|
+
n_jobs=1,
|
|
760
|
+
last_target_only=True,
|
|
761
|
+
verbose="error",
|
|
762
|
+
):
|
|
430
763
|
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
431
764
|
delayed(_create_windows_from_target_channels)(
|
|
432
|
-
ds,
|
|
433
|
-
|
|
765
|
+
ds,
|
|
766
|
+
window_size_samples,
|
|
767
|
+
preload,
|
|
768
|
+
picks,
|
|
769
|
+
reject,
|
|
770
|
+
flat,
|
|
771
|
+
last_target_only,
|
|
772
|
+
"error",
|
|
773
|
+
verbose,
|
|
774
|
+
)
|
|
775
|
+
for ds in concat_ds.datasets
|
|
776
|
+
)
|
|
434
777
|
return BaseConcatDataset(list_of_windows_ds)
|
|
435
778
|
|
|
436
779
|
|
|
437
780
|
def _create_windows_from_target_channels(
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
781
|
+
ds,
|
|
782
|
+
window_size_samples,
|
|
783
|
+
preload=False,
|
|
784
|
+
picks=None,
|
|
785
|
+
reject=None,
|
|
786
|
+
flat=None,
|
|
787
|
+
last_target_only=True,
|
|
788
|
+
on_missing="error",
|
|
789
|
+
verbose="error",
|
|
790
|
+
):
|
|
441
791
|
"""Create WindowsDataset from BaseDataset using targets `misc` channels from mne.Raw.
|
|
442
792
|
|
|
443
793
|
Parameters
|
|
@@ -457,24 +807,26 @@ def _create_windows_from_target_channels(
|
|
|
457
807
|
]
|
|
458
808
|
stop = ds.raw.n_times + ds.raw.first_samp
|
|
459
809
|
|
|
460
|
-
target = ds.raw.get_data(picks=
|
|
810
|
+
target = ds.raw.get_data(picks="misc")
|
|
461
811
|
# TODO: handle multi targets present only for some events
|
|
462
812
|
stops = np.nonzero((~np.isnan(target[0, :])))[0] + 1
|
|
463
813
|
stops = stops[(stops < stop) & (stops >= window_size_samples)]
|
|
464
814
|
stops = stops.astype(int)
|
|
465
|
-
metadata = pd.DataFrame(
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
815
|
+
metadata = pd.DataFrame(
|
|
816
|
+
{
|
|
817
|
+
"i_window_in_trial": np.arange(len(stops)),
|
|
818
|
+
"i_start_in_trial": stops - window_size_samples,
|
|
819
|
+
"i_stop_in_trial": stops,
|
|
820
|
+
"target": len(stops) * [target],
|
|
821
|
+
}
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
targets_from = "channels"
|
|
473
825
|
window_kwargs.append(
|
|
474
|
-
(
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
826
|
+
(
|
|
827
|
+
EEGWindowsDataset.__name__,
|
|
828
|
+
{"targets_from": targets_from, "last_target_only": last_target_only},
|
|
829
|
+
)
|
|
478
830
|
)
|
|
479
831
|
windows_ds = EEGWindowsDataset(
|
|
480
832
|
ds.raw,
|
|
@@ -483,16 +835,23 @@ def _create_windows_from_target_channels(
|
|
|
483
835
|
targets_from=targets_from,
|
|
484
836
|
last_target_only=last_target_only,
|
|
485
837
|
)
|
|
486
|
-
setattr(windows_ds,
|
|
487
|
-
kwargs_name =
|
|
838
|
+
setattr(windows_ds, "window_kwargs", window_kwargs)
|
|
839
|
+
kwargs_name = "raw_preproc_kwargs"
|
|
488
840
|
if hasattr(ds, kwargs_name):
|
|
489
841
|
setattr(windows_ds, kwargs_name, getattr(ds, kwargs_name))
|
|
490
842
|
return windows_ds
|
|
491
843
|
|
|
492
844
|
|
|
493
845
|
def _compute_window_inds(
|
|
494
|
-
|
|
495
|
-
|
|
846
|
+
starts,
|
|
847
|
+
stops,
|
|
848
|
+
start_offset,
|
|
849
|
+
stop_offset,
|
|
850
|
+
size,
|
|
851
|
+
stride,
|
|
852
|
+
drop_last_window,
|
|
853
|
+
accepted_bads_ratio,
|
|
854
|
+
):
|
|
496
855
|
"""Compute window start and stop indices.
|
|
497
856
|
|
|
498
857
|
Create window starts from trial onsets (shifted by start_offset) to trial
|
|
@@ -531,27 +890,31 @@ def _compute_window_inds(
|
|
|
531
890
|
|
|
532
891
|
starts += start_offset
|
|
533
892
|
stops += stop_offset
|
|
534
|
-
if any(size > (stops-starts)):
|
|
535
|
-
bads_mask = size > (stops-starts)
|
|
536
|
-
min_duration = (stops-starts).min()
|
|
893
|
+
if any(size > (stops - starts)):
|
|
894
|
+
bads_mask = size > (stops - starts)
|
|
895
|
+
min_duration = (stops - starts).min()
|
|
537
896
|
if sum(bads_mask) <= accepted_bads_ratio * len(starts):
|
|
538
897
|
starts = starts[np.logical_not(bads_mask)]
|
|
539
898
|
stops = stops[np.logical_not(bads_mask)]
|
|
540
899
|
warnings.warn(
|
|
541
|
-
f
|
|
542
|
-
f
|
|
900
|
+
f"Trials {np.where(bads_mask)[0]} are being dropped as the "
|
|
901
|
+
f"window size ({size}) exceeds their duration {min_duration}."
|
|
902
|
+
)
|
|
543
903
|
else:
|
|
544
904
|
current_ratio = sum(bads_mask) / len(starts)
|
|
545
|
-
raise ValueError(
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
905
|
+
raise ValueError(
|
|
906
|
+
f"Window size {size} exceeds trial duration "
|
|
907
|
+
f"({min_duration}) for too many trials "
|
|
908
|
+
f"({current_ratio * 100}%). Set "
|
|
909
|
+
f"accepted_bads_ratio to at least {current_ratio}"
|
|
910
|
+
"and restart training to be able to continue."
|
|
911
|
+
)
|
|
550
912
|
|
|
551
913
|
i_window_in_trials, i_trials, window_starts = [], [], []
|
|
552
914
|
for start_i, (start, stop) in enumerate(zip(starts, stops)):
|
|
553
|
-
# Generate possible window starts with given stride between original
|
|
554
|
-
# trial onsets (shifted by start_offset
|
|
915
|
+
# Generate possible window starts, with given stride, between original
|
|
916
|
+
# trial onsets and stops (shifted by start_offset and stop_offset,
|
|
917
|
+
# respectively)
|
|
555
918
|
possible_starts = np.arange(start, stop, stride)
|
|
556
919
|
|
|
557
920
|
# Possible window start is actually a start, if window size fits in
|
|
@@ -571,72 +934,98 @@ def _compute_window_inds(
|
|
|
571
934
|
i_window_in_trials.append(i_window_in_trials[-1] + 1)
|
|
572
935
|
i_trials.append(start_i)
|
|
573
936
|
|
|
574
|
-
#
|
|
937
|
+
# Set window stops to be event stops (rather than trial stops)
|
|
575
938
|
window_stops = np.array(window_starts) + size
|
|
576
939
|
if not (len(i_window_in_trials) == len(window_starts) == len(window_stops)):
|
|
577
|
-
raise ValueError(
|
|
578
|
-
|
|
940
|
+
raise ValueError(
|
|
941
|
+
f"{len(i_window_in_trials)} == {len(window_starts)} == {len(window_stops)}"
|
|
942
|
+
)
|
|
579
943
|
|
|
580
944
|
return i_trials, i_window_in_trials, window_starts, window_stops
|
|
581
945
|
|
|
582
946
|
|
|
583
947
|
def _check_windowing_arguments(
|
|
584
|
-
|
|
585
|
-
|
|
948
|
+
trial_start_offset_samples,
|
|
949
|
+
trial_stop_offset_samples,
|
|
950
|
+
window_size_samples,
|
|
951
|
+
window_stride_samples,
|
|
952
|
+
):
|
|
586
953
|
assert isinstance(trial_start_offset_samples, (int, np.integer))
|
|
587
|
-
assert
|
|
588
|
-
|
|
954
|
+
assert isinstance(trial_stop_offset_samples, (int, np.integer)) or (
|
|
955
|
+
trial_stop_offset_samples is None
|
|
956
|
+
)
|
|
589
957
|
assert isinstance(window_size_samples, (int, np.integer, type(None)))
|
|
590
958
|
assert isinstance(window_stride_samples, (int, np.integer, type(None)))
|
|
591
959
|
assert (window_size_samples is None) == (window_stride_samples is None)
|
|
592
960
|
if window_size_samples is not None:
|
|
593
|
-
assert window_size_samples > 0,
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
961
|
+
assert window_size_samples > 0, "window size has to be larger than 0"
|
|
962
|
+
assert window_stride_samples > 0, "window stride has to be larger than 0"
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def _check_and_set_fixed_length_window_arguments(
|
|
966
|
+
start_offset_samples,
|
|
967
|
+
stop_offset_samples,
|
|
968
|
+
window_size_samples,
|
|
969
|
+
window_stride_samples,
|
|
970
|
+
drop_last_window,
|
|
971
|
+
lazy_metadata,
|
|
972
|
+
):
|
|
602
973
|
"""Raises warnings for incorrect input arguments and will set correct default values for
|
|
603
974
|
stop_offset_samples & drop_last_window, if necessary.
|
|
604
975
|
"""
|
|
605
976
|
_check_windowing_arguments(
|
|
606
|
-
start_offset_samples,
|
|
607
|
-
|
|
977
|
+
start_offset_samples,
|
|
978
|
+
stop_offset_samples,
|
|
979
|
+
window_size_samples,
|
|
980
|
+
window_stride_samples,
|
|
981
|
+
)
|
|
608
982
|
|
|
609
983
|
if stop_offset_samples == 0:
|
|
610
984
|
warnings.warn(
|
|
611
|
-
|
|
612
|
-
|
|
985
|
+
"Meaning of `trial_stop_offset_samples`=0 has changed, use `None` "
|
|
986
|
+
"to indicate end of trial/recording. Using `None`."
|
|
987
|
+
)
|
|
613
988
|
stop_offset_samples = None
|
|
614
989
|
|
|
615
990
|
if start_offset_samples != 0 or stop_offset_samples is not None:
|
|
616
|
-
warnings.warn(
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
991
|
+
warnings.warn(
|
|
992
|
+
"Usage of offset_sample args in create_fixed_length_windows is deprecated and"
|
|
993
|
+
" will be removed in future versions. Please use "
|
|
994
|
+
'braindecode.preprocessing.preprocess.Preprocessor("crop", tmin, tmax)'
|
|
995
|
+
" instead."
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
if (
|
|
999
|
+
window_size_samples is not None
|
|
1000
|
+
and window_stride_samples is not None
|
|
1001
|
+
and drop_last_window is None
|
|
1002
|
+
):
|
|
1003
|
+
raise ValueError(
|
|
1004
|
+
"drop_last_window must be set if both window_size_samples &"
|
|
1005
|
+
" window_stride_samples have also been set"
|
|
1006
|
+
)
|
|
1007
|
+
elif (
|
|
1008
|
+
window_size_samples is None
|
|
1009
|
+
and window_stride_samples is None
|
|
1010
|
+
and drop_last_window is False
|
|
1011
|
+
):
|
|
628
1012
|
# necessary for following assertion
|
|
629
1013
|
drop_last_window = None
|
|
630
1014
|
|
|
631
|
-
assert (
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
1015
|
+
assert (
|
|
1016
|
+
(window_size_samples is None)
|
|
1017
|
+
== (window_stride_samples is None)
|
|
1018
|
+
== (drop_last_window is None)
|
|
1019
|
+
)
|
|
1020
|
+
if not drop_last_window and lazy_metadata:
|
|
1021
|
+
raise ValueError(
|
|
1022
|
+
"Cannot have drop_last_window=False and lazy_metadata=True at the same time."
|
|
1023
|
+
)
|
|
635
1024
|
return stop_offset_samples, drop_last_window
|
|
636
1025
|
|
|
637
1026
|
|
|
638
1027
|
def _get_windowing_kwargs(windowing_func_locals):
|
|
639
1028
|
input_kwargs = windowing_func_locals
|
|
640
|
-
input_kwargs.pop(
|
|
1029
|
+
input_kwargs.pop("ds")
|
|
641
1030
|
windowing_kwargs = {k: v for k, v in input_kwargs.items()}
|
|
642
1031
|
return windowing_kwargs
|