braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/augmentation/__init__.py +3 -5
- braindecode/augmentation/base.py +5 -8
- braindecode/augmentation/functional.py +22 -25
- braindecode/augmentation/transforms.py +42 -51
- braindecode/classifier.py +16 -11
- braindecode/datasets/__init__.py +3 -5
- braindecode/datasets/base.py +13 -17
- braindecode/datasets/bbci.py +14 -13
- braindecode/datasets/bcicomp.py +5 -4
- braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
- braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
- braindecode/datasets/{bids/hub.py → hub.py} +350 -375
- braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
- braindecode/datasets/mne.py +19 -19
- braindecode/datasets/moabb.py +10 -10
- braindecode/datasets/nmt.py +56 -58
- braindecode/datasets/sleep_physio_challe_18.py +5 -3
- braindecode/datasets/sleep_physionet.py +5 -5
- braindecode/datasets/tuh.py +18 -21
- braindecode/datasets/xy.py +9 -10
- braindecode/datautil/__init__.py +3 -3
- braindecode/datautil/serialization.py +20 -22
- braindecode/datautil/util.py +7 -120
- braindecode/eegneuralnet.py +52 -22
- braindecode/functional/functions.py +10 -7
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +3 -5
- braindecode/models/atcnet.py +39 -43
- braindecode/models/attentionbasenet.py +41 -37
- braindecode/models/attn_sleep.py +24 -26
- braindecode/models/base.py +6 -6
- braindecode/models/bendr.py +26 -50
- braindecode/models/biot.py +30 -61
- braindecode/models/contrawr.py +5 -5
- braindecode/models/ctnet.py +35 -35
- braindecode/models/deep4.py +5 -5
- braindecode/models/deepsleepnet.py +7 -7
- braindecode/models/eegconformer.py +26 -31
- braindecode/models/eeginception_erp.py +2 -2
- braindecode/models/eeginception_mi.py +6 -6
- braindecode/models/eegitnet.py +5 -5
- braindecode/models/eegminer.py +1 -1
- braindecode/models/eegnet.py +3 -3
- braindecode/models/eegnex.py +2 -2
- braindecode/models/eegsimpleconv.py +2 -2
- braindecode/models/eegsym.py +7 -7
- braindecode/models/eegtcnet.py +6 -6
- braindecode/models/fbcnet.py +2 -2
- braindecode/models/fblightconvnet.py +3 -3
- braindecode/models/fbmsnet.py +3 -3
- braindecode/models/hybrid.py +2 -2
- braindecode/models/ifnet.py +5 -5
- braindecode/models/labram.py +46 -70
- braindecode/models/luna.py +5 -60
- braindecode/models/medformer.py +21 -23
- braindecode/models/msvtnet.py +15 -15
- braindecode/models/patchedtransformer.py +55 -55
- braindecode/models/sccnet.py +2 -2
- braindecode/models/shallow_fbcsp.py +3 -5
- braindecode/models/signal_jepa.py +12 -39
- braindecode/models/sinc_shallow.py +4 -3
- braindecode/models/sleep_stager_blanco_2020.py +2 -2
- braindecode/models/sleep_stager_chambon_2018.py +2 -2
- braindecode/models/sparcnet.py +8 -8
- braindecode/models/sstdpn.py +869 -869
- braindecode/models/summary.csv +17 -19
- braindecode/models/syncnet.py +2 -2
- braindecode/models/tcn.py +5 -5
- braindecode/models/tidnet.py +3 -3
- braindecode/models/tsinception.py +3 -3
- braindecode/models/usleep.py +7 -7
- braindecode/models/util.py +14 -165
- braindecode/modules/__init__.py +1 -9
- braindecode/modules/activation.py +3 -29
- braindecode/modules/attention.py +0 -123
- braindecode/modules/blocks.py +1 -53
- braindecode/modules/convolution.py +0 -53
- braindecode/modules/filter.py +0 -31
- braindecode/modules/layers.py +0 -84
- braindecode/modules/linear.py +1 -22
- braindecode/modules/stats.py +0 -10
- braindecode/modules/util.py +0 -9
- braindecode/modules/wrapper.py +0 -17
- braindecode/preprocessing/preprocess.py +0 -3
- braindecode/regressor.py +18 -15
- braindecode/samplers/ssl.py +1 -1
- braindecode/util.py +28 -38
- braindecode/version.py +1 -1
- braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
- braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
- braindecode/datasets/bids/__init__.py +0 -54
- braindecode/datasets/bids/format.py +0 -717
- braindecode/datasets/bids/hub_format.py +0 -717
- braindecode/datasets/bids/hub_io.py +0 -197
- braindecode/datasets/chb_mit.py +0 -163
- braindecode/datasets/siena.py +0 -162
- braindecode/datasets/utils.py +0 -67
- braindecode/models/brainmodule.py +0 -845
- braindecode/models/config.py +0 -233
- braindecode/models/reve.py +0 -843
- braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
- braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
braindecode/models/config.py
DELETED
|
@@ -1,233 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from inspect import signature
|
|
3
|
-
from types import UnionType
|
|
4
|
-
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
from mne.utils import _soft_import
|
|
8
|
-
from typing_extensions import TypedDict
|
|
9
|
-
|
|
10
|
-
from braindecode.models.base import EEGModuleMixin
|
|
11
|
-
from braindecode.models.util import SigArgName, models_dict, models_mandatory_parameters
|
|
12
|
-
|
|
13
|
-
pydantic = _soft_import(name="pydantic", purpose="model configuration", strict=False)
|
|
14
|
-
|
|
15
|
-
try:
|
|
16
|
-
from numpydantic import NDArray, Shape
|
|
17
|
-
except ImportError:
|
|
18
|
-
# we can't use soft import for numpydantic because numpydantic does not define its version in __init__
|
|
19
|
-
NDArray = Any # type: ignore
|
|
20
|
-
Shape = Any # type: ignore
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class ChsInfoType(TypedDict, total=False, closed=True): # type: ignore[call-arg]
|
|
24
|
-
cal: float
|
|
25
|
-
ch_name: str
|
|
26
|
-
coil_type: int
|
|
27
|
-
coord_frame: int
|
|
28
|
-
kind: str
|
|
29
|
-
loc: NDArray[Shape["12"], np.float64] # type: ignore[misc]
|
|
30
|
-
logno: int
|
|
31
|
-
range: float
|
|
32
|
-
scanno: int
|
|
33
|
-
unit: int
|
|
34
|
-
unit_mul: int
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _replace_type_hints(type_hint: Any) -> Any:
|
|
38
|
-
origin = get_origin(type_hint)
|
|
39
|
-
args = get_args(type_hint)
|
|
40
|
-
if origin is type or origin is Callable or type_hint is Callable:
|
|
41
|
-
return pydantic.ImportString
|
|
42
|
-
if origin is None:
|
|
43
|
-
return type_hint
|
|
44
|
-
replaced_args = tuple(_replace_type_hints(arg) for arg in args)
|
|
45
|
-
if origin is UnionType:
|
|
46
|
-
origin = Union
|
|
47
|
-
return origin[replaced_args]
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
SIGNAL_ARGS_TYPES = {
|
|
51
|
-
"n_chans": int,
|
|
52
|
-
"n_times": int,
|
|
53
|
-
"sfreq": float,
|
|
54
|
-
"input_window_seconds": float,
|
|
55
|
-
"n_outputs": int,
|
|
56
|
-
"chs_info": list[ChsInfoType],
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class BaseBraindecodeModelConfig(pydantic.BaseModel): # type: ignore
|
|
61
|
-
def create_instance(self) -> EEGModuleMixin:
|
|
62
|
-
model_cls = models_dict[self.model_name_]
|
|
63
|
-
kwargs = self.model_dump(mode="python", exclude={"model_name_"})
|
|
64
|
-
if kwargs.get("n_chans") is not None and kwargs.get("chs_info") is not None:
|
|
65
|
-
kwargs.pop("n_chans")
|
|
66
|
-
if (
|
|
67
|
-
kwargs.get("n_times") is not None
|
|
68
|
-
and kwargs.get("input_window_seconds") is not None
|
|
69
|
-
and kwargs.get("sfreq") is not None
|
|
70
|
-
):
|
|
71
|
-
kwargs.pop("n_times")
|
|
72
|
-
return model_cls(**kwargs)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def make_model_config(
|
|
76
|
-
model_class: type[EEGModuleMixin],
|
|
77
|
-
required: list[SigArgName],
|
|
78
|
-
) -> type[BaseBraindecodeModelConfig]:
|
|
79
|
-
"""Create a pydantic model config for a given model class.
|
|
80
|
-
|
|
81
|
-
Parameters
|
|
82
|
-
----------
|
|
83
|
-
model_class : type[EEGModuleMixin]
|
|
84
|
-
The model class for which to create the config.
|
|
85
|
-
required : list of SigArgName
|
|
86
|
-
The required signal arguments for the model.
|
|
87
|
-
|
|
88
|
-
Returns
|
|
89
|
-
-------
|
|
90
|
-
type
|
|
91
|
-
A pydantic BaseModel subclass representing the model config.
|
|
92
|
-
"""
|
|
93
|
-
if not pydantic:
|
|
94
|
-
raise ImportError(
|
|
95
|
-
"pydantic is required to use make_model_config. "
|
|
96
|
-
"Please install braindecode[typing]."
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# ironically, we need to ignore the type here to have the soft dependency.
|
|
100
|
-
|
|
101
|
-
@pydantic.model_validator(mode="before")
|
|
102
|
-
def validate_signal_params(cls, data: Any):
|
|
103
|
-
n_outputs = data.get("n_outputs")
|
|
104
|
-
n_chans = data.get("n_chans")
|
|
105
|
-
chs_info = data.get("chs_info")
|
|
106
|
-
n_times = data.get("n_times")
|
|
107
|
-
input_window_seconds = data.get("input_window_seconds")
|
|
108
|
-
sfreq = data.get("sfreq")
|
|
109
|
-
|
|
110
|
-
# Check that required parameters are provided or can be inferred
|
|
111
|
-
if "n_outputs" in required and n_outputs is None:
|
|
112
|
-
raise ValueError("n_outputs is a required parameter but was not provided.")
|
|
113
|
-
if "n_chans" in required and n_chans is None and chs_info is None:
|
|
114
|
-
raise ValueError(
|
|
115
|
-
"n_chans is required and could not be inferred. Either specify n_chans or chs_info."
|
|
116
|
-
)
|
|
117
|
-
if "chs_info" in required and chs_info is None:
|
|
118
|
-
raise ValueError("chs_info is a required parameter but was not provided.")
|
|
119
|
-
if "n_times" in required and (
|
|
120
|
-
n_times is None and (sfreq is None or input_window_seconds is None)
|
|
121
|
-
):
|
|
122
|
-
raise ValueError(
|
|
123
|
-
"n_times is required and could not be inferred."
|
|
124
|
-
"Either specify n_times or input_window_seconds and sfreq."
|
|
125
|
-
)
|
|
126
|
-
if "sfreq" in required and (
|
|
127
|
-
sfreq is None and (n_times is None or input_window_seconds is None)
|
|
128
|
-
):
|
|
129
|
-
raise ValueError(
|
|
130
|
-
"sfreq is required and could not be inferred."
|
|
131
|
-
"Either specify sfreq or input_window_seconds and n_times."
|
|
132
|
-
)
|
|
133
|
-
if "input_window_seconds" in required and (
|
|
134
|
-
input_window_seconds is None and (n_times is None or sfreq is None)
|
|
135
|
-
):
|
|
136
|
-
raise ValueError(
|
|
137
|
-
"input_window_seconds is required and could not be inferred."
|
|
138
|
-
"Either specify input_window_seconds or n_times and sfreq."
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
# Infer missing parameters if possible, and check consistency
|
|
142
|
-
if chs_info is not None:
|
|
143
|
-
if n_chans is None:
|
|
144
|
-
data["n_chans"] = len(chs_info)
|
|
145
|
-
elif n_chans != len(chs_info):
|
|
146
|
-
raise ValueError(
|
|
147
|
-
f"Provided {n_chans=} does not match length of chs_info: {len(chs_info)}."
|
|
148
|
-
)
|
|
149
|
-
if (
|
|
150
|
-
n_times is not None
|
|
151
|
-
and sfreq is not None
|
|
152
|
-
and input_window_seconds is not None
|
|
153
|
-
):
|
|
154
|
-
if n_times != round(input_window_seconds * sfreq):
|
|
155
|
-
raise ValueError(
|
|
156
|
-
f"Provided {n_times=} does not match {input_window_seconds=} * {sfreq=}."
|
|
157
|
-
)
|
|
158
|
-
elif n_times is None and sfreq is not None and input_window_seconds is not None:
|
|
159
|
-
data["n_times"] = round(input_window_seconds * sfreq)
|
|
160
|
-
elif sfreq is None and n_times is not None and input_window_seconds is not None:
|
|
161
|
-
data["sfreq"] = n_times / input_window_seconds
|
|
162
|
-
elif input_window_seconds is None and n_times is not None and sfreq is not None:
|
|
163
|
-
data["input_window_seconds"] = n_times / sfreq
|
|
164
|
-
return data
|
|
165
|
-
|
|
166
|
-
signature_params = signature(model_class.__init__, eval_str=True).parameters
|
|
167
|
-
has_args = any(p.kind == p.VAR_POSITIONAL for p in signature_params.values())
|
|
168
|
-
has_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature_params.values())
|
|
169
|
-
if has_args:
|
|
170
|
-
raise ValueError("Model __init__ methods cannot have *args")
|
|
171
|
-
|
|
172
|
-
extra = "allow" if has_kwargs else "forbid"
|
|
173
|
-
fields = {}
|
|
174
|
-
for name, p in signature_params.items():
|
|
175
|
-
if name == "self" or p.kind == p.VAR_KEYWORD:
|
|
176
|
-
continue
|
|
177
|
-
|
|
178
|
-
annot = p.annotation
|
|
179
|
-
if annot is p.empty:
|
|
180
|
-
annot = Any
|
|
181
|
-
# case with type[nn.Module] or callable
|
|
182
|
-
else:
|
|
183
|
-
annot = _replace_type_hints(annot)
|
|
184
|
-
# Most models did not specify types for signal args, so we add them here
|
|
185
|
-
if name in SIGNAL_ARGS_TYPES:
|
|
186
|
-
annot = SIGNAL_ARGS_TYPES[name] | None
|
|
187
|
-
|
|
188
|
-
fields[name] = (annot, p.default) if p.default is not p.empty else annot
|
|
189
|
-
|
|
190
|
-
name = model_class.__name__
|
|
191
|
-
model_config = pydantic.create_model(
|
|
192
|
-
f"{name}Config",
|
|
193
|
-
model_name_=(Literal[name], name),
|
|
194
|
-
__config__=pydantic.ConfigDict(
|
|
195
|
-
arbitrary_types_allowed=True, extra=extra, validate_default=True
|
|
196
|
-
),
|
|
197
|
-
__doc__=f"Pydantic config of model {model_class.__name__}\n\n{model_class.__doc__}",
|
|
198
|
-
__base__=BaseBraindecodeModelConfig,
|
|
199
|
-
__module__="braindecode.models.config",
|
|
200
|
-
__validators__={"validate_signal_params": validate_signal_params},
|
|
201
|
-
**fields,
|
|
202
|
-
)
|
|
203
|
-
return model_config
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
# Automatically generate and add classes to the global namespace
|
|
207
|
-
# and define __all__ based on generated classes
|
|
208
|
-
__all__ = ["make_model_config"]
|
|
209
|
-
|
|
210
|
-
if not pydantic:
|
|
211
|
-
pass
|
|
212
|
-
else:
|
|
213
|
-
models_configs: list[type[BaseBraindecodeModelConfig]] = []
|
|
214
|
-
for model_name, req, _ in models_mandatory_parameters:
|
|
215
|
-
model_cls = models_dict[model_name]
|
|
216
|
-
model_cfg = make_model_config(model_cls, req)
|
|
217
|
-
globals()[model_cfg.__name__] = model_cfg
|
|
218
|
-
__all__.append(model_cfg.__name__)
|
|
219
|
-
models_configs.append(model_cfg)
|
|
220
|
-
|
|
221
|
-
BraindecodeModelConfig = Annotated[ # type: ignore
|
|
222
|
-
Union[tuple(models_configs)],
|
|
223
|
-
pydantic.Field(
|
|
224
|
-
discriminator="model_name_", description="Braindecode model configuration"
|
|
225
|
-
),
|
|
226
|
-
]
|
|
227
|
-
|
|
228
|
-
# # Example usage:
|
|
229
|
-
#
|
|
230
|
-
# class DummyConfigWithModel(pydantic.BaseModel):
|
|
231
|
-
# model: BraindecodeModelConfig
|
|
232
|
-
#
|
|
233
|
-
# DummyConfigWithModel.model_validate({'model': dict(model_name_='EEGNet', n_chans=16, n_outputs=1, n_times=200)})
|