rxnn 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -614,7 +614,7 @@ def init_moe_attention(
614
614
  num_query_experts: int = None,
615
615
  num_query_groups: int = None,
616
616
  ) -> GroupedQueryAttention:
617
- assert attention_type in ['gma', 'dma', 'gma_v', 'dma_v'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
617
+ assert attention_type in ['gma', 'dma', 'gma_s', 'dma_s'], "Error, attention type should be one of: 'gma', 'dma', 'gma_s', 'dma_s'"
618
618
 
619
619
  if attention_type == "gma":
620
620
  return GroupedMoeAttention(
@@ -648,8 +648,8 @@ def init_moe_attention(
648
648
  num_query_experts=num_query_experts,
649
649
  num_query_groups=num_query_groups,
650
650
  )
651
- elif attention_type == "gma_v":
652
- return GroupedMoeAttentionVectorized(
651
+ elif attention_type == "gma_s":
652
+ return GroupedMoeAttentionSimplified(
653
653
  embed_dim,
654
654
  num_heads,
655
655
  gqa_groups,
@@ -664,7 +664,7 @@ def init_moe_attention(
664
664
  num_experts=num_experts,
665
665
  )
666
666
  else:
667
- return DeepMoeAttentionVectorized(
667
+ return DeepMoeAttentionSimplified(
668
668
  embed_dim,
669
669
  num_heads,
670
670
  gqa_groups,
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
65
65
  assert ff_activation in ['relu', 'gelu',
66
66
  'swish', 'silu', 'linear',
67
67
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
68
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v', 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v".'
68
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_s', 'dma_s'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_s", "dma_s".'
69
69
 
70
70
  embedding = nn.Embedding(vocab_size, embed_dim)
71
71
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.35
3
+ Version: 0.1.36
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,7 +1,7 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=GxbLmOTBvUiYU0Rc_0ju1n_ocJciHC6i3neDGe-rZZc,29426
4
- rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
3
+ rxnn/experimental/attention.py,sha256=PjmVwNeJXDy72LJr5cl9JD1oqjlwYK-Ahx1K1gLQgf8,29426
4
+ rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.35.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.35.dist-info/METADATA,sha256=aziCzqOeetdE3gMV2i15QoB5O31bGpiZgzcpGM97QPk,16627
30
- rxnn-0.1.35.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.35.dist-info/RECORD,,
28
+ rxnn-0.1.36.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.36.dist-info/METADATA,sha256=cEaouEWWp2OE2dMzM8G5GQe5z_LMMa2UCz03_rHfxhk,16627
30
+ rxnn-0.1.36.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.36.dist-info/RECORD,,
File without changes
File without changes