rxnn 0.2.55__tar.gz → 0.2.56__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.55 → rxnn-0.2.56}/PKG-INFO +1 -1
- {rxnn-0.2.55 → rxnn-0.2.56}/pyproject.toml +1 -1
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/attention.py +18 -1
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/models.py +1 -1
- {rxnn-0.2.55 → rxnn-0.2.56}/LICENSE +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/README.md +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/mrl.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/reward.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.55 → rxnn-0.2.56}/src/rxnn/utils.py +0 -0
@@ -571,7 +571,7 @@ def init_experimental_attention(
|
|
571
571
|
num_global_tokens: int = 16,
|
572
572
|
window_size: int = 128,
|
573
573
|
) -> MultiHeadAttention:
|
574
|
-
assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
|
574
|
+
assert attention_type in ['gma', 'dma', 'sqa', 'flex', 'flex-sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex', 'flex-sqa"
|
575
575
|
|
576
576
|
if attention_type == "gma":
|
577
577
|
return GroupedMoeAttention(
|
@@ -622,6 +622,23 @@ def init_experimental_attention(
|
|
622
622
|
num_global_tokens=num_global_tokens,
|
623
623
|
window_size=window_size,
|
624
624
|
)
|
625
|
+
elif attention_type == "flex-sqa":
|
626
|
+
return FlexSparseQueryAttention(
|
627
|
+
embed_dim,
|
628
|
+
num_heads,
|
629
|
+
gqa_groups,
|
630
|
+
num_query_groups,
|
631
|
+
dropout=dropout,
|
632
|
+
rope=rope,
|
633
|
+
max_seq_len=max_seq_len,
|
634
|
+
rope_only_for_query=rope_only_for_query,
|
635
|
+
rope_only_for_keys=rope_only_for_keys,
|
636
|
+
use_flash_attention=use_flash_attention,
|
637
|
+
is_causal=is_causal,
|
638
|
+
use_bias=use_bias,
|
639
|
+
num_global_tokens=num_global_tokens,
|
640
|
+
window_size=window_size,
|
641
|
+
)
|
625
642
|
else:
|
626
643
|
return SparseQueryAttention(
|
627
644
|
embed_dim,
|
@@ -73,7 +73,7 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
73
73
|
assert ff_activation in ['relu', 'gelu',
|
74
74
|
'swish', 'silu', 'linear',
|
75
75
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
76
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
|
76
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex', 'flex-sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex", "flex-sqa".'
|
77
77
|
|
78
78
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
79
79
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|