rxnn 0.1.48__py3-none-any.whl → 0.1.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -94,16 +94,17 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
94
94
|
B, S, D = key.shape
|
95
95
|
key_flat = key.reshape(-1, D)
|
96
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
97
|
-
weights = weights.view(B, self.num_groups,
|
98
|
-
indices = indices.view(B, self.num_groups
|
97
|
+
weights = weights.view(B, S, self.num_groups, 1)
|
98
|
+
indices = indices.view(B, S, self.num_groups)
|
99
99
|
|
100
100
|
# Compute all experts' projections
|
101
|
-
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1)
|
102
|
-
v_all = self.v_proj(value).view(B, S, self.num_experts, -1)
|
101
|
+
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
102
|
+
v_all = self.v_proj(value).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
103
103
|
|
104
104
|
# Gather top-k experts using expanded indices
|
105
|
-
|
106
|
-
|
105
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, num_groups, S, head_dim]
|
106
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
|
107
|
+
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
|
107
108
|
|
108
109
|
# Weighted
|
109
110
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
@@ -122,8 +123,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
122
123
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
123
124
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
124
125
|
|
125
|
-
print(q.size(), k.size(), v.size())
|
126
|
-
|
127
126
|
return q, k, v
|
128
127
|
|
129
128
|
|
@@ -208,17 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
208
207
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
209
208
|
B, T, D = query.shape
|
210
209
|
query_flat = query.reshape(-1, D)
|
211
|
-
|
212
|
-
|
213
|
-
|
210
|
+
weights_q, indices_q = self.query_router(query_flat)
|
211
|
+
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
212
|
+
indices_q = indices_q.view(B, T, self.num_query_groups)
|
214
213
|
|
215
|
-
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1)
|
214
|
+
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1) # [B, num_groups, S, head_dim]
|
216
215
|
|
217
|
-
# Gather top-k experts
|
218
|
-
|
216
|
+
# Gather top-k experts
|
217
|
+
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
|
218
|
+
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
219
219
|
|
220
220
|
# Weighted sum
|
221
|
-
q = (selected_q *
|
221
|
+
q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
222
222
|
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
223
223
|
|
224
224
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.49.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.49.dist-info/METADATA,sha256=PijR2z5P5nuTlOaWn-ylU_Loluy-e2HRgpMEc4TCohk,16627
|
30
|
+
rxnn-0.1.49.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.49.dist-info/RECORD,,
|
File without changes
|
File without changes
|