braindecode 1.2.0.dev168908090__py3-none-any.whl → 1.2.0.dev175267687__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (33) hide show
  1. braindecode/datasets/experimental.py +218 -0
  2. braindecode/models/__init__.py +6 -8
  3. braindecode/models/atcnet.py +1 -1
  4. braindecode/models/attentionbasenet.py +151 -26
  5. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  6. braindecode/models/ctnet.py +1 -1
  7. braindecode/models/deep4.py +6 -2
  8. braindecode/models/deepsleepnet.py +118 -5
  9. braindecode/models/eegconformer.py +15 -12
  10. braindecode/models/eeginception_erp.py +76 -7
  11. braindecode/models/eeginception_mi.py +2 -0
  12. braindecode/models/eegnet.py +25 -189
  13. braindecode/models/eegnex.py +113 -6
  14. braindecode/models/eegsimpleconv.py +2 -0
  15. braindecode/models/eegtcnet.py +1 -1
  16. braindecode/models/sccnet.py +79 -8
  17. braindecode/models/shallow_fbcsp.py +2 -0
  18. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  19. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  20. braindecode/models/sparcnet.py +2 -0
  21. braindecode/models/summary.csv +39 -41
  22. braindecode/models/tidnet.py +2 -0
  23. braindecode/models/tsinception.py +15 -3
  24. braindecode/models/usleep.py +103 -9
  25. braindecode/models/util.py +5 -5
  26. braindecode/version.py +1 -1
  27. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/METADATA +7 -2
  28. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/RECORD +32 -32
  29. braindecode/models/eegresnet.py +0 -362
  30. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/WHEEL +0 -0
  31. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/licenses/LICENSE.txt +0 -0
  32. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/licenses/NOTICE.txt +0 -0
  33. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/top_level.txt +0 -0
@@ -8,14 +8,128 @@ from braindecode.models.base import EEGModuleMixin
8
8
 
9
9
 
10
10
  class DeepSleepNet(EEGModuleMixin, nn.Module):
11
- """Sleep staging architecture from Supratak et al. (2017) [Supratak2017]_.
11
+ """DeepSleepNet from Supratak et al. (2017) [Supratak2017]_.
12
12
 
13
- .. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/refs/heads/master/img/deepsleepnet.png
13
+ :bdg-success:`Convolution` :bdg-info:`Recurrent`
14
+
15
+ .. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/master/img/deepsleepnet.png
14
16
  :align: center
15
17
  :alt: DeepSleepNet Architecture
18
+ :width: 700px
19
+
20
+ .. rubric:: Architectural Overview
21
+
22
+ DeepSleepNet couples **dual-path convolution neural network representation learning** with
23
+ **sequence residual learning** via bidirectional LSTMs.
24
+
25
+ The network have:
26
+
27
+ - (i) learns complementary, time-frequency features from each
28
+ 30-s epoch using **two parallel CNNs** (small vs. large first-layer filters), then
29
+ - (ii) models **temporal dependencies across epochs** using **two-layer BiLSTMs**
30
+ with a **residual shortcut** from the CNN features, and finally
31
+ - (iii) outputs per-epoch sleep stages. This design encodes both
32
+ epoch-local patterns and longer-range transition rules used by human scorers.
33
+
34
+ In term of implementation:
35
+
36
+ - (i) :class:`_RepresentationLearning` two CNNs extract epoch-wise features
37
+ (small-filter path for temporal precision; large-filter path for frequency precision);
38
+ - (ii) :class:`_SequenceResidualLearning` stacked BiLSTMs with peepholes + residual shortcut
39
+ inject temporal context while preserving CNN evidence;
40
+ - (iii) :class:`_Classifier` linear readout (softmax) for the five sleep stages.
41
+
42
+ .. rubric:: Macro Components
43
+
44
+ - :class:`_RepresentationLearning` **(dual-path CNN → epoch feature)**
45
+
46
+ - *Operations.*
47
+ - **Small-filter CNN** 4 times:
48
+ - :class:`~torch.nn.Conv1d`
49
+ - :class:`~torch.nn.BatchNorm1d`
50
+ - :class:`~torch.nn.ReLU`
51
+ - :class:`~torch.nn.MaxPool1d` after.
52
+ First conv uses **filter length ≈ Fs/2** and **stride ≈ Fs/16** to emphasize *timing* of graphoelements.
53
+ - **Large-filter CNN**:
54
+ - Same stack but first conv uses **filter length ≈ 4·Fs** and
55
+ - **stride ≈ Fs/2** to emphasize *frequency* content.
56
+ - Outputs from both paths are **concatenated** into the epoch embedding ``a_t``.
57
+
58
+ - *Rationale.*
59
+ Two first-layer scales provide a **learned, dual-scale filter bank** that trades
60
+ temporal vs. frequency precision without hand-crafted features.
61
+
62
+ - :class:`_SequenceResidualLearning` (:class:`~torch.nn.BiLSTM` **context + residual fusion)**
63
+
64
+ - *Operations.*
65
+ - **Two-layer BiLSTM** with **peephole connections** processes the sequence of epoch embeddings
66
+ ``{a_t}`` forward and backward; hidden states from both directions are **concatenated**.
67
+ - A **shortcut MLP** (fully connected + :class:`~torch.nn.BatchNorm1d` + :class:`~torch.nn.ReLU`) projects ``a_t`` to the BiLSTM output
68
+ dimension and is **added** (residual) to the :class:`~torch.nn.BiLSTM` output at each time step.
69
+ - *Role.* Encodes **stage-transition rules** and smooths predictions over time while preserving
70
+ salient CNN features via the residual path.
71
+
72
+ - :class:`_Classifier` **(epoch-wise prediction)**
73
+
74
+ - *Operations.*
75
+ - :class:`~torch.nn.Linear` to produce per-epoch class probabilities.
16
76
 
17
- Convolutional neural network and bidirectional-Long Short-Term
18
- for single channels sleep staging described in [Supratak2017]_.
77
+ Original training uses two-step optimization: CNN pretraining on class-balanced data,
78
+ then end-to-end fine-tuning with sequential batches.
79
+
80
+ .. rubric:: Convolutional Details
81
+
82
+ - **Temporal (where time-domain patterns are learned).**
83
+
84
+ Both CNN paths use **1-D temporal convolutions**. The *small-filter* path (first kernel ≈ Fs/2,
85
+ stride ≈ Fs/16) captures *when* characteristic transients occur; the *large-filter* path
86
+ (first kernel ≈ 4·Fs, stride ≈ Fs/2) captures *which* frequency components dominate over the
87
+ epoch. Deeper layers use **small kernels** to refine features with fewer parameters, interleaved
88
+ with **max pooling** for downsampling.
89
+
90
+ - **Spatial (how channels are processed).**
91
+ The original model operates on **single-channel** raw EEG; convolutions therefore mix only
92
+ along time (no spatial convolution across electrodes).
93
+
94
+ - **Spectral (how frequency information emerges).**
95
+ No explicit Fourier/wavelet transform is used. The **large-filter path** serves as a
96
+ *frequency-sensitive* analyzer, while the **small-filter path** remains *time-sensitive*,
97
+ together functioning as a **two-band learned filter bank** at the first layer.
98
+
99
+ .. rubric:: Attention / Sequential Modules
100
+
101
+ - **Type.** **Bidirectional LSTM** (two layers) with **peephole connections**; forward and
102
+ backward streams are independent and concatenated.
103
+ - **Shapes.** For a sequence of ``N`` epochs, the CNN produces ``{a_t} ∈ R^{D}``;
104
+ BiLSTM outputs ``h_t ∈ R^{2H}``; the shortcut MLP maps ``a_t → R^{2H}`` to enable
105
+ **element-wise residual addition**.
106
+ - **Role.** Models **long-range temporal dependencies** (e.g., persisting N2 without visible
107
+ K-complex/spindles), stabilizing per-epoch predictions.
108
+
109
+
110
+ .. rubric:: Additional Mechanisms
111
+
112
+ - **Residual shortcut over sequence encoder.** Adds projected CNN features to BiLSTM outputs,
113
+ improving gradient flow and retaining discriminative content from representation learning.
114
+ - **Two-step training.**
115
+ - (i) **Pretrain** the CNN paths with class-balanced sampling;
116
+ - (ii) **fine-tune** the full network with sequential batches, using **lower LR** for CNNs and **higher LR** for the
117
+ sequence encoder.
118
+ - **State handling.** BiLSTM states are **reinitialized per subject** so that temporal context
119
+ does not leak across recordings.
120
+
121
+
122
+ .. rubric:: Usage and Configuration
123
+
124
+ - **Epoch pipeline.** Use **two parallel CNNs** with the first conv sized to **Fs/2** (small path)
125
+ and **4·Fs** (large path), with strides **Fs/16** and **Fs/2**, respectively; stack three more
126
+ conv blocks with small kernels, plus **max pooling** in each path. Concatenate path outputs
127
+ to form epoch embeddings.
128
+ - **Sequence encoder.** Apply **two-layer BiLSTM (peepholes)** over the sequence of embeddings;
129
+ add a **projection MLP** on the CNN features and **sum** with BiLSTM outputs (residual).
130
+ Finish with :class:`~torch.nn.Linear` per epoch.
131
+ - **Reference implementation.** See the official repository for a faithful implementation and
132
+ training scripts.
19
133
 
20
134
  Parameters
21
135
  ----------
@@ -32,7 +146,6 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
32
146
  drop_prob : float, default=0.5
33
147
  The dropout rate for regularization. Values should be between 0 and 1.
34
148
 
35
-
36
149
  References
37
150
  ----------
38
151
  .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
@@ -27,22 +27,25 @@ class EEGConformer(EEGModuleMixin, nn.Module):
27
27
  EEG-Conformer is a *convolution-first* model augmented with a *lightweight transformer
28
28
  encoder*. The end-to-end flow is:
29
29
 
30
- - (i) :class:`_PatchEmbedding` converts the continuous EEG into a compact sequence of tokens via a :class:`ShallowFBCSPNet` temporal–spatial conv stem and temporal pooling;
31
- - (ii) :class:`_TransformerEncoder applies small multi-head self-attention to integrate longer-range temporal context across tokens;
30
+ - (i) :class:`_PatchEmbedding` converts the continuous EEG into a compact sequence of tokens via a
31
+ :class:`ShallowFBCSPNet` temporal–spatial conv stem and temporal pooling;
32
+ - (ii) :class:`_TransformerEncoder` applies small multi-head self-attention to integrate
33
+ longer-range temporal context across tokens;
32
34
  - (iii) :class:`_ClassificationHead` aggregates the sequence and performs a linear readout.
33
- This preserves the strong inductive biases of shallow CNN filter banks while adding
34
- just enough attention to capture dependencies beyond the pooling horizon [song2022]_.
35
+ This preserves the strong inductive biases of shallow CNN filter banks while adding
36
+ just enough attention to capture dependencies beyond the pooling horizon [song2022]_.
35
37
 
36
38
  .. rubric:: Macro Components
37
39
 
38
40
  - :class:`_PatchEmbedding` **(Shallow conv stem → tokens)**
39
41
 
40
- - *Operations.*
41
- - A temporal convolution (`:class:`torch.nn.Conv2d`) ``(1 x L_t)`` forms a data-driven "filter bank";
42
- - A spatial convolution (`:class:`torch.nn.Conv2d`) (n_chans x 1)`` projects across electrodes, collapsing the channel axis into a virtual channel.
43
- - **Normalization function** `:class:torch.nn.BatchNorm`
44
- - **Activation function** `:class:torch.nn.ELU`
45
- - **Average Pooling** `:class:torch.nn.AvgPool` along time (kernel ``(1, P)`` with stride ``(1, S)``)
42
+ - *Operations.*
43
+ - A temporal convolution (`:class:torch.nn.Conv2d`) ``(1 x L_t)`` forms a data-driven "filter bank";
44
+ - A spatial convolution (`:class:torch.nn.Conv2d`) (n_chans x 1)`` projects across electrodes,
45
+ collapsing the channel axis into a virtual channel.
46
+ - **Normalization function** :class:`torch.nn.BatchNorm`
47
+ - **Activation function** :class:`torch.nn.ELU`
48
+ - **Average Pooling** :class:`torch.nn.AvgPool` along time (kernel ``(1, P)`` with stride ``(1, S)``)
46
49
  - final ``1x1`` :class:`torch.nn.Linear` projection.
47
50
 
48
51
  The result is rearranged to a token sequence ``(B, S_tokens, D)``, where ``D = n_filters_time``.
@@ -53,7 +56,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
53
56
 
54
57
  - :class:`_TransformerEncoder` **(context over temporal tokens)**
55
58
 
56
- - *Operations.*
59
+ - *Operations.*
57
60
  - A stack of ``att_depth`` encoder blocks. :class:`_TransformerEncoderBlock`
58
61
  - Each block applies LayerNorm :class:`torch.nn.LayerNorm`
59
62
  - Multi-Head Self-Attention (``att_heads``) with dropout + residual :class:`MultiHeadAttention` (:class:`torch.nn.Dropout`)
@@ -67,7 +70,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
67
70
 
68
71
  - :class:`ClassificationHead` **(aggregation + readout)**
69
72
 
70
- - *Operations*.
73
+ - *Operations*.
71
74
  - Flatten, :class:`torch.nn.Flatten` the sequence ``(B, S_tokens·D)`` -
72
75
  - MLP (:class:`torch.nn.Linear` → activation (default: :class:`torch.nn.ELU`) → :class:`torch.nn.Dropout` → :class:`torch.nn.Linear`)
73
76
  - final Linear to classes.
@@ -15,12 +15,84 @@ from braindecode.modules import DepthwiseConv2d, Ensure4d, InceptionBlock
15
15
  class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
16
16
  """EEG Inception for ERP-based from Santamaria-Vazquez et al (2020) [santamaria2020]_.
17
17
 
18
+ :bdg-success:`Convolution`
19
+
18
20
  .. figure:: https://braindecode.org/dev/_static/model/eeginceptionerp.jpg
19
21
  :align: center
20
22
  :alt: EEGInceptionERP Architecture
21
23
 
22
- The code for the paper and this model is also available at [santamaria2020]_
23
- and an adaptation for PyTorch [2]_.
24
+ Figure: Overview of EEG-Inception architecture. 2D convolution blocks and depthwise 2D convolution blocks include batch normalization, activation and dropout regularization. The kernel size is displayed for convolutional and average pooling layers.
25
+
26
+ .. rubric:: Architectural Overview
27
+
28
+ A two-stage, multi-scale CNN tailored to ERP detection from short (0-1000 ms) single-trial epochs. Signals are mapped through
29
+ * (i) :class:`_InceptionModule1` multi-scale temporal feature extraction plus per-branch spatial mixing;
30
+ * (ii) :class:`_InceptionModule2` deeper multi-scale refinement at a reduced temporal resolution; and
31
+ * (iii) :class:`_OutputModule` compact aggregation and linear readout.
32
+
33
+ .. rubric:: Macro Components
34
+
35
+ - :class:`_InceptionModule1` **(multi-scale temporal + spatial mixing)**
36
+
37
+ - *Operations.*
38
+ - `EEGInceptionERP.c1`: :class:`torch.nn.Conv2d` ``k=(64,1)``, stride ``(1,1)``, *same* pad on input reshaped to ``(B,1,128,8)`` → BN → activation → dropout.
39
+ - `EEGInceptionERP.d1`: :class:`torch.nn.Conv2d` (depthwise) ``k=(1,8)``, *valid* pad over channels → BN → activation → dropout.
40
+ - `EEGInceptionERP.c2`: :class:`torch.nn.Conv2d` ``k=(32,1)`` → BN → activation → dropout; then `EEGInceptionERP.d2` depthwise ``k=(1,8)`` → BN → activation → dropout.
41
+ - `EEGInceptionERP.c3`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout; then `EEGInceptionERP.d3` depthwise ``k=(1,8)`` → BN → activation → dropout.
42
+ - `EEGInceptionERP.n1`: :class:`torch.nn.Concat` over branch features.
43
+ - `EEGInceptionERP.a1`: :class:`torch.nn.AvgPool2d` ``pool=(4,1)``, stride ``(4,1)`` for temporal downsampling.
44
+
45
+ *Interpretability/robustness.* Depthwise `1 x n_chans` layers act as learnable montage-wide spatial filters per temporal scale; pooling stabilizes against jitter.
46
+
47
+ - :class:`_InceptionModule2` **(refinement at coarser timebase)**
48
+
49
+ - *Operations.*
50
+ - `EEGInceptionERP.c4`: :class:`torch.nn.Conv2d` ``k=(16,1)`` → BN → activation → dropout.
51
+ - `EEGInceptionERP.c5`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout.
52
+ - `EEGInceptionERP.c6`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout.
53
+ - `EEGInceptionERP.n2`: :class:`torch.nn.Concat` (merge C4-C6 outputs).
54
+ - `EEGInceptionERP.a2`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``, stride ``(2,1)``.
55
+ - `EEGInceptionERP.c7`: :class:`torch.nn.Conv2d` ``k=(8,1)`` → BN → activation → dropout; then `EEGInceptionERP.a3`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
56
+ - `EEGInceptionERP.c8`: :class:`torch.nn.Conv2d` ``k=(4,1)`` → BN → activation → dropout; then `EEGInceptionERP.a4`: :class:`torch.nn.AvgPool2d` ``pool=(2,1)``.
57
+
58
+ *Role.* Adds higher-level, shorter-window evidence while progressively compressing temporal dimension.
59
+
60
+ - :class:`_OutputModule` **(aggregation + readout)**
61
+
62
+ - *Operations.*
63
+ - :class:`torch.nn.Flatten`
64
+ - :class:`torch.nn.Linear` ``(features → 2)``
65
+
66
+ .. rubric:: Convolutional Details
67
+
68
+ - **Temporal (where time-domain patterns are learned).**
69
+ First module uses 1D temporal kernels along the 128-sample axis: ``64``, ``32``, ``16``
70
+ (≈500, 250, 125 ms at 128 Hz). After ``pool=(4,1)``, the second module applies ``16``,
71
+ ``8``, ``4`` (≈125, 62.5, 31.25 ms at the pooled rate). All strides are ``1`` in convs;
72
+ temporal resolution changes only via average pooling.
73
+
74
+ - **Spatial (how electrodes are processed).**
75
+ Depthwise convs with ``k=(1,8)`` span all channels and are applied **per temporal branch**,
76
+ yielding scale-specific channel projections (no cross-branch mixing until concatenation).
77
+ There is no full 2D mixing kernel; spatial mixing is factorized and lightweight.
78
+
79
+ - **Spectral (how frequency information is captured).**
80
+ No explicit transform; multiple temporal kernels form a *learned filter bank* over
81
+ ERP-relevant bands. Successive pooling acts as low-pass integration to emphasize sustained
82
+ post-stimulus components.
83
+
84
+ .. rubric:: Additional Mechanisms
85
+
86
+ - Every conv/depthwise block includes **BatchNorm**, nonlinearity (paper used grid-searched activation), and **dropout**.
87
+ - Two Inception stages followed by short convs and pooling keep parameters small (≈15k reported) while preserving multi-scale evidence.
88
+ - Expected input: epochs of shape ``(B,1,128,8)`` (time x channels as a 2D map) or reshaped from ``(B,8,128)`` with an added singleton feature dimension.
89
+
90
+ .. rubric:: Usage and Configuration
91
+
92
+ - **Key knobs.** Number of filters per branch; kernel lengths in both Inception modules; depthwise kernel over channels (typically ``n_chans``); pooling lengths/strides; dropout rate; choice of activation.
93
+ - **Training tips.** Use 0-1000 ms windows at 128 Hz with CAR; tune activation and dropout (they strongly affect performance); early-stop on validation loss when overfitting emerges.
94
+
95
+ .. rubric:: Implementation Details
24
96
 
25
97
  The model is strongly based on the original InceptionNet for an image. The main goal is
26
98
  to extract features in parallel with different scales. The authors extracted three scales
@@ -33,12 +105,9 @@ class EEGInceptionERP(EEGModuleMixin, nn.Sequential):
33
105
  The winners of BEETL Competition/NeurIps 2021 used parts of the
34
106
  model [beetl]_.
35
107
 
36
- The model is fully described in [santamaria2020]_.
108
+ The code for the paper and this model is also available at [santamaria2020]_
109
+ and an adaptation for PyTorch [2]_.
37
110
 
38
- Notes
39
- -----
40
- This implementation is not guaranteed to be correct, has not been checked
41
- by original authors, only reimplemented from the paper based on [2]_.
42
111
 
43
112
  Parameters
44
113
  ----------
@@ -13,6 +13,8 @@ from braindecode.modules import Ensure4d
13
13
  class EEGInceptionMI(EEGModuleMixin, nn.Module):
14
14
  """EEG Inception for Motor Imagery, as proposed in Zhang et al. (2021) [1]_
15
15
 
16
+ :bdg-success:`Convolution`
17
+
16
18
  .. figure:: https://content.cld.iop.org/journals/1741-2552/18/4/046014/revision3/jneabed81f1_hr.jpg
17
19
  :align: center
18
20
  :alt: EEGInceptionMI Architecture
@@ -6,7 +6,7 @@ from __future__ import annotations
6
6
  from typing import Dict, Optional
7
7
 
8
8
  from einops.layers.torch import Rearrange
9
- from mne.utils import warn
9
+ from mne.utils import deprecated, warn
10
10
  from torch import nn
11
11
 
12
12
  from braindecode.functional import glorot_weight_zero_bias
@@ -19,23 +19,21 @@ from braindecode.modules import (
19
19
  )
20
20
 
21
21
 
22
- class EEGNetv4(EEGModuleMixin, nn.Sequential):
23
- """EEGNet v4 model from Lawhern et al. (2018) [Lawhern2018]_.
24
-
25
- :bdg-success:`Convolution` :bdg-secondary:`Depthwise–Separable`
22
+ class EEGNet(EEGModuleMixin, nn.Sequential):
23
+ """EEGNet model from Lawhern et al. (2018) [Lawhern2018]_.
26
24
 
25
+ :bdg-success:`Convolution`
27
26
  .. figure:: https://content.cld.iop.org/journals/1741-2552/15/5/056013/revision2/jneaace8cf01_hr.jpg
28
27
  :align: center
29
- :alt: EEGNetv4 Architecture
28
+ :alt: EEGNet Architecture
30
29
  :width: 600px
31
30
 
32
31
  .. rubric:: Architectural Overview
33
32
 
34
- EEGNetv4 is a compact convolutional network designed for EEG decoding with a
35
- pipeline that mirrors classical EEG processing:
33
+ EEGNet is a compact convolutional network designed for EEG decoding with a pipeline that mirrors classical EEG processing:
36
34
  - (i) learn temporal frequency-selective filters,
37
35
  - (ii) learn spatial filters for those frequencies, and
38
- - (iii) condense features with depthwiseseparable convolutions before a lightweight classifier.
36
+ - (iii) condense features with depthwise-separable convolutions before a lightweight classifier.
39
37
 
40
38
  The architecture is deliberately small (temporal convolutional and spatial patterns) [Lawhern2018]_.
41
39
 
@@ -46,9 +44,9 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
46
44
  - **Depthwise Spatial Filtering.**
47
45
  Depthwise convolution spanning the channel dimension with ``groups = F1``,
48
46
  yielding ``D`` spatial filters for each temporal filter (no cross-filter mixing).
49
- - **NormNonlinearityPooling (+ dropout).**
47
+ - **Norm-Nonlinearity-Pooling (+ dropout).**
50
48
  Batch normalization → ELU → temporal pooling, with dropout.
51
- - **DepthwiseSeparable Convolution Block.**
49
+ - **Depthwise-Separable Convolution Block.**
52
50
  (a) depthwise temporal conv to refine temporal structure;
53
51
  (b) pointwise 1x1 conv to mix feature maps into ``F2`` combinations.
54
52
  - **Classifier Head.**
@@ -56,16 +54,16 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
56
54
 
57
55
  .. rubric:: Convolutional Details
58
56
 
59
- **Temporal.** The initial temporal convs serve as a *learned filter bank*:
60
- long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
61
- Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature’s spectrum [Lawhern2018]_.
57
+ - **Temporal.** The initial temporal convs serve as a *learned filter bank*:
58
+ long 1-D kernels (implemented as 2-D with singleton spatial extent) emphasize oscillatory bands and transients.
59
+ Because this stage is linear prior to BN/ELU, kernels can be analyzed as FIR filters to reveal each feature’s spectrum [Lawhern2018]_.
62
60
 
63
- **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
64
- With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
65
- typically regularized with max-norm.
61
+ - **Spatial.** The depthwise spatial conv spans the full channel axis (kernel height = #electrodes; temporal size = 1).
62
+ With ``groups = F1``, each temporal filter learns its own set of ``D`` spatial projections—akin to CSP, learned end-to-end and
63
+ typically regularized with max-norm.
66
64
 
67
- **Spectral.** No explicit Fourier/wavelet transform is used. Frequency structure
68
- is captured implicitly by the temporal filter bank; later depthwise temporal kernels act as short-time integrators/refiners.
65
+ - **Spectral.** No explicit Fourier/wavelet transform is used. Frequency structure
66
+ is captured implicitly by the temporal filter bank; later depthwise temporal kernels act as short-time integrators/refiners.
69
67
 
70
68
  .. rubric:: Additional Comments
71
69
 
@@ -74,6 +72,8 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
74
72
  - **Depthwise & separable convs:** Parameter-efficient decomposition (depthwise + pointwise) retains power while limiting overfitting
75
73
  [Chollet2017]_ and keeps temporal vs. mixing steps interpretable.
76
74
  - **Regularization:** Batch norm, dropout, pooling, and optional max-norm on spatial kernels aid stability on small EEG datasets.
75
+ - The v4 means the version 4 at the arxiv paper [Lawhern2018]_.
76
+
77
77
 
78
78
  Parameters
79
79
  ----------
@@ -349,174 +349,10 @@ class EEGNetv4(EEGModuleMixin, nn.Sequential):
349
349
  glorot_weight_zero_bias(self)
350
350
 
351
351
 
352
- class EEGNetv1(EEGModuleMixin, nn.Sequential):
353
- """EEGNet model from Lawhern et al. 2016 from [EEGNet]_.
354
-
355
- See details in [EEGNet]_.
356
-
357
- Parameters
358
- ----------
359
- in_chans :
360
- Alias for n_chans.
361
- n_classes:
362
- Alias for n_outputs.
363
- input_window_samples :
364
- Alias for n_times.
365
- activation: nn.Module, default=nn.ELU
366
- Activation function class to apply. Should be a PyTorch activation
367
- module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
368
-
369
- Notes
370
- -----
371
- This implementation is not guaranteed to be correct, has not been checked
372
- by original authors, only reimplemented from the paper description.
373
-
374
- References
375
- ----------
376
- .. [EEGNet] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
377
- S. M., Hung, C. P., & Lance, B. J. (2016).
378
- EEGNet: A Compact Convolutional Network for EEG-based
379
- Brain-Computer Interfaces.
380
- arXiv preprint arXiv:1611.08024.
381
- """
382
-
383
- def __init__(
384
- self,
385
- n_chans=None,
386
- n_outputs=None,
387
- n_times=None,
388
- final_conv_length="auto",
389
- pool_mode="max",
390
- second_kernel_size=(2, 32),
391
- third_kernel_size=(8, 4),
392
- drop_prob=0.25,
393
- activation: nn.Module = nn.ELU,
394
- chs_info=None,
395
- input_window_seconds=None,
396
- sfreq=None,
397
- ):
398
- super().__init__(
399
- n_outputs=n_outputs,
400
- n_chans=n_chans,
401
- chs_info=chs_info,
402
- n_times=n_times,
403
- input_window_seconds=input_window_seconds,
404
- sfreq=sfreq,
405
- )
406
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
407
- warn(
408
- "The class EEGNetv1 is deprecated and will be removed in the "
409
- "release 1.0 of braindecode. Please use "
410
- "braindecode.models.EEGNetv4 instead in the future.",
411
- DeprecationWarning,
412
- )
413
- if final_conv_length == "auto":
414
- assert self.n_times is not None
415
- self.final_conv_length = final_conv_length
416
- self.pool_mode = pool_mode
417
- self.second_kernel_size = second_kernel_size
418
- self.third_kernel_size = third_kernel_size
419
- self.drop_prob = drop_prob
420
- # For the load_state_dict
421
- # When padronize all layers,
422
- # add the old's parameters here
423
- self.mapping = {
424
- "conv_classifier.weight": "final_layer.conv_classifier.weight",
425
- "conv_classifier.bias": "final_layer.conv_classifier.bias",
426
- }
427
-
428
- pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
429
- self.add_module("ensuredims", Ensure4d())
430
- n_filters_1 = 16
431
- self.add_module(
432
- "conv_1",
433
- nn.Conv2d(self.n_chans, n_filters_1, (1, 1), stride=1, bias=True),
434
- )
435
- self.add_module(
436
- "bnorm_1",
437
- nn.BatchNorm2d(n_filters_1, momentum=0.01, affine=True, eps=1e-3),
438
- )
439
- self.add_module("elu_1", activation())
440
- # transpose to examples x 1 x (virtual, not EEG) channels x time
441
- self.add_module("permute_1", Rearrange("batch x y z -> batch z x y"))
442
-
443
- self.add_module("drop_1", nn.Dropout(p=self.drop_prob))
444
-
445
- n_filters_2 = 4
446
- # keras pads unequal padding more in front, so padding
447
- # too large should be ok.
448
- # Not padding in time so that cropped training makes sense
449
- # https://stackoverflow.com/questions/43994604/padding-with-even-kernel-size-in-a-convolutional-layer-in-keras-theano
450
-
451
- self.add_module(
452
- "conv_2",
453
- nn.Conv2d(
454
- 1,
455
- n_filters_2,
456
- self.second_kernel_size,
457
- stride=1,
458
- padding=(self.second_kernel_size[0] // 2, 0),
459
- bias=True,
460
- ),
461
- )
462
- self.add_module(
463
- "bnorm_2",
464
- nn.BatchNorm2d(n_filters_2, momentum=0.01, affine=True, eps=1e-3),
465
- )
466
- self.add_module("elu_2", activation())
467
- self.add_module("pool_2", pool_class(kernel_size=(2, 4), stride=(2, 4)))
468
- self.add_module("drop_2", nn.Dropout(p=self.drop_prob))
469
-
470
- n_filters_3 = 4
471
- self.add_module(
472
- "conv_3",
473
- nn.Conv2d(
474
- n_filters_2,
475
- n_filters_3,
476
- self.third_kernel_size,
477
- stride=1,
478
- padding=(self.third_kernel_size[0] // 2, 0),
479
- bias=True,
480
- ),
481
- )
482
- self.add_module(
483
- "bnorm_3",
484
- nn.BatchNorm2d(n_filters_3, momentum=0.01, affine=True, eps=1e-3),
485
- )
486
- self.add_module("elu_3", activation())
487
- self.add_module("pool_3", pool_class(kernel_size=(2, 4), stride=(2, 4)))
488
- self.add_module("drop_3", nn.Dropout(p=self.drop_prob))
489
-
490
- output_shape = self.get_output_shape()
491
- n_out_virtual_chans = output_shape[2]
492
-
493
- if self.final_conv_length == "auto":
494
- n_out_time = output_shape[3]
495
- self.final_conv_length = n_out_time
496
-
497
- # Incorporating classification module and subsequent ones in one final layer
498
- module = nn.Sequential()
499
-
500
- module.add_module(
501
- "conv_classifier",
502
- nn.Conv2d(
503
- n_filters_3,
504
- self.n_outputs,
505
- (n_out_virtual_chans, self.final_conv_length),
506
- bias=True,
507
- ),
508
- )
509
-
510
- # Transpose back to the logic of braindecode,
511
-
512
- # so time in third dimension (axis=2)
513
- module.add_module(
514
- "permute_2",
515
- Rearrange("batch x y z -> batch x z y"),
516
- )
517
-
518
- module.add_module("squeeze", SqueezeFinalOutput())
519
-
520
- self.add_module("final_layer", module)
352
+ @deprecated(
353
+ "`EEGNetv4` was renamed to `EEGNet` in v1.12; this alias will be removed in v1.14."
354
+ )
355
+ class EEGNetv4(EEGNet):
356
+ """Deprecated alias for EEGNet."""
521
357
 
522
- glorot_weight_zero_bias(self)
358
+ pass