braindecode 1.3.0.dev180329405__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/base.py +1 -1
- braindecode/datasets/__init__.py +12 -4
- braindecode/datasets/base.py +115 -151
- braindecode/datasets/bcicomp.py +4 -4
- braindecode/datasets/bids.py +3 -3
- braindecode/datasets/experimental.py +2 -2
- braindecode/datasets/mne.py +3 -5
- braindecode/datasets/moabb.py +17 -7
- braindecode/datasets/nmt.py +2 -2
- braindecode/datasets/sleep_physio_challe_18.py +2 -2
- braindecode/datasets/sleep_physionet.py +2 -2
- braindecode/datasets/tuh.py +2 -2
- braindecode/datasets/xy.py +2 -2
- braindecode/datautil/__init__.py +11 -1
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/serialization.py +7 -7
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +6 -0
- braindecode/models/atcnet.py +26 -27
- braindecode/models/attentionbasenet.py +37 -32
- braindecode/models/attn_sleep.py +2 -0
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +2 -0
- braindecode/models/contrawr.py +2 -0
- braindecode/models/ctnet.py +8 -3
- braindecode/models/deepsleepnet.py +28 -19
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegitnet.py +2 -0
- braindecode/models/eegminer.py +2 -0
- braindecode/models/eegnet.py +1 -1
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +2 -0
- braindecode/models/fbcnet.py +5 -1
- braindecode/models/fblightconvnet.py +2 -0
- braindecode/models/fbmsnet.py +20 -6
- braindecode/models/ifnet.py +2 -0
- braindecode/models/labram.py +33 -26
- braindecode/models/medformer.py +758 -0
- braindecode/models/msvtnet.py +2 -0
- braindecode/models/patchedtransformer.py +1 -1
- braindecode/models/signal_jepa.py +111 -27
- braindecode/models/sinc_shallow.py +12 -9
- braindecode/models/sstdpn.py +11 -11
- braindecode/models/summary.csv +3 -0
- braindecode/models/syncnet.py +2 -0
- braindecode/models/tcn.py +2 -0
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +3 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -9
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/__init__.py +232 -3
- braindecode/preprocessing/eegprep_preprocess.py +1202 -0
- braindecode/preprocessing/mne_preprocess.py +142 -10
- braindecode/preprocessing/preprocess.py +28 -18
- braindecode/preprocessing/util.py +166 -0
- braindecode/preprocessing/windowers.py +26 -20
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +6 -2
- braindecode-1.3.0.dev182330353.dist-info/RECORD +109 -0
- braindecode-1.3.0.dev180329405.dist-info/RECORD +0 -103
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
|
@@ -6,20 +6,46 @@
|
|
|
6
6
|
# License: BSD-3
|
|
7
7
|
import inspect
|
|
8
8
|
|
|
9
|
+
import mne.channels
|
|
9
10
|
import mne.io
|
|
11
|
+
import mne.preprocessing
|
|
10
12
|
|
|
11
13
|
from braindecode.preprocessing.preprocess import Preprocessor
|
|
12
|
-
from braindecode.util import _update_moabb_docstring
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
def
|
|
16
|
+
def _is_standalone_function(func):
|
|
17
|
+
"""
|
|
18
|
+
Determine if a function is standalone based on its module.
|
|
19
|
+
|
|
20
|
+
Standalone functions are those in mne.preprocessing, mne.channels, mne.filter, etc.
|
|
21
|
+
that are not methods of mne.io.Raw.
|
|
22
|
+
"""
|
|
23
|
+
# Check if it's a method of Raw by seeing if it's bound or unbound method
|
|
24
|
+
if hasattr(mne.io.Raw, func.__name__):
|
|
25
|
+
return False
|
|
26
|
+
# Otherwise, it's a standalone function
|
|
27
|
+
return True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _generate_init_method(func, force_copy_false=False):
|
|
16
31
|
"""
|
|
17
32
|
Generate an __init__ method for a class based on the function's signature.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
func : callable
|
|
37
|
+
The function to wrap.
|
|
38
|
+
force_copy_false : bool
|
|
39
|
+
If True, forces copy=False by default for functions that have a copy parameter.
|
|
18
40
|
"""
|
|
19
41
|
parameters = list(inspect.signature(func).parameters.values())
|
|
20
42
|
param_names = [param.name for param in parameters]
|
|
21
43
|
|
|
22
44
|
def init_method(self, *args, **kwargs):
|
|
45
|
+
# For standalone functions with copy parameter, set copy=False by default
|
|
46
|
+
if force_copy_false and "copy" in param_names and "copy" not in kwargs:
|
|
47
|
+
kwargs["copy"] = False
|
|
48
|
+
|
|
23
49
|
for name, value in zip(param_names, args):
|
|
24
50
|
setattr(self, name, value)
|
|
25
51
|
for name, value in kwargs.items():
|
|
@@ -33,19 +59,70 @@ def _generate_init_method(func):
|
|
|
33
59
|
def _generate_mne_pre_processor(function):
|
|
34
60
|
"""
|
|
35
61
|
Generate a class based on an MNE function for preprocessing.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
function : callable
|
|
66
|
+
The MNE function to wrap. Automatically determines if it's standalone
|
|
67
|
+
or a Raw method based on the function's module and name.
|
|
36
68
|
"""
|
|
37
69
|
class_name = "".join(word.title() for word in function.__name__.split("_")).replace(
|
|
38
70
|
"Eeg", "EEG"
|
|
39
71
|
)
|
|
40
|
-
|
|
41
|
-
|
|
72
|
+
# Create a wrapper note that references the original MNE function
|
|
73
|
+
# For Raw methods, use mne.io.Raw.method_name format with :meth:
|
|
74
|
+
# For standalone functions, use the function name only with :func:
|
|
75
|
+
if hasattr(mne.io.Raw, function.__name__):
|
|
76
|
+
ref_path = f"mne.io.Raw.{function.__name__}"
|
|
77
|
+
ref_role = "meth"
|
|
78
|
+
else:
|
|
79
|
+
# For standalone functions, try common MNE public APIs
|
|
80
|
+
# These are more likely to be in intersphinx inventory
|
|
81
|
+
func_name = function.__name__
|
|
82
|
+
if function.__module__.startswith("mne.preprocessing"):
|
|
83
|
+
ref_path = f"mne.preprocessing.{func_name}"
|
|
84
|
+
elif function.__module__.startswith("mne.channels"):
|
|
85
|
+
ref_path = f"mne.channels.{func_name}"
|
|
86
|
+
elif function.__module__.startswith("mne.filter"):
|
|
87
|
+
ref_path = f"mne.filter.{func_name}"
|
|
88
|
+
else:
|
|
89
|
+
ref_path = f"{function.__module__}.{func_name}"
|
|
90
|
+
ref_role = "func"
|
|
91
|
+
|
|
92
|
+
# Use proper Sphinx cross-reference for intersphinx linking
|
|
93
|
+
wrapper_note = (
|
|
94
|
+
f"Braindecode preprocessor wrapper for :{ref_role}:`~{ref_path}`.\n\n"
|
|
95
|
+
)
|
|
42
96
|
|
|
43
97
|
base_classes = (Preprocessor,)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
98
|
+
|
|
99
|
+
# Automatically determine if function is standalone
|
|
100
|
+
is_standalone = _is_standalone_function(function)
|
|
101
|
+
|
|
102
|
+
# Check if function has a 'copy' parameter
|
|
103
|
+
sig = inspect.signature(function)
|
|
104
|
+
has_copy_param = "copy" in sig.parameters
|
|
105
|
+
force_copy_false = is_standalone and has_copy_param
|
|
106
|
+
|
|
107
|
+
if is_standalone:
|
|
108
|
+
# For standalone functions, store the actual function object
|
|
109
|
+
class_attrs = {
|
|
110
|
+
"__init__": _generate_init_method(
|
|
111
|
+
function, force_copy_false=force_copy_false
|
|
112
|
+
),
|
|
113
|
+
"__doc__": wrapper_note + (function.__doc__ or ""),
|
|
114
|
+
"fn": function, # Store the function itself, not the name
|
|
115
|
+
"_is_standalone": True,
|
|
116
|
+
}
|
|
117
|
+
else:
|
|
118
|
+
# For methods, store the function name as before
|
|
119
|
+
class_attrs = {
|
|
120
|
+
"__init__": _generate_init_method(function),
|
|
121
|
+
"__doc__": wrapper_note + (function.__doc__ or ""),
|
|
122
|
+
"fn": function.__name__,
|
|
123
|
+
"_is_standalone": False,
|
|
124
|
+
}
|
|
125
|
+
|
|
49
126
|
generated_class = type(class_name, base_classes, class_attrs)
|
|
50
127
|
|
|
51
128
|
return generated_class
|
|
@@ -53,12 +130,65 @@ def _generate_mne_pre_processor(function):
|
|
|
53
130
|
|
|
54
131
|
# List of MNE functions to generate classes for
|
|
55
132
|
mne_functions = [
|
|
133
|
+
# From mne.filter
|
|
56
134
|
mne.filter.resample,
|
|
135
|
+
mne.filter.filter_data,
|
|
136
|
+
mne.filter.notch_filter,
|
|
137
|
+
# From mne.io.Raw methods
|
|
138
|
+
mne.io.Raw.add_channels,
|
|
139
|
+
mne.io.Raw.add_events,
|
|
140
|
+
mne.io.Raw.add_proj,
|
|
141
|
+
mne.io.Raw.add_reference_channels,
|
|
142
|
+
mne.io.Raw.anonymize,
|
|
143
|
+
mne.io.Raw.apply_gradient_compensation,
|
|
144
|
+
mne.io.Raw.apply_hilbert,
|
|
145
|
+
mne.io.Raw.apply_proj,
|
|
146
|
+
mne.io.Raw.crop,
|
|
147
|
+
mne.io.Raw.crop_by_annotations,
|
|
148
|
+
mne.io.Raw.del_proj,
|
|
57
149
|
mne.io.Raw.drop_channels,
|
|
58
150
|
mne.io.Raw.filter,
|
|
59
|
-
mne.io.Raw.
|
|
151
|
+
mne.io.Raw.fix_mag_coil_types,
|
|
152
|
+
mne.io.Raw.interpolate_bads,
|
|
153
|
+
mne.io.Raw.interpolate_to,
|
|
154
|
+
mne.io.Raw.notch_filter,
|
|
60
155
|
mne.io.Raw.pick,
|
|
156
|
+
mne.io.Raw.pick_channels,
|
|
157
|
+
mne.io.Raw.pick_types,
|
|
158
|
+
mne.io.Raw.rename_channels,
|
|
159
|
+
mne.io.Raw.reorder_channels,
|
|
160
|
+
mne.io.Raw.rescale,
|
|
161
|
+
mne.io.Raw.resample,
|
|
162
|
+
mne.io.Raw.savgol_filter,
|
|
163
|
+
mne.io.Raw.set_annotations,
|
|
164
|
+
mne.io.Raw.set_channel_types,
|
|
61
165
|
mne.io.Raw.set_eeg_reference,
|
|
166
|
+
mne.io.Raw.set_meas_date,
|
|
167
|
+
mne.io.Raw.set_montage,
|
|
168
|
+
# Standalone functions from mne.preprocessing
|
|
169
|
+
mne.preprocessing.annotate_amplitude,
|
|
170
|
+
mne.preprocessing.annotate_break,
|
|
171
|
+
mne.preprocessing.annotate_movement,
|
|
172
|
+
mne.preprocessing.annotate_muscle_zscore,
|
|
173
|
+
mne.preprocessing.annotate_nan,
|
|
174
|
+
mne.preprocessing.compute_current_source_density,
|
|
175
|
+
mne.preprocessing.compute_bridged_electrodes,
|
|
176
|
+
mne.preprocessing.equalize_bads,
|
|
177
|
+
mne.preprocessing.find_bad_channels_lof,
|
|
178
|
+
mne.preprocessing.fix_stim_artifact,
|
|
179
|
+
mne.preprocessing.interpolate_bridged_electrodes,
|
|
180
|
+
mne.preprocessing.maxwell_filter,
|
|
181
|
+
mne.preprocessing.oversampled_temporal_projection,
|
|
182
|
+
mne.preprocessing.realign_raw,
|
|
183
|
+
mne.preprocessing.regress_artifact,
|
|
184
|
+
# Standalone functions from mne.channels
|
|
185
|
+
mne.channels.combine_channels,
|
|
186
|
+
mne.channels.equalize_channels,
|
|
187
|
+
mne.channels.rename_channels,
|
|
188
|
+
# Top-level mne functions for referencing
|
|
189
|
+
mne.add_reference_channels,
|
|
190
|
+
mne.set_bipolar_reference,
|
|
191
|
+
mne.set_eeg_reference,
|
|
62
192
|
]
|
|
63
193
|
|
|
64
194
|
# Automatically generate and add classes to the global namespace
|
|
@@ -71,6 +201,8 @@ __all__ = [
|
|
|
71
201
|
class_obj.__name__
|
|
72
202
|
for class_obj in globals().values()
|
|
73
203
|
if isinstance(class_obj, type)
|
|
204
|
+
and issubclass(class_obj, Preprocessor)
|
|
205
|
+
and class_obj != Preprocessor
|
|
74
206
|
]
|
|
75
207
|
|
|
76
208
|
# Clean up unnecessary variables
|
|
@@ -30,8 +30,8 @@ from numpy.typing import NDArray
|
|
|
30
30
|
|
|
31
31
|
from braindecode.datasets.base import (
|
|
32
32
|
BaseConcatDataset,
|
|
33
|
-
BaseDataset,
|
|
34
33
|
EEGWindowsDataset,
|
|
34
|
+
RawDataset,
|
|
35
35
|
WindowsDataset,
|
|
36
36
|
)
|
|
37
37
|
from braindecode.datautil.serialization import (
|
|
@@ -85,7 +85,7 @@ class Preprocessor(object):
|
|
|
85
85
|
|
|
86
86
|
def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
|
|
87
87
|
try:
|
|
88
|
-
self._try_apply(raw_or_epochs)
|
|
88
|
+
return self._try_apply(raw_or_epochs)
|
|
89
89
|
except RuntimeError:
|
|
90
90
|
# Maybe the function needs the data to be loaded and the data was
|
|
91
91
|
# not loaded yet. Not all MNE functions need data to be loaded,
|
|
@@ -93,15 +93,20 @@ class Preprocessor(object):
|
|
|
93
93
|
# without preloading data which can make the overall preprocessing
|
|
94
94
|
# pipeline substantially faster.
|
|
95
95
|
raw_or_epochs.load_data()
|
|
96
|
-
self._try_apply(raw_or_epochs)
|
|
96
|
+
return self._try_apply(raw_or_epochs)
|
|
97
97
|
|
|
98
98
|
def _try_apply(self, raw_or_epochs):
|
|
99
99
|
if callable(self.fn):
|
|
100
|
-
self.fn(raw_or_epochs, **self.kwargs)
|
|
100
|
+
result = self.fn(raw_or_epochs, **self.kwargs)
|
|
101
|
+
# For standalone functions that return a new object, propagate it back
|
|
102
|
+
if result is not None and result is not raw_or_epochs:
|
|
103
|
+
return result
|
|
104
|
+
return raw_or_epochs
|
|
101
105
|
else:
|
|
102
106
|
if not hasattr(raw_or_epochs, self.fn):
|
|
103
107
|
raise AttributeError(f"MNE object does not have a {self.fn} method.")
|
|
104
108
|
getattr(raw_or_epochs, self.fn)(**self.kwargs)
|
|
109
|
+
return raw_or_epochs
|
|
105
110
|
|
|
106
111
|
|
|
107
112
|
def preprocess(
|
|
@@ -119,7 +124,7 @@ def preprocess(
|
|
|
119
124
|
Parameters
|
|
120
125
|
----------
|
|
121
126
|
concat_ds : BaseConcatDataset
|
|
122
|
-
A concat of ``
|
|
127
|
+
A concat of ``RecordDataset`` to be preprocessed.
|
|
123
128
|
preprocessors : list of Preprocessor
|
|
124
129
|
Preprocessor objects to apply to each dataset.
|
|
125
130
|
save_dir : str | None
|
|
@@ -229,15 +234,15 @@ def _preprocess(
|
|
|
229
234
|
|
|
230
235
|
Parameters
|
|
231
236
|
----------
|
|
232
|
-
ds:
|
|
237
|
+
ds: RecordDataset
|
|
233
238
|
Dataset object to preprocess.
|
|
234
239
|
ds_index : int
|
|
235
|
-
Index of the
|
|
240
|
+
Index of the ``RecordDataset`` in its ``BaseConcatDataset``. Ignored if save_dir
|
|
236
241
|
is None.
|
|
237
242
|
preprocessors: list(Preprocessor)
|
|
238
243
|
List of preprocessors to apply to the dataset.
|
|
239
244
|
save_dir : str | None
|
|
240
|
-
If provided, save the preprocessed
|
|
245
|
+
If provided, save the preprocessed RecordDataset in the
|
|
241
246
|
specified directory.
|
|
242
247
|
overwrite : bool
|
|
243
248
|
If True, overwrite existing file with the same name.
|
|
@@ -251,20 +256,25 @@ def _preprocess(
|
|
|
251
256
|
if raw_or_epochs.preload and copy_data:
|
|
252
257
|
raw_or_epochs._data = raw_or_epochs._data.copy()
|
|
253
258
|
for preproc in preprocessors:
|
|
254
|
-
preproc.apply(raw_or_epochs)
|
|
259
|
+
raw_or_epochs = preproc.apply(raw_or_epochs)
|
|
260
|
+
return raw_or_epochs
|
|
255
261
|
|
|
256
262
|
if hasattr(ds, "raw"):
|
|
257
263
|
if isinstance(ds, EEGWindowsDataset):
|
|
258
264
|
warn(
|
|
259
265
|
f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
|
|
260
266
|
)
|
|
261
|
-
_preprocess_raw_or_epochs(ds.raw, preprocessors)
|
|
267
|
+
processed = _preprocess_raw_or_epochs(ds.raw, preprocessors)
|
|
268
|
+
if processed is not ds.raw:
|
|
269
|
+
ds.raw = processed
|
|
262
270
|
elif hasattr(ds, "windows"):
|
|
263
|
-
_preprocess_raw_or_epochs(ds.windows, preprocessors)
|
|
271
|
+
processed = _preprocess_raw_or_epochs(ds.windows, preprocessors)
|
|
272
|
+
if processed is not ds.windows:
|
|
273
|
+
ds.windows = processed
|
|
264
274
|
else:
|
|
265
275
|
raise ValueError(
|
|
266
|
-
"Can only preprocess concatenation of
|
|
267
|
-
"
|
|
276
|
+
"Can only preprocess concatenation of RecordDataset, "
|
|
277
|
+
"with either a `raw` or `windows` attribute."
|
|
268
278
|
)
|
|
269
279
|
|
|
270
280
|
# Store preprocessing keyword arguments in the dataset
|
|
@@ -297,11 +307,11 @@ def _get_preproc_kwargs(preprocessors):
|
|
|
297
307
|
|
|
298
308
|
|
|
299
309
|
def _set_preproc_kwargs(ds, preprocessors):
|
|
300
|
-
"""Record preprocessing keyword arguments in
|
|
310
|
+
"""Record preprocessing keyword arguments in RecordDataset.
|
|
301
311
|
|
|
302
312
|
Parameters
|
|
303
313
|
----------
|
|
304
|
-
ds :
|
|
314
|
+
ds : RecordDataset
|
|
305
315
|
Dataset in which to record preprocessing keyword arguments.
|
|
306
316
|
preprocessors : list
|
|
307
317
|
List of preprocessors.
|
|
@@ -309,12 +319,12 @@ def _set_preproc_kwargs(ds, preprocessors):
|
|
|
309
319
|
preproc_kwargs = _get_preproc_kwargs(preprocessors)
|
|
310
320
|
if isinstance(ds, WindowsDataset):
|
|
311
321
|
kind = "window"
|
|
312
|
-
|
|
322
|
+
elif isinstance(ds, EEGWindowsDataset):
|
|
313
323
|
kind = "raw"
|
|
314
|
-
elif isinstance(ds,
|
|
324
|
+
elif isinstance(ds, RawDataset):
|
|
315
325
|
kind = "raw"
|
|
316
326
|
else:
|
|
317
|
-
raise TypeError(f"ds must be a
|
|
327
|
+
raise TypeError(f"ds must be a RecordDataset, got {type(ds)}")
|
|
318
328
|
setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
|
|
319
329
|
|
|
320
330
|
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""Utilities for preprocessing functionality in Braindecode."""
|
|
2
|
+
|
|
3
|
+
# Authors: Christian Kothe <christian.kothe@intheon.io>
|
|
4
|
+
#
|
|
5
|
+
# License: BSD-3
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import json
|
|
9
|
+
import re
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from mne.io.base import BaseRaw
|
|
14
|
+
|
|
15
|
+
__all__ = ["mne_store_metadata", "mne_load_metadata"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Use a unique marker for embedding structured data in info['description']
|
|
19
|
+
_MARKER_PATTERN = re.compile(r"<!-- braindecode-meta:\s*(\S+)\s*-->", re.DOTALL)
|
|
20
|
+
_MARKER_START = "<!-- braindecode-meta:"
|
|
21
|
+
_MARKER_END = "-->"
|
|
22
|
+
|
|
23
|
+
# Marker key for numpy arrays
|
|
24
|
+
_NP_ARRAY_TAG = "__numpy_array__"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _numpy_decoder(dct):
|
|
28
|
+
"""Internal JSON decoder hook to handle numpy arrays."""
|
|
29
|
+
if dct.get(_NP_ARRAY_TAG):
|
|
30
|
+
arr = np.array(dct["data"], dtype=dct["dtype"])
|
|
31
|
+
return arr.reshape(dct["shape"])
|
|
32
|
+
return dct
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
36
|
+
"""Custom JSON encoder hook to handle numpy arrays."""
|
|
37
|
+
|
|
38
|
+
def default(self, obj):
|
|
39
|
+
if isinstance(obj, np.ndarray):
|
|
40
|
+
# Reject complex-valued dtypes as they're not JSON serializable
|
|
41
|
+
if np.issubdtype(obj.dtype, np.complexfloating):
|
|
42
|
+
raise TypeError(
|
|
43
|
+
f"Cannot serialize numpy array with complex dtype {obj.dtype}. "
|
|
44
|
+
"Complex dtypes are not supported."
|
|
45
|
+
)
|
|
46
|
+
return {
|
|
47
|
+
_NP_ARRAY_TAG: True,
|
|
48
|
+
"dtype": obj.dtype.str,
|
|
49
|
+
"shape": obj.shape,
|
|
50
|
+
"data": obj.flatten().tolist(),
|
|
51
|
+
}
|
|
52
|
+
return super().default(obj)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _encode_payload(data: dict) -> str:
|
|
56
|
+
"""Serializes, encodes, and formats data into a marker string."""
|
|
57
|
+
json_str = json.dumps(data, cls=NumpyEncoder)
|
|
58
|
+
encoded = base64.b64encode(json_str.encode("utf-8")).decode("ascii")
|
|
59
|
+
return f"{_MARKER_START} {encoded} {_MARKER_END}"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def mne_store_metadata(
|
|
63
|
+
raw: BaseRaw, payload: Any, *, key: str, no_overwrite: bool = False
|
|
64
|
+
) -> None:
|
|
65
|
+
"""Embed a JSON-serializable metadata payload in an MNE BaseRaw dataset
|
|
66
|
+
under a specified key.
|
|
67
|
+
|
|
68
|
+
This will encode the payload as a base64-encoded JSON string and store it
|
|
69
|
+
in the `info['description']` field of the Raw object while preserving any
|
|
70
|
+
existing content. Note this is not particularly efficient and should not
|
|
71
|
+
be used for very large payloads.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
raw : BaseRaw
|
|
76
|
+
The MNE Raw object to store data in.
|
|
77
|
+
payload : Any
|
|
78
|
+
The JSON-serializable data to store.
|
|
79
|
+
key : str
|
|
80
|
+
The key under which to store the payload.
|
|
81
|
+
no_overwrite : bool
|
|
82
|
+
If True, will not overwrite an existing entry with the same key.
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
# the description is apparently the only viable place where custom metadata may be
|
|
86
|
+
# stored in MNE Raw objects that persists through saving/loading
|
|
87
|
+
description = raw.info.get("description") or ""
|
|
88
|
+
|
|
89
|
+
# Try to find existing eegprep data
|
|
90
|
+
if match := _MARKER_PATTERN.search(description):
|
|
91
|
+
# Parse existing data
|
|
92
|
+
try:
|
|
93
|
+
decoded = base64.b64decode(match.group(1)).decode("utf-8")
|
|
94
|
+
existing_data = json.loads(decoded, object_hook=_numpy_decoder)
|
|
95
|
+
except (ValueError, json.JSONDecodeError):
|
|
96
|
+
existing_data = {}
|
|
97
|
+
# Check no_overwrite condition
|
|
98
|
+
if no_overwrite and key in existing_data:
|
|
99
|
+
return
|
|
100
|
+
# Update data
|
|
101
|
+
existing_data[key] = payload
|
|
102
|
+
new_marker = _encode_payload(existing_data)
|
|
103
|
+
# Replace the old marker with updated one
|
|
104
|
+
new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
|
|
105
|
+
else:
|
|
106
|
+
# No existing data, append new marker
|
|
107
|
+
data = {key: payload}
|
|
108
|
+
new_marker = _encode_payload(data)
|
|
109
|
+
# Append with spacing if description exists
|
|
110
|
+
if description.strip():
|
|
111
|
+
new_description = f"{description.rstrip()}\n{new_marker}"
|
|
112
|
+
else:
|
|
113
|
+
new_description = new_marker
|
|
114
|
+
|
|
115
|
+
raw.info["description"] = new_description
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def mne_load_metadata(raw: BaseRaw, *, key: str, delete: bool = False) -> Any | None:
|
|
119
|
+
"""Retrieves data that was previously stored using mne_store_metadata from an MNE
|
|
120
|
+
BaseRaw dataset.
|
|
121
|
+
|
|
122
|
+
This function can retrieve data from an MNE Raw object that was stored
|
|
123
|
+
using `mne_store_metadata`. It decodes the base64-encoded JSON string from the
|
|
124
|
+
`info['description']` field and extracts the payload associated with the
|
|
125
|
+
specified key.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
raw : BaseRaw
|
|
130
|
+
The MNE Raw object to retrieve data from.
|
|
131
|
+
key : str
|
|
132
|
+
The key under which the payload was stored.
|
|
133
|
+
delete : bool
|
|
134
|
+
If True, removes the key from the stored data after retrieval.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
Any | None
|
|
139
|
+
The retrieved payload, or None if not found.
|
|
140
|
+
"""
|
|
141
|
+
description = raw.info.get("description") or ""
|
|
142
|
+
match = _MARKER_PATTERN.search(description)
|
|
143
|
+
if not match:
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
decoded = base64.b64decode(match.group(1)).decode("utf-8")
|
|
148
|
+
data = json.loads(decoded, object_hook=_numpy_decoder)
|
|
149
|
+
except (ValueError, json.JSONDecodeError):
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
result = data.get(key)
|
|
153
|
+
|
|
154
|
+
if delete and key in data:
|
|
155
|
+
# Remove the key from data
|
|
156
|
+
del data[key]
|
|
157
|
+
if data:
|
|
158
|
+
# Still have other keys, update the marker
|
|
159
|
+
new_marker = _encode_payload(data)
|
|
160
|
+
new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
|
|
161
|
+
else:
|
|
162
|
+
# No more keys, remove the entire marker
|
|
163
|
+
new_description = _MARKER_PATTERN.sub("", description, count=1).rstrip()
|
|
164
|
+
raw.info["description"] = new_description
|
|
165
|
+
|
|
166
|
+
return result
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# Maciej Sliwowski <maciek.sliwowski@gmail.com>
|
|
12
12
|
# Mohammed Fattouh <mo.fattouh@gmail.com>
|
|
13
13
|
# Robin Schirrmeister <robintibor@gmail.com>
|
|
14
|
+
# Matthew Chen <matt.chen42601@gmail.com>
|
|
14
15
|
#
|
|
15
16
|
# License: BSD (3-clause)
|
|
16
17
|
|
|
@@ -25,7 +26,12 @@ import pandas as pd
|
|
|
25
26
|
from joblib import Parallel, delayed
|
|
26
27
|
from numpy.typing import ArrayLike
|
|
27
28
|
|
|
28
|
-
from ..datasets.base import
|
|
29
|
+
from ..datasets.base import (
|
|
30
|
+
BaseConcatDataset,
|
|
31
|
+
EEGWindowsDataset,
|
|
32
|
+
RawDataset,
|
|
33
|
+
WindowsDataset,
|
|
34
|
+
)
|
|
29
35
|
|
|
30
36
|
|
|
31
37
|
class _LazyDataFrame:
|
|
@@ -189,7 +195,7 @@ def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
|
|
|
189
195
|
|
|
190
196
|
# XXX it's called concat_ds...
|
|
191
197
|
def create_windows_from_events(
|
|
192
|
-
concat_ds: BaseConcatDataset,
|
|
198
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
193
199
|
trial_start_offset_samples: int = 0,
|
|
194
200
|
trial_stop_offset_samples: int = 0,
|
|
195
201
|
window_size_samples: int | None = None,
|
|
@@ -206,7 +212,7 @@ def create_windows_from_events(
|
|
|
206
212
|
use_mne_epochs: bool | None = None,
|
|
207
213
|
n_jobs: int = 1,
|
|
208
214
|
verbose: bool | str | int | None = "error",
|
|
209
|
-
):
|
|
215
|
+
) -> BaseConcatDataset[WindowsDataset | EEGWindowsDataset]:
|
|
210
216
|
"""Create windows based on events in mne.Raw.
|
|
211
217
|
|
|
212
218
|
This function extracts windows of size window_size_samples in the interval
|
|
@@ -228,7 +234,7 @@ def create_windows_from_events(
|
|
|
228
234
|
|
|
229
235
|
Parameters
|
|
230
236
|
----------
|
|
231
|
-
concat_ds: BaseConcatDataset
|
|
237
|
+
concat_ds: BaseConcatDataset[RawDataset]
|
|
232
238
|
A concat of base datasets each holding raw and description.
|
|
233
239
|
trial_start_offset_samples: int
|
|
234
240
|
Start offset from original trial onsets, in samples. Defaults to zero.
|
|
@@ -268,7 +274,7 @@ def create_windows_from_events(
|
|
|
268
274
|
rejection based on flatness is done. See mne.Epochs.
|
|
269
275
|
on_missing: str
|
|
270
276
|
What to do if one or several event ids are not found in the recording.
|
|
271
|
-
Valid keys are ‘error
|
|
277
|
+
Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
|
|
272
278
|
accepted_bads_ratio: float, optional
|
|
273
279
|
Acceptable proportion of trials with inconsistent length in a raw. If
|
|
274
280
|
the number of trials whose length is exceeded by the window size is
|
|
@@ -286,7 +292,7 @@ def create_windows_from_events(
|
|
|
286
292
|
|
|
287
293
|
Returns
|
|
288
294
|
-------
|
|
289
|
-
windows_datasets: BaseConcatDataset
|
|
295
|
+
windows_datasets: BaseConcatDataset[WindowsDataset | EEGWindowsDataset]
|
|
290
296
|
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
291
297
|
"""
|
|
292
298
|
_check_windowing_arguments(
|
|
@@ -341,7 +347,7 @@ def create_windows_from_events(
|
|
|
341
347
|
|
|
342
348
|
|
|
343
349
|
def create_fixed_length_windows(
|
|
344
|
-
concat_ds: BaseConcatDataset,
|
|
350
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
345
351
|
start_offset_samples: int = 0,
|
|
346
352
|
stop_offset_samples: int | None = None,
|
|
347
353
|
window_size_samples: int | None = None,
|
|
@@ -358,12 +364,12 @@ def create_fixed_length_windows(
|
|
|
358
364
|
on_missing: str = "error",
|
|
359
365
|
n_jobs: int = 1,
|
|
360
366
|
verbose: bool | str | int | None = "error",
|
|
361
|
-
):
|
|
367
|
+
) -> BaseConcatDataset[EEGWindowsDataset]:
|
|
362
368
|
"""Windower that creates sliding windows.
|
|
363
369
|
|
|
364
370
|
Parameters
|
|
365
371
|
----------
|
|
366
|
-
concat_ds: ConcatDataset
|
|
372
|
+
concat_ds: ConcatDataset[RawDataset]
|
|
367
373
|
A concat of base datasets each holding raw and description.
|
|
368
374
|
start_offset_samples: int
|
|
369
375
|
Start offset from beginning of recording in samples.
|
|
@@ -398,7 +404,7 @@ def create_fixed_length_windows(
|
|
|
398
404
|
by using the _LazyDataFrame (experimental).
|
|
399
405
|
on_missing: str
|
|
400
406
|
What to do if one or several event ids are not found in the recording.
|
|
401
|
-
Valid keys are ‘error
|
|
407
|
+
Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
|
|
402
408
|
n_jobs: int
|
|
403
409
|
Number of jobs to use to parallelize the windowing.
|
|
404
410
|
verbose: bool | str | int | None
|
|
@@ -406,7 +412,7 @@ def create_fixed_length_windows(
|
|
|
406
412
|
|
|
407
413
|
Returns
|
|
408
414
|
-------
|
|
409
|
-
windows_datasets: BaseConcatDataset
|
|
415
|
+
windows_datasets: BaseConcatDataset[EEGWindowsDataset]
|
|
410
416
|
Concatenated datasets of WindowsDataset containing the extracted windows.
|
|
411
417
|
"""
|
|
412
418
|
stop_offset_samples, drop_last_window = (
|
|
@@ -473,11 +479,11 @@ def _create_windows_from_events(
|
|
|
473
479
|
verbose="error",
|
|
474
480
|
use_mne_epochs=False,
|
|
475
481
|
):
|
|
476
|
-
"""Create WindowsDataset from
|
|
482
|
+
"""Create WindowsDataset from RawDataset based on events.
|
|
477
483
|
|
|
478
484
|
Parameters
|
|
479
485
|
----------
|
|
480
|
-
ds :
|
|
486
|
+
ds : RawDataset
|
|
481
487
|
Dataset containing continuous data and description.
|
|
482
488
|
infer_mapping : bool
|
|
483
489
|
If True, extract all events from all datasets and map them to
|
|
@@ -509,7 +515,7 @@ def _create_windows_from_events(
|
|
|
509
515
|
}
|
|
510
516
|
)
|
|
511
517
|
|
|
512
|
-
events, events_id = mne.events_from_annotations(ds.raw, mapping)
|
|
518
|
+
events, events_id = mne.events_from_annotations(ds.raw, mapping, verbose=verbose)
|
|
513
519
|
onsets = events[:, 0]
|
|
514
520
|
# Onsets are relative to the beginning of the recording
|
|
515
521
|
filtered_durations = np.array(
|
|
@@ -648,11 +654,11 @@ def _create_fixed_length_windows(
|
|
|
648
654
|
on_missing="error",
|
|
649
655
|
verbose="error",
|
|
650
656
|
):
|
|
651
|
-
"""Create WindowsDataset from
|
|
657
|
+
"""Create WindowsDataset from RawDataset with sliding windows.
|
|
652
658
|
|
|
653
659
|
Parameters
|
|
654
660
|
----------
|
|
655
|
-
ds :
|
|
661
|
+
ds : RawDataset
|
|
656
662
|
Dataset containing continuous data and description.
|
|
657
663
|
|
|
658
664
|
See `create_fixed_length_windows` for description of other parameters.
|
|
@@ -750,7 +756,7 @@ def _create_fixed_length_windows(
|
|
|
750
756
|
|
|
751
757
|
|
|
752
758
|
def create_windows_from_target_channels(
|
|
753
|
-
concat_ds,
|
|
759
|
+
concat_ds: BaseConcatDataset[RawDataset],
|
|
754
760
|
window_size_samples=None,
|
|
755
761
|
preload=False,
|
|
756
762
|
picks=None,
|
|
@@ -759,7 +765,7 @@ def create_windows_from_target_channels(
|
|
|
759
765
|
n_jobs=1,
|
|
760
766
|
last_target_only=True,
|
|
761
767
|
verbose="error",
|
|
762
|
-
):
|
|
768
|
+
) -> BaseConcatDataset[EEGWindowsDataset]:
|
|
763
769
|
list_of_windows_ds = Parallel(n_jobs=n_jobs)(
|
|
764
770
|
delayed(_create_windows_from_target_channels)(
|
|
765
771
|
ds,
|
|
@@ -788,11 +794,11 @@ def _create_windows_from_target_channels(
|
|
|
788
794
|
on_missing="error",
|
|
789
795
|
verbose="error",
|
|
790
796
|
):
|
|
791
|
-
"""Create WindowsDataset from
|
|
797
|
+
"""Create WindowsDataset from RawDataset using targets `misc` channels from mne.Raw.
|
|
792
798
|
|
|
793
799
|
Parameters
|
|
794
800
|
----------
|
|
795
|
-
ds :
|
|
801
|
+
ds : RawDataset
|
|
796
802
|
Dataset containing continuous data and description.
|
|
797
803
|
|
|
798
804
|
See `create_fixed_length_windows` for description of other parameters.
|