rxnn 0.1.48__py3-none-any.whl → 0.1.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -94,16 +94,17 @@ class GroupedMoeAttention(GroupedQueryAttention):
94
94
  B, S, D = key.shape
95
95
  key_flat = key.reshape(-1, D)
96
96
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
97
- weights = weights.view(B, self.num_groups, S, 1)
98
- indices = indices.view(B, self.num_groups, S).unsqueeze(-1).expand(-1, -1, S, -1)
97
+ weights = weights.view(B, S, self.num_groups, 1)
98
+ indices = indices.view(B, S, self.num_groups)
99
99
 
100
100
  # Compute all experts' projections
101
- k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
102
- v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
101
+ k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
102
+ v_all = self.v_proj(value).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
103
103
 
104
104
  # Gather top-k experts using expanded indices
105
- selected_k = torch.gather(k_all, 1, indices) # [B, num_groups, S, head_dim]
106
- selected_v = torch.gather(v_all, 1, indices) # [B, num_groups, S, head_dim]
105
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, num_groups, S, head_dim]
106
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
107
+ selected_v = torch.gather(v_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
107
108
 
108
109
  # Weighted
109
110
  weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
@@ -122,8 +123,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
122
123
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
123
124
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
124
125
 
125
- print(q.size(), k.size(), v.size())
126
-
127
126
  return q, k, v
128
127
 
129
128
 
@@ -208,17 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
208
207
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
209
208
  B, T, D = query.shape
210
209
  query_flat = query.reshape(-1, D)
211
- weights, indices = self.query_router(query_flat)
212
- weights = weights.view(B, self.num_query_groups, T, 1)
213
- indices = indices.view(B, self.num_query_groups, T).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
210
+ weights_q, indices_q = self.query_router(query_flat)
211
+ weights_q = weights_q.view(B, T, self.num_query_groups, 1)
212
+ indices_q = indices_q.view(B, T, self.num_query_groups)
214
213
 
215
- q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
214
+ q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1) # [B, num_groups, S, head_dim]
216
215
 
217
- # Gather top-k experts using expanded indices
218
- selected_q = torch.gather(q_all, 1, indices) # [B, num_query_groups, T, head_dim]
216
+ # Gather top-k experts
217
+ expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
218
+ selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
219
219
 
220
220
  # Weighted sum
221
- q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
221
+ q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
222
222
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
223
223
 
224
224
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.48
3
+ Version: 0.1.49
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=XaN7IAHhp1DxFtZxeTEP-EZN7PWnjocC42ndSDB9RvY,29432
3
+ rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
4
4
  rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
5
5
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.48.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.48.dist-info/METADATA,sha256=YelXCeEWnK9llWGzSHOiThlUIz8ttWH-KsL_68pJ9-Y,16627
30
- rxnn-0.1.48.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.48.dist-info/RECORD,,
28
+ rxnn-0.1.49.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.49.dist-info/METADATA,sha256=PijR2z5P5nuTlOaWn-ylU_Loluy-e2HRgpMEc4TCohk,16627
30
+ rxnn-0.1.49.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.49.dist-info/RECORD,,
File without changes
File without changes