braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev171178473__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/base.py +1 -1
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/__init__.py +10 -2
- braindecode/datasets/base.py +116 -152
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +218 -0
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +2 -2
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +4 -3
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/serialization.py +18 -13
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +12 -8
- braindecode/models/atcnet.py +156 -17
- braindecode/models/attentionbasenet.py +148 -16
- braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +3 -1
- braindecode/models/ctnet.py +7 -4
- braindecode/models/deep4.py +6 -2
- braindecode/models/deepsleepnet.py +127 -5
- braindecode/models/eegconformer.py +114 -15
- braindecode/models/eeginception_erp.py +82 -7
- braindecode/models/eeginception_mi.py +2 -0
- braindecode/models/eegnet.py +64 -177
- braindecode/models/eegnex.py +113 -6
- braindecode/models/eegsimpleconv.py +2 -0
- braindecode/models/eegtcnet.py +1 -1
- braindecode/models/labram.py +188 -84
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/sccnet.py +81 -8
- braindecode/models/shallow_fbcsp.py +2 -0
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sleep_stager_blanco_2020.py +2 -0
- braindecode/models/sleep_stager_chambon_2018.py +2 -0
- braindecode/models/sparcnet.py +2 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +42 -41
- braindecode/models/tidnet.py +2 -0
- braindecode/models/tsinception.py +15 -3
- braindecode/models/usleep.py +108 -9
- braindecode/models/util.py +8 -5
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +24 -0
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/preprocess.py +42 -39
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +24 -19
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/METADATA +12 -3
- braindecode-1.3.0.dev171178473.dist-info/RECORD +106 -0
- braindecode/models/eegresnet.py +0 -362
- braindecode-1.2.0.dev184328194.dist-info/RECORD +0 -101
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/WHEEL +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171178473.dist-info}/top_level.txt +0 -0
braindecode/datasets/base.py
CHANGED
|
@@ -16,14 +16,17 @@ import json
|
|
|
16
16
|
import os
|
|
17
17
|
import shutil
|
|
18
18
|
import warnings
|
|
19
|
+
from abc import abstractmethod
|
|
19
20
|
from collections.abc import Callable
|
|
20
21
|
from glob import glob
|
|
21
|
-
from typing import Iterable, no_type_check
|
|
22
|
+
from typing import Generic, Iterable, no_type_check
|
|
22
23
|
|
|
23
24
|
import mne.io
|
|
24
25
|
import numpy as np
|
|
25
26
|
import pandas as pd
|
|
27
|
+
from mne.utils.docs import deprecated
|
|
26
28
|
from torch.utils.data import ConcatDataset, Dataset
|
|
29
|
+
from typing_extensions import TypeVar
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
def _create_description(description) -> pd.Series:
|
|
@@ -37,7 +40,64 @@ def _create_description(description) -> pd.Series:
|
|
|
37
40
|
return description
|
|
38
41
|
|
|
39
42
|
|
|
40
|
-
class
|
|
43
|
+
class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]]):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
description: dict | pd.Series | None = None,
|
|
47
|
+
transform: Callable | None = None,
|
|
48
|
+
):
|
|
49
|
+
self._description = _create_description(description)
|
|
50
|
+
self.transform = transform
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def __len__(self) -> int:
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def description(self) -> pd.Series:
|
|
58
|
+
return self._description
|
|
59
|
+
|
|
60
|
+
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
61
|
+
"""Update (add or overwrite) the dataset description.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
description: dict | pd.Series
|
|
66
|
+
Description in the form key: value.
|
|
67
|
+
overwrite: bool
|
|
68
|
+
Has to be True if a key in description already exists in the
|
|
69
|
+
dataset description.
|
|
70
|
+
"""
|
|
71
|
+
description = _create_description(description)
|
|
72
|
+
if self.description is None:
|
|
73
|
+
self._description = description
|
|
74
|
+
else:
|
|
75
|
+
for key, value in description.items():
|
|
76
|
+
# if the key is already in the existing description, drop it
|
|
77
|
+
if key in self._description:
|
|
78
|
+
assert overwrite, (
|
|
79
|
+
f"'{key}' already in description. Please "
|
|
80
|
+
f"rename or set overwrite to True."
|
|
81
|
+
)
|
|
82
|
+
self._description.pop(key)
|
|
83
|
+
self._description = pd.concat([self.description, description])
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def transform(self) -> Callable | None:
|
|
87
|
+
return self._transform
|
|
88
|
+
|
|
89
|
+
@transform.setter
|
|
90
|
+
def transform(self, value: Callable | None):
|
|
91
|
+
if value is not None and not callable(value):
|
|
92
|
+
raise ValueError("Transform needs to be a callable.")
|
|
93
|
+
self._transform = value
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Type of the datasets contained in BaseConcatDataset
|
|
97
|
+
T = TypeVar("T", bound=RecordDataset)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class RawDataset(RecordDataset):
|
|
41
101
|
"""Returns samples from an mne.io.Raw object along with a target.
|
|
42
102
|
|
|
43
103
|
Dataset which serves samples from an mne.io.Raw object along with a target.
|
|
@@ -64,9 +124,8 @@ class BaseDataset(Dataset):
|
|
|
64
124
|
target_name: str | tuple[str, ...] | None = None,
|
|
65
125
|
transform: Callable | None = None,
|
|
66
126
|
):
|
|
127
|
+
super().__init__(description, transform)
|
|
67
128
|
self.raw = raw
|
|
68
|
-
self._description = _create_description(description)
|
|
69
|
-
self.transform = transform
|
|
70
129
|
|
|
71
130
|
# save target name for load/save later
|
|
72
131
|
self.target_name = self._target_name(target_name)
|
|
@@ -85,45 +144,6 @@ class BaseDataset(Dataset):
|
|
|
85
144
|
def __len__(self):
|
|
86
145
|
return len(self.raw)
|
|
87
146
|
|
|
88
|
-
@property
|
|
89
|
-
def transform(self):
|
|
90
|
-
return self._transform
|
|
91
|
-
|
|
92
|
-
@transform.setter
|
|
93
|
-
def transform(self, value):
|
|
94
|
-
if value is not None and not callable(value):
|
|
95
|
-
raise ValueError("Transform needs to be a callable.")
|
|
96
|
-
self._transform = value
|
|
97
|
-
|
|
98
|
-
@property
|
|
99
|
-
def description(self) -> pd.Series:
|
|
100
|
-
return self._description
|
|
101
|
-
|
|
102
|
-
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
103
|
-
"""Update (add or overwrite) the dataset description.
|
|
104
|
-
|
|
105
|
-
Parameters
|
|
106
|
-
----------
|
|
107
|
-
description: dict | pd.Series
|
|
108
|
-
Description in the form key: value.
|
|
109
|
-
overwrite: bool
|
|
110
|
-
Has to be True if a key in description already exists in the
|
|
111
|
-
dataset description.
|
|
112
|
-
"""
|
|
113
|
-
description = _create_description(description)
|
|
114
|
-
for key, value in description.items():
|
|
115
|
-
# if the key is already in the existing description, drop it
|
|
116
|
-
if self._description is not None and key in self._description:
|
|
117
|
-
assert overwrite, (
|
|
118
|
-
f"'{key}' already in description. Please "
|
|
119
|
-
f"rename or set overwrite to True."
|
|
120
|
-
)
|
|
121
|
-
self._description.pop(key)
|
|
122
|
-
if self._description is None:
|
|
123
|
-
self._description = description
|
|
124
|
-
else:
|
|
125
|
-
self._description = pd.concat([self.description, description])
|
|
126
|
-
|
|
127
147
|
def _target_name(self, target_name):
|
|
128
148
|
if target_name is not None and not isinstance(target_name, (str, tuple, list)):
|
|
129
149
|
raise ValueError("target_name has to be None, str, tuple or list")
|
|
@@ -150,7 +170,17 @@ class BaseDataset(Dataset):
|
|
|
150
170
|
return target_name if len(target_name) > 1 else target_name[0]
|
|
151
171
|
|
|
152
172
|
|
|
153
|
-
|
|
173
|
+
@deprecated(
|
|
174
|
+
"The BaseDataset class is deprecated. "
|
|
175
|
+
"If you want to instantiate a dataset containing raws, use RawDataset instead. "
|
|
176
|
+
"If you want to type a Braindecode dataset (i.e. RawDataset|EEGWindowsDataset|WindowsDataset), "
|
|
177
|
+
"use the RecordDataset class instead."
|
|
178
|
+
)
|
|
179
|
+
class BaseDataset(RawDataset):
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class EEGWindowsDataset(RecordDataset):
|
|
154
184
|
"""Returns windows from an mne.Raw object, its window indices, along with a target.
|
|
155
185
|
|
|
156
186
|
Dataset which serves windows from an mne.Epochs object along with their
|
|
@@ -161,12 +191,12 @@ class EEGWindowsDataset(BaseDataset):
|
|
|
161
191
|
required to serve information about the windowing (e.g., useful for cropped
|
|
162
192
|
training).
|
|
163
193
|
See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
|
|
164
|
-
from a `
|
|
194
|
+
from a `RawDataset` object.
|
|
165
195
|
|
|
166
196
|
Parameters
|
|
167
197
|
----------
|
|
168
198
|
windows : mne.Raw or mne.Epochs (Epochs is outdated)
|
|
169
|
-
Windows obtained through the application of a windower to a
|
|
199
|
+
Windows obtained through the application of a windower to a ``RawDataset``
|
|
170
200
|
(see `braindecode.datautil.windowers`).
|
|
171
201
|
description : dict | pandas.Series | None
|
|
172
202
|
Holds additional info about the windows.
|
|
@@ -185,18 +215,17 @@ class EEGWindowsDataset(BaseDataset):
|
|
|
185
215
|
|
|
186
216
|
def __init__(
|
|
187
217
|
self,
|
|
188
|
-
raw: mne.io.BaseRaw
|
|
218
|
+
raw: mne.io.BaseRaw,
|
|
189
219
|
metadata: pd.DataFrame,
|
|
190
220
|
description: dict | pd.Series | None = None,
|
|
191
221
|
transform: Callable | None = None,
|
|
192
222
|
targets_from: str = "metadata",
|
|
193
223
|
last_target_only: bool = True,
|
|
194
224
|
):
|
|
225
|
+
super().__init__(description, transform)
|
|
195
226
|
self.raw = raw
|
|
196
227
|
self.metadata = metadata
|
|
197
|
-
self._description = _create_description(description)
|
|
198
228
|
|
|
199
|
-
self.transform = transform
|
|
200
229
|
self.last_target_only = last_target_only
|
|
201
230
|
if targets_from not in ("metadata", "channels"):
|
|
202
231
|
raise ValueError("Wrong value for parameter `targets_from`.")
|
|
@@ -255,44 +284,8 @@ class EEGWindowsDataset(BaseDataset):
|
|
|
255
284
|
def __len__(self):
|
|
256
285
|
return len(self.crop_inds)
|
|
257
286
|
|
|
258
|
-
@property
|
|
259
|
-
def transform(self):
|
|
260
|
-
return self._transform
|
|
261
|
-
|
|
262
|
-
@transform.setter
|
|
263
|
-
def transform(self, value):
|
|
264
|
-
if value is not None and not callable(value):
|
|
265
|
-
raise ValueError("Transform needs to be a callable.")
|
|
266
|
-
self._transform = value
|
|
267
|
-
|
|
268
|
-
@property
|
|
269
|
-
def description(self) -> pd.Series:
|
|
270
|
-
return self._description
|
|
271
|
-
|
|
272
|
-
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
273
|
-
"""Update (add or overwrite) the dataset description.
|
|
274
|
-
|
|
275
|
-
Parameters
|
|
276
|
-
----------
|
|
277
|
-
description: dict | pd.Series
|
|
278
|
-
Description in the form key: value.
|
|
279
|
-
overwrite: bool
|
|
280
|
-
Has to be True if a key in description already exists in the
|
|
281
|
-
dataset description.
|
|
282
|
-
"""
|
|
283
|
-
description = _create_description(description)
|
|
284
|
-
for key, value in description.items():
|
|
285
|
-
# if they key is already in the existing description, drop it
|
|
286
|
-
if key in self._description:
|
|
287
|
-
assert overwrite, (
|
|
288
|
-
f"'{key}' already in description. Please "
|
|
289
|
-
f"rename or set overwrite to True."
|
|
290
|
-
)
|
|
291
|
-
self._description.pop(key)
|
|
292
|
-
self._description = pd.concat([self.description, description])
|
|
293
287
|
|
|
294
|
-
|
|
295
|
-
class WindowsDataset(BaseDataset):
|
|
288
|
+
class WindowsDataset(RecordDataset):
|
|
296
289
|
"""Returns windows from an mne.Epochs object along with a target.
|
|
297
290
|
|
|
298
291
|
Dataset which serves windows from an mne.Epochs object along with their
|
|
@@ -303,12 +296,12 @@ class WindowsDataset(BaseDataset):
|
|
|
303
296
|
required to serve information about the windowing (e.g., useful for cropped
|
|
304
297
|
training).
|
|
305
298
|
See `braindecode.datautil.windowers` to directly create a `WindowsDataset`
|
|
306
|
-
from a
|
|
299
|
+
from a ``RawDataset`` object.
|
|
307
300
|
|
|
308
301
|
Parameters
|
|
309
302
|
----------
|
|
310
303
|
windows : mne.Epochs
|
|
311
|
-
Windows obtained through the application of a windower to a
|
|
304
|
+
Windows obtained through the application of a windower to a RawDataset
|
|
312
305
|
(see `braindecode.datautil.windowers`).
|
|
313
306
|
description : dict | pandas.Series | None
|
|
314
307
|
Holds additional info about the windows.
|
|
@@ -327,19 +320,20 @@ class WindowsDataset(BaseDataset):
|
|
|
327
320
|
targets_from: str = "metadata",
|
|
328
321
|
last_target_only: bool = True,
|
|
329
322
|
):
|
|
323
|
+
super().__init__(description, transform)
|
|
330
324
|
self.windows = windows
|
|
331
|
-
self._description = _create_description(description)
|
|
332
|
-
self.transform = transform
|
|
333
325
|
self.last_target_only = last_target_only
|
|
334
326
|
if targets_from not in ("metadata", "channels"):
|
|
335
327
|
raise ValueError("Wrong value for parameter `targets_from`.")
|
|
336
328
|
self.targets_from = targets_from
|
|
337
329
|
|
|
338
|
-
|
|
330
|
+
metadata = self.windows.metadata
|
|
331
|
+
assert metadata is not None, "WindowsDataset requires windows with metadata."
|
|
332
|
+
self.crop_inds = metadata.loc[
|
|
339
333
|
:, ["i_window_in_trial", "i_start_in_trial", "i_stop_in_trial"]
|
|
340
334
|
].to_numpy()
|
|
341
335
|
if self.targets_from == "metadata":
|
|
342
|
-
self.y =
|
|
336
|
+
self.y = metadata.loc[:, "target"].to_list()
|
|
343
337
|
|
|
344
338
|
def __getitem__(self, index: int):
|
|
345
339
|
"""Get a window and its target.
|
|
@@ -379,44 +373,8 @@ class WindowsDataset(BaseDataset):
|
|
|
379
373
|
def __len__(self) -> int:
|
|
380
374
|
return len(self.windows.events)
|
|
381
375
|
|
|
382
|
-
@property
|
|
383
|
-
def transform(self):
|
|
384
|
-
return self._transform
|
|
385
|
-
|
|
386
|
-
@transform.setter
|
|
387
|
-
def transform(self, value):
|
|
388
|
-
if value is not None and not callable(value):
|
|
389
|
-
raise ValueError("Transform needs to be a callable.")
|
|
390
|
-
self._transform = value
|
|
391
376
|
|
|
392
|
-
|
|
393
|
-
def description(self) -> pd.Series:
|
|
394
|
-
return self._description
|
|
395
|
-
|
|
396
|
-
def set_description(self, description: dict | pd.Series, overwrite: bool = False):
|
|
397
|
-
"""Update (add or overwrite) the dataset description.
|
|
398
|
-
|
|
399
|
-
Parameters
|
|
400
|
-
----------
|
|
401
|
-
description: dict | pd.Series
|
|
402
|
-
Description in the form key: value.
|
|
403
|
-
overwrite: bool
|
|
404
|
-
Has to be True if a key in description already exists in the
|
|
405
|
-
dataset description.
|
|
406
|
-
"""
|
|
407
|
-
description = _create_description(description)
|
|
408
|
-
for key, value in description.items():
|
|
409
|
-
# if they key is already in the existing description, drop it
|
|
410
|
-
if key in self._description:
|
|
411
|
-
assert overwrite, (
|
|
412
|
-
f"'{key}' already in description. Please "
|
|
413
|
-
f"rename or set overwrite to True."
|
|
414
|
-
)
|
|
415
|
-
self._description.pop(key)
|
|
416
|
-
self._description = pd.concat([self.description, description])
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
class BaseConcatDataset(ConcatDataset):
|
|
377
|
+
class BaseConcatDataset(ConcatDataset, Generic[T]):
|
|
420
378
|
"""A base class for concatenated datasets.
|
|
421
379
|
|
|
422
380
|
Holds either mne.Raw or mne.Epoch in self.datasets and has
|
|
@@ -425,22 +383,27 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
425
383
|
Parameters
|
|
426
384
|
----------
|
|
427
385
|
list_of_ds : list
|
|
428
|
-
list of
|
|
386
|
+
list of RecordDataset
|
|
429
387
|
target_transform : callable | None
|
|
430
388
|
Optional function to call on targets before returning them.
|
|
431
389
|
|
|
432
390
|
"""
|
|
433
391
|
|
|
392
|
+
datasets: list[T]
|
|
393
|
+
|
|
434
394
|
def __init__(
|
|
435
395
|
self,
|
|
436
|
-
list_of_ds: list[
|
|
437
|
-
| None = None,
|
|
396
|
+
list_of_ds: list[T | BaseConcatDataset[T]],
|
|
438
397
|
target_transform: Callable | None = None,
|
|
439
398
|
):
|
|
440
399
|
# if we get a list of BaseConcatDataset, get all the individual datasets
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
400
|
+
flattened_list_of_ds: list[T] = []
|
|
401
|
+
for ds in list_of_ds:
|
|
402
|
+
if isinstance(ds, BaseConcatDataset):
|
|
403
|
+
flattened_list_of_ds.extend(ds.datasets)
|
|
404
|
+
else:
|
|
405
|
+
flattened_list_of_ds.append(ds)
|
|
406
|
+
super().__init__(flattened_list_of_ds)
|
|
444
407
|
|
|
445
408
|
self.target_transform = target_transform
|
|
446
409
|
|
|
@@ -703,22 +666,23 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
703
666
|
ds.set_description({key: value_}, overwrite=overwrite)
|
|
704
667
|
|
|
705
668
|
def save(self, path: str, overwrite: bool = False, offset: int = 0):
|
|
706
|
-
"""Save datasets to files by creating one subdirectory for each dataset
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
0
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
1
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
669
|
+
"""Save datasets to files by creating one subdirectory for each dataset::
|
|
670
|
+
|
|
671
|
+
path/
|
|
672
|
+
0/
|
|
673
|
+
0-raw.fif | 0-epo.fif
|
|
674
|
+
description.json
|
|
675
|
+
raw_preproc_kwargs.json (if raws were preprocessed)
|
|
676
|
+
window_kwargs.json (if this is a windowed dataset)
|
|
677
|
+
window_preproc_kwargs.json (if windows were preprocessed)
|
|
678
|
+
target_name.json (if target_name is not None and dataset is raw)
|
|
679
|
+
1/
|
|
680
|
+
1-raw.fif | 1-epo.fif
|
|
681
|
+
description.json
|
|
682
|
+
raw_preproc_kwargs.json (if raws were preprocessed)
|
|
683
|
+
window_kwargs.json (if this is a windowed dataset)
|
|
684
|
+
window_preproc_kwargs.json (if windows were preprocessed)
|
|
685
|
+
target_name.json (if target_name is not None and dataset is raw)
|
|
722
686
|
|
|
723
687
|
Parameters
|
|
724
688
|
----------
|
|
@@ -815,7 +779,7 @@ class BaseConcatDataset(ConcatDataset):
|
|
|
815
779
|
@staticmethod
|
|
816
780
|
def _save_description(sub_dir, description):
|
|
817
781
|
description_file_path = os.path.join(sub_dir, "description.json")
|
|
818
|
-
description.to_json(description_file_path)
|
|
782
|
+
description.to_json(description_file_path, default_handler=str)
|
|
819
783
|
|
|
820
784
|
@staticmethod
|
|
821
785
|
def _save_kwargs(sub_dir, ds):
|
braindecode/datasets/bcicomp.py
CHANGED
|
@@ -16,7 +16,7 @@ import numpy as np
|
|
|
16
16
|
from mne.utils import verbose
|
|
17
17
|
from scipy.io import loadmat
|
|
18
18
|
|
|
19
|
-
from braindecode.datasets import BaseConcatDataset,
|
|
19
|
+
from braindecode.datasets import BaseConcatDataset, RawDataset
|
|
20
20
|
|
|
21
21
|
DATASET_URL = (
|
|
22
22
|
"https://stacks.stanford.edu/file/druid:zk881ps0522/"
|
|
@@ -73,8 +73,8 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
73
73
|
file_name=file_path.split("/")[-1],
|
|
74
74
|
session="test",
|
|
75
75
|
)
|
|
76
|
-
datasets.append(
|
|
77
|
-
datasets.append(
|
|
76
|
+
datasets.append(RawDataset(raw_train, description=desc_train))
|
|
77
|
+
datasets.append(RawDataset(raw_test, description=desc_test))
|
|
78
78
|
super().__init__(datasets)
|
|
79
79
|
|
|
80
80
|
@staticmethod
|
|
@@ -85,7 +85,7 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
85
85
|
----------
|
|
86
86
|
path (None | str) – Location of where to look for the data storing location.
|
|
87
87
|
If None, the environment variable or config parameter
|
|
88
|
-
MNE_DATASETS_(dataset)_PATH is used. If it doesn
|
|
88
|
+
MNE_DATASETS_(dataset)_PATH is used. If it doesn't exist, the “~/mne_data”
|
|
89
89
|
directory is used. If the dataset is not found under the given path, the data
|
|
90
90
|
will be automatically downloaded to the specified folder.
|
|
91
91
|
force_update (bool) – Force update of the dataset even if a local copy exists.
|
braindecode/datasets/bids.py
CHANGED
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
|
19
19
|
import pandas as pd
|
|
20
20
|
from joblib import Parallel, delayed
|
|
21
21
|
|
|
22
|
-
from .base import BaseConcatDataset,
|
|
22
|
+
from .base import BaseConcatDataset, RawDataset, WindowsDataset
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
|
|
@@ -186,12 +186,12 @@ class BIDSDataset(BaseConcatDataset):
|
|
|
186
186
|
)
|
|
187
187
|
super().__init__(all_base_ds)
|
|
188
188
|
|
|
189
|
-
def _get_dataset(self, bids_path: mne_bids.BIDSPath) ->
|
|
189
|
+
def _get_dataset(self, bids_path: mne_bids.BIDSPath) -> RawDataset:
|
|
190
190
|
description = _description_from_bids_path(bids_path)
|
|
191
191
|
raw = mne_bids.read_raw_bids(bids_path, verbose=False)
|
|
192
192
|
if self.preload:
|
|
193
193
|
raw.load_data()
|
|
194
|
-
return
|
|
194
|
+
return RawDataset(raw, description)
|
|
195
195
|
|
|
196
196
|
|
|
197
197
|
class BIDSEpochsDataset(BIDSDataset):
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Callable, Sequence
|
|
6
|
+
|
|
7
|
+
import mne_bids
|
|
8
|
+
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BIDSIterableDataset(IterableDataset):
|
|
12
|
+
"""Dataset for loading BIDS.
|
|
13
|
+
|
|
14
|
+
.. warning::
|
|
15
|
+
This class is experimental and may change in the future.
|
|
16
|
+
|
|
17
|
+
.. warning::
|
|
18
|
+
This dataset is not consistent with the Braindecode API.
|
|
19
|
+
|
|
20
|
+
This class has the same parameters as the :func:`mne_bids.find_matching_paths` function
|
|
21
|
+
as it will be used to find the files to load. The default ``extensions`` parameter was changed.
|
|
22
|
+
|
|
23
|
+
More information on BIDS (Brain Imaging Data Structure)
|
|
24
|
+
can be found at https://bids.neuroimaging.io
|
|
25
|
+
|
|
26
|
+
Examples
|
|
27
|
+
--------
|
|
28
|
+
>>> from braindecode.datasets import RecordDataset, BaseConcatDataset
|
|
29
|
+
>>> from braindecode.datasets.bids import BIDSIterableDataset, _description_from_bids_path
|
|
30
|
+
>>> from braindecode.preprocessing import create_fixed_length_windows
|
|
31
|
+
>>>
|
|
32
|
+
>>> def my_reader_fn(path):
|
|
33
|
+
... raw = mne_bids.read_raw_bids(path)
|
|
34
|
+
... desc = _description_from_bids_path(path)
|
|
35
|
+
... ds = RawDataset(raw, description=desc)
|
|
36
|
+
... windows_ds = create_fixed_length_windows(
|
|
37
|
+
... BaseConcatDataset([ds]),
|
|
38
|
+
... window_size_samples=400,
|
|
39
|
+
... window_stride_samples=200,
|
|
40
|
+
... )
|
|
41
|
+
... return windows_ds
|
|
42
|
+
>>>
|
|
43
|
+
>>> dataset = BIDSIterableDataset(
|
|
44
|
+
... reader_fn=my_reader_fn,
|
|
45
|
+
... root="root/of/my/bids/dataset/",
|
|
46
|
+
... )
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
reader_fn : Callable[[mne_bids.BIDSPath], Sequence]
|
|
51
|
+
A function that takes a BIDSPath and returns a dataset.
|
|
52
|
+
pool_size : int
|
|
53
|
+
The number of recordings to read and sample from.
|
|
54
|
+
bids_paths : list[mne_bids.BIDSPath] | None
|
|
55
|
+
A list of BIDSPaths to load. If None, will use the paths found by
|
|
56
|
+
:func:`mne_bids.find_matching_paths` and the arguments below.
|
|
57
|
+
root : pathlib.Path | str
|
|
58
|
+
The root of the BIDS path.
|
|
59
|
+
subjects : str | array-like of str | None
|
|
60
|
+
The subject ID. Corresponds to "sub".
|
|
61
|
+
sessions : str | array-like of str | None
|
|
62
|
+
The acquisition session. Corresponds to "ses".
|
|
63
|
+
tasks : str | array-like of str | None
|
|
64
|
+
The experimental task. Corresponds to "task".
|
|
65
|
+
acquisitions: str | array-like of str | None
|
|
66
|
+
The acquisition parameters. Corresponds to "acq".
|
|
67
|
+
runs : str | array-like of str | None
|
|
68
|
+
The run number. Corresponds to "run".
|
|
69
|
+
processings : str | array-like of str | None
|
|
70
|
+
The processing label. Corresponds to "proc".
|
|
71
|
+
recordings : str | array-like of str | None
|
|
72
|
+
The recording name. Corresponds to "rec".
|
|
73
|
+
spaces : str | array-like of str | None
|
|
74
|
+
The coordinate space for anatomical and sensor location
|
|
75
|
+
files (e.g., ``*_electrodes.tsv``, ``*_markers.mrk``).
|
|
76
|
+
Corresponds to "space".
|
|
77
|
+
Note that valid values for ``space`` must come from a list
|
|
78
|
+
of BIDS keywords as described in the BIDS specification.
|
|
79
|
+
splits : str | array-like of str | None
|
|
80
|
+
The split of the continuous recording file for ``.fif`` data.
|
|
81
|
+
Corresponds to "split".
|
|
82
|
+
descriptions : str | array-like of str | None
|
|
83
|
+
This corresponds to the BIDS entity ``desc``. It is used to provide
|
|
84
|
+
additional information for derivative data, e.g., preprocessed data
|
|
85
|
+
may be assigned ``description='cleaned'``.
|
|
86
|
+
suffixes : str | array-like of str | None
|
|
87
|
+
The filename suffix. This is the entity after the
|
|
88
|
+
last ``_`` before the extension. E.g., ``'channels'``.
|
|
89
|
+
The following filename suffix's are accepted:
|
|
90
|
+
'meg', 'markers', 'eeg', 'ieeg', 'T1w',
|
|
91
|
+
'participants', 'scans', 'electrodes', 'coordsystem',
|
|
92
|
+
'channels', 'events', 'headshape', 'digitizer',
|
|
93
|
+
'beh', 'physio', 'stim'
|
|
94
|
+
extensions : str | array-like of str | None
|
|
95
|
+
The extension of the filename. E.g., ``'.json'``.
|
|
96
|
+
By default, uses the ones accepted by :func:`mne_bids.read_raw_bids`.
|
|
97
|
+
datatypes : str | array-like of str | None
|
|
98
|
+
The BIDS data type, e.g., ``'anat'``, ``'func'``, ``'eeg'``, ``'meg'``,
|
|
99
|
+
``'ieeg'``.
|
|
100
|
+
check : bool
|
|
101
|
+
If ``True``, only returns paths that conform to BIDS. If ``False``
|
|
102
|
+
(default), the ``.check`` attribute of the returned
|
|
103
|
+
:class:`mne_bids.BIDSPath` object will be set to ``True`` for paths that
|
|
104
|
+
do conform to BIDS, and to ``False`` for those that don't.
|
|
105
|
+
preload : bool
|
|
106
|
+
If True, preload the data. Defaults to False.
|
|
107
|
+
n_jobs : int
|
|
108
|
+
Number of jobs to run in parallel. Defaults to 1.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
reader_fn: Callable[[mne_bids.BIDSPath], Sequence],
|
|
114
|
+
pool_size: int = 4,
|
|
115
|
+
bids_paths: list[mne_bids.BIDSPath] | None = None,
|
|
116
|
+
root: Path | str | None = None,
|
|
117
|
+
subjects: str | list[str] | None = None,
|
|
118
|
+
sessions: str | list[str] | None = None,
|
|
119
|
+
tasks: str | list[str] | None = None,
|
|
120
|
+
acquisitions: str | list[str] | None = None,
|
|
121
|
+
runs: str | list[str] | None = None,
|
|
122
|
+
processings: str | list[str] | None = None,
|
|
123
|
+
recordings: str | list[str] | None = None,
|
|
124
|
+
spaces: str | list[str] | None = None,
|
|
125
|
+
splits: str | list[str] | None = None,
|
|
126
|
+
descriptions: str | list[str] | None = None,
|
|
127
|
+
suffixes: str | list[str] | None = None,
|
|
128
|
+
extensions: str | list[str] | None = [
|
|
129
|
+
".con",
|
|
130
|
+
".sqd",
|
|
131
|
+
".pdf",
|
|
132
|
+
".fif",
|
|
133
|
+
".ds",
|
|
134
|
+
".vhdr",
|
|
135
|
+
".set",
|
|
136
|
+
".edf",
|
|
137
|
+
".bdf",
|
|
138
|
+
".EDF",
|
|
139
|
+
".snirf",
|
|
140
|
+
".cdt",
|
|
141
|
+
".mef",
|
|
142
|
+
".nwb",
|
|
143
|
+
],
|
|
144
|
+
datatypes: str | list[str] | None = None,
|
|
145
|
+
check: bool = False,
|
|
146
|
+
):
|
|
147
|
+
if bids_paths is None:
|
|
148
|
+
bids_paths = mne_bids.find_matching_paths(
|
|
149
|
+
root=root,
|
|
150
|
+
subjects=subjects,
|
|
151
|
+
sessions=sessions,
|
|
152
|
+
tasks=tasks,
|
|
153
|
+
acquisitions=acquisitions,
|
|
154
|
+
runs=runs,
|
|
155
|
+
processings=processings,
|
|
156
|
+
recordings=recordings,
|
|
157
|
+
spaces=spaces,
|
|
158
|
+
splits=splits,
|
|
159
|
+
descriptions=descriptions,
|
|
160
|
+
suffixes=suffixes,
|
|
161
|
+
extensions=extensions,
|
|
162
|
+
datatypes=datatypes,
|
|
163
|
+
check=check,
|
|
164
|
+
ignore_json=True,
|
|
165
|
+
)
|
|
166
|
+
# Filter out _epo.fif files:
|
|
167
|
+
bids_paths = [
|
|
168
|
+
bids_path
|
|
169
|
+
for bids_path in bids_paths
|
|
170
|
+
if not (bids_path.suffix == "epo" and bids_path.extension == ".fif")
|
|
171
|
+
]
|
|
172
|
+
self.bids_paths = bids_paths
|
|
173
|
+
self.reader_fn = reader_fn
|
|
174
|
+
self.pool_size = pool_size
|
|
175
|
+
|
|
176
|
+
def __add__(self, other):
|
|
177
|
+
assert isinstance(other, BIDSIterableDataset)
|
|
178
|
+
return BIDSIterableDataset(
|
|
179
|
+
reader_fn=self.reader_fn,
|
|
180
|
+
bids_paths=self.bids_paths + other.bids_paths,
|
|
181
|
+
pool_size=self.pool_size,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def __iadd__(self, other):
|
|
185
|
+
assert isinstance(other, BIDSIterableDataset)
|
|
186
|
+
self.bids_paths += other.bids_paths
|
|
187
|
+
return self
|
|
188
|
+
|
|
189
|
+
def __iter__(self):
|
|
190
|
+
worker_info = get_worker_info()
|
|
191
|
+
if worker_info is None: # single-process data loading, return the full iterator
|
|
192
|
+
bids_paths = self.bids_paths
|
|
193
|
+
else: # in a worker process
|
|
194
|
+
# split workload
|
|
195
|
+
bids_paths = self.bids_paths[worker_info.id :: worker_info.num_workers]
|
|
196
|
+
|
|
197
|
+
pool = []
|
|
198
|
+
end = False
|
|
199
|
+
paths_it = iter(random.sample(bids_paths, k=len(bids_paths)))
|
|
200
|
+
while not (end and len(pool) == 0):
|
|
201
|
+
while not end and len(pool) < self.pool_size:
|
|
202
|
+
try:
|
|
203
|
+
bids_path = next(paths_it)
|
|
204
|
+
ds = self.reader_fn(bids_path)
|
|
205
|
+
if ds is None:
|
|
206
|
+
print(f"Skipping {bids_path} as it is too short.")
|
|
207
|
+
continue
|
|
208
|
+
idx = iter(random.sample(range(len(ds)), k=len(ds)))
|
|
209
|
+
pool.append((ds, idx))
|
|
210
|
+
except StopIteration:
|
|
211
|
+
end = True
|
|
212
|
+
i_pool = random.randint(0, len(pool) - 1)
|
|
213
|
+
ds, idx = pool[i_pool]
|
|
214
|
+
try:
|
|
215
|
+
i_ds = next(idx)
|
|
216
|
+
yield ds[i_ds]
|
|
217
|
+
except StopIteration:
|
|
218
|
+
pool.pop(i_pool)
|