rxnn 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -95,16 +95,15 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
95
95
|
key_flat = key.reshape(-1, D)
|
96
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
97
97
|
weights = weights.view(B, S, self.num_groups, 1)
|
98
|
-
indices = indices.view(B, S, self.num_groups)
|
98
|
+
indices = indices.view(B, S, self.num_groups).unsqueeze(-1).transpose(1, 2).expand(-1, -1, S, -1)
|
99
99
|
|
100
100
|
# Compute all experts' projections
|
101
101
|
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
102
102
|
v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
|
103
103
|
|
104
104
|
# Gather top-k experts using expanded indices
|
105
|
-
|
106
|
-
|
107
|
-
selected_v = torch.gather(v_all, 1, expanded_indices) # [B, num_groups, S, head_dim]
|
105
|
+
selected_k = torch.gather(k_all, 1, indices) # [B, num_groups, S, head_dim]
|
106
|
+
selected_v = torch.gather(v_all, 1, indices) # [B, num_groups, S, head_dim]
|
108
107
|
|
109
108
|
# Weighted
|
110
109
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
@@ -209,13 +208,12 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
209
208
|
query_flat = query.reshape(-1, D)
|
210
209
|
weights, indices = self.query_router(query_flat)
|
211
210
|
weights = weights.view(B, T, self.num_query_groups, 1)
|
212
|
-
indices = indices.view(B, T, self.num_query_groups)
|
211
|
+
indices = indices.view(B, T, self.num_query_groups).unsqueeze(-1).transpose(1, 2).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
|
213
212
|
|
214
213
|
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
|
215
214
|
|
216
215
|
# Gather top-k experts using expanded indices
|
217
|
-
|
218
|
-
selected_q = torch.gather(q_all, 1, expanded_indices) # [B, num_query_groups, T, head_dim]
|
216
|
+
selected_q = torch.gather(q_all, 1, indices) # [B, num_query_groups, T, head_dim]
|
219
217
|
|
220
218
|
# Weighted sum
|
221
219
|
q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256
|
3
|
+
rxnn/experimental/attention.py,sha256=KGBh3KdsR92QgzwYx9I5ay0baGACojg_ixxhraF8Vi0,29419
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.46.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.46.dist-info/METADATA,sha256=cJyWr6q9bpjsEnsR9IFd3tJte3WV34qHLporCJhiEHc,16627
|
30
|
+
rxnn-0.1.46.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.46.dist-info/RECORD,,
|
File without changes
|
File without changes
|