braindecode 1.3.0.dev180851780__py3-none-any.whl → 1.3.0.dev181594385__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (66) hide show
  1. braindecode/augmentation/base.py +1 -1
  2. braindecode/datasets/__init__.py +10 -2
  3. braindecode/datasets/base.py +115 -151
  4. braindecode/datasets/bcicomp.py +4 -4
  5. braindecode/datasets/bids.py +3 -3
  6. braindecode/datasets/experimental.py +2 -2
  7. braindecode/datasets/mne.py +3 -5
  8. braindecode/datasets/moabb.py +2 -2
  9. braindecode/datasets/nmt.py +2 -2
  10. braindecode/datasets/sleep_physio_challe_18.py +2 -2
  11. braindecode/datasets/sleep_physionet.py +2 -2
  12. braindecode/datasets/tuh.py +2 -2
  13. braindecode/datasets/xy.py +2 -2
  14. braindecode/datautil/serialization.py +7 -7
  15. braindecode/eegneuralnet.py +2 -0
  16. braindecode/functional/functions.py +6 -2
  17. braindecode/functional/initialization.py +2 -3
  18. braindecode/models/__init__.py +2 -0
  19. braindecode/models/atcnet.py +26 -27
  20. braindecode/models/attentionbasenet.py +39 -32
  21. braindecode/models/attn_sleep.py +2 -0
  22. braindecode/models/base.py +280 -2
  23. braindecode/models/bendr.py +469 -0
  24. braindecode/models/biot.py +2 -0
  25. braindecode/models/contrawr.py +2 -0
  26. braindecode/models/ctnet.py +8 -3
  27. braindecode/models/deepsleepnet.py +28 -19
  28. braindecode/models/eegconformer.py +2 -2
  29. braindecode/models/eeginception_erp.py +31 -25
  30. braindecode/models/eegitnet.py +2 -0
  31. braindecode/models/eegminer.py +2 -0
  32. braindecode/models/eegnet.py +1 -1
  33. braindecode/models/eegtcnet.py +2 -0
  34. braindecode/models/fbcnet.py +2 -0
  35. braindecode/models/fblightconvnet.py +2 -0
  36. braindecode/models/fbmsnet.py +2 -0
  37. braindecode/models/ifnet.py +2 -0
  38. braindecode/models/labram.py +193 -87
  39. braindecode/models/msvtnet.py +2 -0
  40. braindecode/models/patchedtransformer.py +1 -1
  41. braindecode/models/signal_jepa.py +111 -27
  42. braindecode/models/sinc_shallow.py +12 -9
  43. braindecode/models/sstdpn.py +11 -11
  44. braindecode/models/summary.csv +1 -0
  45. braindecode/models/syncnet.py +2 -0
  46. braindecode/models/tcn.py +2 -0
  47. braindecode/models/usleep.py +26 -21
  48. braindecode/models/util.py +1 -0
  49. braindecode/modules/attention.py +10 -10
  50. braindecode/modules/blocks.py +3 -3
  51. braindecode/modules/filter.py +2 -3
  52. braindecode/modules/layers.py +18 -17
  53. braindecode/preprocessing/__init__.py +24 -0
  54. braindecode/preprocessing/eegprep_preprocess.py +1202 -0
  55. braindecode/preprocessing/preprocess.py +12 -12
  56. braindecode/preprocessing/util.py +166 -0
  57. braindecode/preprocessing/windowers.py +24 -19
  58. braindecode/samplers/base.py +8 -8
  59. braindecode/version.py +1 -1
  60. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/METADATA +6 -2
  61. braindecode-1.3.0.dev181594385.dist-info/RECORD +106 -0
  62. braindecode-1.3.0.dev180851780.dist-info/RECORD +0 -103
  63. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/WHEEL +0 -0
  64. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/LICENSE.txt +0 -0
  65. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/licenses/NOTICE.txt +0 -0
  66. {braindecode-1.3.0.dev180851780.dist-info → braindecode-1.3.0.dev181594385.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@ from braindecode.models.base import EEGModuleMixin
16
16
  class PBT(EEGModuleMixin, nn.Module):
17
17
  r"""Patched Brain Transformer (PBT) model from Klein et al. (2025) [pbt]_.
18
18
 
19
- :bdg-danger:`Large Brain Models`
19
+ :bdg-danger:`Large Brain Model`
20
20
 
21
21
  This implementation was based in https://github.com/timonkl/PatchedBrainTransformer/
22
22
 
@@ -5,7 +5,8 @@ from __future__ import annotations
5
5
 
6
6
  import math
7
7
  from copy import deepcopy
8
- from typing import Any, Sequence
8
+ from pathlib import Path
9
+ from typing import Any, Optional, Sequence
9
10
 
10
11
  import torch
11
12
  from einops.layers.torch import Rearrange
@@ -145,6 +146,8 @@ class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
145
146
  class SignalJEPA(_BaseSignalJEPA):
146
147
  """Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
147
148
 
149
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
150
+
148
151
  This model is not meant for classification but for SSL pre-training.
149
152
  Its output shape depends on the input shape.
150
153
  For classification purposes, three variants of this model are available:
@@ -231,6 +234,8 @@ class SignalJEPA(_BaseSignalJEPA):
231
234
  class SignalJEPA_Contextual(_BaseSignalJEPA):
232
235
  """Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
233
236
 
237
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
238
+
234
239
  This architecture is one of the variants of :class:`SignalJEPA`
235
240
  that can be used for classification purposes.
236
241
 
@@ -319,25 +324,50 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
319
324
  @classmethod
320
325
  def from_pretrained(
321
326
  cls,
322
- model: SignalJEPA,
323
- n_outputs: int,
327
+ model: Optional[SignalJEPA | str | Path] = None, # type: ignore
328
+ n_outputs: Optional[int] = None, # type: ignore
324
329
  n_spat_filters: int = 4,
325
- chs_info: list[dict[str, Any]] | None = None,
330
+ chs_info: Optional[list[dict[str, Any]]] = None, # type: ignore
331
+ **kwargs,
326
332
  ):
327
- """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
333
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
328
334
 
329
335
  Parameters
330
336
  ----------
331
- model: SignalJEPA
332
- Pre-trained model.
333
- n_outputs: int
334
- Number of classes for the new model.
337
+ model: SignalJEPA, str, Path, or None
338
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
339
+ (for Hub-style loading), or None (for Hub loading via kwargs).
340
+ n_outputs: int or None
341
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
342
+ optional when loading from Hub (will be read from config).
335
343
  n_spat_filters: int
336
344
  Number of spatial filters.
337
345
  chs_info: list of dict | None
338
346
  Information about each individual EEG channel. This should be filled with
339
347
  ``info["chs"]``. Refer to :class:`mne.Info` for more details.
348
+ **kwargs
349
+ Additional keyword arguments passed to the parent class for Hub loading.
340
350
  """
351
+ # Check if this is a Hub-style load (from a directory path)
352
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
353
+ # This is a Hub load, delegate to parent class
354
+ if isinstance(model, (str, Path)):
355
+ # model is actually the repo_id or directory path
356
+ return super().from_pretrained(model, **kwargs)
357
+ else:
358
+ # model is None, treat as hub-style load
359
+ return super().from_pretrained(**kwargs)
360
+
361
+ # This is the original SignalJEPA transfer learning case
362
+ if not isinstance(model, SignalJEPA):
363
+ raise TypeError(
364
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
365
+ )
366
+ if n_outputs is None:
367
+ raise ValueError(
368
+ "n_outputs must be provided when loading from a SignalJEPA model"
369
+ )
370
+
341
371
  feature_encoder = model.feature_encoder
342
372
  pos_encoder = model.pos_encoder
343
373
  transformer = model.transformer
@@ -377,6 +407,8 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
377
407
  class SignalJEPA_PostLocal(_BaseSignalJEPA):
378
408
  """Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
379
409
 
410
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
411
+
380
412
  This architecture is one of the variants of :class:`SignalJEPA`
381
413
  that can be used for classification purposes.
382
414
 
@@ -463,22 +495,47 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
463
495
 
464
496
  @classmethod
465
497
  def from_pretrained(
466
- cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
498
+ cls,
499
+ model: SignalJEPA | str | Path = None, # type: ignore
500
+ n_outputs: int = None, # type: ignore
501
+ n_spat_filters: int = 4,
502
+ **kwargs,
467
503
  ):
468
- """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
504
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
469
505
 
470
506
  Parameters
471
507
  ----------
472
- model: SignalJEPA
473
- Pre-trained model.
474
- n_outputs: int
475
- Number of classes for the new model.
508
+ model: SignalJEPA, str, Path, or None
509
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
510
+ (for Hub-style loading), or None (for Hub loading via kwargs).
511
+ n_outputs: int or None
512
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
513
+ optional when loading from Hub (will be read from config).
476
514
  n_spat_filters: int
477
515
  Number of spatial filters.
478
- chs_info: list of dict | None
479
- Information about each individual EEG channel. This should be filled with
480
- ``info["chs"]``. Refer to :class:`mne.Info` for more details.
516
+ **kwargs
517
+ Additional keyword arguments passed to the parent class for Hub loading.
481
518
  """
519
+ # Check if this is a Hub-style load (from a directory path)
520
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
521
+ # This is a Hub load, delegate to parent class
522
+ if isinstance(model, (str, Path)):
523
+ # model is actually the repo_id or directory path
524
+ return super().from_pretrained(model, **kwargs)
525
+ else:
526
+ # model is None, treat as hub-style load
527
+ return super().from_pretrained(**kwargs)
528
+
529
+ # This is the original SignalJEPA transfer learning case
530
+ if not isinstance(model, SignalJEPA):
531
+ raise TypeError(
532
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
533
+ )
534
+ if n_outputs is None:
535
+ raise ValueError(
536
+ "n_outputs must be provided when loading from a SignalJEPA model"
537
+ )
538
+
482
539
  feature_encoder = model.feature_encoder
483
540
  assert feature_encoder is not None
484
541
  new_model = cls(
@@ -501,6 +558,8 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
501
558
  class SignalJEPA_PreLocal(_BaseSignalJEPA):
502
559
  """Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
503
560
 
561
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
562
+
504
563
  This architecture is one of the variants of :class:`SignalJEPA`
505
564
  that can be used for classification purposes.
506
565
 
@@ -597,22 +656,47 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
597
656
 
598
657
  @classmethod
599
658
  def from_pretrained(
600
- cls, model: SignalJEPA, n_outputs: int, n_spat_filters: int = 4
659
+ cls,
660
+ model: SignalJEPA | str | Path = None, # type: ignore
661
+ n_outputs: int = None, # type: ignore
662
+ n_spat_filters: int = 4,
663
+ **kwargs,
601
664
  ):
602
- """Instantiate a new model from a pre-trained :class:`SignalJEPA` model.
665
+ """Instantiate a new model from a pre-trained :class:`SignalJEPA` model or from Hub.
603
666
 
604
667
  Parameters
605
668
  ----------
606
- model: SignalJEPA
607
- Pre-trained model.
608
- n_outputs: int
609
- Number of classes for the new model.
669
+ model: SignalJEPA, str, Path, or None
670
+ Either a pre-trained :class:`SignalJEPA` model, a string/Path to a local directory
671
+ (for Hub-style loading), or None (for Hub loading via kwargs).
672
+ n_outputs: int or None
673
+ Number of classes for the new model. Required when loading from a SignalJEPA model,
674
+ optional when loading from Hub (will be read from config).
610
675
  n_spat_filters: int
611
676
  Number of spatial filters.
612
- chs_info: list of dict | None
613
- Information about each individual EEG channel. This should be filled with
614
- ``info["chs"]``. Refer to :class:`mne.Info` for more details.
677
+ **kwargs
678
+ Additional keyword arguments passed to the parent class for Hub loading.
615
679
  """
680
+ # Check if this is a Hub-style load (from a directory path)
681
+ if isinstance(model, (str, Path)) or (model is None and kwargs):
682
+ # This is a Hub load, delegate to parent class
683
+ if isinstance(model, (str, Path)):
684
+ # model is actually the repo_id or directory path
685
+ return super().from_pretrained(model, **kwargs)
686
+ else:
687
+ # model is None, treat as hub-style load
688
+ return super().from_pretrained(**kwargs)
689
+
690
+ # This is the original SignalJEPA transfer learning case
691
+ if not isinstance(model, SignalJEPA):
692
+ raise TypeError(
693
+ f"model must be a SignalJEPA instance, a path string, or Path object, got {type(model)}"
694
+ )
695
+ if n_outputs is None:
696
+ raise ValueError(
697
+ "n_outputs must be provided when loading from a SignalJEPA model"
698
+ )
699
+
616
700
  feature_encoder = model.feature_encoder
617
701
  assert feature_encoder is not None
618
702
  new_model = cls(
@@ -12,6 +12,8 @@ from braindecode.models.base import EEGModuleMixin
12
12
  class SincShallowNet(EEGModuleMixin, nn.Module):
13
13
  """Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
14
14
 
15
+ :bdg-success:`Convolution` :bdg-warning:`Interpretability`
16
+
15
17
  .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S0893608020302021-gr2_lrg.jpg
16
18
  :align: center
17
19
  :alt: SincShallowNet Architecture
@@ -19,23 +21,24 @@ class SincShallowNet(EEGModuleMixin, nn.Module):
19
21
  The Sinc-ShallowNet architecture has these fundamental blocks:
20
22
 
21
23
  1. **Block 1: Spectral and Spatial Feature Extraction**
22
- - *Temporal Sinc-Convolutional Layer*:
23
- Uses parametrized sinc functions to learn band-pass filters,
24
- significantly reducing the number of trainable parameters by only
25
- learning the lower and upper cutoff frequencies for each filter.
26
- - *Spatial Depthwise Convolutional Layer*:
27
- Applies depthwise convolutions to learn spatial filters for
28
- each temporal feature map independently, further reducing
29
- parameters and enhancing interpretability.
30
- - *Batch Normalization*
24
+
25
+ - *Temporal Sinc-Convolutional Layer*: Uses parametrized sinc functions to learn band-pass filters,
26
+ significantly reducing the number of trainable parameters by only
27
+ learning the lower and upper cutoff frequencies for each filter.
28
+ - *Spatial Depthwise Convolutional Layer*: Applies depthwise convolutions to learn spatial filters for
29
+ each temporal feature map independently, further reducing
30
+ parameters and enhancing interpretability.
31
+ - *Batch Normalization*
31
32
 
32
33
  2. **Block 2: Temporal Aggregation**
34
+
33
35
  - *Activation Function*: ELU
34
36
  - *Average Pooling Layer*: Aggregation by averaging spatial dim
35
37
  - *Dropout Layer*
36
38
  - *Flatten Layer*
37
39
 
38
40
  3. **Block 3: Classification**
41
+
39
42
  - *Fully Connected Layer*: Maps the feature vector to n_outputs.
40
43
 
41
44
  **Implementation Notes:**
@@ -24,7 +24,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
24
24
  :alt: SSTDPN Architecture
25
25
  :width: 1000px
26
26
 
27
- The **SpatialSpectral** and **Temporal - Dual Prototype Network** (SST-DPN)
27
+ The **Spatial-Spectral** and **Temporal - Dual Prototype Network** (SST-DPN)
28
28
  is an end-to-end 1D convolutional architecture designed for motor imagery (MI) EEG decoding,
29
29
  aiming to address challenges related to discriminative feature extraction and
30
30
  small-sample sizes [Han2025]_.
@@ -37,9 +37,9 @@ class SSTDPN(EEGModuleMixin, nn.Module):
37
37
  SST-DPN consists of a feature extractor (_SSTEncoder, comprising Adaptive Spatial-Spectral
38
38
  Fusion and Multi-scale Variance Pooling) followed by Dual Prototype Learning classification [Han2025]_.
39
39
 
40
- 1. **Adaptive SpatialSpectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
41
- multi-channel spatialspectral representation, followed by :class:`_SpatSpectralAttn`
42
- (Spatial-Spectral Attention) to model relationships and highlight key spatialspectral
40
+ 1. **Adaptive Spatial-Spectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
41
+ multi-channel spatial-spectral representation, followed by :class:`_SpatSpectralAttn`
42
+ (Spatial-Spectral Attention) to model relationships and highlight key spatial-spectral
43
43
  channels [Han2025]_.
44
44
 
45
45
  2. **Multi-scale Variance Pooling (MVP)**: Applies :class:`_MultiScaleVarPooler` with variance pooling
@@ -57,7 +57,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
57
57
 
58
58
  - `SSTDPN.encoder` **(Feature Extractor)**
59
59
 
60
- - *Operations.* Combines Adaptive SpatialSpectral Fusion and Multi-scale Variance Pooling
60
+ - *Operations.* Combines Adaptive Spatial-Spectral Fusion and Multi-scale Variance Pooling
61
61
  via an internal :class:`_SSTEncoder`.
62
62
  - *Role.* Maps the raw MI-EEG trial :math:`X_i \in \mathbb{R}^{C \times T}` to the
63
63
  feature space :math:`z_i \in \mathbb{R}^d`.
@@ -69,11 +69,11 @@ class SSTDPN(EEGModuleMixin, nn.Module):
69
69
  depth multiplier `n_spectral_filters_temporal` (equivalent to :math:`F_1` in the paper).
70
70
  - *Role.* Extracts multiple distinct spectral bands from each EEG channel independently.
71
71
 
72
- - `_SSTEncoder.spt_attn` **(SpatialSpectral Attention for Channel Gating)**
72
+ - `_SSTEncoder.spt_attn` **(Spatial-Spectral Attention for Channel Gating)**
73
73
 
74
74
  - *Operations.* Internal :class:`_SpatSpectralAttn` module using Global Context Embedding
75
75
  via variance-based pooling, followed by adaptive channel normalization and gating.
76
- - *Role.* Reweights channels in the spatialspectral dimension to extract efficient and
76
+ - *Role.* Reweights channels in the spatial-spectral dimension to extract efficient and
77
77
  discriminative features by emphasizing task-relevant regions and frequency bands.
78
78
 
79
79
  - `_SSTEncoder.chan_conv` **(Pointwise Fusion across Channels)**
@@ -81,7 +81,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
81
81
  - *Operations.* A 1D pointwise convolution with `n_fused_filters` output channels
82
82
  (equivalent to :math:`F_2` in the paper), followed by BatchNorm and the specified
83
83
  `activation` function (default: ELU).
84
- - *Role.* Fuses the weighted spatialspectral features across all electrodes to produce
84
+ - *Role.* Fuses the weighted spatial-spectral features across all electrodes to produce
85
85
  a fused representation :math:`X_{fused} \in \mathbb{R}^{F_2 \times T}`.
86
86
 
87
87
  - `_SSTEncoder.mvp` **(Multi-scale Variance Pooling for Temporal Extraction)**
@@ -109,11 +109,11 @@ class SSTDPN(EEGModuleMixin, nn.Module):
109
109
  * **Spatial.**
110
110
  The initial convolution at the classes :class:`_DepthwiseTemporalConv1d` groups parameter :math:`h=1`,
111
111
  meaning :math:`F_1` temporal filters are shared across channels. The Spatial-Spectral Attention
112
- mechanism explicitly models the relationships among these channels in the spatialspectral
112
+ mechanism explicitly models the relationships among these channels in the spatial-spectral
113
113
  dimension, allowing for finer-grained spatial feature modeling compared to conventional
114
114
  GCNs according to the authors [Han2025]_.
115
115
  In other words, all electrode channels share :math:`F_1` temporal filters
116
- independently to produce the spatialspectral representation.
116
+ independently to produce the spatial-spectral representation.
117
117
 
118
118
  * **Spectral.**
119
119
  Spectral information is implicitly extracted via the :math:`F_1` filters in :class:`_DepthwiseTemporalConv1d`.
@@ -123,7 +123,7 @@ class SSTDPN(EEGModuleMixin, nn.Module):
123
123
 
124
124
  .. rubric:: Additional Mechanisms
125
125
 
126
- - **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatialspectral relationships
126
+ - **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatial-spectral relationships
127
127
  at the channel level, distinct from applying attention to deep feature dimensions,
128
128
  which is common in comparison methods like :class:`ATCNet`.
129
129
  - **Regularization.** Dual Prototype Learning acts as a regularization technique
@@ -39,3 +39,4 @@ FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sf
39
39
  IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
40
40
  PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Large Brain Model"
41
41
  SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Small Attention"
42
+ BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Large Brain Model,Convolution"
@@ -10,6 +10,8 @@ from braindecode.models.base import EEGModuleMixin
10
10
  class SyncNet(EEGModuleMixin, nn.Module):
11
11
  """Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
12
12
 
13
+ :bdg-warning:`Interpretability`
14
+
13
15
  .. figure:: https://braindecode.org/dev/_static/model/SyncNet.png
14
16
  :align: center
15
17
  :alt: SyncNet Architecture
braindecode/models/tcn.py CHANGED
@@ -14,6 +14,8 @@ from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
14
14
  class BDTCN(EEGModuleMixin, nn.Module):
15
15
  """Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
16
16
 
17
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
18
+
17
19
  .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
18
20
  :align: center
19
21
  :alt: Braindecode TCN Architecture
@@ -62,43 +62,48 @@ class USleep(EEGModuleMixin, nn.Module):
62
62
  - Decoder :class:`_DecoderBlock` **(progressive upsampling + skip fusion to high-frequency map, 12 blocks; upsampling x2 per block)**
63
63
 
64
64
  - *Operations.*
65
- - **Nearest-neighbor upsample**, :class:`nn.Upsample` (x2)
66
- - **Convolution2d** (k=2), :class:`torch.nn.Conv2d`
67
- - ELU, :class:`torch.nn.ELU`
68
- - Batch Norm, :class:`torch.nn.BatchNorm2d`
69
- - **Concatenate** with the encoder skip at the same temporal scale, :function:`torch.cat`
70
- - **Convolution**, :class:`torch.nn.Conv2d`
71
- - ELU, :class:`torch.nn.ELU`
72
- - Batch Norm, :class:`torch.nn.BatchNorm2d`.
65
+
66
+ - **Nearest-neighbor upsample**, :class:`nn.Upsample` (x2)
67
+ - **Convolution2d** (k=2), :class:`torch.nn.Conv2d`
68
+ - ELU, :class:`torch.nn.ELU`
69
+ - Batch Norm, :class:`torch.nn.BatchNorm2d`
70
+ - **Concatenate** with the encoder skip at the same temporal scale, ``torch.cat``
71
+ - **Convolution**, :class:`torch.nn.Conv2d`
72
+ - ELU, :class:`torch.nn.ELU`
73
+ - Batch Norm, :class:`torch.nn.BatchNorm2d`.
73
74
 
74
75
  **Output**: A multi-class, **high-frequency** per-sample representation aligned to the input rate (128 Hz).
75
76
 
76
77
  - **Segment Classifier incorporate into :class:`braindecode.models.USleep` (aggregation to fixed epochs)**
77
78
 
78
79
  - *Operations.*
79
- - **Mean-pool**, :class:`torch.nn.AvgPool2d` per class with kernel = epoch length *i* and stride *i*
80
- - **1x1 conv**, :class:`torch.nn.Conv2d`
81
- - ELU, :class:`torch.nn.ELU`
82
- - **1x1 conv**, :class:`torch.nn.Conv2d` with ``(T, K)`` (epochs x stages).
80
+
81
+ - **Mean-pool**, :class:`torch.nn.AvgPool2d` per class with kernel = epoch length *i* and stride *i*
82
+ - **1x1 conv**, :class:`torch.nn.Conv2d`
83
+ - ELU, :class:`torch.nn.ELU`
84
+ - **1x1 conv**, :class:`torch.nn.Conv2d` with ``(T, K)`` (epochs x stages).
83
85
 
84
86
  **Role**: Learns a **non-linear** weighted combination over each 30-s window (unlike U-Time's linear combiner).
85
87
 
86
88
  .. rubric:: Convolutional Details
87
89
 
88
90
  - **Temporal (where time-domain patterns are learned).**
89
- All convolutions are **1-D along time**; depth (12 levels) plus pooling yields an extensive receptive field
90
- (reported sensitivity to ±6.75 min around each epoch; theoretical field 9.6 min at the deepest layer).
91
- The decoder restores sample-level resolution before epoch aggregation.
91
+
92
+ All convolutions are **1-D along time**; depth (12 levels) plus pooling yields an extensive receptive field
93
+ (reported sensitivity to ±6.75 min around each epoch; theoretical field ≈ 9.6 min at the deepest layer).
94
+ The decoder restores sample-level resolution before epoch aggregation.
92
95
 
93
96
  - **Spatial (how channels are processed).**
94
- Convolutions mix across the *channel* dimension jointly with time (no separate spatial operator). The system
95
- is **montage-agnostic** (any reasonable EEG/EOG pair) and was trained across diverse cohorts/protocols,
96
- supporting robustness to channel placement and hardware differences.
97
+
98
+ Convolutions mix across the *channel* dimension jointly with time (no separate spatial operator). The system
99
+ is **montage-agnostic** (any reasonable EEG/EOG pair) and was trained across diverse cohorts/protocols,
100
+ supporting robustness to channel placement and hardware differences.
97
101
 
98
102
  - **Spectral (how frequency content is captured).**
99
- No explicit Fourier/wavelet transform is used; the **stack of temporal convolutions** acts as a learned
100
- filter bank whose effective bandwidth grows with depth. The high-frequency decoder output (128 Hz)
101
- retains fine temporal detail for the segment classifier.
103
+
104
+ No explicit Fourier/wavelet transform is used; the **stack of temporal convolutions** acts as a learned
105
+ filter bank whose effective bandwidth grows with depth. The high-frequency decoder output (128 Hz)
106
+ retains fine temporal detail for the segment classifier.
102
107
 
103
108
 
104
109
  .. rubric:: Attention / Sequential Modules
@@ -97,6 +97,7 @@ models_mandatory_parameters = [
97
97
  ("IFNet", ["n_chans", "n_outputs", "n_times", "sfreq"], dict(sfreq=200.0)),
98
98
  ("PBT", ["n_chans", "n_outputs", "n_times"], None),
99
99
  ("SSTDPN", ["n_chans", "n_outputs", "n_times", "sfreq"], None),
100
+ ("BENDR", ["n_chans", "n_outputs", "n_times"], None),
100
101
  ]
101
102
 
102
103
  ################################################################
@@ -38,7 +38,7 @@ class SqueezeAndExcitation(nn.Module):
38
38
  References
39
39
  ----------
40
40
  .. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
41
- Squeeze-and-Excitation Networks. CVPR 2018.
41
+ Squeeze-and-Excitation Networks. CVPR 2018.
42
42
  """
43
43
 
44
44
  def __init__(self, in_channels: int, reduction_rate: int, bias: bool = False):
@@ -93,7 +93,7 @@ class GSoP(nn.Module):
93
93
  References
94
94
  ----------
95
95
  .. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
96
- Global Second-order Pooling Convolutional Networks. CVPR 2018.
96
+ Global Second-order Pooling Convolutional Networks. CVPR 2018.
97
97
  """
98
98
 
99
99
  def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
@@ -149,7 +149,7 @@ class FCA(nn.Module):
149
149
  References
150
150
  ----------
151
151
  .. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
152
- FcaNet: Frequency Channel Attention Networks. ICCV 2021.
152
+ FcaNet: Frequency Channel Attention Networks. ICCV 2021.
153
153
  """
154
154
 
155
155
  def __init__(
@@ -233,7 +233,7 @@ class EncNet(nn.Module):
233
233
  References
234
234
  ----------
235
235
  .. [Zhang2018] Zhang, H. et al. 2018.
236
- Context Encoding for Semantic Segmentation. CVPR 2018.
236
+ Context Encoding for Semantic Segmentation. CVPR 2018.
237
237
  """
238
238
 
239
239
  def __init__(self, in_channels: int, n_codewords: int):
@@ -290,7 +290,7 @@ class ECA(nn.Module):
290
290
  References
291
291
  ----------
292
292
  .. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
293
- for Deep Convolutional Neural Networks. CVPR 2021.
293
+ for Deep Convolutional Neural Networks. CVPR 2021.
294
294
  """
295
295
 
296
296
  def __init__(self, in_channels: int, kernel_size: int):
@@ -341,8 +341,8 @@ class GatherExcite(nn.Module):
341
341
  References
342
342
  ----------
343
343
  .. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
344
- Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks.
345
- NeurIPS 2018.
344
+ Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks.
345
+ NeurIPS 2018.
346
346
  """
347
347
 
348
348
  def __init__(
@@ -410,7 +410,7 @@ class GCT(nn.Module):
410
410
  References
411
411
  ----------
412
412
  .. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
413
- Gated Channel Transformation for Visual Recognition. CVPR 2020.
413
+ Gated Channel Transformation for Visual Recognition. CVPR 2020.
414
414
  """
415
415
 
416
416
  def __init__(self, in_channels: int):
@@ -455,7 +455,7 @@ class SRM(nn.Module):
455
455
  References
456
456
  ----------
457
457
  .. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
458
- Recalibration Module for Convolutional Neural Networks. ICCV 2019.
458
+ Recalibration Module for Convolutional Neural Networks. ICCV 2019.
459
459
  """
460
460
 
461
461
  def __init__(
@@ -520,7 +520,7 @@ class CBAM(nn.Module):
520
520
  References
521
521
  ----------
522
522
  .. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
523
- CBAM: Convolutional Block Attention Module. ECCV 2018.
523
+ CBAM: Convolutional Block Attention Module. ECCV 2018.
524
524
  """
525
525
 
526
526
  def __init__(self, in_channels: int, reduction_rate: int, kernel_size: int):
@@ -37,8 +37,8 @@ class MLP(nn.Sequential):
37
37
  :math:`a_i` are called activation functions. The trainable parameters of an
38
38
  MLP are its weights and biases :math:`\\phi = \{W_i, b_i | i = 1, \dots, L\}`.
39
39
 
40
- Parameters:
41
- -----------
40
+ Parameters
41
+ ----------
42
42
  in_features: int
43
43
  Number of input features.
44
44
  hidden_features: Sequential[int] (default=None)
@@ -49,7 +49,7 @@ class MLP(nn.Sequential):
49
49
  out_features: int (default=None)
50
50
  Number of output features, if None, set to in_features.
51
51
  act_layer: nn.GELU (default)
52
- The activation function constructor. If :py:`None`, use
52
+ The activation function constructor. If ``None``, use
53
53
  :class:`torch.nn.GELU` instead.
54
54
  drop: float (default=0.0)
55
55
  Dropout rate.
@@ -17,9 +17,8 @@ class FilterBankLayer(nn.Module):
17
17
  It uses MNE's `create_filter` function to create the band-specific filters and
18
18
  applies them to multi-channel time-series data. Each filter in the bank corresponds to a
19
19
  specific frequency band and is applied to all channels of the input data. The filtering is
20
- performed using FFT-based convolution via the `fftconvolve` function from
21
- :func:`torchaudio.functional if the method is FIR, and `filtfilt` function from
22
- :func:`torchaudio.functional if the method is IIR.
20
+ performed using FFT-based convolution via the ``torchaudio.functional`` if the method is FIR,
21
+ and ``torchaudio.functional`` if the method is IIR.
23
22
 
24
23
  The default configuration creates 9 non-overlapping frequency bands with a 4 Hz bandwidth,
25
24
  spanning from 4 Hz to 40 Hz (i.e., 4-8 Hz, 8-12 Hz, ..., 36-40 Hz). This setup is based on the
@@ -70,26 +70,27 @@ class TimeDistributed(nn.Module):
70
70
  class DropPath(nn.Module):
71
71
  """Drop paths, also known as Stochastic Depth, per sample.
72
72
 
73
- When applied in main path of residual blocks.
73
+ When applied in main path of residual blocks.
74
74
 
75
- Parameters:
76
- -----------
77
- drop_prob: float (default=None)
78
- Drop path probability (should be in range 0-1).
75
+ Parameters
76
+ ----------
77
+ drop_prob: float (default=None)
78
+ Drop path probability (should be in range 0-1).
79
79
 
80
- Notes
81
- -----
82
- Code copied and modified from VISSL facebookresearch:
80
+ Notes
81
+ -----
82
+ Code copied and modified from VISSL facebookresearch:
83
83
  https://github.com/facebookresearch/vissl/blob/0b5d6a94437bc00baed112ca90c9d78c6ccfbafb/vissl/models/model_helpers.py#L676
84
- All rights reserved.
85
-
86
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
87
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
88
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
89
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
90
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
91
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
92
- SOFTWARE.
84
+
85
+ All rights reserved.
86
+
87
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
88
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
89
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
90
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
91
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
92
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
93
+ SOFTWARE.
93
94
  """
94
95
 
95
96
  def __init__(self, drop_prob=None):