braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Utilities for preprocessing functionality in Braindecode."""
|
|
2
|
+
|
|
3
|
+
# Authors: Christian Kothe <christian.kothe@intheon.io>
|
|
4
|
+
#
|
|
5
|
+
# License: BSD-3
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import inspect
|
|
9
|
+
import json
|
|
10
|
+
import re
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from mne.io.base import BaseRaw
|
|
15
|
+
|
|
16
|
+
from braindecode import preprocessing
|
|
17
|
+
|
|
18
|
+
__all__ = ["mne_store_metadata", "mne_load_metadata"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Use a unique marker for embedding structured data in info['description']
|
|
22
|
+
_MARKER_PATTERN = re.compile(r"<!-- braindecode-meta:\s*(\S+)\s*-->", re.DOTALL)
|
|
23
|
+
_MARKER_START = "<!-- braindecode-meta:"
|
|
24
|
+
_MARKER_END = "-->"
|
|
25
|
+
|
|
26
|
+
# Marker key for numpy arrays
|
|
27
|
+
_NP_ARRAY_TAG = "__numpy_array__"
|
|
28
|
+
|
|
29
|
+
preprocessor_dict = {}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _init_preprocessor_dict():
|
|
33
|
+
for m in inspect.getmembers(preprocessing, inspect.isclass):
|
|
34
|
+
if issubclass(m[1], preprocessing.Preprocessor):
|
|
35
|
+
preprocessor_dict[m[0]] = m[1]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _numpy_decoder(dct):
|
|
39
|
+
"""Internal JSON decoder hook to handle numpy arrays."""
|
|
40
|
+
if dct.get(_NP_ARRAY_TAG):
|
|
41
|
+
arr = np.array(dct["data"], dtype=dct["dtype"])
|
|
42
|
+
return arr.reshape(dct["shape"])
|
|
43
|
+
return dct
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
47
|
+
"""Custom JSON encoder hook to handle numpy arrays."""
|
|
48
|
+
|
|
49
|
+
def default(self, obj):
|
|
50
|
+
if isinstance(obj, np.ndarray):
|
|
51
|
+
# Reject complex-valued dtypes as they're not JSON serializable
|
|
52
|
+
if np.issubdtype(obj.dtype, np.complexfloating):
|
|
53
|
+
raise TypeError(
|
|
54
|
+
f"Cannot serialize numpy array with complex dtype {obj.dtype}. "
|
|
55
|
+
"Complex dtypes are not supported."
|
|
56
|
+
)
|
|
57
|
+
return {
|
|
58
|
+
_NP_ARRAY_TAG: True,
|
|
59
|
+
"dtype": obj.dtype.str,
|
|
60
|
+
"shape": obj.shape,
|
|
61
|
+
"data": obj.flatten().tolist(),
|
|
62
|
+
}
|
|
63
|
+
return super().default(obj)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _encode_payload(data: dict) -> str:
|
|
67
|
+
"""Serializes, encodes, and formats data into a marker string."""
|
|
68
|
+
json_str = json.dumps(data, cls=NumpyEncoder)
|
|
69
|
+
encoded = base64.b64encode(json_str.encode("utf-8")).decode("ascii")
|
|
70
|
+
return f"{_MARKER_START} {encoded} {_MARKER_END}"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def mne_store_metadata(
|
|
74
|
+
raw: BaseRaw, payload: Any, *, key: str, no_overwrite: bool = False
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Embed a JSON-serializable metadata payload in an MNE BaseRaw dataset
|
|
77
|
+
under a specified key.
|
|
78
|
+
|
|
79
|
+
This will encode the payload as a base64-encoded JSON string and store it
|
|
80
|
+
in the `info['description']` field of the Raw object while preserving any
|
|
81
|
+
existing content. Note this is not particularly efficient and should not
|
|
82
|
+
be used for very large payloads.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
raw : BaseRaw
|
|
87
|
+
The MNE Raw object to store data in.
|
|
88
|
+
payload : Any
|
|
89
|
+
The JSON-serializable data to store.
|
|
90
|
+
key : str
|
|
91
|
+
The key under which to store the payload.
|
|
92
|
+
no_overwrite : bool
|
|
93
|
+
If True, will not overwrite an existing entry with the same key.
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
# the description is apparently the only viable place where custom metadata may be
|
|
97
|
+
# stored in MNE Raw objects that persists through saving/loading
|
|
98
|
+
description = raw.info.get("description") or ""
|
|
99
|
+
|
|
100
|
+
# Try to find existing eegprep data
|
|
101
|
+
if match := _MARKER_PATTERN.search(description):
|
|
102
|
+
# Parse existing data
|
|
103
|
+
try:
|
|
104
|
+
decoded = base64.b64decode(match.group(1)).decode("utf-8")
|
|
105
|
+
existing_data = json.loads(decoded, object_hook=_numpy_decoder)
|
|
106
|
+
except (ValueError, json.JSONDecodeError):
|
|
107
|
+
existing_data = {}
|
|
108
|
+
# Check no_overwrite condition
|
|
109
|
+
if no_overwrite and key in existing_data:
|
|
110
|
+
return
|
|
111
|
+
# Update data
|
|
112
|
+
existing_data[key] = payload
|
|
113
|
+
new_marker = _encode_payload(existing_data)
|
|
114
|
+
# Replace the old marker with updated one
|
|
115
|
+
new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
|
|
116
|
+
else:
|
|
117
|
+
# No existing data, append new marker
|
|
118
|
+
data = {key: payload}
|
|
119
|
+
new_marker = _encode_payload(data)
|
|
120
|
+
# Append with spacing if description exists
|
|
121
|
+
if description.strip():
|
|
122
|
+
new_description = f"{description.rstrip()}\n{new_marker}"
|
|
123
|
+
else:
|
|
124
|
+
new_description = new_marker
|
|
125
|
+
|
|
126
|
+
raw.info["description"] = new_description
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def mne_load_metadata(raw: BaseRaw, *, key: str, delete: bool = False) -> Any | None:
|
|
130
|
+
"""Retrieves data that was previously stored using mne_store_metadata from an MNE
|
|
131
|
+
BaseRaw dataset.
|
|
132
|
+
|
|
133
|
+
This function can retrieve data from an MNE Raw object that was stored
|
|
134
|
+
using `mne_store_metadata`. It decodes the base64-encoded JSON string from the
|
|
135
|
+
`info['description']` field and extracts the payload associated with the
|
|
136
|
+
specified key.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
raw : BaseRaw
|
|
141
|
+
The MNE Raw object to retrieve data from.
|
|
142
|
+
key : str
|
|
143
|
+
The key under which the payload was stored.
|
|
144
|
+
delete : bool
|
|
145
|
+
If True, removes the key from the stored data after retrieval.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
Any | None
|
|
150
|
+
The retrieved payload, or None if not found.
|
|
151
|
+
"""
|
|
152
|
+
description = raw.info.get("description") or ""
|
|
153
|
+
match = _MARKER_PATTERN.search(description)
|
|
154
|
+
if not match:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
decoded = base64.b64decode(match.group(1)).decode("utf-8")
|
|
159
|
+
data = json.loads(decoded, object_hook=_numpy_decoder)
|
|
160
|
+
except (ValueError, json.JSONDecodeError):
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
result = data.get(key)
|
|
164
|
+
|
|
165
|
+
if delete and key in data:
|
|
166
|
+
# Remove the key from data
|
|
167
|
+
del data[key]
|
|
168
|
+
if data:
|
|
169
|
+
# Still have other keys, update the marker
|
|
170
|
+
new_marker = _encode_payload(data)
|
|
171
|
+
new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
|
|
172
|
+
else:
|
|
173
|
+
# No more keys, remove the entire marker
|
|
174
|
+
new_description = _MARKER_PATTERN.sub("", description, count=1).rstrip()
|
|
175
|
+
raw.info["description"] = new_description
|
|
176
|
+
|
|
177
|
+
return result
|