braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,574 @@
1
+ # Authors: Pierre Guetschel
2
+ # Maciej Sliwowski
3
+ #
4
+ # License: BSD-3
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import warnings
10
+ from collections import OrderedDict
11
+ from pathlib import Path
12
+ from typing import Dict, Iterable, Optional, Type, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ from docstring_inheritance import NumpyDocstringInheritanceInitMeta
17
+ from mne.utils import _soft_import
18
+ from torchinfo import ModelStatistics, summary
19
+
20
+ from braindecode.version import __version__
21
+
22
+ huggingface_hub = _soft_import(
23
+ "huggingface_hub", "Hugging Face Hub integration", strict=False
24
+ )
25
+
26
+ HAS_HF_HUB = huggingface_hub is not False
27
+
28
+
29
+ class _BaseHubMixin:
30
+ pass
31
+
32
+
33
+ # Define base class for hub mixin
34
+ if HAS_HF_HUB:
35
+ _BaseHubMixin: Type = huggingface_hub.PyTorchModelHubMixin # type: ignore
36
+
37
+
38
+ def deprecated_args(obj, *old_new_args):
39
+ out_args = []
40
+ for old_name, new_name, old_val, new_val in old_new_args:
41
+ if old_val is None:
42
+ out_args.append(new_val)
43
+ else:
44
+ warnings.warn(
45
+ f"{obj.__class__.__name__}: {old_name!r} is depreciated. Use {new_name!r} instead."
46
+ )
47
+ if new_val is not None:
48
+ raise ValueError(
49
+ f"{obj.__class__.__name__}: Both {old_name!r} and {new_name!r} were specified."
50
+ )
51
+ out_args.append(old_val)
52
+ return out_args
53
+
54
+
55
+ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta):
56
+ """
57
+ Mixin class for all EEG models in braindecode.
58
+
59
+ This class integrates with Hugging Face Hub when the ``huggingface_hub`` package
60
+ is installed, enabling models to be pushed to and loaded from the Hub using
61
+ :func:`push_to_hub()` and :func:`from_pretrained()` methods.
62
+
63
+ Parameters
64
+ ----------
65
+ n_outputs : int
66
+ Number of outputs of the model. This is the number of classes
67
+ in the case of classification.
68
+ n_chans : int
69
+ Number of EEG channels.
70
+ chs_info : list of dict
71
+ Information about each individual EEG channel. This should be filled with
72
+ ``info["chs"]``. Refer to :class:`mne.Info` for more details.
73
+ n_times : int
74
+ Number of time samples of the input window.
75
+ input_window_seconds : float
76
+ Length of the input window in seconds.
77
+ sfreq : float
78
+ Sampling frequency of the EEG recordings.
79
+
80
+ Raises
81
+ ------
82
+ ValueError: If some input signal-related parameters are not specified
83
+ and can not be inferred.
84
+
85
+ Notes
86
+ -----
87
+ If some input signal-related parameters are not specified,
88
+ there will be an attempt to infer them from the other parameters.
89
+
90
+ .. rubric:: Hugging Face Hub integration
91
+
92
+ When the optional ``huggingface_hub`` package is installed, all models
93
+ automatically gain the ability to be pushed to and loaded from the
94
+ Hugging Face Hub. Install with::
95
+
96
+ pip install braindecode[hug]
97
+
98
+ **Pushing a model to the Hub:**
99
+
100
+ .. code-block:: python
101
+
102
+ from braindecode.models import EEGNetv4
103
+
104
+ # Train your model
105
+ model = EEGNetv4(n_chans=22, n_outputs=4, n_times=1000)
106
+ # ... training code ...
107
+
108
+ # Push to the Hub
109
+ model.push_to_hub(
110
+ repo_id="username/my-eegnet-model", commit_message="Initial model upload"
111
+ )
112
+
113
+ **Loading a model from the Hub:**
114
+
115
+ .. code-block:: python
116
+
117
+ from braindecode.models import EEGNetv4
118
+
119
+ # Load pretrained model
120
+ model = EEGNetv4.from_pretrained("username/my-eegnet-model")
121
+
122
+ The integration automatically handles EEG-specific parameters (n_chans,
123
+ n_times, sfreq, chs_info, etc.) by saving them in a config file alongside
124
+ the model weights. This ensures that loaded models are correctly configured
125
+ for their original data specifications.
126
+
127
+ .. important::
128
+ Currently, only EEG-specific parameters (n_outputs, n_chans, n_times,
129
+ input_window_seconds, sfreq, chs_info) are saved to the Hub. Model-specific
130
+ parameters (e.g., dropout rates, activation functions, number of filters)
131
+ are not preserved and will use their default values when loading from the Hub.
132
+
133
+ To use non-default model parameters, specify them explicitly when calling
134
+ :func:`from_pretrained()`::
135
+
136
+ model = EEGNet.from_pretrained("user/model", dropout=0.3, activation='relu')
137
+
138
+ Full parameter serialization will be addressed in a future update.
139
+ """
140
+
141
+ def __init_subclass__(cls, **kwargs):
142
+ if not HAS_HF_HUB:
143
+ super().__init_subclass__(**kwargs)
144
+ return
145
+
146
+ base_tags = ["braindecode", cls.__name__]
147
+ user_tags = kwargs.pop("tags", None)
148
+ tags = list(user_tags) if user_tags is not None else []
149
+ for tag in base_tags:
150
+ if tag not in tags:
151
+ tags.append(tag)
152
+
153
+ docs_url = kwargs.pop(
154
+ "docs_url",
155
+ f"https://braindecode.org/stable/generated/braindecode.models.{cls.__name__}.html",
156
+ )
157
+ repo_url = kwargs.pop("repo_url", "https://braindecode.org")
158
+ library_name = kwargs.pop("library_name", "braindecode")
159
+ license = kwargs.pop("license", "bsd-3-clause")
160
+ # TODO: model_card_template can be added in the future for custom model cards
161
+ super().__init_subclass__(
162
+ tags=tags,
163
+ docs_url=docs_url,
164
+ repo_url=repo_url,
165
+ library_name=library_name,
166
+ license=license,
167
+ **kwargs,
168
+ )
169
+
170
+ def __init__(
171
+ self,
172
+ n_outputs: Optional[int] = None, # type: ignore[assignment]
173
+ n_chans: Optional[int] = None, # type: ignore[assignment]
174
+ chs_info=None, # type: ignore[assignment]
175
+ n_times: Optional[int] = None, # type: ignore[assignment]
176
+ input_window_seconds: Optional[float] = None, # type: ignore[assignment]
177
+ sfreq: Optional[float] = None, # type: ignore[assignment]
178
+ ):
179
+ # Deserialize chs_info if it comes as a list of dicts (from Hub)
180
+ if chs_info is not None and isinstance(chs_info, list):
181
+ if len(chs_info) > 0 and isinstance(chs_info[0], dict):
182
+ # Check if it needs deserialization (has 'loc' as list)
183
+ if "loc" in chs_info[0] and isinstance(chs_info[0]["loc"], list):
184
+ chs_info = self._deserialize_chs_info(chs_info)
185
+ warnings.warn(
186
+ "Modifying chs_info argument using the _deserialize_chs_info() method"
187
+ )
188
+
189
+ if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
190
+ raise ValueError(f"{n_chans=} different from {chs_info=} length")
191
+ if (
192
+ n_times is not None
193
+ and input_window_seconds is not None
194
+ and sfreq is not None
195
+ and n_times != round(input_window_seconds * sfreq)
196
+ ):
197
+ raise ValueError(
198
+ f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
199
+ )
200
+
201
+ self._input_window_seconds = input_window_seconds # type: ignore[assignment]
202
+ self._chs_info = chs_info # type: ignore[assignment]
203
+ self._n_outputs = n_outputs # type: ignore[assignment]
204
+ self._n_chans = n_chans # type: ignore[assignment]
205
+ self._n_times = n_times # type: ignore[assignment]
206
+ self._sfreq = sfreq # type: ignore[assignment]
207
+
208
+ super().__init__()
209
+
210
+ @property
211
+ def n_outputs(self) -> int:
212
+ if self._n_outputs is None:
213
+ raise ValueError("n_outputs not specified.")
214
+ return self._n_outputs
215
+
216
+ @property
217
+ def n_chans(self) -> int:
218
+ if self._n_chans is None and self._chs_info is not None:
219
+ return len(self._chs_info)
220
+ elif self._n_chans is None:
221
+ raise ValueError(
222
+ "n_chans could not be inferred. Either specify n_chans or chs_info."
223
+ )
224
+ return self._n_chans
225
+
226
+ @property
227
+ def chs_info(self) -> list[str]:
228
+ if self._chs_info is None:
229
+ raise ValueError("chs_info not specified.")
230
+ return self._chs_info
231
+
232
+ @property
233
+ def n_times(self) -> int:
234
+ if (
235
+ self._n_times is None
236
+ and self._input_window_seconds is not None
237
+ and self._sfreq is not None
238
+ ):
239
+ return round(self._input_window_seconds * self._sfreq)
240
+ elif self._n_times is None:
241
+ raise ValueError(
242
+ "n_times could not be inferred. "
243
+ "Either specify n_times or input_window_seconds and sfreq."
244
+ )
245
+ return self._n_times
246
+
247
+ @property
248
+ def input_window_seconds(self) -> float:
249
+ if (
250
+ self._input_window_seconds is None
251
+ and self._n_times is not None
252
+ and self._sfreq is not None
253
+ ):
254
+ return float(self._n_times / self._sfreq)
255
+ elif self._input_window_seconds is None:
256
+ raise ValueError(
257
+ "input_window_seconds could not be inferred. "
258
+ "Either specify input_window_seconds or n_times and sfreq."
259
+ )
260
+ return self._input_window_seconds
261
+
262
+ @property
263
+ def sfreq(self) -> float:
264
+ if (
265
+ self._sfreq is None
266
+ and self._input_window_seconds is not None
267
+ and self._n_times is not None
268
+ ):
269
+ return float(self._n_times / self._input_window_seconds)
270
+ elif self._sfreq is None:
271
+ raise ValueError(
272
+ "sfreq could not be inferred. "
273
+ "Either specify sfreq or input_window_seconds and n_times."
274
+ )
275
+ return self._sfreq
276
+
277
+ @property
278
+ def input_shape(self) -> tuple[int, int, int]:
279
+ """Input data shape."""
280
+ return (1, self.n_chans, self.n_times)
281
+
282
+ def get_output_shape(self) -> tuple[int, ...]:
283
+ """Returns shape of neural network output for batch size equal 1.
284
+
285
+ Returns
286
+ -------
287
+ output_shape : tuple[int, ...]
288
+ shape of the network output for `batch_size==1` (1, ...)
289
+ """
290
+ with torch.inference_mode():
291
+ try:
292
+ return tuple(
293
+ self.forward( # type: ignore
294
+ torch.zeros(
295
+ self.input_shape,
296
+ dtype=next(self.parameters()).dtype, # type: ignore
297
+ device=next(self.parameters()).device, # type: ignore
298
+ )
299
+ ).shape
300
+ )
301
+ except RuntimeError as exc:
302
+ if str(exc).endswith(
303
+ (
304
+ "Output size is too small",
305
+ "Kernel size can't be greater than actual input size",
306
+ )
307
+ ):
308
+ msg = (
309
+ "During model prediction RuntimeError was thrown showing that at some "
310
+ f"layer `{str(exc).split('.')[-1]}` (see above in the stacktrace). This "
311
+ "could be caused by providing too small `n_times`/`input_window_seconds`. "
312
+ "Model may require longer chunks of signal in the input than "
313
+ f"{self.input_shape}."
314
+ )
315
+ raise ValueError(msg) from exc
316
+ raise exc
317
+
318
+ mapping: Optional[Dict[str, str]] = None
319
+
320
+ def load_state_dict(self, state_dict, *args, **kwargs):
321
+ mapping = self.mapping if self.mapping else {}
322
+ new_state_dict = OrderedDict()
323
+ for k, v in state_dict.items():
324
+ if k in mapping:
325
+ new_state_dict[mapping[k]] = v
326
+ else:
327
+ new_state_dict[k] = v
328
+
329
+ return super().load_state_dict(new_state_dict, *args, **kwargs)
330
+
331
+ def to_dense_prediction_model(self, axis: tuple[int, ...] | int = (2, 3)) -> None:
332
+ """
333
+ Transform a sequential model with strides to a model that outputs.
334
+
335
+ dense predictions by removing the strides and instead inserting dilations.
336
+ Modifies model in-place.
337
+
338
+ Parameters
339
+ ----------
340
+ axis : int or (int,int)
341
+ Axis to transform (in terms of intermediate output axes)
342
+ can either be 2, 3, or (2,3).
343
+
344
+ Notes
345
+ -----
346
+ Does not yet work correctly for average pooling.
347
+ Prior to version 0.1.7, there had been a bug that could move strides
348
+ backwards one layer.
349
+ """
350
+ if not hasattr(axis, "__iter__"):
351
+ axis = (axis,)
352
+ assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis" # type: ignore[union-attr]
353
+ axis = np.array(axis) - 2
354
+ stride_so_far = np.array([1, 1])
355
+ for module in self.modules(): # type: ignore
356
+ if hasattr(module, "dilation"):
357
+ assert module.dilation == 1 or (module.dilation == (1, 1)), (
358
+ "Dilation should equal 1 before conversion, maybe the model is "
359
+ "already converted?"
360
+ )
361
+ new_dilation = [1, 1]
362
+ for ax in axis: # type: ignore[union-attr]
363
+ new_dilation[ax] = int(stride_so_far[ax])
364
+ module.dilation = tuple(new_dilation)
365
+ if hasattr(module, "stride"):
366
+ if not hasattr(module.stride, "__len__"):
367
+ module.stride = (module.stride, module.stride)
368
+ stride_so_far *= np.array(module.stride)
369
+ new_stride = list(module.stride)
370
+ for ax in axis: # type: ignore[union-attr]
371
+ new_stride[ax] = 1
372
+ module.stride = tuple(new_stride)
373
+
374
+ def get_torchinfo_statistics(
375
+ self,
376
+ col_names: Optional[Iterable[str]] = (
377
+ "input_size",
378
+ "output_size",
379
+ "num_params",
380
+ "kernel_size",
381
+ ),
382
+ row_settings: Optional[Iterable[str]] = ("var_names", "depth"),
383
+ ) -> ModelStatistics:
384
+ """Generate table describing the model using torchinfo.summary.
385
+
386
+ Parameters
387
+ ----------
388
+ col_names : tuple, optional
389
+ Specify which columns to show in the output, see torchinfo for details, by default
390
+ ("input_size", "output_size", "num_params", "kernel_size")
391
+ row_settings : tuple, optional
392
+ Specify which features to show in a row, see torchinfo for details, by default
393
+ ("var_names", "depth")
394
+
395
+ Returns
396
+ -------
397
+ torchinfo.ModelStatistics
398
+ ModelStatistics generated by torchinfo.summary.
399
+ """
400
+ return summary(
401
+ self,
402
+ input_size=(1, self.n_chans, self.n_times),
403
+ col_names=col_names,
404
+ row_settings=row_settings,
405
+ verbose=0,
406
+ )
407
+
408
+ def __str__(self) -> str:
409
+ return str(self.get_torchinfo_statistics())
410
+
411
+ @staticmethod
412
+ def _serialize_chs_info(chs_info):
413
+ """
414
+ Serialize MNE channel info to JSON-compatible format.
415
+
416
+ Parameters
417
+ ----------
418
+ chs_info : list of dict or None
419
+ Channel information from MNE Info object.
420
+
421
+ Returns
422
+ -------
423
+ list of dict or None
424
+ Serialized channel information that can be saved to JSON.
425
+ """
426
+ if chs_info is None:
427
+ return None
428
+
429
+ serialized = []
430
+ for ch in chs_info:
431
+ # Extract serializable fields from MNE channel info
432
+ ch_dict = {
433
+ "ch_name": ch.get("ch_name", ""),
434
+ }
435
+
436
+ # Handle kind field - can be either string or integer
437
+ kind_val = ch.get("kind")
438
+ if kind_val is not None:
439
+ ch_dict["kind"] = (
440
+ kind_val if isinstance(kind_val, str) else int(kind_val)
441
+ )
442
+
443
+ # Add numeric fields with safe conversion
444
+ coil_type = ch.get("coil_type")
445
+ if coil_type is not None:
446
+ ch_dict["coil_type"] = int(coil_type)
447
+
448
+ unit = ch.get("unit")
449
+ if unit is not None:
450
+ ch_dict["unit"] = int(unit)
451
+
452
+ cal = ch.get("cal")
453
+ if cal is not None:
454
+ ch_dict["cal"] = float(cal)
455
+
456
+ range_val = ch.get("range")
457
+ if range_val is not None:
458
+ ch_dict["range"] = float(range_val)
459
+
460
+ # Serialize location array if present
461
+ if "loc" in ch and ch["loc"] is not None:
462
+ ch_dict["loc"] = (
463
+ ch["loc"].tolist()
464
+ if hasattr(ch["loc"], "tolist")
465
+ else list(ch["loc"])
466
+ )
467
+ serialized.append(ch_dict)
468
+
469
+ return serialized
470
+
471
+ @staticmethod
472
+ def _deserialize_chs_info(chs_info_dict):
473
+ """
474
+ Deserialize channel info from JSON-compatible format to MNE-like structure.
475
+
476
+ Parameters
477
+ ----------
478
+ chs_info_dict : list of dict or None
479
+ Serialized channel information.
480
+
481
+ Returns
482
+ -------
483
+ list of dict or None
484
+ Deserialized channel information compatible with MNE.
485
+ """
486
+ if chs_info_dict is None:
487
+ return None
488
+
489
+ deserialized = []
490
+ for ch_dict in chs_info_dict:
491
+ ch = ch_dict.copy()
492
+ # Convert location back to numpy array if present
493
+ if "loc" in ch and ch["loc"] is not None:
494
+ ch["loc"] = np.array(ch["loc"])
495
+ deserialized.append(ch)
496
+
497
+ return deserialized
498
+
499
+ def _save_pretrained(self, save_directory):
500
+ """
501
+ Save model configuration and weights to the Hub.
502
+
503
+ This method is called by PyTorchModelHubMixin.push_to_hub() to save
504
+ model-specific configuration alongside the model weights.
505
+
506
+ Parameters
507
+ ----------
508
+ save_directory : str or Path
509
+ Directory where the configuration should be saved.
510
+ """
511
+ if not HAS_HF_HUB:
512
+ return
513
+
514
+ save_directory = Path(save_directory)
515
+
516
+ # Collect EEG-specific configuration
517
+ config = {
518
+ "n_outputs": self._n_outputs,
519
+ "n_chans": self._n_chans,
520
+ "n_times": self._n_times,
521
+ "input_window_seconds": self._input_window_seconds,
522
+ "sfreq": self._sfreq,
523
+ "chs_info": self._serialize_chs_info(self._chs_info),
524
+ "braindecode_version": __version__,
525
+ }
526
+
527
+ # Save to config.json
528
+ config_path = save_directory / "config.json"
529
+ with open(config_path, "w") as f:
530
+ json.dump(config, f, indent=2)
531
+
532
+ # Save model weights with standard Hub filename
533
+ weights_path = save_directory / "pytorch_model.bin"
534
+ torch.save(self.state_dict(), weights_path)
535
+
536
+ # Also save in safetensors format using parent's implementation
537
+ try:
538
+ super()._save_pretrained(save_directory)
539
+ except (ImportError, RuntimeError) as e:
540
+ # Fallback to pytorch_model.bin if safetensors saving fails
541
+ warnings.warn(
542
+ f"Could not save model in safetensors format: {e}. "
543
+ "Model weights saved in pytorch_model.bin instead.",
544
+ stacklevel=2,
545
+ )
546
+
547
+ if HAS_HF_HUB:
548
+
549
+ @classmethod
550
+ def _from_pretrained(
551
+ cls,
552
+ *,
553
+ model_id: str,
554
+ revision: Optional[str],
555
+ cache_dir: Optional[Union[str, Path]],
556
+ force_download: bool,
557
+ local_files_only: bool,
558
+ token: Union[str, bool, None],
559
+ map_location: str = "cpu",
560
+ strict: bool = False,
561
+ **model_kwargs,
562
+ ):
563
+ model_kwargs.pop("braindecode_version", None)
564
+ return super()._from_pretrained( # type: ignore
565
+ model_id=model_id,
566
+ revision=revision,
567
+ cache_dir=cache_dir,
568
+ force_download=force_download,
569
+ local_files_only=local_files_only,
570
+ token=token,
571
+ map_location=map_location,
572
+ strict=strict,
573
+ **model_kwargs,
574
+ )