braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
# Authors: Robin Schirrmeister <robintibor@gmail.com>
|
|
2
|
+
# Hubert Banville <hubert.jbanville@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD (3-clause)
|
|
5
|
+
import inspect
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Literal, Optional, Sequence
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
import braindecode.models as models
|
|
14
|
+
|
|
15
|
+
models_dict = {}
|
|
16
|
+
|
|
17
|
+
# For the models inside the init model, go through all the models
|
|
18
|
+
# check those have the EEGMixin class inherited. If they are, add them to the
|
|
19
|
+
# list.
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _init_models_dict():
|
|
23
|
+
for m in inspect.getmembers(models, inspect.isclass):
|
|
24
|
+
if (
|
|
25
|
+
issubclass(m[1], models.base.EEGModuleMixin)
|
|
26
|
+
and m[1] != models.base.EEGModuleMixin
|
|
27
|
+
):
|
|
28
|
+
if m[1].__name__ == "EEGNetv4":
|
|
29
|
+
continue
|
|
30
|
+
models_dict[m[0]] = m[1]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
SigArgName = Literal[
|
|
34
|
+
"n_outputs",
|
|
35
|
+
"n_chans",
|
|
36
|
+
"chs_info",
|
|
37
|
+
"n_times",
|
|
38
|
+
"input_window_seconds",
|
|
39
|
+
"sfreq",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
################################################################
|
|
44
|
+
# Test cases for models
|
|
45
|
+
#
|
|
46
|
+
# This list should be updated whenever a new model is added to
|
|
47
|
+
# braindecode (otherwise `test_completeness__models_test_cases`
|
|
48
|
+
# will fail).
|
|
49
|
+
# Each element in the list should be a tuple with structure
|
|
50
|
+
# (model_class, required_params, signal_params), such that:
|
|
51
|
+
#
|
|
52
|
+
# model_name: str
|
|
53
|
+
# The name of the class of the model to be tested.
|
|
54
|
+
# required_params: list[str]
|
|
55
|
+
# The signal-related parameters that are needed to initialize
|
|
56
|
+
# the model.
|
|
57
|
+
# signal_params: dict | None
|
|
58
|
+
# The characteristics of the signal that should be passed to
|
|
59
|
+
# the model tested in case the default_signal_params are not
|
|
60
|
+
# compatible with this model.
|
|
61
|
+
# The keys of this dictionary can only be among those of
|
|
62
|
+
# default_signal_params.
|
|
63
|
+
################################################################
|
|
64
|
+
models_mandatory_parameters: list[
|
|
65
|
+
tuple[str, list[SigArgName], dict[SigArgName, Any] | None]
|
|
66
|
+
] = [
|
|
67
|
+
("ATCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
68
|
+
("BDTCN", ["n_chans", "n_outputs"], None),
|
|
69
|
+
("Deep4Net", ["n_chans", "n_outputs", "n_times"], None),
|
|
70
|
+
("DeepSleepNet", ["n_outputs"], None),
|
|
71
|
+
("EEGConformer", ["n_chans", "n_outputs", "n_times"], None),
|
|
72
|
+
("EEGInceptionERP", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
73
|
+
("EEGInceptionMI", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
74
|
+
("EEGITNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
75
|
+
("EEGNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
76
|
+
("ShallowFBCSPNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
77
|
+
(
|
|
78
|
+
"SleepStagerBlanco2020",
|
|
79
|
+
["n_chans", "n_outputs", "n_times"],
|
|
80
|
+
{"n_chans": 4}, # n_chans dividable by n_groups=2
|
|
81
|
+
),
|
|
82
|
+
("SleepStagerChambon2018", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
83
|
+
(
|
|
84
|
+
"AttnSleep",
|
|
85
|
+
["n_outputs", "n_times", "sfreq"],
|
|
86
|
+
{
|
|
87
|
+
"sfreq": 100.0,
|
|
88
|
+
"n_times": 3000,
|
|
89
|
+
"chs_info": [{"ch_name": "C1", "kind": "eeg"}],
|
|
90
|
+
},
|
|
91
|
+
), # 1 channel
|
|
92
|
+
("TIDNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
93
|
+
("USleep", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 128.0}),
|
|
94
|
+
("BIOT", ["n_chans", "n_outputs", "sfreq", "n_times"], None),
|
|
95
|
+
("AttentionBaseNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
96
|
+
("Labram", ["n_chans", "n_outputs", "n_times"], None),
|
|
97
|
+
("EEGSimpleConv", ["n_chans", "n_outputs", "sfreq"], None),
|
|
98
|
+
("SPARCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
99
|
+
("ContraWR", ["n_chans", "n_outputs", "sfreq", "n_times"], {"sfreq": 200.0}),
|
|
100
|
+
("EEGNeX", ["n_chans", "n_outputs", "n_times"], None),
|
|
101
|
+
("EEGSym", ["chs_info", "n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
102
|
+
("TSception", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
|
|
103
|
+
("EEGTCNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
104
|
+
("SyncNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
105
|
+
("MSVTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
106
|
+
("EEGMiner", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
|
|
107
|
+
("CTNet", ["n_chans", "n_outputs", "n_times"], None),
|
|
108
|
+
("SincShallowNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 250.0}),
|
|
109
|
+
("SCCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
|
|
110
|
+
("SignalJEPA", ["chs_info"], None),
|
|
111
|
+
("SignalJEPA_Contextual", ["chs_info", "n_times", "n_outputs"], None),
|
|
112
|
+
("SignalJEPA_PostLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
113
|
+
("SignalJEPA_PreLocal", ["n_chans", "n_times", "n_outputs"], None),
|
|
114
|
+
("FBCNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
|
|
115
|
+
("FBMSNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
|
|
116
|
+
("FBLightConvNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
|
|
117
|
+
("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], {"sfreq": 200.0}),
|
|
118
|
+
("PBT", ["n_chans", "n_outputs", "n_times"], None),
|
|
119
|
+
("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
120
|
+
("BrainModule", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
|
|
121
|
+
("BENDR", ["n_chans", "n_outputs", "n_times"], None),
|
|
122
|
+
("LUNA", ["n_chans", "n_times", "n_outputs"], None),
|
|
123
|
+
("MEDFormer", ["n_chans", "n_outputs", "n_times"], None),
|
|
124
|
+
(
|
|
125
|
+
"REVE",
|
|
126
|
+
["n_times", "n_outputs", "n_chans", "chs_info"],
|
|
127
|
+
{
|
|
128
|
+
"sfreq": 200.0,
|
|
129
|
+
"n_chans": 19,
|
|
130
|
+
"n_times": 1_000,
|
|
131
|
+
"chs_info": [{"ch_name": f"E{i + 1}", "kind": "eeg"} for i in range(19)],
|
|
132
|
+
},
|
|
133
|
+
),
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
################################################################
|
|
137
|
+
# List of models that are not meant for classification
|
|
138
|
+
#
|
|
139
|
+
# Their output shape may difer from the expected output shape
|
|
140
|
+
# for classification models.
|
|
141
|
+
################################################################
|
|
142
|
+
non_classification_models = [
|
|
143
|
+
"SignalJEPA",
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
################################################################
|
|
147
|
+
|
|
148
|
+
rng = np.random.default_rng(12)
|
|
149
|
+
# Generating the channel info
|
|
150
|
+
chs_info = [
|
|
151
|
+
{
|
|
152
|
+
"ch_name": f"C{i}",
|
|
153
|
+
"kind": "eeg",
|
|
154
|
+
"loc": rng.random(12),
|
|
155
|
+
}
|
|
156
|
+
for i in range(1, 4)
|
|
157
|
+
]
|
|
158
|
+
default_signal_params: dict[SigArgName, Any] = {
|
|
159
|
+
"n_times": 1000,
|
|
160
|
+
"sfreq": 250.0,
|
|
161
|
+
"n_outputs": 2,
|
|
162
|
+
"chs_info": chs_info,
|
|
163
|
+
"n_chans": len(chs_info),
|
|
164
|
+
"input_window_seconds": 4.0,
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _get_signal_params(
|
|
169
|
+
signal_params: dict[SigArgName, Any] | None,
|
|
170
|
+
required_params: list[SigArgName] | None = None,
|
|
171
|
+
) -> dict[SigArgName, Any]:
|
|
172
|
+
"""Get signal parameters for model initialization in tests."""
|
|
173
|
+
sp = deepcopy(default_signal_params)
|
|
174
|
+
if signal_params is not None:
|
|
175
|
+
sp.update(signal_params)
|
|
176
|
+
if "chs_info" in signal_params and "n_chans" not in signal_params:
|
|
177
|
+
sp["n_chans"] = len(signal_params["chs_info"])
|
|
178
|
+
if "n_chans" in signal_params and "chs_info" not in signal_params:
|
|
179
|
+
sp["chs_info"] = [
|
|
180
|
+
{"ch_name": f"C{i}", "kind": "eeg", "loc": rng.random(12)}
|
|
181
|
+
for i in range(signal_params["n_chans"])
|
|
182
|
+
]
|
|
183
|
+
assert isinstance(sp["n_times"], int)
|
|
184
|
+
assert isinstance(sp["sfreq"], float)
|
|
185
|
+
assert isinstance(sp["input_window_seconds"], float)
|
|
186
|
+
if "input_window_seconds" not in signal_params:
|
|
187
|
+
sp["input_window_seconds"] = sp["n_times"] / sp["sfreq"]
|
|
188
|
+
if "sfreq" not in signal_params:
|
|
189
|
+
sp["sfreq"] = sp["n_times"] / sp["input_window_seconds"]
|
|
190
|
+
if "n_times" not in signal_params:
|
|
191
|
+
sp["n_times"] = int(sp["input_window_seconds"] * sp["sfreq"])
|
|
192
|
+
if required_params is not None:
|
|
193
|
+
sp = {
|
|
194
|
+
k: sp[k] for k in set((signal_params or {}).keys()).union(required_params)
|
|
195
|
+
}
|
|
196
|
+
return sp
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _get_possible_signal_params(
|
|
200
|
+
signal_params: dict[SigArgName, Any], required_params: list[SigArgName]
|
|
201
|
+
):
|
|
202
|
+
sp = signal_params
|
|
203
|
+
|
|
204
|
+
# List possible model kwargs:
|
|
205
|
+
output_kwargs = []
|
|
206
|
+
output_kwargs.append(dict(n_outputs=sp["n_outputs"]))
|
|
207
|
+
|
|
208
|
+
if "n_outputs" not in required_params:
|
|
209
|
+
output_kwargs.append(dict(n_outputs=None))
|
|
210
|
+
|
|
211
|
+
channel_kwargs = []
|
|
212
|
+
channel_kwargs.append(dict(chs_info=sp["chs_info"], n_chans=None))
|
|
213
|
+
if "chs_info" not in required_params:
|
|
214
|
+
channel_kwargs.append(dict(n_chans=sp["n_chans"], chs_info=None))
|
|
215
|
+
if "n_chans" not in required_params and "chs_info" not in required_params:
|
|
216
|
+
channel_kwargs.append(dict(n_chans=None, chs_info=None))
|
|
217
|
+
|
|
218
|
+
time_kwargs = []
|
|
219
|
+
time_kwargs.append(
|
|
220
|
+
dict(n_times=sp["n_times"], sfreq=sp["sfreq"], input_window_seconds=None)
|
|
221
|
+
)
|
|
222
|
+
time_kwargs.append(
|
|
223
|
+
dict(
|
|
224
|
+
n_times=None,
|
|
225
|
+
sfreq=sp["sfreq"],
|
|
226
|
+
input_window_seconds=sp["input_window_seconds"],
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
time_kwargs.append(
|
|
230
|
+
dict(
|
|
231
|
+
n_times=sp["n_times"],
|
|
232
|
+
sfreq=None,
|
|
233
|
+
input_window_seconds=sp["input_window_seconds"],
|
|
234
|
+
)
|
|
235
|
+
)
|
|
236
|
+
if "n_times" not in required_params and "sfreq" not in required_params:
|
|
237
|
+
time_kwargs.append(
|
|
238
|
+
dict(
|
|
239
|
+
n_times=None,
|
|
240
|
+
sfreq=None,
|
|
241
|
+
input_window_seconds=sp["input_window_seconds"],
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
if (
|
|
245
|
+
"n_times" not in required_params
|
|
246
|
+
and "input_window_seconds" not in required_params
|
|
247
|
+
):
|
|
248
|
+
time_kwargs.append(
|
|
249
|
+
dict(n_times=None, sfreq=sp["sfreq"], input_window_seconds=None)
|
|
250
|
+
)
|
|
251
|
+
if "sfreq" not in required_params and "input_window_seconds" not in required_params:
|
|
252
|
+
time_kwargs.append(
|
|
253
|
+
dict(n_times=sp["n_times"], sfreq=None, input_window_seconds=None)
|
|
254
|
+
)
|
|
255
|
+
if (
|
|
256
|
+
"n_times" not in required_params
|
|
257
|
+
and "sfreq" not in required_params
|
|
258
|
+
and "input_window_seconds" not in required_params
|
|
259
|
+
):
|
|
260
|
+
time_kwargs.append(dict(n_times=None, sfreq=None, input_window_seconds=None))
|
|
261
|
+
|
|
262
|
+
return [
|
|
263
|
+
dict(**o, **c, **t)
|
|
264
|
+
for o in output_kwargs
|
|
265
|
+
for c in channel_kwargs
|
|
266
|
+
for t in time_kwargs
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
################################################################
|
|
271
|
+
def get_summary_table(dir_name=None):
|
|
272
|
+
if dir_name is None:
|
|
273
|
+
dir_path = Path(__file__).parent
|
|
274
|
+
else:
|
|
275
|
+
dir_path = Path(dir_name) if not isinstance(dir_name, Path) else dir_name
|
|
276
|
+
|
|
277
|
+
path = dir_path / "summary.csv"
|
|
278
|
+
|
|
279
|
+
df = pd.read_csv(
|
|
280
|
+
path,
|
|
281
|
+
header=0,
|
|
282
|
+
index_col="Model",
|
|
283
|
+
skipinitialspace=True,
|
|
284
|
+
)
|
|
285
|
+
return df
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def extract_channel_locations_from_chs_info(
|
|
289
|
+
chs_info: Optional[Sequence[Dict[str, Any]]],
|
|
290
|
+
num_channels: Optional[int] = None,
|
|
291
|
+
) -> Optional[np.ndarray]:
|
|
292
|
+
"""Extract 3D channel locations from MNE-style channel information.
|
|
293
|
+
|
|
294
|
+
This function provides a unified approach to extract 3D channel locations
|
|
295
|
+
from MNE channel information. It's compatible with models like SignalJEPA
|
|
296
|
+
and LUNA that need to work with channel spatial information.
|
|
297
|
+
|
|
298
|
+
Parameters
|
|
299
|
+
----------
|
|
300
|
+
chs_info : list of dict or None
|
|
301
|
+
Channel information, typically from ``mne.Info.chs``. Each dict should
|
|
302
|
+
contain a 'loc' key with a 12-element array (MNE format) where indices 3:6
|
|
303
|
+
represent the 3D cartesian coordinates.
|
|
304
|
+
num_channels : int or None
|
|
305
|
+
If specified, only extract the first ``num_channels`` channel locations.
|
|
306
|
+
If None, extract all available channels.
|
|
307
|
+
|
|
308
|
+
Returns
|
|
309
|
+
-------
|
|
310
|
+
channel_locations : np.ndarray of shape (n_channels, 3) or None
|
|
311
|
+
Array of 3D channel locations in cartesian coordinates. Returns None if
|
|
312
|
+
no valid locations are found.
|
|
313
|
+
|
|
314
|
+
Notes
|
|
315
|
+
-----
|
|
316
|
+
- This function handles both 12-element MNE location format (using indices 3:6)
|
|
317
|
+
and 3-element location format (using directly).
|
|
318
|
+
- Invalid or missing locations cause extraction to stop at that point.
|
|
319
|
+
- Returns None if no valid locations can be extracted.
|
|
320
|
+
- This is a unified utility compatible with models like SignalJEPA and LUNA.
|
|
321
|
+
|
|
322
|
+
Examples
|
|
323
|
+
--------
|
|
324
|
+
>>> import mne
|
|
325
|
+
>>> from braindecode.models.util import extract_channel_locations_from_chs_info
|
|
326
|
+
>>> raw = mne.io.read_raw_edf("sample.edf")
|
|
327
|
+
>>> locs = extract_channel_locations_from_chs_info(raw.info['chs'], num_channels=22)
|
|
328
|
+
>>> print(locs.shape)
|
|
329
|
+
(22, 3)
|
|
330
|
+
"""
|
|
331
|
+
if chs_info is None:
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
locations = []
|
|
335
|
+
n_to_extract = num_channels if num_channels is not None else len(chs_info)
|
|
336
|
+
|
|
337
|
+
for i, ch_info in enumerate(chs_info[:n_to_extract]):
|
|
338
|
+
if not isinstance(ch_info, dict):
|
|
339
|
+
break
|
|
340
|
+
|
|
341
|
+
loc = ch_info.get("loc")
|
|
342
|
+
if loc is None:
|
|
343
|
+
break
|
|
344
|
+
|
|
345
|
+
try:
|
|
346
|
+
loc_array = np.asarray(loc, dtype=np.float32)
|
|
347
|
+
|
|
348
|
+
# MNE format: 12-element array with coordinates at indices 3:6
|
|
349
|
+
if loc_array.ndim == 1 and loc_array.size >= 6:
|
|
350
|
+
if loc_array.size == 12:
|
|
351
|
+
# Standard MNE format
|
|
352
|
+
coordinates = loc_array[3:6]
|
|
353
|
+
else:
|
|
354
|
+
# Assume first 3 elements are coordinates
|
|
355
|
+
coordinates = loc_array[:3]
|
|
356
|
+
else:
|
|
357
|
+
break
|
|
358
|
+
|
|
359
|
+
locations.append(coordinates)
|
|
360
|
+
except (ValueError, TypeError):
|
|
361
|
+
break
|
|
362
|
+
|
|
363
|
+
if len(locations) == 0:
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
return np.stack(locations, axis=0)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
_summary_table = get_summary_table()
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from .activation import LogActivation, SafeLog
|
|
2
|
+
from .attention import (
|
|
3
|
+
CAT,
|
|
4
|
+
CBAM,
|
|
5
|
+
ECA,
|
|
6
|
+
FCA,
|
|
7
|
+
GCT,
|
|
8
|
+
SRM,
|
|
9
|
+
CATLite,
|
|
10
|
+
EncNet,
|
|
11
|
+
GatherExcite,
|
|
12
|
+
GSoP,
|
|
13
|
+
MultiHeadAttention,
|
|
14
|
+
SqueezeAndExcitation,
|
|
15
|
+
)
|
|
16
|
+
from .blocks import MLP, FeedForwardBlock, InceptionBlock
|
|
17
|
+
from .convolution import (
|
|
18
|
+
AvgPool2dWithConv,
|
|
19
|
+
CausalConv1d,
|
|
20
|
+
CombinedConv,
|
|
21
|
+
Conv2dWithConstraint,
|
|
22
|
+
DepthwiseConv2d,
|
|
23
|
+
)
|
|
24
|
+
from .filter import FilterBankLayer, GeneralizedGaussianFilter
|
|
25
|
+
from .layers import (
|
|
26
|
+
Chomp1d,
|
|
27
|
+
DropPath,
|
|
28
|
+
Ensure4d,
|
|
29
|
+
SqueezeFinalOutput,
|
|
30
|
+
SubjectLayers,
|
|
31
|
+
TimeDistributed,
|
|
32
|
+
)
|
|
33
|
+
from .linear import LinearWithConstraint, MaxNormLinear
|
|
34
|
+
from .parametrization import MaxNorm, MaxNormParametrize
|
|
35
|
+
from .stats import (
|
|
36
|
+
LogPowerLayer,
|
|
37
|
+
LogVarLayer,
|
|
38
|
+
MaxLayer,
|
|
39
|
+
MeanLayer,
|
|
40
|
+
StatLayer,
|
|
41
|
+
StdLayer,
|
|
42
|
+
VarLayer,
|
|
43
|
+
)
|
|
44
|
+
from .util import aggregate_probas
|
|
45
|
+
from .wrapper import Expression, IntermediateOutputWrapper
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
"LogActivation",
|
|
49
|
+
"SafeLog",
|
|
50
|
+
"CAT",
|
|
51
|
+
"CBAM",
|
|
52
|
+
"ECA",
|
|
53
|
+
"FCA",
|
|
54
|
+
"GCT",
|
|
55
|
+
"SRM",
|
|
56
|
+
"CATLite",
|
|
57
|
+
"EncNet",
|
|
58
|
+
"GatherExcite",
|
|
59
|
+
"GSoP",
|
|
60
|
+
"MultiHeadAttention",
|
|
61
|
+
"SqueezeAndExcitation",
|
|
62
|
+
"MLP",
|
|
63
|
+
"FeedForwardBlock",
|
|
64
|
+
"InceptionBlock",
|
|
65
|
+
"AvgPool2dWithConv",
|
|
66
|
+
"CausalConv1d",
|
|
67
|
+
"CombinedConv",
|
|
68
|
+
"Conv2dWithConstraint",
|
|
69
|
+
"DepthwiseConv2d",
|
|
70
|
+
"FilterBankLayer",
|
|
71
|
+
"GeneralizedGaussianFilter",
|
|
72
|
+
"Chomp1d",
|
|
73
|
+
"DropPath",
|
|
74
|
+
"Ensure4d",
|
|
75
|
+
"SubjectLayers",
|
|
76
|
+
"SqueezeFinalOutput",
|
|
77
|
+
"TimeDistributed",
|
|
78
|
+
"LinearWithConstraint",
|
|
79
|
+
"MaxNormLinear",
|
|
80
|
+
"MaxNorm",
|
|
81
|
+
"MaxNormParametrize",
|
|
82
|
+
"LogPowerLayer",
|
|
83
|
+
"LogVarLayer",
|
|
84
|
+
"MaxLayer",
|
|
85
|
+
"MeanLayer",
|
|
86
|
+
"StatLayer",
|
|
87
|
+
"StdLayer",
|
|
88
|
+
"VarLayer",
|
|
89
|
+
"aggregate_probas",
|
|
90
|
+
"Expression",
|
|
91
|
+
"IntermediateOutputWrapper",
|
|
92
|
+
]
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor, nn
|
|
3
|
+
|
|
4
|
+
import braindecode.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SafeLog(nn.Module):
|
|
8
|
+
r"""
|
|
9
|
+
Safe logarithm activation function module.
|
|
10
|
+
|
|
11
|
+
:math:`\text{SafeLog}(x) = \log\left(\max(x, \epsilon)\right)`
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
epsilon : float, optional
|
|
16
|
+
A small value to clamp the input tensor to prevent computing log(0) or log of negative numbers.
|
|
17
|
+
Default is 1e-6.
|
|
18
|
+
|
|
19
|
+
Examples
|
|
20
|
+
--------
|
|
21
|
+
>>> import torch
|
|
22
|
+
>>> from braindecode.modules import SafeLog
|
|
23
|
+
>>> module = SafeLog(epsilon=1e-6)
|
|
24
|
+
>>> inputs = torch.rand(2, 3)
|
|
25
|
+
>>> outputs = module(inputs)
|
|
26
|
+
>>> outputs.shape
|
|
27
|
+
torch.Size([2, 3])
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, epsilon: float = 1e-6):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.epsilon = epsilon
|
|
34
|
+
|
|
35
|
+
def forward(self, x) -> Tensor:
|
|
36
|
+
"""
|
|
37
|
+
Forward pass of the SafeLog module.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
x : torch.Tensor
|
|
42
|
+
Input tensor.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
torch.Tensor
|
|
47
|
+
Output tensor after applying safe logarithm.
|
|
48
|
+
"""
|
|
49
|
+
return F.safe_log(x=x, eps=self.epsilon)
|
|
50
|
+
|
|
51
|
+
def extra_repr(self) -> str:
|
|
52
|
+
eps_str = f"eps={self.epsilon}"
|
|
53
|
+
return eps_str
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class LogActivation(nn.Module):
|
|
57
|
+
"""Logarithm activation function.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
epsilon : float, default=1e-6
|
|
62
|
+
Small float to adjust the activation.
|
|
63
|
+
|
|
64
|
+
Examples
|
|
65
|
+
--------
|
|
66
|
+
>>> import torch
|
|
67
|
+
>>> from braindecode.modules import LogActivation
|
|
68
|
+
>>> module = LogActivation(epsilon=1e-6)
|
|
69
|
+
>>> inputs = torch.rand(2, 3)
|
|
70
|
+
>>> outputs = module(inputs)
|
|
71
|
+
>>> outputs.shape
|
|
72
|
+
torch.Size([2, 3])
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, epsilon: float = 1e-6, *args, **kwargs):
|
|
76
|
+
"""
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
epsilon : float
|
|
80
|
+
Small float to adjust the activation.
|
|
81
|
+
"""
|
|
82
|
+
super().__init__(*args, **kwargs)
|
|
83
|
+
self.epsilon = epsilon
|
|
84
|
+
|
|
85
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
return torch.log(x + self.epsilon) # Adding epsilon to prevent log(0)
|