rxnn 0.1.31__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +23 -32
- rxnn/transformers/attention.py +10 -8
- {rxnn-0.1.31.dist-info → rxnn-0.1.33.dist-info}/METADATA +1 -1
- {rxnn-0.1.31.dist-info → rxnn-0.1.33.dist-info}/RECORD +6 -6
- {rxnn-0.1.31.dist-info → rxnn-0.1.33.dist-info}/LICENSE +0 -0
- {rxnn-0.1.31.dist-info → rxnn-0.1.33.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -323,7 +323,7 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
323
323
|
|
324
324
|
# Key/Value MoE routing
|
325
325
|
B, S, D = key.shape
|
326
|
-
key_flat = key.reshape(
|
326
|
+
key_flat = key.reshape(-1, D)
|
327
327
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
328
328
|
weights = weights.view(B, S, self.num_groups, 1)
|
329
329
|
indices = indices.view(B, S, self.num_groups)
|
@@ -331,35 +331,28 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
331
331
|
# Compute all experts' projections
|
332
332
|
# Shape: (B*S, num_experts, head_dim)
|
333
333
|
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
334
|
-
v_all = torch.einsum('bd,edh->beh', value.view(
|
334
|
+
v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
|
335
335
|
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
print('k_all', k_all.size())
|
341
|
-
print('v_all', v_all.size())
|
342
|
-
|
343
|
-
|
344
|
-
# Gather top-k experts and weights
|
345
|
-
# Expand indices to [B, S, num_groups, head_dim]
|
346
|
-
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
347
|
-
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
348
|
-
selected_v = torch.gather(v_all, 2, expanded_indices)
|
336
|
+
if self.use_bias:
|
337
|
+
k_all += self.bk
|
338
|
+
v_all += self.bv
|
349
339
|
|
350
|
-
|
351
|
-
|
340
|
+
# Get results for all heads
|
341
|
+
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
342
|
+
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
352
343
|
|
353
|
-
#
|
354
|
-
|
355
|
-
|
344
|
+
# Gather top-k experts using expanded indices
|
345
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
|
346
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
347
|
+
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
356
348
|
|
357
|
-
|
358
|
-
|
349
|
+
# Weighted
|
350
|
+
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
351
|
+
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
359
352
|
|
360
353
|
# Reshape to GQA format
|
361
|
-
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
362
|
-
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
354
|
+
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
355
|
+
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
363
356
|
|
364
357
|
if not self.use_flash_attention:
|
365
358
|
group_heads = self.num_heads // self.num_groups
|
@@ -370,10 +363,6 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
370
363
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
371
364
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
372
365
|
|
373
|
-
print('q', q.size())
|
374
|
-
print('k', k.size())
|
375
|
-
print('v', v.size())
|
376
|
-
|
377
366
|
return q, k, v
|
378
367
|
|
379
368
|
|
@@ -458,13 +447,16 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
458
447
|
|
459
448
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
460
449
|
B, T, D = query.shape
|
461
|
-
query_flat = query.reshape(
|
450
|
+
query_flat = query.reshape(-1, D)
|
462
451
|
weights_q, indices_q = self.query_router(query_flat)
|
463
452
|
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
464
453
|
indices_q = indices_q.view(B, T, self.num_query_groups)
|
465
454
|
|
466
455
|
# Compute all query experts
|
467
456
|
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
457
|
+
if self.use_bias:
|
458
|
+
q_all += self.bq
|
459
|
+
|
468
460
|
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
469
461
|
|
470
462
|
# Gather top-k experts
|
@@ -472,12 +464,11 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
472
464
|
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
473
465
|
|
474
466
|
# Weighted sum
|
475
|
-
q =
|
476
|
-
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B,
|
467
|
+
q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
|
468
|
+
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
477
469
|
|
478
470
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
479
471
|
|
480
|
-
|
481
472
|
# Others
|
482
473
|
|
483
474
|
class FlexAttention(MultiHeadAttention):
|
rxnn/transformers/attention.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
+
from torch.backends.cuda import sdp_kernel, SDPBackend
|
4
5
|
import math
|
5
6
|
from .positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
|
6
7
|
|
@@ -17,7 +18,7 @@ class MultiHeadAttention(nn.Module):
|
|
17
18
|
rope_only_for_query: bool = False,
|
18
19
|
use_relative_embeddings: bool = False,
|
19
20
|
max_seq_len: int = 1024,
|
20
|
-
use_flash_attention: bool =
|
21
|
+
use_flash_attention: bool = True,
|
21
22
|
is_causal: bool = False,
|
22
23
|
use_bias: bool = False,
|
23
24
|
*args,
|
@@ -101,13 +102,14 @@ class MultiHeadAttention(nn.Module):
|
|
101
102
|
|
102
103
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
103
104
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
105
|
+
with sdp_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
106
|
+
attn_output = F.scaled_dot_product_attention(
|
107
|
+
q, k, v,
|
108
|
+
attn_mask=mask if not self.is_causal else None,
|
109
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
110
|
+
is_causal=self.is_causal,
|
111
|
+
enable_gqa=enable_gqa,
|
112
|
+
)
|
111
113
|
return self._transpose_output(attn_output, b, t, d)
|
112
114
|
|
113
115
|
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=h9pEv_70NKpD5KOQOFP3h-IJKzh7Wbnaxka4Bd3rdd8,29745
|
4
4
|
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=Nox986BH9qq4rDYLiYmfj1DeMeULF3akexIl99MPccM,14331
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.33.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.33.dist-info/METADATA,sha256=m3DWDnTu7Lx1kHYPIAQCdKU8t4QZBdqG0QcSIFvB924,16627
|
30
|
+
rxnn-0.1.33.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.33.dist-info/RECORD,,
|
File without changes
|
File without changes
|