rxnn 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +4 -4
- rxnn/experimental/models.py +1 -1
- rxnn/transformers/attention.py +6 -2
- {rxnn-0.1.35.dist-info → rxnn-0.1.37.dist-info}/METADATA +1 -1
- {rxnn-0.1.35.dist-info → rxnn-0.1.37.dist-info}/RECORD +7 -7
- {rxnn-0.1.35.dist-info → rxnn-0.1.37.dist-info}/LICENSE +0 -0
- {rxnn-0.1.35.dist-info → rxnn-0.1.37.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -614,7 +614,7 @@ def init_moe_attention(
|
|
614
614
|
num_query_experts: int = None,
|
615
615
|
num_query_groups: int = None,
|
616
616
|
) -> GroupedQueryAttention:
|
617
|
-
assert attention_type in ['gma', 'dma', '
|
617
|
+
assert attention_type in ['gma', 'dma', 'gma_s', 'dma_s'], "Error, attention type should be one of: 'gma', 'dma', 'gma_s', 'dma_s'"
|
618
618
|
|
619
619
|
if attention_type == "gma":
|
620
620
|
return GroupedMoeAttention(
|
@@ -648,8 +648,8 @@ def init_moe_attention(
|
|
648
648
|
num_query_experts=num_query_experts,
|
649
649
|
num_query_groups=num_query_groups,
|
650
650
|
)
|
651
|
-
elif attention_type == "
|
652
|
-
return
|
651
|
+
elif attention_type == "gma_s":
|
652
|
+
return GroupedMoeAttentionSimplified(
|
653
653
|
embed_dim,
|
654
654
|
num_heads,
|
655
655
|
gqa_groups,
|
@@ -664,7 +664,7 @@ def init_moe_attention(
|
|
664
664
|
num_experts=num_experts,
|
665
665
|
)
|
666
666
|
else:
|
667
|
-
return
|
667
|
+
return DeepMoeAttentionSimplified(
|
668
668
|
embed_dim,
|
669
669
|
num_heads,
|
670
670
|
gqa_groups,
|
rxnn/experimental/models.py
CHANGED
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
65
65
|
assert ff_activation in ['relu', 'gelu',
|
66
66
|
'swish', 'silu', 'linear',
|
67
67
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
68
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', '
|
68
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_s', 'dma_s'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_s", "dma_s".'
|
69
69
|
|
70
70
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
71
71
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
rxnn/transformers/attention.py
CHANGED
@@ -102,8 +102,12 @@ class MultiHeadAttention(nn.Module):
|
|
102
102
|
|
103
103
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
104
104
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
105
|
-
|
106
|
-
|
105
|
+
# After ~6h of fighthing, PyTorch based is still now working so I decided to use FlashAttention directly
|
106
|
+
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
107
|
+
# return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
|
108
|
+
from flash_attn import flash_attn_func
|
109
|
+
attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, is_causal=self.is_causal)
|
110
|
+
return self._transpose_output(attn_output, b, t, d)
|
107
111
|
|
108
112
|
def _torch_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
109
113
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=PjmVwNeJXDy72LJr5cl9JD1oqjlwYK-Ahx1K1gLQgf8,29426
|
4
|
+
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=bsuAXCKR0WbOxgu-IkJHgn7jUu2CK4hqNw60IZbGTEE,15698
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.37.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.37.dist-info/METADATA,sha256=9cKmCtODY8tw_VimbuCN9787asAvNbNylZBK2gOBzLE,16627
|
30
|
+
rxnn-0.1.37.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.37.dist-info/RECORD,,
|
File without changes
|
File without changes
|