rxnn 0.1.54__py3-none-any.whl → 0.1.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +109 -2
- rxnn/experimental/models.py +1 -1
- {rxnn-0.1.54.dist-info → rxnn-0.1.55.dist-info}/METADATA +1 -1
- {rxnn-0.1.54.dist-info → rxnn-0.1.55.dist-info}/RECORD +6 -6
- {rxnn-0.1.54.dist-info → rxnn-0.1.55.dist-info}/LICENSE +0 -0
- {rxnn-0.1.54.dist-info → rxnn-0.1.55.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -226,6 +226,98 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
226
226
|
|
227
227
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
228
228
|
|
229
|
+
class SparseQueryAttention(MultiHeadAttention):
|
230
|
+
"""Sparse Grouped Query attention layer, with RoPE support"""
|
231
|
+
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
embed_dim: int,
|
235
|
+
num_heads: int,
|
236
|
+
num_groups: int,
|
237
|
+
num_query_groups: int,
|
238
|
+
dropout: float = 0.0,
|
239
|
+
rope: RotaryPositionalEmbedding = None,
|
240
|
+
rope_only_for_query: bool = False,
|
241
|
+
use_relative_embeddings: bool = False,
|
242
|
+
max_seq_len: int = 1024,
|
243
|
+
use_flash_attention: bool = False,
|
244
|
+
is_causal: bool = False,
|
245
|
+
use_bias: bool = False,
|
246
|
+
*args,
|
247
|
+
**kwargs,
|
248
|
+
):
|
249
|
+
self.num_groups = num_groups
|
250
|
+
self.num_query_groups = num_query_groups
|
251
|
+
super(SparseQueryAttention, self).__init__(
|
252
|
+
embed_dim,
|
253
|
+
num_heads,
|
254
|
+
dropout=dropout,
|
255
|
+
rope=rope,
|
256
|
+
rope_only_for_query=rope_only_for_query,
|
257
|
+
use_relative_embeddings=use_relative_embeddings,
|
258
|
+
max_seq_len=max_seq_len,
|
259
|
+
use_flash_attention=use_flash_attention,
|
260
|
+
is_causal=is_causal,
|
261
|
+
use_bias=use_bias,
|
262
|
+
*args,
|
263
|
+
**kwargs,
|
264
|
+
)
|
265
|
+
assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
|
266
|
+
|
267
|
+
def _init_kv(self, embed_dim: int):
|
268
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
269
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_groups), bias=self.use_bias)
|
270
|
+
|
271
|
+
def _init_q(self, embed_dim: int):
|
272
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim // (self.num_heads // self.num_query_groups))
|
273
|
+
|
274
|
+
def _init_out(self, embed_dim: int):
|
275
|
+
"""Initialize output projection"""
|
276
|
+
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_query_groups), embed_dim)
|
277
|
+
|
278
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
279
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
280
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, d // (self.num_heads // self.num_query_groups))
|
281
|
+
|
282
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
283
|
+
"""Override query, key, and value projections for GQA case - split data into heads and groups"""
|
284
|
+
head_dim = d // self.num_heads
|
285
|
+
if not self.rel_embed:
|
286
|
+
q = self.q_proj(query).view(b, t, self.num_query_heads, head_dim).transpose(1, 2)
|
287
|
+
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
288
|
+
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
289
|
+
else:
|
290
|
+
group_heads = self.num_heads // self.num_groups
|
291
|
+
query_heads = self.num_query_heads // self.num_query_groups
|
292
|
+
# Process Q
|
293
|
+
q = self.q_proj(query).view(b, -1, self.num_query_groups, head_dim).transpose(1, 2) # (B, Q_G, T, head_dim)
|
294
|
+
|
295
|
+
# Process K and V
|
296
|
+
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
|
297
|
+
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2) # (B, G, S, head_dim)
|
298
|
+
|
299
|
+
# Expand and flatten to 4D tensors
|
300
|
+
q = q.unsqueeze(2).expand(-1, -1, query_heads, -1, -1) # (B, Q_G, query_heads, T, head_dim)
|
301
|
+
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
302
|
+
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
303
|
+
|
304
|
+
q = q.flatten(start_dim=1, end_dim=2) # (B, Q, T, head_dim)
|
305
|
+
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
306
|
+
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
307
|
+
return q, k, v
|
308
|
+
|
309
|
+
def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
|
310
|
+
is_gqa = self.num_query_groups != self.num_groups
|
311
|
+
if self.use_flash_attention:
|
312
|
+
# Compute attention with FlashAttention
|
313
|
+
return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
|
314
|
+
else:
|
315
|
+
# Compute attention using optimized PyTorch implementation
|
316
|
+
return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
|
317
|
+
|
318
|
+
|
319
|
+
|
320
|
+
|
229
321
|
class GroupedMoeAttentionSimplified(GroupedQueryAttention):
|
230
322
|
"""
|
231
323
|
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
@@ -607,8 +699,8 @@ def init_moe_attention(
|
|
607
699
|
num_experts: int = None,
|
608
700
|
num_query_experts: int = None,
|
609
701
|
num_query_groups: int = None,
|
610
|
-
) ->
|
611
|
-
assert attention_type in ['gma', 'dma', 'gma_s', 'dma_s'], "Error, attention type should be one of: 'gma', 'dma', 'gma_s', 'dma_s'"
|
702
|
+
) -> MultiHeadAttention:
|
703
|
+
assert attention_type in ['gma', 'dma', 'gma_s', 'dma_s', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'gma_s', 'dma_s', 'sqa'"
|
612
704
|
|
613
705
|
if attention_type == "gma":
|
614
706
|
return GroupedMoeAttention(
|
@@ -642,6 +734,21 @@ def init_moe_attention(
|
|
642
734
|
num_query_experts=num_query_experts,
|
643
735
|
num_query_groups=num_query_groups,
|
644
736
|
)
|
737
|
+
elif attention_type == 'sqa':
|
738
|
+
return SparseQueryAttention(
|
739
|
+
embed_dim,
|
740
|
+
num_heads,
|
741
|
+
gqa_groups,
|
742
|
+
num_query_groups,
|
743
|
+
dropout=dropout,
|
744
|
+
rope=rope,
|
745
|
+
use_relative_embeddings=use_relative_embeddings,
|
746
|
+
max_seq_len=max_seq_len,
|
747
|
+
rope_only_for_query=rope_only_for_query,
|
748
|
+
use_flash_attention=use_flash_attention,
|
749
|
+
is_causal=is_causal,
|
750
|
+
use_bias=use_bias,
|
751
|
+
)
|
645
752
|
elif attention_type == "gma_s":
|
646
753
|
return GroupedMoeAttentionSimplified(
|
647
754
|
embed_dim,
|
rxnn/experimental/models.py
CHANGED
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
65
65
|
assert ff_activation in ['relu', 'gelu',
|
66
66
|
'swish', 'silu', 'linear',
|
67
67
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
68
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', '
|
68
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
69
69
|
|
70
70
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
71
71
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256
|
3
|
+
rxnn/experimental/attention.py,sha256=oPknT_PVcNwZvDwpZM7gmP4M_md_FW8oYwJDdQk1avM,34544
|
4
|
+
rxnn/experimental/models.py,sha256=iprFSQDPK75zebDJBJ1i-mnNS9jlGf9RAIk-S0E9D-Q,4689
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.55.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.55.dist-info/METADATA,sha256=qiMp63aMlBdbvVvSJDL2bfW5XoR0PzNwN6pWdkfCuOM,16627
|
30
|
+
rxnn-0.1.55.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.55.dist-info/RECORD,,
|
File without changes
|
File without changes
|