braindecode 1.5.0.dev984__py3-none-any.whl → 1.5.0.dev182195895__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -439,7 +439,7 @@ class _BIOTEncoder(nn.Module):
439
439
  self.channel_tokens = nn.Embedding(
440
440
  num_embeddings=n_chans, embedding_dim=emb_size
441
441
  )
442
- self.index = nn.Parameter(torch.LongTensor(range(n_chans)), requires_grad=False)
442
+ self.register_buffer("index", torch.arange(n_chans, dtype=torch.long))
443
443
 
444
444
  def stft(self, sample):
445
445
  """
@@ -837,6 +837,18 @@ class CATLite(nn.Module):
837
837
  class MultiHeadAttention(nn.Module):
838
838
  """Multi-head self-attention block.
839
839
 
840
+ Uses ``F.scaled_dot_product_attention`` for optimized attention
841
+ kernels (flash-attention on CUDA, memory-efficient on other devices).
842
+
843
+ Parameters
844
+ ----------
845
+ emb_size : int
846
+ The embedding dimension.
847
+ num_heads : int
848
+ Number of attention heads. Must evenly divide ``emb_size``.
849
+ dropout : float, optional
850
+ Dropout probability applied to attention weights. Default: 0.0.
851
+
840
852
  Examples
841
853
  --------
842
854
  >>> import torch
@@ -848,40 +860,57 @@ class MultiHeadAttention(nn.Module):
848
860
  torch.Size([2, 10, 32])
849
861
  """
850
862
 
851
- def __init__(self, emb_size, num_heads, dropout):
863
+ def __init__(self, emb_size, num_heads, dropout=0.0):
852
864
  super().__init__()
865
+ if emb_size % num_heads != 0:
866
+ raise ValueError(
867
+ f"emb_size ({emb_size}) must be divisible by num_heads ({num_heads})."
868
+ )
853
869
  self.emb_size = emb_size
854
870
  self.num_heads = num_heads
871
+ self.head_dim = emb_size // num_heads
855
872
  self.keys = nn.Linear(emb_size, emb_size)
856
873
  self.queries = nn.Linear(emb_size, emb_size)
857
874
  self.values = nn.Linear(emb_size, emb_size)
858
- self.att_drop = nn.Dropout(dropout)
875
+ self.att_drop = dropout
859
876
  self.projection = nn.Linear(emb_size, emb_size)
860
877
 
861
878
  self.rearrange_stack = Rearrange(
862
- "b n (h d) -> b h n d",
863
- h=num_heads,
879
+ "batch seq (heads head_dim) -> batch heads seq head_dim",
880
+ heads=num_heads,
864
881
  )
865
882
  self.rearrange_unstack = Rearrange(
866
- "b h n d -> b n (h d)",
883
+ "batch heads seq head_dim -> batch seq (heads head_dim)",
867
884
  )
868
885
 
869
886
  def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
887
+ """Forward pass.
888
+
889
+ Parameters
890
+ ----------
891
+ x : Tensor
892
+ Input tensor of shape ``(batch, seq, emb_size)``.
893
+ mask : Tensor, optional
894
+ Attention mask following PyTorch SDPA convention: for boolean
895
+ masks ``True`` means **ignore** that position; for float
896
+ masks the values are **added** to attention scores before
897
+ softmax.
898
+ """
870
899
  queries = self.rearrange_stack(self.queries(x))
871
900
  keys = self.rearrange_stack(self.keys(x))
872
901
  values = self.rearrange_stack(self.values(x))
873
- energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
874
- if mask is not None:
875
- fill_value = float("-inf")
876
- energy = energy.masked_fill(~mask, fill_value)
877
-
878
- scaling = self.emb_size ** (1 / 2)
879
- att = F.softmax(energy / scaling, dim=-1)
880
- att = self.att_drop(att)
881
- out = torch.einsum("bhal, bhlv -> bhav ", att, values)
902
+
903
+ dp = self.att_drop if self.training else 0.0
904
+ out = F.scaled_dot_product_attention(
905
+ queries,
906
+ keys,
907
+ values,
908
+ attn_mask=mask,
909
+ dropout_p=dp,
910
+ )
911
+
882
912
  out = self.rearrange_unstack(out)
883
- out = self.projection(out)
884
- return out
913
+ return self.projection(out)
885
914
 
886
915
 
887
916
  class CrissCrossTransformerEncoderLayer(nn.Module):
braindecode/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.5.0.dev984"
1
+ __version__ = "1.5.0.dev182195895"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.5.0.dev984
3
+ Version: 1.5.0.dev182195895
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -3,7 +3,7 @@ braindecode/classifier.py,sha256=7kC_oY_UzHEes_WWdCvEpiA1ZKxMeuLL5tIPp5rfcpg,962
3
3
  braindecode/eegneuralnet.py,sha256=xjE6aPZdCQPs29NIpy_m1GLMMC2WZ3Db0Fuh1-xE1h4,13827
4
4
  braindecode/regressor.py,sha256=KiMJpqCUPWA2k2JWk9HGYTzeoBqJ4gAKEudeUVcFZY4,9266
5
5
  braindecode/util.py,sha256=f8bNIwt-SwsHqheH_BADQxTtA9oPt3Lb7GFnoI-Huwc,14101
6
- braindecode/version.py,sha256=5WqXRAeL_e0Aq78aGQsveuaWwGcV_OgtE7s9XQXznkc,29
6
+ braindecode/version.py,sha256=FCjD8jntd2wdAgufj5JsChb_O65vD4HdEFpKyVmbAz0,35
7
7
  braindecode/augmentation/__init__.py,sha256=4xune2QUK6KHMKsAqijF7I9eeiVbP0wEoQJjCNLNcKM,1081
8
8
  braindecode/augmentation/base.py,sha256=OJ1shOljI1yTY9zh2qWxQwivlY43sfx9Q-MAyMhxtPs,7338
9
9
  braindecode/augmentation/functional.py,sha256=q2k6mAXrujYlOZUndcjZN8e8b-6oJF1gGsORAI23hyE,43998
@@ -45,7 +45,7 @@ braindecode/models/attentionbasenet.py,sha256=k7ar7aEjANudPu7krAZsRx-ag61ugirS1X
45
45
  braindecode/models/attn_sleep.py,sha256=F9x4spTtzfiCC1h9UYITmIDQeJW6_2CXTZktZX9R0RE,17950
46
46
  braindecode/models/base.py,sha256=yGJgr0f5rD-gJZ5Msw9FzGzVO-x_re-6jOSK8Iht6x4,27923
47
47
  braindecode/models/bendr.py,sha256=kqIKtIrgHhIdwYfBFwAJ1g69YbdpA8bs2C2e2GUT0rY,25643
48
- braindecode/models/biot.py,sha256=k742atP5TMAFkEDQ1hd-AhOJW2a0R6_StRbZGYjWS1U,19358
48
+ braindecode/models/biot.py,sha256=EJuf77ieWirrkYrVZessQF-MulVBdNTUhpRw3UxHhb8,19348
49
49
  braindecode/models/brainmodule.py,sha256=idyQVTp3VBJXKF1YjMx8o1kUKrcL_E_AJsioPjcuqV8,33282
50
50
  braindecode/models/cbramod.py,sha256=ZNyM_iBGFGjMpmOiqhLwmgGzy-AowBJqTORzNeNG99I,14309
51
51
  braindecode/models/config.py,sha256=9yc7fh2p6txrjaOROyJevJO7ud5fofZk6C76JfwZrX0,8431
@@ -93,7 +93,7 @@ braindecode/models/usleep.py,sha256=VaIRbTSuKisD_W6fMwLlY5bNiH41fBnfx0k-iiLHWpw,
93
93
  braindecode/models/util.py,sha256=QSku9ZlXTyJzMuTLG-VlPPifTe2_6huo7i9-zQnhZho,17728
94
94
  braindecode/modules/__init__.py,sha256=NChqITh6yMcagoYFVkQC14aJjCAGkRO8fYQrH_0Zv1k,1977
95
95
  braindecode/modules/activation.py,sha256=QSlpMj_mfroLvaaIGfnjWJbWgvpEBigq2AU0jj0fglQ,2496
96
- braindecode/modules/attention.py,sha256=Rvad4FchhXSjgApi15aQTXRsWYJ43aChOH-akSbRJAI,32689
96
+ braindecode/modules/attention.py,sha256=zHJ6kL54ZGGX4376Xk2rzxRMateOi4AtSSxzLs5hD-o,33629
97
97
  braindecode/modules/blocks.py,sha256=Nv-hu1nx0zZBs0aYMe7FDCzRkH1rf34CHXpoTQP6lCE,4993
98
98
  braindecode/modules/convolution.py,sha256=b6jvSp1bS8skIUfX2Ztn_JHyxX8q3hvaEqyOmzNnRsc,10786
99
99
  braindecode/modules/filter.py,sha256=Lcq5zt3uEmIiFeu8SbAWkhsqIoSU1zS_Mo-0mBFHrJE,25918
@@ -119,9 +119,9 @@ braindecode/training/scoring.py,sha256=WRkwqbitA3m_dzRnGp2ZIZPge5Nhx9gAEQhIHzeH4
119
119
  braindecode/visualization/__init__.py,sha256=4EER_xHqZIDzEvmgUEm7K1bgNKpyZAIClR9ZCkMuY4M,240
120
120
  braindecode/visualization/confusion_matrices.py,sha256=qIWMLEHow5CJ7PhGggD8mnD55Le6xhma9HSzt4R33fc,9509
121
121
  braindecode/visualization/gradients.py,sha256=-NuWIlpOeSR2OWGfUl6jn1RqdNPxmiQNLPCckWb8SRE,2330
122
- braindecode-1.5.0.dev984.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
123
- braindecode-1.5.0.dev984.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
124
- braindecode-1.5.0.dev984.dist-info/METADATA,sha256=gs4siQJaVjCx8UszxeuPfWYFF_FPk2aagFqBzmIZbTI,10099
125
- braindecode-1.5.0.dev984.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
126
- braindecode-1.5.0.dev984.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
127
- braindecode-1.5.0.dev984.dist-info/RECORD,,
122
+ braindecode-1.5.0.dev182195895.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
123
+ braindecode-1.5.0.dev182195895.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
124
+ braindecode-1.5.0.dev182195895.dist-info/METADATA,sha256=VSuDNooSSWCZExCgShqwhOq2EayUOJQ5w_bixsGHC8A,10105
125
+ braindecode-1.5.0.dev182195895.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
126
+ braindecode-1.5.0.dev182195895.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
127
+ braindecode-1.5.0.dev182195895.dist-info/RECORD,,