braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -24,15 +24,16 @@ from braindecode.modules.attention import (
24
24
 
25
25
 
26
26
  class AttentionBaseNet(EEGModuleMixin, nn.Module):
27
- r"""AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.
27
+ """AttentionBaseNet from Wimpff M et al. (2023) [Martin2023]_.
28
28
 
29
- :bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
29
+ :bdg-success:`Convolution` :bdg-info:`Small Attention`
30
30
 
31
31
  .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
32
32
  :align: center
33
33
  :alt: AttentionBaseNet Architecture
34
34
  :width: 640px
35
35
 
36
+
36
37
  .. rubric:: Architectural Overview
37
38
 
38
39
  AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
@@ -49,6 +50,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
49
50
  attention unit that *re-weights channels* (and optionally temporal positions) before
50
51
  classification.
51
52
 
53
+
52
54
  .. rubric:: Macro Components
53
55
 
54
56
  - :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**
@@ -90,6 +92,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
90
92
  *Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
91
93
  ``(B, ch_dim·T₂)`` to classes.
92
94
 
95
+
93
96
  .. rubric:: Convolutional Details
94
97
 
95
98
  - **Temporal (where time-domain patterns are learned).**
@@ -108,6 +111,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
108
111
  emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
109
112
  channel attention (DCT-based) summarizes frequencies to drive channel weights.
110
113
 
114
+
111
115
  .. rubric:: Attention / Sequential Modules
112
116
 
113
117
  - **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
@@ -120,6 +124,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
120
124
  - **Role.** Re-weights channels (and optionally time) to highlight informative sources
121
125
  and suppress distractors, improving SNR ahead of the linear head.
122
126
 
127
+
123
128
  .. rubric:: Additional Mechanisms
124
129
 
125
130
  **Attention variants at a glance:**
@@ -158,6 +163,17 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
158
163
  only after the stem learns stable filters. For small datasets, prefer simpler modes
159
164
  (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
160
165
 
166
+ Notes
167
+ -----
168
+ - Sequence length after each stage is computed internally; the final classifier expects
169
+ a flattened ``ch_dim x T₂`` vector.
170
+ - Attention operates on *channel* dimension by design; temporal gating exists only in
171
+ specific variants (CBAM/CAT).
172
+ - The paper and original code with more details about the methodological
173
+ choices are available at the [Martin2023]_ and [MartinCode]_.
174
+
175
+ .. versionadded:: 0.9
176
+
161
177
  Parameters
162
178
  ----------
163
179
  n_temporal_filters : int, optional
@@ -219,24 +235,13 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
219
235
  kernel_size : int, default=9
220
236
  The kernel size used in certain types of attention mechanisms for convolution
221
237
  operations.
222
- activation : type[nn.Module] = nn.ELU,
238
+ activation: nn.Module, default=nn.ELU
223
239
  Activation function class to apply. Should be a PyTorch activation
224
240
  module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
225
241
  extra_params : bool, default=False
226
242
  Flag to indicate whether additional, custom parameters should be passed to
227
243
  the attention mechanism.
228
244
 
229
- Notes
230
- -----
231
- - Sequence length after each stage is computed internally; the final classifier expects
232
- a flattened ``ch_dim x T₂`` vector.
233
- - Attention operates on *channel* dimension by design; temporal gating exists only in
234
- specific variants (CBAM/CAT).
235
- - The paper and original code with more details about the methodological
236
- choices are available at the [Martin2023]_ and [MartinCode]_.
237
-
238
- .. versionadded:: 0.9
239
-
240
245
  References
241
246
  ----------
242
247
  .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
@@ -272,7 +277,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
272
277
  freq_idx: int = 0,
273
278
  n_codewords: int = 4,
274
279
  kernel_size: int = 9,
275
- activation: type[nn.Module] = nn.ELU,
280
+ activation: nn.Module = nn.ELU,
276
281
  extra_params: bool = False,
277
282
  ):
278
283
  super(AttentionBaseNet, self).__init__()
@@ -392,8 +397,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
392
397
  pool_length: int,
393
398
  ) -> int:
394
399
  """
395
- Calculates the minimum n_times required for the model to work.
396
-
400
+ Calculates the minimum n_times required for the model to work
397
401
  with the given parameters.
398
402
 
399
403
  The calculation is based on reversing the pooling operations to
@@ -409,15 +413,15 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
409
413
 
410
414
 
411
415
  class _FeatureExtractor(nn.Module):
412
- r"""
413
- A module for feature extraction of the data with temporal and spatial.
414
-
416
+ """
417
+ A module for feature extraction of the data with temporal and spatial
415
418
  transformations.
416
419
 
417
420
  This module sequentially processes the input through a series of layers:
418
421
  rearrangement, temporal convolution, batch normalization, spatial convolution,
419
422
  another batch normalization, an ELU non-linearity, average pooling, and dropout.
420
423
 
424
+
421
425
  Parameters
422
426
  ----------
423
427
  n_chans : int
@@ -435,7 +439,7 @@ class _FeatureExtractor(nn.Module):
435
439
  The stride of the average pooling operation. Default is 15.
436
440
  drop_prob : float, optional
437
441
  The dropout rate for regularization. Default is 0.5.
438
- activation : nn.Module, default=nn.ELU
442
+ activation: nn.Module, default=nn.ELU
439
443
  Activation function class to apply. Should be a PyTorch activation
440
444
  module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
441
445
  """
@@ -449,7 +453,7 @@ class _FeatureExtractor(nn.Module):
449
453
  pool_length: int = 75,
450
454
  pool_stride: int = 15,
451
455
  drop_prob: float = 0.5,
452
- activation: type[nn.Module] = nn.ELU,
456
+ activation: nn.Module = nn.ELU,
453
457
  ):
454
458
  super().__init__()
455
459
 
@@ -489,9 +493,8 @@ class _FeatureExtractor(nn.Module):
489
493
 
490
494
 
491
495
  class _ChannelAttentionBlock(nn.Module):
492
- r"""
493
- A neural network module implementing channel-wise attention mechanisms to enhance.
494
-
496
+ """
497
+ A neural network module implementing channel-wise attention mechanisms to enhance
495
498
  feature representations by selectively emphasizing important channels and suppressing
496
499
  less useful ones. This block integrates convolutional layers, pooling, dropout, and
497
500
  an optional attention mechanism that can be customized based on the given mode.
@@ -545,7 +548,7 @@ class _ChannelAttentionBlock(nn.Module):
545
548
  extra_params : bool, default=False
546
549
  Flag to indicate whether additional, custom parameters should be passed to
547
550
  the attention mechanism.
548
- activation : nn.Module, default=nn.ELU
551
+ activation: nn.Module, default=nn.ELU
549
552
  Activation function class to apply. Should be a PyTorch activation
550
553
  module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
551
554
 
@@ -561,7 +564,7 @@ class _ChannelAttentionBlock(nn.Module):
561
564
  attention_block : torch.nn.Module or None
562
565
  The attention mechanism applied to the output of the convolutional layers,
563
566
  if `attention_mode` is not None. Otherwise, it's set to None.
564
- activation : nn.Module, default=nn.ELU
567
+ activation: nn.Module, default=nn.ELU
565
568
  Activation function class to apply. Should be a PyTorch activation
566
569
  module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
567
570
 
@@ -571,6 +574,7 @@ class _ChannelAttentionBlock(nn.Module):
571
574
  >>> x = torch.randn(1, 16, 64, 64) # Example input tensor
572
575
  >>> output = channel_attention_block(x)
573
576
  The output tensor then can be further processed or used as input to another block.
577
+
574
578
  """
575
579
 
576
580
  def __init__(
@@ -588,7 +592,7 @@ class _ChannelAttentionBlock(nn.Module):
588
592
  n_codewords: int = 4,
589
593
  kernel_size: int = 9,
590
594
  extra_params: bool = False,
591
- activation: type[nn.Module] = nn.ELU,
595
+ activation: nn.Module = nn.ELU,
592
596
  ):
593
597
  super().__init__()
594
598
  self.conv = nn.Sequential(
@@ -648,31 +652,31 @@ def get_attention_block(
648
652
 
649
653
  Parameters
650
654
  ----------
651
- attention_mode : str
655
+ attention_mode: str
652
656
  The type of attention mechanism to apply.
653
- ch_dim : int
657
+ ch_dim: int
654
658
  The number of input channels to the block.
655
- reduction_rate : int
659
+ reduction_rate: int
656
660
  The reduction rate used in the attention mechanism to reduce
657
661
  dimensionality and computational complexity.
658
662
  Used in all the methods, except for the
659
663
  encnet and eca.
660
- use_mlp : bool
664
+ use_mlp: bool
661
665
  Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used
662
666
  within the attention mechanism for further processing. Used in the ge
663
667
  and srm attention mechanism.
664
- seq_len : int
668
+ seq_len: int
665
669
  The sequence length, used in certain types of attention mechanisms to
666
670
  process temporal dimensions. Used in the ge or fca attention mechanism.
667
- freq_idx : int
671
+ freq_idx: int
668
672
  DCT index used in fca attention mechanism.
669
- n_codewords : int
673
+ n_codewords: int
670
674
  The number of codewords (clusters) used in attention mechanisms
671
675
  that employ quantization or clustering strategies, encnet.
672
- kernel_size : int
676
+ kernel_size: int
673
677
  The kernel size used in certain types of attention mechanisms for convolution
674
678
  operations, used in the cbam, eca, and cat attention mechanisms.
675
- extra_params : bool
679
+ extra_params: bool
676
680
  Parameter to pass additional parameters to the GatherExcite mechanism.
677
681
 
678
682
  Returns
@@ -16,9 +16,9 @@ from braindecode.modules import CausalConv1d
16
16
 
17
17
 
18
18
  class AttnSleep(EEGModuleMixin, nn.Module):
19
- r"""Sleep Staging Architecture from Eldele et al (2021) [Eldele2021]_.
19
+ """Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
20
20
 
21
- :bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
21
+ :bdg-success:`Convolution` :bdg-info:`Small Attention`
22
22
 
23
23
  .. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
24
24
  :align: center
@@ -63,10 +63,10 @@ class AttnSleep(EEGModuleMixin, nn.Module):
63
63
  Alias for `n_outputs`.
64
64
  input_size_s : float
65
65
  Alias for `input_window_seconds`.
66
- activation : nn.Module, default=nn.ReLU
66
+ activation: nn.Module, default=nn.ReLU
67
67
  Activation function class to apply. Should be a PyTorch activation
68
68
  module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
69
- activation_mrcnn : nn.Module, default=nn.ReLU
69
+ activation_mrcnn: nn.Module, default=nn.ReLU
70
70
  Activation function class to apply in the Mask R-CNN layer.
71
71
  Should be a PyTorch activation module class like ``nn.ReLU`` or
72
72
  ``nn.GELU``. Default is ``nn.GELU``.
@@ -90,8 +90,8 @@ class AttnSleep(EEGModuleMixin, nn.Module):
90
90
  d_ff=120,
91
91
  n_attn_heads=5,
92
92
  drop_prob=0.1,
93
- activation_mrcnn: type[nn.Module] = nn.GELU,
94
- activation: type[nn.Module] = nn.ReLU,
93
+ activation_mrcnn: nn.Module = nn.GELU,
94
+ activation: nn.Module = nn.ReLU,
95
95
  input_window_seconds=None,
96
96
  n_outputs=None,
97
97
  after_reduced_cnn_size=30,
@@ -175,7 +175,7 @@ class AttnSleep(EEGModuleMixin, nn.Module):
175
175
 
176
176
  Parameters
177
177
  ----------
178
- x : torch.Tensor
178
+ x: torch.Tensor
179
179
  Batch of EEG windows of shape (batch_size, n_channels, n_times).
180
180
  """
181
181
 
@@ -230,7 +230,7 @@ class _SEBasicBlock(nn.Module):
230
230
  planes,
231
231
  stride=1,
232
232
  downsample=None,
233
- activation: type[nn.Module] = nn.ReLU,
233
+ activation: nn.Module = nn.ReLU,
234
234
  *,
235
235
  reduction=16,
236
236
  ):
@@ -278,8 +278,8 @@ class _MRCNN(nn.Module):
278
278
  self,
279
279
  after_reduced_cnn_size,
280
280
  kernel_size=7,
281
- activation: type[nn.Module] = nn.GELU,
282
- activation_se: type[nn.Module] = nn.ReLU,
281
+ activation: nn.Module = nn.GELU,
282
+ activation_se: nn.Module = nn.ReLU,
283
283
  ):
284
284
  super(_MRCNN, self).__init__()
285
285
  drate = 0.5
@@ -325,7 +325,7 @@ class _MRCNN(nn.Module):
325
325
  )
326
326
 
327
327
  def _make_layer(
328
- self, block, planes, blocks, stride=1, activate: type[nn.Module] = nn.ReLU
328
+ self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
329
329
  ): # makes residual SE block
330
330
  downsample = None
331
331
  if stride != 1 or self.inplanes != planes * block.expansion:
@@ -363,7 +363,7 @@ class _MRCNN(nn.Module):
363
363
  def _attention(
364
364
  query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
365
365
  ) -> tuple[torch.Tensor, torch.Tensor]:
366
- """Implementation of Scaled dot product attention."""
366
+ """Implementation of Scaled dot product attention"""
367
367
  # d_k - dimension of the query and key vectors
368
368
  d_k = query.size(-1)
369
369
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
@@ -392,7 +392,7 @@ class _MultiHeadedAttention(nn.Module):
392
392
  self.dropout = nn.Dropout(p=dropout)
393
393
 
394
394
  def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
395
- """Implements Multi-head attention."""
395
+ """Implements Multi-head attention"""
396
396
  nbatches = query.size(0)
397
397
 
398
398
  query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
@@ -423,7 +423,9 @@ class _MultiHeadedAttention(nn.Module):
423
423
 
424
424
 
425
425
  class _ResidualLayerNormAttn(nn.Module):
426
- r"""A residual connection followed by a layer norm."""
426
+ """
427
+ A residual connection followed by a layer norm.
428
+ """
427
429
 
428
430
  def __init__(self, size, dropout, fn_attn):
429
431
  super().__init__()
@@ -462,9 +464,8 @@ class _ResidualLayerNormFF(nn.Module):
462
464
 
463
465
 
464
466
  class _TCE(nn.Module):
465
- r"""
466
- Transformer Encoder.
467
-
467
+ """
468
+ Transformer Encoder
468
469
  It is a stack of n layers.
469
470
  """
470
471
 
@@ -482,9 +483,8 @@ class _TCE(nn.Module):
482
483
 
483
484
 
484
485
  class _EncoderLayer(nn.Module):
485
- r"""
486
- An encoder layer.
487
-
486
+ """
487
+ An encoder layer
488
488
  Made up of self-attention and a feed forward layer.
489
489
  Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
490
490
  """
@@ -515,7 +515,7 @@ class _EncoderLayer(nn.Module):
515
515
  )
516
516
 
517
517
  def forward(self, x_in: torch.Tensor) -> torch.Tensor:
518
- """Transformer Encoder."""
518
+ """Transformer Encoder"""
519
519
  query = self.conv(x_in)
520
520
  # Encoder self-attention
521
521
  x = self.residual_self_attn(query, x_in, x_in)
@@ -524,11 +524,9 @@ class _EncoderLayer(nn.Module):
524
524
 
525
525
 
526
526
  class _PositionwiseFeedForward(nn.Module):
527
- r"""Positionwise feed-forward network."""
527
+ """Positionwise feed-forward network."""
528
528
 
529
- def __init__(
530
- self, d_model, d_ff, dropout=0.1, activation: type[nn.Module] = nn.ReLU
531
- ):
529
+ def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
532
530
  super().__init__()
533
531
  self.w_1 = nn.Linear(d_model, d_ff)
534
532
  self.w_2 = nn.Linear(d_ff, d_model)
@@ -544,6 +542,6 @@ class _PositionwiseFeedForward(nn.Module):
544
542
  "`SleepStagerEldele2021` was renamed to `AttnSleep` in v1.12 to follow original author's name; this alias will be removed in v1.14."
545
543
  )
546
544
  class SleepStagerEldele2021(AttnSleep):
547
- r"""Deprecated alias for SleepStagerEldele2021."""
545
+ """Deprecated alias for SleepStagerEldele2021."""
548
546
 
549
547
  pass
@@ -192,7 +192,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
192
192
  n_times is not None
193
193
  and input_window_seconds is not None
194
194
  and sfreq is not None
195
- and n_times != round(input_window_seconds * sfreq)
195
+ and n_times != int(input_window_seconds * sfreq)
196
196
  ):
197
197
  raise ValueError(
198
198
  f"{n_times=} different from {input_window_seconds=} * {sfreq=}"
@@ -236,7 +236,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
236
236
  and self._input_window_seconds is not None
237
237
  and self._sfreq is not None
238
238
  ):
239
- return round(self._input_window_seconds * self._sfreq)
239
+ return int(self._input_window_seconds * self._sfreq)
240
240
  elif self._n_times is None:
241
241
  raise ValueError(
242
242
  "n_times could not be inferred. "
@@ -284,7 +284,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
284
284
 
285
285
  Returns
286
286
  -------
287
- output_shape : tuple[int, ...]
287
+ output_shape: tuple[int, ...]
288
288
  shape of the network output for `batch_size==1` (1, ...)
289
289
  """
290
290
  with torch.inference_mode():
@@ -330,14 +330,13 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
330
330
 
331
331
  def to_dense_prediction_model(self, axis: tuple[int, ...] | int = (2, 3)) -> None:
332
332
  """
333
- Transform a sequential model with strides to a model that outputs.
334
-
333
+ Transform a sequential model with strides to a model that outputs
335
334
  dense predictions by removing the strides and instead inserting dilations.
336
335
  Modifies model in-place.
337
336
 
338
337
  Parameters
339
338
  ----------
340
- axis : int or (int,int)
339
+ axis: int or (int,int)
341
340
  Axis to transform (in terms of intermediate output axes)
342
341
  can either be 2, 3, or (2,3).
343
342
 
@@ -346,6 +345,7 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
346
345
  Does not yet work correctly for average pooling.
347
346
  Prior to version 0.1.7, there had been a bug that could move strides
348
347
  backwards one layer.
348
+
349
349
  """
350
350
  if not hasattr(axis, "__iter__"):
351
351
  axis = (axis,)
@@ -8,15 +8,16 @@ from braindecode.models.base import EEGModuleMixin
8
8
 
9
9
 
10
10
  class BENDR(EEGModuleMixin, nn.Module):
11
- r"""BENDR (BErt-inspired Neural Data Representations) from Kostas et al (2021) [bendr]_.
11
+ """BENDR (BErt-inspired Neural Data Representations) from Kostas et al. (2021) [bendr]_.
12
12
 
13
- :bdg-success:`Convolution` :bdg-danger:`Foundation Model`
13
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
14
14
 
15
15
  .. figure:: https://www.frontiersin.org/files/Articles/653659/fnhum-15-653659-HTML/image_m/fnhum-15-653659-g001.jpg
16
16
  :align: center
17
17
  :alt: BENDR Architecture
18
18
  :width: 1000px
19
19
 
20
+
20
21
  The **BENDR** architecture adapts techniques used for language modeling (LM) toward the
21
22
  development of encephalography modeling (EM) [bendr]_. It utilizes a self-supervised
22
23
  training objective to learn compressed representations of raw EEG signals [bendr]_. The
@@ -78,31 +79,6 @@ class BENDR(EEGModuleMixin, nn.Module):
78
79
  prepended to the BENDR sequence before input to the transformer, serving as the aggregate
79
80
  representation token [bendr]_.
80
81
 
81
- .. important::
82
- **Pre-trained Weights Available**
83
-
84
- This model has pre-trained weights available on the Hugging Face Hub.
85
- You can load them using:
86
-
87
- .. code-block:: python
88
-
89
- from braindecode.models import BENDR
90
-
91
- # Load pre-trained model from Hugging Face Hub
92
- # you can specify `n_outputs` for your downstream task
93
- model = BENDR.from_pretrained("braindecode/braindecode-bendr", n_outputs=2)
94
-
95
- To push your own trained model to the Hub:
96
-
97
- .. code-block:: python
98
-
99
- # After training your model
100
- model.push_to_hub(
101
- repo_id="username/my-bendr-model", commit_message="Upload trained BENDR model"
102
- )
103
-
104
- Requires installing ``braindecode[hug]`` for Hub integration.
105
-
106
82
  Notes
107
83
  -----
108
84
  * The full BENDR architecture contains a large number of parameters; configuration (1)
@@ -119,27 +95,6 @@ class BENDR(EEGModuleMixin, nn.Module):
119
95
  **self-supervised pre-training** on large, unlabeled EEG datasets (like TUEG) followed
120
96
  by subsequent fine-tuning on the specific downstream classification task [bendr]_.
121
97
 
122
- References
123
- ----------
124
- .. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
125
- BENDR: Using transformers and a contrastive self-supervised learning task to learn from
126
- massive amounts of EEG data.
127
- Frontiers in Human Neuroscience, 15, 653659.
128
- https://doi.org/10.3389/fnhum.2021.653659
129
- .. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
130
- wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
131
- In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
132
- Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
133
- https://dl.acm.org/doi/10.5555/3495724.3496768
134
- .. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
135
- Improving Transformer Optimization Through Better Initialization.
136
- In International Conference on Machine Learning (pp. 4475-4483). PMLR.
137
- https://dl.acm.org/doi/10.5555/3524938.3525354
138
- .. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
139
- Reducing Transformer Depth on Demand with Structured Dropout.
140
- International Conference on Learning Representations.
141
- Retrieved from https://openreview.net/forum?id=SylO2yStDr
142
-
143
98
  Parameters
144
99
  ----------
145
100
  encoder_h : int, default=512
@@ -183,6 +138,27 @@ class BENDR(EEGModuleMixin, nn.Module):
183
138
  final_layer : bool, default=True
184
139
  If True, includes a final linear classification layer that maps from encoder_h to
185
140
  n_outputs. If False, the model outputs the contextualized features directly.
141
+
142
+ References
143
+ ----------
144
+ .. [bendr] Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
145
+ BENDR: Using transformers and a contrastive self-supervised learning task to learn from
146
+ massive amounts of EEG data.
147
+ Frontiers in Human Neuroscience, 15, 653659.
148
+ https://doi.org/10.3389/fnhum.2021.653659
149
+ .. [wav2vec2] Baevski, A., Zhou, Y., Mohamed, A., & Auli, M. (2020).
150
+ wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations.
151
+ In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, & H. Lin (Eds),
152
+ Advances in Neural Information Processing Systems (Vol. 33, pp. 12449-12460).
153
+ https://dl.acm.org/doi/10.5555/3495724.3496768
154
+ .. [tfixup] Huang, T. K., Liang, S., Jha, A., & Salakhutdinov, R. (2020).
155
+ Improving Transformer Optimization Through Better Initialization.
156
+ In International Conference on Machine Learning (pp. 4475-4483). PMLR.
157
+ https://dl.acm.org/doi/10.5555/3524938.3525354
158
+ .. [layerdrop] Fan, A., Grave, E., & Joulin, A. (2020).
159
+ Reducing Transformer Depth on Demand with Structured Dropout.
160
+ International Conference on Learning Representations.
161
+ Retrieved from https://openreview.net/forum?id=SylO2yStDr
186
162
  """
187
163
 
188
164
  def __init__(
@@ -200,7 +176,7 @@ class BENDR(EEGModuleMixin, nn.Module):
200
176
  projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
201
177
  drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
202
178
  layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
203
- activation: type[nn.Module] = nn.GELU, # Activation function
179
+ activation=nn.GELU, # Activation function
204
180
  # Transformer specific parameters
205
181
  transformer_layers=8,
206
182
  transformer_heads=8,
@@ -349,7 +325,7 @@ class _ConvEncoderBENDR(nn.Module):
349
325
 
350
326
 
351
327
  class _BENDRContextualizer(nn.Module):
352
- r"""Transformer-based contextualizer for BENDR."""
328
+ """Transformer-based contextualizer for BENDR."""
353
329
 
354
330
  def __init__(
355
331
  self,