rxnn 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -65,9 +65,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
65
65
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
66
66
 
67
67
  hidden_dim = embed_dim // self.num_heads
68
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
68
+ self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
69
69
  self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
70
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
70
+ self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
71
71
  self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
72
72
  self._init_experts()
73
73
 
@@ -80,34 +80,34 @@ class GroupedMoeAttention(GroupedQueryAttention):
80
80
 
81
81
  def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
82
82
  B, S, G = indices.shape
83
- x_flat = x.view(-1, x.size(-1))
83
+ x_flat = x.view(-1, x.size(-1)) # [B*S, D]
84
84
 
85
- # Flatten batch and sequence dimensions
86
- indices_flat = indices.view(-1, G)
87
- weights_flat = weights.view(-1, G, 1)
85
+ indices_flat = indices.view(-1, G) # [B*S, G]
86
+ weights_flat = weights.view(-1, G) # [B*S, G]
88
87
 
89
- # Create expanded indices for expert processing
90
- mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
91
- for g in range(G):
92
- mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
93
-
94
- output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
88
+ output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
95
89
 
96
90
  for e in range(self.num_experts):
97
- token_mask = mask[:, e]
98
- if not token_mask.any():
91
+ # 1. Find tokens where expert `e` is used in ANY group
92
+ expert_mask = (indices_flat == e).any(dim=1) # [B*S]
93
+ if not expert_mask.any():
99
94
  continue
100
95
 
101
- # Get positions where expert e is used in any group
102
- x_slice = x_flat[token_mask]
103
- proj = F.linear(x_slice, w[e].t(), b[e] if b is not None else None)
96
+ # 2. Project tokens using expert `e`
97
+ x_slice = x_flat[expert_mask] # [num_selected, D]
98
+ proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
104
99
 
105
- # Find which groups use this expert for selected tokens
106
- group_mask = (indices_flat[token_mask] == e)
100
+ # 3. Scatter projections into correct groups
101
+ for g in range(G):
102
+ group_mask = indices_flat[expert_mask, g] == e # [num_selected]
103
+ if not group_mask.any():
104
+ continue
107
105
 
108
- # Accumulate projections for relevant groups
109
- weighted_proj = proj.unsqueeze(1) * weights_flat[token_mask] * group_mask.unsqueeze(-1).float()
110
- output[token_mask] += weighted_proj.sum(dim=1)
106
+ # Get tokens in this group using expert `e`
107
+ group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
108
+ # Weight and scatter
109
+ weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
110
+ output[group_tokens, g] += weighted_proj
111
111
 
112
112
  return output.view(B, S, G, -1)
113
113
 
@@ -118,7 +118,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
118
118
  # Key/Value processing
119
119
  B, S, D = key.shape
120
120
  key_flat = key.view(-1, D)
121
- print('key_flat: ', key_flat.shape)
122
121
  weights_k_flat, indices_k_flat = self.router(key_flat)
123
122
  # Reshape back to original dimensions
124
123
  weights_k = weights_k_flat.view(B, S, -1)
@@ -199,7 +198,7 @@ class DeepMoeAttention(GroupedMoeAttention):
199
198
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
200
199
 
201
200
  hidden_dim = embed_dim // self.num_heads
202
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
201
+ self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
203
202
  self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
204
203
  self._init_query_experts()
205
204
 
@@ -217,7 +216,7 @@ class DeepMoeAttention(GroupedMoeAttention):
217
216
  # Query processing
218
217
  B, T, D = query.shape
219
218
  # Flatten for query routing
220
- query_flat = query.view(B * T, D)
219
+ query_flat = query.view(-1, D)
221
220
  weights_q_flat, indices_q_flat = self.query_router(query_flat)
222
221
  # Reshape back
223
222
  weights_q = weights_q_flat.view(B, T, -1)
rxnn/transformers/moe.py CHANGED
@@ -16,26 +16,20 @@ class MoeRouter(nn.Module):
16
16
 
17
17
  def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
18
18
  expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
19
- print('expert mask: ', expert_mask.shape)
20
19
  expert_usage = expert_mask.sum(dim=0).mean(dim=0)
21
- print('expert usage: ', expert_usage.shape)
22
20
  mean_probs = probs.mean(dim=0)
23
- print('mean probs: ', mean_probs.shape)
24
21
  return (expert_usage * mean_probs).sum() * self.num_experts
25
22
 
26
23
 
27
24
  def forward(self, x: torch.Tensor):
28
25
  # Input shape: [batch*seq_len, embed_dim]
29
26
  logits = self.gate(x)
30
- print('router logits: ', logits.shape)
31
27
  probs = F.softmax(logits, dim=-1)
32
- print('router probs: ', probs.shape)
33
28
  # Get top-k experts for each token
34
29
  top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
35
30
 
36
31
  # Normalize weights (sum to 1 for each token)
37
32
  top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
38
- print('top k: ', top_k_weights.shape, top_k_indices.shape)
39
33
  # Load Balance Loss
40
34
  self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
41
35
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.19
3
+ Version: 0.1.20
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,6 +1,6 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=nvYtC6BdJQ8VUcNc_co2Fe2at7TBzA4OOIfG2tWWVCk,32104
3
+ rxnn/experimental/attention.py,sha256=_8dgNPxZRmplZb_k86ejZsCxhUx60mOtB8M8ZcnrTpI,32173
4
4
  rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -21,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
23
  rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
24
- rxnn/transformers/moe.py,sha256=gJ-jXKtc01xcBayaYchRZy7imFGnvwVfUflXvFiKjKU,5048
24
+ rxnn/transformers/moe.py,sha256=msspVdefdt2ekIN8aT-V8DolK4taESQL_NVsSGOepIs,4739
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.19.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.19.dist-info/METADATA,sha256=4ul6X1SOT2bzHCxK88SjcYc0-1zy8YAKPCoMtZ2dKrY,16627
30
- rxnn-0.1.19.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.19.dist-info/RECORD,,
28
+ rxnn-0.1.20.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.20.dist-info/METADATA,sha256=lEI864O9VwZMOqxSmf2a4IPpiXIhq9SANZGrRlLJxYc,16627
30
+ rxnn-0.1.20.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.20.dist-info/RECORD,,
File without changes
File without changes