braindecode 1.3.0.dev177069446__py3-none-any.whl → 1.3.0.dev177628147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. braindecode/augmentation/__init__.py +3 -5
  2. braindecode/augmentation/base.py +5 -8
  3. braindecode/augmentation/functional.py +22 -25
  4. braindecode/augmentation/transforms.py +42 -51
  5. braindecode/classifier.py +16 -11
  6. braindecode/datasets/__init__.py +3 -5
  7. braindecode/datasets/base.py +13 -17
  8. braindecode/datasets/bbci.py +14 -13
  9. braindecode/datasets/bcicomp.py +5 -4
  10. braindecode/datasets/{bids/datasets.py → bids.py} +18 -12
  11. braindecode/datasets/{bids/iterable.py → experimental.py} +6 -8
  12. braindecode/datasets/{bids/hub.py → hub.py} +350 -375
  13. braindecode/datasets/{bids/hub_validation.py → hub_validation.py} +1 -2
  14. braindecode/datasets/mne.py +19 -19
  15. braindecode/datasets/moabb.py +10 -10
  16. braindecode/datasets/nmt.py +56 -58
  17. braindecode/datasets/sleep_physio_challe_18.py +5 -3
  18. braindecode/datasets/sleep_physionet.py +5 -5
  19. braindecode/datasets/tuh.py +18 -21
  20. braindecode/datasets/xy.py +9 -10
  21. braindecode/datautil/__init__.py +3 -3
  22. braindecode/datautil/serialization.py +20 -22
  23. braindecode/datautil/util.py +7 -120
  24. braindecode/eegneuralnet.py +52 -22
  25. braindecode/functional/functions.py +10 -7
  26. braindecode/functional/initialization.py +2 -3
  27. braindecode/models/__init__.py +3 -5
  28. braindecode/models/atcnet.py +39 -43
  29. braindecode/models/attentionbasenet.py +41 -37
  30. braindecode/models/attn_sleep.py +24 -26
  31. braindecode/models/base.py +6 -6
  32. braindecode/models/bendr.py +26 -50
  33. braindecode/models/biot.py +30 -61
  34. braindecode/models/contrawr.py +5 -5
  35. braindecode/models/ctnet.py +35 -35
  36. braindecode/models/deep4.py +5 -5
  37. braindecode/models/deepsleepnet.py +7 -7
  38. braindecode/models/eegconformer.py +26 -31
  39. braindecode/models/eeginception_erp.py +2 -2
  40. braindecode/models/eeginception_mi.py +6 -6
  41. braindecode/models/eegitnet.py +5 -5
  42. braindecode/models/eegminer.py +1 -1
  43. braindecode/models/eegnet.py +3 -3
  44. braindecode/models/eegnex.py +2 -2
  45. braindecode/models/eegsimpleconv.py +2 -2
  46. braindecode/models/eegsym.py +7 -7
  47. braindecode/models/eegtcnet.py +6 -6
  48. braindecode/models/fbcnet.py +2 -2
  49. braindecode/models/fblightconvnet.py +3 -3
  50. braindecode/models/fbmsnet.py +3 -3
  51. braindecode/models/hybrid.py +2 -2
  52. braindecode/models/ifnet.py +5 -5
  53. braindecode/models/labram.py +46 -70
  54. braindecode/models/luna.py +5 -60
  55. braindecode/models/medformer.py +21 -23
  56. braindecode/models/msvtnet.py +15 -15
  57. braindecode/models/patchedtransformer.py +55 -55
  58. braindecode/models/sccnet.py +2 -2
  59. braindecode/models/shallow_fbcsp.py +3 -5
  60. braindecode/models/signal_jepa.py +12 -39
  61. braindecode/models/sinc_shallow.py +4 -3
  62. braindecode/models/sleep_stager_blanco_2020.py +2 -2
  63. braindecode/models/sleep_stager_chambon_2018.py +2 -2
  64. braindecode/models/sparcnet.py +8 -8
  65. braindecode/models/sstdpn.py +869 -869
  66. braindecode/models/summary.csv +17 -19
  67. braindecode/models/syncnet.py +2 -2
  68. braindecode/models/tcn.py +5 -5
  69. braindecode/models/tidnet.py +3 -3
  70. braindecode/models/tsinception.py +3 -3
  71. braindecode/models/usleep.py +7 -7
  72. braindecode/models/util.py +14 -165
  73. braindecode/modules/__init__.py +1 -9
  74. braindecode/modules/activation.py +3 -29
  75. braindecode/modules/attention.py +0 -123
  76. braindecode/modules/blocks.py +1 -53
  77. braindecode/modules/convolution.py +0 -53
  78. braindecode/modules/filter.py +0 -31
  79. braindecode/modules/layers.py +0 -84
  80. braindecode/modules/linear.py +1 -22
  81. braindecode/modules/stats.py +0 -10
  82. braindecode/modules/util.py +0 -9
  83. braindecode/modules/wrapper.py +0 -17
  84. braindecode/preprocessing/preprocess.py +0 -3
  85. braindecode/regressor.py +18 -15
  86. braindecode/samplers/ssl.py +1 -1
  87. braindecode/util.py +28 -38
  88. braindecode/version.py +1 -1
  89. braindecode-1.3.0.dev177628147.dist-info/METADATA +202 -0
  90. braindecode-1.3.0.dev177628147.dist-info/RECORD +114 -0
  91. braindecode/datasets/bids/__init__.py +0 -54
  92. braindecode/datasets/bids/format.py +0 -717
  93. braindecode/datasets/bids/hub_format.py +0 -717
  94. braindecode/datasets/bids/hub_io.py +0 -197
  95. braindecode/datasets/chb_mit.py +0 -163
  96. braindecode/datasets/siena.py +0 -162
  97. braindecode/datasets/utils.py +0 -67
  98. braindecode/models/brainmodule.py +0 -845
  99. braindecode/models/config.py +0 -233
  100. braindecode/models/reve.py +0 -843
  101. braindecode-1.3.0.dev177069446.dist-info/METADATA +0 -230
  102. braindecode-1.3.0.dev177069446.dist-info/RECORD +0 -124
  103. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/WHEEL +0 -0
  104. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/LICENSE.txt +0 -0
  105. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/licenses/NOTICE.txt +0 -0
  106. {braindecode-1.3.0.dev177069446.dist-info → braindecode-1.3.0.dev177628147.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,9 @@ from braindecode.models.base import EEGModuleMixin
14
14
 
15
15
 
16
16
  class PBT(EEGModuleMixin, nn.Module):
17
- r"""Patched Brain Transformer (PBT) model from Klein et al (2025) [pbt]_.
17
+ r"""Patched Brain Transformer (PBT) model from Klein et al. (2025) [pbt]_.
18
18
 
19
- :bdg-danger:`Foundation Model`
19
+ :bdg-danger:`Large Brain Model`
20
20
 
21
21
  This implementation was based in https://github.com/timonkl/PatchedBrainTransformer/
22
22
 
@@ -136,7 +136,7 @@ class PBT(EEGModuleMixin, nn.Module):
136
136
 
137
137
  .. warning::
138
138
 
139
- **Important:** As the other Foundation Models in Braindecode, :class:`PBT` is
139
+ **Important:** As the other Large Brain Models in Braindecode, :class:`PBT` is
140
140
  designed for large-scale pre-training and fine-tuning. Training from
141
141
  scratch on small datasets may lead to suboptimal results. Cross-Dataset
142
142
  pre-training and subsequent fine-tuning is recommended to leverage the
@@ -146,9 +146,9 @@ class PBT(EEGModuleMixin, nn.Module):
146
146
  ----------
147
147
  d_input : int, optional
148
148
  Size (in samples) of each patch (token) extracted along the time axis.
149
- embed_dim : int, optional
149
+ d_model : int, optional
150
150
  Transformer embedding dimensionality.
151
- num_layers : int, optional
151
+ n_blocks : int, optional
152
152
  Number of Transformer encoder layers.
153
153
  num_heads : int, optional
154
154
  Number of attention heads.
@@ -190,13 +190,13 @@ class PBT(EEGModuleMixin, nn.Module):
190
190
  sfreq=None,
191
191
  # Model parameters
192
192
  d_input: int = 64,
193
- embed_dim: int = 128,
194
- num_layers: int = 4,
193
+ d_model: int = 128,
194
+ n_blocks: int = 4,
195
195
  num_heads: int = 4,
196
196
  drop_prob: float = 0.1,
197
197
  learnable_cls=True,
198
198
  bias_transformer=False,
199
- activation: type[nn.Module] = nn.GELU,
199
+ activation: nn.Module = nn.GELU,
200
200
  ) -> None:
201
201
  super().__init__(
202
202
  n_outputs=n_outputs,
@@ -209,8 +209,8 @@ class PBT(EEGModuleMixin, nn.Module):
209
209
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
210
210
  # Store hyperparameters
211
211
  self.d_input = d_input
212
- self.embed_dim = embed_dim
213
- self.num_layers = num_layers
212
+ self.d_model = d_model
213
+ self.n_blocks = n_blocks
214
214
  self.num_heads = num_heads
215
215
  self.drop_prob = drop_prob
216
216
 
@@ -219,11 +219,11 @@ class PBT(EEGModuleMixin, nn.Module):
219
219
 
220
220
  # Classification token (learnable or fixed zero)
221
221
  if learnable_cls:
222
- self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim) * 0.002)
222
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.002)
223
223
  else:
224
224
  # non-learnable zeroed tensor
225
225
  self.cls_token = torch.full(
226
- size=(1, 1, self.embed_dim),
226
+ size=(1, 1, self.d_model),
227
227
  fill_value=0,
228
228
  requires_grad=False,
229
229
  dtype=torch.float32,
@@ -234,20 +234,20 @@ class PBT(EEGModuleMixin, nn.Module):
234
234
  n_chans=self.n_chans, n_times=self.n_times, d_input=self.d_input
235
235
  )
236
236
 
237
- # Linear patch projection from token raw-size -> embed_dim
237
+ # Linear patch projection from token raw-size -> d_model
238
238
  self.patching_projection = nn.Linear(
239
- in_features=self.d_input, out_features=self.embed_dim, bias=False
239
+ in_features=self.d_input, out_features=self.d_model, bias=False
240
240
  )
241
241
 
242
- # actual embedding table mapping indices -> embed_dim
242
+ # actual embedding table mapping indices -> d_model
243
243
  self.pos_embedding = nn.Embedding(
244
- num_embeddings=self.num_embeddings + 1, embedding_dim=self.embed_dim
244
+ num_embeddings=self.num_embeddings + 1, embedding_dim=self.d_model
245
245
  )
246
246
 
247
247
  # Transformer encoder stack
248
248
  self.transformer_encoder = _TransformerEncoder(
249
- num_layers=num_layers,
250
- embed_dim=self.embed_dim,
249
+ n_blocks=n_blocks,
250
+ d_model=self.d_model,
251
251
  n_head=num_heads,
252
252
  drop_prob=drop_prob,
253
253
  bias=bias_transformer,
@@ -256,7 +256,7 @@ class PBT(EEGModuleMixin, nn.Module):
256
256
 
257
257
  # classification head on classify token - CLS token
258
258
  self.final_layer = nn.Linear(
259
- in_features=embed_dim, out_features=self.n_outputs, bias=True
259
+ in_features=d_model, out_features=self.n_outputs, bias=True
260
260
  )
261
261
 
262
262
  # initialize weights
@@ -305,7 +305,7 @@ class PBT(EEGModuleMixin, nn.Module):
305
305
 
306
306
 
307
307
  class _LayerNorm(nn.Module):
308
- r"""Layer normalization with optional bias.
308
+ """Layer normalization with optional bias.
309
309
 
310
310
  Simple wrapper around :func:`torch.nn.functional.layer_norm` exposing a
311
311
  learnable scale and optional bias.
@@ -346,7 +346,7 @@ class _LayerNorm(nn.Module):
346
346
 
347
347
 
348
348
  class _MHSA(nn.Module):
349
- r"""Multi-head self-attention (MHSA) block.
349
+ """Multi-head self-attention (MHSA) block.
350
350
 
351
351
  Implements a standard multi-head attention mechanism with optional
352
352
  use of PyTorch's scaled_dot_product_attention (FlashAttention) when
@@ -354,7 +354,7 @@ class _MHSA(nn.Module):
354
354
 
355
355
  Parameters
356
356
  ----------
357
- embed_dim : int
357
+ d_model : int
358
358
  Dimensionality of the model / embeddings.
359
359
  n_head : int
360
360
  Number of attention heads.
@@ -366,25 +366,25 @@ class _MHSA(nn.Module):
366
366
 
367
367
  def __init__(
368
368
  self,
369
- embed_dim: int,
369
+ d_model: int,
370
370
  n_head: int,
371
371
  bias: bool,
372
372
  drop_prob: float = 0.0,
373
373
  ) -> None:
374
374
  super().__init__()
375
375
 
376
- assert embed_dim % n_head == 0, "embed_dim must be divisible by n_head"
376
+ assert d_model % n_head == 0, "d_model must be divisible by n_head"
377
377
 
378
378
  # qkv and output projection
379
- self.attn = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
380
- self.proj = nn.Linear(embed_dim, embed_dim, bias=bias)
379
+ self.attn = nn.Linear(d_model, 3 * d_model, bias=bias)
380
+ self.proj = nn.Linear(d_model, d_model, bias=bias)
381
381
 
382
382
  # dropout modules
383
383
  self.attn_drop_prob = nn.Dropout(drop_prob)
384
384
  self.resid_drop_prob = nn.Dropout(drop_prob)
385
385
 
386
386
  self.n_head = n_head
387
- self.embed_dim = embed_dim
387
+ self.d_model = d_model
388
388
  self.drop_prob = drop_prob
389
389
 
390
390
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -393,7 +393,7 @@ class _MHSA(nn.Module):
393
393
  Parameters
394
394
  ----------
395
395
  x : torch.Tensor
396
- Input tensor of shape (B, T, C) where C == embed_dim.
396
+ Input tensor of shape (B, T, C) where C == d_model.
397
397
 
398
398
  Returns
399
399
  -------
@@ -404,7 +404,7 @@ class _MHSA(nn.Module):
404
404
  B, T, C = x.size()
405
405
 
406
406
  # project to q, k, v and reshape for multi-head attention
407
- q, k, v = self.attn(x).split(self.embed_dim, dim=2)
407
+ q, k, v = self.attn(x).split(self.d_model, dim=2)
408
408
 
409
409
  # (B, nh, T, hs)
410
410
  q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
@@ -429,14 +429,14 @@ class _MHSA(nn.Module):
429
429
 
430
430
 
431
431
  class _FeedForward(nn.Module):
432
- r"""Position-wise feed-forward network from Transformer.
432
+ """Position-wise feed-forward network from Transformer.
433
433
 
434
434
  Implements the two-layer MLP with GELU activation and dropout used in
435
435
  Transformer architectures.
436
436
 
437
437
  Parameters
438
438
  ----------
439
- embed_dim : int
439
+ d_model : int
440
440
  Input and output dimensionality.
441
441
  dim_feedforward : int, optional
442
442
  Hidden dimensionality of the feed-forward layer. If None, must be provided by caller.
@@ -448,20 +448,20 @@ class _FeedForward(nn.Module):
448
448
 
449
449
  def __init__(
450
450
  self,
451
- embed_dim: int,
451
+ d_model: int,
452
452
  dim_feedforward: Optional[int] = None,
453
453
  drop_prob: float = 0.0,
454
454
  bias: bool = False,
455
- activation: type[nn.Module] = nn.GELU,
455
+ activation: nn.Module = nn.GELU,
456
456
  ) -> None:
457
457
  super().__init__()
458
458
 
459
459
  if dim_feedforward is None:
460
460
  raise ValueError("dim_feedforward must be provided")
461
461
 
462
- self.proj_in = nn.Linear(embed_dim, dim_feedforward, bias=bias)
462
+ self.proj_in = nn.Linear(d_model, dim_feedforward, bias=bias)
463
463
  self.activation = activation()
464
- self.proj = nn.Linear(dim_feedforward, embed_dim, bias=bias)
464
+ self.proj = nn.Linear(dim_feedforward, d_model, bias=bias)
465
465
  self.drop_prob = nn.Dropout(drop_prob)
466
466
  self.drop_prob1 = nn.Dropout(drop_prob)
467
467
 
@@ -476,7 +476,7 @@ class _FeedForward(nn.Module):
476
476
 
477
477
 
478
478
  class _TransformerEncoderLayer(nn.Module):
479
- r"""Single Transformer encoder layer (pre-norm) combining MHSA and feed-forward.
479
+ """Single Transformer encoder layer (pre-norm) combining MHSA and feed-forward.
480
480
 
481
481
  The block follows the pattern:
482
482
  x <- x + MHSA(_LayerNorm(x))
@@ -485,27 +485,27 @@ class _TransformerEncoderLayer(nn.Module):
485
485
 
486
486
  def __init__(
487
487
  self,
488
- embed_dim: int,
488
+ d_model: int,
489
489
  n_head: int,
490
490
  drop_prob: float = 0.0,
491
491
  dim_feedforward: Optional[int] = None,
492
492
  bias: bool = False,
493
- activation: type[nn.Module] = nn.GELU,
493
+ activation: nn.Module = nn.GELU,
494
494
  ) -> None:
495
495
  super().__init__()
496
496
 
497
497
  if dim_feedforward is None:
498
- dim_feedforward = 4 * embed_dim
498
+ dim_feedforward = 4 * d_model
499
499
  # note: preserve the original behaviour (print) from the provided code
500
500
  print(
501
- "dim_feedforward is set to 4*embed_dim, the default in Vaswani et al. (Attention is all you need)"
501
+ "dim_feedforward is set to 4*d_model, the default in Vaswani et al. (Attention is all you need)"
502
502
  )
503
503
 
504
- self.layer_norm_att = _LayerNorm(embed_dim, bias=bias)
505
- self.mhsa = _MHSA(embed_dim, n_head, bias, drop_prob=drop_prob)
506
- self.layer_norm_ff = _LayerNorm(embed_dim, bias=bias)
504
+ self.layer_norm_att = _LayerNorm(d_model, bias=bias)
505
+ self.mhsa = _MHSA(d_model, n_head, bias, drop_prob=drop_prob)
506
+ self.layer_norm_ff = _LayerNorm(d_model, bias=bias)
507
507
  self.feed_forward = _FeedForward(
508
- embed_dim=embed_dim,
508
+ d_model=d_model,
509
509
  dim_feedforward=dim_feedforward,
510
510
  drop_prob=drop_prob,
511
511
  bias=bias,
@@ -518,7 +518,7 @@ class _TransformerEncoderLayer(nn.Module):
518
518
  Parameters
519
519
  ----------
520
520
  x : torch.Tensor
521
- Input of shape (B, T, embed_dim).
521
+ Input of shape (B, T, d_model).
522
522
 
523
523
  Returns
524
524
  -------
@@ -531,13 +531,13 @@ class _TransformerEncoderLayer(nn.Module):
531
531
 
532
532
 
533
533
  class _TransformerEncoder(nn.Module):
534
- r"""Stack of Transformer encoder layers.
534
+ """Stack of Transformer encoder layers.
535
535
 
536
536
  Parameters
537
537
  ----------
538
- num_layers : int
538
+ n_blocks : int
539
539
  Number of encoder layers to stack.
540
- embed_dim : int
540
+ d_model : int
541
541
  Dimensionality of embeddings.
542
542
  n_head : int
543
543
  Number of attention heads per layer.
@@ -549,26 +549,26 @@ class _TransformerEncoder(nn.Module):
549
549
 
550
550
  def __init__(
551
551
  self,
552
- num_layers: int,
553
- embed_dim: int,
552
+ n_blocks: int,
553
+ d_model: int,
554
554
  n_head: int,
555
555
  drop_prob: float,
556
556
  bias: bool,
557
- activation: type[nn.Module] = nn.GELU,
557
+ activation: nn.Module = nn.GELU,
558
558
  ) -> None:
559
559
  super().__init__()
560
560
 
561
561
  self.encoder_block = nn.ModuleList(
562
562
  [
563
563
  _TransformerEncoderLayer(
564
- embed_dim=embed_dim,
564
+ d_model=d_model,
565
565
  n_head=n_head,
566
566
  drop_prob=drop_prob,
567
567
  dim_feedforward=None,
568
568
  bias=bias,
569
569
  activation=activation,
570
570
  )
571
- for _ in range(num_layers)
571
+ for _ in range(n_blocks)
572
572
  ]
573
573
  )
574
574
 
@@ -576,7 +576,7 @@ class _TransformerEncoder(nn.Module):
576
576
  self.apply(self._init_weights)
577
577
  for pn, p in self.named_parameters():
578
578
  if pn.endswith("proj.weight"):
579
- torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * num_layers))
579
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_blocks))
580
580
 
581
581
  @staticmethod
582
582
  def _init_weights(module: nn.Module) -> None:
@@ -593,7 +593,7 @@ class _TransformerEncoder(nn.Module):
593
593
 
594
594
 
595
595
  class _Patcher(nn.Module):
596
- r"""Patching encoding helper.
596
+ """Patching encoding helper.
597
597
 
598
598
  This module "patchifies" the original X entry in a ViT manner.
599
599
 
@@ -15,7 +15,7 @@ from braindecode.modules import LogActivation
15
15
 
16
16
 
17
17
  class SCCNet(EEGModuleMixin, nn.Module):
18
- r"""SCCNet from Wei, C S (2019) [sccnet]_.
18
+ """SCCNet from Wei, C S (2019) [sccnet]_.
19
19
 
20
20
  :bdg-success:`Convolution`
21
21
 
@@ -155,7 +155,7 @@ class SCCNet(EEGModuleMixin, nn.Module):
155
155
  n_spatial_filters: int = 22,
156
156
  n_spatial_filters_smooth: int = 20,
157
157
  drop_prob: float = 0.5,
158
- activation: type[nn.Module] = LogActivation,
158
+ activation: nn.Module = LogActivation,
159
159
  batch_norm_momentum: float = 0.1,
160
160
  ):
161
161
  super().__init__(
@@ -2,8 +2,6 @@
2
2
  #
3
3
  # License: BSD (3-clause)
4
4
 
5
- from typing import Callable
6
-
7
5
  from einops.layers.torch import Rearrange
8
6
  from torch import nn
9
7
  from torch.nn import init
@@ -20,7 +18,7 @@ from braindecode.modules import (
20
18
 
21
19
 
22
20
  class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
23
- r"""Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
21
+ """Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.
24
22
 
25
23
  :bdg-success:`Convolution`
26
24
 
@@ -83,9 +81,9 @@ class ShallowFBCSPNet(EEGModuleMixin, nn.Sequential):
83
81
  pool_time_length=75,
84
82
  pool_time_stride=15,
85
83
  final_conv_length="auto",
86
- conv_nonlin: Callable = square,
84
+ conv_nonlin=square,
87
85
  pool_mode="mean",
88
- activation_pool_nonlin: type[nn.Module] = SafeLog,
86
+ activation_pool_nonlin: nn.Module = SafeLog,
89
87
  split_first_layer=True,
90
88
  batch_norm=True,
91
89
  batch_norm_alpha=0.1,
@@ -24,7 +24,7 @@ _DEFAULT_CONV_LAYER_SPEC = ( # downsampling: 128Hz -> 1Hz, receptive field 1.18
24
24
 
25
25
 
26
26
  class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
27
- r"""Base class for the SignalJEPA models
27
+ """Base class for the SignalJEPA models
28
28
 
29
29
  Parameters
30
30
  ----------
@@ -144,9 +144,9 @@ class _BaseSignalJEPA(EEGModuleMixin, nn.Module):
144
144
 
145
145
 
146
146
  class SignalJEPA(_BaseSignalJEPA):
147
- r"""Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
147
+ """Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_
148
148
 
149
- :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
149
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
150
150
 
151
151
  This model is not meant for classification but for SSL pre-training.
152
152
  Its output shape depends on the input shape.
@@ -232,9 +232,9 @@ class SignalJEPA(_BaseSignalJEPA):
232
232
 
233
233
 
234
234
  class SignalJEPA_Contextual(_BaseSignalJEPA):
235
- r"""Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
235
+ """Contextual downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
236
236
 
237
- :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
237
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
238
238
 
239
239
  This architecture is one of the variants of :class:`SignalJEPA`
240
240
  that can be used for classification purposes.
@@ -405,9 +405,9 @@ class SignalJEPA_Contextual(_BaseSignalJEPA):
405
405
 
406
406
 
407
407
  class SignalJEPA_PostLocal(_BaseSignalJEPA):
408
- r"""Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
408
+ """Post-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
409
409
 
410
- :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
410
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
411
411
 
412
412
  This architecture is one of the variants of :class:`SignalJEPA`
413
413
  that can be used for classification purposes.
@@ -556,9 +556,9 @@ class SignalJEPA_PostLocal(_BaseSignalJEPA):
556
556
 
557
557
 
558
558
  class SignalJEPA_PreLocal(_BaseSignalJEPA):
559
- r"""Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
559
+ """Pre-local downstream architecture introduced in signal-JEPA Guetschel, P et al (2024) [1]_.
560
560
 
561
- :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Foundation Model`
561
+ :bdg-success:`Convolution` :bdg-dark-line:`Channel` :bdg-danger:`Large Brain Model`
562
562
 
563
563
  This architecture is one of the variants of :class:`SignalJEPA`
564
564
  that can be used for classification purposes.
@@ -569,33 +569,6 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
569
569
 
570
570
  .. versionadded:: 0.9
571
571
 
572
- .. important::
573
- **Pre-trained Weights Available**
574
-
575
- This model has pre-trained weights available on the Hugging Face Hub.
576
- You can load them using:
577
-
578
- .. code-block:: python
579
-
580
- from braindecode.models import SignalJEPA_PreLocal
581
-
582
- # Load pre-trained model from Hugging Face Hub
583
- model = SignalJEPA_PreLocal.from_pretrained(
584
- "braindecode/SignalJEPA-PreLocal-pretrained"
585
- )
586
-
587
- To push your own trained model to the Hub:
588
-
589
- .. code-block:: python
590
-
591
- # After training your model
592
- model.push_to_hub(
593
- repo_id="username/my-sjepa-model",
594
- commit_message="Upload trained SignalJEPA model",
595
- )
596
-
597
- Requires installing ``braindecode[hug]`` for Hub integration.
598
-
599
572
  Parameters
600
573
  ----------
601
574
  n_spat_filters : int
@@ -745,7 +718,7 @@ class SignalJEPA_PreLocal(_BaseSignalJEPA):
745
718
 
746
719
 
747
720
  class _ConvFeatureEncoder(nn.Sequential):
748
- r"""Convolutional feature encoder for EEG data.
721
+ """Convolutional feature encoder for EEG data.
749
722
 
750
723
  Computes successive 1D convolutions (with activations) over the time
751
724
  dimension of the input EEG signal.
@@ -865,7 +838,7 @@ class _ConvFeatureEncoder(nn.Sequential):
865
838
 
866
839
 
867
840
  class _ChannelEmbedding(nn.Embedding):
868
- r"""Embedding layer for EEG channels.
841
+ """Embedding layer for EEG channels.
869
842
 
870
843
  The difference with a regular :class:`nn.Embedding` is that the embedding
871
844
  vectors are initialized with a positional encodding of the channel locations.
@@ -926,7 +899,7 @@ class _ChannelEmbedding(nn.Embedding):
926
899
 
927
900
 
928
901
  class _PosEncoder(nn.Module):
929
- r"""Positional encoder for EEG data.
902
+ """Positional encoder for EEG data.
930
903
 
931
904
  Parameters
932
905
  ----------
@@ -1,4 +1,5 @@
1
1
  import math
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  import torch.nn as nn
@@ -9,7 +10,7 @@ from braindecode.models.base import EEGModuleMixin
9
10
 
10
11
 
11
12
  class SincShallowNet(EEGModuleMixin, nn.Module):
12
- r"""Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
13
+ """Sinc-ShallowNet from Borra, D et al (2020) [borra2020]_.
13
14
 
14
15
  :bdg-success:`Convolution` :bdg-warning:`Interpretability`
15
16
 
@@ -92,7 +93,7 @@ class SincShallowNet(EEGModuleMixin, nn.Module):
92
93
  num_time_filters: int = 32,
93
94
  time_filter_len: int = 33,
94
95
  depth_multiplier: int = 2,
95
- activation: type[nn.Module] | None = nn.ELU,
96
+ activation: Optional[nn.Module] = nn.ELU,
96
97
  drop_prob: float = 0.5,
97
98
  first_freq: float = 5.0,
98
99
  min_freq: float = 1.0,
@@ -210,7 +211,7 @@ class SincShallowNet(EEGModuleMixin, nn.Module):
210
211
 
211
212
 
212
213
  class _SincFilter(nn.Module):
213
- r"""Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
214
+ """Sinc-Based Convolutional Layer for Band-Pass Filtering from Ravanelli and Bengio (2018) [ravanelli]_.
214
215
 
215
216
  The `SincFilter` layer implements a convolutional layer where each kernel is
216
217
  defined using a parametrized sinc function.
@@ -9,7 +9,7 @@ from braindecode.models.base import EEGModuleMixin
9
9
 
10
10
 
11
11
  class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
12
- r"""Sleep staging architecture from Blanco et al (2020) from [Blanco2020]_
12
+ """Sleep staging architecture from Blanco et al. (2020) from [Blanco2020]_
13
13
 
14
14
  :bdg-success:`Convolution`
15
15
 
@@ -68,7 +68,7 @@ class SleepStagerBlanco2020(EEGModuleMixin, nn.Module):
68
68
  drop_prob=0.5,
69
69
  apply_batch_norm=False,
70
70
  return_feats=False,
71
- activation: type[nn.Module] = nn.ReLU,
71
+ activation: nn.Module = nn.ReLU,
72
72
  chs_info=None,
73
73
  n_times=None,
74
74
  ):
@@ -11,7 +11,7 @@ from braindecode.models.base import EEGModuleMixin
11
11
 
12
12
 
13
13
  class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
14
- r"""Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
14
+ """Sleep staging architecture from Chambon et al. (2018) [Chambon2018]_.
15
15
 
16
16
  :bdg-success:`Convolution`
17
17
 
@@ -70,7 +70,7 @@ class SleepStagerChambon2018(EEGModuleMixin, nn.Module):
70
70
  time_conv_size_s=0.5,
71
71
  max_pool_size_s=0.125,
72
72
  pad_size_s=0.25,
73
- activation: type[nn.Module] = nn.ReLU,
73
+ activation: nn.Module = nn.ReLU,
74
74
  input_window_seconds=None,
75
75
  n_outputs=5,
76
76
  drop_prob=0.25,
@@ -11,7 +11,7 @@ from braindecode.models.base import EEGModuleMixin
11
11
 
12
12
 
13
13
  class SPARCNet(EEGModuleMixin, nn.Module):
14
- r"""Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al (2023) [jing2023]_.
14
+ """Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al. (2023) [jing2023]_.
15
15
 
16
16
  :bdg-success:`Convolution`
17
17
 
@@ -69,7 +69,7 @@ class SPARCNet(EEGModuleMixin, nn.Module):
69
69
  drop_prob: float = 0.5,
70
70
  conv_bias: bool = True,
71
71
  batch_norm: bool = True,
72
- activation: type[nn.Module] = nn.ELU,
72
+ activation: nn.Module = nn.ELU,
73
73
  kernel_size_conv0: int = 7,
74
74
  kernel_size_conv1: int = 1,
75
75
  kernel_size_conv2: int = 3,
@@ -213,7 +213,7 @@ class SPARCNet(EEGModuleMixin, nn.Module):
213
213
 
214
214
 
215
215
  class _DenseLayer(nn.Sequential):
216
- r"""
216
+ """
217
217
  A densely connected layer with batch normalization and dropout.
218
218
 
219
219
  Parameters
@@ -252,7 +252,7 @@ class _DenseLayer(nn.Sequential):
252
252
  drop_prob: float = 0.5,
253
253
  conv_bias: bool = True,
254
254
  batch_norm: bool = True,
255
- activation: type[nn.Module] = nn.ELU,
255
+ activation: nn.Module = nn.ELU,
256
256
  kernel_size_conv1: int = 1,
257
257
  kernel_size_conv2: int = 3,
258
258
  stride_conv1: int = 1,
@@ -302,7 +302,7 @@ class _DenseLayer(nn.Sequential):
302
302
 
303
303
 
304
304
  class _DenseBlock(nn.Sequential):
305
- r"""
305
+ """
306
306
  A densely connected block that uses DenseLayers.
307
307
 
308
308
  Parameters
@@ -344,7 +344,7 @@ class _DenseBlock(nn.Sequential):
344
344
  drop_prob=0.5,
345
345
  conv_bias=True,
346
346
  batch_norm=True,
347
- activation: type[nn.Module] = nn.ELU,
347
+ activation: nn.Module = nn.ELU,
348
348
  kernel_size_conv1: int = 1,
349
349
  kernel_size_conv2: int = 3,
350
350
  stride_conv1: int = 1,
@@ -371,7 +371,7 @@ class _DenseBlock(nn.Sequential):
371
371
 
372
372
 
373
373
  class _TransitionLayer(nn.Sequential):
374
- r"""
374
+ """
375
375
  A pooling transition layer.
376
376
 
377
377
  Parameters
@@ -403,7 +403,7 @@ class _TransitionLayer(nn.Sequential):
403
403
  out_channels,
404
404
  conv_bias=True,
405
405
  batch_norm=True,
406
- activation: type[nn.Module] = nn.ELU,
406
+ activation: nn.Module = nn.ELU,
407
407
  kernel_size_trans: int = 2,
408
408
  stride_trans: int = 2,
409
409
  ):