rxnn 0.1.31__tar.gz → 0.1.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {rxnn-0.1.31 → rxnn-0.1.32}/PKG-INFO +1 -1
  2. {rxnn-0.1.31 → rxnn-0.1.32}/pyproject.toml +1 -1
  3. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/attention.py +23 -32
  4. {rxnn-0.1.31 → rxnn-0.1.32}/LICENSE +0 -0
  5. {rxnn-0.1.31 → rxnn-0.1.32}/README.md +0 -0
  6. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/__init__.py +0 -0
  7. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/__init__.py +0 -0
  8. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/models.py +0 -0
  9. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/experimental/moe.py +0 -0
  10. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/memory/__init__.py +0 -0
  11. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/memory/norm.py +0 -0
  12. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/memory/stm.py +0 -0
  13. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/rxt/__init__.py +0 -0
  14. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/rxt/models.py +0 -0
  15. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/__init__.py +0 -0
  16. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/base.py +0 -0
  17. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/bml.py +0 -0
  18. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/callbacks.py +0 -0
  19. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/dataset.py +0 -0
  20. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/scheduler.py +0 -0
  21. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/training/tokenizer.py +0 -0
  22. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/__init__.py +0 -0
  23. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/attention.py +0 -0
  24. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/ff.py +0 -0
  25. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/layers.py +0 -0
  26. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/mask.py +0 -0
  27. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/models.py +0 -0
  28. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/moe.py +0 -0
  29. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/positional.py +0 -0
  30. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/transformers/sampler.py +0 -0
  31. {rxnn-0.1.31 → rxnn-0.1.32}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.31
3
+ Version: 0.1.32
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.31"
7
+ version = "0.1.32"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -323,7 +323,7 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
323
323
 
324
324
  # Key/Value MoE routing
325
325
  B, S, D = key.shape
326
- key_flat = key.reshape(B * S, D)
326
+ key_flat = key.reshape(-1, D)
327
327
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
328
328
  weights = weights.view(B, S, self.num_groups, 1)
329
329
  indices = indices.view(B, S, self.num_groups)
@@ -331,35 +331,28 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
331
331
  # Compute all experts' projections
332
332
  # Shape: (B*S, num_experts, head_dim)
333
333
  k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
334
- v_all = torch.einsum('bd,edh->beh', value.view(B*S, D), self.wv)
334
+ v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
335
335
 
336
- # Reshape to [B, S, num_experts, head_dim]
337
- k_all = k_all.view(B, S, self.num_experts, -1)
338
- v_all = v_all.view(B, S, self.num_experts, -1)
339
-
340
- print('k_all', k_all.size())
341
- print('v_all', v_all.size())
342
-
343
-
344
- # Gather top-k experts and weights
345
- # Expand indices to [B, S, num_groups, head_dim]
346
- expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
347
- selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
348
- selected_v = torch.gather(v_all, 2, expanded_indices)
336
+ if self.use_bias:
337
+ k_all += self.bk
338
+ v_all += self.bv
349
339
 
350
- print('selected_k', selected_k.size())
351
- print('selected_v', selected_v.size())
340
+ # Get results for all heads
341
+ k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
342
+ v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
352
343
 
353
- # Weighted sum
354
- weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
355
- weighted_v = (selected_v * weights).sum(dim=2)
344
+ # Gather top-k experts using expanded indices
345
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
346
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
347
+ selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
356
348
 
357
- print('weighted_k', weighted_k.size())
358
- print('weighted_v', weighted_v.size())
349
+ # Weighted
350
+ weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
351
+ weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
359
352
 
360
353
  # Reshape to GQA format
361
- k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
362
- v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
354
+ k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
355
+ v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
363
356
 
364
357
  if not self.use_flash_attention:
365
358
  group_heads = self.num_heads // self.num_groups
@@ -370,10 +363,6 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
370
363
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
371
364
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
372
365
 
373
- print('q', q.size())
374
- print('k', k.size())
375
- print('v', v.size())
376
-
377
366
  return q, k, v
378
367
 
379
368
 
@@ -458,13 +447,16 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
458
447
 
459
448
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
460
449
  B, T, D = query.shape
461
- query_flat = query.reshape(B * T, D)
450
+ query_flat = query.reshape(-1, D)
462
451
  weights_q, indices_q = self.query_router(query_flat)
463
452
  weights_q = weights_q.view(B, T, self.num_query_groups, 1)
464
453
  indices_q = indices_q.view(B, T, self.num_query_groups)
465
454
 
466
455
  # Compute all query experts
467
456
  q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
457
+ if self.use_bias:
458
+ q_all += self.bq
459
+
468
460
  q_all = q_all.view(B, T, self.num_query_experts, -1)
469
461
 
470
462
  # Gather top-k experts
@@ -472,12 +464,11 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
472
464
  selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
473
465
 
474
466
  # Weighted sum
475
- q = (selected_q * weights_q).sum(dim=2) # [B, T, head_dim]
476
- q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, H_q, T, head_dim]
467
+ q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
468
+ q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
477
469
 
478
470
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
479
471
 
480
-
481
472
  # Others
482
473
 
483
474
  class FlexAttention(MultiHeadAttention):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes