braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,9 @@ from braindecode.modules import MLP, DropPath
21
21
 
22
22
 
23
23
  class Labram(EEGModuleMixin, nn.Module):
24
- r"""Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
+ """Labram from Jiang, W B et al (2024) [Jiang2024]_.
25
25
 
26
- :bdg-success:`Convolution` :bdg-danger:`Foundation Model`
26
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
27
27
 
28
28
  .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
29
29
  :align: center
@@ -64,30 +64,6 @@ class Labram(EEGModuleMixin, nn.Module):
64
64
  - LayerNorm: Apply layer normalization to the data;
65
65
  - Linear: An head linear layer to transformer the data into classes.
66
66
 
67
- .. important::
68
- **Pre-trained Weights Available**
69
-
70
- This model has pre-trained weights available on the Hugging Face Hub.
71
- You can load them using:
72
-
73
- .. code-block:: python
74
-
75
- from braindecode.models import Labram
76
-
77
- # Load pre-trained model from Hugging Face Hub
78
- model = Labram.from_pretrained("braindecode/labram-pretrained")
79
-
80
- To push your own trained model to the Hub:
81
-
82
- .. code-block:: python
83
-
84
- # After training your model
85
- model.push_to_hub(
86
- repo_id="username/my-labram-model", commit_message="Upload trained Labram model"
87
- )
88
-
89
- Requires installing ``braindecode[hug]`` for Hub integration.
90
-
91
67
  .. versionadded:: 0.9
92
68
 
93
69
 
@@ -107,15 +83,15 @@ class Labram(EEGModuleMixin, nn.Module):
107
83
  ----------
108
84
  patch_size : int
109
85
  The size of the patch to be used in the patch embedding.
110
- embed_dim : int
86
+ emb_size : int
111
87
  The dimension of the embedding.
112
- conv_in_channels : int
88
+ in_conv_channels : int
113
89
  The number of convolutional input channels.
114
- conv_out_channels : int
90
+ out_channels : int
115
91
  The number of convolutional output channels.
116
- num_layers : int (default=12)
92
+ n_layers : int (default=12)
117
93
  The number of attention layers of the model.
118
- num_heads : int (default=10)
94
+ att_num_heads : int (default=10)
119
95
  The number of attention heads.
120
96
  mlp_ratio : float (default=4.0)
121
97
  The expansion ratio of the mlp layer
@@ -179,26 +155,26 @@ class Labram(EEGModuleMixin, nn.Module):
179
155
  sfreq=None,
180
156
  input_window_seconds=None,
181
157
  patch_size=200,
182
- embed_dim=200,
183
- conv_in_channels=1,
184
- conv_out_channels=8,
185
- num_layers=12,
186
- num_heads=10,
158
+ emb_size=200,
159
+ in_conv_channels=1,
160
+ out_channels=8,
161
+ n_layers=12,
162
+ att_num_heads=10,
187
163
  mlp_ratio=4.0,
188
164
  qkv_bias=False,
189
- qk_norm: type[nn.Module] = nn.LayerNorm,
165
+ qk_norm=nn.LayerNorm,
190
166
  qk_scale=None,
191
167
  drop_prob=0.0,
192
168
  attn_drop_prob=0.0,
193
169
  drop_path_prob=0.0,
194
- norm_layer: type[nn.Module] = nn.LayerNorm,
170
+ norm_layer=nn.LayerNorm,
195
171
  init_values=0.1,
196
172
  use_abs_pos_emb=True,
197
173
  use_mean_pooling=True,
198
174
  init_scale=0.001,
199
175
  neural_tokenizer=True,
200
176
  attn_head_dim=None,
201
- activation: type[nn.Module] = nn.GELU,
177
+ activation: nn.Module = nn.GELU,
202
178
  ):
203
179
  super().__init__(
204
180
  n_outputs=n_outputs,
@@ -211,7 +187,7 @@ class Labram(EEGModuleMixin, nn.Module):
211
187
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
212
188
 
213
189
  self.patch_size = patch_size
214
- self.num_features = self.embed_dim = embed_dim
190
+ self.num_features = self.emb_size = emb_size
215
191
  self.neural_tokenizer = neural_tokenizer
216
192
  self.init_scale = init_scale
217
193
 
@@ -223,20 +199,20 @@ class Labram(EEGModuleMixin, nn.Module):
223
199
  )
224
200
  self.patch_size = self.n_times
225
201
  self.num_features = None
226
- self.embed_dim = None
202
+ self.emb_size = None
227
203
  else:
228
204
  self.patch_size = patch_size
229
205
  self.n_path = self.n_times // self.patch_size
230
206
 
231
- if neural_tokenizer and conv_in_channels != 1:
207
+ if neural_tokenizer and in_conv_channels != 1:
232
208
  warn(
233
209
  "The model is in Neural Tokenizer mode, but the variable "
234
- + "`conv_in_channels` is different from the default values."
235
- + "`conv_in_channels` is only needed for the Neural Decoder mode."
236
- + "conv_in_channels is not used in the Neural Tokenizer mode.",
210
+ + "`in_conv_channels` is different from the default values."
211
+ + "`in_conv_channels` is only needed for the Neural Decoder mode."
212
+ + "in_conv_channels is not used in the Neural Tokenizer mode.",
237
213
  UserWarning,
238
214
  )
239
- conv_in_channels = 1
215
+ in_conv_channels = 1
240
216
  # If you can use the model in Neural Tokenizer mode,
241
217
  # temporal conv layer will be use over the patched dataset
242
218
  if neural_tokenizer:
@@ -255,7 +231,7 @@ class Labram(EEGModuleMixin, nn.Module):
255
231
  (
256
232
  "temporal_conv",
257
233
  _TemporalConv(
258
- out_channels=conv_out_channels, activation=activation
234
+ out_channels=out_channels, activation=activation
259
235
  ),
260
236
  ),
261
237
  ]
@@ -273,8 +249,8 @@ class Labram(EEGModuleMixin, nn.Module):
273
249
  _PatchEmbed(
274
250
  n_times=self.n_times,
275
251
  patch_size=patch_size,
276
- in_channels=conv_in_channels,
277
- emb_dim=self.embed_dim,
252
+ in_channels=in_conv_channels,
253
+ emb_dim=self.emb_size,
278
254
  ),
279
255
  )
280
256
 
@@ -283,12 +259,12 @@ class Labram(EEGModuleMixin, nn.Module):
283
259
  out = self.patch_embed(dummy)
284
260
  # out.shape for tokenizer: (1, n_chans, emb_dim)
285
261
  # for decoder: (1, n_patch, patch_size, emb_dim), but we want last dim
286
- self.embed_dim = out.shape[-1]
287
- self.num_features = self.embed_dim
262
+ self.emb_size = out.shape[-1]
263
+ self.num_features = self.emb_size
288
264
 
289
265
  # Defining the parameters
290
266
  # Creating a parameter list with cls token]
291
- self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
267
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.emb_size))
292
268
  # Positional embedding and time embedding are complementary
293
269
  # one is for the spatial information and the other is for the temporal
294
270
  # information.
@@ -297,26 +273,26 @@ class Labram(EEGModuleMixin, nn.Module):
297
273
  # information.
298
274
  if use_abs_pos_emb:
299
275
  self.position_embedding = nn.Parameter(
300
- torch.zeros(1, self.n_chans + 1, self.embed_dim),
276
+ torch.zeros(1, self.n_chans + 1, self.emb_size),
301
277
  requires_grad=True,
302
278
  )
303
279
  else:
304
280
  self.position_embedding = None
305
281
 
306
282
  self.temporal_embedding = nn.Parameter(
307
- torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.embed_dim),
283
+ torch.zeros(1, self.patch_embed[0].n_patchs + 1, self.emb_size),
308
284
  requires_grad=True,
309
285
  )
310
286
  self.pos_drop = nn.Dropout(p=drop_prob)
311
287
 
312
288
  dpr = [
313
- x.item() for x in torch.linspace(0, drop_path_prob, num_layers)
289
+ x.item() for x in torch.linspace(0, drop_path_prob, n_layers)
314
290
  ] # stochastic depth decay rule
315
291
  self.blocks = nn.ModuleList(
316
292
  [
317
293
  _WindowsAttentionBlock(
318
- dim=self.embed_dim,
319
- num_heads=num_heads,
294
+ dim=self.emb_size,
295
+ num_heads=att_num_heads,
320
296
  mlp_ratio=mlp_ratio,
321
297
  qkv_bias=qkv_bias,
322
298
  qk_norm=qk_norm,
@@ -334,14 +310,14 @@ class Labram(EEGModuleMixin, nn.Module):
334
310
  attn_head_dim=attn_head_dim,
335
311
  activation=activation,
336
312
  )
337
- for i in range(num_layers)
313
+ for i in range(n_layers)
338
314
  ]
339
315
  )
340
- self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.embed_dim)
341
- self.fc_norm = norm_layer(self.embed_dim) if use_mean_pooling else None
316
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(self.emb_size)
317
+ self.fc_norm = norm_layer(self.emb_size) if use_mean_pooling else None
342
318
 
343
319
  if self.n_outputs > 0:
344
- self.final_layer = nn.Linear(self.embed_dim, self.n_outputs)
320
+ self.final_layer = nn.Linear(self.emb_size, self.n_outputs)
345
321
  else:
346
322
  self.final_layer = nn.Identity()
347
323
 
@@ -439,7 +415,7 @@ class Labram(EEGModuleMixin, nn.Module):
439
415
  x = self.patch_embed(x)
440
416
  # x shape: (batch, n_chans, emb_dim)
441
417
  n_patch = self.n_chans
442
- temporal = self.embed_dim
418
+ temporal = self.emb_size
443
419
  else:
444
420
  # For neural decoder: input is (batch, n_chans, n_times)
445
421
  # patch_embed returns (batch, n_patchs, emb_dim)
@@ -486,7 +462,7 @@ class Labram(EEGModuleMixin, nn.Module):
486
462
  # In decoder mode, we have n_patch patches and don't need to expand
487
463
  # Just broadcast the temporal embedding
488
464
  if temporal is None:
489
- temporal = self.embed_dim
465
+ temporal = self.emb_size
490
466
 
491
467
  # Get temporal embeddings for n_patch patches
492
468
  n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
@@ -661,7 +637,7 @@ class Labram(EEGModuleMixin, nn.Module):
661
637
 
662
638
 
663
639
  class _SegmentPatch(nn.Module):
664
- r"""Segment and Patch for EEG data.
640
+ """Segment and Patch for EEG data.
665
641
 
666
642
  Adapted Patch Embedding inspired in the Visual Transform approach
667
643
  to extract the learned segmentor, we expect get the input shape as:
@@ -767,7 +743,7 @@ class _SegmentPatch(nn.Module):
767
743
 
768
744
 
769
745
  class _PatchEmbed(nn.Module):
770
- r"""EEG to Patch Embedding for Neural Decoder mode.
746
+ """EEG to Patch Embedding for Neural Decoder mode.
771
747
 
772
748
  This code is used when we want to apply the patch embedding
773
749
  after the codebook layer (Neural Decoder mode).
@@ -873,7 +849,7 @@ class _PatchEmbed(nn.Module):
873
849
 
874
850
 
875
851
  class _Attention(nn.Module):
876
- r"""
852
+ """
877
853
  Attention with the options of Window-based multi-head self attention (W-MSA).
878
854
 
879
855
  This code is strong inspired by:
@@ -1071,7 +1047,7 @@ class _Attention(nn.Module):
1071
1047
 
1072
1048
 
1073
1049
  class _WindowsAttentionBlock(nn.Module):
1074
- r"""Blocks of Windows Attention with Layer norm and MLP.
1050
+ """Blocks of Windows Attention with Layer norm and MLP.
1075
1051
 
1076
1052
  Notes: This code is strong inspired by:
1077
1053
  BeiTv2 from Microsoft.
@@ -1130,7 +1106,7 @@ class _WindowsAttentionBlock(nn.Module):
1130
1106
  attn_drop=0.0,
1131
1107
  drop_path=0.0,
1132
1108
  init_values=None,
1133
- activation: type[nn.Module] = nn.GELU,
1109
+ activation: nn.Module = nn.GELU,
1134
1110
  norm_layer=nn.LayerNorm,
1135
1111
  window_size=None,
1136
1112
  attn_head_dim=None,
@@ -1206,7 +1182,7 @@ class _WindowsAttentionBlock(nn.Module):
1206
1182
 
1207
1183
 
1208
1184
  class _TemporalConv(nn.Module):
1209
- r"""
1185
+ """
1210
1186
  Temporal Convolutional Module inspired by Visual Transformer.
1211
1187
 
1212
1188
  In this module we apply the follow steps three times repeatedly
@@ -1253,7 +1229,7 @@ class _TemporalConv(nn.Module):
1253
1229
  padding_1=(0, 7),
1254
1230
  kernel_size_2=(1, 3),
1255
1231
  padding_2=(0, 1),
1256
- activation: type[nn.Module] = nn.GELU,
1232
+ activation: nn.Module = nn.GELU,
1257
1233
  ):
1258
1234
  super().__init__()
1259
1235
 
@@ -27,9 +27,9 @@ from braindecode.modules.layers import DropPath
27
27
 
28
28
 
29
29
  class LUNA(EEGModuleMixin, nn.Module):
30
- r"""LUNA from Döner et al [LUNA]_.
30
+ """LUNA from Döner et al. [LUNA]_.
31
31
 
32
- :bdg-success:`Convolution` :bdg-danger:`Foundation Model` :bdg-dark-line:`Channel`
32
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model` :bdg-dark-line:`Channel`
33
33
 
34
34
  .. figure:: https://arxiv.org/html/2510.22257v1/x1.png
35
35
  :align: center
@@ -44,61 +44,6 @@ class LUNA(EEGModuleMixin, nn.Module):
44
44
  3. Patch-wise Temporal Encoder (RoPE-based transformer)
45
45
  4. Decoder Heads (classification or reconstruction)
46
46
 
47
- .. important::
48
- **Pre-trained Weights Available**
49
-
50
- This model has pre-trained weights available on the Hugging Face Hub
51
- at `thorir/LUNA <https://huggingface.co/thorir/LUNA>`_.
52
-
53
- Available model variants:
54
-
55
- - **LUNA_base.safetensors** - Base model (embed_dim=64, num_queries=4, depth=8)
56
- - **LUNA_large.safetensors** - Large model (embed_dim=96, num_queries=6, depth=10)
57
- - **LUNA_huge.safetensors** - Huge model (embed_dim=128, num_queries=8, depth=24)
58
-
59
- Example loading for fine-tuning:
60
-
61
- .. code-block:: python
62
-
63
- from huggingface_hub import hf_hub_download
64
- from safetensors.torch import load_file
65
- from braindecode.models import LUNA
66
-
67
- # Download pre-trained weights
68
- model_path = hf_hub_download(
69
- repo_id="thorir/LUNA",
70
- filename="LUNA_base.safetensors",
71
- )
72
-
73
- # Create model for classification (fine-tuning)
74
- model = LUNA(
75
- n_outputs=2, # Number of classes for your task
76
- n_chans=22,
77
- n_times=1000,
78
- embed_dim=64,
79
- num_queries=4,
80
- depth=8,
81
- )
82
-
83
- # Load pre-trained encoder weights
84
- state_dict = load_file(model_path)
85
- # Apply key mapping for pretrained weights
86
- mapping = model.mapping.copy()
87
- mapping["cross_attn.temparature"] = "cross_attn.temperature"
88
- mapped_state_dict = {mapping.get(k, k): v for k, v in state_dict.items()}
89
- model.load_state_dict(mapped_state_dict, strict=False)
90
-
91
- To push your own trained model to the Hub:
92
-
93
- .. code-block:: python
94
-
95
- # After training your model
96
- model.push_to_hub(
97
- repo_id="username/my-luna-model", commit_message="Upload trained LUNA model"
98
- )
99
-
100
- Requires installing ``braindecode[hug]`` for Hub integration.
101
-
102
47
  Parameters
103
48
  ----------
104
49
  patch_size : int
@@ -457,7 +402,7 @@ def nerf_positional_encoding(coords: torch.Tensor, embed_size: int) -> torch.Ten
457
402
 
458
403
 
459
404
  class _ChannelEmbeddings(nn.Module):
460
- r"""
405
+ """
461
406
  This class creates embeddings for each EEG channel based on a predefined
462
407
  mapping of channel names to indices.
463
408
 
@@ -485,7 +430,7 @@ class _ChannelEmbeddings(nn.Module):
485
430
 
486
431
 
487
432
  class _FrequencyFeatureEmbedder(nn.Module):
488
- r"""
433
+ """
489
434
  This class takes data that is of the form (B, C, T) and patches it
490
435
  along the time dimension (T) into patches of size P (patch_size).
491
436
  The output is of the form (B, C, S, P) where S = T // P.
@@ -861,7 +806,7 @@ class _PatchEmbedNetwork(nn.Module):
861
806
 
862
807
 
863
808
  class _Mlp(nn.Module):
864
- r"""MLP as used in Vision Transformer, MLP-Mixer and related networks.
809
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks.
865
810
 
866
811
  Code copied from timm.models.mlp.Mlp
867
812
  """
@@ -18,21 +18,20 @@ from braindecode.models.base import EEGModuleMixin
18
18
 
19
19
 
20
20
  class MEDFormer(EEGModuleMixin, nn.Module):
21
- r"""
22
- Medformer from Wang et al (2024) [Medformer2024]_.
21
+ r"""Medformer from Wang et al. (2024) [Medformer2024]_.
23
22
 
24
- :bdg-success:`Convolution` :bdg-danger:`Foundation Model`
23
+ :bdg-success:`Convolution` :bdg-danger:`Large Brain Model`
25
24
 
26
25
  .. figure:: https://raw.githubusercontent.com/DL4mHealth/Medformer/refs/heads/main/figs/medformer_architecture.png
27
26
  :align: center
28
27
  :alt: MEDFormer Architecture.
29
28
 
30
- a) Workflow. b) For the input sample :math:`{x}_{\text{in}}`, the authors apply :math:`n`
29
+ a) Workflow. b) For the input sample :math:`{x}_{\\textrm{in}}`, the authors apply :math:`n`
31
30
  different patch lengths in parallel to create patched features :math:`{x}_p^{(i)}`, where :math:`i`
32
31
  ranges from 1 to :math:`n`. Each patch length represents a different granularity. These patched
33
- features are linearly transformed into :math:`{x}_e^{(i)}` and augmented into :math:`\widetilde{x}_e^{(i)}`.
34
- c) The final patch embedding :math:`{x}^{(i)}` fuses augmented :math:`\widetilde{x}_e^{(i)}` with the
35
- positional embedding :math:`{W}_{\text{pos}}` and the granularity embedding :math:`{W}_{\text{gr}}^{(i)}`.
32
+ features are linearly transformed into :math:`{x}_e^{(i)}` and augmented into :math:`\\widetilde{x}_e^{(i)}`.
33
+ c) The final patch embedding :math:`{x}^{(i)}` fuses augmented :math:`\\widetilde{{x}}_e^{(i)}` with the
34
+ positional embedding :math:`{W}_{\\text{pos}}` and the granularity embedding :math:`{W}_{\\text{gr}}^{(i)}`.
36
35
  Each granularity employs a router :math:`{u}^{(i)}` to capture aggregated information.
37
36
  Intra-granularity attention focuses within individual granularities, and inter-granularity attention
38
37
  leverages the routers to integrate information across granularities.
@@ -116,7 +115,6 @@ class MEDFormer(EEGModuleMixin, nn.Module):
116
115
  **Role.** Learns representations and correlations within and across temporal scales while
117
116
  reducing complexity from :math:`O((\sum_i N_i)^2)` to
118
117
  :math:`O(\sum_i N_i^2 + n^2)` through the router mechanism.
119
-
120
118
  .. rubric:: Temporal, Spatial, and Spectral Encoding
121
119
 
122
120
  - **Temporal:** Multiple patch lengths in :attr:`patch_len_list` capture features at several
@@ -130,7 +128,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
130
128
  .. rubric:: Additional Mechanisms
131
129
 
132
130
  - **Granularity router:** Each granularity :math:`i` receives a dedicated router token
133
- :math:`\mathbf{u}^{(i)}`. Intra-attention updates the token, and inter-attention exchanges
131
+ :math:`\\mathbf{u}^{(i)}`. Intra-attention updates the token, and inter-attention exchanges
134
132
  aggregated information across scales.
135
133
  - **Complexity:** Router-mediated two-stage attention maintains :math:`O(T^2)` complexity for
136
134
  suitable patch lengths (e.g., power series), preserving transformer-like efficiency while
@@ -141,7 +139,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
141
139
  patch_len_list : list of int, optional
142
140
  Patch lengths for multi-granularity patching; each entry selects a temporal scale.
143
141
  The default is ``[14, 44, 45]``.
144
- embed_dim : int, optional
142
+ d_model : int, optional
145
143
  Embedding dimensionality. The default is ``128``.
146
144
  num_heads : int, optional
147
145
  Number of attention heads, which must divide :attr:`d_model`. The default is ``8``.
@@ -149,7 +147,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
149
147
  Dropout probability. The default is ``0.1``.
150
148
  no_inter_attn : bool, optional
151
149
  If ``True``, disables inter-granularity attention. The default is ``False``.
152
- num_layers : int, optional
150
+ n_layers : int, optional
153
151
  Number of encoder layers. The default is ``6``.
154
152
  dim_feedforward : int, optional
155
153
  Feedforward dimensionality. The default is ``256``.
@@ -191,16 +189,16 @@ class MEDFormer(EEGModuleMixin, nn.Module):
191
189
  sfreq=None,
192
190
  # Model parameters
193
191
  patch_len_list: Optional[List[int]] = None,
194
- embed_dim: int = 128,
192
+ d_model: int = 128,
195
193
  num_heads: int = 8,
196
194
  drop_prob: float = 0.1,
197
195
  no_inter_attn: bool = False,
198
- num_layers: int = 6,
196
+ n_layers: int = 6,
199
197
  dim_feedforward: int = 256,
200
- activation_trans: type[nn.Module] | None = nn.ReLU,
198
+ activation_trans: Optional[nn.Module] = nn.ReLU,
201
199
  single_channel: bool = False,
202
200
  output_attention: bool = True,
203
- activation_class: type[nn.Module] | None = nn.GELU,
201
+ activation_class: Optional[nn.Module] = nn.GELU,
204
202
  ):
205
203
  super().__init__(
206
204
  n_outputs=n_outputs,
@@ -217,11 +215,11 @@ class MEDFormer(EEGModuleMixin, nn.Module):
217
215
  # - enc_in refers to the number of time points
218
216
 
219
217
  # Save model parameters as instance variables
220
- self.embed_dim = embed_dim
218
+ self.d_model = d_model
221
219
  self.num_heads = num_heads
222
220
  self.drop_prob = drop_prob
223
221
  self.no_inter_attn = no_inter_attn
224
- self.num_layers = num_layers
222
+ self.n_layers = n_layers
225
223
  self.dim_feedforward = dim_feedforward
226
224
  self.activation_trans = activation_trans
227
225
  self.output_attention = output_attention
@@ -244,7 +242,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
244
242
  # Initialize the embedding layer.
245
243
  self.enc_embedding = _ListPatchEmbedding(
246
244
  enc_in=self.n_times,
247
- d_model=self.embed_dim,
245
+ d_model=self.d_model,
248
246
  seq_len=self.n_chans,
249
247
  patch_len_list=self.patch_len_list,
250
248
  stride_list=self.stride_list,
@@ -259,22 +257,22 @@ class MEDFormer(EEGModuleMixin, nn.Module):
259
257
  _EncoderLayer(
260
258
  attention=_MedformerLayer(
261
259
  num_blocks=len(self.patch_len_list),
262
- d_model=self.embed_dim,
260
+ d_model=self.d_model,
263
261
  num_heads=self.num_heads,
264
262
  dropout=self.drop_prob,
265
263
  output_attention=self.output_attention,
266
264
  no_inter=self.no_inter_attn,
267
265
  ),
268
- d_model=self.embed_dim,
266
+ d_model=self.d_model,
269
267
  dim_feedforward=self.dim_feedforward,
270
268
  dropout=self.drop_prob,
271
269
  activation=self.activation_trans()
272
270
  if self.activation_trans is not None
273
271
  else nn.ReLU(),
274
272
  )
275
- for _ in range(self.num_layers)
273
+ for _ in range(self.n_layers)
276
274
  ],
277
- norm_layer=torch.nn.LayerNorm(self.embed_dim),
275
+ norm_layer=torch.nn.LayerNorm(self.d_model),
278
276
  )
279
277
 
280
278
  # For classification tasks, add additional layers.
@@ -283,7 +281,7 @@ class MEDFormer(EEGModuleMixin, nn.Module):
283
281
  )
284
282
  self.dropout = nn.Dropout(self.drop_prob)
285
283
  self.final_layer = nn.Linear(
286
- self.embed_dim
284
+ self.d_model
287
285
  * len(self.patch_num_list)
288
286
  * (1 if not self.single_channel else self.n_chans),
289
287
  self.n_outputs,
@@ -11,9 +11,9 @@ from braindecode.models.base import EEGModuleMixin
11
11
 
12
12
 
13
13
  class MSVTNet(EEGModuleMixin, nn.Module):
14
- r"""MSVTNet model from Liu K et al (2024) from [msvt2024]_.
14
+ """MSVTNet model from Liu K et al (2024) from [msvt2024]_.
15
15
 
16
- :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Attention/Transformer`
16
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent` :bdg-info:`Small Attention`
17
17
 
18
18
  This model implements a multi-scale convolutional transformer network
19
19
  for EEG signal classification, as described in [msvt2024]_.
@@ -41,9 +41,9 @@ class MSVTNet(EEGModuleMixin, nn.Module):
41
41
  Dropout probability for convolutional layers, by default 0.3.
42
42
  num_heads : int, optional
43
43
  Number of attention heads in the transformer encoder, by default 8.
44
- ffn_expansion_factor : float, optional
44
+ feedforward_ratio : float, optional
45
45
  Ratio to compute feedforward dimension in the transformer, by default 1.
46
- att_drop_prob : float, optional
46
+ drop_prob_trans : float, optional
47
47
  Dropout probability for the transformer, by default 0.5.
48
48
  num_layers : int, optional
49
49
  Number of transformer encoder layers, by default 2.
@@ -85,8 +85,8 @@ class MSVTNet(EEGModuleMixin, nn.Module):
85
85
  pool2_size: int = 7,
86
86
  drop_prob: float = 0.3,
87
87
  num_heads: int = 8,
88
- ffn_expansion_factor: float = 1,
89
- att_drop_prob: float = 0.5,
88
+ feedforward_ratio: float = 1,
89
+ drop_prob_trans: float = 0.5,
90
90
  num_layers: int = 2,
91
91
  activation: Type[nn.Module] = nn.ELU,
92
92
  return_features: bool = False,
@@ -139,8 +139,8 @@ class MSVTNet(EEGModuleMixin, nn.Module):
139
139
  seq_len,
140
140
  d_model,
141
141
  num_heads,
142
- ffn_expansion_factor,
143
- att_drop_prob,
142
+ feedforward_ratio,
143
+ drop_prob_trans,
144
144
  num_layers,
145
145
  )
146
146
 
@@ -193,7 +193,7 @@ class MSVTNet(EEGModuleMixin, nn.Module):
193
193
 
194
194
 
195
195
  class _TSConv(nn.Sequential):
196
- r"""
196
+ """
197
197
  Time-Distributed Separable Convolution block.
198
198
 
199
199
  The architecture consists of:
@@ -280,7 +280,7 @@ class _TSConv(nn.Sequential):
280
280
 
281
281
 
282
282
  class _PositionalEncoding(nn.Module):
283
- r"""
283
+ """
284
284
  Positional encoding module that adds learnable positional embeddings.
285
285
 
286
286
  Parameters
@@ -303,7 +303,7 @@ class _PositionalEncoding(nn.Module):
303
303
 
304
304
 
305
305
  class _Transformer(nn.Module):
306
- r"""
306
+ """
307
307
  Transformer encoder module with learnable class token and positional encoding.
308
308
 
309
309
  Parameters
@@ -314,7 +314,7 @@ class _Transformer(nn.Module):
314
314
  Dimensionality of the model.
315
315
  num_heads : int
316
316
  Number of heads in the multihead attention.
317
- ffn_expansion_factor : float
317
+ feedforward_ratio : float
318
318
  Ratio to compute the dimension of the feedforward network.
319
319
  drop_prob : float, optional
320
320
  Dropout probability, by default 0.5.
@@ -327,7 +327,7 @@ class _Transformer(nn.Module):
327
327
  seq_length: int,
328
328
  d_model: int,
329
329
  num_heads: int,
330
- ffn_expansion_factor: float,
330
+ feedforward_ratio: float,
331
331
  drop_prob: float = 0.5,
332
332
  num_layers: int = 4,
333
333
  ) -> None:
@@ -335,7 +335,7 @@ class _Transformer(nn.Module):
335
335
  self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
336
336
  self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
337
337
 
338
- dim_ff = int(d_model * ffn_expansion_factor)
338
+ dim_ff = int(d_model * feedforward_ratio)
339
339
  self.dropout = nn.Dropout(drop_prob)
340
340
  self.trans = nn.TransformerEncoder(
341
341
  nn.TransformerEncoderLayer(
@@ -359,7 +359,7 @@ class _Transformer(nn.Module):
359
359
 
360
360
 
361
361
  class _DenseLayers(nn.Sequential):
362
- r"""
362
+ """
363
363
  Final classification layers.
364
364
 
365
365
  Parameters