braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/datasets/base.py
CHANGED
|
@@ -10,24 +10,28 @@ Dataset classes.
|
|
|
10
10
|
#
|
|
11
11
|
# License: BSD (3-clause)
|
|
12
12
|
|
|
13
|
-
import
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
14
15
|
import json
|
|
16
|
+
import os
|
|
15
17
|
import shutil
|
|
16
|
-
from typing import Iterable
|
|
17
18
|
import warnings
|
|
19
|
+
from collections.abc import Callable
|
|
18
20
|
from glob import glob
|
|
21
|
+
from typing import Iterable, no_type_check
|
|
19
22
|
|
|
23
|
+
import mne.io
|
|
20
24
|
import numpy as np
|
|
21
25
|
import pandas as pd
|
|
22
|
-
from torch.utils.data import
|
|
26
|
+
from torch.utils.data import ConcatDataset, Dataset
|
|
23
27
|
|
|
24
28
|
|
|
25
|
-
def _create_description(description):
|
|
29
|
+
def _create_description(description) -> pd.Series:
|
|
26
30
|
if description is not None:
|
|
27
|
-
if
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
+
if not isinstance(description, pd.Series) and not isinstance(description, dict):
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"'{description}' has to be either a pandas.Series or a dict."
|
|
34
|
+
)
|
|
31
35
|
if isinstance(description, dict):
|
|
32
36
|
description = pd.Series(description)
|
|
33
37
|
return description
|
|
@@ -52,8 +56,14 @@ class BaseDataset(Dataset):
|
|
|
52
56
|
transform : callable | None
|
|
53
57
|
On-the-fly transform applied to the example before it is returned.
|
|
54
58
|
"""
|
|
55
|
-
|
|
56
|
-
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
raw: mne.io.BaseRaw,
|
|
63
|
+
description: dict | pd.Series | None = None,
|
|
64
|
+
target_name: str | tuple[str, ...] | None = None,
|
|
65
|
+
transform: Callable | None = None,
|
|
66
|
+
):
|
|
57
67
|
self.raw = raw
|
|
58
68
|
self._description = _create_description(description)
|
|
59
69
|
self.transform = transform
|
|
@@ -82,14 +92,14 @@ class BaseDataset(Dataset):
|
|
|
82
92
|
@transform.setter
|
|
83
93
|
def transform(self, value):
|
|
84
94
|
if value is not None and not callable(value):
|
|
85
|
-
raise ValueError(
|
|
95
|
+
raise ValueError("Transform needs to be a callable.")
|
|
86
96
|
self._transform = value
|
|
87
97
|
|
|
88
98
|
@property
|
|
89
|
-
def description(self):
|
|
99
|
+
def description(self) -> pd.Series:
|
|
90
100
|
return self._description
|
|
91
101
|
|
|
92
|
-
def set_description(self, description, overwrite=False):
|
|
102
|
+
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
93
103
|
"""Update (add or overwrite) the dataset description.
|
|
94
104
|
|
|
95
105
|
Parameters
|
|
@@ -104,8 +114,10 @@ class BaseDataset(Dataset):
|
|
|
104
114
|
for key, value in description.items():
|
|
105
115
|
# if the key is already in the existing description, drop it
|
|
106
116
|
if self._description is not None and key in self._description:
|
|
107
|
-
assert overwrite, (
|
|
108
|
-
|
|
117
|
+
assert overwrite, (
|
|
118
|
+
f"'{key}' already in description. Please "
|
|
119
|
+
f"rename or set overwrite to True."
|
|
120
|
+
)
|
|
109
121
|
self._description.pop(key)
|
|
110
122
|
if self._description is None:
|
|
111
123
|
self._description = description
|
|
@@ -114,7 +126,7 @@ class BaseDataset(Dataset):
|
|
|
114
126
|
|
|
115
127
|
def _target_name(self, target_name):
|
|
116
128
|
if target_name is not None and not isinstance(target_name, (str, tuple, list)):
|
|
117
|
-
raise ValueError(
|
|
129
|
+
raise ValueError("target_name has to be None, str, tuple or list")
|
|
118
130
|
if target_name is None:
|
|
119
131
|
return target_name
|
|
120
132
|
else:
|
|
@@ -128,9 +140,12 @@ class BaseDataset(Dataset):
|
|
|
128
140
|
# check if target name(s) can be read from description
|
|
129
141
|
for name in target_name:
|
|
130
142
|
if self.description is None or name not in self.description:
|
|
131
|
-
warnings.warn(
|
|
132
|
-
|
|
133
|
-
|
|
143
|
+
warnings.warn(
|
|
144
|
+
f"'{name}' not in description. '__getitem__'"
|
|
145
|
+
f"will fail unless an appropriate target is"
|
|
146
|
+
f" added to description.",
|
|
147
|
+
UserWarning,
|
|
148
|
+
)
|
|
134
149
|
# return a list of str if there are multiple targets and a str otherwise
|
|
135
150
|
return target_name if len(target_name) > 1 else target_name[0]
|
|
136
151
|
|
|
@@ -168,24 +183,31 @@ class EEGWindowsDataset(BaseDataset):
|
|
|
168
183
|
as well as `targets`.
|
|
169
184
|
"""
|
|
170
185
|
|
|
171
|
-
def __init__(
|
|
172
|
-
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
raw: mne.io.BaseRaw | mne.BaseEpochs,
|
|
189
|
+
metadata: pd.DataFrame,
|
|
190
|
+
description: dict | pd.Series | None = None,
|
|
191
|
+
transform: Callable | None = None,
|
|
192
|
+
targets_from: str = "metadata",
|
|
193
|
+
last_target_only: bool = True,
|
|
194
|
+
):
|
|
173
195
|
self.raw = raw
|
|
174
196
|
self.metadata = metadata
|
|
175
197
|
self._description = _create_description(description)
|
|
176
198
|
|
|
177
199
|
self.transform = transform
|
|
178
200
|
self.last_target_only = last_target_only
|
|
179
|
-
if targets_from not in (
|
|
180
|
-
raise ValueError(
|
|
201
|
+
if targets_from not in ("metadata", "channels"):
|
|
202
|
+
raise ValueError("Wrong value for parameter `targets_from`.")
|
|
181
203
|
self.targets_from = targets_from
|
|
182
204
|
self.crop_inds = metadata.loc[
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
if self.targets_from ==
|
|
186
|
-
self.y = metadata.loc[:,
|
|
205
|
+
:, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
|
|
206
|
+
].to_numpy()
|
|
207
|
+
if self.targets_from == "metadata":
|
|
208
|
+
self.y = metadata.loc[:, "target"].to_list()
|
|
187
209
|
|
|
188
|
-
def __getitem__(self, index):
|
|
210
|
+
def __getitem__(self, index: int):
|
|
189
211
|
"""Get a window and its target.
|
|
190
212
|
|
|
191
213
|
Parameters
|
|
@@ -209,16 +231,16 @@ class EEGWindowsDataset(BaseDataset):
|
|
|
209
231
|
|
|
210
232
|
i_window_in_trial, i_start, i_stop = crop_inds
|
|
211
233
|
X = self.raw._getitem((slice(None), slice(i_start, i_stop)), return_times=False)
|
|
212
|
-
X = X.astype(
|
|
234
|
+
X = X.astype("float32")
|
|
213
235
|
# ensure we don't give the user the option
|
|
214
236
|
# to accidentally modify the underlying array
|
|
215
237
|
X = X.copy()
|
|
216
238
|
if self.transform is not None:
|
|
217
239
|
X = self.transform(X)
|
|
218
|
-
if self.targets_from ==
|
|
240
|
+
if self.targets_from == "metadata":
|
|
219
241
|
y = self.y[index]
|
|
220
242
|
else:
|
|
221
|
-
misc_mask = np.array(self.raw.get_channel_types()) ==
|
|
243
|
+
misc_mask = np.array(self.raw.get_channel_types()) == "misc"
|
|
222
244
|
if self.last_target_only:
|
|
223
245
|
y = X[misc_mask, -1]
|
|
224
246
|
else:
|
|
@@ -240,14 +262,14 @@ class EEGWindowsDataset(BaseDataset):
|
|
|
240
262
|
@transform.setter
|
|
241
263
|
def transform(self, value):
|
|
242
264
|
if value is not None and not callable(value):
|
|
243
|
-
raise ValueError(
|
|
265
|
+
raise ValueError("Transform needs to be a callable.")
|
|
244
266
|
self._transform = value
|
|
245
267
|
|
|
246
268
|
@property
|
|
247
|
-
def description(self):
|
|
269
|
+
def description(self) -> pd.Series:
|
|
248
270
|
return self._description
|
|
249
271
|
|
|
250
|
-
def set_description(self, description, overwrite=False):
|
|
272
|
+
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
251
273
|
"""Update (add or overwrite) the dataset description.
|
|
252
274
|
|
|
253
275
|
Parameters
|
|
@@ -262,8 +284,10 @@ class EEGWindowsDataset(BaseDataset):
|
|
|
262
284
|
for key, value in description.items():
|
|
263
285
|
# if they key is already in the existing description, drop it
|
|
264
286
|
if key in self._description:
|
|
265
|
-
assert overwrite, (
|
|
266
|
-
|
|
287
|
+
assert overwrite, (
|
|
288
|
+
f"'{key}' already in description. Please "
|
|
289
|
+
f"rename or set overwrite to True."
|
|
290
|
+
)
|
|
267
291
|
self._description.pop(key)
|
|
268
292
|
self._description = pd.concat([self.description, description])
|
|
269
293
|
|
|
@@ -294,23 +318,30 @@ class WindowsDataset(BaseDataset):
|
|
|
294
318
|
Defines whether targets will be extracted from mne.Epochs metadata or mne.Epochs `misc`
|
|
295
319
|
channels (time series targets). It can be `metadata` (default) or `channels`.
|
|
296
320
|
"""
|
|
297
|
-
|
|
298
|
-
|
|
321
|
+
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
windows: mne.BaseEpochs,
|
|
325
|
+
description: dict | pd.Series | None = None,
|
|
326
|
+
transform: Callable | None = None,
|
|
327
|
+
targets_from: str = "metadata",
|
|
328
|
+
last_target_only: bool = True,
|
|
329
|
+
):
|
|
299
330
|
self.windows = windows
|
|
300
331
|
self._description = _create_description(description)
|
|
301
332
|
self.transform = transform
|
|
302
333
|
self.last_target_only = last_target_only
|
|
303
|
-
if targets_from not in (
|
|
304
|
-
raise ValueError(
|
|
334
|
+
if targets_from not in ("metadata", "channels"):
|
|
335
|
+
raise ValueError("Wrong value for parameter `targets_from`.")
|
|
305
336
|
self.targets_from = targets_from
|
|
306
337
|
|
|
307
338
|
self.crop_inds = self.windows.metadata.loc[
|
|
308
|
-
:, [
|
|
309
|
-
|
|
310
|
-
if self.targets_from ==
|
|
311
|
-
self.y = self.windows.metadata.loc[:,
|
|
339
|
+
:, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
|
|
340
|
+
].to_numpy()
|
|
341
|
+
if self.targets_from == "metadata":
|
|
342
|
+
self.y = self.windows.metadata.loc[:, "target"].to_list()
|
|
312
343
|
|
|
313
|
-
def __getitem__(self, index):
|
|
344
|
+
def __getitem__(self, index: int):
|
|
314
345
|
"""Get a window and its target.
|
|
315
346
|
|
|
316
347
|
Parameters
|
|
@@ -327,13 +358,13 @@ class WindowsDataset(BaseDataset):
|
|
|
327
358
|
np.ndarray
|
|
328
359
|
Crop indices.
|
|
329
360
|
"""
|
|
330
|
-
X = self.windows.get_data(item=index)[0].astype(
|
|
361
|
+
X = self.windows.get_data(item=index)[0].astype("float32")
|
|
331
362
|
if self.transform is not None:
|
|
332
363
|
X = self.transform(X)
|
|
333
|
-
if self.targets_from ==
|
|
364
|
+
if self.targets_from == "metadata":
|
|
334
365
|
y = self.y[index]
|
|
335
366
|
else:
|
|
336
|
-
misc_mask = np.array(self.windows.get_channel_types()) ==
|
|
367
|
+
misc_mask = np.array(self.windows.get_channel_types()) == "misc"
|
|
337
368
|
if self.last_target_only:
|
|
338
369
|
y = X[misc_mask, -1]
|
|
339
370
|
else:
|
|
@@ -345,7 +376,7 @@ class WindowsDataset(BaseDataset):
|
|
|
345
376
|
crop_inds = self.crop_inds[index].tolist()
|
|
346
377
|
return X, y, crop_inds
|
|
347
378
|
|
|
348
|
-
def __len__(self):
|
|
379
|
+
def __len__(self) -> int:
|
|
349
380
|
return len(self.windows.events)
|
|
350
381
|
|
|
351
382
|
@property
|
|
@@ -355,14 +386,14 @@ class WindowsDataset(BaseDataset):
|
|
|
355
386
|
@transform.setter
|
|
356
387
|
def transform(self, value):
|
|
357
388
|
if value is not None and not callable(value):
|
|
358
|
-
raise ValueError(
|
|
389
|
+
raise ValueError("Transform needs to be a callable.")
|
|
359
390
|
self._transform = value
|
|
360
391
|
|
|
361
392
|
@property
|
|
362
|
-
def description(self):
|
|
393
|
+
def description(self) -> pd.Series:
|
|
363
394
|
return self._description
|
|
364
395
|
|
|
365
|
-
def set_description(self, description, overwrite=False):
|
|
396
|
+
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
366
397
|
"""Update (add or overwrite) the dataset description.
|
|
367
398
|
|
|
368
399
|
Parameters
|
|
@@ -377,16 +408,19 @@ class WindowsDataset(BaseDataset):
|
|
|
377
408
|
for key, value in description.items():
|
|
378
409
|
# if they key is already in the existing description, drop it
|
|
379
410
|
if key in self._description:
|
|
380
|
-
assert overwrite, (
|
|
381
|
-
|
|
411
|
+
assert overwrite, (
|
|
412
|
+
f"'{key}' already in description. Please "
|
|
413
|
+
f"rename or set overwrite to True."
|
|
414
|
+
)
|
|
382
415
|
self._description.pop(key)
|
|
383
416
|
self._description = pd.concat([self.description, description])
|
|
384
417
|
|
|
385
418
|
|
|
386
419
|
class BaseConcatDataset(ConcatDataset):
|
|
387
|
-
"""A base class for concatenated datasets.
|
|
388
|
-
|
|
389
|
-
|
|
420
|
+
"""A base class for concatenated datasets.
|
|
421
|
+
|
|
422
|
+
Holds either mne.Raw or mne.Epoch in self.datasets and has
|
|
423
|
+
a pandas DataFrame with additional description.
|
|
390
424
|
|
|
391
425
|
Parameters
|
|
392
426
|
----------
|
|
@@ -394,8 +428,15 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
394
428
|
list of BaseDataset, BaseConcatDataset or WindowsDataset
|
|
395
429
|
target_transform : callable | None
|
|
396
430
|
Optional function to call on targets before returning them.
|
|
431
|
+
|
|
397
432
|
"""
|
|
398
|
-
|
|
433
|
+
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
list_of_ds: list[BaseDataset | BaseConcatDataset | WindowsDataset]
|
|
437
|
+
| None = None,
|
|
438
|
+
target_transform: Callable | None = None,
|
|
439
|
+
):
|
|
399
440
|
# if we get a list of BaseConcatDataset, get all the individual datasets
|
|
400
441
|
if list_of_ds and isinstance(list_of_ds[0], BaseConcatDataset):
|
|
401
442
|
list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
|
|
@@ -415,7 +456,7 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
415
456
|
|
|
416
457
|
return X, y
|
|
417
458
|
|
|
418
|
-
def __getitem__(self, idx):
|
|
459
|
+
def __getitem__(self, idx: int | list):
|
|
419
460
|
"""
|
|
420
461
|
Parameters
|
|
421
462
|
----------
|
|
@@ -433,9 +474,16 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
433
474
|
item = item[:1] + (self.target_transform(item[1]),) + item[2:]
|
|
434
475
|
return item
|
|
435
476
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
477
|
+
@no_type_check # TODO, it's a mess
|
|
478
|
+
def split(
|
|
479
|
+
self,
|
|
480
|
+
by: str | list[int] | list[list[int]] | dict[str, list[int]] | None = None,
|
|
481
|
+
property: str | None = None,
|
|
482
|
+
split_ids: list[int] | list[list[int]] | dict[str, list[int]] | None = None,
|
|
483
|
+
) -> dict[str, BaseConcatDataset]:
|
|
484
|
+
"""Split the dataset based on information listed in its description.
|
|
485
|
+
|
|
486
|
+
The format could be based on a DataFrame or based on indices.
|
|
439
487
|
|
|
440
488
|
Parameters
|
|
441
489
|
----------
|
|
@@ -448,8 +496,8 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
448
496
|
If a dict then each key will be used in the returned
|
|
449
497
|
splits dict and each value should be a list of int.
|
|
450
498
|
property : str
|
|
451
|
-
Some property which is listed in info DataFrame.
|
|
452
|
-
split_ids : list |
|
|
499
|
+
Some property which is listed in the info DataFrame.
|
|
500
|
+
split_ids : list | dict
|
|
453
501
|
List of indices to be combined in a subset.
|
|
454
502
|
It can be a list of int or a list of list of int.
|
|
455
503
|
|
|
@@ -459,20 +507,22 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
459
507
|
A dictionary with the name of the split (a string) as key and the
|
|
460
508
|
dataset as value.
|
|
461
509
|
"""
|
|
462
|
-
|
|
463
|
-
|
|
510
|
+
|
|
511
|
+
args_not_none = [by is not None, property is not None, split_ids is not None]
|
|
464
512
|
if sum(args_not_none) != 1:
|
|
465
513
|
raise ValueError("Splitting requires exactly one argument.")
|
|
466
514
|
|
|
467
515
|
if property is not None or split_ids is not None:
|
|
468
|
-
warnings.warn(
|
|
469
|
-
|
|
470
|
-
|
|
516
|
+
warnings.warn(
|
|
517
|
+
"Keyword arguments `property` and `split_ids` "
|
|
518
|
+
"are deprecated and will be removed in the future. "
|
|
519
|
+
"Use `by` instead.",
|
|
520
|
+
DeprecationWarning,
|
|
521
|
+
)
|
|
471
522
|
by = property if property is not None else split_ids
|
|
472
523
|
if isinstance(by, str):
|
|
473
524
|
split_ids = {
|
|
474
|
-
k: list(v)
|
|
475
|
-
for k, v in self.description.groupby(by).groups.items()
|
|
525
|
+
k: list(v) for k, v in self.description.groupby(by).groups.items()
|
|
476
526
|
}
|
|
477
527
|
elif isinstance(by, dict):
|
|
478
528
|
split_ids = by
|
|
@@ -483,11 +533,15 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
483
533
|
# assume list(list(int))
|
|
484
534
|
split_ids = {split_i: split for split_i, split in enumerate(by)}
|
|
485
535
|
|
|
486
|
-
return {
|
|
487
|
-
|
|
488
|
-
|
|
536
|
+
return {
|
|
537
|
+
str(split_name): BaseConcatDataset(
|
|
538
|
+
[self.datasets[ds_ind] for ds_ind in ds_inds],
|
|
539
|
+
target_transform=self.target_transform,
|
|
540
|
+
)
|
|
541
|
+
for split_name, ds_inds in split_ids.items()
|
|
542
|
+
}
|
|
489
543
|
|
|
490
|
-
def get_metadata(self):
|
|
544
|
+
def get_metadata(self) -> pd.DataFrame:
|
|
491
545
|
"""Concatenate the metadata and description of the wrapped Epochs.
|
|
492
546
|
|
|
493
547
|
Returns
|
|
@@ -497,13 +551,20 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
497
551
|
BaseConcatDataset, with the metadata and description information
|
|
498
552
|
for each window.
|
|
499
553
|
"""
|
|
500
|
-
if not all(
|
|
501
|
-
|
|
502
|
-
|
|
554
|
+
if not all(
|
|
555
|
+
[
|
|
556
|
+
isinstance(ds, (WindowsDataset, EEGWindowsDataset))
|
|
557
|
+
for ds in self.datasets
|
|
558
|
+
]
|
|
559
|
+
):
|
|
560
|
+
raise TypeError(
|
|
561
|
+
"Metadata dataframe can only be computed when all "
|
|
562
|
+
"datasets are WindowsDataset."
|
|
563
|
+
)
|
|
503
564
|
|
|
504
565
|
all_dfs = list()
|
|
505
566
|
for ds in self.datasets:
|
|
506
|
-
if hasattr(ds,
|
|
567
|
+
if hasattr(ds, "windows"):
|
|
507
568
|
df = ds.windows.metadata
|
|
508
569
|
else:
|
|
509
570
|
df = ds.metadata
|
|
@@ -529,7 +590,7 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
529
590
|
@target_transform.setter
|
|
530
591
|
def target_transform(self, fn):
|
|
531
592
|
if not (callable(fn) or fn is None):
|
|
532
|
-
raise TypeError(
|
|
593
|
+
raise TypeError("target_transform must be a callable.")
|
|
533
594
|
self._target_transform = fn
|
|
534
595
|
|
|
535
596
|
def _outdated_save(self, path, overwrite=False):
|
|
@@ -547,39 +608,50 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
547
608
|
Whether to delete old files (.json, .fif, -epo.fif) in specified
|
|
548
609
|
directory prior to saving.
|
|
549
610
|
"""
|
|
550
|
-
warnings.warn(
|
|
551
|
-
|
|
611
|
+
warnings.warn(
|
|
612
|
+
"This function only exists for backwards compatibility "
|
|
613
|
+
"purposes. DO NOT USE!",
|
|
614
|
+
UserWarning,
|
|
615
|
+
)
|
|
552
616
|
if isinstance(self.datasets[0], EEGWindowsDataset):
|
|
553
|
-
raise NotImplementedError(
|
|
617
|
+
raise NotImplementedError(
|
|
618
|
+
"Outdated save not implemented for new window datasets."
|
|
619
|
+
)
|
|
554
620
|
if len(self.datasets) == 0:
|
|
555
621
|
raise ValueError("Expect at least one dataset")
|
|
556
|
-
if not (
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
622
|
+
if not (
|
|
623
|
+
hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
|
|
624
|
+
):
|
|
625
|
+
raise ValueError("dataset should have either raw or windows attribute")
|
|
560
626
|
file_name_templates = ["{}-raw.fif", "{}-epo.fif"]
|
|
561
|
-
description_file_name = os.path.join(path,
|
|
562
|
-
target_file_name = os.path.join(path,
|
|
627
|
+
description_file_name = os.path.join(path, "description.json")
|
|
628
|
+
target_file_name = os.path.join(path, "target_name.json")
|
|
563
629
|
if not overwrite:
|
|
564
|
-
from braindecode.datautil.serialization import
|
|
565
|
-
_check_save_dir_empty
|
|
630
|
+
from braindecode.datautil.serialization import ( # Import here to avoid circular import
|
|
631
|
+
_check_save_dir_empty,
|
|
632
|
+
)
|
|
633
|
+
|
|
566
634
|
_check_save_dir_empty(path)
|
|
567
635
|
else:
|
|
568
636
|
for file_name_template in file_name_templates:
|
|
569
|
-
file_names = glob(
|
|
570
|
-
path, f"*{file_name_template.lstrip('{}')}")
|
|
637
|
+
file_names = glob(
|
|
638
|
+
os.path.join(path, f"*{file_name_template.lstrip('{}')}")
|
|
639
|
+
)
|
|
571
640
|
_ = [os.remove(f) for f in file_names]
|
|
572
641
|
if os.path.isfile(target_file_name):
|
|
573
642
|
os.remove(target_file_name)
|
|
574
643
|
if os.path.isfile(description_file_name):
|
|
575
644
|
os.remove(description_file_name)
|
|
576
|
-
for kwarg_name in [
|
|
577
|
-
|
|
578
|
-
|
|
645
|
+
for kwarg_name in [
|
|
646
|
+
"raw_preproc_kwargs",
|
|
647
|
+
"window_kwargs",
|
|
648
|
+
"window_preproc_kwargs",
|
|
649
|
+
]:
|
|
650
|
+
kwarg_path = os.path.join(path, ".".join([kwarg_name, "json"]))
|
|
579
651
|
if os.path.exists(kwarg_path):
|
|
580
652
|
os.remove(kwarg_path)
|
|
581
653
|
|
|
582
|
-
is_raw = hasattr(self.datasets[0],
|
|
654
|
+
is_raw = hasattr(self.datasets[0], "raw")
|
|
583
655
|
|
|
584
656
|
if is_raw:
|
|
585
657
|
file_name_template = file_name_templates[0]
|
|
@@ -594,21 +666,26 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
594
666
|
ds.windows.save(full_file_path, overwrite=overwrite)
|
|
595
667
|
|
|
596
668
|
self.description.to_json(description_file_name)
|
|
597
|
-
for kwarg_name in [
|
|
598
|
-
|
|
669
|
+
for kwarg_name in [
|
|
670
|
+
"raw_preproc_kwargs",
|
|
671
|
+
"window_kwargs",
|
|
672
|
+
"window_preproc_kwargs",
|
|
673
|
+
]:
|
|
599
674
|
if hasattr(self, kwarg_name):
|
|
600
|
-
kwargs_path = os.path.join(path,
|
|
675
|
+
kwargs_path = os.path.join(path, ".".join([kwarg_name, "json"]))
|
|
601
676
|
kwargs = getattr(self, kwarg_name)
|
|
602
677
|
if kwargs is not None:
|
|
603
|
-
json.dump(kwargs, open(kwargs_path,
|
|
678
|
+
json.dump(kwargs, open(kwargs_path, "w"))
|
|
604
679
|
|
|
605
680
|
@property
|
|
606
|
-
def description(self):
|
|
681
|
+
def description(self) -> pd.DataFrame:
|
|
607
682
|
df = pd.DataFrame([ds.description for ds in self.datasets])
|
|
608
683
|
df.reset_index(inplace=True, drop=True)
|
|
609
684
|
return df
|
|
610
685
|
|
|
611
|
-
def set_description(
|
|
686
|
+
def set_description(
|
|
687
|
+
self, description: dict | pd.DataFrame, overwrite: bool = False
|
|
688
|
+
):
|
|
612
689
|
"""Update (add or overwrite) the dataset description.
|
|
613
690
|
|
|
614
691
|
Parameters
|
|
@@ -625,7 +702,7 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
625
702
|
for ds, value_ in zip(self.datasets, value):
|
|
626
703
|
ds.set_description({key: value_}, overwrite=overwrite)
|
|
627
704
|
|
|
628
|
-
def save(self, path, overwrite=False, offset=0):
|
|
705
|
+
def save(self, path: str, overwrite: bool = False, offset: int = 0):
|
|
629
706
|
"""Save datasets to files by creating one subdirectory for each dataset:
|
|
630
707
|
path/
|
|
631
708
|
0/
|
|
@@ -659,10 +736,10 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
659
736
|
"""
|
|
660
737
|
if len(self.datasets) == 0:
|
|
661
738
|
raise ValueError("Expect at least one dataset")
|
|
662
|
-
if not (
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
739
|
+
if not (
|
|
740
|
+
hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
|
|
741
|
+
):
|
|
742
|
+
raise ValueError("dataset should have either raw or windows attribute")
|
|
666
743
|
path_contents = os.listdir(path)
|
|
667
744
|
n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
|
|
668
745
|
for i_ds, ds in enumerate(self.datasets):
|
|
@@ -676,9 +753,10 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
676
753
|
shutil.rmtree(sub_dir)
|
|
677
754
|
else:
|
|
678
755
|
raise FileExistsError(
|
|
679
|
-
f
|
|
680
|
-
f
|
|
681
|
-
f
|
|
756
|
+
f"Subdirectory {sub_dir} already exists. Please select"
|
|
757
|
+
f" a different directory, set overwrite=True, or "
|
|
758
|
+
f"resolve manually."
|
|
759
|
+
)
|
|
682
760
|
# save_dir/{i_ds+offset}/
|
|
683
761
|
os.makedirs(sub_dir)
|
|
684
762
|
# save_dir/{i_ds+offset}/{i_ds+offset}-{raw_or_epo}.fif
|
|
@@ -696,59 +774,67 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
696
774
|
if overwrite:
|
|
697
775
|
# the following will be True for all datasets preprocessed and
|
|
698
776
|
# stored in parallel with braindecode.preprocessing.preprocess
|
|
699
|
-
if i_ds+1+offset < n_sub_dirs:
|
|
700
|
-
warnings.warn(
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
777
|
+
if i_ds + 1 + offset < n_sub_dirs:
|
|
778
|
+
warnings.warn(
|
|
779
|
+
f"The number of saved datasets ({i_ds + 1 + offset}) "
|
|
780
|
+
f"does not match the number of existing "
|
|
781
|
+
f"subdirectories ({n_sub_dirs}). You may now "
|
|
782
|
+
f"encounter a mix of differently preprocessed "
|
|
783
|
+
f"datasets!",
|
|
784
|
+
UserWarning,
|
|
785
|
+
)
|
|
705
786
|
# if path contains files or directories that were not touched, raise
|
|
706
787
|
# warning
|
|
707
788
|
if path_contents:
|
|
708
|
-
warnings.warn(
|
|
709
|
-
|
|
789
|
+
warnings.warn(
|
|
790
|
+
f"Chosen directory {path} contains other "
|
|
791
|
+
f"subdirectories or files {path_contents}."
|
|
792
|
+
)
|
|
710
793
|
|
|
711
794
|
@staticmethod
|
|
712
795
|
def _save_signals(sub_dir, ds, i_ds, offset):
|
|
713
|
-
raw_or_epo =
|
|
714
|
-
fif_file_name = f
|
|
796
|
+
raw_or_epo = "raw" if hasattr(ds, "raw") else "epo"
|
|
797
|
+
fif_file_name = f"{i_ds + offset}-{raw_or_epo}.fif"
|
|
715
798
|
fif_file_path = os.path.join(sub_dir, fif_file_name)
|
|
716
|
-
raw_or_windows =
|
|
799
|
+
raw_or_windows = "raw" if raw_or_epo == "raw" else "windows"
|
|
717
800
|
|
|
718
801
|
# The following appears to be necessary to avoid a CI failure when
|
|
719
802
|
# preprocessing WindowsDatasets with serialization enabled. The failure
|
|
720
803
|
# comes from `mne.epochs._check_consistency` which ensures the Epochs's
|
|
721
804
|
# object `times` attribute is not writeable.
|
|
722
|
-
getattr(ds, raw_or_windows).times.flags[
|
|
805
|
+
getattr(ds, raw_or_windows).times.flags["WRITEABLE"] = False
|
|
723
806
|
|
|
724
807
|
getattr(ds, raw_or_windows).save(fif_file_path)
|
|
725
808
|
|
|
726
809
|
@staticmethod
|
|
727
810
|
def _save_metadata(sub_dir, ds):
|
|
728
|
-
if hasattr(ds,
|
|
729
|
-
metadata_file_path = os.path.join(sub_dir,
|
|
811
|
+
if hasattr(ds, "metadata"):
|
|
812
|
+
metadata_file_path = os.path.join(sub_dir, "metadata_df.pkl")
|
|
730
813
|
ds.metadata.to_pickle(metadata_file_path)
|
|
731
814
|
|
|
732
815
|
@staticmethod
|
|
733
816
|
def _save_description(sub_dir, description):
|
|
734
|
-
description_file_path = os.path.join(sub_dir,
|
|
817
|
+
description_file_path = os.path.join(sub_dir, "description.json")
|
|
735
818
|
description.to_json(description_file_path)
|
|
736
819
|
|
|
737
820
|
@staticmethod
|
|
738
821
|
def _save_kwargs(sub_dir, ds):
|
|
739
|
-
for kwargs_name in [
|
|
740
|
-
|
|
822
|
+
for kwargs_name in [
|
|
823
|
+
"raw_preproc_kwargs",
|
|
824
|
+
"window_kwargs",
|
|
825
|
+
"window_preproc_kwargs",
|
|
826
|
+
]:
|
|
741
827
|
if hasattr(ds, kwargs_name):
|
|
742
|
-
kwargs_file_name =
|
|
828
|
+
kwargs_file_name = ".".join([kwargs_name, "json"])
|
|
743
829
|
kwargs_file_path = os.path.join(sub_dir, kwargs_file_name)
|
|
744
830
|
kwargs = getattr(ds, kwargs_name)
|
|
745
831
|
if kwargs is not None:
|
|
746
|
-
with open(kwargs_file_path,
|
|
832
|
+
with open(kwargs_file_path, "w") as f:
|
|
747
833
|
json.dump(kwargs, f)
|
|
748
834
|
|
|
749
835
|
@staticmethod
|
|
750
836
|
def _save_target_name(sub_dir, ds):
|
|
751
|
-
if hasattr(ds,
|
|
752
|
-
target_file_path = os.path.join(sub_dir,
|
|
753
|
-
with open(target_file_path,
|
|
754
|
-
json.dump({
|
|
837
|
+
if hasattr(ds, "target_name"):
|
|
838
|
+
target_file_path = os.path.join(sub_dir, "target_name.json")
|
|
839
|
+
with open(target_file_path, "w") as f:
|
|
840
|
+
json.dump({"target_name": ds.target_name}, f)
|