rxnn 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -434,12 +434,13 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
434
434
|
|
435
435
|
def _init_out(self, embed_dim: int):
|
436
436
|
"""Initialize output projection"""
|
437
|
-
|
437
|
+
out_hidden_dim = embed_dim // self.num_heads * self.num_query_groups
|
438
|
+
self.out_proj = nn.Linear(out_hidden_dim, embed_dim)
|
438
439
|
|
439
440
|
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
440
441
|
"""Transpose attention output back to (B, T, D) shape"""
|
441
|
-
|
442
|
-
return attn_output.transpose(1, 2).contiguous().view(b, t,
|
442
|
+
out_hidden_dim = d // self.num_heads * self.num_query_groups
|
443
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, out_hidden_dim)
|
443
444
|
|
444
445
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
445
446
|
B, T, D = query.shape
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=GnK_J7o_4fJ5O50ETx4oG-p7dOCsPRMwVGv3BIbUIbg,29439
|
4
4
|
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.28.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.28.dist-info/METADATA,sha256=zpVetjl-0pFz7Z4e4GUlybS-rBHKFk3AYIS6fM46diU,16627
|
30
|
+
rxnn-0.1.28.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.28.dist-info/RECORD,,
|
File without changes
|
File without changes
|