braindecode 1.5.0.dev984__py3-none-any.whl → 1.5.0.dev182195895__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/models/biot.py +1 -1
- braindecode/modules/attention.py +45 -16
- braindecode/version.py +1 -1
- {braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/METADATA +1 -1
- {braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/RECORD +9 -9
- {braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/WHEEL +0 -0
- {braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/top_level.txt +0 -0
braindecode/models/biot.py
CHANGED
|
@@ -439,7 +439,7 @@ class _BIOTEncoder(nn.Module):
|
|
|
439
439
|
self.channel_tokens = nn.Embedding(
|
|
440
440
|
num_embeddings=n_chans, embedding_dim=emb_size
|
|
441
441
|
)
|
|
442
|
-
self.index
|
|
442
|
+
self.register_buffer("index", torch.arange(n_chans, dtype=torch.long))
|
|
443
443
|
|
|
444
444
|
def stft(self, sample):
|
|
445
445
|
"""
|
braindecode/modules/attention.py
CHANGED
|
@@ -837,6 +837,18 @@ class CATLite(nn.Module):
|
|
|
837
837
|
class MultiHeadAttention(nn.Module):
|
|
838
838
|
"""Multi-head self-attention block.
|
|
839
839
|
|
|
840
|
+
Uses ``F.scaled_dot_product_attention`` for optimized attention
|
|
841
|
+
kernels (flash-attention on CUDA, memory-efficient on other devices).
|
|
842
|
+
|
|
843
|
+
Parameters
|
|
844
|
+
----------
|
|
845
|
+
emb_size : int
|
|
846
|
+
The embedding dimension.
|
|
847
|
+
num_heads : int
|
|
848
|
+
Number of attention heads. Must evenly divide ``emb_size``.
|
|
849
|
+
dropout : float, optional
|
|
850
|
+
Dropout probability applied to attention weights. Default: 0.0.
|
|
851
|
+
|
|
840
852
|
Examples
|
|
841
853
|
--------
|
|
842
854
|
>>> import torch
|
|
@@ -848,40 +860,57 @@ class MultiHeadAttention(nn.Module):
|
|
|
848
860
|
torch.Size([2, 10, 32])
|
|
849
861
|
"""
|
|
850
862
|
|
|
851
|
-
def __init__(self, emb_size, num_heads, dropout):
|
|
863
|
+
def __init__(self, emb_size, num_heads, dropout=0.0):
|
|
852
864
|
super().__init__()
|
|
865
|
+
if emb_size % num_heads != 0:
|
|
866
|
+
raise ValueError(
|
|
867
|
+
f"emb_size ({emb_size}) must be divisible by num_heads ({num_heads})."
|
|
868
|
+
)
|
|
853
869
|
self.emb_size = emb_size
|
|
854
870
|
self.num_heads = num_heads
|
|
871
|
+
self.head_dim = emb_size // num_heads
|
|
855
872
|
self.keys = nn.Linear(emb_size, emb_size)
|
|
856
873
|
self.queries = nn.Linear(emb_size, emb_size)
|
|
857
874
|
self.values = nn.Linear(emb_size, emb_size)
|
|
858
|
-
self.att_drop =
|
|
875
|
+
self.att_drop = dropout
|
|
859
876
|
self.projection = nn.Linear(emb_size, emb_size)
|
|
860
877
|
|
|
861
878
|
self.rearrange_stack = Rearrange(
|
|
862
|
-
"
|
|
863
|
-
|
|
879
|
+
"batch seq (heads head_dim) -> batch heads seq head_dim",
|
|
880
|
+
heads=num_heads,
|
|
864
881
|
)
|
|
865
882
|
self.rearrange_unstack = Rearrange(
|
|
866
|
-
"
|
|
883
|
+
"batch heads seq head_dim -> batch seq (heads head_dim)",
|
|
867
884
|
)
|
|
868
885
|
|
|
869
886
|
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
887
|
+
"""Forward pass.
|
|
888
|
+
|
|
889
|
+
Parameters
|
|
890
|
+
----------
|
|
891
|
+
x : Tensor
|
|
892
|
+
Input tensor of shape ``(batch, seq, emb_size)``.
|
|
893
|
+
mask : Tensor, optional
|
|
894
|
+
Attention mask following PyTorch SDPA convention: for boolean
|
|
895
|
+
masks ``True`` means **ignore** that position; for float
|
|
896
|
+
masks the values are **added** to attention scores before
|
|
897
|
+
softmax.
|
|
898
|
+
"""
|
|
870
899
|
queries = self.rearrange_stack(self.queries(x))
|
|
871
900
|
keys = self.rearrange_stack(self.keys(x))
|
|
872
901
|
values = self.rearrange_stack(self.values(x))
|
|
873
|
-
|
|
874
|
-
if
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
902
|
+
|
|
903
|
+
dp = self.att_drop if self.training else 0.0
|
|
904
|
+
out = F.scaled_dot_product_attention(
|
|
905
|
+
queries,
|
|
906
|
+
keys,
|
|
907
|
+
values,
|
|
908
|
+
attn_mask=mask,
|
|
909
|
+
dropout_p=dp,
|
|
910
|
+
)
|
|
911
|
+
|
|
882
912
|
out = self.rearrange_unstack(out)
|
|
883
|
-
|
|
884
|
-
return out
|
|
913
|
+
return self.projection(out)
|
|
885
914
|
|
|
886
915
|
|
|
887
916
|
class CrissCrossTransformerEncoderLayer(nn.Module):
|
braindecode/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "1.5.0.
|
|
1
|
+
__version__ = "1.5.0.dev182195895"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.5.0.
|
|
3
|
+
Version: 1.5.0.dev182195895
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -3,7 +3,7 @@ braindecode/classifier.py,sha256=7kC_oY_UzHEes_WWdCvEpiA1ZKxMeuLL5tIPp5rfcpg,962
|
|
|
3
3
|
braindecode/eegneuralnet.py,sha256=xjE6aPZdCQPs29NIpy_m1GLMMC2WZ3Db0Fuh1-xE1h4,13827
|
|
4
4
|
braindecode/regressor.py,sha256=KiMJpqCUPWA2k2JWk9HGYTzeoBqJ4gAKEudeUVcFZY4,9266
|
|
5
5
|
braindecode/util.py,sha256=f8bNIwt-SwsHqheH_BADQxTtA9oPt3Lb7GFnoI-Huwc,14101
|
|
6
|
-
braindecode/version.py,sha256=
|
|
6
|
+
braindecode/version.py,sha256=FCjD8jntd2wdAgufj5JsChb_O65vD4HdEFpKyVmbAz0,35
|
|
7
7
|
braindecode/augmentation/__init__.py,sha256=4xune2QUK6KHMKsAqijF7I9eeiVbP0wEoQJjCNLNcKM,1081
|
|
8
8
|
braindecode/augmentation/base.py,sha256=OJ1shOljI1yTY9zh2qWxQwivlY43sfx9Q-MAyMhxtPs,7338
|
|
9
9
|
braindecode/augmentation/functional.py,sha256=q2k6mAXrujYlOZUndcjZN8e8b-6oJF1gGsORAI23hyE,43998
|
|
@@ -45,7 +45,7 @@ braindecode/models/attentionbasenet.py,sha256=k7ar7aEjANudPu7krAZsRx-ag61ugirS1X
|
|
|
45
45
|
braindecode/models/attn_sleep.py,sha256=F9x4spTtzfiCC1h9UYITmIDQeJW6_2CXTZktZX9R0RE,17950
|
|
46
46
|
braindecode/models/base.py,sha256=yGJgr0f5rD-gJZ5Msw9FzGzVO-x_re-6jOSK8Iht6x4,27923
|
|
47
47
|
braindecode/models/bendr.py,sha256=kqIKtIrgHhIdwYfBFwAJ1g69YbdpA8bs2C2e2GUT0rY,25643
|
|
48
|
-
braindecode/models/biot.py,sha256=
|
|
48
|
+
braindecode/models/biot.py,sha256=EJuf77ieWirrkYrVZessQF-MulVBdNTUhpRw3UxHhb8,19348
|
|
49
49
|
braindecode/models/brainmodule.py,sha256=idyQVTp3VBJXKF1YjMx8o1kUKrcL_E_AJsioPjcuqV8,33282
|
|
50
50
|
braindecode/models/cbramod.py,sha256=ZNyM_iBGFGjMpmOiqhLwmgGzy-AowBJqTORzNeNG99I,14309
|
|
51
51
|
braindecode/models/config.py,sha256=9yc7fh2p6txrjaOROyJevJO7ud5fofZk6C76JfwZrX0,8431
|
|
@@ -93,7 +93,7 @@ braindecode/models/usleep.py,sha256=VaIRbTSuKisD_W6fMwLlY5bNiH41fBnfx0k-iiLHWpw,
|
|
|
93
93
|
braindecode/models/util.py,sha256=QSku9ZlXTyJzMuTLG-VlPPifTe2_6huo7i9-zQnhZho,17728
|
|
94
94
|
braindecode/modules/__init__.py,sha256=NChqITh6yMcagoYFVkQC14aJjCAGkRO8fYQrH_0Zv1k,1977
|
|
95
95
|
braindecode/modules/activation.py,sha256=QSlpMj_mfroLvaaIGfnjWJbWgvpEBigq2AU0jj0fglQ,2496
|
|
96
|
-
braindecode/modules/attention.py,sha256=
|
|
96
|
+
braindecode/modules/attention.py,sha256=zHJ6kL54ZGGX4376Xk2rzxRMateOi4AtSSxzLs5hD-o,33629
|
|
97
97
|
braindecode/modules/blocks.py,sha256=Nv-hu1nx0zZBs0aYMe7FDCzRkH1rf34CHXpoTQP6lCE,4993
|
|
98
98
|
braindecode/modules/convolution.py,sha256=b6jvSp1bS8skIUfX2Ztn_JHyxX8q3hvaEqyOmzNnRsc,10786
|
|
99
99
|
braindecode/modules/filter.py,sha256=Lcq5zt3uEmIiFeu8SbAWkhsqIoSU1zS_Mo-0mBFHrJE,25918
|
|
@@ -119,9 +119,9 @@ braindecode/training/scoring.py,sha256=WRkwqbitA3m_dzRnGp2ZIZPge5Nhx9gAEQhIHzeH4
|
|
|
119
119
|
braindecode/visualization/__init__.py,sha256=4EER_xHqZIDzEvmgUEm7K1bgNKpyZAIClR9ZCkMuY4M,240
|
|
120
120
|
braindecode/visualization/confusion_matrices.py,sha256=qIWMLEHow5CJ7PhGggD8mnD55Le6xhma9HSzt4R33fc,9509
|
|
121
121
|
braindecode/visualization/gradients.py,sha256=-NuWIlpOeSR2OWGfUl6jn1RqdNPxmiQNLPCckWb8SRE,2330
|
|
122
|
-
braindecode-1.5.0.
|
|
123
|
-
braindecode-1.5.0.
|
|
124
|
-
braindecode-1.5.0.
|
|
125
|
-
braindecode-1.5.0.
|
|
126
|
-
braindecode-1.5.0.
|
|
127
|
-
braindecode-1.5.0.
|
|
122
|
+
braindecode-1.5.0.dev182195895.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
|
|
123
|
+
braindecode-1.5.0.dev182195895.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
|
|
124
|
+
braindecode-1.5.0.dev182195895.dist-info/METADATA,sha256=VSuDNooSSWCZExCgShqwhOq2EayUOJQ5w_bixsGHC8A,10105
|
|
125
|
+
braindecode-1.5.0.dev182195895.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
126
|
+
braindecode-1.5.0.dev182195895.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
|
|
127
|
+
braindecode-1.5.0.dev182195895.dist-info/RECORD,,
|
|
File without changes
|
{braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/licenses/LICENSE.txt
RENAMED
|
File without changes
|
{braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/licenses/NOTICE.txt
RENAMED
|
File without changes
|
{braindecode-1.5.0.dev984.dist-info → braindecode-1.5.0.dev182195895.dist-info}/top_level.txt
RENAMED
|
File without changes
|