rxnn 0.1.47__py3-none-any.whl → 0.1.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -94,16 +94,17 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
94
94
|
B, S, D = key.shape
|
95
95
|
key_flat = key.reshape(-1, D)
|
96
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
97
|
-
weights = weights.view(B, self.num_groups,
|
98
|
-
indices = indices.view(B, self.num_groups
|
97
|
+
weights = weights.view(B, S, self.num_groups, 1)
|
98
|
+
indices = indices.view(B, S, self.num_groups)
|
99
99
|
|
100
100
|
# Compute all experts' projections
|
101
|
-
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1)
|
102
|
-
v_all = self.v_proj(value).view(B, S, self.num_experts, -1)
|
101
|
+
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
102
|
+
v_all = self.v_proj(value).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
103
103
|
|
104
104
|
# Gather top-k experts using expanded indices
|
105
|
-
|
106
|
-
|
105
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, num_groups, S, head_dim]
|
106
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
|
107
|
+
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
|
107
108
|
|
108
109
|
# Weighted
|
109
110
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
@@ -206,17 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
206
207
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
207
208
|
B, T, D = query.shape
|
208
209
|
query_flat = query.reshape(-1, D)
|
209
|
-
|
210
|
-
|
211
|
-
|
210
|
+
weights_q, indices_q = self.query_router(query_flat)
|
211
|
+
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
212
|
+
indices_q = indices_q.view(B, T, self.num_query_groups)
|
212
213
|
|
213
|
-
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1)
|
214
|
+
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1) # [B, num_groups, S, head_dim]
|
214
215
|
|
215
|
-
# Gather top-k experts
|
216
|
-
|
216
|
+
# Gather top-k experts
|
217
|
+
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
|
218
|
+
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
217
219
|
|
218
220
|
# Weighted sum
|
219
|
-
q = (selected_q *
|
221
|
+
q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
220
222
|
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
221
223
|
|
222
224
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.49.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.49.dist-info/METADATA,sha256=PijR2z5P5nuTlOaWn-ylU_Loluy-e2HRgpMEc4TCohk,16627
|
30
|
+
rxnn-0.1.49.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.49.dist-info/RECORD,,
|
File without changes
|
File without changes
|