rxnn 0.2.55__tar.gz → 0.2.56__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.55 → rxnn-0.2.56}/PKG-INFO +1 -1
  2. {rxnn-0.2.55 → rxnn-0.2.56}/pyproject.toml +1 -1
  3. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/attention.py +18 -1
  4. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/models.py +1 -1
  5. {rxnn-0.2.55 → rxnn-0.2.56}/LICENSE +0 -0
  6. {rxnn-0.2.55 → rxnn-0.2.56}/README.md +0 -0
  7. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/.DS_Store +0 -0
  8. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/__init__.py +0 -0
  9. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/__init__.py +0 -0
  10. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/moe.py +0 -0
  11. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/attention.py +0 -0
  13. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/norm.py +0 -0
  14. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/stm.py +0 -0
  15. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/rxt/__init__.py +0 -0
  16. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/rxt/models.py +0 -0
  17. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/__init__.py +0 -0
  18. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/base.py +0 -0
  19. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/bml.py +0 -0
  20. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/callbacks.py +0 -0
  21. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/dataset.py +0 -0
  22. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/ddp.py +0 -0
  23. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/models.py +0 -0
  24. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/mrl.py +0 -0
  25. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.55
3
+ Version: 0.2.56
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.55"
7
+ version = "0.2.56"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -571,7 +571,7 @@ def init_experimental_attention(
571
571
  num_global_tokens: int = 16,
572
572
  window_size: int = 128,
573
573
  ) -> MultiHeadAttention:
574
- assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
574
+ assert attention_type in ['gma', 'dma', 'sqa', 'flex', 'flex-sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex', 'flex-sqa"
575
575
 
576
576
  if attention_type == "gma":
577
577
  return GroupedMoeAttention(
@@ -622,6 +622,23 @@ def init_experimental_attention(
622
622
  num_global_tokens=num_global_tokens,
623
623
  window_size=window_size,
624
624
  )
625
+ elif attention_type == "flex-sqa":
626
+ return FlexSparseQueryAttention(
627
+ embed_dim,
628
+ num_heads,
629
+ gqa_groups,
630
+ num_query_groups,
631
+ dropout=dropout,
632
+ rope=rope,
633
+ max_seq_len=max_seq_len,
634
+ rope_only_for_query=rope_only_for_query,
635
+ rope_only_for_keys=rope_only_for_keys,
636
+ use_flash_attention=use_flash_attention,
637
+ is_causal=is_causal,
638
+ use_bias=use_bias,
639
+ num_global_tokens=num_global_tokens,
640
+ window_size=window_size,
641
+ )
625
642
  else:
626
643
  return SparseQueryAttention(
627
644
  embed_dim,
@@ -73,7 +73,7 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
73
73
  assert ff_activation in ['relu', 'gelu',
74
74
  'swish', 'silu', 'linear',
75
75
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
76
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
76
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex', 'flex-sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex", "flex-sqa".'
77
77
 
78
78
  embedding = nn.Embedding(vocab_size, embed_dim)
79
79
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes