rxnn 0.1.30__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -337,16 +337,26 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
337
337
|
k_all = k_all.view(B, S, self.num_experts, -1)
|
338
338
|
v_all = v_all.view(B, S, self.num_experts, -1)
|
339
339
|
|
340
|
+
print('k_all', k_all.size())
|
341
|
+
print('v_all', v_all.size())
|
342
|
+
|
343
|
+
|
340
344
|
# Gather top-k experts and weights
|
341
345
|
# Expand indices to [B, S, num_groups, head_dim]
|
342
346
|
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
343
347
|
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
344
348
|
selected_v = torch.gather(v_all, 2, expanded_indices)
|
345
349
|
|
350
|
+
print('selected_k', selected_k.size())
|
351
|
+
print('selected_v', selected_v.size())
|
352
|
+
|
346
353
|
# Weighted sum
|
347
354
|
weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
|
348
355
|
weighted_v = (selected_v * weights).sum(dim=2)
|
349
356
|
|
357
|
+
print('weighted_k', weighted_k.size())
|
358
|
+
print('weighted_v', weighted_v.size())
|
359
|
+
|
350
360
|
# Reshape to GQA format
|
351
361
|
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
|
352
362
|
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=lqRIR2V7pA5LZQ1x_nhUvQBeJ83DLD6tXdoee-W6ZdU,29833
|
4
4
|
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.31.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.31.dist-info/METADATA,sha256=jy1wfyPS-nD34GAMFxcWR_dFn6ZJogJ2-IDxIVmLdmI,16627
|
30
|
+
rxnn-0.1.31.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|