braindecode 1.3.0.dev168011974__py3-none-any.whl → 1.3.0.dev171478045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

@@ -22,6 +22,8 @@ from braindecode.modules import MLP, DropPath
22
22
  class Labram(EEGModuleMixin, nn.Module):
23
23
  """Labram from Jiang, W B et al (2024) [Jiang2024]_.
24
24
 
25
+ :bdg-danger:`Large Brain Model`
26
+
25
27
  .. figure:: https://arxiv.org/html/2405.18765v1/x1.png
26
28
  :align: center
27
29
  :alt: Labram Architecture.
@@ -61,13 +63,24 @@ class Labram(EEGModuleMixin, nn.Module):
61
63
 
62
64
  .. versionadded:: 0.9
63
65
 
66
+
67
+ Examples on how to load pre-trained weights:
68
+ --------------------------------------------
69
+ >>> import torch
70
+ >>> from braindecode.models import Labram
71
+ >>> model = Labram(n_times=1600, n_chans=64, n_outputs=4)
72
+ >>> url = 'https://huggingface.co/braindecode/Labram-Braindecode/blob/main/braindecode_labram_base.pt'
73
+ >>> state = torch.hub.load_state_dict_from_url(url, progress=True)
74
+ >>> model.load_state_dict(state)
75
+
76
+
64
77
  Parameters
65
78
  ----------
66
79
  patch_size : int
67
80
  The size of the patch to be used in the patch embedding.
68
81
  emb_size : int
69
82
  The dimension of the embedding.
70
- in_channels : int
83
+ in_conv_channels : int
71
84
  The number of convolutional input channels.
72
85
  out_channels : int
73
86
  The number of convolutional output channels.
@@ -79,8 +92,10 @@ class Labram(EEGModuleMixin, nn.Module):
79
92
  The expansion ratio of the mlp layer
80
93
  qkv_bias : bool (default=False)
81
94
  If True, add a learnable bias to the query, key, and value tensors.
82
- qk_norm : Pytorch Normalize layer (default=None)
83
- If not None, apply LayerNorm to the query and key tensors
95
+ qk_norm : Pytorch Normalize layer (default=nn.LayerNorm)
96
+ If not None, apply LayerNorm to the query and key tensors.
97
+ Default is nn.LayerNorm for better weight transfer from original LaBraM.
98
+ Set to None to disable Q,K normalization.
84
99
  qk_scale : float (default=None)
85
100
  If not None, use this value as the scale factor. If None,
86
101
  use head_dim**-0.5, where head_dim = dim // num_heads.
@@ -92,9 +107,10 @@ class Labram(EEGModuleMixin, nn.Module):
92
107
  Dropout rate for the attention weights used on DropPath.
93
108
  norm_layer : Pytorch Normalize layer (default=nn.LayerNorm)
94
109
  The normalization layer to be used.
95
- init_values : float (default=None)
110
+ init_values : float (default=0.1)
96
111
  If not None, use this value to initialize the gamma_1 and gamma_2
97
- parameters.
112
+ parameters for residual scaling. Default is 0.1 for better weight
113
+ transfer from original LaBraM. Set to None to disable.
98
114
  use_abs_pos_emb : bool (default=True)
99
115
  If True, use absolute position embedding.
100
116
  use_mean_pooling : bool (default=True)
@@ -135,19 +151,19 @@ class Labram(EEGModuleMixin, nn.Module):
135
151
  input_window_seconds=None,
136
152
  patch_size=200,
137
153
  emb_size=200,
138
- in_channels=1,
154
+ in_conv_channels=1,
139
155
  out_channels=8,
140
156
  n_layers=12,
141
157
  att_num_heads=10,
142
158
  mlp_ratio=4.0,
143
159
  qkv_bias=False,
144
- qk_norm=None,
160
+ qk_norm=nn.LayerNorm,
145
161
  qk_scale=None,
146
162
  drop_prob=0.0,
147
163
  attn_drop_prob=0.0,
148
164
  drop_path_prob=0.0,
149
165
  norm_layer=nn.LayerNorm,
150
- init_values=None,
166
+ init_values=0.1,
151
167
  use_abs_pos_emb=True,
152
168
  use_mean_pooling=True,
153
169
  init_scale=0.001,
@@ -183,15 +199,15 @@ class Labram(EEGModuleMixin, nn.Module):
183
199
  self.patch_size = patch_size
184
200
  self.n_path = self.n_times // self.patch_size
185
201
 
186
- if neural_tokenizer and in_channels != 1:
202
+ if neural_tokenizer and in_conv_channels != 1:
187
203
  warn(
188
204
  "The model is in Neural Tokenizer mode, but the variable "
189
- + "`in_channels` is different from the default values."
190
- + "`in_channels` is only needed for the Neural Decoder mode."
191
- + "in_channels is not used in the Neural Tokenizer mode.",
205
+ + "`in_conv_channels` is different from the default values."
206
+ + "`in_conv_channels` is only needed for the Neural Decoder mode."
207
+ + "in_conv_channels is not used in the Neural Tokenizer mode.",
192
208
  UserWarning,
193
209
  )
194
- in_channels = 1
210
+ in_conv_channels = 1
195
211
  # If you can use the model in Neural Tokenizer mode,
196
212
  # temporal conv layer will be use over the patched dataset
197
213
  if neural_tokenizer:
@@ -228,7 +244,7 @@ class Labram(EEGModuleMixin, nn.Module):
228
244
  _PatchEmbed(
229
245
  n_times=self.n_times,
230
246
  patch_size=patch_size,
231
- in_channels=in_channels,
247
+ in_channels=in_conv_channels,
232
248
  emb_dim=self.emb_size,
233
249
  ),
234
250
  )
@@ -373,8 +389,7 @@ class Labram(EEGModuleMixin, nn.Module):
373
389
  Parameters
374
390
  ----------
375
391
  x : torch.Tensor
376
- The input data with shape (batch, n_chans, n_patches, patch size),
377
- if neural decoder or (batch, n_chans, n_times), if neural tokenizer.
392
+ The input data with shape (batch, n_chans, n_times).
378
393
  input_chans : int
379
394
  The number of input channels.
380
395
  return_patch_tokens : bool
@@ -387,37 +402,72 @@ class Labram(EEGModuleMixin, nn.Module):
387
402
  x : torch.Tensor
388
403
  The output of the model.
389
404
  """
405
+ batch_size = x.shape[0]
406
+
390
407
  if self.neural_tokenizer:
391
- batch_size, nch, n_patch, temporal = self.patch_embed.segment_patch(x).shape
408
+ # For neural tokenizer: input is (batch, n_chans, n_times)
409
+ # patch_embed returns (batch, n_chans, emb_dim)
410
+ x = self.patch_embed(x)
411
+ # x shape: (batch, n_chans, emb_dim)
412
+ n_patch = self.n_chans
413
+ temporal = self.emb_size
392
414
  else:
393
- batch_size, nch, n_patch = self.patch_embed(x).shape
394
- x = self.patch_embed(x)
415
+ # For neural decoder: input is (batch, n_chans, n_times)
416
+ # patch_embed returns (batch, n_patchs, emb_dim)
417
+ x = self.patch_embed(x)
418
+ # x shape: (batch, n_patchs, emb_dim)
419
+ batch_size, n_patch, temporal = x.shape
420
+
395
421
  # add the [CLS] token to the embedded patch tokens
396
422
  cls_tokens = self.cls_token.expand(batch_size, -1, -1)
397
423
 
424
+ # Concatenate cls token with patch/channel embeddings
398
425
  x = torch.cat((cls_tokens, x), dim=1)
399
426
 
400
427
  # Positional Embedding
401
- if input_chans is not None:
402
- pos_embed_used = self.position_embedding[:, input_chans]
403
- else:
404
- pos_embed_used = self.position_embedding
405
-
406
428
  if self.position_embedding is not None:
407
- pos_embed = self._adj_position_embedding(
408
- pos_embed_used=pos_embed_used, batch_size=batch_size
409
- )
429
+ if self.neural_tokenizer:
430
+ # In tokenizer mode, use channel-based position embedding
431
+ if input_chans is not None:
432
+ pos_embed_used = self.position_embedding[:, input_chans]
433
+ else:
434
+ pos_embed_used = self.position_embedding
435
+
436
+ pos_embed = self._adj_position_embedding(
437
+ pos_embed_used=pos_embed_used, batch_size=batch_size
438
+ )
439
+ else:
440
+ # In decoder mode, we have different number of patches
441
+ # Adapt position embedding for n_patch patches
442
+ # Use the first n_patch+1 positions from position_embedding
443
+ n_pos = min(self.position_embedding.shape[1], n_patch + 1)
444
+ pos_embed_used = self.position_embedding[:, :n_pos, :]
445
+ pos_embed = pos_embed_used.expand(batch_size, -1, -1)
446
+
410
447
  x += pos_embed
411
448
 
412
449
  # The time embedding is added across the channels after the [CLS] token
413
450
  if self.neural_tokenizer:
414
451
  num_ch = self.n_chans
452
+ time_embed = self._adj_temporal_embedding(
453
+ num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
454
+ )
455
+ x[:, 1:, :] += time_embed
415
456
  else:
416
- num_ch = n_patch
417
- time_embed = self._adj_temporal_embedding(
418
- num_ch=num_ch, batch_size=batch_size, dim_embed=temporal
419
- )
420
- x[:, 1:, :] += time_embed
457
+ # In decoder mode, we have n_patch patches and don't need to expand
458
+ # Just broadcast the temporal embedding
459
+ if temporal is None:
460
+ temporal = self.emb_size
461
+
462
+ # Get temporal embeddings for n_patch patches
463
+ n_time_tokens = min(n_patch, self.temporal_embedding.shape[1] - 1)
464
+ time_embed = self.temporal_embedding[
465
+ :, 1 : n_time_tokens + 1, :
466
+ ] # (1, n_patch, emb_dim)
467
+ time_embed = time_embed.expand(
468
+ batch_size, -1, -1
469
+ ) # (batch, n_patch, emb_dim)
470
+ x[:, 1:, :] += time_embed
421
471
 
422
472
  x = self.pos_drop(x)
423
473
 
@@ -428,10 +478,10 @@ class Labram(EEGModuleMixin, nn.Module):
428
478
  if self.fc_norm is not None:
429
479
  if return_all_tokens:
430
480
  return self.fc_norm(x)
431
- temporal = x[:, 1:, :]
481
+ tokens = x[:, 1:, :]
432
482
  if return_patch_tokens:
433
- return self.fc_norm(temporal)
434
- return self.fc_norm(temporal.mean(1))
483
+ return self.fc_norm(tokens)
484
+ return self.fc_norm(tokens.mean(1))
435
485
  else:
436
486
  if return_all_tokens:
437
487
  return x
@@ -505,14 +555,16 @@ class Labram(EEGModuleMixin, nn.Module):
505
555
  def _adj_temporal_embedding(self, num_ch, batch_size, dim_embed=None):
506
556
  """
507
557
  Adjust the dimensions of the time embedding to match the
508
- number of channels.
558
+ number of channels or patches.
509
559
 
510
560
  Parameters
511
561
  ----------
512
562
  num_ch : int
513
- The number of channels or number of code books vectors.
563
+ The number of channels or number of patches.
514
564
  batch_size : int
515
565
  Batch size of the input data.
566
+ dim_embed : int
567
+ The embedding dimension (temporal feature dimension).
516
568
 
517
569
  Returns
518
570
  -------
@@ -523,17 +575,24 @@ class Labram(EEGModuleMixin, nn.Module):
523
575
  if dim_embed is None:
524
576
  cut_dimension = self.patch_size
525
577
  else:
526
- cut_dimension = dim_embed
527
- # first step will be match the time_embed to the number of channels
528
- temporal_embedding = self.temporal_embedding[:, 1:cut_dimension, :]
578
+ cut_dimension = min(dim_embed, self.temporal_embedding.shape[1] - 1)
579
+
580
+ # Get the temporal embedding: (1, temporal_embedding_dim, emb_size)
581
+ # Slice to cut_dimension: (1, cut_dimension, emb_size)
582
+ temporal_embedding = self.temporal_embedding[:, 1 : cut_dimension + 1, :]
583
+
529
584
  # Add a new dimension to the time embedding
530
- # e.g. (batch, 62, 200) -> (batch, 1, 62, 200)
585
+ # e.g. (1, 5, 200) -> (1, 1, 5, 200)
531
586
  temporal_embedding = temporal_embedding.unsqueeze(1)
532
- # Expand the time embedding to match the number of channels
533
- # or number of patches from
587
+
588
+ # Expand the time embedding to match the number of channels or patches
589
+ # (1, 1, cut_dimension, 200) -> (batch_size, num_ch, cut_dimension, 200)
534
590
  temporal_embedding = temporal_embedding.expand(batch_size, num_ch, -1, -1)
591
+
535
592
  # Flatten the intermediate dimensions
593
+ # (batch_size, num_ch, cut_dimension, 200) -> (batch_size, num_ch * cut_dimension, 200)
536
594
  temporal_embedding = temporal_embedding.flatten(1, 2)
595
+
537
596
  return temporal_embedding
538
597
 
539
598
  def _adj_position_embedding(self, pos_embed_used, batch_size):
@@ -679,25 +738,27 @@ class _SegmentPatch(nn.Module):
679
738
 
680
739
 
681
740
  class _PatchEmbed(nn.Module):
682
- """EEG to Patch Embedding.
741
+ """EEG to Patch Embedding for Neural Decoder mode.
683
742
 
684
743
  This code is used when we want to apply the patch embedding
685
- after the codebook layer.
744
+ after the codebook layer (Neural Decoder mode).
745
+
746
+ The input is expected to be in the format (Batch, n_channels, n_times),
747
+ but the original LaBraM expects pre-patched data (Batch, n_channels, n_patches, patch_size).
748
+ This class reshapes the input to the pre-patched format, then applies a 2D
749
+ convolution to project this pre-patched data to the embedding dimension,
750
+ and finally flattens across channels to produce a unified embedding.
686
751
 
687
752
  Parameters:
688
753
  -----------
689
754
  n_times: int (default=2000)
690
- Number of temporal components of the input tensor.
755
+ Number of temporal components of the input tensor (used for dimension calculation).
691
756
  patch_size: int (default=200)
692
757
  Size of the patch, default is 1-seconds with 200Hz.
693
758
  in_channels: int (default=1)
694
- Number of input channels for to be used in the convolution.
759
+ Number of input channels (from VQVAE codebook).
695
760
  emb_dim: int (default=200)
696
- Number of out_channes to be used in the convolution, here,
697
- we used the same as patch_size.
698
- n_codebooks: int (default=62)
699
- Number of patches to be used in the convolution, here,
700
- we used the same as n_times // patch_size.
761
+ Number of output embedding dimension.
701
762
  """
702
763
 
703
764
  def __init__(
@@ -707,10 +768,13 @@ class _PatchEmbed(nn.Module):
707
768
  self.n_times = n_times
708
769
  self.patch_size = patch_size
709
770
  self.patch_shape = (1, self.n_times // self.patch_size)
710
- n_patchs = n_codebooks * (self.n_times // self.patch_size)
711
-
712
- self.n_patchs = n_patchs
771
+ self.n_patchs = self.n_times // self.patch_size
772
+ self.emb_dim = emb_dim
773
+ self.in_channels = in_channels
713
774
 
775
+ # 2D Conv to project the pre-patched data
776
+ # Input: (Batch, in_channels, n_patches, patch_size)
777
+ # After proj: (Batch, emb_dim, n_patches, 1)
714
778
  self.proj = nn.Conv2d(
715
779
  in_channels=in_channels,
716
780
  out_channels=emb_dim,
@@ -718,27 +782,64 @@ class _PatchEmbed(nn.Module):
718
782
  stride=(1, self.patch_size),
719
783
  )
720
784
 
721
- self.merge_transpose = Rearrange(
722
- "Batch ch patch spatch -> Batch patch spatch ch",
723
- )
724
-
725
785
  def forward(self, x):
726
786
  """
727
- Apply the convolution to the input tensor.
728
- then merge the output tensor to the desired shape.
787
+ Apply the temporal projection to the input tensor after grouping channels.
729
788
 
730
- Parameters:
731
- -----------
732
- x: torch.Tensor
733
- Input tensor of shape (Batch, Channels, n_patchs, patch_size).
789
+ Parameters
790
+ ----------
791
+ x : torch.Tensor
792
+ Input tensor of shape (Batch, n_channels, n_times) or
793
+ (Batch, n_channels, n_patches, patch_size).
734
794
 
735
- Return:
795
+ Returns
736
796
  -------
737
- x: torch.Tensor
738
- Output tensor of shape (Batch, n_patchs, patch_size, channels).
797
+ torch.Tensor
798
+ Output tensor of shape (Batch, n_patchs, emb_dim).
739
799
  """
800
+ if x.ndim == 4:
801
+ batch_size, n_channels, n_patchs, patch_len = x.shape
802
+ if patch_len != self.patch_size:
803
+ raise ValueError(
804
+ "When providing a 4D tensor, the last dimension "
805
+ f"({patch_len}) must match patch_size ({self.patch_size})."
806
+ )
807
+ n_times = n_patchs * patch_len
808
+ x = x.reshape(batch_size, n_channels, n_times)
809
+ elif x.ndim == 3:
810
+ batch_size, n_channels, n_times = x.shape
811
+ else:
812
+ raise ValueError(
813
+ "Input must be either 3D (batch, channels, times) or "
814
+ "4D (batch, channels, n_patches, patch_size)."
815
+ )
816
+
817
+ if n_times % self.patch_size != 0:
818
+ raise ValueError(
819
+ f"n_times ({n_times}) must be divisible by patch_size ({self.patch_size})."
820
+ )
821
+ if n_channels % self.in_channels != 0:
822
+ raise ValueError(
823
+ "The input channel dimension "
824
+ f"({n_channels}) must be divisible by in_channels ({self.in_channels})."
825
+ )
826
+
827
+ group_size = n_channels // self.in_channels
828
+
829
+ # Reshape so Conv2d sees `in_channels` feature maps and uses the grouped
830
+ # EEG channels as the spatial height dimension.
831
+ # Shape after view: (Batch, in_channels, group_size, n_times)
832
+ x = x.view(batch_size, self.in_channels, group_size, n_times)
833
+
834
+ # Apply the temporal projection per group.
835
+ # Output shape: (Batch, emb_dim, group_size, n_patchs)
740
836
  x = self.proj(x)
741
- x = self.merge_transpose(x)
837
+
838
+ # THIS IS braindecode's MODIFICATION:
839
+ # Average over the grouped channel dimension and permute to (Batch, n_patchs, emb_dim)
840
+ x = x.mean(dim=2)
841
+ x = x.transpose(1, 2).contiguous()
842
+
742
843
  return x
743
844
 
744
845