braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/datasets/base.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Dataset classes.
|
|
3
|
+
"""
|
|
2
4
|
|
|
3
5
|
# Authors: Hubert Banville <hubert.jbanville@gmail.com>
|
|
4
6
|
# Lukas Gemein <l.gemein@gmail.com>
|
|
@@ -26,7 +28,7 @@ from mne.utils.docs import deprecated
|
|
|
26
28
|
from torch.utils.data import ConcatDataset, Dataset
|
|
27
29
|
from typing_extensions import TypeVar
|
|
28
30
|
|
|
29
|
-
from .
|
|
31
|
+
from .hub import HubDatasetMixin
|
|
30
32
|
from .registry import register_dataset
|
|
31
33
|
|
|
32
34
|
|
|
@@ -63,9 +65,9 @@ class RecordDataset(Dataset[tuple[np.ndarray, int | str, tuple[int, int, int]]])
|
|
|
63
65
|
|
|
64
66
|
Parameters
|
|
65
67
|
----------
|
|
66
|
-
description
|
|
68
|
+
description: dict | pd.Series
|
|
67
69
|
Description in the form key: value.
|
|
68
|
-
overwrite
|
|
70
|
+
overwrite: bool
|
|
69
71
|
Has to be True if a key in description already exists in the
|
|
70
72
|
dataset description.
|
|
71
73
|
"""
|
|
@@ -399,6 +401,7 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
|
399
401
|
list of RecordDataset
|
|
400
402
|
target_transform : callable | None
|
|
401
403
|
Optional function to call on targets before returning them.
|
|
404
|
+
|
|
402
405
|
"""
|
|
403
406
|
|
|
404
407
|
datasets: list[T]
|
|
@@ -433,8 +436,8 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
|
433
436
|
|
|
434
437
|
def __getitem__(self, idx: int | list):
|
|
435
438
|
"""
|
|
436
|
-
|
|
437
|
-
|
|
439
|
+
Parameters
|
|
440
|
+
----------
|
|
438
441
|
idx : int | list
|
|
439
442
|
Index of window and target to return. If provided as a list of
|
|
440
443
|
ints, multiple windows and targets will be extracted and
|
|
@@ -569,8 +572,7 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
|
569
572
|
self._target_transform = fn
|
|
570
573
|
|
|
571
574
|
def _outdated_save(self, path, overwrite=False):
|
|
572
|
-
"""This is a copy of the old saving function, that had inconsistent
|
|
573
|
-
|
|
575
|
+
"""This is a copy of the old saving function, that had inconsistent
|
|
574
576
|
functionality for BaseDataset and WindowsDataset. It only exists to
|
|
575
577
|
assure backwards compatibility by still being able to run the old tests.
|
|
576
578
|
|
|
@@ -666,10 +668,10 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
|
666
668
|
|
|
667
669
|
Parameters
|
|
668
670
|
----------
|
|
669
|
-
description
|
|
671
|
+
description: dict | pd.DataFrame
|
|
670
672
|
Description in the form key: value where the length of the value
|
|
671
673
|
has to match the number of datasets.
|
|
672
|
-
overwrite
|
|
674
|
+
overwrite: bool
|
|
673
675
|
Has to be True if a key in description already exists in the
|
|
674
676
|
dataset description.
|
|
675
677
|
"""
|
|
@@ -717,14 +719,8 @@ class BaseConcatDataset(ConcatDataset, HubDatasetMixin, Generic[T]):
|
|
|
717
719
|
hasattr(self.datasets[0], "raw") or hasattr(self.datasets[0], "windows")
|
|
718
720
|
):
|
|
719
721
|
raise ValueError("dataset should have either raw or windows attribute")
|
|
720
|
-
|
|
721
|
-
# Create path if it doesn't exist
|
|
722
|
-
os.makedirs(path, exist_ok=True)
|
|
723
|
-
|
|
724
722
|
path_contents = os.listdir(path)
|
|
725
|
-
n_sub_dirs = len(
|
|
726
|
-
[e for e in path_contents if os.path.isdir(os.path.join(path, e))]
|
|
727
|
-
)
|
|
723
|
+
n_sub_dirs = len([os.path.isdir(e) for e in path_contents])
|
|
728
724
|
for i_ds, ds in enumerate(self.datasets):
|
|
729
725
|
# remove subdirectory from list of untouched files / subdirectories
|
|
730
726
|
if str(i_ds + offset) in path_contents:
|
braindecode/datasets/bbci.py
CHANGED
|
@@ -27,11 +27,11 @@ class BBCIDataset(object):
|
|
|
27
27
|
|
|
28
28
|
Parameters
|
|
29
29
|
----------
|
|
30
|
-
filename
|
|
31
|
-
load_sensor_names
|
|
30
|
+
filename: str
|
|
31
|
+
load_sensor_names: list of str, optional
|
|
32
32
|
Also speeds up loading if you only load some sensors.
|
|
33
33
|
None means load all sensors.
|
|
34
|
-
check_class_names
|
|
34
|
+
check_class_names: bool, optional
|
|
35
35
|
check if the class names are part of some known class names at
|
|
36
36
|
Translational NeuroTechnology Lab, AG Ball, Freiburg, Germany.
|
|
37
37
|
"""
|
|
@@ -134,13 +134,13 @@ class BBCIDataset(object):
|
|
|
134
134
|
|
|
135
135
|
Parameters
|
|
136
136
|
----------
|
|
137
|
-
filename
|
|
138
|
-
pattern
|
|
137
|
+
filename: str
|
|
138
|
+
pattern: str, optional
|
|
139
139
|
Only return those sensor names that match the given pattern.
|
|
140
140
|
|
|
141
141
|
Returns
|
|
142
142
|
-------
|
|
143
|
-
sensor_names
|
|
143
|
+
sensor_names: list of str
|
|
144
144
|
Sensor names that match the pattern or all sensor names in the file.
|
|
145
145
|
"""
|
|
146
146
|
with h5py.File(filename, "r") as h5file:
|
|
@@ -237,17 +237,17 @@ class BBCIDataset(object):
|
|
|
237
237
|
|
|
238
238
|
def _check_class_names(all_class_names, event_times_in_ms, event_classes):
|
|
239
239
|
"""
|
|
240
|
-
Checks if the class names are part of some known class names used in
|
|
241
|
-
|
|
240
|
+
Checks if the class names are part of some known class names used in
|
|
242
241
|
translational neurotechnology lab, AG Ball, Freiburg.
|
|
243
242
|
|
|
244
243
|
Logs warning in case class names are not known.
|
|
245
244
|
|
|
246
245
|
Parameters
|
|
247
246
|
----------
|
|
248
|
-
all_class_names
|
|
249
|
-
event_times_in_ms
|
|
250
|
-
event_classes
|
|
247
|
+
all_class_names: list of str
|
|
248
|
+
event_times_in_ms: list of number
|
|
249
|
+
event_classes: list of number
|
|
250
|
+
|
|
251
251
|
"""
|
|
252
252
|
if all_class_names == ["Right Hand", "Left Hand", "Rest", "Feet"]:
|
|
253
253
|
pass
|
|
@@ -665,15 +665,16 @@ def load_bbci_sets_from_folder(
|
|
|
665
665
|
|
|
666
666
|
Parameters
|
|
667
667
|
----------
|
|
668
|
-
folder
|
|
668
|
+
folder: str
|
|
669
669
|
Folder with .BBCI.mat files inside
|
|
670
|
-
runs
|
|
670
|
+
runs: list of int
|
|
671
671
|
If you only want to load specific runs.
|
|
672
672
|
Assumes filenames with such kind of part: S001R02 for Run 2.
|
|
673
673
|
Tries to match this regex: ``'S[0-9]{3,3}R[0-9]{2,2}_'``.
|
|
674
674
|
|
|
675
675
|
Returns
|
|
676
676
|
-------
|
|
677
|
+
|
|
677
678
|
"""
|
|
678
679
|
bbci_mat_files = sorted(glob(os.path.join(folder, "*.BBCI.mat")))
|
|
679
680
|
if runs != "all":
|
braindecode/datasets/bcicomp.py
CHANGED
|
@@ -33,16 +33,16 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
33
33
|
http://www.bbci.de/competition/iv/ for the dataset and competition description.
|
|
34
34
|
ECoG library containing the dataset: https://searchworks.stanford.edu/view/zk881ps0522
|
|
35
35
|
|
|
36
|
+
Notes
|
|
37
|
+
-----
|
|
38
|
+
When using this dataset please cite [1]_ .
|
|
39
|
+
|
|
36
40
|
Parameters
|
|
37
41
|
----------
|
|
38
42
|
subject_ids : list(int) | int | None
|
|
39
43
|
(list of) int of subject(s) to be loaded. If None, load all available
|
|
40
44
|
subjects. Should be in range 1-3.
|
|
41
45
|
|
|
42
|
-
Notes
|
|
43
|
-
-----
|
|
44
|
-
When using this dataset please cite [1]_ .
|
|
45
|
-
|
|
46
46
|
References
|
|
47
47
|
----------
|
|
48
48
|
.. [1] Miller, Kai J. "A library of human electrocorticographic data and analyses."
|
|
@@ -94,6 +94,7 @@ class BCICompetitionIVDataset4(BaseConcatDataset):
|
|
|
94
94
|
|
|
95
95
|
Returns
|
|
96
96
|
-------
|
|
97
|
+
|
|
97
98
|
"""
|
|
98
99
|
signature = "BCICompetitionIVDataset4"
|
|
99
100
|
folder_name = "BCI_Competion4_dataset4_data_fingerflexions"
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# mypy: ignore-errors
|
|
2
1
|
"""Dataset for loading BIDS.
|
|
3
2
|
|
|
4
3
|
More information on BIDS (Brain Imaging Data Structure) can be found at https://bids.neuroimaging.io
|
|
@@ -20,19 +19,26 @@ import numpy as np
|
|
|
20
19
|
import pandas as pd
|
|
21
20
|
from joblib import Parallel, delayed
|
|
22
21
|
|
|
23
|
-
from
|
|
22
|
+
from .base import BaseConcatDataset, RawDataset, WindowsDataset
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def _description_from_bids_path(bids_path: mne_bids.BIDSPath) -> dict[str, Any]:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
26
|
+
return {
|
|
27
|
+
"path": bids_path.fpath,
|
|
28
|
+
"subject": bids_path.subject,
|
|
29
|
+
"session": bids_path.session,
|
|
30
|
+
"task": bids_path.task,
|
|
31
|
+
"acquisition": bids_path.acquisition,
|
|
32
|
+
"run": bids_path.run,
|
|
33
|
+
"processing": bids_path.processing,
|
|
34
|
+
"recording": bids_path.recording,
|
|
35
|
+
"space": bids_path.space,
|
|
36
|
+
"split": bids_path.split,
|
|
37
|
+
"description": bids_path.description,
|
|
38
|
+
"suffix": bids_path.suffix,
|
|
39
|
+
"extension": bids_path.extension,
|
|
40
|
+
"datatype": bids_path.datatype,
|
|
41
|
+
}
|
|
36
42
|
|
|
37
43
|
|
|
38
44
|
@dataclass
|
|
@@ -59,7 +65,7 @@ class BIDSDataset(BaseConcatDataset):
|
|
|
59
65
|
The acquisition session. Corresponds to "ses".
|
|
60
66
|
tasks : str | array-like of str | None
|
|
61
67
|
The experimental task. Corresponds to "task".
|
|
62
|
-
acquisitions
|
|
68
|
+
acquisitions: str | array-like of str | None
|
|
63
69
|
The acquisition parameters. Corresponds to "acq".
|
|
64
70
|
runs : str | array-like of str | None
|
|
65
71
|
The run number. Corresponds to "run".
|
|
@@ -25,13 +25,14 @@ class BIDSIterableDataset(IterableDataset):
|
|
|
25
25
|
|
|
26
26
|
Examples
|
|
27
27
|
--------
|
|
28
|
-
>>> from braindecode.datasets import
|
|
29
|
-
>>> from braindecode.datasets.bids import BIDSIterableDataset
|
|
28
|
+
>>> from braindecode.datasets import RecordDataset, BaseConcatDataset
|
|
29
|
+
>>> from braindecode.datasets.bids import BIDSIterableDataset, _description_from_bids_path
|
|
30
30
|
>>> from braindecode.preprocessing import create_fixed_length_windows
|
|
31
31
|
>>>
|
|
32
32
|
>>> def my_reader_fn(path):
|
|
33
33
|
... raw = mne_bids.read_raw_bids(path)
|
|
34
|
-
...
|
|
34
|
+
... desc = _description_from_bids_path(path)
|
|
35
|
+
... ds = RawDataset(raw, description=desc)
|
|
35
36
|
... windows_ds = create_fixed_length_windows(
|
|
36
37
|
... BaseConcatDataset([ds]),
|
|
37
38
|
... window_size_samples=400,
|
|
@@ -47,8 +48,7 @@ class BIDSIterableDataset(IterableDataset):
|
|
|
47
48
|
Parameters
|
|
48
49
|
----------
|
|
49
50
|
reader_fn : Callable[[mne_bids.BIDSPath], Sequence]
|
|
50
|
-
A function that takes a BIDSPath and returns a dataset
|
|
51
|
-
RecordDataset or BaseConcatDataset of RecordDataset).
|
|
51
|
+
A function that takes a BIDSPath and returns a dataset.
|
|
52
52
|
pool_size : int
|
|
53
53
|
The number of recordings to read and sample from.
|
|
54
54
|
bids_paths : list[mne_bids.BIDSPath] | None
|
|
@@ -62,7 +62,7 @@ class BIDSIterableDataset(IterableDataset):
|
|
|
62
62
|
The acquisition session. Corresponds to "ses".
|
|
63
63
|
tasks : str | array-like of str | None
|
|
64
64
|
The experimental task. Corresponds to "task".
|
|
65
|
-
acquisitions
|
|
65
|
+
acquisitions: str | array-like of str | None
|
|
66
66
|
The acquisition parameters. Corresponds to "acq".
|
|
67
67
|
runs : str | array-like of str | None
|
|
68
68
|
The run number. Corresponds to "run".
|
|
@@ -106,8 +106,6 @@ class BIDSIterableDataset(IterableDataset):
|
|
|
106
106
|
If True, preload the data. Defaults to False.
|
|
107
107
|
n_jobs : int
|
|
108
108
|
Number of jobs to run in parallel. Defaults to 1.
|
|
109
|
-
|
|
110
|
-
|
|
111
109
|
"""
|
|
112
110
|
|
|
113
111
|
def __init__(
|