braindecode 1.3.0.dev175955015__py3-none-any.whl → 1.3.0.dev176481332__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (65) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +10 -2
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +2 -2
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/serialization.py +7 -7
  15. braindecode/functional/functions.py +6 -2
  16. braindecode/functional/initialization.py +2 -3
  17. braindecode/models/__init__.py +2 -0
  18. braindecode/models/atcnet.py +26 -27
  19. braindecode/models/attentionbasenet.py +37 -32
  20. braindecode/models/attn_sleep.py +2 -0
  21. braindecode/models/base.py +2 -2
  22. braindecode/models/bendr.py +469 -0
  23. braindecode/models/biot.py +2 -0
  24. braindecode/models/contrawr.py +2 -0
  25. braindecode/models/ctnet.py +8 -3
  26. braindecode/models/deepsleepnet.py +28 -19
  27. braindecode/models/eegconformer.py +2 -2
  28. braindecode/models/eeginception_erp.py +31 -25
  29. braindecode/models/eegitnet.py +2 -0
  30. braindecode/models/eegminer.py +2 -0
  31. braindecode/models/eegnet.py +1 -1
  32. braindecode/models/eegtcnet.py +2 -0
  33. braindecode/models/fbcnet.py +2 -0
  34. braindecode/models/fblightconvnet.py +2 -0
  35. braindecode/models/fbmsnet.py +2 -0
  36. braindecode/models/ifnet.py +2 -0
  37. braindecode/models/labram.py +33 -26
  38. braindecode/models/msvtnet.py +2 -0
  39. braindecode/models/patchedtransformer.py +1 -1
  40. braindecode/models/signal_jepa.py +8 -0
  41. braindecode/models/sinc_shallow.py +12 -9
  42. braindecode/models/sstdpn.py +11 -11
  43. braindecode/models/summary.csv +1 -0
  44. braindecode/models/syncnet.py +2 -0
  45. braindecode/models/tcn.py +2 -0
  46. braindecode/models/usleep.py +26 -21
  47. braindecode/models/util.py +1 -0
  48. braindecode/modules/attention.py +10 -10
  49. braindecode/modules/blocks.py +3 -3
  50. braindecode/modules/filter.py +2 -3
  51. braindecode/modules/layers.py +18 -17
  52. braindecode/preprocessing/__init__.py +24 -0
  53. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  54. braindecode/preprocessing/preprocess.py +12 -12
  55. braindecode/preprocessing/util.py +166 -0
  56. braindecode/preprocessing/windowers.py +26 -20
  57. braindecode/samplers/base.py +8 -8
  58. braindecode/version.py +1 -1
  59. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/METADATA +4 -2
  60. braindecode-1.3.0.dev176481332.dist-info/RECORD +106 -0
  61. braindecode-1.3.0.dev175955015.dist-info/RECORD +0 -103
  62. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/WHEEL +0 -0
  63. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/licenses/LICENSE.txt +0 -0
  64. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/licenses/NOTICE.txt +0 -0
  65. {braindecode-1.3.0.dev175955015.dist-info → braindecode-1.3.0.dev176481332.dist-info}/top_level.txt +0 -0
@@ -35,51 +35,57 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
35
35
  - :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
36
36
 
37
37
  - *Operations.*
38
- - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
39
- - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
40
- - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
41
- - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
- - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
43
- - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
38
+
39
+ - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
40
+ - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
41
+ - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
+ - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
43
+ - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
44
+ - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
44
45
 
45
46
  *Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
46
47
 
47
48
  - :class:`_InceptionModule2` **(refinement at coarser timebase)**
48
49
 
49
50
  - *Operations.*
50
- - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
51
- - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
52
- - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
53
- - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
54
- - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
55
- - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
56
- - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
51
+
52
+ - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
53
+ - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
54
+ - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` BN → activation → dropout.
55
+ - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
56
+ - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
57
+ - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
58
+ - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
57
59
 
58
60
  *Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
59
61
 
60
62
  - :class:`_OutputModule` **(aggregation + readout)**
61
63
 
62
64
  - *Operations.*
63
- - :class:`torch.nn.Flatten`
64
- - :class:`torch.nn.Linear` ``(features → 2)``
65
+
66
+ - :class:`torch.nn.Flatten`
67
+ - :class:`torch.nn.Linear` ``(features → 2)``
65
68
 
66
69
  .. rubric:: Convolutional Details
67
70
 
68
71
  - **Temporal (where time-domain patterns are learned).**
69
- First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
70
- (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
71
- ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
72
- temporal resolution changes only via average pooling.
72
+
73
+ First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
74
+ (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
75
+ ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
76
+ temporal resolution changes only via average pooling.
73
77
 
74
78
  - **Spatial (how electrodes are processed).**
75
- Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
76
- yielding scale-specific channel projections (no cross-branch mixing until concatenation).
77
- There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
79
+
80
+ Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
81
+ yielding scale-specific channel projections (no cross-branch mixing until concatenation).
82
+ There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
78
83
 
79
84
  - **Spectral (how frequency information is captured).**
80
- No explicit transform; multiple temporal kernels form a *learned filter bank* over
81
- ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
82
- post-stimulus components.
85
+
86
+ No explicit transform; multiple temporal kernels form a *learned filter bank* over
87
+ ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
88
+ post-stimulus components.
83
89
 
84
90
  .. rubric:: Additional Mechanisms
85
91
 
@@ -11,6 +11,8 @@ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
11
11
  class EEGITNet(EEGModuleMixin, nn.Sequential):
12
12
  """EEG-ITNet from Salami, et al (2022) [Salami2022]_
13
13
 
14
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
15
+
14
16
  .. figure:: https://braindecode.org/dev/_static/model/eegitnet.jpg
15
17
  :align: center
16
18
  :alt: EEG-ITNet Architecture
@@ -21,6 +21,8 @@ _eeg_miner_methods = ["mag", "corr", "plv"]
21
21
  class EEGMiner(EEGModuleMixin, nn.Module):
22
22
  """EEGMiner from Ludwig et al (2024) [eegminer]_.
23
23
 
24
+ :bdg-success:`Convolution` :bdg-warning:`Interpretability`
25
+
24
26
  .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036010/revision2/jnead44d7f1_hr.jpg
25
27
  :align: center
26
28
  :alt: EEGMiner Architecture
@@ -57,7 +57,7 @@ class EEGNet(EEGModuleMixin, nn.Sequential):
57
57
 
58
58
  - **Temporal.** The initial temporal convs serve as a *learned filter bank*:
59
59
  long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
60
- Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each features spectrum [Lawhern2018]_.
60
+ Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature's spectrum [Lawhern2018]_.
61
61
 
62
62
  - **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
63
63
  With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
@@ -15,6 +15,8 @@ from braindecode.modules import Chomp1d, MaxNormLinear
15
15
  class EEGTCNet(EEGModuleMixin, nn.Module):
16
16
  """EEGTCNet model from Ingolfsson et al. (2020) [ingolfsson2020]_.
17
17
 
18
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
19
+
18
20
  .. figure:: https://braindecode.org/dev/_static/model/eegtcnet.jpg
19
21
  :align: center
20
22
  :alt: EEGTCNet Architecture
@@ -31,6 +31,8 @@ _valid_layers = {
31
31
  class FBCNet(EEGModuleMixin, nn.Module):
32
32
  """FBCNet from Mane, R et al (2021) [fbcnet2021]_.
33
33
 
34
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
35
+
34
36
  .. figure:: https://raw.githubusercontent.com/ravikiran-mane/FBCNet/refs/heads/master/FBCNet-V2.png
35
37
  :align: center
36
38
  :alt: FBCNet Architecture
@@ -18,6 +18,8 @@ from braindecode.modules import (
18
18
  class FBLightConvNet(EEGModuleMixin, nn.Module):
19
19
  """LightConvNet from Ma, X et al (2023) [lightconvnet]_.
20
20
 
21
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
22
+
21
23
  .. figure:: https://raw.githubusercontent.com/Ma-Xinzhi/LightConvNet/refs/heads/main/network_architecture.png
22
24
  :align: center
23
25
  :alt: LightConvNet Neural Network
@@ -19,6 +19,8 @@ from braindecode.modules import (
19
19
  class FBMSNet(EEGModuleMixin, nn.Module):
20
20
  """FBMSNet from Liu et al (2022) [fbmsnet]_.
21
21
 
22
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
23
+
22
24
  .. figure:: https://raw.githubusercontent.com/Want2Vanish/FBMSNet/refs/heads/main/FBMSNet.png
23
25
  :align: center
24
26
  :alt: FBMSNet Architecture
@@ -31,6 +31,8 @@ from braindecode.modules import (
31
31
  class IFNet(EEGModuleMixin, nn.Module):
32
32
  """IFNetV2 from Wang J et al (2023) [ifnet]_.
33
33
 
34
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
35
+
34
36
  .. figure:: https://raw.githubusercontent.com/Jiaheng-Wang/IFNet/main/IFNet.png
35
37
  :align: center
36
38
  :alt: IFNetV2 Architecture
@@ -2,6 +2,7 @@
2
2
  Labram module.
3
3
  Authors: Wei-Bang Jiang
4
4
  Bruno Aristimunha <b.aristimunha@gmail.com>
5
+ Matthew Chen <matt.chen4260@gmail.com>
5
6
  License: BSD 3 clause
6
7
  """
7
8
 
@@ -22,12 +23,14 @@ from braindecode.modules import MLP, DropPath
22
23
  class Labram(EEGModuleMixin, nn.Module):
23
24
  """Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
25
 
26
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
27
+
25
28
  .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
26
29
  :align: center
27
30
  :alt: Labram Architecture.
28
31
 
29
32
  Large Brain Model for Learning Generic Representations with Tremendous
30
- EEG Data in BCI from [Jiang2024]_
33
+ EEG Data in BCI from [Jiang2024]_.
31
34
 
32
35
  This is an **adaptation** of the code [Code2024]_ from the Labram model.
33
36
 
@@ -35,7 +38,8 @@ class Labram(EEGModuleMixin, nn.Module):
35
38
  BEiTv2 [BeiTv2]_.
36
39
 
37
40
  The models can be used in two modes:
38
- - Neural Tokenizor: Design to get an embedding layers (e.g. classification).
41
+
42
+ - Neural Tokenizer: Design to get an embedding layers (e.g. classification).
39
43
  - Neural Decoder: To extract the ampliture and phase outputs with a VQSNP.
40
44
 
41
45
  The braindecode's modification is to allow the model to be used in
@@ -43,33 +47,36 @@ class Labram(EEGModuleMixin, nn.Module):
43
47
  equals True. The original implementation uses (batch, n_chans, n_patches,
44
48
  patch_size) as input with static segmentation of the input data.
45
49
 
46
- The models have the following sequence of steps:
47
- if neural tokenizer:
48
- - SegmentPatch: Segment the input data in patches;
49
- - TemporalConv: Apply a temporal convolution to the segmented data;
50
- - Residual adding cls, temporal and position embeddings (optional);
51
- - WindowsAttentionBlock: Apply a windows attention block to the data;
52
- - LayerNorm: Apply layer normalization to the data;
53
- - Linear: An head linear layer to transformer the data into classes.
54
-
55
- else:
56
- - PatchEmbed: Apply a patch embedding to the input data;
57
- - Residual adding cls, temporal and position embeddings (optional);
58
- - WindowsAttentionBlock: Apply a windows attention block to the data;
59
- - LayerNorm: Apply layer normalization to the data;
60
- - Linear: An head linear layer to transformer the data into classes.
50
+ The models have the following sequence of steps::
51
+
52
+ if neural tokenizer:
53
+ - SegmentPatch: Segment the input data in patches;
54
+ - TemporalConv: Apply a temporal convolution to the segmented data;
55
+ - Residual adding cls, temporal and position embeddings (optional);
56
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
57
+ - LayerNorm: Apply layer normalization to the data;
58
+ - Linear: An head linear layer to transformer the data into classes.
59
+
60
+ else:
61
+ - PatchEmbed: Apply a patch embedding to the input data;
62
+ - Residual adding cls, temporal and position embeddings (optional);
63
+ - WindowsAttentionBlock: Apply a windows attention block to the data;
64
+ - LayerNorm: Apply layer normalization to the data;
65
+ - Linear: An head linear layer to transformer the data into classes.
61
66
 
62
67
  .. versionadded:: 0.9
63
68
 
64
69
 
65
- Examples on how to load pre-trained weights:
66
- --------------------------------------------
67
- >>> import torch
68
- >>> from braindecode.models import Labram
69
- >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
70
- >>> url = 'https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt'
71
- >>> state = torch.hub.load_state_dict_from_url(url, progress=True)
72
- >>> model.load_state_dict(state)
70
+ Examples
71
+ --------
72
+ Load pre-trained weights::
73
+
74
+ >>> import torch
75
+ >>> from braindecode.models import Labram
76
+ >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
77
+ >>> url = "https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt"
78
+ >>> state = torch.hub.load_state_dict_from_url(url, progress=True)
79
+ >>> model.load_state_dict(state)
73
80
 
74
81
 
75
82
  Parameters
@@ -116,7 +123,7 @@ class Labram(EEGModuleMixin, nn.Module):
116
123
  init_scale : float (default=0.001)
117
124
  The initial scale to be used in the parameters of the model.
118
125
  neural_tokenizer : bool (default=True)
119
- The model can be used in two modes: Neural Tokenizor or Neural Decoder.
126
+ The model can be used in two modes: Neural Tokenizer or Neural Decoder.
120
127
  attn_head_dim : bool (default=None)
121
128
  The head dimension to be used in the attention layer, to be used only
122
129
  during pre-training.
@@ -13,6 +13,8 @@ from braindecode.models.base import EEGModuleMixin
13
13
  class MSVTNet(EEGModuleMixin, nn.Module):
14
14
  """MSVTNet model from Liu K et al (2024) from [msvt2024]_.
15
15
 
16
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Small Attention`
17
+
16
18
  This model implements a multi-scale convolutional transformer network
17
19
  for EEG signal classification, as described in [msvt2024]_.
18
20
 
@@ -16,7 +16,7 @@ from braindecode.models.base import EEGModuleMixin
16
16
  class PBT(EEGModuleMixin, nn.Module):
17
17
  r"""Patched Brain Transformer (PBT) model from Klein et al. (2025) [pbt]_.
18
18
 
19
- :bdg-danger:`Large Brain Models`
19
+ :bdg-danger:`Large Brain Model`
20
20
 
21
21
  This implementation was based in https://github.com/timonkl/PatchedBrainTransformer/
22
22
 
@@ -146,6 +146,8 @@ class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
146
146
  class SignalJEPA(_BaseSignalJEPA):
147
147
  """Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
148
148
 
149
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
150
+
149
151
  This model is not meant for classification but for SSL pre-training.
150
152
  Its output shape depends on the input shape.
151
153
  For classification purposes, three variants of this model are available:
@@ -232,6 +234,8 @@ class SignalJEPA(_BaseSignalJEPA):
232
234
  class SignalJEPA_Contextual(_BaseSignalJEPA):
233
235
  """Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
234
236
 
237
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
238
+
235
239
  This architecture is one of the variants of :class:`SignalJEPA`
236
240
  that can be used for classification purposes.
237
241
 
@@ -403,6 +407,8 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
403
407
  class SignalJEPA_PostLocal(_BaseSignalJEPA):
404
408
  """Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
405
409
 
410
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
411
+
406
412
  This architecture is one of the variants of :class:`SignalJEPA`
407
413
  that can be used for classification purposes.
408
414
 
@@ -552,6 +558,8 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
552
558
  class SignalJEPA_PreLocal(_BaseSignalJEPA):
553
559
  """Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
554
560
 
561
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
562
+
555
563
  This architecture is one of the variants of :class:`SignalJEPA`
556
564
  that can be used for classification purposes.
557
565
 
@@ -12,6 +12,8 @@ from braindecode.models.base import EEGModuleMixin
12
12
  class SincShallowNet(EEGModuleMixin, nn.Module):
13
13
  """Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
14
14
 
15
+ :bdg-success:`Convolution` :bdg-warning:`Interpretability`
16
+
15
17
  .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S0893608020302021-gr2_lrg.jpg
16
18
  :align: center
17
19
  :alt: SincShallowNet Architecture
@@ -19,23 +21,24 @@ class SincShallowNet(EEGModuleMixin, nn.Module):
19
21
  The Sinc-ShallowNet architecture has these fundamental blocks:
20
22
 
21
23
  1. **Block 1: Spectral and Spatial Feature Extraction**
22
- - *Temporal Sinc-Convolutional Layer*:
23
- Uses parametrized sinc functions to learn band-pass filters,
24
- significantly reducing the number of trainable parameters by only
25
- learning the lower and upper cutoff frequencies for each filter.
26
- - *Spatial Depthwise Convolutional Layer*:
27
- Applies depthwise convolutions to learn spatial filters for
28
- each temporal feature map independently, further reducing
29
- parameters and enhancing interpretability.
30
- - *Batch Normalization*
24
+
25
+ - *Temporal Sinc-Convolutional Layer*: Uses parametrized sinc functions to learn band-pass filters,
26
+ significantly reducing the number of trainable parameters by only
27
+ learning the lower and upper cutoff frequencies for each filter.
28
+ - *Spatial Depthwise Convolutional Layer*: Applies depthwise convolutions to learn spatial filters for
29
+ each temporal feature map independently, further reducing
30
+ parameters and enhancing interpretability.
31
+ - *Batch Normalization*
31
32
 
32
33
  2. **Block 2: Temporal Aggregation**
34
+
33
35
  - *Activation Function*: ELU
34
36
  - *Average Pooling Layer*: Aggregation by averaging spatial dim
35
37
  - *Dropout Layer*
36
38
  - *Flatten Layer*
37
39
 
38
40
  3. **Block 3: Classification**
41
+
39
42
  - *Fully Connected Layer*: Maps the feature vector to n_outputs.
40
43
 
41
44
  **Implementation Notes:**
@@ -24,7 +24,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
24
24
  :alt: SSTDPN Architecture
25
25
  :width: 1000px
26
26
 
27
- The **SpatialSpectral** and **Temporal - Dual Prototype Network** (SST-DPN)
27
+ The **Spatial-Spectral** and **Temporal - Dual Prototype Network** (SST-DPN)
28
28
  is an end-to-end 1D convolutional architecture designed for motor imagery (MI) EEG decoding,
29
29
  aiming to address challenges related to discriminative feature extraction and
30
30
  small-sample sizes [Han2025]_.
@@ -37,9 +37,9 @@ class SSTDPN(EEGModuleMixin, nn.Module):
37
37
  SST-DPN consists of a feature extractor (_SSTEncoder, comprising Adaptive Spatial-Spectral
38
38
  Fusion and Multi-scale Variance Pooling) followed by Dual Prototype Learning classification [Han2025]_.
39
39
 
40
- 1. **Adaptive SpatialSpectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
41
- multi-channel spatialspectral representation, followed by :class:`_SpatSpectralAttn`
42
- (Spatial-Spectral Attention) to model relationships and highlight key spatialspectral
40
+ 1. **Adaptive Spatial-Spectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
41
+ multi-channel spatial-spectral representation, followed by :class:`_SpatSpectralAttn`
42
+ (Spatial-Spectral Attention) to model relationships and highlight key spatial-spectral
43
43
  channels [Han2025]_.
44
44
 
45
45
  2. **Multi-scale Variance Pooling (MVP)**: Applies :class:`_MultiScaleVarPooler` with variance pooling
@@ -57,7 +57,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
57
57
 
58
58
  - `SSTDPN.encoder` **(Feature Extractor)**
59
59
 
60
- - *Operations.* Combines Adaptive SpatialSpectral Fusion and Multi-scale Variance Pooling
60
+ - *Operations.* Combines Adaptive Spatial-Spectral Fusion and Multi-scale Variance Pooling
61
61
  via an internal :class:`_SSTEncoder`.
62
62
  - *Role.* Maps the raw MI-EEG trial :math:`X_i \in \mathbb{R}^{C \times T}` to the
63
63
  feature space :math:`z_i \in \mathbb{R}^d`.
@@ -69,11 +69,11 @@ class SSTDPN(EEGModuleMixin, nn.Module):
69
69
  depth multiplier `n_spectral_filters_temporal` (equivalent to :math:`F_1` in the paper).
70
70
  - *Role.* Extracts multiple distinct spectral bands from each EEG channel independently.
71
71
 
72
- - `_SSTEncoder.spt_attn` **(SpatialSpectral Attention for Channel Gating)**
72
+ - `_SSTEncoder.spt_attn` **(Spatial-Spectral Attention for Channel Gating)**
73
73
 
74
74
  - *Operations.* Internal :class:`_SpatSpectralAttn` module using Global Context Embedding
75
75
  via variance-based pooling, followed by adaptive channel normalization and gating.
76
- - *Role.* Reweights channels in the spatialspectral dimension to extract efficient and
76
+ - *Role.* Reweights channels in the spatial-spectral dimension to extract efficient and
77
77
  discriminative features by emphasizing task-relevant regions and frequency bands.
78
78
 
79
79
  - `_SSTEncoder.chan_conv` **(Pointwise Fusion across Channels)**
@@ -81,7 +81,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
81
81
  - *Operations.* A 1D pointwise convolution with `n_fused_filters` output channels
82
82
  (equivalent to :math:`F_2` in the paper), followed by BatchNorm and the specified
83
83
  `activation` function (default: ELU).
84
- - *Role.* Fuses the weighted spatialspectral features across all electrodes to produce
84
+ - *Role.* Fuses the weighted spatial-spectral features across all electrodes to produce
85
85
  a fused representation :math:`X_{fused} \in \mathbb{R}^{F_2 \times T}`.
86
86
 
87
87
  - `_SSTEncoder.mvp` **(Multi-scale Variance Pooling for Temporal Extraction)**
@@ -109,11 +109,11 @@ class SSTDPN(EEGModuleMixin, nn.Module):
109
109
  * **Spatial.**
110
110
  The initial convolution at the classes :class:`_DepthwiseTemporalConv1d` groups parameter :math:`h=1`,
111
111
  meaning :math:`F_1` temporal filters are shared across channels. The Spatial-Spectral Attention
112
- mechanism explicitly models the relationships among these channels in the spatialspectral
112
+ mechanism explicitly models the relationships among these channels in the spatial-spectral
113
113
  dimension, allowing for finer-grained spatial feature modeling compared to conventional
114
114
  GCNs according to the authors [Han2025]_.
115
115
  In other words, all electrode channels share :math:`F_1` temporal filters
116
- independently to produce the spatialspectral representation.
116
+ independently to produce the spatial-spectral representation.
117
117
 
118
118
  * **Spectral.**
119
119
  Spectral information is implicitly extracted via the :math:`F_1` filters in :class:`_DepthwiseTemporalConv1d`.
@@ -123,7 +123,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
123
123
 
124
124
  .. rubric:: Additional Mechanisms
125
125
 
126
- - **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatialspectral relationships
126
+ - **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatial-spectral relationships
127
127
  at the channel level, distinct from applying attention to deep feature dimensions,
128
128
  which is common in comparison methods like :class:`ATCNet`.
129
129
  - **Regularization.** Dual Prototype Learning acts as a regularization technique
@@ -39,3 +39,4 @@ FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sf
39
39
  IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
40
40
  PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Large Brain Model"
41
41
  SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
42
+ BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
@@ -10,6 +10,8 @@ from braindecode.models.base import EEGModuleMixin
10
10
  class SyncNet(EEGModuleMixin, nn.Module):
11
11
  """Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
12
12
 
13
+ :bdg-warning:`Interpretability`
14
+
13
15
  .. figure:: https://braindecode.org/dev/_static/model/SyncNet.png
14
16
  :align: center
15
17
  :alt: SyncNet Architecture
braindecode/models/tcn.py CHANGED
@@ -14,6 +14,8 @@ from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
14
14
  class BDTCN(EEGModuleMixin, nn.Module):
15
15
  """Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
16
16
 
17
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
18
+
17
19
  .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
18
20
  :align: center
19
21
  :alt: Braindecode TCN Architecture
@@ -62,43 +62,48 @@ class USleep(EEGModuleMixin, nn.Module):
62
62
  - Decoder :class:`_DecoderBlock` **(progressive upsampling + skip fusion to high-frequency map, 12 blocks; upsampling x2 per block)**
63
63
 
64
64
  - *Operations.*
65
- - **Nearest-neighbor upsample**, :class:`nn.Upsample` (x2)
66
- - **Convolution2d** (k=2), :class:`torch.nn.Conv2d`
67
- - ELU, :class:`torch.nn.ELU`
68
- - Batch Norm, :class:`torch.nn.BatchNorm2d`
69
- - **Concatenate** with the encoder skip at the same temporal scale, :function:`torch.cat`
70
- - **Convolution**, :class:`torch.nn.Conv2d`
71
- - ELU, :class:`torch.nn.ELU`
72
- - Batch Norm, :class:`torch.nn.BatchNorm2d`.
65
+
66
+ - **Nearest-neighbor upsample**, :class:`nn.Upsample` (x2)
67
+ - **Convolution2d** (k=2), :class:`torch.nn.Conv2d`
68
+ - ELU, :class:`torch.nn.ELU`
69
+ - Batch Norm, :class:`torch.nn.BatchNorm2d`
70
+ - **Concatenate** with the encoder skip at the same temporal scale, ``torch.cat``
71
+ - **Convolution**, :class:`torch.nn.Conv2d`
72
+ - ELU, :class:`torch.nn.ELU`
73
+ - Batch Norm, :class:`torch.nn.BatchNorm2d`.
73
74
 
74
75
  **Output**: A multi-class, **high-frequency** per-sample representation aligned to the input rate (128 Hz).
75
76
 
76
77
  - **Segment Classifier incorporate into :class:`braindecode.models.USleep` (aggregation to fixed epochs)**
77
78
 
78
79
  - *Operations.*
79
- - **Mean-pool**, :class:`torch.nn.AvgPool2d` per class with kernel = epoch length *i* and stride *i*
80
- - **1x1 conv**, :class:`torch.nn.Conv2d`
81
- - ELU, :class:`torch.nn.ELU`
82
- - **1x1 conv**, :class:`torch.nn.Conv2d` with ``(T, K)`` (epochs x stages).
80
+
81
+ - **Mean-pool**, :class:`torch.nn.AvgPool2d` per class with kernel = epoch length *i* and stride *i*
82
+ - **1x1 conv**, :class:`torch.nn.Conv2d`
83
+ - ELU, :class:`torch.nn.ELU`
84
+ - **1x1 conv**, :class:`torch.nn.Conv2d` with ``(T, K)`` (epochs x stages).
83
85
 
84
86
  **Role**: Learns a **non-linear** weighted combination over each 30-s window (unlike U-Time's linear combiner).
85
87
 
86
88
  .. rubric:: Convolutional Details
87
89
 
88
90
  - **Temporal (where time-domain patterns are learned).**
89
- All convolutions are **1-D along time**; depth (12 levels) plus pooling yields an extensive receptive field
90
- (reported sensitivity to ±6.75 min around each epoch; theoretical field 9.6 min at the deepest layer).
91
- The decoder restores sample-level resolution before epoch aggregation.
91
+
92
+ All convolutions are **1-D along time**; depth (12 levels) plus pooling yields an extensive receptive field
93
+ (reported sensitivity to ±6.75 min around each epoch; theoretical field ≈ 9.6 min at the deepest layer).
94
+ The decoder restores sample-level resolution before epoch aggregation.
92
95
 
93
96
  - **Spatial (how channels are processed).**
94
- Convolutions mix across the *channel* dimension jointly with time (no separate spatial operator). The system
95
- is **montage-agnostic** (any reasonable EEG/EOG pair) and was trained across diverse cohorts/protocols,
96
- supporting robustness to channel placement and hardware differences.
97
+
98
+ Convolutions mix across the *channel* dimension jointly with time (no separate spatial operator). The system
99
+ is **montage-agnostic** (any reasonable EEG/EOG pair) and was trained across diverse cohorts/protocols,
100
+ supporting robustness to channel placement and hardware differences.
97
101
 
98
102
  - **Spectral (how frequency content is captured).**
99
- No explicit Fourier/wavelet transform is used; the **stack of temporal convolutions** acts as a learned
100
- filter bank whose effective bandwidth grows with depth. The high-frequency decoder output (128 Hz)
101
- retains fine temporal detail for the segment classifier.
103
+
104
+ No explicit Fourier/wavelet transform is used; the **stack of temporal convolutions** acts as a learned
105
+ filter bank whose effective bandwidth grows with depth. The high-frequency decoder output (128 Hz)
106
+ retains fine temporal detail for the segment classifier.
102
107
 
103
108
 
104
109
  .. rubric:: Attention / Sequential Modules
@@ -97,6 +97,7 @@ models_mandatory_parameters = [
97
97
  ("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
98
98
  ("PBT", ["n_chans", "n_outputs", "n_times"], None),
99
99
  ("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
100
+ ("BENDR", ["n_chans", "n_outputs", "n_times"], None),
100
101
  ]
101
102
 
102
103
  ################################################################