rxnn 0.1.30__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -323,7 +323,7 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
323
323
|
|
324
324
|
# Key/Value MoE routing
|
325
325
|
B, S, D = key.shape
|
326
|
-
key_flat = key.reshape(
|
326
|
+
key_flat = key.reshape(-1, D)
|
327
327
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
328
328
|
weights = weights.view(B, S, self.num_groups, 1)
|
329
329
|
indices = indices.view(B, S, self.num_groups)
|
@@ -331,25 +331,28 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
331
331
|
# Compute all experts' projections
|
332
332
|
# Shape: (B*S, num_experts, head_dim)
|
333
333
|
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
334
|
-
v_all = torch.einsum('bd,edh->beh', value.view(
|
334
|
+
v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
|
335
335
|
|
336
|
-
|
337
|
-
|
338
|
-
|
336
|
+
if self.use_bias:
|
337
|
+
k_all += self.bk
|
338
|
+
v_all += self.bv
|
339
339
|
|
340
|
-
#
|
341
|
-
|
342
|
-
|
343
|
-
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
344
|
-
selected_v = torch.gather(v_all, 2, expanded_indices)
|
340
|
+
# Get results for all heads
|
341
|
+
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
342
|
+
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
345
343
|
|
346
|
-
#
|
347
|
-
|
348
|
-
|
344
|
+
# Gather top-k experts using expanded indices
|
345
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
|
346
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
347
|
+
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
348
|
+
|
349
|
+
# Weighted
|
350
|
+
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
351
|
+
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
349
352
|
|
350
353
|
# Reshape to GQA format
|
351
|
-
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
352
|
-
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
354
|
+
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
355
|
+
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
353
356
|
|
354
357
|
if not self.use_flash_attention:
|
355
358
|
group_heads = self.num_heads // self.num_groups
|
@@ -360,10 +363,6 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
360
363
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
361
364
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
362
365
|
|
363
|
-
print('q', q.size())
|
364
|
-
print('k', k.size())
|
365
|
-
print('v', v.size())
|
366
|
-
|
367
366
|
return q, k, v
|
368
367
|
|
369
368
|
|
@@ -448,13 +447,16 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
448
447
|
|
449
448
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
450
449
|
B, T, D = query.shape
|
451
|
-
query_flat = query.reshape(
|
450
|
+
query_flat = query.reshape(-1, D)
|
452
451
|
weights_q, indices_q = self.query_router(query_flat)
|
453
452
|
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
454
453
|
indices_q = indices_q.view(B, T, self.num_query_groups)
|
455
454
|
|
456
455
|
# Compute all query experts
|
457
456
|
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
457
|
+
if self.use_bias:
|
458
|
+
q_all += self.bq
|
459
|
+
|
458
460
|
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
459
461
|
|
460
462
|
# Gather top-k experts
|
@@ -462,12 +464,11 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
462
464
|
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
463
465
|
|
464
466
|
# Weighted sum
|
465
|
-
q =
|
466
|
-
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B,
|
467
|
+
q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
|
468
|
+
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
467
469
|
|
468
470
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
469
471
|
|
470
|
-
|
471
472
|
# Others
|
472
473
|
|
473
474
|
class FlexAttention(MultiHeadAttention):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=h9pEv_70NKpD5KOQOFP3h-IJKzh7Wbnaxka4Bd3rdd8,29745
|
4
4
|
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.32.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.32.dist-info/METADATA,sha256=oGoeF0F8LvsoUc8qGSd6n0oCIZmoj59bj42QwkaWDuM,16627
|
30
|
+
rxnn-0.1.32.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.32.dist-info/RECORD,,
|
File without changes
|
File without changes
|