rxnn 0.1.48__tar.gz → 0.1.49__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {rxnn-0.1.48 → rxnn-0.1.49}/PKG-INFO +1 -1
  2. {rxnn-0.1.48 → rxnn-0.1.49}/pyproject.toml +1 -1
  3. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/experimental/attention.py +15 -15
  4. {rxnn-0.1.48 → rxnn-0.1.49}/LICENSE +0 -0
  5. {rxnn-0.1.48 → rxnn-0.1.49}/README.md +0 -0
  6. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/__init__.py +0 -0
  7. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/experimental/__init__.py +0 -0
  8. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/experimental/models.py +0 -0
  9. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/experimental/moe.py +0 -0
  10. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/memory/__init__.py +0 -0
  11. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/memory/norm.py +0 -0
  12. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/memory/stm.py +0 -0
  13. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/rxt/__init__.py +0 -0
  14. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/rxt/models.py +0 -0
  15. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/training/__init__.py +0 -0
  16. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/training/base.py +0 -0
  17. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/training/bml.py +0 -0
  18. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/training/callbacks.py +0 -0
  19. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/training/dataset.py +0 -0
  20. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/training/scheduler.py +0 -0
  21. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/training/tokenizer.py +0 -0
  22. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/__init__.py +0 -0
  23. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/attention.py +0 -0
  24. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/ff.py +0 -0
  25. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/layers.py +0 -0
  26. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/mask.py +0 -0
  27. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/models.py +0 -0
  28. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/moe.py +0 -0
  29. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/positional.py +0 -0
  30. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/transformers/sampler.py +0 -0
  31. {rxnn-0.1.48 → rxnn-0.1.49}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.48
3
+ Version: 0.1.49
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.48"
7
+ version = "0.1.49"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -94,16 +94,17 @@ class GroupedMoeAttention(GroupedQueryAttention):
94
94
  B, S, D = key.shape
95
95
  key_flat = key.reshape(-1, D)
96
96
  weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
97
- weights = weights.view(B, self.num_groups, S, 1)
98
- indices = indices.view(B, self.num_groups, S).unsqueeze(-1).expand(-1, -1, S, -1)
97
+ weights = weights.view(B, S, self.num_groups, 1)
98
+ indices = indices.view(B, S, self.num_groups)
99
99
 
100
100
  # Compute all experts' projections
101
- k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
102
- v_all = self.v_proj(value).view(B, S, self.num_experts, -1).permute(0, 2, 1, 3) # [B, num_experts, S, head_dim]
101
+ k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
102
+ v_all = self.v_proj(value).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
103
103
 
104
104
  # Gather top-k experts using expanded indices
105
- selected_k = torch.gather(k_all, 1, indices) # [B, num_groups, S, head_dim]
106
- selected_v = torch.gather(v_all, 1, indices) # [B, num_groups, S, head_dim]
105
+ expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, num_groups, S, head_dim]
106
+ selected_k = torch.gather(k_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
107
+ selected_v = torch.gather(v_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
107
108
 
108
109
  # Weighted
109
110
  weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
@@ -122,8 +123,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
122
123
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
123
124
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
124
125
 
125
- print(q.size(), k.size(), v.size())
126
-
127
126
  return q, k, v
128
127
 
129
128
 
@@ -208,17 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
208
207
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
209
208
  B, T, D = query.shape
210
209
  query_flat = query.reshape(-1, D)
211
- weights, indices = self.query_router(query_flat)
212
- weights = weights.view(B, self.num_query_groups, T, 1)
213
- indices = indices.view(B, self.num_query_groups, T).unsqueeze(-1).expand(-1, -1, T, -1) # [B, num_query_groups, T, head_dim]
210
+ weights_q, indices_q = self.query_router(query_flat)
211
+ weights_q = weights_q.view(B, T, self.num_query_groups, 1)
212
+ indices_q = indices_q.view(B, T, self.num_query_groups)
214
213
 
215
- q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1).permute(0, 2, 1, 3) # [B, num_query_experts, T, head_dim]
214
+ q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1) # [B, num_groups, S, head_dim]
216
215
 
217
- # Gather top-k experts using expanded indices
218
- selected_q = torch.gather(q_all, 1, indices) # [B, num_query_groups, T, head_dim]
216
+ # Gather top-k experts
217
+ expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
218
+ selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
219
219
 
220
220
  # Weighted sum
221
- q = (selected_q * weights).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
221
+ q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
222
222
  q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
223
223
 
224
224
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes