braindecode 1.3.0.dev180329405__py3-none-any.whl → 1.3.0.dev182330353__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +12 -4
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +17 -7
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/__init__.py +11 -1
  15. braindecode/datautil/channel_utils.py +114 -0
  16. braindecode/datautil/serialization.py +7 -7
  17. braindecode/functional/functions.py +6 -2
  18. braindecode/functional/initialization.py +2 -3
  19. braindecode/models/__init__.py +6 -0
  20. braindecode/models/atcnet.py +26 -27
  21. braindecode/models/attentionbasenet.py +37 -32
  22. braindecode/models/attn_sleep.py +2 -0
  23. braindecode/models/base.py +280 -2
  24. braindecode/models/bendr.py +469 -0
  25. braindecode/models/biot.py +2 -0
  26. braindecode/models/contrawr.py +2 -0
  27. braindecode/models/ctnet.py +8 -3
  28. braindecode/models/deepsleepnet.py +28 -19
  29. braindecode/models/eegconformer.py +2 -2
  30. braindecode/models/eeginception_erp.py +31 -25
  31. braindecode/models/eegitnet.py +2 -0
  32. braindecode/models/eegminer.py +2 -0
  33. braindecode/models/eegnet.py +1 -1
  34. braindecode/models/eegsym.py +917 -0
  35. braindecode/models/eegtcnet.py +2 -0
  36. braindecode/models/fbcnet.py +5 -1
  37. braindecode/models/fblightconvnet.py +2 -0
  38. braindecode/models/fbmsnet.py +20 -6
  39. braindecode/models/ifnet.py +2 -0
  40. braindecode/models/labram.py +33 -26
  41. braindecode/models/medformer.py +758 -0
  42. braindecode/models/msvtnet.py +2 -0
  43. braindecode/models/patchedtransformer.py +1 -1
  44. braindecode/models/signal_jepa.py +111 -27
  45. braindecode/models/sinc_shallow.py +12 -9
  46. braindecode/models/sstdpn.py +11 -11
  47. braindecode/models/summary.csv +3 -0
  48. braindecode/models/syncnet.py +2 -0
  49. braindecode/models/tcn.py +2 -0
  50. braindecode/models/usleep.py +26 -21
  51. braindecode/models/util.py +3 -0
  52. braindecode/modules/attention.py +10 -10
  53. braindecode/modules/blocks.py +3 -3
  54. braindecode/modules/filter.py +2 -9
  55. braindecode/modules/layers.py +18 -17
  56. braindecode/preprocessing/__init__.py +232 -3
  57. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  58. braindecode/preprocessing/mne_preprocess.py +142 -10
  59. braindecode/preprocessing/preprocess.py +28 -18
  60. braindecode/preprocessing/util.py +166 -0
  61. braindecode/preprocessing/windowers.py +26 -20
  62. braindecode/samplers/base.py +8 -8
  63. braindecode/version.py +1 -1
  64. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/METADATA +6 -2
  65. braindecode-1.3.0.dev182330353.dist-info/RECORD +109 -0
  66. braindecode-1.3.0.dev180329405.dist-info/RECORD +0 -103
  67. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/WHEEL +0 -0
  68. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/LICENSE.txt +0 -0
  69. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/licenses/NOTICE.txt +0 -0
  70. {braindecode-1.3.0.dev180329405.dist-info → braindecode-1.3.0.dev182330353.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,46 @@
6
6
  # License: BSD-3
7
7
  import inspect
8
8
 
9
+ import mne.channels
9
10
  import mne.io
11
+ import mne.preprocessing
10
12
 
11
13
  from braindecode.preprocessing.preprocess import Preprocessor
12
- from braindecode.util import _update_moabb_docstring
13
14
 
14
15
 
15
- def _generate_init_method(func):
16
+ def _is_standalone_function(func):
17
+ """
18
+ Determine if a function is standalone based on its module.
19
+
20
+ Standalone functions are those in mne.preprocessing, mne.channels, mne.filter, etc.
21
+ that are not methods of mne.io.Raw.
22
+ """
23
+ # Check if it's a method of Raw by seeing if it's bound or unbound method
24
+ if hasattr(mne.io.Raw, func.__name__):
25
+ return False
26
+ # Otherwise, it's a standalone function
27
+ return True
28
+
29
+
30
+ def _generate_init_method(func, force_copy_false=False):
16
31
  """
17
32
  Generate an __init__ method for a class based on the function's signature.
33
+
34
+ Parameters
35
+ ----------
36
+ func : callable
37
+ The function to wrap.
38
+ force_copy_false : bool
39
+ If True, forces copy=False by default for functions that have a copy parameter.
18
40
  """
19
41
  parameters = list(inspect.signature(func).parameters.values())
20
42
  param_names = [param.name for param in parameters]
21
43
 
22
44
  def init_method(self, *args, **kwargs):
45
+ # For standalone functions with copy parameter, set copy=False by default
46
+ if force_copy_false and "copy" in param_names and "copy" not in kwargs:
47
+ kwargs["copy"] = False
48
+
23
49
  for name, value in zip(param_names, args):
24
50
  setattr(self, name, value)
25
51
  for name, value in kwargs.items():
@@ -33,19 +59,70 @@ def _generate_init_method(func):
33
59
  def _generate_mne_pre_processor(function):
34
60
  """
35
61
  Generate a class based on an MNE function for preprocessing.
62
+
63
+ Parameters
64
+ ----------
65
+ function : callable
66
+ The MNE function to wrap. Automatically determines if it's standalone
67
+ or a Raw method based on the function's module and name.
36
68
  """
37
69
  class_name = "".join(word.title() for word in function.__name__.split("_")).replace(
38
70
  "Eeg", "EEG"
39
71
  )
40
- import_path = f"{function.__module__}.{function.__name__}"
41
- doc = f" See more details in {import_path}"
72
+ # Create a wrapper note that references the original MNE function
73
+ # For Raw methods, use mne.io.Raw.method_name format with :meth:
74
+ # For standalone functions, use the function name only with :func:
75
+ if hasattr(mne.io.Raw, function.__name__):
76
+ ref_path = f"mne.io.Raw.{function.__name__}"
77
+ ref_role = "meth"
78
+ else:
79
+ # For standalone functions, try common MNE public APIs
80
+ # These are more likely to be in intersphinx inventory
81
+ func_name = function.__name__
82
+ if function.__module__.startswith("mne.preprocessing"):
83
+ ref_path = f"mne.preprocessing.{func_name}"
84
+ elif function.__module__.startswith("mne.channels"):
85
+ ref_path = f"mne.channels.{func_name}"
86
+ elif function.__module__.startswith("mne.filter"):
87
+ ref_path = f"mne.filter.{func_name}"
88
+ else:
89
+ ref_path = f"{function.__module__}.{func_name}"
90
+ ref_role = "func"
91
+
92
+ # Use proper Sphinx cross-reference for intersphinx linking
93
+ wrapper_note = (
94
+ f"Braindecode preprocessor wrapper for :{ref_role}:`~{ref_path}`.\n\n"
95
+ )
42
96
 
43
97
  base_classes = (Preprocessor,)
44
- class_attrs = {
45
- "__init__": _generate_init_method(function),
46
- "__doc__": _update_moabb_docstring(function, doc),
47
- "fn": function.__name__,
48
- }
98
+
99
+ # Automatically determine if function is standalone
100
+ is_standalone = _is_standalone_function(function)
101
+
102
+ # Check if function has a 'copy' parameter
103
+ sig = inspect.signature(function)
104
+ has_copy_param = "copy" in sig.parameters
105
+ force_copy_false = is_standalone and has_copy_param
106
+
107
+ if is_standalone:
108
+ # For standalone functions, store the actual function object
109
+ class_attrs = {
110
+ "__init__": _generate_init_method(
111
+ function, force_copy_false=force_copy_false
112
+ ),
113
+ "__doc__": wrapper_note + (function.__doc__ or ""),
114
+ "fn": function, # Store the function itself, not the name
115
+ "_is_standalone": True,
116
+ }
117
+ else:
118
+ # For methods, store the function name as before
119
+ class_attrs = {
120
+ "__init__": _generate_init_method(function),
121
+ "__doc__": wrapper_note + (function.__doc__ or ""),
122
+ "fn": function.__name__,
123
+ "_is_standalone": False,
124
+ }
125
+
49
126
  generated_class = type(class_name, base_classes, class_attrs)
50
127
 
51
128
  return generated_class
@@ -53,12 +130,65 @@ def _generate_mne_pre_processor(function):
53
130
 
54
131
  # List of MNE functions to generate classes for
55
132
  mne_functions = [
133
+ # From mne.filter
56
134
  mne.filter.resample,
135
+ mne.filter.filter_data,
136
+ mne.filter.notch_filter,
137
+ # From mne.io.Raw methods
138
+ mne.io.Raw.add_channels,
139
+ mne.io.Raw.add_events,
140
+ mne.io.Raw.add_proj,
141
+ mne.io.Raw.add_reference_channels,
142
+ mne.io.Raw.anonymize,
143
+ mne.io.Raw.apply_gradient_compensation,
144
+ mne.io.Raw.apply_hilbert,
145
+ mne.io.Raw.apply_proj,
146
+ mne.io.Raw.crop,
147
+ mne.io.Raw.crop_by_annotations,
148
+ mne.io.Raw.del_proj,
57
149
  mne.io.Raw.drop_channels,
58
150
  mne.io.Raw.filter,
59
- mne.io.Raw.crop,
151
+ mne.io.Raw.fix_mag_coil_types,
152
+ mne.io.Raw.interpolate_bads,
153
+ mne.io.Raw.interpolate_to,
154
+ mne.io.Raw.notch_filter,
60
155
  mne.io.Raw.pick,
156
+ mne.io.Raw.pick_channels,
157
+ mne.io.Raw.pick_types,
158
+ mne.io.Raw.rename_channels,
159
+ mne.io.Raw.reorder_channels,
160
+ mne.io.Raw.rescale,
161
+ mne.io.Raw.resample,
162
+ mne.io.Raw.savgol_filter,
163
+ mne.io.Raw.set_annotations,
164
+ mne.io.Raw.set_channel_types,
61
165
  mne.io.Raw.set_eeg_reference,
166
+ mne.io.Raw.set_meas_date,
167
+ mne.io.Raw.set_montage,
168
+ # Standalone functions from mne.preprocessing
169
+ mne.preprocessing.annotate_amplitude,
170
+ mne.preprocessing.annotate_break,
171
+ mne.preprocessing.annotate_movement,
172
+ mne.preprocessing.annotate_muscle_zscore,
173
+ mne.preprocessing.annotate_nan,
174
+ mne.preprocessing.compute_current_source_density,
175
+ mne.preprocessing.compute_bridged_electrodes,
176
+ mne.preprocessing.equalize_bads,
177
+ mne.preprocessing.find_bad_channels_lof,
178
+ mne.preprocessing.fix_stim_artifact,
179
+ mne.preprocessing.interpolate_bridged_electrodes,
180
+ mne.preprocessing.maxwell_filter,
181
+ mne.preprocessing.oversampled_temporal_projection,
182
+ mne.preprocessing.realign_raw,
183
+ mne.preprocessing.regress_artifact,
184
+ # Standalone functions from mne.channels
185
+ mne.channels.combine_channels,
186
+ mne.channels.equalize_channels,
187
+ mne.channels.rename_channels,
188
+ # Top-level mne functions for referencing
189
+ mne.add_reference_channels,
190
+ mne.set_bipolar_reference,
191
+ mne.set_eeg_reference,
62
192
  ]
63
193
 
64
194
  # Automatically generate and add classes to the global namespace
@@ -71,6 +201,8 @@ __all__ = [
71
201
  class_obj.__name__
72
202
  for class_obj in globals().values()
73
203
  if isinstance(class_obj, type)
204
+ and issubclass(class_obj, Preprocessor)
205
+ and class_obj != Preprocessor
74
206
  ]
75
207
 
76
208
  # Clean up unnecessary variables
@@ -30,8 +30,8 @@ from numpy.typing import NDArray
30
30
 
31
31
  from braindecode.datasets.base import (
32
32
  BaseConcatDataset,
33
- BaseDataset,
34
33
  EEGWindowsDataset,
34
+ RawDataset,
35
35
  WindowsDataset,
36
36
  )
37
37
  from braindecode.datautil.serialization import (
@@ -85,7 +85,7 @@ class Preprocessor(object):
85
85
 
86
86
  def apply(self, raw_or_epochs: BaseRaw | BaseEpochs):
87
87
  try:
88
- self._try_apply(raw_or_epochs)
88
+ return self._try_apply(raw_or_epochs)
89
89
  except RuntimeError:
90
90
  # Maybe the function needs the data to be loaded and the data was
91
91
  # not loaded yet. Not all MNE functions need data to be loaded,
@@ -93,15 +93,20 @@ class Preprocessor(object):
93
93
  # without preloading data which can make the overall preprocessing
94
94
  # pipeline substantially faster.
95
95
  raw_or_epochs.load_data()
96
- self._try_apply(raw_or_epochs)
96
+ return self._try_apply(raw_or_epochs)
97
97
 
98
98
  def _try_apply(self, raw_or_epochs):
99
99
  if callable(self.fn):
100
- self.fn(raw_or_epochs, **self.kwargs)
100
+ result = self.fn(raw_or_epochs, **self.kwargs)
101
+ # For standalone functions that return a new object, propagate it back
102
+ if result is not None and result is not raw_or_epochs:
103
+ return result
104
+ return raw_or_epochs
101
105
  else:
102
106
  if not hasattr(raw_or_epochs, self.fn):
103
107
  raise AttributeError(f"MNE object does not have a {self.fn} method.")
104
108
  getattr(raw_or_epochs, self.fn)(**self.kwargs)
109
+ return raw_or_epochs
105
110
 
106
111
 
107
112
  def preprocess(
@@ -119,7 +124,7 @@ def preprocess(
119
124
  Parameters
120
125
  ----------
121
126
  concat_ds : BaseConcatDataset
122
- A concat of ``BaseDataset`` or ``WindowsDataset`` to be preprocessed.
127
+ A concat of ``RecordDataset`` to be preprocessed.
123
128
  preprocessors : list of Preprocessor
124
129
  Preprocessor objects to apply to each dataset.
125
130
  save_dir : str | None
@@ -229,15 +234,15 @@ def _preprocess(
229
234
 
230
235
  Parameters
231
236
  ----------
232
- ds: BaseDataset | WindowsDataset
237
+ ds: RecordDataset
233
238
  Dataset object to preprocess.
234
239
  ds_index : int
235
- Index of the BaseDataset in its BaseConcatDataset. Ignored if save_dir
240
+ Index of the ``RecordDataset`` in its ``BaseConcatDataset``. Ignored if save_dir
236
241
  is None.
237
242
  preprocessors: list(Preprocessor)
238
243
  List of preprocessors to apply to the dataset.
239
244
  save_dir : str | None
240
- If provided, save the preprocessed BaseDataset in the
245
+ If provided, save the preprocessed RecordDataset in the
241
246
  specified directory.
242
247
  overwrite : bool
243
248
  If True, overwrite existing file with the same name.
@@ -251,20 +256,25 @@ def _preprocess(
251
256
  if raw_or_epochs.preload and copy_data:
252
257
  raw_or_epochs._data = raw_or_epochs._data.copy()
253
258
  for preproc in preprocessors:
254
- preproc.apply(raw_or_epochs)
259
+ raw_or_epochs = preproc.apply(raw_or_epochs)
260
+ return raw_or_epochs
255
261
 
256
262
  if hasattr(ds, "raw"):
257
263
  if isinstance(ds, EEGWindowsDataset):
258
264
  warn(
259
265
  f"Applying preprocessors {preprocessors} to the mne.io.Raw of an EEGWindowsDataset."
260
266
  )
261
- _preprocess_raw_or_epochs(ds.raw, preprocessors)
267
+ processed = _preprocess_raw_or_epochs(ds.raw, preprocessors)
268
+ if processed is not ds.raw:
269
+ ds.raw = processed
262
270
  elif hasattr(ds, "windows"):
263
- _preprocess_raw_or_epochs(ds.windows, preprocessors)
271
+ processed = _preprocess_raw_or_epochs(ds.windows, preprocessors)
272
+ if processed is not ds.windows:
273
+ ds.windows = processed
264
274
  else:
265
275
  raise ValueError(
266
- "Can only preprocess concatenation of BaseDataset or "
267
- "WindowsDataset, with either a `raw` or `windows` attribute."
276
+ "Can only preprocess concatenation of RecordDataset, "
277
+ "with either a `raw` or `windows` attribute."
268
278
  )
269
279
 
270
280
  # Store preprocessing keyword arguments in the dataset
@@ -297,11 +307,11 @@ def _get_preproc_kwargs(preprocessors):
297
307
 
298
308
 
299
309
  def _set_preproc_kwargs(ds, preprocessors):
300
- """Record preprocessing keyword arguments in BaseDataset or WindowsDataset.
310
+ """Record preprocessing keyword arguments in RecordDataset.
301
311
 
302
312
  Parameters
303
313
  ----------
304
- ds : BaseDataset | WindowsDataset
314
+ ds : RecordDataset
305
315
  Dataset in which to record preprocessing keyword arguments.
306
316
  preprocessors : list
307
317
  List of preprocessors.
@@ -309,12 +319,12 @@ def _set_preproc_kwargs(ds, preprocessors):
309
319
  preproc_kwargs = _get_preproc_kwargs(preprocessors)
310
320
  if isinstance(ds, WindowsDataset):
311
321
  kind = "window"
312
- if isinstance(ds, EEGWindowsDataset):
322
+ elif isinstance(ds, EEGWindowsDataset):
313
323
  kind = "raw"
314
- elif isinstance(ds, BaseDataset):
324
+ elif isinstance(ds, RawDataset):
315
325
  kind = "raw"
316
326
  else:
317
- raise TypeError(f"ds must be a BaseDataset or a WindowsDataset, got {type(ds)}")
327
+ raise TypeError(f"ds must be a RecordDataset, got {type(ds)}")
318
328
  setattr(ds, kind + "_preproc_kwargs", preproc_kwargs)
319
329
 
320
330
 
@@ -0,0 +1,166 @@
1
+ """Utilities for preprocessing functionality in Braindecode."""
2
+
3
+ # Authors: Christian Kothe <christian.kothe@intheon.io>
4
+ #
5
+ # License: BSD-3
6
+
7
+ import base64
8
+ import json
9
+ import re
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ from mne.io.base import BaseRaw
14
+
15
+ __all__ = ["mne_store_metadata", "mne_load_metadata"]
16
+
17
+
18
+ # Use a unique marker for embedding structured data in info['description']
19
+ _MARKER_PATTERN = re.compile(r"<!-- braindecode-meta:\s*(\S+)\s*-->", re.DOTALL)
20
+ _MARKER_START = "<!-- braindecode-meta:"
21
+ _MARKER_END = "-->"
22
+
23
+ # Marker key for numpy arrays
24
+ _NP_ARRAY_TAG = "__numpy_array__"
25
+
26
+
27
+ def _numpy_decoder(dct):
28
+ """Internal JSON decoder hook to handle numpy arrays."""
29
+ if dct.get(_NP_ARRAY_TAG):
30
+ arr = np.array(dct["data"], dtype=dct["dtype"])
31
+ return arr.reshape(dct["shape"])
32
+ return dct
33
+
34
+
35
+ class NumpyEncoder(json.JSONEncoder):
36
+ """Custom JSON encoder hook to handle numpy arrays."""
37
+
38
+ def default(self, obj):
39
+ if isinstance(obj, np.ndarray):
40
+ # Reject complex-valued dtypes as they're not JSON serializable
41
+ if np.issubdtype(obj.dtype, np.complexfloating):
42
+ raise TypeError(
43
+ f"Cannot serialize numpy array with complex dtype {obj.dtype}. "
44
+ "Complex dtypes are not supported."
45
+ )
46
+ return {
47
+ _NP_ARRAY_TAG: True,
48
+ "dtype": obj.dtype.str,
49
+ "shape": obj.shape,
50
+ "data": obj.flatten().tolist(),
51
+ }
52
+ return super().default(obj)
53
+
54
+
55
+ def _encode_payload(data: dict) -> str:
56
+ """Serializes, encodes, and formats data into a marker string."""
57
+ json_str = json.dumps(data, cls=NumpyEncoder)
58
+ encoded = base64.b64encode(json_str.encode("utf-8")).decode("ascii")
59
+ return f"{_MARKER_START} {encoded} {_MARKER_END}"
60
+
61
+
62
+ def mne_store_metadata(
63
+ raw: BaseRaw, payload: Any, *, key: str, no_overwrite: bool = False
64
+ ) -> None:
65
+ """Embed a JSON-serializable metadata payload in an MNE BaseRaw dataset
66
+ under a specified key.
67
+
68
+ This will encode the payload as a base64-encoded JSON string and store it
69
+ in the `info['description']` field of the Raw object while preserving any
70
+ existing content. Note this is not particularly efficient and should not
71
+ be used for very large payloads.
72
+
73
+ Parameters
74
+ ----------
75
+ raw : BaseRaw
76
+ The MNE Raw object to store data in.
77
+ payload : Any
78
+ The JSON-serializable data to store.
79
+ key : str
80
+ The key under which to store the payload.
81
+ no_overwrite : bool
82
+ If True, will not overwrite an existing entry with the same key.
83
+
84
+ """
85
+ # the description is apparently the only viable place where custom metadata may be
86
+ # stored in MNE Raw objects that persists through saving/loading
87
+ description = raw.info.get("description") or ""
88
+
89
+ # Try to find existing eegprep data
90
+ if match := _MARKER_PATTERN.search(description):
91
+ # Parse existing data
92
+ try:
93
+ decoded = base64.b64decode(match.group(1)).decode("utf-8")
94
+ existing_data = json.loads(decoded, object_hook=_numpy_decoder)
95
+ except (ValueError, json.JSONDecodeError):
96
+ existing_data = {}
97
+ # Check no_overwrite condition
98
+ if no_overwrite and key in existing_data:
99
+ return
100
+ # Update data
101
+ existing_data[key] = payload
102
+ new_marker = _encode_payload(existing_data)
103
+ # Replace the old marker with updated one
104
+ new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
105
+ else:
106
+ # No existing data, append new marker
107
+ data = {key: payload}
108
+ new_marker = _encode_payload(data)
109
+ # Append with spacing if description exists
110
+ if description.strip():
111
+ new_description = f"{description.rstrip()}\n{new_marker}"
112
+ else:
113
+ new_description = new_marker
114
+
115
+ raw.info["description"] = new_description
116
+
117
+
118
+ def mne_load_metadata(raw: BaseRaw, *, key: str, delete: bool = False) -> Any | None:
119
+ """Retrieves data that was previously stored using mne_store_metadata from an MNE
120
+ BaseRaw dataset.
121
+
122
+ This function can retrieve data from an MNE Raw object that was stored
123
+ using `mne_store_metadata`. It decodes the base64-encoded JSON string from the
124
+ `info['description']` field and extracts the payload associated with the
125
+ specified key.
126
+
127
+ Parameters
128
+ ----------
129
+ raw : BaseRaw
130
+ The MNE Raw object to retrieve data from.
131
+ key : str
132
+ The key under which the payload was stored.
133
+ delete : bool
134
+ If True, removes the key from the stored data after retrieval.
135
+
136
+ Returns
137
+ -------
138
+ Any | None
139
+ The retrieved payload, or None if not found.
140
+ """
141
+ description = raw.info.get("description") or ""
142
+ match = _MARKER_PATTERN.search(description)
143
+ if not match:
144
+ return None
145
+
146
+ try:
147
+ decoded = base64.b64decode(match.group(1)).decode("utf-8")
148
+ data = json.loads(decoded, object_hook=_numpy_decoder)
149
+ except (ValueError, json.JSONDecodeError):
150
+ return None
151
+
152
+ result = data.get(key)
153
+
154
+ if delete and key in data:
155
+ # Remove the key from data
156
+ del data[key]
157
+ if data:
158
+ # Still have other keys, update the marker
159
+ new_marker = _encode_payload(data)
160
+ new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
161
+ else:
162
+ # No more keys, remove the entire marker
163
+ new_description = _MARKER_PATTERN.sub("", description, count=1).rstrip()
164
+ raw.info["description"] = new_description
165
+
166
+ return result
@@ -11,6 +11,7 @@
11
11
  # Maciej Sliwowski <maciek.sliwowski@gmail.com>
12
12
  # Mohammed Fattouh <mo.fattouh@gmail.com>
13
13
  # Robin Schirrmeister <robintibor@gmail.com>
14
+ # Matthew Chen <matt.chen42601@gmail.com>
14
15
  #
15
16
  # License: BSD (3-clause)
16
17
 
@@ -25,7 +26,12 @@ import pandas as pd
25
26
  from joblib import Parallel, delayed
26
27
  from numpy.typing import ArrayLike
27
28
 
28
- from ..datasets.base import BaseConcatDataset, EEGWindowsDataset, WindowsDataset
29
+ from ..datasets.base import (
30
+ BaseConcatDataset,
31
+ EEGWindowsDataset,
32
+ RawDataset,
33
+ WindowsDataset,
34
+ )
29
35
 
30
36
 
31
37
  class _LazyDataFrame:
@@ -189,7 +195,7 @@ def _get_use_mne_epochs(use_mne_epochs, reject, picks, flat, drop_bad_windows):
189
195
 
190
196
  # XXX it's called concat_ds...
191
197
  def create_windows_from_events(
192
- concat_ds: BaseConcatDataset,
198
+ concat_ds: BaseConcatDataset[RawDataset],
193
199
  trial_start_offset_samples: int = 0,
194
200
  trial_stop_offset_samples: int = 0,
195
201
  window_size_samples: int | None = None,
@@ -206,7 +212,7 @@ def create_windows_from_events(
206
212
  use_mne_epochs: bool | None = None,
207
213
  n_jobs: int = 1,
208
214
  verbose: bool | str | int | None = "error",
209
- ):
215
+ ) -> BaseConcatDataset[WindowsDataset | EEGWindowsDataset]:
210
216
  """Create windows based on events in mne.Raw.
211
217
 
212
218
  This function extracts windows of size window_size_samples in the interval
@@ -228,7 +234,7 @@ def create_windows_from_events(
228
234
 
229
235
  Parameters
230
236
  ----------
231
- concat_ds: BaseConcatDataset
237
+ concat_ds: BaseConcatDataset[RawDataset]
232
238
  A concat of base datasets each holding raw and description.
233
239
  trial_start_offset_samples: int
234
240
  Start offset from original trial onsets, in samples. Defaults to zero.
@@ -268,7 +274,7 @@ def create_windows_from_events(
268
274
  rejection based on flatness is done. See mne.Epochs.
269
275
  on_missing: str
270
276
  What to do if one or several event ids are not found in the recording.
271
- Valid keys are ‘error | ‘warning | ‘ignore’. See mne.Epochs.
277
+ Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
272
278
  accepted_bads_ratio: float, optional
273
279
  Acceptable proportion of trials with inconsistent length in a raw. If
274
280
  the number of trials whose length is exceeded by the window size is
@@ -286,7 +292,7 @@ def create_windows_from_events(
286
292
 
287
293
  Returns
288
294
  -------
289
- windows_datasets: BaseConcatDataset
295
+ windows_datasets: BaseConcatDataset[WindowsDataset | EEGWindowsDataset]
290
296
  Concatenated datasets of WindowsDataset containing the extracted windows.
291
297
  """
292
298
  _check_windowing_arguments(
@@ -341,7 +347,7 @@ def create_windows_from_events(
341
347
 
342
348
 
343
349
  def create_fixed_length_windows(
344
- concat_ds: BaseConcatDataset,
350
+ concat_ds: BaseConcatDataset[RawDataset],
345
351
  start_offset_samples: int = 0,
346
352
  stop_offset_samples: int | None = None,
347
353
  window_size_samples: int | None = None,
@@ -358,12 +364,12 @@ def create_fixed_length_windows(
358
364
  on_missing: str = "error",
359
365
  n_jobs: int = 1,
360
366
  verbose: bool | str | int | None = "error",
361
- ):
367
+ ) -> BaseConcatDataset[EEGWindowsDataset]:
362
368
  """Windower that creates sliding windows.
363
369
 
364
370
  Parameters
365
371
  ----------
366
- concat_ds: ConcatDataset
372
+ concat_ds: ConcatDataset[RawDataset]
367
373
  A concat of base datasets each holding raw and description.
368
374
  start_offset_samples: int
369
375
  Start offset from beginning of recording in samples.
@@ -398,7 +404,7 @@ def create_fixed_length_windows(
398
404
  by using the _LazyDataFrame (experimental).
399
405
  on_missing: str
400
406
  What to do if one or several event ids are not found in the recording.
401
- Valid keys are ‘error | ‘warning | ‘ignore’. See mne.Epochs.
407
+ Valid keys are ‘error' | ‘warning' | ‘ignore'. See mne.Epochs.
402
408
  n_jobs: int
403
409
  Number of jobs to use to parallelize the windowing.
404
410
  verbose: bool | str | int | None
@@ -406,7 +412,7 @@ def create_fixed_length_windows(
406
412
 
407
413
  Returns
408
414
  -------
409
- windows_datasets: BaseConcatDataset
415
+ windows_datasets: BaseConcatDataset[EEGWindowsDataset]
410
416
  Concatenated datasets of WindowsDataset containing the extracted windows.
411
417
  """
412
418
  stop_offset_samples, drop_last_window = (
@@ -473,11 +479,11 @@ def _create_windows_from_events(
473
479
  verbose="error",
474
480
  use_mne_epochs=False,
475
481
  ):
476
- """Create WindowsDataset from BaseDataset based on events.
482
+ """Create WindowsDataset from RawDataset based on events.
477
483
 
478
484
  Parameters
479
485
  ----------
480
- ds : BaseDataset
486
+ ds : RawDataset
481
487
  Dataset containing continuous data and description.
482
488
  infer_mapping : bool
483
489
  If True, extract all events from all datasets and map them to
@@ -509,7 +515,7 @@ def _create_windows_from_events(
509
515
  }
510
516
  )
511
517
 
512
- events, events_id = mne.events_from_annotations(ds.raw, mapping)
518
+ events, events_id = mne.events_from_annotations(ds.raw, mapping, verbose=verbose)
513
519
  onsets = events[:, 0]
514
520
  # Onsets are relative to the beginning of the recording
515
521
  filtered_durations = np.array(
@@ -648,11 +654,11 @@ def _create_fixed_length_windows(
648
654
  on_missing="error",
649
655
  verbose="error",
650
656
  ):
651
- """Create WindowsDataset from BaseDataset with sliding windows.
657
+ """Create WindowsDataset from RawDataset with sliding windows.
652
658
 
653
659
  Parameters
654
660
  ----------
655
- ds : BaseDataset
661
+ ds : RawDataset
656
662
  Dataset containing continuous data and description.
657
663
 
658
664
  See `create_fixed_length_windows` for description of other parameters.
@@ -750,7 +756,7 @@ def _create_fixed_length_windows(
750
756
 
751
757
 
752
758
  def create_windows_from_target_channels(
753
- concat_ds,
759
+ concat_ds: BaseConcatDataset[RawDataset],
754
760
  window_size_samples=None,
755
761
  preload=False,
756
762
  picks=None,
@@ -759,7 +765,7 @@ def create_windows_from_target_channels(
759
765
  n_jobs=1,
760
766
  last_target_only=True,
761
767
  verbose="error",
762
- ):
768
+ ) -> BaseConcatDataset[EEGWindowsDataset]:
763
769
  list_of_windows_ds = Parallel(n_jobs=n_jobs)(
764
770
  delayed(_create_windows_from_target_channels)(
765
771
  ds,
@@ -788,11 +794,11 @@ def _create_windows_from_target_channels(
788
794
  on_missing="error",
789
795
  verbose="error",
790
796
  ):
791
- """Create WindowsDataset from BaseDataset using targets `misc` channels from mne.Raw.
797
+ """Create WindowsDataset from RawDataset using targets `misc` channels from mne.Raw.
792
798
 
793
799
  Parameters
794
800
  ----------
795
- ds : BaseDataset
801
+ ds : RawDataset
796
802
  Dataset containing continuous data and description.
797
803
 
798
804
  See `create_fixed_length_windows` for description of other parameters.