rxnn 0.1.29__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -319,7 +319,7 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
319
319
|
head_dim = d // self.num_heads
|
320
320
|
|
321
321
|
# Process Query as in GQA
|
322
|
-
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2)
|
322
|
+
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
323
323
|
|
324
324
|
# Key/Value MoE routing
|
325
325
|
B, S, D = key.shape
|
@@ -337,16 +337,26 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
337
337
|
k_all = k_all.view(B, S, self.num_experts, -1)
|
338
338
|
v_all = v_all.view(B, S, self.num_experts, -1)
|
339
339
|
|
340
|
+
print('k_all', k_all.size())
|
341
|
+
print('v_all', v_all.size())
|
342
|
+
|
343
|
+
|
340
344
|
# Gather top-k experts and weights
|
341
345
|
# Expand indices to [B, S, num_groups, head_dim]
|
342
346
|
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
343
347
|
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
344
348
|
selected_v = torch.gather(v_all, 2, expanded_indices)
|
345
349
|
|
350
|
+
print('selected_k', selected_k.size())
|
351
|
+
print('selected_v', selected_v.size())
|
352
|
+
|
346
353
|
# Weighted sum
|
347
354
|
weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
|
348
355
|
weighted_v = (selected_v * weights).sum(dim=2)
|
349
356
|
|
357
|
+
print('weighted_k', weighted_k.size())
|
358
|
+
print('weighted_v', weighted_v.size())
|
359
|
+
|
350
360
|
# Reshape to GQA format
|
351
361
|
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
|
352
362
|
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
@@ -360,6 +370,10 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
360
370
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
361
371
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
362
372
|
|
373
|
+
print('q', q.size())
|
374
|
+
print('k', k.size())
|
375
|
+
print('v', v.size())
|
376
|
+
|
363
377
|
return q, k, v
|
364
378
|
|
365
379
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=lqRIR2V7pA5LZQ1x_nhUvQBeJ83DLD6tXdoee-W6ZdU,29833
|
4
4
|
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.31.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.31.dist-info/METADATA,sha256=jy1wfyPS-nD34GAMFxcWR_dFn6ZJogJ2-IDxIVmLdmI,16627
|
30
|
+
rxnn-0.1.31.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|