braindecode 1.3.0.dev175415232__tar.gz → 1.3.0.dev175955015__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (136) hide show
  1. {braindecode-1.3.0.dev175415232/braindecode.egg-info → braindecode-1.3.0.dev175955015}/PKG-INFO +4 -2
  2. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/eegneuralnet.py +2 -0
  3. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/attentionbasenet.py +2 -0
  4. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/base.py +280 -2
  5. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/labram.py +168 -69
  6. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/signal_jepa.py +103 -27
  7. braindecode-1.3.0.dev175955015/braindecode/version.py +1 -0
  8. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015/braindecode.egg-info}/PKG-INFO +4 -2
  9. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode.egg-info/requires.txt +4 -1
  10. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/api.rst +3 -0
  11. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/whats_new.rst +4 -1
  12. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/pyproject.toml +2 -1
  13. braindecode-1.3.0.dev175415232/braindecode/version.py +0 -1
  14. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/LICENSE.txt +0 -0
  15. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/MANIFEST.in +0 -0
  16. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/NOTICE.txt +0 -0
  17. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/README.rst +0 -0
  18. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/__init__.py +0 -0
  19. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/augmentation/__init__.py +0 -0
  20. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/augmentation/base.py +0 -0
  21. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/augmentation/functional.py +0 -0
  22. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/augmentation/transforms.py +0 -0
  23. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/classifier.py +0 -0
  24. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/__init__.py +0 -0
  25. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/base.py +0 -0
  26. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/bbci.py +0 -0
  27. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/bcicomp.py +0 -0
  28. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/bids.py +0 -0
  29. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/experimental.py +0 -0
  30. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/mne.py +0 -0
  31. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/moabb.py +0 -0
  32. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/nmt.py +0 -0
  33. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
  34. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/sleep_physionet.py +0 -0
  35. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/tuh.py +0 -0
  36. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datasets/xy.py +0 -0
  37. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datautil/__init__.py +0 -0
  38. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datautil/serialization.py +0 -0
  39. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/datautil/util.py +0 -0
  40. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/functional/__init__.py +0 -0
  41. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/functional/functions.py +0 -0
  42. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/functional/initialization.py +0 -0
  43. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/__init__.py +0 -0
  44. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/atcnet.py +0 -0
  45. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/attn_sleep.py +0 -0
  46. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/biot.py +0 -0
  47. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/contrawr.py +0 -0
  48. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/ctnet.py +0 -0
  49. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/deep4.py +0 -0
  50. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/deepsleepnet.py +0 -0
  51. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eegconformer.py +0 -0
  52. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eeginception_erp.py +0 -0
  53. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eeginception_mi.py +0 -0
  54. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eegitnet.py +0 -0
  55. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eegminer.py +0 -0
  56. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eegnet.py +0 -0
  57. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eegnex.py +0 -0
  58. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eegsimpleconv.py +0 -0
  59. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/eegtcnet.py +0 -0
  60. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/fbcnet.py +0 -0
  61. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/fblightconvnet.py +0 -0
  62. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/fbmsnet.py +0 -0
  63. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/hybrid.py +0 -0
  64. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/ifnet.py +0 -0
  65. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/msvtnet.py +0 -0
  66. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/patchedtransformer.py +0 -0
  67. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/sccnet.py +0 -0
  68. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/shallow_fbcsp.py +0 -0
  69. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/sinc_shallow.py +0 -0
  70. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
  71. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
  72. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/sparcnet.py +0 -0
  73. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/sstdpn.py +0 -0
  74. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/summary.csv +0 -0
  75. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/syncnet.py +0 -0
  76. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/tcn.py +0 -0
  77. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/tidnet.py +0 -0
  78. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/tsinception.py +0 -0
  79. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/usleep.py +0 -0
  80. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/models/util.py +0 -0
  81. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/__init__.py +0 -0
  82. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/activation.py +0 -0
  83. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/attention.py +0 -0
  84. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/blocks.py +0 -0
  85. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/convolution.py +0 -0
  86. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/filter.py +0 -0
  87. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/layers.py +0 -0
  88. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/linear.py +0 -0
  89. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/parametrization.py +0 -0
  90. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/stats.py +0 -0
  91. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/util.py +0 -0
  92. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/modules/wrapper.py +0 -0
  93. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/preprocessing/__init__.py +0 -0
  94. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/preprocessing/mne_preprocess.py +0 -0
  95. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/preprocessing/preprocess.py +0 -0
  96. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/preprocessing/windowers.py +0 -0
  97. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/regressor.py +0 -0
  98. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/samplers/__init__.py +0 -0
  99. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/samplers/base.py +0 -0
  100. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/samplers/ssl.py +0 -0
  101. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/training/__init__.py +0 -0
  102. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/training/callbacks.py +0 -0
  103. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/training/losses.py +0 -0
  104. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/training/scoring.py +0 -0
  105. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/util.py +0 -0
  106. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/visualization/__init__.py +0 -0
  107. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/visualization/confusion_matrices.py +0 -0
  108. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode/visualization/gradients.py +0 -0
  109. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode.egg-info/SOURCES.txt +0 -0
  110. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode.egg-info/dependency_links.txt +0 -0
  111. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/braindecode.egg-info/top_level.txt +0 -0
  112. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/Makefile +0 -0
  113. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/_templates/autosummary/class.rst +0 -0
  114. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/_templates/autosummary/function.rst +0 -0
  115. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/cite.rst +0 -0
  116. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/conf.py +0 -0
  117. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/help.rst +0 -0
  118. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/index.rst +0 -0
  119. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/install/install.rst +0 -0
  120. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/install/install_pip.rst +0 -0
  121. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/install/install_source.rst +0 -0
  122. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/attention.rst +0 -0
  123. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/channel.rst +0 -0
  124. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/convolution.rst +0 -0
  125. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/filterbank.rst +0 -0
  126. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/gnn.rst +0 -0
  127. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/interpretable.rst +0 -0
  128. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/lbm.rst +0 -0
  129. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/recurrent.rst +0 -0
  130. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/categorization/spd.rst +0 -0
  131. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/models.rst +0 -0
  132. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/models_categorization.rst +0 -0
  133. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/models_table.rst +0 -0
  134. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/models/models_visualization.rst +0 -0
  135. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/docs/sg_execution_times.rst +0 -0
  136. {braindecode-1.3.0.dev175415232 → braindecode-1.3.0.dev175955015}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.3.0.dev175415232
3
+ Version: 1.3.0.dev175955015
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -40,6 +40,8 @@ Requires-Dist: linear_attention_transformer
40
40
  Requires-Dist: docstring_inheritance
41
41
  Provides-Extra: moabb
42
42
  Requires-Dist: moabb>=1.2.0; extra == "moabb"
43
+ Provides-Extra: hug
44
+ Requires-Dist: huggingface_hub[torch]>=0.20.0; extra == "hug"
43
45
  Provides-Extra: tests
44
46
  Requires-Dist: pytest; extra == "tests"
45
47
  Requires-Dist: pytest-cov; extra == "tests"
@@ -65,7 +67,7 @@ Requires-Dist: pre-commit; extra == "docs"
65
67
  Requires-Dist: openneuro-py; extra == "docs"
66
68
  Requires-Dist: plotly; extra == "docs"
67
69
  Provides-Extra: all
68
- Requires-Dist: braindecode[docs,moabb,tests]; extra == "all"
70
+ Requires-Dist: braindecode[docs,hug,moabb,tests]; extra == "all"
69
71
  Dynamic: license-file
70
72
 
71
73
  .. image:: https://badges.gitter.im/braindecodechat/community.svg
@@ -189,6 +189,8 @@ class _EEGNeuralNet(NeuralNet, abc.ABC):
189
189
  "Skipping setting signal-related parameters from data."
190
190
  )
191
191
  return
192
+ if classes is None:
193
+ classes = getattr(self, "classes", None)
192
194
  # get kwargs from signal:
193
195
  signal_kwargs = dict()
194
196
  # Using shape to work both with torch.tensor and numpy.array:
@@ -381,6 +381,8 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
381
381
  for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
382
382
  out = math.floor(out + 2 * (k // 2) - k + 1)
383
383
  out = math.floor((out - pl) / ps + 1)
384
+ # Ensure output is at least 1 to avoid zero-sized tensors
385
+ out = max(1, out)
384
386
  seq_lengths.append(int(out))
385
387
  return seq_lengths
386
388
 
@@ -5,15 +5,35 @@
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ import json
8
9
  import warnings
9
10
  from collections import OrderedDict
10
- from typing import Dict, Iterable, Optional
11
+ from pathlib import Path
12
+ from typing import Dict, Iterable, Optional, Type, Union
11
13
 
12
14
  import numpy as np
13
15
  import torch
14
16
  from docstring_inheritance import NumpyDocstringInheritanceInitMeta
17
+ from mne.utils import _soft_import
15
18
  from torchinfo import ModelStatistics, summary
16
19
 
20
+ from braindecode.version import __version__
21
+
22
+ huggingface_hub = _soft_import(
23
+ "huggingface_hub", "Hugging Face Hub integration", strict=False
24
+ )
25
+
26
+ HAS_HF_HUB = huggingface_hub is not False
27
+
28
+
29
+ class _BaseHubMixin:
30
+ pass
31
+
32
+
33
+ # Define base class for hub mixin
34
+ if HAS_HF_HUB:
35
+ _BaseHubMixin: Type = huggingface_hub.PyTorchModelHubMixin # type: ignore
36
+
17
37
 
18
38
  def deprecated_args(obj, *old_new_args):
19
39
  out_args = []
@@ -32,10 +52,14 @@ def deprecated_args(obj, *old_new_args):
32
52
  return out_args
33
53
 
34
54
 
35
- class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
55
+ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta):
36
56
  """
37
57
  Mixin class for all EEG models in braindecode.
38
58
 
59
+ This class integrates with Hugging Face Hub when the ``huggingface_hub`` package
60
+ is installed, enabling models to be pushed to and loaded from the Hub using
61
+ :func:`push_to_hub()` and :func:`from_pretrained()` methods.
62
+
39
63
  Parameters
40
64
  ----------
41
65
  n_outputs : int
@@ -62,8 +86,87 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
62
86
  -----
63
87
  If some input signal-related parameters are not specified,
64
88
  there will be an attempt to infer them from the other parameters.
89
+
90
+ Hugging Face Hub Integration
91
+ -----------------------------
92
+ When the optional ``huggingface_hub`` package is installed, all models
93
+ automatically gain the ability to be pushed to and loaded from the
94
+ Hugging Face Hub. Install with::
95
+
96
+ pip install braindecode[hug]
97
+
98
+ **Pushing a model to the Hub:**
99
+
100
+ .. code-block:: python
101
+
102
+ from braindecode.models import EEGNetv4
103
+
104
+ # Train your model
105
+ model = EEGNetv4(n_chans=22, n_outputs=4, n_times=1000)
106
+ # ... training code ...
107
+
108
+ # Push to the Hub
109
+ model.push_to_hub(
110
+ repo_id="username/my-eegnet-model", commit_message="Initial model upload"
111
+ )
112
+
113
+ **Loading a model from the Hub:**
114
+
115
+ .. code-block:: python
116
+
117
+ from braindecode.models import EEGNetv4
118
+
119
+ # Load pretrained model
120
+ model = EEGNetv4.from_pretrained("username/my-eegnet-model")
121
+
122
+ The integration automatically handles EEG-specific parameters (n_chans,
123
+ n_times, sfreq, chs_info, etc.) by saving them in a config file alongside
124
+ the model weights. This ensures that loaded models are correctly configured
125
+ for their original data specifications.
126
+
127
+ .. important::
128
+ Currently, only EEG-specific parameters (n_outputs, n_chans, n_times,
129
+ input_window_seconds, sfreq, chs_info) are saved to the Hub. Model-specific
130
+ parameters (e.g., dropout rates, activation functions, number of filters)
131
+ are not preserved and will use their default values when loading from the Hub.
132
+
133
+ To use non-default model parameters, specify them explicitly when calling
134
+ :func:`from_pretrained()`::
135
+
136
+ model = EEGNet.from_pretrained("user/model", dropout=0.3, activation='relu')
137
+
138
+ Full parameter serialization will be addressed in a future update.
65
139
  """
66
140
 
141
+ def __init_subclass__(cls, **kwargs):
142
+ if not HAS_HF_HUB:
143
+ super().__init_subclass__(**kwargs)
144
+ return
145
+
146
+ base_tags = ["braindecode", cls.__name__]
147
+ user_tags = kwargs.pop("tags", None)
148
+ tags = list(user_tags) if user_tags is not None else []
149
+ for tag in base_tags:
150
+ if tag not in tags:
151
+ tags.append(tag)
152
+
153
+ docs_url = kwargs.pop(
154
+ "docs_url",
155
+ f"https://braindecode.org/stable/generated/braindecode.models.{cls.__name__}.html",
156
+ )
157
+ repo_url = kwargs.pop("repo_url", "https://braindecode.org")
158
+ library_name = kwargs.pop("library_name", "braindecode")
159
+ license = kwargs.pop("license", "bsd-3-clause")
160
+ # TODO: model_card_template can be added in the future for custom model cards
161
+ super().__init_subclass__(
162
+ tags=tags,
163
+ docs_url=docs_url,
164
+ repo_url=repo_url,
165
+ library_name=library_name,
166
+ license=license,
167
+ **kwargs,
168
+ )
169
+
67
170
  def __init__(
68
171
  self,
69
172
  n_outputs: Optional[int] = None, # type: ignore[assignment]
@@ -73,6 +176,16 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
73
176
  input_window_seconds: Optional[float] = None, # type: ignore[assignment]
74
177
  sfreq: Optional[float] = None, # type: ignore[assignment]
75
178
  ):
179
+ # Deserialize chs_info if it comes as a list of dicts (from Hub)
180
+ if chs_info is not None and isinstance(chs_info, list):
181
+ if len(chs_info) > 0 and isinstance(chs_info[0], dict):
182
+ # Check if it needs deserialization (has 'loc' as list)
183
+ if "loc" in chs_info[0] and isinstance(chs_info[0]["loc"], list):
184
+ chs_info = self._deserialize_chs_info(chs_info)
185
+ warnings.warn(
186
+ "Modifying chs_info argument using the _deserialize_chs_info() method"
187
+ )
188
+
76
189
  if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
77
190
  raise ValueError(f"{n_chans=} different from {chs_info=} length")
78
191
  if (
@@ -294,3 +407,168 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
294
407
 
295
408
  def __str__(self) -> str:
296
409
  return str(self.get_torchinfo_statistics())
410
+
411
+ @staticmethod
412
+ def _serialize_chs_info(chs_info):
413
+ """
414
+ Serialize MNE channel info to JSON-compatible format.
415
+
416
+ Parameters
417
+ ----------
418
+ chs_info : list of dict or None
419
+ Channel information from MNE Info object.
420
+
421
+ Returns
422
+ -------
423
+ list of dict or None
424
+ Serialized channel information that can be saved to JSON.
425
+ """
426
+ if chs_info is None:
427
+ return None
428
+
429
+ serialized = []
430
+ for ch in chs_info:
431
+ # Extract serializable fields from MNE channel info
432
+ ch_dict = {
433
+ "ch_name": ch.get("ch_name", ""),
434
+ }
435
+
436
+ # Handle kind field - can be either string or integer
437
+ kind_val = ch.get("kind")
438
+ if kind_val is not None:
439
+ ch_dict["kind"] = (
440
+ kind_val if isinstance(kind_val, str) else int(kind_val)
441
+ )
442
+
443
+ # Add numeric fields with safe conversion
444
+ coil_type = ch.get("coil_type")
445
+ if coil_type is not None:
446
+ ch_dict["coil_type"] = int(coil_type)
447
+
448
+ unit = ch.get("unit")
449
+ if unit is not None:
450
+ ch_dict["unit"] = int(unit)
451
+
452
+ cal = ch.get("cal")
453
+ if cal is not None:
454
+ ch_dict["cal"] = float(cal)
455
+
456
+ range_val = ch.get("range")
457
+ if range_val is not None:
458
+ ch_dict["range"] = float(range_val)
459
+
460
+ # Serialize location array if present
461
+ if "loc" in ch and ch["loc"] is not None:
462
+ ch_dict["loc"] = (
463
+ ch["loc"].tolist()
464
+ if hasattr(ch["loc"], "tolist")
465
+ else list(ch["loc"])
466
+ )
467
+ serialized.append(ch_dict)
468
+
469
+ return serialized
470
+
471
+ @staticmethod
472
+ def _deserialize_chs_info(chs_info_dict):
473
+ """
474
+ Deserialize channel info from JSON-compatible format to MNE-like structure.
475
+
476
+ Parameters
477
+ ----------
478
+ chs_info_dict : list of dict or None
479
+ Serialized channel information.
480
+
481
+ Returns
482
+ -------
483
+ list of dict or None
484
+ Deserialized channel information compatible with MNE.
485
+ """
486
+ if chs_info_dict is None:
487
+ return None
488
+
489
+ deserialized = []
490
+ for ch_dict in chs_info_dict:
491
+ ch = ch_dict.copy()
492
+ # Convert location back to numpy array if present
493
+ if "loc" in ch and ch["loc"] is not None:
494
+ ch["loc"] = np.array(ch["loc"])
495
+ deserialized.append(ch)
496
+
497
+ return deserialized
498
+
499
+ def _save_pretrained(self, save_directory):
500
+ """
501
+ Save model configuration and weights to the Hub.
502
+
503
+ This method is called by PyTorchModelHubMixin.push_to_hub() to save
504
+ model-specific configuration alongside the model weights.
505
+
506
+ Parameters
507
+ ----------
508
+ save_directory : str or Path
509
+ Directory where the configuration should be saved.
510
+ """
511
+ if not HAS_HF_HUB:
512
+ return
513
+
514
+ save_directory = Path(save_directory)
515
+
516
+ # Collect EEG-specific configuration
517
+ config = {
518
+ "n_outputs": self._n_outputs,
519
+ "n_chans": self._n_chans,
520
+ "n_times": self._n_times,
521
+ "input_window_seconds": self._input_window_seconds,
522
+ "sfreq": self._sfreq,
523
+ "chs_info": self._serialize_chs_info(self._chs_info),
524
+ "braindecode_version": __version__,
525
+ }
526
+
527
+ # Save to config.json
528
+ config_path = save_directory / "config.json"
529
+ with open(config_path, "w") as f:
530
+ json.dump(config, f, indent=2)
531
+
532
+ # Save model weights with standard Hub filename
533
+ weights_path = save_directory / "pytorch_model.bin"
534
+ torch.save(self.state_dict(), weights_path)
535
+
536
+ # Also save in safetensors format using parent's implementation
537
+ try:
538
+ super()._save_pretrained(save_directory)
539
+ except (ImportError, RuntimeError) as e:
540
+ # Fallback to pytorch_model.bin if safetensors saving fails
541
+ warnings.warn(
542
+ f"Could not save model in safetensors format: {e}. "
543
+ "Model weights saved in pytorch_model.bin instead.",
544
+ stacklevel=2,
545
+ )
546
+
547
+ if HAS_HF_HUB:
548
+
549
+ @classmethod
550
+ def _from_pretrained(
551
+ cls,
552
+ *,
553
+ model_id: str,
554
+ revision: Optional[str],
555
+ cache_dir: Optional[Union[str, Path]],
556
+ force_download: bool,
557
+ local_files_only: bool,
558
+ token: Union[str, bool, None],
559
+ map_location: str = "cpu",
560
+ strict: bool = False,
561
+ **model_kwargs,
562
+ ):
563
+ model_kwargs.pop("braindecode_version", None)
564
+ return super()._from_pretrained( # type: ignore
565
+ model_id=model_id,
566
+ revision=revision,
567
+ cache_dir=cache_dir,
568
+ force_download=force_download,
569
+ local_files_only=local_files_only,
570
+ token=token,
571
+ map_location=map_location,
572
+ strict=strict,
573
+ **model_kwargs,
574
+ )