braindecode 1.3.0.dev176728557__py3-none-any.whl → 1.3.0.dev177509039__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/augmentation/functional.py +154 -54
- braindecode/augmentation/transforms.py +2 -2
- braindecode/datasets/base.py +18 -17
- braindecode/datasets/bcicomp.py +1 -1
- braindecode/datasets/sleep_physio_challe_18.py +2 -1
- braindecode/datautil/serialization.py +11 -6
- braindecode/eegneuralnet.py +2 -0
- braindecode/functional/functions.py +6 -2
- braindecode/functional/initialization.py +2 -3
- braindecode/models/__init__.py +6 -0
- braindecode/models/atcnet.py +32 -33
- braindecode/models/attentionbasenet.py +39 -32
- braindecode/models/base.py +280 -2
- braindecode/models/bendr.py +469 -0
- braindecode/models/biot.py +3 -1
- braindecode/models/ctnet.py +6 -3
- braindecode/models/deepsleepnet.py +27 -18
- braindecode/models/eegconformer.py +2 -2
- braindecode/models/eeginception_erp.py +31 -25
- braindecode/models/eegnet.py +5 -4
- braindecode/models/labram.py +188 -84
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/signal_jepa.py +109 -27
- braindecode/models/sinc_shallow.py +10 -9
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +9 -6
- braindecode/models/usleep.py +26 -21
- braindecode/models/util.py +3 -0
- braindecode/modules/attention.py +10 -10
- braindecode/modules/blocks.py +3 -3
- braindecode/modules/filter.py +2 -3
- braindecode/modules/layers.py +18 -17
- braindecode/preprocessing/preprocess.py +11 -2
- braindecode/preprocessing/windowers.py +2 -2
- braindecode/samplers/base.py +8 -8
- braindecode/version.py +1 -1
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/METADATA +4 -2
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/RECORD +42 -39
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/WHEEL +0 -0
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev176728557.dist-info → braindecode-1.3.0.dev177509039.dist-info}/top_level.txt +0 -0
braindecode/models/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from .atcnet import ATCNet
|
|
|
6
6
|
from .attentionbasenet import AttentionBaseNet
|
|
7
7
|
from .attn_sleep import AttnSleep
|
|
8
8
|
from .base import EEGModuleMixin
|
|
9
|
+
from .bendr import BENDR
|
|
9
10
|
from .biot import BIOT
|
|
10
11
|
from .contrawr import ContraWR
|
|
11
12
|
from .ctnet import CTNet
|
|
@@ -27,6 +28,7 @@ from .hybrid import HybridNet
|
|
|
27
28
|
from .ifnet import IFNet
|
|
28
29
|
from .labram import Labram
|
|
29
30
|
from .msvtnet import MSVTNet
|
|
31
|
+
from .patchedtransformer import PBT
|
|
30
32
|
from .sccnet import SCCNet
|
|
31
33
|
from .shallow_fbcsp import ShallowFBCSPNet
|
|
32
34
|
from .signal_jepa import (
|
|
@@ -39,6 +41,7 @@ from .sinc_shallow import SincShallowNet
|
|
|
39
41
|
from .sleep_stager_blanco_2020 import SleepStagerBlanco2020
|
|
40
42
|
from .sleep_stager_chambon_2018 import SleepStagerChambon2018
|
|
41
43
|
from .sparcnet import SPARCNet
|
|
44
|
+
from .sstdpn import SSTDPN
|
|
42
45
|
from .syncnet import SyncNet
|
|
43
46
|
from .tcn import BDTCN, TCN
|
|
44
47
|
from .tidnet import TIDNet
|
|
@@ -56,6 +59,7 @@ __all__ = [
|
|
|
56
59
|
"AttentionBaseNet",
|
|
57
60
|
"EEGModuleMixin",
|
|
58
61
|
"BIOT",
|
|
62
|
+
"BENDR",
|
|
59
63
|
"ContraWR",
|
|
60
64
|
"CTNet",
|
|
61
65
|
"Deep4Net",
|
|
@@ -77,6 +81,7 @@ __all__ = [
|
|
|
77
81
|
"IFNet",
|
|
78
82
|
"Labram",
|
|
79
83
|
"MSVTNet",
|
|
84
|
+
"PBT",
|
|
80
85
|
"SCCNet",
|
|
81
86
|
"ShallowFBCSPNet",
|
|
82
87
|
"SignalJEPA",
|
|
@@ -84,6 +89,7 @@ __all__ = [
|
|
|
84
89
|
"SignalJEPA_PostLocal",
|
|
85
90
|
"SignalJEPA_PreLocal",
|
|
86
91
|
"SincShallowNet",
|
|
92
|
+
"SSTDPN",
|
|
87
93
|
"SleepStagerBlanco2020",
|
|
88
94
|
"SleepStagerChambon2018",
|
|
89
95
|
"SPARCNet",
|
braindecode/models/atcnet.py
CHANGED
|
@@ -50,7 +50,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
50
50
|
- **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
|
|
51
51
|
FIR-like filter bank (``F1`` maps).
|
|
52
52
|
- **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
|
|
53
|
-
``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet
|
|
53
|
+
``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
|
|
54
54
|
- **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
|
|
55
55
|
- **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
|
|
56
56
|
**BN → ELU → AvgPool → Dropout**.
|
|
@@ -62,13 +62,15 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
62
62
|
|
|
63
63
|
- **Sliding-Window Sequencer**
|
|
64
64
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
65
|
+
From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
|
|
66
|
+
of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
|
|
67
|
+
``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
|
|
68
|
+
encoders over shifted contexts and is key to robustness on nonstationary EEG.
|
|
69
69
|
|
|
70
70
|
- :class:`_AttentionBlock` **(small MHA on temporal positions)**
|
|
71
71
|
|
|
72
|
+
Attention here is *local to a window* and purely temporal.
|
|
73
|
+
|
|
72
74
|
- *Operations.*
|
|
73
75
|
- Rearrange to ``(B, T_w, F2)``,
|
|
74
76
|
- Normalization :class:`torch.nn.LayerNorm`
|
|
@@ -76,11 +78,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
76
78
|
- Dropout :class:`torch.nn.Dropout`
|
|
77
79
|
- Rearrange back to ``(B, F2, T_w)``.
|
|
78
80
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
*Role.* Re-weights evidence across the window, letting the model emphasize informative
|
|
83
|
-
segments (onsets, bursts) before causal convolutions aggregate history.
|
|
81
|
+
*Role.* Re-weights evidence across the window, letting the model emphasize informative
|
|
82
|
+
segments (onsets, bursts) before causal convolutions aggregate history.
|
|
84
83
|
|
|
85
84
|
- :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
|
|
86
85
|
|
|
@@ -90,8 +89,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
90
89
|
a residual (identity or 1x1 mapping).
|
|
91
90
|
- The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
|
|
92
91
|
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
*Role.* Efficient long-range temporal integration with stable gradients; the dilated
|
|
93
|
+
receptive field complements attention's soft selection.
|
|
95
94
|
|
|
96
95
|
- **Aggregation & Classifier**
|
|
97
96
|
|
|
@@ -104,16 +103,16 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
104
103
|
.. rubric:: Convolutional Details
|
|
105
104
|
|
|
106
105
|
- **Temporal.** Temporal structure is learned in three places:
|
|
107
|
-
- (1) the stem
|
|
106
|
+
- (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
|
|
108
107
|
- (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
|
|
109
|
-
- (3) the TCN
|
|
108
|
+
- (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
|
|
110
109
|
(long-range dependencies). The minimum sequence length required by the TCN stack is
|
|
111
110
|
``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
|
|
112
111
|
when inputs are shorter to preserve feasibility.
|
|
113
112
|
|
|
114
113
|
- **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
|
|
115
114
|
producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
|
|
116
|
-
This mirrors EEGNet
|
|
115
|
+
This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
|
|
117
116
|
|
|
118
117
|
|
|
119
118
|
.. rubric:: Attention / Sequential Modules
|
|
@@ -137,17 +136,17 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
137
136
|
|
|
138
137
|
.. rubric:: Usage and Configuration
|
|
139
138
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
139
|
+
- ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
|
|
140
|
+
(with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
|
|
141
|
+
- Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
|
|
142
|
+
``T_c = T/(P1·P2)`` and thus window width ``T_w``.
|
|
143
|
+
- ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
|
|
144
|
+
- ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
|
|
145
|
+
- ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
|
|
146
|
+
longer inputs (see minimum length above). The implementation warns and *rescales*
|
|
147
|
+
kernels/pools/windows if inputs are too short.
|
|
148
|
+
- **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
|
|
149
|
+
the official code; ``concat=True`` mirrors the paper's concatenation variant.
|
|
151
150
|
|
|
152
151
|
|
|
153
152
|
Notes
|
|
@@ -370,7 +369,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
370
369
|
nn.Sequential(
|
|
371
370
|
*[
|
|
372
371
|
_TCNResidualBlock(
|
|
373
|
-
in_channels=self.F2,
|
|
372
|
+
in_channels=self.F2 if i == 0 else self.tcn_n_filters,
|
|
374
373
|
kernel_size=self.tcn_kernel_size,
|
|
375
374
|
n_filters=self.tcn_n_filters,
|
|
376
375
|
dropout=self.tcn_dropout,
|
|
@@ -388,7 +387,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
388
387
|
self.final_layer = nn.ModuleList(
|
|
389
388
|
[
|
|
390
389
|
MaxNormLinear(
|
|
391
|
-
in_features=self.
|
|
390
|
+
in_features=self.tcn_n_filters * self.n_windows,
|
|
392
391
|
out_features=self.n_outputs,
|
|
393
392
|
max_norm_val=self.max_norm_const,
|
|
394
393
|
)
|
|
@@ -398,7 +397,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
398
397
|
self.final_layer = nn.ModuleList(
|
|
399
398
|
[
|
|
400
399
|
MaxNormLinear(
|
|
401
|
-
in_features=self.
|
|
400
|
+
in_features=self.tcn_n_filters,
|
|
402
401
|
out_features=self.n_outputs,
|
|
403
402
|
max_norm_val=self.max_norm_const,
|
|
404
403
|
)
|
|
@@ -408,7 +407,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
408
407
|
|
|
409
408
|
self.out_fun = nn.Identity()
|
|
410
409
|
|
|
411
|
-
def forward(self, X):
|
|
410
|
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
|
412
411
|
# Dimension: (batch_size, C, T)
|
|
413
412
|
X = self.ensuredims(X)
|
|
414
413
|
# Dimension: (batch_size, C, T, 1)
|
|
@@ -695,8 +694,8 @@ class _TCNResidualBlock(nn.Module):
|
|
|
695
694
|
# Reshape the input for the residual connection when necessary
|
|
696
695
|
if in_channels != n_filters:
|
|
697
696
|
self.reshaping_conv = nn.Conv1d(
|
|
698
|
-
in_channels=in_channels,
|
|
699
|
-
out_channels=n_filters,
|
|
697
|
+
in_channels=in_channels, # Specify input channels
|
|
698
|
+
out_channels=n_filters, # Specify output channels
|
|
700
699
|
kernel_size=1,
|
|
701
700
|
padding="same",
|
|
702
701
|
)
|
|
@@ -716,7 +715,7 @@ class _TCNResidualBlock(nn.Module):
|
|
|
716
715
|
out = self.activation(out)
|
|
717
716
|
out = self.drop2(out)
|
|
718
717
|
|
|
719
|
-
|
|
718
|
+
X = self.reshaping_conv(X)
|
|
720
719
|
|
|
721
720
|
# ----- Residual connection -----
|
|
722
721
|
out = X + out
|
|
@@ -97,7 +97,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
97
97
|
|
|
98
98
|
- **Temporal (where time-domain patterns are learned).**
|
|
99
99
|
Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
|
|
100
|
-
bands/transients; the attention block
|
|
100
|
+
bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
|
|
101
101
|
short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
|
|
102
102
|
set the token rate and effective temporal resolution.
|
|
103
103
|
|
|
@@ -127,23 +127,24 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
127
127
|
|
|
128
128
|
.. rubric:: Additional Mechanisms
|
|
129
129
|
|
|
130
|
-
|
|
131
|
-
- ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
|
|
132
|
-
- ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
|
|
133
|
-
- ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
|
|
134
|
-
- ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
|
|
135
|
-
- ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
|
|
136
|
-
- ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
|
|
137
|
-
- ``"gct"``: Gated Channel Transformation (global context normalization + gating).
|
|
138
|
-
- ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
|
|
139
|
-
- ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
|
|
140
|
-
- ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
|
|
141
|
-
- **Auto-compatibility on short inputs.**
|
|
130
|
+
**Attention variants at a glance:**
|
|
142
131
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
132
|
+
- ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
|
|
133
|
+
- ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
|
|
134
|
+
- ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
|
|
135
|
+
- ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
|
|
136
|
+
- ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
|
|
137
|
+
- ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
|
|
138
|
+
- ``"gct"``: Gated Channel Transformation (global context normalization + gating).
|
|
139
|
+
- ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
|
|
140
|
+
- ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
|
|
141
|
+
- ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
|
|
146
142
|
|
|
143
|
+
**Auto-compatibility on short inputs:**
|
|
144
|
+
|
|
145
|
+
If the input duration is too short for the configured kernels/pools, the implementation
|
|
146
|
+
**automatically rescales** temporal lengths/strides downward (with a warning) to keep
|
|
147
|
+
shapes valid and preserve the pipeline semantics.
|
|
147
148
|
|
|
148
149
|
.. rubric:: Usage and Configuration
|
|
149
150
|
|
|
@@ -158,9 +159,9 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
158
159
|
- ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
|
|
159
160
|
- **Training tips.**
|
|
160
161
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
162
|
+
Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
|
|
163
|
+
only after the stem learns stable filters. For small datasets, prefer simpler modes
|
|
164
|
+
(``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
|
|
164
165
|
|
|
165
166
|
Notes
|
|
166
167
|
-----
|
|
@@ -170,6 +171,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
170
171
|
specific variants (CBAM/CAT).
|
|
171
172
|
- The paper and original code with more details about the methodological
|
|
172
173
|
choices are available at the [Martin2023]_ and [MartinCode]_.
|
|
174
|
+
|
|
173
175
|
.. versionadded:: 0.9
|
|
174
176
|
|
|
175
177
|
Parameters
|
|
@@ -198,19 +200,21 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
198
200
|
the depth of the network after the initial layer. Default is 16.
|
|
199
201
|
attention_mode : str, optional
|
|
200
202
|
The type of attention mechanism to apply. If `None`, no attention is applied.
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
203
|
+
|
|
204
|
+
- "se" for Squeeze-and-excitation network
|
|
205
|
+
- "gsop" for Global Second-Order Pooling
|
|
206
|
+
- "fca" for Frequency Channel Attention Network
|
|
207
|
+
- "encnet" for context encoding module
|
|
208
|
+
- "eca" for Efficient channel attention for deep convolutional neural networks
|
|
209
|
+
- "ge" for Gather-Excite
|
|
210
|
+
- "gct" for Gated Channel Transformation
|
|
211
|
+
- "srm" for Style-based Recalibration Module
|
|
212
|
+
- "cbam" for Convolutional Block Attention Module
|
|
213
|
+
- "cat" for Learning to collaborate channel and temporal attention
|
|
214
|
+
from multi-information fusion
|
|
215
|
+
- "catlite" for Learning to collaborate channel attention
|
|
216
|
+
from multi-information fusion (lite version, cat w/o temporal attention)
|
|
217
|
+
|
|
214
218
|
pool_length : int, default=8
|
|
215
219
|
The length of the window for the average pooling operation.
|
|
216
220
|
pool_stride : int, default=8
|
|
@@ -381,6 +385,8 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
381
385
|
for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
|
|
382
386
|
out = math.floor(out + 2 * (k // 2) - k + 1)
|
|
383
387
|
out = math.floor((out - pl) / ps + 1)
|
|
388
|
+
# Ensure output is at least 1 to avoid zero-sized tensors
|
|
389
|
+
out = max(1, out)
|
|
384
390
|
seq_lengths.append(int(out))
|
|
385
391
|
return seq_lengths
|
|
386
392
|
|
|
@@ -497,6 +503,7 @@ class _ChannelAttentionBlock(nn.Module):
|
|
|
497
503
|
----------
|
|
498
504
|
attention_mode : str, optional
|
|
499
505
|
The type of attention mechanism to apply. If `None`, no attention is applied.
|
|
506
|
+
|
|
500
507
|
- "se" for Squeeze-and-excitation network
|
|
501
508
|
- "gsop" for Global Second-Order Pooling
|
|
502
509
|
- "fca" for Frequency Channel Attention Network
|
braindecode/models/base.py
CHANGED
|
@@ -5,15 +5,35 @@
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
+
import json
|
|
8
9
|
import warnings
|
|
9
10
|
from collections import OrderedDict
|
|
10
|
-
from
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict, Iterable, Optional, Type, Union
|
|
11
13
|
|
|
12
14
|
import numpy as np
|
|
13
15
|
import torch
|
|
14
16
|
from docstring_inheritance import NumpyDocstringInheritanceInitMeta
|
|
17
|
+
from mne.utils import _soft_import
|
|
15
18
|
from torchinfo import ModelStatistics, summary
|
|
16
19
|
|
|
20
|
+
from braindecode.version import __version__
|
|
21
|
+
|
|
22
|
+
huggingface_hub = _soft_import(
|
|
23
|
+
"huggingface_hub", "Hugging Face Hub integration", strict=False
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
HAS_HF_HUB = huggingface_hub is not False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _BaseHubMixin:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Define base class for hub mixin
|
|
34
|
+
if HAS_HF_HUB:
|
|
35
|
+
_BaseHubMixin: Type = huggingface_hub.PyTorchModelHubMixin # type: ignore
|
|
36
|
+
|
|
17
37
|
|
|
18
38
|
def deprecated_args(obj, *old_new_args):
|
|
19
39
|
out_args = []
|
|
@@ -32,10 +52,14 @@ def deprecated_args(obj, *old_new_args):
|
|
|
32
52
|
return out_args
|
|
33
53
|
|
|
34
54
|
|
|
35
|
-
class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
55
|
+
class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta):
|
|
36
56
|
"""
|
|
37
57
|
Mixin class for all EEG models in braindecode.
|
|
38
58
|
|
|
59
|
+
This class integrates with Hugging Face Hub when the ``huggingface_hub`` package
|
|
60
|
+
is installed, enabling models to be pushed to and loaded from the Hub using
|
|
61
|
+
:func:`push_to_hub()` and :func:`from_pretrained()` methods.
|
|
62
|
+
|
|
39
63
|
Parameters
|
|
40
64
|
----------
|
|
41
65
|
n_outputs : int
|
|
@@ -62,8 +86,87 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
62
86
|
-----
|
|
63
87
|
If some input signal-related parameters are not specified,
|
|
64
88
|
there will be an attempt to infer them from the other parameters.
|
|
89
|
+
|
|
90
|
+
.. rubric:: Hugging Face Hub integration
|
|
91
|
+
|
|
92
|
+
When the optional ``huggingface_hub`` package is installed, all models
|
|
93
|
+
automatically gain the ability to be pushed to and loaded from the
|
|
94
|
+
Hugging Face Hub. Install with::
|
|
95
|
+
|
|
96
|
+
pip install braindecode[hug]
|
|
97
|
+
|
|
98
|
+
**Pushing a model to the Hub:**
|
|
99
|
+
|
|
100
|
+
.. code-block:: python
|
|
101
|
+
|
|
102
|
+
from braindecode.models import EEGNetv4
|
|
103
|
+
|
|
104
|
+
# Train your model
|
|
105
|
+
model = EEGNetv4(n_chans=22, n_outputs=4, n_times=1000)
|
|
106
|
+
# ... training code ...
|
|
107
|
+
|
|
108
|
+
# Push to the Hub
|
|
109
|
+
model.push_to_hub(
|
|
110
|
+
repo_id="username/my-eegnet-model", commit_message="Initial model upload"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
**Loading a model from the Hub:**
|
|
114
|
+
|
|
115
|
+
.. code-block:: python
|
|
116
|
+
|
|
117
|
+
from braindecode.models import EEGNetv4
|
|
118
|
+
|
|
119
|
+
# Load pretrained model
|
|
120
|
+
model = EEGNetv4.from_pretrained("username/my-eegnet-model")
|
|
121
|
+
|
|
122
|
+
The integration automatically handles EEG-specific parameters (n_chans,
|
|
123
|
+
n_times, sfreq, chs_info, etc.) by saving them in a config file alongside
|
|
124
|
+
the model weights. This ensures that loaded models are correctly configured
|
|
125
|
+
for their original data specifications.
|
|
126
|
+
|
|
127
|
+
.. important::
|
|
128
|
+
Currently, only EEG-specific parameters (n_outputs, n_chans, n_times,
|
|
129
|
+
input_window_seconds, sfreq, chs_info) are saved to the Hub. Model-specific
|
|
130
|
+
parameters (e.g., dropout rates, activation functions, number of filters)
|
|
131
|
+
are not preserved and will use their default values when loading from the Hub.
|
|
132
|
+
|
|
133
|
+
To use non-default model parameters, specify them explicitly when calling
|
|
134
|
+
:func:`from_pretrained()`::
|
|
135
|
+
|
|
136
|
+
model = EEGNet.from_pretrained("user/model", dropout=0.3, activation='relu')
|
|
137
|
+
|
|
138
|
+
Full parameter serialization will be addressed in a future update.
|
|
65
139
|
"""
|
|
66
140
|
|
|
141
|
+
def __init_subclass__(cls, **kwargs):
|
|
142
|
+
if not HAS_HF_HUB:
|
|
143
|
+
super().__init_subclass__(**kwargs)
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
base_tags = ["braindecode", cls.__name__]
|
|
147
|
+
user_tags = kwargs.pop("tags", None)
|
|
148
|
+
tags = list(user_tags) if user_tags is not None else []
|
|
149
|
+
for tag in base_tags:
|
|
150
|
+
if tag not in tags:
|
|
151
|
+
tags.append(tag)
|
|
152
|
+
|
|
153
|
+
docs_url = kwargs.pop(
|
|
154
|
+
"docs_url",
|
|
155
|
+
f"https://braindecode.org/stable/generated/braindecode.models.{cls.__name__}.html",
|
|
156
|
+
)
|
|
157
|
+
repo_url = kwargs.pop("repo_url", "https://braindecode.org")
|
|
158
|
+
library_name = kwargs.pop("library_name", "braindecode")
|
|
159
|
+
license = kwargs.pop("license", "bsd-3-clause")
|
|
160
|
+
# TODO: model_card_template can be added in the future for custom model cards
|
|
161
|
+
super().__init_subclass__(
|
|
162
|
+
tags=tags,
|
|
163
|
+
docs_url=docs_url,
|
|
164
|
+
repo_url=repo_url,
|
|
165
|
+
library_name=library_name,
|
|
166
|
+
license=license,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
67
170
|
def __init__(
|
|
68
171
|
self,
|
|
69
172
|
n_outputs: Optional[int] = None, # type: ignore[assignment]
|
|
@@ -73,6 +176,16 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
73
176
|
input_window_seconds: Optional[float] = None, # type: ignore[assignment]
|
|
74
177
|
sfreq: Optional[float] = None, # type: ignore[assignment]
|
|
75
178
|
):
|
|
179
|
+
# Deserialize chs_info if it comes as a list of dicts (from Hub)
|
|
180
|
+
if chs_info is not None and isinstance(chs_info, list):
|
|
181
|
+
if len(chs_info) > 0 and isinstance(chs_info[0], dict):
|
|
182
|
+
# Check if it needs deserialization (has 'loc' as list)
|
|
183
|
+
if "loc" in chs_info[0] and isinstance(chs_info[0]["loc"], list):
|
|
184
|
+
chs_info = self._deserialize_chs_info(chs_info)
|
|
185
|
+
warnings.warn(
|
|
186
|
+
"Modifying chs_info argument using the _deserialize_chs_info() method"
|
|
187
|
+
)
|
|
188
|
+
|
|
76
189
|
if n_chans is not None and chs_info is not None and len(chs_info) != n_chans:
|
|
77
190
|
raise ValueError(f"{n_chans=} different from {chs_info=} length")
|
|
78
191
|
if (
|
|
@@ -294,3 +407,168 @@ class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
|
|
|
294
407
|
|
|
295
408
|
def __str__(self) -> str:
|
|
296
409
|
return str(self.get_torchinfo_statistics())
|
|
410
|
+
|
|
411
|
+
@staticmethod
|
|
412
|
+
def _serialize_chs_info(chs_info):
|
|
413
|
+
"""
|
|
414
|
+
Serialize MNE channel info to JSON-compatible format.
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
chs_info : list of dict or None
|
|
419
|
+
Channel information from MNE Info object.
|
|
420
|
+
|
|
421
|
+
Returns
|
|
422
|
+
-------
|
|
423
|
+
list of dict or None
|
|
424
|
+
Serialized channel information that can be saved to JSON.
|
|
425
|
+
"""
|
|
426
|
+
if chs_info is None:
|
|
427
|
+
return None
|
|
428
|
+
|
|
429
|
+
serialized = []
|
|
430
|
+
for ch in chs_info:
|
|
431
|
+
# Extract serializable fields from MNE channel info
|
|
432
|
+
ch_dict = {
|
|
433
|
+
"ch_name": ch.get("ch_name", ""),
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
# Handle kind field - can be either string or integer
|
|
437
|
+
kind_val = ch.get("kind")
|
|
438
|
+
if kind_val is not None:
|
|
439
|
+
ch_dict["kind"] = (
|
|
440
|
+
kind_val if isinstance(kind_val, str) else int(kind_val)
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Add numeric fields with safe conversion
|
|
444
|
+
coil_type = ch.get("coil_type")
|
|
445
|
+
if coil_type is not None:
|
|
446
|
+
ch_dict["coil_type"] = int(coil_type)
|
|
447
|
+
|
|
448
|
+
unit = ch.get("unit")
|
|
449
|
+
if unit is not None:
|
|
450
|
+
ch_dict["unit"] = int(unit)
|
|
451
|
+
|
|
452
|
+
cal = ch.get("cal")
|
|
453
|
+
if cal is not None:
|
|
454
|
+
ch_dict["cal"] = float(cal)
|
|
455
|
+
|
|
456
|
+
range_val = ch.get("range")
|
|
457
|
+
if range_val is not None:
|
|
458
|
+
ch_dict["range"] = float(range_val)
|
|
459
|
+
|
|
460
|
+
# Serialize location array if present
|
|
461
|
+
if "loc" in ch and ch["loc"] is not None:
|
|
462
|
+
ch_dict["loc"] = (
|
|
463
|
+
ch["loc"].tolist()
|
|
464
|
+
if hasattr(ch["loc"], "tolist")
|
|
465
|
+
else list(ch["loc"])
|
|
466
|
+
)
|
|
467
|
+
serialized.append(ch_dict)
|
|
468
|
+
|
|
469
|
+
return serialized
|
|
470
|
+
|
|
471
|
+
@staticmethod
|
|
472
|
+
def _deserialize_chs_info(chs_info_dict):
|
|
473
|
+
"""
|
|
474
|
+
Deserialize channel info from JSON-compatible format to MNE-like structure.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
chs_info_dict : list of dict or None
|
|
479
|
+
Serialized channel information.
|
|
480
|
+
|
|
481
|
+
Returns
|
|
482
|
+
-------
|
|
483
|
+
list of dict or None
|
|
484
|
+
Deserialized channel information compatible with MNE.
|
|
485
|
+
"""
|
|
486
|
+
if chs_info_dict is None:
|
|
487
|
+
return None
|
|
488
|
+
|
|
489
|
+
deserialized = []
|
|
490
|
+
for ch_dict in chs_info_dict:
|
|
491
|
+
ch = ch_dict.copy()
|
|
492
|
+
# Convert location back to numpy array if present
|
|
493
|
+
if "loc" in ch and ch["loc"] is not None:
|
|
494
|
+
ch["loc"] = np.array(ch["loc"])
|
|
495
|
+
deserialized.append(ch)
|
|
496
|
+
|
|
497
|
+
return deserialized
|
|
498
|
+
|
|
499
|
+
def _save_pretrained(self, save_directory):
|
|
500
|
+
"""
|
|
501
|
+
Save model configuration and weights to the Hub.
|
|
502
|
+
|
|
503
|
+
This method is called by PyTorchModelHubMixin.push_to_hub() to save
|
|
504
|
+
model-specific configuration alongside the model weights.
|
|
505
|
+
|
|
506
|
+
Parameters
|
|
507
|
+
----------
|
|
508
|
+
save_directory : str or Path
|
|
509
|
+
Directory where the configuration should be saved.
|
|
510
|
+
"""
|
|
511
|
+
if not HAS_HF_HUB:
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
save_directory = Path(save_directory)
|
|
515
|
+
|
|
516
|
+
# Collect EEG-specific configuration
|
|
517
|
+
config = {
|
|
518
|
+
"n_outputs": self._n_outputs,
|
|
519
|
+
"n_chans": self._n_chans,
|
|
520
|
+
"n_times": self._n_times,
|
|
521
|
+
"input_window_seconds": self._input_window_seconds,
|
|
522
|
+
"sfreq": self._sfreq,
|
|
523
|
+
"chs_info": self._serialize_chs_info(self._chs_info),
|
|
524
|
+
"braindecode_version": __version__,
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
# Save to config.json
|
|
528
|
+
config_path = save_directory / "config.json"
|
|
529
|
+
with open(config_path, "w") as f:
|
|
530
|
+
json.dump(config, f, indent=2)
|
|
531
|
+
|
|
532
|
+
# Save model weights with standard Hub filename
|
|
533
|
+
weights_path = save_directory / "pytorch_model.bin"
|
|
534
|
+
torch.save(self.state_dict(), weights_path)
|
|
535
|
+
|
|
536
|
+
# Also save in safetensors format using parent's implementation
|
|
537
|
+
try:
|
|
538
|
+
super()._save_pretrained(save_directory)
|
|
539
|
+
except (ImportError, RuntimeError) as e:
|
|
540
|
+
# Fallback to pytorch_model.bin if safetensors saving fails
|
|
541
|
+
warnings.warn(
|
|
542
|
+
f"Could not save model in safetensors format: {e}. "
|
|
543
|
+
"Model weights saved in pytorch_model.bin instead.",
|
|
544
|
+
stacklevel=2,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
if HAS_HF_HUB:
|
|
548
|
+
|
|
549
|
+
@classmethod
|
|
550
|
+
def _from_pretrained(
|
|
551
|
+
cls,
|
|
552
|
+
*,
|
|
553
|
+
model_id: str,
|
|
554
|
+
revision: Optional[str],
|
|
555
|
+
cache_dir: Optional[Union[str, Path]],
|
|
556
|
+
force_download: bool,
|
|
557
|
+
local_files_only: bool,
|
|
558
|
+
token: Union[str, bool, None],
|
|
559
|
+
map_location: str = "cpu",
|
|
560
|
+
strict: bool = False,
|
|
561
|
+
**model_kwargs,
|
|
562
|
+
):
|
|
563
|
+
model_kwargs.pop("braindecode_version", None)
|
|
564
|
+
return super()._from_pretrained( # type: ignore
|
|
565
|
+
model_id=model_id,
|
|
566
|
+
revision=revision,
|
|
567
|
+
cache_dir=cache_dir,
|
|
568
|
+
force_download=force_download,
|
|
569
|
+
local_files_only=local_files_only,
|
|
570
|
+
token=token,
|
|
571
|
+
map_location=map_location,
|
|
572
|
+
strict=strict,
|
|
573
|
+
**model_kwargs,
|
|
574
|
+
)
|