rxnn 0.2.55__py3-none-any.whl → 0.2.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -571,7 +571,7 @@ def init_experimental_attention(
571
571
  num_global_tokens: int = 16,
572
572
  window_size: int = 128,
573
573
  ) -> MultiHeadAttention:
574
- assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
574
+ assert attention_type in ['gma', 'dma', 'sqa', 'flex', 'flex-sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex', 'flex-sqa"
575
575
 
576
576
  if attention_type == "gma":
577
577
  return GroupedMoeAttention(
@@ -622,6 +622,23 @@ def init_experimental_attention(
622
622
  num_global_tokens=num_global_tokens,
623
623
  window_size=window_size,
624
624
  )
625
+ elif attention_type == "flex-sqa":
626
+ return FlexSparseQueryAttention(
627
+ embed_dim,
628
+ num_heads,
629
+ gqa_groups,
630
+ num_query_groups,
631
+ dropout=dropout,
632
+ rope=rope,
633
+ max_seq_len=max_seq_len,
634
+ rope_only_for_query=rope_only_for_query,
635
+ rope_only_for_keys=rope_only_for_keys,
636
+ use_flash_attention=use_flash_attention,
637
+ is_causal=is_causal,
638
+ use_bias=use_bias,
639
+ num_global_tokens=num_global_tokens,
640
+ window_size=window_size,
641
+ )
625
642
  else:
626
643
  return SparseQueryAttention(
627
644
  embed_dim,
@@ -73,7 +73,7 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
73
73
  assert ff_activation in ['relu', 'gelu',
74
74
  'swish', 'silu', 'linear',
75
75
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
76
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
76
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex', 'flex-sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex", "flex-sqa".'
77
77
 
78
78
  embedding = nn.Embedding(vocab_size, embed_dim)
79
79
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.55
3
+ Version: 0.2.56
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,8 +1,8 @@
1
1
  rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
2
2
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- rxnn/experimental/attention.py,sha256=JMs6Wr2rRe5J5m0ULhudmhBrzPicGuOOyg5hO8aLFiQ,27846
5
- rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
4
+ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhXg,28455
5
+ rxnn/experimental/models.py,sha256=oJWd56LUsLc9S8eCZw-ShvuWjoQxj4C9GitbohlQ0ok,5139
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.55.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.55.dist-info/METADATA,sha256=4XCpsJFv9dpetex6uDRLrzKlMYlQZFLK2H2j---WZmA,25997
38
- rxnn-0.2.55.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.55.dist-info/RECORD,,
36
+ rxnn-0.2.56.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.56.dist-info/METADATA,sha256=qW9X-oP3LWHB0E6S0opHPzWjDmNBRy9DjOz5od4qutc,25997
38
+ rxnn-0.2.56.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.56.dist-info/RECORD,,
File without changes
File without changes