braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,177 @@
1
+ """Utilities for preprocessing functionality in Braindecode."""
2
+
3
+ # Authors: Christian Kothe <christian.kothe@intheon.io>
4
+ #
5
+ # License: BSD-3
6
+
7
+ import base64
8
+ import inspect
9
+ import json
10
+ import re
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ from mne.io.base import BaseRaw
15
+
16
+ from braindecode import preprocessing
17
+
18
+ __all__ = ["mne_store_metadata", "mne_load_metadata"]
19
+
20
+
21
+ # Use a unique marker for embedding structured data in info['description']
22
+ _MARKER_PATTERN = re.compile(r"<!-- braindecode-meta:\s*(\S+)\s*-->", re.DOTALL)
23
+ _MARKER_START = "<!-- braindecode-meta:"
24
+ _MARKER_END = "-->"
25
+
26
+ # Marker key for numpy arrays
27
+ _NP_ARRAY_TAG = "__numpy_array__"
28
+
29
+ preprocessor_dict = {}
30
+
31
+
32
+ def _init_preprocessor_dict():
33
+ for m in inspect.getmembers(preprocessing, inspect.isclass):
34
+ if issubclass(m[1], preprocessing.Preprocessor):
35
+ preprocessor_dict[m[0]] = m[1]
36
+
37
+
38
+ def _numpy_decoder(dct):
39
+ """Internal JSON decoder hook to handle numpy arrays."""
40
+ if dct.get(_NP_ARRAY_TAG):
41
+ arr = np.array(dct["data"], dtype=dct["dtype"])
42
+ return arr.reshape(dct["shape"])
43
+ return dct
44
+
45
+
46
+ class NumpyEncoder(json.JSONEncoder):
47
+ """Custom JSON encoder hook to handle numpy arrays."""
48
+
49
+ def default(self, obj):
50
+ if isinstance(obj, np.ndarray):
51
+ # Reject complex-valued dtypes as they're not JSON serializable
52
+ if np.issubdtype(obj.dtype, np.complexfloating):
53
+ raise TypeError(
54
+ f"Cannot serialize numpy array with complex dtype {obj.dtype}. "
55
+ "Complex dtypes are not supported."
56
+ )
57
+ return {
58
+ _NP_ARRAY_TAG: True,
59
+ "dtype": obj.dtype.str,
60
+ "shape": obj.shape,
61
+ "data": obj.flatten().tolist(),
62
+ }
63
+ return super().default(obj)
64
+
65
+
66
+ def _encode_payload(data: dict) -> str:
67
+ """Serializes, encodes, and formats data into a marker string."""
68
+ json_str = json.dumps(data, cls=NumpyEncoder)
69
+ encoded = base64.b64encode(json_str.encode("utf-8")).decode("ascii")
70
+ return f"{_MARKER_START} {encoded} {_MARKER_END}"
71
+
72
+
73
+ def mne_store_metadata(
74
+ raw: BaseRaw, payload: Any, *, key: str, no_overwrite: bool = False
75
+ ) -> None:
76
+ """Embed a JSON-serializable metadata payload in an MNE BaseRaw dataset
77
+ under a specified key.
78
+
79
+ This will encode the payload as a base64-encoded JSON string and store it
80
+ in the `info['description']` field of the Raw object while preserving any
81
+ existing content. Note this is not particularly efficient and should not
82
+ be used for very large payloads.
83
+
84
+ Parameters
85
+ ----------
86
+ raw : BaseRaw
87
+ The MNE Raw object to store data in.
88
+ payload : Any
89
+ The JSON-serializable data to store.
90
+ key : str
91
+ The key under which to store the payload.
92
+ no_overwrite : bool
93
+ If True, will not overwrite an existing entry with the same key.
94
+
95
+ """
96
+ # the description is apparently the only viable place where custom metadata may be
97
+ # stored in MNE Raw objects that persists through saving/loading
98
+ description = raw.info.get("description") or ""
99
+
100
+ # Try to find existing eegprep data
101
+ if match := _MARKER_PATTERN.search(description):
102
+ # Parse existing data
103
+ try:
104
+ decoded = base64.b64decode(match.group(1)).decode("utf-8")
105
+ existing_data = json.loads(decoded, object_hook=_numpy_decoder)
106
+ except (ValueError, json.JSONDecodeError):
107
+ existing_data = {}
108
+ # Check no_overwrite condition
109
+ if no_overwrite and key in existing_data:
110
+ return
111
+ # Update data
112
+ existing_data[key] = payload
113
+ new_marker = _encode_payload(existing_data)
114
+ # Replace the old marker with updated one
115
+ new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
116
+ else:
117
+ # No existing data, append new marker
118
+ data = {key: payload}
119
+ new_marker = _encode_payload(data)
120
+ # Append with spacing if description exists
121
+ if description.strip():
122
+ new_description = f"{description.rstrip()}\n{new_marker}"
123
+ else:
124
+ new_description = new_marker
125
+
126
+ raw.info["description"] = new_description
127
+
128
+
129
+ def mne_load_metadata(raw: BaseRaw, *, key: str, delete: bool = False) -> Any | None:
130
+ """Retrieves data that was previously stored using mne_store_metadata from an MNE
131
+ BaseRaw dataset.
132
+
133
+ This function can retrieve data from an MNE Raw object that was stored
134
+ using `mne_store_metadata`. It decodes the base64-encoded JSON string from the
135
+ `info['description']` field and extracts the payload associated with the
136
+ specified key.
137
+
138
+ Parameters
139
+ ----------
140
+ raw : BaseRaw
141
+ The MNE Raw object to retrieve data from.
142
+ key : str
143
+ The key under which the payload was stored.
144
+ delete : bool
145
+ If True, removes the key from the stored data after retrieval.
146
+
147
+ Returns
148
+ -------
149
+ Any | None
150
+ The retrieved payload, or None if not found.
151
+ """
152
+ description = raw.info.get("description") or ""
153
+ match = _MARKER_PATTERN.search(description)
154
+ if not match:
155
+ return None
156
+
157
+ try:
158
+ decoded = base64.b64decode(match.group(1)).decode("utf-8")
159
+ data = json.loads(decoded, object_hook=_numpy_decoder)
160
+ except (ValueError, json.JSONDecodeError):
161
+ return None
162
+
163
+ result = data.get(key)
164
+
165
+ if delete and key in data:
166
+ # Remove the key from data
167
+ del data[key]
168
+ if data:
169
+ # Still have other keys, update the marker
170
+ new_marker = _encode_payload(data)
171
+ new_description = _MARKER_PATTERN.sub(new_marker, description, count=1)
172
+ else:
173
+ # No more keys, remove the entire marker
174
+ new_description = _MARKER_PATTERN.sub("", description, count=1).rstrip()
175
+ raw.info["description"] = new_description
176
+
177
+ return result