rxnn 0.1.28__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -330,8 +330,8 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
330
330
|
|
331
331
|
# Compute all experts' projections
|
332
332
|
# Shape: (B*S, num_experts, head_dim)
|
333
|
-
k_all = torch.einsum('
|
334
|
-
v_all = torch.einsum('
|
333
|
+
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
334
|
+
v_all = torch.einsum('bd,edh->beh', value.view(B*S, D), self.wv)
|
335
335
|
|
336
336
|
# Reshape to [B, S, num_experts, head_dim]
|
337
337
|
k_all = k_all.view(B, S, self.num_experts, -1)
|
@@ -450,7 +450,7 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
450
450
|
indices_q = indices_q.view(B, T, self.num_query_groups)
|
451
451
|
|
452
452
|
# Compute all query experts
|
453
|
-
q_all = torch.einsum('
|
453
|
+
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
454
454
|
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
455
455
|
|
456
456
|
# Gather top-k experts
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=FLIyfRWZC1Ppfx9u8qUcNcfMSscuibhKEQ8zCJHXcfk,29439
|
4
4
|
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.29.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.29.dist-info/METADATA,sha256=u3tlE_n8tZsqgW44vPssCmmJ03wHKM57tGS1xXI3HcQ,16627
|
30
|
+
rxnn-0.1.29.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.29.dist-info/RECORD,,
|
File without changes
|
File without changes
|