braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -9,9 +9,9 @@ from braindecode.models.base import EEGModuleMixin
9
9
 
10
10
 
11
11
  class BIOT(EEGModuleMixin, nn.Module):
12
- r"""BIOT from Yang et al (2023) [Yang2023]_
12
+ """BIOT from Yang et al. (2023) [Yang2023]_
13
13
 
14
- :bdg-danger:`Foundation Model`
14
+ :bdg-danger:`Large Brain Model`
15
15
 
16
16
  .. figure:: https://braindecode.org/dev/_static/model/biot.jpg
17
17
  :align: center
@@ -19,7 +19,7 @@ class BIOT(EEGModuleMixin, nn.Module):
19
19
 
20
20
  BIOT: Cross-data Biosignal Learning in the Wild.
21
21
 
22
- BIOT is a foundation model for biosignal classification. It is
22
+ BIOT is a large brain model for biosignal classification. It is
23
23
  a wrapper around the `BIOTEncoder` and `ClassificationHead` modules.
24
24
 
25
25
  It is designed for N-dimensional biosignal data such as EEG, ECG, etc.
@@ -41,44 +41,15 @@ class BIOT(EEGModuleMixin, nn.Module):
41
41
  linear layer that takes the output of the `BIOTEncoder` and outputs
42
42
  the classification probabilities.
43
43
 
44
- .. important::
45
- **Pre-trained Weights Available**
46
-
47
- This model has pre-trained weights available on the Hugging Face Hub.
48
- You can load them using:
49
-
50
- .. code-block:: python
51
-
52
- from braindecode.models import BIOT
53
-
54
- # Load the original pre-trained model from Hugging Face Hub
55
- # For 16-channel models:
56
- model = BIOT.from_pretrained("braindecode/biot-pretrained-prest-16chs")
57
-
58
- # For 18-channel models:
59
- model = BIOT.from_pretrained("braindecode/biot-pretrained-shhs-prest-18chs")
60
- model = BIOT.from_pretrained("braindecode/biot-pretrained-six-datasets-18chs")
61
-
62
- To push your own trained model to the Hub:
63
-
64
- .. code-block:: python
65
-
66
- # After training your model
67
- model.push_to_hub(
68
- repo_id="username/my-biot-model", commit_message="Upload trained BIOT model"
69
- )
70
-
71
- Requires installing ``braindecode[hug]`` for Hub integration.
72
-
73
44
  .. versionadded:: 0.9
74
45
 
75
46
  Parameters
76
47
  ----------
77
- embed_dim : int, optional
48
+ emb_size : int, optional
78
49
  The size of the embedding layer, by default 256
79
- num_heads : int, optional
50
+ att_num_heads : int, optional
80
51
  The number of attention heads, by default 8
81
- num_layers : int, optional
52
+ n_layers : int, optional
82
53
  The number of transformer layers, by default 4
83
54
  activation: nn.Module, default=nn.ELU
84
55
  Activation function class to apply. Should be a PyTorch activation
@@ -105,9 +76,9 @@ class BIOT(EEGModuleMixin, nn.Module):
105
76
 
106
77
  def __init__(
107
78
  self,
108
- embed_dim=256,
109
- num_heads=8,
110
- num_layers=4,
79
+ emb_size=256,
80
+ att_num_heads=8,
81
+ n_layers=4,
111
82
  sfreq=200,
112
83
  hop_length=100,
113
84
  return_feature=False,
@@ -116,12 +87,12 @@ class BIOT(EEGModuleMixin, nn.Module):
116
87
  chs_info=None,
117
88
  n_times=None,
118
89
  input_window_seconds=None,
119
- activation: type[nn.Module] = nn.ELU,
90
+ activation: nn.Module = nn.ELU,
120
91
  drop_prob: float = 0.5,
121
92
  # Parameters for the encoder
122
93
  max_seq_len: int = 1024,
123
- att_drop_prob=0.2,
124
- att_layer_drop_prob=0.2,
94
+ attn_dropout=0.2,
95
+ attn_layer_dropout=0.2,
125
96
  ):
126
97
  super().__init__(
127
98
  n_outputs=n_outputs,
@@ -132,10 +103,10 @@ class BIOT(EEGModuleMixin, nn.Module):
132
103
  sfreq=sfreq,
133
104
  )
134
105
  del n_outputs, n_chans, chs_info, n_times, sfreq
135
- self.embed_dim = embed_dim
106
+ self.emb_size = emb_size
136
107
  self.hop_length = hop_length
137
- self.num_heads = num_heads
138
- self.num_layers = num_layers
108
+ self.att_num_heads = att_num_heads
109
+ self.n_layers = n_layers
139
110
  self.return_feature = return_feature
140
111
  if (self.sfreq != 200) & (self.sfreq is not None):
141
112
  warn(
@@ -143,7 +114,7 @@ class BIOT(EEGModuleMixin, nn.Module):
143
114
  + "no guarantee to generalize well with the default parameters",
144
115
  UserWarning,
145
116
  )
146
- if self.n_chans > embed_dim:
117
+ if self.n_chans > emb_size:
147
118
  warn(
148
119
  "The number of channels is larger than the embedding size. "
149
120
  + "This may cause overfitting. Consider using a larger "
@@ -171,20 +142,20 @@ class BIOT(EEGModuleMixin, nn.Module):
171
142
  self.n_fft = int(self.sfreq)
172
143
 
173
144
  self.encoder = _BIOTEncoder(
174
- emb_size=self.embed_dim,
175
- num_heads=self.num_heads,
176
- n_layers=self.num_layers,
145
+ emb_size=emb_size,
146
+ att_num_heads=att_num_heads,
147
+ n_layers=n_layers,
177
148
  n_chans=self.n_chans,
178
149
  n_fft=self.n_fft,
179
150
  hop_length=hop_length,
180
151
  drop_prob=drop_prob,
181
152
  max_seq_len=max_seq_len,
182
- attn_dropout=att_drop_prob,
183
- attn_layer_dropout=att_layer_drop_prob,
153
+ attn_dropout=attn_dropout,
154
+ attn_layer_dropout=attn_layer_dropout,
184
155
  )
185
156
 
186
157
  self.final_layer = _ClassificationHead(
187
- emb_size=self.embed_dim,
158
+ emb_size=emb_size,
188
159
  n_outputs=self.n_outputs,
189
160
  activation=activation,
190
161
  )
@@ -216,7 +187,7 @@ class BIOT(EEGModuleMixin, nn.Module):
216
187
 
217
188
 
218
189
  class _PatchFrequencyEmbedding(nn.Module):
219
- r"""
190
+ """
220
191
  Patch Frequency Embedding.
221
192
 
222
193
  A simple linear layer is used to learn some representation over the
@@ -258,7 +229,7 @@ class _PatchFrequencyEmbedding(nn.Module):
258
229
 
259
230
 
260
231
  class _ClassificationHead(nn.Sequential):
261
- r"""
232
+ """
262
233
  Classification head for the BIOT model.
263
234
 
264
235
  Simple linear layer with ELU activation function.
@@ -279,9 +250,7 @@ class _ClassificationHead(nn.Sequential):
279
250
  (batch, n_outputs)
280
251
  """
281
252
 
282
- def __init__(
283
- self, emb_size: int, n_outputs: int, activation: type[nn.Module] = nn.ELU
284
- ):
253
+ def __init__(self, emb_size: int, n_outputs: int, activation: nn.Module = nn.ELU):
285
254
  super().__init__()
286
255
  self.activation_layer = activation()
287
256
  self.classification_head = nn.Linear(emb_size, n_outputs)
@@ -293,7 +262,7 @@ class _ClassificationHead(nn.Sequential):
293
262
 
294
263
 
295
264
  class _PositionalEncoding(nn.Module):
296
- r"""
265
+ """
297
266
  Positional Encoding.
298
267
 
299
268
  We first create a `pe` zero matrix of shape (max_len, d_model) where max_len is the
@@ -354,7 +323,7 @@ class _PositionalEncoding(nn.Module):
354
323
 
355
324
 
356
325
  class _BIOTEncoder(nn.Module):
357
- r"""
326
+ """
358
327
  BIOT Encoder.
359
328
 
360
329
  The BIOT encoder is a transformer that takes the time series input data and
@@ -376,7 +345,7 @@ class _BIOTEncoder(nn.Module):
376
345
  The number of channels
377
346
  emb_size: int
378
347
  The size of the embedding layer
379
- num_heads: int
348
+ att_num_heads: int
380
349
  The number of attention heads
381
350
  n_layers: int
382
351
  The number of transformer layers
@@ -389,7 +358,7 @@ class _BIOTEncoder(nn.Module):
389
358
  def __init__(
390
359
  self,
391
360
  emb_size=256, # The size of the embedding layer
392
- num_heads=8, # The number of attention heads
361
+ att_num_heads=8, # The number of attention heads
393
362
  n_chans=16, # The number of channels
394
363
  n_layers=4, # The number of transformer layers
395
364
  n_fft=200, # Related with the frequency resolution
@@ -409,7 +378,7 @@ class _BIOTEncoder(nn.Module):
409
378
  )
410
379
  self.transformer = LinearAttentionTransformer(
411
380
  dim=emb_size,
412
- heads=num_heads,
381
+ heads=att_num_heads,
413
382
  depth=n_layers,
414
383
  max_seq_len=max_seq_len,
415
384
  attn_layer_dropout=attn_layer_dropout,
@@ -8,7 +8,7 @@ from braindecode.models.base import EEGModuleMixin
8
8
 
9
9
 
10
10
  class ContraWR(EEGModuleMixin, nn.Module):
11
- r"""Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021]_.
11
+ """Contrast with the World Representation ContraWR from Yang et al (2021) [Yang2021]_.
12
12
 
13
13
  :bdg-success:`Convolution`
14
14
 
@@ -58,7 +58,7 @@ class ContraWR(EEGModuleMixin, nn.Module):
58
58
  emb_size: int = 256,
59
59
  res_channels: list[int] = [32, 64, 128],
60
60
  steps=20,
61
- activation: type[nn.Module] = nn.ELU,
61
+ activation: nn.Module = nn.ELU,
62
62
  drop_prob: float = 0.5,
63
63
  stride_res: int = 2,
64
64
  kernel_size_res: int = 3,
@@ -148,7 +148,7 @@ class ContraWR(EEGModuleMixin, nn.Module):
148
148
 
149
149
 
150
150
  class _ResBlock(nn.Module):
151
- r"""Convolutional Residual Block 2D.
151
+ """Convolutional Residual Block 2D.
152
152
 
153
153
  This block stacks two convolutional layers with batch normalization,
154
154
  max pooling, dropout, and residual connection.
@@ -195,7 +195,7 @@ class _ResBlock(nn.Module):
195
195
  kernel_size=3,
196
196
  padding=1,
197
197
  drop_prob=0.5,
198
- activation: type[nn.Module] = nn.ReLU,
198
+ activation: nn.Module = nn.ReLU,
199
199
  ):
200
200
  super().__init__()
201
201
  self.conv1 = nn.Conv2d(
@@ -259,7 +259,7 @@ class _ResBlock(nn.Module):
259
259
 
260
260
 
261
261
  class _STFTModule(nn.Module):
262
- r"""
262
+ """
263
263
  A PyTorch module that computes the Short-Time Fourier Transform (STFT)
264
264
  of an EEG batch tensor.
265
265
 
@@ -25,9 +25,9 @@ from braindecode.modules import (
25
25
 
26
26
 
27
27
  class CTNet(EEGModuleMixin, nn.Module):
28
- r"""CTNet from Zhao, W et al (2024) [ctnet]_.
28
+ """CTNet from Zhao, W et al (2024) [ctnet]_.
29
29
 
30
- :bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
30
+ :bdg-success:`Convolution` :bdg-info:`Small Attention`
31
31
 
32
32
  A Convolutional Transformer Network for EEG-Based Motor Imagery Classification
33
33
 
@@ -61,11 +61,11 @@ class CTNet(EEGModuleMixin, nn.Module):
61
61
  ----------
62
62
  activation : nn.Module, default=nn.GELU
63
63
  Activation function to use in the network.
64
- num_heads : int, default=4
64
+ heads : int, default=4
65
65
  Number of attention heads in the Transformer encoder.
66
- embed_dim : int or None, default=None
66
+ emb_size : int or None, default=None
67
67
  Embedding size (dimensionality) for the Transformer encoder.
68
- num_layers : int, default=6
68
+ depth : int, default=6
69
69
  Number of encoder layers in the Transformer.
70
70
  n_filters_time : int, default=20
71
71
  Number of temporal filters in the first convolutional layer.
@@ -77,11 +77,11 @@ class CTNet(EEGModuleMixin, nn.Module):
77
77
  Pooling size for the first average pooling layer.
78
78
  pool_size_2 : int, default=8
79
79
  Pooling size for the second average pooling layer.
80
- cnn_drop_prob: float, default=0.3
80
+ drop_prob_cnn : float, default=0.3
81
81
  Dropout probability after convolutional layers.
82
- att_positional_drop_prob : float, default=0.1
82
+ drop_prob_posi : float, default=0.1
83
83
  Dropout probability for the positional encoding in the Transformer.
84
- final_drop_prob : float, default=0.5
84
+ drop_prob_final : float, default=0.5
85
85
  Dropout probability before the final classification layer.
86
86
 
87
87
  Notes
@@ -109,15 +109,15 @@ class CTNet(EEGModuleMixin, nn.Module):
109
109
  n_times=None,
110
110
  input_window_seconds=None,
111
111
  # Model specific arguments
112
- activation_patch: type[nn.Module] = nn.ELU,
113
- activation_transformer: type[nn.Module] = nn.GELU,
114
- cnn_drop_prob: float = 0.3,
115
- att_positional_drop_prob: float = 0.1,
116
- final_drop_prob: float = 0.5,
112
+ activation_patch: nn.Module = nn.ELU,
113
+ activation_transformer: nn.Module = nn.GELU,
114
+ drop_prob_cnn: float = 0.3,
115
+ drop_prob_posi: float = 0.1,
116
+ drop_prob_final: float = 0.5,
117
117
  # other parameters
118
- num_heads: int = 4,
119
- embed_dim: Optional[int] = 40,
120
- num_layers: int = 6,
118
+ heads: int = 4,
119
+ emb_size: Optional[int] = 40,
120
+ depth: int = 6,
121
121
  n_filters_time: Optional[int] = None,
122
122
  kernel_size: int = 64,
123
123
  depth_multiplier: Optional[int] = 2,
@@ -136,14 +136,14 @@ class CTNet(EEGModuleMixin, nn.Module):
136
136
 
137
137
  self.activation_patch = activation_patch
138
138
  self.activation_transformer = activation_transformer
139
- self.cnn_drop_prob = cnn_drop_prob
139
+ self.drop_prob_cnn = drop_prob_cnn
140
140
  self.pool_size_1 = pool_size_1
141
141
  self.pool_size_2 = pool_size_2
142
142
  self.kernel_size = kernel_size
143
- self.att_positional_drop_prob = att_positional_drop_prob
144
- self.final_drop_prob = final_drop_prob
145
- self.num_heads = num_heads
146
- self.num_layers = num_layers
143
+ self.drop_prob_posi = drop_prob_posi
144
+ self.drop_prob_final = drop_prob_final
145
+ self.heads = heads
146
+ self.depth = depth
147
147
  # n_times - pool_size_1 / p
148
148
  self.sequence_length = math.floor(
149
149
  (
@@ -154,8 +154,8 @@ class CTNet(EEGModuleMixin, nn.Module):
154
154
  + 1
155
155
  )
156
156
 
157
- self.depth_multiplier, self.n_filters_time, self.embed_dim = self._resolve_dims(
158
- depth_multiplier, n_filters_time, embed_dim
157
+ self.depth_multiplier, self.n_filters_time, self.emb_size = self._resolve_dims(
158
+ depth_multiplier, n_filters_time, emb_size
159
159
  )
160
160
 
161
161
  # Layers
@@ -168,32 +168,32 @@ class CTNet(EEGModuleMixin, nn.Module):
168
168
  depth_multiplier=self.depth_multiplier,
169
169
  pool_size_1=self.pool_size_1,
170
170
  pool_size_2=self.pool_size_2,
171
- drop_prob=self.cnn_drop_prob,
171
+ drop_prob=self.drop_prob_cnn,
172
172
  n_chans=self.n_chans,
173
173
  activation=self.activation_patch,
174
174
  )
175
175
 
176
176
  self.position = _PositionalEncoding(
177
- emb_size=self.embed_dim,
178
- drop_prob=self.att_positional_drop_prob,
177
+ emb_size=self.emb_size,
178
+ drop_prob=self.drop_prob_posi,
179
179
  n_times=self.n_times,
180
180
  pool_size=self.pool_size_1,
181
181
  )
182
182
 
183
183
  self.trans = _TransformerEncoder(
184
- self.num_heads,
185
- self.num_layers,
186
- self.embed_dim,
184
+ self.heads,
185
+ self.depth,
186
+ self.emb_size,
187
187
  activation=self.activation_transformer,
188
188
  )
189
189
 
190
190
  self.flatten_drop_layer = nn.Sequential(
191
191
  nn.Flatten(),
192
- nn.Dropout(p=self.final_drop_prob),
192
+ nn.Dropout(p=self.drop_prob_final),
193
193
  )
194
194
 
195
195
  self.final_layer = nn.Linear(
196
- in_features=int(self.embed_dim * self.sequence_length),
196
+ in_features=int(self.emb_size * self.sequence_length),
197
197
  out_features=self.n_outputs,
198
198
  )
199
199
 
@@ -213,7 +213,7 @@ class CTNet(EEGModuleMixin, nn.Module):
213
213
  """
214
214
  x = self.ensuredim(x)
215
215
  cnn = self.cnn(x)
216
- cnn = cnn * math.sqrt(self.embed_dim)
216
+ cnn = cnn * math.sqrt(self.emb_size)
217
217
  cnn = self.position(cnn)
218
218
  trans = self.trans(cnn)
219
219
  features = cnn + trans
@@ -312,7 +312,7 @@ class _PatchEmbeddingEEGNet(nn.Module):
312
312
  pool_size_2: int = 8,
313
313
  drop_prob: float = 0.3,
314
314
  n_chans: int = 22,
315
- activation: type[nn.Module] = nn.ELU,
315
+ activation: nn.Module = nn.ELU,
316
316
  ):
317
317
  super().__init__()
318
318
  n_filters_out = depth_multiplier * n_filters_time
@@ -416,7 +416,7 @@ class _TransformerEncoderBlock(nn.Module):
416
416
  drop_prob: float = 0.5,
417
417
  forward_expansion: int = 4,
418
418
  forward_drop_p: float = 0.5,
419
- activation: type[nn.Module] = nn.GELU,
419
+ activation: nn.Module = nn.GELU,
420
420
  ):
421
421
  super().__init__()
422
422
  self.attention = _ResidualAdd(
@@ -466,7 +466,7 @@ class _TransformerEncoder(nn.Module):
466
466
  nheads: int,
467
467
  depth: int,
468
468
  dim_feedforward: int,
469
- activation: type[nn.Module] = nn.GELU,
469
+ activation: nn.Module = nn.GELU,
470
470
  ):
471
471
  super().__init__()
472
472
  self.layers = nn.Sequential(
@@ -17,7 +17,7 @@ from braindecode.modules import (
17
17
 
18
18
 
19
19
  class Deep4Net(EEGModuleMixin, nn.Sequential):
20
- r"""Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
20
+ """Deep ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
21
21
 
22
22
  :bdg-success:`Convolution`
23
23
 
@@ -109,12 +109,12 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
109
109
  filter_length_3=10,
110
110
  n_filters_4=200,
111
111
  filter_length_4=10,
112
- activation_first_conv_nonlin: type[nn.Module] = nn.ELU,
112
+ activation_first_conv_nonlin: nn.Module = nn.ELU,
113
113
  first_pool_mode="max",
114
- first_pool_nonlin: type[nn.Module] = nn.Identity,
115
- activation_later_conv_nonlin: type[nn.Module] = nn.ELU,
114
+ first_pool_nonlin: nn.Module = nn.Identity,
115
+ activation_later_conv_nonlin: nn.Module = nn.ELU,
116
116
  later_pool_mode="max",
117
- later_pool_nonlin: type[nn.Module] = nn.Identity,
117
+ later_pool_nonlin: nn.Module = nn.Identity,
118
118
  drop_prob=0.5,
119
119
  split_first_layer=True,
120
120
  batch_norm=True,
@@ -8,7 +8,7 @@ from braindecode.models.base import EEGModuleMixin
8
8
 
9
9
 
10
10
  class DeepSleepNet(EEGModuleMixin, nn.Module):
11
- r"""DeepSleepNet from Supratak et al (2017) [Supratak2017]_.
11
+ """DeepSleepNet from Supratak et al. (2017) [Supratak2017]_.
12
12
 
13
13
  :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
14
14
 
@@ -172,8 +172,8 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
172
172
  n_times=None,
173
173
  input_window_seconds=None,
174
174
  sfreq=None,
175
- activation_large: type[nn.Module] = nn.ELU,
176
- activation_small: type[nn.Module] = nn.ReLU,
175
+ activation_large: nn.Module = nn.ELU,
176
+ activation_small: nn.Module = nn.ReLU,
177
177
  drop_prob: float = 0.5,
178
178
  ):
179
179
  super().__init__(
@@ -240,7 +240,7 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
240
240
 
241
241
 
242
242
  class _SmallCNN(nn.Module):
243
- r"""
243
+ """
244
244
  Smaller filter sizes to learn temporal information.
245
245
 
246
246
  Parameters
@@ -252,7 +252,7 @@ class _SmallCNN(nn.Module):
252
252
  The dropout rate for regularization. Values should be between 0 and 1.
253
253
  """
254
254
 
255
- def __init__(self, activation: type[nn.Module] = nn.ReLU, drop_prob: float = 0.5):
255
+ def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
256
256
  super().__init__()
257
257
  self.conv1 = nn.Sequential(
258
258
  nn.Conv2d(
@@ -317,7 +317,7 @@ class _SmallCNN(nn.Module):
317
317
 
318
318
 
319
319
  class _LargeCNN(nn.Module):
320
- r"""
320
+ """
321
321
  Larger filter sizes to learn frequency information.
322
322
 
323
323
  Parameters
@@ -328,7 +328,7 @@ class _LargeCNN(nn.Module):
328
328
 
329
329
  """
330
330
 
331
- def __init__(self, activation: type[nn.Module] = nn.ELU, drop_prob: float = 0.5):
331
+ def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
332
332
  super().__init__()
333
333
 
334
334
  self.conv1 = nn.Sequential(
@@ -12,9 +12,9 @@ from braindecode.modules import FeedForwardBlock, MultiHeadAttention
12
12
 
13
13
 
14
14
  class EEGConformer(EEGModuleMixin, nn.Module):
15
- r"""EEG Conformer from Song et al (2022) [song2022]_.
15
+ """EEG Conformer from Song et al. (2022) [song2022]_.
16
16
 
17
- :bdg-success:`Convolution` :bdg-info:`Attention/Transformer`
17
+ :bdg-success:`Convolution` :bdg-info:`Small Attention`
18
18
 
19
19
  .. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
20
20
  :align: center
@@ -57,9 +57,9 @@ class EEGConformer(EEGModuleMixin, nn.Module):
57
57
  - :class:`_TransformerEncoder` **(context over temporal tokens)**
58
58
 
59
59
  - *Operations.*
60
- - A stack of ``num_layers`` encoder blocks. :class:`_TransformerEncoderBlock`
60
+ - A stack of ``att_depth`` encoder blocks. :class:`_TransformerEncoderBlock`
61
61
  - Each block applies LayerNorm :class:`torch.nn.LayerNorm`
62
- - Multi-Head Self-Attention (``num_heads``) with dropout + residual :class:`MultiHeadAttention` (:class:`torch.nn.Dropout`)
62
+ - Multi-Head Self-Attention (``att_heads``) with dropout + residual :class:`MultiHeadAttention` (:class:`torch.nn.Dropout`)
63
63
  - LayerNorm :class:`torch.nn.LayerNorm`
64
64
  - 2-layer feed-forward (≈4x expansion, :class:`torch.nn.GELU`) with dropout + residual.
65
65
 
@@ -100,7 +100,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
100
100
 
101
101
  .. rubric:: Attention / Sequential Modules
102
102
 
103
- - **Type.** Standard multi-head self-attention (MHA) with ``num_heads`` heads over the token sequence.
103
+ - **Type.** Standard multi-head self-attention (MHA) with ``att_heads`` heads over the token sequence.
104
104
  - **Shapes.** Input/Output: ``(B, S_tokens, D)``; attention operates along the ``S_tokens`` axis.
105
105
  - **Role.** Re-weights and integrates evidence across pooled windows, capturing dependencies
106
106
  longer than any single token while leaving channel relationships to the convolutional stem.
@@ -127,7 +127,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
127
127
  - **Instantiation.** Choose ``n_filters_time`` (embedding size ``D``) and
128
128
  ``filter_time_length`` to match the rhythms of interest. Tune
129
129
  ``pool_time_length/stride`` to trade temporal resolution for sequence length.
130
- Keep ``num_layers`` modest (e.g., 4–6) and set ``num_heads`` to divide ``D``.
130
+ Keep ``att_depth`` modest (e.g., 4–6) and set ``att_heads`` to divide ``D``.
131
131
  ``final_fc_length="auto"`` infers the flattened size from PatchEmbedding.
132
132
 
133
133
  Notes
@@ -160,9 +160,9 @@ class EEGConformer(EEGModuleMixin, nn.Module):
160
160
  Length of stride between temporal pooling filters.
161
161
  drop_prob: float
162
162
  Dropout rate of the convolutional layer.
163
- num_layers: int
163
+ att_depth: int
164
164
  Number of self-attention layers.
165
- num_heads: int
165
+ att_heads: int
166
166
  Number of attention heads.
167
167
  att_drop_prob: float
168
168
  Dropout rate of the self-attention layer.
@@ -197,13 +197,13 @@ class EEGConformer(EEGModuleMixin, nn.Module):
197
197
  pool_time_length=75,
198
198
  pool_time_stride=15,
199
199
  drop_prob=0.5,
200
- num_layers=6,
201
- num_heads=10,
200
+ att_depth=6,
201
+ att_heads=10,
202
202
  att_drop_prob=0.5,
203
203
  final_fc_length="auto",
204
204
  return_features=False,
205
- activation: type[nn.Module] = nn.ELU,
206
- activation_transfor: type[nn.Module] = nn.GELU,
205
+ activation: nn.Module = nn.ELU,
206
+ activation_transfor: nn.Module = nn.GELU,
207
207
  n_times=None,
208
208
  chs_info=None,
209
209
  input_window_seconds=None,
@@ -250,9 +250,9 @@ class EEGConformer(EEGModuleMixin, nn.Module):
250
250
  self.final_fc_length = final_fc_length
251
251
 
252
252
  self.transformer = _TransformerEncoder(
253
- num_layers=num_layers,
253
+ att_depth=att_depth,
254
254
  emb_size=n_filters_time,
255
- num_heads=num_heads,
255
+ att_heads=att_heads,
256
256
  att_drop=att_drop_prob,
257
257
  activation=activation_transfor,
258
258
  )
@@ -284,7 +284,7 @@ class EEGConformer(EEGModuleMixin, nn.Module):
284
284
 
285
285
 
286
286
  class _PatchEmbedding(nn.Module):
287
- r"""Patch Embedding.
287
+ """Patch Embedding.
288
288
 
289
289
  The authors used a convolution module to capture local features,
290
290
  instead of position embedding.
@@ -318,7 +318,7 @@ class _PatchEmbedding(nn.Module):
318
318
  pool_time_length,
319
319
  stride_avg_pool,
320
320
  drop_prob,
321
- activation: type[nn.Module] = nn.ELU,
321
+ activation: nn.Module = nn.ELU,
322
322
  ):
323
323
  super().__init__()
324
324
 
@@ -364,16 +364,16 @@ class _TransformerEncoderBlock(nn.Sequential):
364
364
  def __init__(
365
365
  self,
366
366
  emb_size,
367
- num_heads,
367
+ att_heads,
368
368
  att_drop,
369
369
  forward_expansion=4,
370
- activation: type[nn.Module] = nn.GELU,
370
+ activation: nn.Module = nn.GELU,
371
371
  ):
372
372
  super().__init__(
373
373
  _ResidualAdd(
374
374
  nn.Sequential(
375
375
  nn.LayerNorm(emb_size),
376
- MultiHeadAttention(emb_size, num_heads, att_drop),
376
+ MultiHeadAttention(emb_size, att_heads, att_drop),
377
377
  nn.Dropout(att_drop),
378
378
  )
379
379
  ),
@@ -393,17 +393,17 @@ class _TransformerEncoderBlock(nn.Sequential):
393
393
 
394
394
 
395
395
  class _TransformerEncoder(nn.Sequential):
396
- r"""Transformer encoder module for the transformer encoder.
396
+ """Transformer encoder module for the transformer encoder.
397
397
 
398
398
  Similar to the layers used in ViT.
399
399
 
400
400
  Parameters
401
401
  ----------
402
- num_layers : int
402
+ att_depth : int
403
403
  Number of transformer encoder blocks.
404
404
  emb_size : int
405
405
  Embedding size of the transformer encoder.
406
- num_heads : int
406
+ att_heads : int
407
407
  Number of attention heads.
408
408
  att_drop : float
409
409
  Dropout probability for the attention layers.
@@ -411,19 +411,14 @@ class _TransformerEncoder(nn.Sequential):
411
411
  """
412
412
 
413
413
  def __init__(
414
- self,
415
- num_layers,
416
- emb_size,
417
- num_heads,
418
- att_drop,
419
- activation: type[nn.Module] = nn.GELU,
414
+ self, att_depth, emb_size, att_heads, att_drop, activation: nn.Module = nn.GELU
420
415
  ):
421
416
  super().__init__(
422
417
  *[
423
418
  _TransformerEncoderBlock(
424
- emb_size, num_heads, att_drop, activation=activation
419
+ emb_size, att_heads, att_drop, activation=activation
425
420
  )
426
- for _ in range(num_layers)
421
+ for _ in range(att_depth)
427
422
  ]
428
423
  )
429
424
 
@@ -436,7 +431,7 @@ class _FullyConnected(nn.Module):
436
431
  drop_prob_2=0.3,
437
432
  out_channels=256,
438
433
  hidden_channels=32,
439
- activation: type[nn.Module] = nn.ELU,
434
+ activation: nn.Module = nn.ELU,
440
435
  ):
441
436
  """Fully-connected layer for the transformer encoder.
442
437