rxnn 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -65,9 +65,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
65
65
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
66
66
 
67
67
  hidden_dim = embed_dim // self.num_heads
68
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
68
+ self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
69
69
  self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
70
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
70
+ self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
71
71
  self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
72
72
  self._init_experts()
73
73
 
@@ -80,34 +80,34 @@ class GroupedMoeAttention(GroupedQueryAttention):
80
80
 
81
81
  def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
82
82
  B, S, G = indices.shape
83
- x_flat = x.view(-1, x.size(-1))
84
-
85
- # Flatten batch and sequence dimensions
86
- indices_flat = indices.view(-1, G)
87
- weights_flat = weights.view(-1, G, 1)
83
+ x_flat = x.view(-1, x.size(-1)) # [B*S, D]
88
84
 
89
- # Create expanded indices for expert processing
90
- mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
91
- for g in range(G):
92
- mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
85
+ indices_flat = indices.view(-1, G) # [B*S, G]
86
+ weights_flat = weights.view(-1, G) # [B*S, G]
93
87
 
94
- output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
88
+ output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
95
89
 
96
90
  for e in range(self.num_experts):
97
- token_mask = mask[:, e]
98
- if not token_mask.any():
91
+ # 1. Find tokens where expert `e` is used in ANY group
92
+ expert_mask = (indices_flat == e).any(dim=1) # [B*S]
93
+ if not expert_mask.any():
99
94
  continue
100
95
 
101
- # Get positions where expert e is used in any group
102
- x_slice = x_flat[token_mask]
103
- proj = F.linear(x_slice, w[e].t(), b[e] if b is not None else None)
96
+ # 2. Project tokens using expert `e`
97
+ x_slice = x_flat[expert_mask] # [num_selected, D]
98
+ proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
104
99
 
105
- # Find which groups use this expert for selected tokens
106
- group_mask = (indices_flat[token_mask] == e)
100
+ # 3. Scatter projections into correct groups
101
+ for g in range(G):
102
+ group_mask = indices_flat[expert_mask, g] == e # [num_selected]
103
+ if not group_mask.any():
104
+ continue
107
105
 
108
- # Accumulate projections for relevant groups
109
- weighted_proj = proj.unsqueeze(1) * weights_flat[token_mask] * group_mask.unsqueeze(-1).float()
110
- output[token_mask] += weighted_proj.sum(dim=1)
106
+ # Get tokens in this group using expert `e`
107
+ group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
108
+ # Weight and scatter
109
+ weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
110
+ output[group_tokens, g] += weighted_proj
111
111
 
112
112
  return output.view(B, S, G, -1)
113
113
 
@@ -118,7 +118,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
118
118
  # Key/Value processing
119
119
  B, S, D = key.shape
120
120
  key_flat = key.view(-1, D)
121
- print('key_flat: ', key_flat.shape)
122
121
  weights_k_flat, indices_k_flat = self.router(key_flat)
123
122
  # Reshape back to original dimensions
124
123
  weights_k = weights_k_flat.view(B, S, -1)
@@ -126,6 +125,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
126
125
  k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
127
126
  v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
128
127
 
128
+ print('processed k', k.size())
129
+ print('processed v', v.size())
130
+
129
131
  # Expand to GQA format
130
132
  k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
131
133
  v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
@@ -139,6 +141,10 @@ class GroupedMoeAttention(GroupedQueryAttention):
139
141
  k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
140
142
  v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
141
143
 
144
+ print('q', q.size())
145
+ print('k', k.size())
146
+ print('v', v.size())
147
+
142
148
  return q, k, v
143
149
 
144
150
 
@@ -199,7 +205,7 @@ class DeepMoeAttention(GroupedMoeAttention):
199
205
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
200
206
 
201
207
  hidden_dim = embed_dim // self.num_heads
202
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
208
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
203
209
  self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
204
210
  self._init_query_experts()
205
211
 
@@ -217,7 +223,7 @@ class DeepMoeAttention(GroupedMoeAttention):
217
223
  # Query processing
218
224
  B, T, D = query.shape
219
225
  # Flatten for query routing
220
- query_flat = query.view(B * T, D)
226
+ query_flat = query.view(-1, D)
221
227
  weights_q_flat, indices_q_flat = self.query_router(query_flat)
222
228
  # Reshape back
223
229
  weights_q = weights_q_flat.view(B, T, -1)
rxnn/transformers/moe.py CHANGED
@@ -16,26 +16,20 @@ class MoeRouter(nn.Module):
16
16
 
17
17
  def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
18
18
  expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
19
- print('expert mask: ', expert_mask.shape)
20
19
  expert_usage = expert_mask.sum(dim=0).mean(dim=0)
21
- print('expert usage: ', expert_usage.shape)
22
20
  mean_probs = probs.mean(dim=0)
23
- print('mean probs: ', mean_probs.shape)
24
21
  return (expert_usage * mean_probs).sum() * self.num_experts
25
22
 
26
23
 
27
24
  def forward(self, x: torch.Tensor):
28
25
  # Input shape: [batch*seq_len, embed_dim]
29
26
  logits = self.gate(x)
30
- print('router logits: ', logits.shape)
31
27
  probs = F.softmax(logits, dim=-1)
32
- print('router probs: ', probs.shape)
33
28
  # Get top-k experts for each token
34
29
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
35
30
 
36
31
  # Normalize weights (sum to 1 for each token)
37
32
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
38
- print('top k: ', top_k_weights.shape, top_k_indices.shape)
39
33
  # Load Balance Loss
40
34
  self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
41
35
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.19
3
+ Version: 0.1.21
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=nvYtC6BdJQ8VUcNc_co2Fe2at7TBzA4OOIfG2tWWVCk,32104
3
+ rxnn/experimental/attention.py,sha256=6qD3QCpkQHsIaNktjcQrRitQgQ-WkRUVtSFgEDfYGbA,32340
4
4
  rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -21,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
23
  rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
24
- rxnn/transformers/moe.py,sha256=gJ-jXKtc01xcBayaYchRZy7imFGnvwVfUflXvFiKjKU,5048
24
+ rxnn/transformers/moe.py,sha256=msspVdefdt2ekIN8aT-V8DolK4taESQL_NVsSGOepIs,4739
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.19.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.19.dist-info/METADATA,sha256=4ul6X1SOT2bzHCxK88SjcYc0-1zy8YAKPCoMtZ2dKrY,16627
30
- rxnn-0.1.19.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.19.dist-info/RECORD,,
28
+ rxnn-0.1.21.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.21.dist-info/METADATA,sha256=gtCrs3sVTMXB9UNS1-qcJNIPzHNO8d7UaJlfviJNFEI,16627
30
+ rxnn-0.1.21.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.21.dist-info/RECORD,,
File without changes
File without changes