rxnn 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -330,8 +330,8 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
330
330
 
331
331
  # Compute all experts' projections
332
332
  # Shape: (B*S, num_experts, head_dim)
333
- k_all = torch.einsum('be,ehd->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
- v_all = torch.einsum('be,ehd->beh', value.view(B*S, D), self.wv)
333
+ k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
+ v_all = torch.einsum('bd,edh->beh', value.view(B*S, D), self.wv)
335
335
 
336
336
  # Reshape to [B, S, num_experts, head_dim]
337
337
  k_all = k_all.view(B, S, self.num_experts, -1)
@@ -434,12 +434,13 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
434
434
 
435
435
  def _init_out(self, embed_dim: int):
436
436
  """Initialize output projection"""
437
- self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
437
+ out_hidden_dim = embed_dim // self.num_heads * self.num_query_groups
438
+ self.out_proj = nn.Linear(out_hidden_dim, embed_dim)
438
439
 
439
440
  def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
440
441
  """Transpose attention output back to (B, T, D) shape"""
441
- hidden_dim = d // self.num_heads * self.num_query_groups
442
- return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
442
+ out_hidden_dim = d // self.num_heads * self.num_query_groups
443
+ return attn_output.transpose(1, 2).contiguous().view(b, t, out_hidden_dim)
443
444
 
444
445
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
445
446
  B, T, D = query.shape
@@ -449,7 +450,7 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
449
450
  indices_q = indices_q.view(B, T, self.num_query_groups)
450
451
 
451
452
  # Compute all query experts
452
- q_all = torch.einsum('be,ehd->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
453
+ q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
453
454
  q_all = q_all.view(B, T, self.num_query_experts, -1)
454
455
 
455
456
  # Gather top-k experts
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.27
3
+ Version: 0.1.29
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=csasMRxL4nq2dS7pc9WdS4bvCB70ZVgsR7LTHV2jEJ0,29388
3
+ rxnn/experimental/attention.py,sha256=FLIyfRWZC1Ppfx9u8qUcNcfMSscuibhKEQ8zCJHXcfk,29439
4
4
  rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.27.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.27.dist-info/METADATA,sha256=XjcqSdhjTRsCvvP-o981Ihp4k5PFRCQUSLVsPZ_NVPw,16627
30
- rxnn-0.1.27.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.27.dist-info/RECORD,,
28
+ rxnn-0.1.29.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.29.dist-info/METADATA,sha256=u3tlE_n8tZsqgW44vPssCmmJ03wHKM57tGS1xXI3HcQ,16627
30
+ rxnn-0.1.29.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.29.dist-info/RECORD,,
File without changes
File without changes