braindecode 1.3.0.dev176728557__py3-none-any.whl → 1.3.0.dev177509039__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (42) hide show
  1. braindecode/augmentation/functional.py +154 -54
  2. braindecode/augmentation/transforms.py +2 -2
  3. braindecode/datasets/base.py +18 -17
  4. braindecode/datasets/bcicomp.py +1 -1
  5. braindecode/datasets/sleep_physio_challe_18.py +2 -1
  6. braindecode/datautil/serialization.py +11 -6
  7. braindecode/eegneuralnet.py +2 -0
  8. braindecode/functional/functions.py +6 -2
  9. braindecode/functional/initialization.py +2 -3
  10. braindecode/models/__init__.py +6 -0
  11. braindecode/models/atcnet.py +32 -33
  12. braindecode/models/attentionbasenet.py +39 -32
  13. braindecode/models/base.py +280 -2
  14. braindecode/models/bendr.py +469 -0
  15. braindecode/models/biot.py +3 -1
  16. braindecode/models/ctnet.py +6 -3
  17. braindecode/models/deepsleepnet.py +27 -18
  18. braindecode/models/eegconformer.py +2 -2
  19. braindecode/models/eeginception_erp.py +31 -25
  20. braindecode/models/eegnet.py +5 -4
  21. braindecode/models/labram.py +188 -84
  22. braindecode/models/patchedtransformer.py +640 -0
  23. braindecode/models/signal_jepa.py +109 -27
  24. braindecode/models/sinc_shallow.py +10 -9
  25. braindecode/models/sstdpn.py +869 -0
  26. braindecode/models/summary.csv +9 -6
  27. braindecode/models/usleep.py +26 -21
  28. braindecode/models/util.py +3 -0
  29. braindecode/modules/attention.py +10 -10
  30. braindecode/modules/blocks.py +3 -3
  31. braindecode/modules/filter.py +2 -3
  32. braindecode/modules/layers.py +18 -17
  33. braindecode/preprocessing/preprocess.py +11 -2
  34. braindecode/preprocessing/windowers.py +2 -2
  35. braindecode/samplers/base.py +8 -8
  36. braindecode/version.py +1 -1
  37. {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/METADATA +4 -2
  38. {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/RECORD +42 -39
  39. {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/WHEEL +0 -0
  40. {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/licenses/LICENSE.txt +0 -0
  41. {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/licenses/NOTICE.txt +0 -0
  42. {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from .atcnet import ATCNet
6
6
  from .attentionbasenet import AttentionBaseNet
7
7
  from .attn_sleep import AttnSleep
8
8
  from .base import EEGModuleMixin
9
+ from .bendr import BENDR
9
10
  from .biot import BIOT
10
11
  from .contrawr import ContraWR
11
12
  from .ctnet import CTNet
@@ -27,6 +28,7 @@ from .hybrid import HybridNet
27
28
  from .ifnet import IFNet
28
29
  from .labram import Labram
29
30
  from .msvtnet import MSVTNet
31
+ from .patchedtransformer import PBT
30
32
  from .sccnet import SCCNet
31
33
  from .shallow_fbcsp import ShallowFBCSPNet
32
34
  from .signal_jepa import (
@@ -39,6 +41,7 @@ from .sinc_shallow import SincShallowNet
39
41
  from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
40
42
  from .sleep_stager_chambon_2018 import SleepStagerChambon2018
41
43
  from .sparcnet import SPARCNet
44
+ from .sstdpn import SSTDPN
42
45
  from .syncnet import SyncNet
43
46
  from .tcn import BDTCN, TCN
44
47
  from .tidnet import TIDNet
@@ -56,6 +59,7 @@ __all__ = [
56
59
  "AttentionBaseNet",
57
60
  "EEGModuleMixin",
58
61
  "BIOT",
62
+ "BENDR",
59
63
  "ContraWR",
60
64
  "CTNet",
61
65
  "Deep4Net",
@@ -77,6 +81,7 @@ __all__ = [
77
81
  "IFNet",
78
82
  "Labram",
79
83
  "MSVTNet",
84
+ "PBT",
80
85
  "SCCNet",
81
86
  "ShallowFBCSPNet",
82
87
  "SignalJEPA",
@@ -84,6 +89,7 @@ __all__ = [
84
89
  "SignalJEPA_PostLocal",
85
90
  "SignalJEPA_PreLocal",
86
91
  "SincShallowNet",
92
+ "SSTDPN",
87
93
  "SleepStagerBlanco2020",
88
94
  "SleepStagerChambon2018",
89
95
  "SPARCNet",
@@ -50,7 +50,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
50
50
  - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
51
51
  FIR-like filter bank (``F1`` maps).
52
52
  - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
53
- ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNets CSP-like step).
53
+ ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
54
54
  - **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
55
55
  - **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
56
56
  **BN → ELU → AvgPool → Dropout**.
@@ -62,13 +62,15 @@ class ATCNet(EEGModuleMixin, nn.Module):
62
62
 
63
63
  - **Sliding-Window Sequencer**
64
64
 
65
- From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
66
- of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
67
- ``(B, F2, T_w)`` forwarded to its own attentionTCN branch. This creates *parallel*
68
- encoders over shifted contexts and is key to robustness on nonstationary EEG.
65
+ From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
66
+ of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
67
+ ``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
68
+ encoders over shifted contexts and is key to robustness on nonstationary EEG.
69
69
 
70
70
  - :class:`_AttentionBlock` **(small MHA on temporal positions)**
71
71
 
72
+ Attention here is *local to a window* and purely temporal.
73
+
72
74
  - *Operations.*
73
75
  - Rearrange to ``(B, T_w, F2)``,
74
76
  - Normalization :class:`torch.nn.LayerNorm`
@@ -76,11 +78,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
76
78
  - Dropout :class:`torch.nn.Dropout`
77
79
  - Rearrange back to ``(B, F2, T_w)``.
78
80
 
79
-
80
- **Note**: Attention is *local to a window* and purely temporal.
81
-
82
- *Role.* Re-weights evidence across the window, letting the model emphasize informative
83
- segments (onsets, bursts) before causal convolutions aggregate history.
81
+ *Role.* Re-weights evidence across the window, letting the model emphasize informative
82
+ segments (onsets, bursts) before causal convolutions aggregate history.
84
83
 
85
84
  - :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
86
85
 
@@ -90,8 +89,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
90
89
  a residual (identity or 1x1 mapping).
91
90
  - The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
92
91
 
93
- *Role.* Efficient long-range temporal integration with stable gradients; the dilated
94
- receptive field complements attentions soft selection.
92
+ *Role.* Efficient long-range temporal integration with stable gradients; the dilated
93
+ receptive field complements attention's soft selection.
95
94
 
96
95
  - **Aggregation & Classifier**
97
96
 
@@ -104,16 +103,16 @@ class ATCNet(EEGModuleMixin, nn.Module):
104
103
  .. rubric:: Convolutional Details
105
104
 
106
105
  - **Temporal.** Temporal structure is learned in three places:
107
- - (1) the stems wide ``(L_t, 1)`` conv (learned filter bank),
106
+ - (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
108
107
  - (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
109
- - (3) the TCNs causal 1-D convolutions with exponentially increasing dilation
108
+ - (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
110
109
  (long-range dependencies). The minimum sequence length required by the TCN stack is
111
110
  ``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
112
111
  when inputs are shorter to preserve feasibility.
113
112
 
114
113
  - **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
115
114
  producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
116
- This mirrors EEGNets interpretability: each temporal filter has its own spatial pattern.
115
+ This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
117
116
 
118
117
 
119
118
  .. rubric:: Attention / Sequential Modules
@@ -137,17 +136,17 @@ class ATCNet(EEGModuleMixin, nn.Module):
137
136
 
138
137
  .. rubric:: Usage and Configuration
139
138
 
140
- - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
141
- (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
142
- - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
143
- ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
144
- - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
145
- - ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
146
- - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
147
- longer inputs (see minimum length above). The implementation warns and *rescales*
148
- kernels/pools/windows if inputs are too short.
149
- - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
150
- the official code; ``concat=True`` mirrors the papers concatenation variant.
139
+ - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
140
+ (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
141
+ - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
142
+ ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
143
+ - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
144
+ - ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
145
+ - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
146
+ longer inputs (see minimum length above). The implementation warns and *rescales*
147
+ kernels/pools/windows if inputs are too short.
148
+ - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
149
+ the official code; ``concat=True`` mirrors the paper's concatenation variant.
151
150
 
152
151
 
153
152
  Notes
@@ -370,7 +369,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
370
369
  nn.Sequential(
371
370
  *[
372
371
  _TCNResidualBlock(
373
- in_channels=self.F2,
372
+ in_channels=self.F2 if i == 0 else self.tcn_n_filters,
374
373
  kernel_size=self.tcn_kernel_size,
375
374
  n_filters=self.tcn_n_filters,
376
375
  dropout=self.tcn_dropout,
@@ -388,7 +387,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
388
387
  self.final_layer = nn.ModuleList(
389
388
  [
390
389
  MaxNormLinear(
391
- in_features=self.F2 * self.n_windows,
390
+ in_features=self.tcn_n_filters * self.n_windows,
392
391
  out_features=self.n_outputs,
393
392
  max_norm_val=self.max_norm_const,
394
393
  )
@@ -398,7 +397,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
398
397
  self.final_layer = nn.ModuleList(
399
398
  [
400
399
  MaxNormLinear(
401
- in_features=self.F2,
400
+ in_features=self.tcn_n_filters,
402
401
  out_features=self.n_outputs,
403
402
  max_norm_val=self.max_norm_const,
404
403
  )
@@ -408,7 +407,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
408
407
 
409
408
  self.out_fun = nn.Identity()
410
409
 
411
- def forward(self, X):
410
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
412
411
  # Dimension: (batch_size, C, T)
413
412
  X = self.ensuredims(X)
414
413
  # Dimension: (batch_size, C, T, 1)
@@ -695,8 +694,8 @@ class _TCNResidualBlock(nn.Module):
695
694
  # Reshape the input for the residual connection when necessary
696
695
  if in_channels != n_filters:
697
696
  self.reshaping_conv = nn.Conv1d(
698
- in_channels=in_channels,
699
- out_channels=n_filters,
697
+ in_channels=in_channels, # Specify input channels
698
+ out_channels=n_filters, # Specify output channels
700
699
  kernel_size=1,
701
700
  padding="same",
702
701
  )
@@ -716,7 +715,7 @@ class _TCNResidualBlock(nn.Module):
716
715
  out = self.activation(out)
717
716
  out = self.drop2(out)
718
717
 
719
- out = self.reshaping_conv(out)
718
+ X = self.reshaping_conv(X)
720
719
 
721
720
  # ----- Residual connection -----
722
721
  out = X + out
@@ -97,7 +97,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
97
97
 
98
98
  - **Temporal (where time-domain patterns are learned).**
99
99
  Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
100
- bands/transients; the attention blocks depthwise temporal conv (``(1, L_a)``) sharpens
100
+ bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
101
101
  short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
102
102
  set the token rate and effective temporal resolution.
103
103
 
@@ -127,23 +127,24 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
127
127
 
128
128
  .. rubric:: Additional Mechanisms
129
129
 
130
- - **Attention variants at a glance.**
131
- - ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
132
- - ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
133
- - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
134
- - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
135
- - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
136
- - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
137
- - ``"gct"``: Gated Channel Transformation (global context normalization + gating).
138
- - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
139
- - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
140
- - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
141
- - **Auto-compatibility on short inputs.**
130
+ **Attention variants at a glance:**
142
131
 
143
- If the input duration is too short for the configured kernels/pools, the implementation
144
- **automatically rescales** temporal lengths/strides downward (with a warning) to keep
145
- shapes valid and preserve the pipeline semantics.
132
+ - ``"se"``: Squeeze-and-Excitation (global pooling bottleneck gates).
133
+ - ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
134
+ - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
135
+ - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
136
+ - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
137
+ - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
138
+ - ``"gct"``: Gated Channel Transformation (global context normalization + gating).
139
+ - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
140
+ - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
141
+ - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
146
142
 
143
+ **Auto-compatibility on short inputs:**
144
+
145
+ If the input duration is too short for the configured kernels/pools, the implementation
146
+ **automatically rescales** temporal lengths/strides downward (with a warning) to keep
147
+ shapes valid and preserve the pipeline semantics.
147
148
 
148
149
  .. rubric:: Usage and Configuration
149
150
 
@@ -158,9 +159,9 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
158
159
  - ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
159
160
  - **Training tips.**
160
161
 
161
- Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
162
- only after the stem learns stable filters. For small datasets, prefer simpler modes
163
- (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
162
+ Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
163
+ only after the stem learns stable filters. For small datasets, prefer simpler modes
164
+ (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
164
165
 
165
166
  Notes
166
167
  -----
@@ -170,6 +171,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
170
171
  specific variants (CBAM/CAT).
171
172
  - The paper and original code with more details about the methodological
172
173
  choices are available at the [Martin2023]_ and [MartinCode]_.
174
+
173
175
  .. versionadded:: 0.9
174
176
 
175
177
  Parameters
@@ -198,19 +200,21 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
198
200
  the depth of the network after the initial layer. Default is 16.
199
201
  attention_mode : str, optional
200
202
  The type of attention mechanism to apply. If `None`, no attention is applied.
201
- - "se" for Squeeze-and-excitation network
202
- - "gsop" for Global Second-Order Pooling
203
- - "fca" for Frequency Channel Attention Network
204
- - "encnet" for context encoding module
205
- - "eca" for Efficient channel attention for deep convolutional neural networks
206
- - "ge" for Gather-Excite
207
- - "gct" for Gated Channel Transformation
208
- - "srm" for Style-based Recalibration Module
209
- - "cbam" for Convolutional Block Attention Module
210
- - "cat" for Learning to collaborate channel and temporal attention
211
- from multi-information fusion
212
- - "catlite" for Learning to collaborate channel attention
213
- from multi-information fusion (lite version, cat w/o temporal attention)
203
+
204
+ - "se" for Squeeze-and-excitation network
205
+ - "gsop" for Global Second-Order Pooling
206
+ - "fca" for Frequency Channel Attention Network
207
+ - "encnet" for context encoding module
208
+ - "eca" for Efficient channel attention for deep convolutional neural networks
209
+ - "ge" for Gather-Excite
210
+ - "gct" for Gated Channel Transformation
211
+ - "srm" for Style-based Recalibration Module
212
+ - "cbam" for Convolutional Block Attention Module
213
+ - "cat" for Learning to collaborate channel and temporal attention
214
+ from multi-information fusion
215
+ - "catlite" for Learning to collaborate channel attention
216
+ from multi-information fusion (lite version, cat w/o temporal attention)
217
+
214
218
  pool_length : int, default=8
215
219
  The length of the window for the average pooling operation.
216
220
  pool_stride : int, default=8
@@ -381,6 +385,8 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
381
385
  for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
382
386
  out = math.floor(out + 2 * (k // 2) - k + 1)
383
387
  out = math.floor((out - pl) / ps + 1)
388
+ # Ensure output is at least 1 to avoid zero-sized tensors
389
+ out = max(1, out)
384
390
  seq_lengths.append(int(out))
385
391
  return seq_lengths
386
392
 
@@ -497,6 +503,7 @@ class _ChannelAttentionBlock(nn.Module):
497
503
  ----------
498
504
  attention_mode : str, optional
499
505
  The type of attention mechanism to apply. If `None`, no attention is applied.
506
+
500
507
  - "se" for Squeeze-and-excitation network
501
508
  - "gsop" for Global Second-Order Pooling
502
509
  - "fca" for Frequency Channel Attention Network
@@ -5,15 +5,35 @@
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ import json
8
9
  import warnings
9
10
  from collections import OrderedDict
10
- from typing import Dict, Iterable, Optional
11
+ from pathlib import Path
12
+ from typing import Dict, Iterable, Optional, Type, Union
11
13
 
12
14
  import numpy as np
13
15
  import torch
14
16
  from docstring_inheritance import NumpyDocstringInheritanceInitMeta
17
+ from mne.utils import _soft_import
15
18
  from torchinfo import ModelStatistics, summary
16
19
 
20
+ from braindecode.version import __version__
21
+
22
+ huggingface_hub = _soft_import(
23
+ "huggingface_hub", "Hugging Face Hub integration", strict=False
24
+ )
25
+
26
+ HAS_HF_HUB = huggingface_hub is not False
27
+
28
+
29
+ class _BaseHubMixin:
30
+ pass
31
+
32
+
33
+ # Define base class for hub mixin
34
+ if HAS_HF_HUB:
35
+ _BaseHubMixin: Type = huggingface_hub.PyTorchModelHubMixin # type: ignore
36
+
17
37
 
18
38
  def deprecated_args(obj, *old_new_args):
19
39
  out_args = []
@@ -32,10 +52,14 @@ def deprecated_args(obj, *old_new_args):
32
52
  return out_args
33
53
 
34
54
 
35
- class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
55
+ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta):
36
56
  """
37
57
  Mixin class for all EEG models in braindecode.
38
58
 
59
+ This class integrates with Hugging Face Hub when the ``huggingface_hub`` package
60
+ is installed, enabling models to be pushed to and loaded from the Hub using
61
+ :func:`push_to_hub()` and :func:`from_pretrained()` methods.
62
+
39
63
  Parameters
40
64
  ----------
41
65
  n_outputs : int
@@ -62,8 +86,87 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
62
86
  -----
63
87
  If some input signal-related parameters are not specified,
64
88
  there will be an attempt to infer them from the other parameters.
89
+
90
+ .. rubric:: Hugging Face Hub integration
91
+
92
+ When the optional ``huggingface_hub`` package is installed, all models
93
+ automatically gain the ability to be pushed to and loaded from the
94
+ Hugging Face Hub. Install with::
95
+
96
+ pip install braindecode[hug]
97
+
98
+ **Pushing a model to the Hub:**
99
+
100
+ .. code-block:: python
101
+
102
+ from braindecode.models import EEGNetv4
103
+
104
+ # Train your model
105
+ model = EEGNetv4(n_chans=22, n_outputs=4, n_times=1000)
106
+ # ... training code ...
107
+
108
+ # Push to the Hub
109
+ model.push_to_hub(
110
+ repo_id="username/my-eegnet-model", commit_message="Initial model upload"
111
+ )
112
+
113
+ **Loading a model from the Hub:**
114
+
115
+ .. code-block:: python
116
+
117
+ from braindecode.models import EEGNetv4
118
+
119
+ # Load pretrained model
120
+ model = EEGNetv4.from_pretrained("username/my-eegnet-model")
121
+
122
+ The integration automatically handles EEG-specific parameters (n_chans,
123
+ n_times, sfreq, chs_info, etc.) by saving them in a config file alongside
124
+ the model weights. This ensures that loaded models are correctly configured
125
+ for their original data specifications.
126
+
127
+ .. important::
128
+ Currently, only EEG-specific parameters (n_outputs, n_chans, n_times,
129
+ input_window_seconds, sfreq, chs_info) are saved to the Hub. Model-specific
130
+ parameters (e.g., dropout rates, activation functions, number of filters)
131
+ are not preserved and will use their default values when loading from the Hub.
132
+
133
+ To use non-default model parameters, specify them explicitly when calling
134
+ :func:`from_pretrained()`::
135
+
136
+ model = EEGNet.from_pretrained("user/model", dropout=0.3, activation='relu')
137
+
138
+ Full parameter serialization will be addressed in a future update.
65
139
  """
66
140
 
141
+ def __init_subclass__(cls, **kwargs):
142
+ if not HAS_HF_HUB:
143
+ super().__init_subclass__(**kwargs)
144
+ return
145
+
146
+ base_tags = ["braindecode", cls.__name__]
147
+ user_tags = kwargs.pop("tags", None)
148
+ tags = list(user_tags) if user_tags is not None else []
149
+ for tag in base_tags:
150
+ if tag not in tags:
151
+ tags.append(tag)
152
+
153
+ docs_url = kwargs.pop(
154
+ "docs_url",
155
+ f"https://braindecode.org/stable/generated/braindecode.models.{cls.__name__}.html",
156
+ )
157
+ repo_url = kwargs.pop("repo_url", "https://braindecode.org")
158
+ library_name = kwargs.pop("library_name", "braindecode")
159
+ license = kwargs.pop("license", "bsd-3-clause")
160
+ # TODO: model_card_template can be added in the future for custom model cards
161
+ super().__init_subclass__(
162
+ tags=tags,
163
+ docs_url=docs_url,
164
+ repo_url=repo_url,
165
+ library_name=library_name,
166
+ license=license,
167
+ **kwargs,
168
+ )
169
+
67
170
  def __init__(
68
171
  self,
69
172
  n_outputs: Optional[int] = None, # type: ignore[assignment]
@@ -73,6 +176,16 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
73
176
  input_window_seconds: Optional[float] = None, # type: ignore[assignment]
74
177
  sfreq: Optional[float] = None, # type: ignore[assignment]
75
178
  ):
179
+ # Deserialize chs_info if it comes as a list of dicts (from Hub)
180
+ if chs_info is not None and isinstance(chs_info, list):
181
+ if len(chs_info) > 0 and isinstance(chs_info[0], dict):
182
+ # Check if it needs deserialization (has 'loc' as list)
183
+ if "loc" in chs_info[0] and isinstance(chs_info[0]["loc"], list):
184
+ chs_info = self._deserialize_chs_info(chs_info)
185
+ warnings.warn(
186
+ "Modifying chs_info argument using the _deserialize_chs_info() method"
187
+ )
188
+
76
189
  if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
77
190
  raise ValueError(f"{n_chans=} different from {chs_info=} length")
78
191
  if (
@@ -294,3 +407,168 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
294
407
 
295
408
  def __str__(self) -> str:
296
409
  return str(self.get_torchinfo_statistics())
410
+
411
+ @staticmethod
412
+ def _serialize_chs_info(chs_info):
413
+ """
414
+ Serialize MNE channel info to JSON-compatible format.
415
+
416
+ Parameters
417
+ ----------
418
+ chs_info : list of dict or None
419
+ Channel information from MNE Info object.
420
+
421
+ Returns
422
+ -------
423
+ list of dict or None
424
+ Serialized channel information that can be saved to JSON.
425
+ """
426
+ if chs_info is None:
427
+ return None
428
+
429
+ serialized = []
430
+ for ch in chs_info:
431
+ # Extract serializable fields from MNE channel info
432
+ ch_dict = {
433
+ "ch_name": ch.get("ch_name", ""),
434
+ }
435
+
436
+ # Handle kind field - can be either string or integer
437
+ kind_val = ch.get("kind")
438
+ if kind_val is not None:
439
+ ch_dict["kind"] = (
440
+ kind_val if isinstance(kind_val, str) else int(kind_val)
441
+ )
442
+
443
+ # Add numeric fields with safe conversion
444
+ coil_type = ch.get("coil_type")
445
+ if coil_type is not None:
446
+ ch_dict["coil_type"] = int(coil_type)
447
+
448
+ unit = ch.get("unit")
449
+ if unit is not None:
450
+ ch_dict["unit"] = int(unit)
451
+
452
+ cal = ch.get("cal")
453
+ if cal is not None:
454
+ ch_dict["cal"] = float(cal)
455
+
456
+ range_val = ch.get("range")
457
+ if range_val is not None:
458
+ ch_dict["range"] = float(range_val)
459
+
460
+ # Serialize location array if present
461
+ if "loc" in ch and ch["loc"] is not None:
462
+ ch_dict["loc"] = (
463
+ ch["loc"].tolist()
464
+ if hasattr(ch["loc"], "tolist")
465
+ else list(ch["loc"])
466
+ )
467
+ serialized.append(ch_dict)
468
+
469
+ return serialized
470
+
471
+ @staticmethod
472
+ def _deserialize_chs_info(chs_info_dict):
473
+ """
474
+ Deserialize channel info from JSON-compatible format to MNE-like structure.
475
+
476
+ Parameters
477
+ ----------
478
+ chs_info_dict : list of dict or None
479
+ Serialized channel information.
480
+
481
+ Returns
482
+ -------
483
+ list of dict or None
484
+ Deserialized channel information compatible with MNE.
485
+ """
486
+ if chs_info_dict is None:
487
+ return None
488
+
489
+ deserialized = []
490
+ for ch_dict in chs_info_dict:
491
+ ch = ch_dict.copy()
492
+ # Convert location back to numpy array if present
493
+ if "loc" in ch and ch["loc"] is not None:
494
+ ch["loc"] = np.array(ch["loc"])
495
+ deserialized.append(ch)
496
+
497
+ return deserialized
498
+
499
+ def _save_pretrained(self, save_directory):
500
+ """
501
+ Save model configuration and weights to the Hub.
502
+
503
+ This method is called by PyTorchModelHubMixin.push_to_hub() to save
504
+ model-specific configuration alongside the model weights.
505
+
506
+ Parameters
507
+ ----------
508
+ save_directory : str or Path
509
+ Directory where the configuration should be saved.
510
+ """
511
+ if not HAS_HF_HUB:
512
+ return
513
+
514
+ save_directory = Path(save_directory)
515
+
516
+ # Collect EEG-specific configuration
517
+ config = {
518
+ "n_outputs": self._n_outputs,
519
+ "n_chans": self._n_chans,
520
+ "n_times": self._n_times,
521
+ "input_window_seconds": self._input_window_seconds,
522
+ "sfreq": self._sfreq,
523
+ "chs_info": self._serialize_chs_info(self._chs_info),
524
+ "braindecode_version": __version__,
525
+ }
526
+
527
+ # Save to config.json
528
+ config_path = save_directory / "config.json"
529
+ with open(config_path, "w") as f:
530
+ json.dump(config, f, indent=2)
531
+
532
+ # Save model weights with standard Hub filename
533
+ weights_path = save_directory / "pytorch_model.bin"
534
+ torch.save(self.state_dict(), weights_path)
535
+
536
+ # Also save in safetensors format using parent's implementation
537
+ try:
538
+ super()._save_pretrained(save_directory)
539
+ except (ImportError, RuntimeError) as e:
540
+ # Fallback to pytorch_model.bin if safetensors saving fails
541
+ warnings.warn(
542
+ f"Could not save model in safetensors format: {e}. "
543
+ "Model weights saved in pytorch_model.bin instead.",
544
+ stacklevel=2,
545
+ )
546
+
547
+ if HAS_HF_HUB:
548
+
549
+ @classmethod
550
+ def _from_pretrained(
551
+ cls,
552
+ *,
553
+ model_id: str,
554
+ revision: Optional[str],
555
+ cache_dir: Optional[Union[str, Path]],
556
+ force_download: bool,
557
+ local_files_only: bool,
558
+ token: Union[str, bool, None],
559
+ map_location: str = "cpu",
560
+ strict: bool = False,
561
+ **model_kwargs,
562
+ ):
563
+ model_kwargs.pop("braindecode_version", None)
564
+ return super()._from_pretrained( # type: ignore
565
+ model_id=model_id,
566
+ revision=revision,
567
+ cache_dir=cache_dir,
568
+ force_download=force_download,
569
+ local_files_only=local_files_only,
570
+ token=token,
571
+ map_location=map_location,
572
+ strict=strict,
573
+ **model_kwargs,
574
+ )