rxnn 0.1.47__py3-none-any.whl → 0.1.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -94,16 +94,17 @@ class GroupedMoeAttention(GroupedQueryAttention):
94
94
  B, S, D = key.shape
95
95
  key_flat = key.reshape(-1, D)
96
96
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
97
- weights = weights.view(B, self.num_groups, S, 1)
98
- indices = indices.view(B, self.num_groups, S).unsqueeze(-1).expand(-1, -1, S, -1)
97
+ weights = weights.view(B, S, self.num_groups, 1)
98
+ indices = indices.view(B, S, self.num_groups)
99
99
 
100
100
  # Compute all experts' projections
101
- k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
102
- v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
101
+ k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
102
+ v_all = self.v_proj(value).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
103
103
 
104
104
  # Gather top-k experts using expanded indices
105
- selected_k = torch.gather(k_all, 1, indices) # [B, num_groups, S, head_dim]
106
- selected_v = torch.gather(v_all, 1, indices) # [B, num_groups, S, head_dim]
105
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, num_groups, S, head_dim]
106
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
107
+ selected_v = torch.gather(v_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
107
108
 
108
109
  # Weighted
109
110
  weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
@@ -206,17 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
206
207
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
207
208
  B, T, D = query.shape
208
209
  query_flat = query.reshape(-1, D)
209
- weights, indices = self.query_router(query_flat)
210
- weights = weights.view(B, self.num_query_groups, T, 1)
211
- indices = indices.view(B, self.num_query_groups, T).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
210
+ weights_q, indices_q = self.query_router(query_flat)
211
+ weights_q = weights_q.view(B, T, self.num_query_groups, 1)
212
+ indices_q = indices_q.view(B, T, self.num_query_groups)
212
213
 
213
- q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
214
+ q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1) # [B, num_groups, S, head_dim]
214
215
 
215
- # Gather top-k experts using expanded indices
216
- selected_q = torch.gather(q_all, 1, indices) # [B, num_query_groups, T, head_dim]
216
+ # Gather top-k experts
217
+ expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
218
+ selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
217
219
 
218
220
  # Weighted sum
219
- q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
221
+ q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
220
222
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
221
223
 
222
224
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.47
3
+ Version: 0.1.49
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=iOLPDm45kIx_hd9VUaMgJiXT2AMCwBV9Uj47yHnsOs4,29387
3
+ rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
4
4
  rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.47.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.47.dist-info/METADATA,sha256=Cm2G_bHp5E1katGPPCj-NE9_dTkXJHKr9NpWMnlSbWw,16627
30
- rxnn-0.1.47.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.47.dist-info/RECORD,,
28
+ rxnn-0.1.49.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.49.dist-info/METADATA,sha256=PijR2z5P5nuTlOaWn-ylU_Loluy-e2HRgpMEc4TCohk,16627
30
+ rxnn-0.1.49.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.49.dist-info/RECORD,,
File without changes
File without changes