braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,574 @@
|
|
|
1
|
+
# Authors: Pierre Guetschel
|
|
2
|
+
# Maciej Sliwowski
|
|
3
|
+
#
|
|
4
|
+
# License: BSD-3
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import warnings
|
|
10
|
+
from collections import OrderedDict
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Iterable, Optional, Type, Union
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
from docstring_inheritance import NumpyDocstringInheritanceInitMeta
|
|
17
|
+
from mne.utils import _soft_import
|
|
18
|
+
from torchinfo import ModelStatistics, summary
|
|
19
|
+
|
|
20
|
+
from braindecode.version import __version__
|
|
21
|
+
|
|
22
|
+
huggingface_hub = _soft_import(
|
|
23
|
+
"huggingface_hub", "Hugging Face Hub integration", strict=False
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
HAS_HF_HUB = huggingface_hub is not False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _BaseHubMixin:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Define base class for hub mixin
|
|
34
|
+
if HAS_HF_HUB:
|
|
35
|
+
_BaseHubMixin: Type = huggingface_hub.PyTorchModelHubMixin # type: ignore
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def deprecated_args(obj, *old_new_args):
|
|
39
|
+
out_args = []
|
|
40
|
+
for old_name, new_name, old_val, new_val in old_new_args:
|
|
41
|
+
if old_val is None:
|
|
42
|
+
out_args.append(new_val)
|
|
43
|
+
else:
|
|
44
|
+
warnings.warn(
|
|
45
|
+
f"{obj.__class__.__name__}: {old_name!r} is depreciated. Use {new_name!r} instead."
|
|
46
|
+
)
|
|
47
|
+
if new_val is not None:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"{obj.__class__.__name__}: Both {old_name!r} and {new_name!r} were specified."
|
|
50
|
+
)
|
|
51
|
+
out_args.append(old_val)
|
|
52
|
+
return out_args
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta):
|
|
56
|
+
"""
|
|
57
|
+
Mixin class for all EEG models in braindecode.
|
|
58
|
+
|
|
59
|
+
This class integrates with Hugging Face Hub when the ``huggingface_hub`` package
|
|
60
|
+
is installed, enabling models to be pushed to and loaded from the Hub using
|
|
61
|
+
:func:`push_to_hub()` and :func:`from_pretrained()` methods.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
n_outputs : int
|
|
66
|
+
Number of outputs of the model. This is the number of classes
|
|
67
|
+
in the case of classification.
|
|
68
|
+
n_chans : int
|
|
69
|
+
Number of EEG channels.
|
|
70
|
+
chs_info : list of dict
|
|
71
|
+
Information about each individual EEG channel. This should be filled with
|
|
72
|
+
``info["chs"]``. Refer to :class:`mne.Info` for more details.
|
|
73
|
+
n_times : int
|
|
74
|
+
Number of time samples of the input window.
|
|
75
|
+
input_window_seconds : float
|
|
76
|
+
Length of the input window in seconds.
|
|
77
|
+
sfreq : float
|
|
78
|
+
Sampling frequency of the EEG recordings.
|
|
79
|
+
|
|
80
|
+
Raises
|
|
81
|
+
------
|
|
82
|
+
ValueError: If some input signal-related parameters are not specified
|
|
83
|
+
and can not be inferred.
|
|
84
|
+
|
|
85
|
+
Notes
|
|
86
|
+
-----
|
|
87
|
+
If some input signal-related parameters are not specified,
|
|
88
|
+
there will be an attempt to infer them from the other parameters.
|
|
89
|
+
|
|
90
|
+
.. rubric:: Hugging Face Hub integration
|
|
91
|
+
|
|
92
|
+
When the optional ``huggingface_hub`` package is installed, all models
|
|
93
|
+
automatically gain the ability to be pushed to and loaded from the
|
|
94
|
+
Hugging Face Hub. Install with::
|
|
95
|
+
|
|
96
|
+
pip install braindecode[hug]
|
|
97
|
+
|
|
98
|
+
**Pushing a model to the Hub:**
|
|
99
|
+
|
|
100
|
+
.. code-block:: python
|
|
101
|
+
|
|
102
|
+
from braindecode.models import EEGNetv4
|
|
103
|
+
|
|
104
|
+
# Train your model
|
|
105
|
+
model = EEGNetv4(n_chans=22, n_outputs=4, n_times=1000)
|
|
106
|
+
# ... training code ...
|
|
107
|
+
|
|
108
|
+
# Push to the Hub
|
|
109
|
+
model.push_to_hub(
|
|
110
|
+
repo_id="username/my-eegnet-model", commit_message="Initial model upload"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
**Loading a model from the Hub:**
|
|
114
|
+
|
|
115
|
+
.. code-block:: python
|
|
116
|
+
|
|
117
|
+
from braindecode.models import EEGNetv4
|
|
118
|
+
|
|
119
|
+
# Load pretrained model
|
|
120
|
+
model = EEGNetv4.from_pretrained("username/my-eegnet-model")
|
|
121
|
+
|
|
122
|
+
The integration automatically handles EEG-specific parameters (n_chans,
|
|
123
|
+
n_times, sfreq, chs_info, etc.) by saving them in a config file alongside
|
|
124
|
+
the model weights. This ensures that loaded models are correctly configured
|
|
125
|
+
for their original data specifications.
|
|
126
|
+
|
|
127
|
+
.. important::
|
|
128
|
+
Currently, only EEG-specific parameters (n_outputs, n_chans, n_times,
|
|
129
|
+
input_window_seconds, sfreq, chs_info) are saved to the Hub. Model-specific
|
|
130
|
+
parameters (e.g., dropout rates, activation functions, number of filters)
|
|
131
|
+
are not preserved and will use their default values when loading from the Hub.
|
|
132
|
+
|
|
133
|
+
To use non-default model parameters, specify them explicitly when calling
|
|
134
|
+
:func:`from_pretrained()`::
|
|
135
|
+
|
|
136
|
+
model = EEGNet.from_pretrained("user/model", dropout=0.3, activation='relu')
|
|
137
|
+
|
|
138
|
+
Full parameter serialization will be addressed in a future update.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init_subclass__(cls, **kwargs):
|
|
142
|
+
if not HAS_HF_HUB:
|
|
143
|
+
super().__init_subclass__(**kwargs)
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
base_tags = ["braindecode", cls.__name__]
|
|
147
|
+
user_tags = kwargs.pop("tags", None)
|
|
148
|
+
tags = list(user_tags) if user_tags is not None else []
|
|
149
|
+
for tag in base_tags:
|
|
150
|
+
if tag not in tags:
|
|
151
|
+
tags.append(tag)
|
|
152
|
+
|
|
153
|
+
docs_url = kwargs.pop(
|
|
154
|
+
"docs_url",
|
|
155
|
+
f"https://braindecode.org/stable/generated/braindecode.models.{cls.__name__}.html",
|
|
156
|
+
)
|
|
157
|
+
repo_url = kwargs.pop("repo_url", "https://braindecode.org")
|
|
158
|
+
library_name = kwargs.pop("library_name", "braindecode")
|
|
159
|
+
license = kwargs.pop("license", "bsd-3-clause")
|
|
160
|
+
# TODO: model_card_template can be added in the future for custom model cards
|
|
161
|
+
super().__init_subclass__(
|
|
162
|
+
tags=tags,
|
|
163
|
+
docs_url=docs_url,
|
|
164
|
+
repo_url=repo_url,
|
|
165
|
+
library_name=library_name,
|
|
166
|
+
license=license,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
n_outputs: Optional[int] = None, # type: ignore[assignment]
|
|
173
|
+
n_chans: Optional[int] = None, # type: ignore[assignment]
|
|
174
|
+
chs_info=None, # type: ignore[assignment]
|
|
175
|
+
n_times: Optional[int] = None, # type: ignore[assignment]
|
|
176
|
+
input_window_seconds: Optional[float] = None, # type: ignore[assignment]
|
|
177
|
+
sfreq: Optional[float] = None, # type: ignore[assignment]
|
|
178
|
+
):
|
|
179
|
+
# Deserialize chs_info if it comes as a list of dicts (from Hub)
|
|
180
|
+
if chs_info is not None and isinstance(chs_info, list):
|
|
181
|
+
if len(chs_info) > 0 and isinstance(chs_info[0], dict):
|
|
182
|
+
# Check if it needs deserialization (has 'loc' as list)
|
|
183
|
+
if "loc" in chs_info[0] and isinstance(chs_info[0]["loc"], list):
|
|
184
|
+
chs_info = self._deserialize_chs_info(chs_info)
|
|
185
|
+
warnings.warn(
|
|
186
|
+
"Modifying chs_info argument using the _deserialize_chs_info() method"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
|
|
190
|
+
raise ValueError(f"{n_chans=} different from {chs_info=} length")
|
|
191
|
+
if (
|
|
192
|
+
n_times is not None
|
|
193
|
+
and input_window_seconds is not None
|
|
194
|
+
and sfreq is not None
|
|
195
|
+
and n_times != round(input_window_seconds * sfreq)
|
|
196
|
+
):
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self._input_window_seconds = input_window_seconds # type: ignore[assignment]
|
|
202
|
+
self._chs_info = chs_info # type: ignore[assignment]
|
|
203
|
+
self._n_outputs = n_outputs # type: ignore[assignment]
|
|
204
|
+
self._n_chans = n_chans # type: ignore[assignment]
|
|
205
|
+
self._n_times = n_times # type: ignore[assignment]
|
|
206
|
+
self._sfreq = sfreq # type: ignore[assignment]
|
|
207
|
+
|
|
208
|
+
super().__init__()
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def n_outputs(self) -> int:
|
|
212
|
+
if self._n_outputs is None:
|
|
213
|
+
raise ValueError("n_outputs not specified.")
|
|
214
|
+
return self._n_outputs
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def n_chans(self) -> int:
|
|
218
|
+
if self._n_chans is None and self._chs_info is not None:
|
|
219
|
+
return len(self._chs_info)
|
|
220
|
+
elif self._n_chans is None:
|
|
221
|
+
raise ValueError(
|
|
222
|
+
"n_chans could not be inferred. Either specify n_chans or chs_info."
|
|
223
|
+
)
|
|
224
|
+
return self._n_chans
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def chs_info(self) -> list[str]:
|
|
228
|
+
if self._chs_info is None:
|
|
229
|
+
raise ValueError("chs_info not specified.")
|
|
230
|
+
return self._chs_info
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def n_times(self) -> int:
|
|
234
|
+
if (
|
|
235
|
+
self._n_times is None
|
|
236
|
+
and self._input_window_seconds is not None
|
|
237
|
+
and self._sfreq is not None
|
|
238
|
+
):
|
|
239
|
+
return round(self._input_window_seconds * self._sfreq)
|
|
240
|
+
elif self._n_times is None:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"n_times could not be inferred. "
|
|
243
|
+
"Either specify n_times or input_window_seconds and sfreq."
|
|
244
|
+
)
|
|
245
|
+
return self._n_times
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def input_window_seconds(self) -> float:
|
|
249
|
+
if (
|
|
250
|
+
self._input_window_seconds is None
|
|
251
|
+
and self._n_times is not None
|
|
252
|
+
and self._sfreq is not None
|
|
253
|
+
):
|
|
254
|
+
return float(self._n_times / self._sfreq)
|
|
255
|
+
elif self._input_window_seconds is None:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"input_window_seconds could not be inferred. "
|
|
258
|
+
"Either specify input_window_seconds or n_times and sfreq."
|
|
259
|
+
)
|
|
260
|
+
return self._input_window_seconds
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def sfreq(self) -> float:
|
|
264
|
+
if (
|
|
265
|
+
self._sfreq is None
|
|
266
|
+
and self._input_window_seconds is not None
|
|
267
|
+
and self._n_times is not None
|
|
268
|
+
):
|
|
269
|
+
return float(self._n_times / self._input_window_seconds)
|
|
270
|
+
elif self._sfreq is None:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
"sfreq could not be inferred. "
|
|
273
|
+
"Either specify sfreq or input_window_seconds and n_times."
|
|
274
|
+
)
|
|
275
|
+
return self._sfreq
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def input_shape(self) -> tuple[int, int, int]:
|
|
279
|
+
"""Input data shape."""
|
|
280
|
+
return (1, self.n_chans, self.n_times)
|
|
281
|
+
|
|
282
|
+
def get_output_shape(self) -> tuple[int, ...]:
|
|
283
|
+
"""Returns shape of neural network output for batch size equal 1.
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
output_shape : tuple[int, ...]
|
|
288
|
+
shape of the network output for `batch_size==1` (1, ...)
|
|
289
|
+
"""
|
|
290
|
+
with torch.inference_mode():
|
|
291
|
+
try:
|
|
292
|
+
return tuple(
|
|
293
|
+
self.forward( # type: ignore
|
|
294
|
+
torch.zeros(
|
|
295
|
+
self.input_shape,
|
|
296
|
+
dtype=next(self.parameters()).dtype, # type: ignore
|
|
297
|
+
device=next(self.parameters()).device, # type: ignore
|
|
298
|
+
)
|
|
299
|
+
).shape
|
|
300
|
+
)
|
|
301
|
+
except RuntimeError as exc:
|
|
302
|
+
if str(exc).endswith(
|
|
303
|
+
(
|
|
304
|
+
"Output size is too small",
|
|
305
|
+
"Kernel size can't be greater than actual input size",
|
|
306
|
+
)
|
|
307
|
+
):
|
|
308
|
+
msg = (
|
|
309
|
+
"During model prediction RuntimeError was thrown showing that at some "
|
|
310
|
+
f"layer `{str(exc).split('.')[-1]}` (see above in the stacktrace). This "
|
|
311
|
+
"could be caused by providing too small `n_times`/`input_window_seconds`. "
|
|
312
|
+
"Model may require longer chunks of signal in the input than "
|
|
313
|
+
f"{self.input_shape}."
|
|
314
|
+
)
|
|
315
|
+
raise ValueError(msg) from exc
|
|
316
|
+
raise exc
|
|
317
|
+
|
|
318
|
+
mapping: Optional[Dict[str, str]] = None
|
|
319
|
+
|
|
320
|
+
def load_state_dict(self, state_dict, *args, **kwargs):
|
|
321
|
+
mapping = self.mapping if self.mapping else {}
|
|
322
|
+
new_state_dict = OrderedDict()
|
|
323
|
+
for k, v in state_dict.items():
|
|
324
|
+
if k in mapping:
|
|
325
|
+
new_state_dict[mapping[k]] = v
|
|
326
|
+
else:
|
|
327
|
+
new_state_dict[k] = v
|
|
328
|
+
|
|
329
|
+
return super().load_state_dict(new_state_dict, *args, **kwargs)
|
|
330
|
+
|
|
331
|
+
def to_dense_prediction_model(self, axis: tuple[int, ...] | int = (2, 3)) -> None:
|
|
332
|
+
"""
|
|
333
|
+
Transform a sequential model with strides to a model that outputs.
|
|
334
|
+
|
|
335
|
+
dense predictions by removing the strides and instead inserting dilations.
|
|
336
|
+
Modifies model in-place.
|
|
337
|
+
|
|
338
|
+
Parameters
|
|
339
|
+
----------
|
|
340
|
+
axis : int or (int,int)
|
|
341
|
+
Axis to transform (in terms of intermediate output axes)
|
|
342
|
+
can either be 2, 3, or (2,3).
|
|
343
|
+
|
|
344
|
+
Notes
|
|
345
|
+
-----
|
|
346
|
+
Does not yet work correctly for average pooling.
|
|
347
|
+
Prior to version 0.1.7, there had been a bug that could move strides
|
|
348
|
+
backwards one layer.
|
|
349
|
+
"""
|
|
350
|
+
if not hasattr(axis, "__iter__"):
|
|
351
|
+
axis = (axis,)
|
|
352
|
+
assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis" # type: ignore[union-attr]
|
|
353
|
+
axis = np.array(axis) - 2
|
|
354
|
+
stride_so_far = np.array([1, 1])
|
|
355
|
+
for module in self.modules(): # type: ignore
|
|
356
|
+
if hasattr(module, "dilation"):
|
|
357
|
+
assert module.dilation == 1 or (module.dilation == (1, 1)), (
|
|
358
|
+
"Dilation should equal 1 before conversion, maybe the model is "
|
|
359
|
+
"already converted?"
|
|
360
|
+
)
|
|
361
|
+
new_dilation = [1, 1]
|
|
362
|
+
for ax in axis: # type: ignore[union-attr]
|
|
363
|
+
new_dilation[ax] = int(stride_so_far[ax])
|
|
364
|
+
module.dilation = tuple(new_dilation)
|
|
365
|
+
if hasattr(module, "stride"):
|
|
366
|
+
if not hasattr(module.stride, "__len__"):
|
|
367
|
+
module.stride = (module.stride, module.stride)
|
|
368
|
+
stride_so_far *= np.array(module.stride)
|
|
369
|
+
new_stride = list(module.stride)
|
|
370
|
+
for ax in axis: # type: ignore[union-attr]
|
|
371
|
+
new_stride[ax] = 1
|
|
372
|
+
module.stride = tuple(new_stride)
|
|
373
|
+
|
|
374
|
+
def get_torchinfo_statistics(
|
|
375
|
+
self,
|
|
376
|
+
col_names: Optional[Iterable[str]] = (
|
|
377
|
+
"input_size",
|
|
378
|
+
"output_size",
|
|
379
|
+
"num_params",
|
|
380
|
+
"kernel_size",
|
|
381
|
+
),
|
|
382
|
+
row_settings: Optional[Iterable[str]] = ("var_names", "depth"),
|
|
383
|
+
) -> ModelStatistics:
|
|
384
|
+
"""Generate table describing the model using torchinfo.summary.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
col_names : tuple, optional
|
|
389
|
+
Specify which columns to show in the output, see torchinfo for details, by default
|
|
390
|
+
("input_size", "output_size", "num_params", "kernel_size")
|
|
391
|
+
row_settings : tuple, optional
|
|
392
|
+
Specify which features to show in a row, see torchinfo for details, by default
|
|
393
|
+
("var_names", "depth")
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
torchinfo.ModelStatistics
|
|
398
|
+
ModelStatistics generated by torchinfo.summary.
|
|
399
|
+
"""
|
|
400
|
+
return summary(
|
|
401
|
+
self,
|
|
402
|
+
input_size=(1, self.n_chans, self.n_times),
|
|
403
|
+
col_names=col_names,
|
|
404
|
+
row_settings=row_settings,
|
|
405
|
+
verbose=0,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
def __str__(self) -> str:
|
|
409
|
+
return str(self.get_torchinfo_statistics())
|
|
410
|
+
|
|
411
|
+
@staticmethod
|
|
412
|
+
def _serialize_chs_info(chs_info):
|
|
413
|
+
"""
|
|
414
|
+
Serialize MNE channel info to JSON-compatible format.
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
chs_info : list of dict or None
|
|
419
|
+
Channel information from MNE Info object.
|
|
420
|
+
|
|
421
|
+
Returns
|
|
422
|
+
-------
|
|
423
|
+
list of dict or None
|
|
424
|
+
Serialized channel information that can be saved to JSON.
|
|
425
|
+
"""
|
|
426
|
+
if chs_info is None:
|
|
427
|
+
return None
|
|
428
|
+
|
|
429
|
+
serialized = []
|
|
430
|
+
for ch in chs_info:
|
|
431
|
+
# Extract serializable fields from MNE channel info
|
|
432
|
+
ch_dict = {
|
|
433
|
+
"ch_name": ch.get("ch_name", ""),
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
# Handle kind field - can be either string or integer
|
|
437
|
+
kind_val = ch.get("kind")
|
|
438
|
+
if kind_val is not None:
|
|
439
|
+
ch_dict["kind"] = (
|
|
440
|
+
kind_val if isinstance(kind_val, str) else int(kind_val)
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Add numeric fields with safe conversion
|
|
444
|
+
coil_type = ch.get("coil_type")
|
|
445
|
+
if coil_type is not None:
|
|
446
|
+
ch_dict["coil_type"] = int(coil_type)
|
|
447
|
+
|
|
448
|
+
unit = ch.get("unit")
|
|
449
|
+
if unit is not None:
|
|
450
|
+
ch_dict["unit"] = int(unit)
|
|
451
|
+
|
|
452
|
+
cal = ch.get("cal")
|
|
453
|
+
if cal is not None:
|
|
454
|
+
ch_dict["cal"] = float(cal)
|
|
455
|
+
|
|
456
|
+
range_val = ch.get("range")
|
|
457
|
+
if range_val is not None:
|
|
458
|
+
ch_dict["range"] = float(range_val)
|
|
459
|
+
|
|
460
|
+
# Serialize location array if present
|
|
461
|
+
if "loc" in ch and ch["loc"] is not None:
|
|
462
|
+
ch_dict["loc"] = (
|
|
463
|
+
ch["loc"].tolist()
|
|
464
|
+
if hasattr(ch["loc"], "tolist")
|
|
465
|
+
else list(ch["loc"])
|
|
466
|
+
)
|
|
467
|
+
serialized.append(ch_dict)
|
|
468
|
+
|
|
469
|
+
return serialized
|
|
470
|
+
|
|
471
|
+
@staticmethod
|
|
472
|
+
def _deserialize_chs_info(chs_info_dict):
|
|
473
|
+
"""
|
|
474
|
+
Deserialize channel info from JSON-compatible format to MNE-like structure.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
chs_info_dict : list of dict or None
|
|
479
|
+
Serialized channel information.
|
|
480
|
+
|
|
481
|
+
Returns
|
|
482
|
+
-------
|
|
483
|
+
list of dict or None
|
|
484
|
+
Deserialized channel information compatible with MNE.
|
|
485
|
+
"""
|
|
486
|
+
if chs_info_dict is None:
|
|
487
|
+
return None
|
|
488
|
+
|
|
489
|
+
deserialized = []
|
|
490
|
+
for ch_dict in chs_info_dict:
|
|
491
|
+
ch = ch_dict.copy()
|
|
492
|
+
# Convert location back to numpy array if present
|
|
493
|
+
if "loc" in ch and ch["loc"] is not None:
|
|
494
|
+
ch["loc"] = np.array(ch["loc"])
|
|
495
|
+
deserialized.append(ch)
|
|
496
|
+
|
|
497
|
+
return deserialized
|
|
498
|
+
|
|
499
|
+
def _save_pretrained(self, save_directory):
|
|
500
|
+
"""
|
|
501
|
+
Save model configuration and weights to the Hub.
|
|
502
|
+
|
|
503
|
+
This method is called by PyTorchModelHubMixin.push_to_hub() to save
|
|
504
|
+
model-specific configuration alongside the model weights.
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
save_directory : str or Path
|
|
509
|
+
Directory where the configuration should be saved.
|
|
510
|
+
"""
|
|
511
|
+
if not HAS_HF_HUB:
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
save_directory = Path(save_directory)
|
|
515
|
+
|
|
516
|
+
# Collect EEG-specific configuration
|
|
517
|
+
config = {
|
|
518
|
+
"n_outputs": self._n_outputs,
|
|
519
|
+
"n_chans": self._n_chans,
|
|
520
|
+
"n_times": self._n_times,
|
|
521
|
+
"input_window_seconds": self._input_window_seconds,
|
|
522
|
+
"sfreq": self._sfreq,
|
|
523
|
+
"chs_info": self._serialize_chs_info(self._chs_info),
|
|
524
|
+
"braindecode_version": __version__,
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
# Save to config.json
|
|
528
|
+
config_path = save_directory / "config.json"
|
|
529
|
+
with open(config_path, "w") as f:
|
|
530
|
+
json.dump(config, f, indent=2)
|
|
531
|
+
|
|
532
|
+
# Save model weights with standard Hub filename
|
|
533
|
+
weights_path = save_directory / "pytorch_model.bin"
|
|
534
|
+
torch.save(self.state_dict(), weights_path)
|
|
535
|
+
|
|
536
|
+
# Also save in safetensors format using parent's implementation
|
|
537
|
+
try:
|
|
538
|
+
super()._save_pretrained(save_directory)
|
|
539
|
+
except (ImportError, RuntimeError) as e:
|
|
540
|
+
# Fallback to pytorch_model.bin if safetensors saving fails
|
|
541
|
+
warnings.warn(
|
|
542
|
+
f"Could not save model in safetensors format: {e}. "
|
|
543
|
+
"Model weights saved in pytorch_model.bin instead.",
|
|
544
|
+
stacklevel=2,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
if HAS_HF_HUB:
|
|
548
|
+
|
|
549
|
+
@classmethod
|
|
550
|
+
def _from_pretrained(
|
|
551
|
+
cls,
|
|
552
|
+
*,
|
|
553
|
+
model_id: str,
|
|
554
|
+
revision: Optional[str],
|
|
555
|
+
cache_dir: Optional[Union[str, Path]],
|
|
556
|
+
force_download: bool,
|
|
557
|
+
local_files_only: bool,
|
|
558
|
+
token: Union[str, bool, None],
|
|
559
|
+
map_location: str = "cpu",
|
|
560
|
+
strict: bool = False,
|
|
561
|
+
**model_kwargs,
|
|
562
|
+
):
|
|
563
|
+
model_kwargs.pop("braindecode_version", None)
|
|
564
|
+
return super()._from_pretrained( # type: ignore
|
|
565
|
+
model_id=model_id,
|
|
566
|
+
revision=revision,
|
|
567
|
+
cache_dir=cache_dir,
|
|
568
|
+
force_download=force_download,
|
|
569
|
+
local_files_only=local_files_only,
|
|
570
|
+
token=token,
|
|
571
|
+
map_location=map_location,
|
|
572
|
+
strict=strict,
|
|
573
|
+
**model_kwargs,
|
|
574
|
+
)
|