rxnn 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +4 -4
- rxnn/experimental/models.py +1 -1
- {rxnn-0.1.35.dist-info → rxnn-0.1.36.dist-info}/METADATA +1 -1
- {rxnn-0.1.35.dist-info → rxnn-0.1.36.dist-info}/RECORD +6 -6
- {rxnn-0.1.35.dist-info → rxnn-0.1.36.dist-info}/LICENSE +0 -0
- {rxnn-0.1.35.dist-info → rxnn-0.1.36.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -614,7 +614,7 @@ def init_moe_attention(
|
|
614
614
|
num_query_experts: int = None,
|
615
615
|
num_query_groups: int = None,
|
616
616
|
) -> GroupedQueryAttention:
|
617
|
-
assert attention_type in ['gma', 'dma', '
|
617
|
+
assert attention_type in ['gma', 'dma', 'gma_s', 'dma_s'], "Error, attention type should be one of: 'gma', 'dma', 'gma_s', 'dma_s'"
|
618
618
|
|
619
619
|
if attention_type == "gma":
|
620
620
|
return GroupedMoeAttention(
|
@@ -648,8 +648,8 @@ def init_moe_attention(
|
|
648
648
|
num_query_experts=num_query_experts,
|
649
649
|
num_query_groups=num_query_groups,
|
650
650
|
)
|
651
|
-
elif attention_type == "
|
652
|
-
return
|
651
|
+
elif attention_type == "gma_s":
|
652
|
+
return GroupedMoeAttentionSimplified(
|
653
653
|
embed_dim,
|
654
654
|
num_heads,
|
655
655
|
gqa_groups,
|
@@ -664,7 +664,7 @@ def init_moe_attention(
|
|
664
664
|
num_experts=num_experts,
|
665
665
|
)
|
666
666
|
else:
|
667
|
-
return
|
667
|
+
return DeepMoeAttentionSimplified(
|
668
668
|
embed_dim,
|
669
669
|
num_heads,
|
670
670
|
gqa_groups,
|
rxnn/experimental/models.py
CHANGED
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
65
65
|
assert ff_activation in ['relu', 'gelu',
|
66
66
|
'swish', 'silu', 'linear',
|
67
67
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
68
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', '
|
68
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_s', 'dma_s'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_s", "dma_s".'
|
69
69
|
|
70
70
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
71
71
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=PjmVwNeJXDy72LJr5cl9JD1oqjlwYK-Ahx1K1gLQgf8,29426
|
4
|
+
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.36.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.36.dist-info/METADATA,sha256=cEaouEWWp2OE2dMzM8G5GQe5z_LMMa2UCz03_rHfxhk,16627
|
30
|
+
rxnn-0.1.36.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|