rxnn 0.1.30__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -323,7 +323,7 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
323
323
 
324
324
  # Key/Value MoE routing
325
325
  B, S, D = key.shape
326
- key_flat = key.reshape(B * S, D)
326
+ key_flat = key.reshape(-1, D)
327
327
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
328
328
  weights = weights.view(B, S, self.num_groups, 1)
329
329
  indices = indices.view(B, S, self.num_groups)
@@ -331,25 +331,28 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
331
331
  # Compute all experts' projections
332
332
  # Shape: (B*S, num_experts, head_dim)
333
333
  k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
- v_all = torch.einsum('bd,edh->beh', value.view(B*S, D), self.wv)
334
+ v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
335
335
 
336
- # Reshape to [B, S, num_experts, head_dim]
337
- k_all = k_all.view(B, S, self.num_experts, -1)
338
- v_all = v_all.view(B, S, self.num_experts, -1)
336
+ if self.use_bias:
337
+ k_all += self.bk
338
+ v_all += self.bv
339
339
 
340
- # Gather top-k experts and weights
341
- # Expand indices to [B, S, num_groups, head_dim]
342
- expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
343
- selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
344
- selected_v = torch.gather(v_all, 2, expanded_indices)
340
+ # Get results for all heads
341
+ k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
342
+ v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
345
343
 
346
- # Weighted sum
347
- weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
348
- weighted_v = (selected_v * weights).sum(dim=2)
344
+ # Gather top-k experts using expanded indices
345
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
346
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
347
+ selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
348
+
349
+ # Weighted
350
+ weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
351
+ weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
349
352
 
350
353
  # Reshape to GQA format
351
- k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
352
- v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
354
+ k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
355
+ v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
353
356
 
354
357
  if not self.use_flash_attention:
355
358
  group_heads = self.num_heads // self.num_groups
@@ -360,10 +363,6 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
360
363
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
361
364
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
362
365
 
363
- print('q', q.size())
364
- print('k', k.size())
365
- print('v', v.size())
366
-
367
366
  return q, k, v
368
367
 
369
368
 
@@ -448,13 +447,16 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
448
447
 
449
448
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
450
449
  B, T, D = query.shape
451
- query_flat = query.reshape(B * T, D)
450
+ query_flat = query.reshape(-1, D)
452
451
  weights_q, indices_q = self.query_router(query_flat)
453
452
  weights_q = weights_q.view(B, T, self.num_query_groups, 1)
454
453
  indices_q = indices_q.view(B, T, self.num_query_groups)
455
454
 
456
455
  # Compute all query experts
457
456
  q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
457
+ if self.use_bias:
458
+ q_all += self.bq
459
+
458
460
  q_all = q_all.view(B, T, self.num_query_experts, -1)
459
461
 
460
462
  # Gather top-k experts
@@ -462,12 +464,11 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
462
464
  selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
463
465
 
464
466
  # Weighted sum
465
- q = (selected_q * weights_q).sum(dim=2) # [B, T, head_dim]
466
- q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, H_q, T, head_dim]
467
+ q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
468
+ q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
467
469
 
468
470
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
469
471
 
470
-
471
472
  # Others
472
473
 
473
474
  class FlexAttention(MultiHeadAttention):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.30
3
+ Version: 0.1.32
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=dOhnHYknVMiw4kIfaWo08ycz1Kl5KfFAHfZEVljs2n0,29567
3
+ rxnn/experimental/attention.py,sha256=h9pEv_70NKpD5KOQOFP3h-IJKzh7Wbnaxka4Bd3rdd8,29745
4
4
  rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.30.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.30.dist-info/METADATA,sha256=E5ScN8-I6sP22N80oe7QhLeszz6w9L3UcBN2GCTfu7Q,16627
30
- rxnn-0.1.30.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.30.dist-info/RECORD,,
28
+ rxnn-0.1.32.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.32.dist-info/METADATA,sha256=oGoeF0F8LvsoUc8qGSd6n0oCIZmoj59bj42QwkaWDuM,16627
30
+ rxnn-0.1.32.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.32.dist-info/RECORD,,
File without changes
File without changes