rxnn 0.1.46__py3-none-any.whl → 0.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -94,8 +94,8 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
94
94
|
B, S, D = key.shape
|
95
95
|
key_flat = key.reshape(-1, D)
|
96
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
97
|
-
weights = weights.view(B,
|
98
|
-
indices = indices.view(B,
|
97
|
+
weights = weights.view(B, self.num_groups, S, 1)
|
98
|
+
indices = indices.view(B, self.num_groups, S).unsqueeze(-1).expand(-1, -1, S, -1)
|
99
99
|
|
100
100
|
# Compute all experts' projections
|
101
101
|
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
@@ -122,6 +122,8 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
122
122
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
123
123
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
124
124
|
|
125
|
+
print(q.size(), k.size(), v.size())
|
126
|
+
|
125
127
|
return q, k, v
|
126
128
|
|
127
129
|
|
@@ -207,8 +209,8 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
207
209
|
B, T, D = query.shape
|
208
210
|
query_flat = query.reshape(-1, D)
|
209
211
|
weights, indices = self.query_router(query_flat)
|
210
|
-
weights = weights.view(B,
|
211
|
-
indices = indices.view(B,
|
212
|
+
weights = weights.view(B, self.num_query_groups, T, 1)
|
213
|
+
indices = indices.view(B, self.num_query_groups, T).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
|
212
214
|
|
213
215
|
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
|
214
216
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=XaN7IAHhp1DxFtZxeTEP-EZN7PWnjocC42ndSDB9RvY,29432
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.48.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.48.dist-info/METADATA,sha256=YelXCeEWnK9llWGzSHOiThlUIz8ttWH-KsL_68pJ9-Y,16627
|
30
|
+
rxnn-0.1.48.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.48.dist-info/RECORD,,
|
File without changes
|
File without changes
|