rxnn 0.1.19__tar.gz → 0.1.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {rxnn-0.1.19 → rxnn-0.1.20}/PKG-INFO +1 -1
  2. {rxnn-0.1.19 → rxnn-0.1.20}/pyproject.toml +1 -1
  3. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/experimental/attention.py +24 -25
  4. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/moe.py +0 -6
  5. {rxnn-0.1.19 → rxnn-0.1.20}/LICENSE +0 -0
  6. {rxnn-0.1.19 → rxnn-0.1.20}/README.md +0 -0
  7. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/experimental/models.py +0 -0
  10. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/experimental/moe.py +0 -0
  11. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/memory/__init__.py +0 -0
  12. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/memory/norm.py +0 -0
  13. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/memory/stm.py +0 -0
  14. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/rxt/__init__.py +0 -0
  15. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/rxt/models.py +0 -0
  16. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/training/__init__.py +0 -0
  17. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/training/base.py +0 -0
  18. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/training/bml.py +0 -0
  19. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/training/callbacks.py +0 -0
  20. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/training/dataset.py +0 -0
  21. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/training/scheduler.py +0 -0
  22. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/training/tokenizer.py +0 -0
  23. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/__init__.py +0 -0
  24. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/attention.py +0 -0
  25. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/ff.py +0 -0
  26. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/layers.py +0 -0
  27. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/mask.py +0 -0
  28. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/models.py +0 -0
  29. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/positional.py +0 -0
  30. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/transformers/sampler.py +0 -0
  31. {rxnn-0.1.19 → rxnn-0.1.20}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.19
3
+ Version: 0.1.20
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.19"
7
+ version = "0.1.20"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -65,9 +65,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
65
65
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
66
66
 
67
67
  hidden_dim = embed_dim // self.num_heads
68
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
68
+ self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
69
69
  self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
70
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
70
+ self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
71
71
  self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
72
72
  self._init_experts()
73
73
 
@@ -80,34 +80,34 @@ class GroupedMoeAttention(GroupedQueryAttention):
80
80
 
81
81
  def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
82
82
  B, S, G = indices.shape
83
- x_flat = x.view(-1, x.size(-1))
83
+ x_flat = x.view(-1, x.size(-1)) # [B*S, D]
84
84
 
85
- # Flatten batch and sequence dimensions
86
- indices_flat = indices.view(-1, G)
87
- weights_flat = weights.view(-1, G, 1)
85
+ indices_flat = indices.view(-1, G) # [B*S, G]
86
+ weights_flat = weights.view(-1, G) # [B*S, G]
88
87
 
89
- # Create expanded indices for expert processing
90
- mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
91
- for g in range(G):
92
- mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
93
-
94
- output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
88
+ output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
95
89
 
96
90
  for e in range(self.num_experts):
97
- token_mask = mask[:, e]
98
- if not token_mask.any():
91
+ # 1. Find tokens where expert `e` is used in ANY group
92
+ expert_mask = (indices_flat == e).any(dim=1) # [B*S]
93
+ if not expert_mask.any():
99
94
  continue
100
95
 
101
- # Get positions where expert e is used in any group
102
- x_slice = x_flat[token_mask]
103
- proj = F.linear(x_slice, w[e].t(), b[e] if b is not None else None)
96
+ # 2. Project tokens using expert `e`
97
+ x_slice = x_flat[expert_mask] # [num_selected, D]
98
+ proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
104
99
 
105
- # Find which groups use this expert for selected tokens
106
- group_mask = (indices_flat[token_mask] == e)
100
+ # 3. Scatter projections into correct groups
101
+ for g in range(G):
102
+ group_mask = indices_flat[expert_mask, g] == e # [num_selected]
103
+ if not group_mask.any():
104
+ continue
107
105
 
108
- # Accumulate projections for relevant groups
109
- weighted_proj = proj.unsqueeze(1) * weights_flat[token_mask] * group_mask.unsqueeze(-1).float()
110
- output[token_mask] += weighted_proj.sum(dim=1)
106
+ # Get tokens in this group using expert `e`
107
+ group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
108
+ # Weight and scatter
109
+ weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
110
+ output[group_tokens, g] += weighted_proj
111
111
 
112
112
  return output.view(B, S, G, -1)
113
113
 
@@ -118,7 +118,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
118
118
  # Key/Value processing
119
119
  B, S, D = key.shape
120
120
  key_flat = key.view(-1, D)
121
- print('key_flat: ', key_flat.shape)
122
121
  weights_k_flat, indices_k_flat = self.router(key_flat)
123
122
  # Reshape back to original dimensions
124
123
  weights_k = weights_k_flat.view(B, S, -1)
@@ -199,7 +198,7 @@ class DeepMoeAttention(GroupedMoeAttention):
199
198
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
200
199
 
201
200
  hidden_dim = embed_dim // self.num_heads
202
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
201
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
203
202
  self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
204
203
  self._init_query_experts()
205
204
 
@@ -217,7 +216,7 @@ class DeepMoeAttention(GroupedMoeAttention):
217
216
  # Query processing
218
217
  B, T, D = query.shape
219
218
  # Flatten for query routing
220
- query_flat = query.view(B * T, D)
219
+ query_flat = query.view(-1, D)
221
220
  weights_q_flat, indices_q_flat = self.query_router(query_flat)
222
221
  # Reshape back
223
222
  weights_q = weights_q_flat.view(B, T, -1)
@@ -16,26 +16,20 @@ class MoeRouter(nn.Module):
16
16
 
17
17
  def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
18
18
  expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
19
- print('expert mask: ', expert_mask.shape)
20
19
  expert_usage = expert_mask.sum(dim=0).mean(dim=0)
21
- print('expert usage: ', expert_usage.shape)
22
20
  mean_probs = probs.mean(dim=0)
23
- print('mean probs: ', mean_probs.shape)
24
21
  return (expert_usage * mean_probs).sum() * self.num_experts
25
22
 
26
23
 
27
24
  def forward(self, x: torch.Tensor):
28
25
  # Input shape: [batch*seq_len, embed_dim]
29
26
  logits = self.gate(x)
30
- print('router logits: ', logits.shape)
31
27
  probs = F.softmax(logits, dim=-1)
32
- print('router probs: ', probs.shape)
33
28
  # Get top-k experts for each token
34
29
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
35
30
 
36
31
  # Normalize weights (sum to 1 for each token)
37
32
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
38
- print('top k: ', top_k_weights.shape, top_k_indices.shape)
39
33
  # Load Balance Loss
40
34
  self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
41
35
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes