rxnn 0.1.55__py3-none-any.whl → 0.1.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +3 -3
- rxnn/experimental/models.py +6 -6
- rxnn/rxt/models.py +33 -10
- {rxnn-0.1.55.dist-info → rxnn-0.1.57.dist-info}/METADATA +1 -1
- {rxnn-0.1.55.dist-info → rxnn-0.1.57.dist-info}/RECORD +7 -7
- {rxnn-0.1.55.dist-info → rxnn-0.1.57.dist-info}/LICENSE +0 -0
- {rxnn-0.1.55.dist-info → rxnn-0.1.57.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -283,12 +283,12 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
283
283
|
"""Override query, key, and value projections for GQA case - split data into heads and groups"""
|
284
284
|
head_dim = d // self.num_heads
|
285
285
|
if not self.rel_embed:
|
286
|
-
q = self.q_proj(query).view(b, t, self.
|
286
|
+
q = self.q_proj(query).view(b, t, self.num_query_groups, head_dim).transpose(1, 2)
|
287
287
|
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
288
288
|
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
289
289
|
else:
|
290
290
|
group_heads = self.num_heads // self.num_groups
|
291
|
-
query_heads = self.
|
291
|
+
query_heads = self.num_heads // self.num_query_groups
|
292
292
|
# Process Q
|
293
293
|
q = self.q_proj(query).view(b, -1, self.num_query_groups, head_dim).transpose(1, 2) # (B, Q_G, T, head_dim)
|
294
294
|
|
@@ -683,7 +683,7 @@ class InfiniteAttention(MultiHeadAttention):
|
|
683
683
|
attn = torch.softmax(attn, dim=-1)
|
684
684
|
return torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
685
685
|
|
686
|
-
def
|
686
|
+
def init_experimental_attention(
|
687
687
|
embed_dim: int,
|
688
688
|
num_heads: int,
|
689
689
|
attention_type: str,
|
rxnn/experimental/models.py
CHANGED
@@ -8,7 +8,7 @@ from ..transformers.layers import ClassicTransformerLayer
|
|
8
8
|
from ..transformers.models import ClassicTransformerDecoder
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..utils import get_model_size
|
11
|
-
from .attention import
|
11
|
+
from .attention import init_experimental_attention
|
12
12
|
|
13
13
|
|
14
14
|
class MoeAttentionTransformerConfig(TypedDict):
|
@@ -77,11 +77,11 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
77
77
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
78
78
|
max_seq_len=seq_len, is_causal=True)
|
79
79
|
else:
|
80
|
-
att_init = lambda:
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
80
|
+
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
81
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
82
|
+
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
83
|
+
num_query_experts=att_num_query_experts,
|
84
|
+
num_query_groups=att_num_query_groups)
|
85
85
|
|
86
86
|
use_moe_att = att_type in ['gma', 'dma', 'gma_s', 'dma_s']
|
87
87
|
|
rxnn/rxt/models.py
CHANGED
@@ -9,7 +9,7 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
|
|
9
9
|
from ..transformers.ff import get_activation_layer
|
10
10
|
from ..memory.stm import ShortTermMemory
|
11
11
|
from ..utils import get_model_size
|
12
|
-
|
12
|
+
from ..experimental.attention import init_experimental_attention
|
13
13
|
|
14
14
|
class RxTAlphaComponentConfig(TypedDict):
|
15
15
|
num_layers: int
|
@@ -31,6 +31,9 @@ class RxTAlphaComponentConfig(TypedDict):
|
|
31
31
|
moe_top_k: int
|
32
32
|
self_att_type: str
|
33
33
|
cross_att_type: str
|
34
|
+
att_num_experts: int
|
35
|
+
att_num_query_experts: int
|
36
|
+
att_num_query_groups: int
|
34
37
|
|
35
38
|
|
36
39
|
class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
@@ -58,14 +61,17 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
58
61
|
moe_top_k: int = 1,
|
59
62
|
self_att_type: str = 'gqa',
|
60
63
|
cross_att_type: str = 'mqa',
|
64
|
+
att_num_experts: int = None,
|
65
|
+
att_num_query_experts: int = None,
|
66
|
+
att_num_query_groups: int = None,
|
61
67
|
**kwargs
|
62
68
|
):
|
63
69
|
super(RxTAlphaComponentBase, self).__init__(**kwargs)
|
64
70
|
assert ff_activation in ['relu', 'gelu',
|
65
71
|
'swish', 'silu', 'linear',
|
66
72
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
67
|
-
assert self_att_type in ['mha', 'gqa', 'mqa'], 'Self-attention type could be "mha", "gqa", "mqa"'
|
68
|
-
assert cross_att_type in ['mha', 'gqa', 'mqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa"'
|
73
|
+
assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
74
|
+
assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
69
75
|
|
70
76
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
71
77
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -73,6 +79,28 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
73
79
|
|
74
80
|
ff_activation = get_activation_layer(ff_activation)
|
75
81
|
|
82
|
+
if self_att_type in ['mha', 'gqa', 'mqa']:
|
83
|
+
att_init = lambda: init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
84
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
85
|
+
max_seq_len=seq_len, is_causal=True)
|
86
|
+
else:
|
87
|
+
att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
88
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
89
|
+
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
90
|
+
num_query_experts=att_num_query_experts,
|
91
|
+
num_query_groups=att_num_query_groups)
|
92
|
+
|
93
|
+
if cross_att_type in ['mha', 'gqa', 'mqa']:
|
94
|
+
cross_att_init = lambda: init_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
95
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
96
|
+
max_seq_len=seq_len, is_causal=True)
|
97
|
+
else:
|
98
|
+
cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
|
99
|
+
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
100
|
+
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
101
|
+
num_query_experts=att_num_query_experts,
|
102
|
+
num_query_groups=att_num_query_groups)
|
103
|
+
|
76
104
|
layers = nn.ModuleList([
|
77
105
|
ReactiveTransformerLayer(
|
78
106
|
embed_dim,
|
@@ -84,13 +112,8 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
|
|
84
112
|
ff_activation=ff_activation,
|
85
113
|
ff_dropout=ff_dropout,
|
86
114
|
use_rms_norm=use_rms_norm,
|
87
|
-
self_attention=
|
88
|
-
|
89
|
-
max_seq_len=seq_len, is_causal=is_causal),
|
90
|
-
memory_cross_attention=init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
|
91
|
-
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
92
|
-
max_seq_len=seq_len, rope_only_for_query=True,
|
93
|
-
is_causal=is_causal)
|
115
|
+
self_attention=att_init(),
|
116
|
+
memory_cross_attention=cross_att_init(),
|
94
117
|
) for _ in range(num_layers)
|
95
118
|
])
|
96
119
|
self.model = self._init_model(stm, layers, embedding, use_flash_attention, embed_dim, vocab_size)
|
@@ -1,13 +1,13 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=ivIqIc-15DWA_q-ITy2iaYmB7tffKVtiuqjdSH3mtS4,34548
|
4
|
+
rxnn/experimental/models.py,sha256=_i9kvQsAYPyMQo2VfMUTmtBs-mE2w75j1X-OHx03IJk,4743
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
8
8
|
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
9
9
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
rxnn/rxt/models.py,sha256=
|
10
|
+
rxnn/rxt/models.py,sha256=L5RvhORONmYSF_pVjP8HwiSeAypSNqfFi6Fogp2oJes,8543
|
11
11
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/training/base.py,sha256=gEWASLSuWR8UF8b2e-DYqkBZ1lBx0VsIm4kGf9eWSHM,11678
|
13
13
|
rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.57.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.57.dist-info/METADATA,sha256=K9kcLSS3CUYwFg9N-KhPK5J4tMmsgKYFqF6VkH8689U,16627
|
30
|
+
rxnn-0.1.57.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.57.dist-info/RECORD,,
|
File without changes
|
File without changes
|