rxnn 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +24 -25
- rxnn/transformers/moe.py +0 -6
- {rxnn-0.1.19.dist-info → rxnn-0.1.20.dist-info}/METADATA +1 -1
- {rxnn-0.1.19.dist-info → rxnn-0.1.20.dist-info}/RECORD +6 -6
- {rxnn-0.1.19.dist-info → rxnn-0.1.20.dist-info}/LICENSE +0 -0
- {rxnn-0.1.19.dist-info → rxnn-0.1.20.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -65,9 +65,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
65
65
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
66
66
|
|
67
67
|
hidden_dim = embed_dim // self.num_heads
|
68
|
-
self.wk = nn.Parameter(torch.empty(self.num_experts,
|
68
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
69
69
|
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
70
|
-
self.wv = nn.Parameter(torch.empty(self.num_experts,
|
70
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
71
71
|
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
72
72
|
self._init_experts()
|
73
73
|
|
@@ -80,34 +80,34 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
80
80
|
|
81
81
|
def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
|
82
82
|
B, S, G = indices.shape
|
83
|
-
x_flat = x.view(-1, x.size(-1))
|
83
|
+
x_flat = x.view(-1, x.size(-1)) # [B*S, D]
|
84
84
|
|
85
|
-
|
86
|
-
|
87
|
-
weights_flat = weights.view(-1, G, 1)
|
85
|
+
indices_flat = indices.view(-1, G) # [B*S, G]
|
86
|
+
weights_flat = weights.view(-1, G) # [B*S, G]
|
88
87
|
|
89
|
-
|
90
|
-
mask = torch.zeros(B * S, self.num_experts, device=x.device, dtype=torch.bool)
|
91
|
-
for g in range(G):
|
92
|
-
mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
|
93
|
-
|
94
|
-
output = torch.zeros(B * S, G, w.size(2), device=x.device, dtype=x.dtype)
|
88
|
+
output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
|
95
89
|
|
96
90
|
for e in range(self.num_experts):
|
97
|
-
|
98
|
-
|
91
|
+
# 1. Find tokens where expert `e` is used in ANY group
|
92
|
+
expert_mask = (indices_flat == e).any(dim=1) # [B*S]
|
93
|
+
if not expert_mask.any():
|
99
94
|
continue
|
100
95
|
|
101
|
-
#
|
102
|
-
x_slice = x_flat[
|
103
|
-
proj = F.linear(x_slice, w[e]
|
96
|
+
# 2. Project tokens using expert `e`
|
97
|
+
x_slice = x_flat[expert_mask] # [num_selected, D]
|
98
|
+
proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
|
104
99
|
|
105
|
-
#
|
106
|
-
|
100
|
+
# 3. Scatter projections into correct groups
|
101
|
+
for g in range(G):
|
102
|
+
group_mask = indices_flat[expert_mask, g] == e # [num_selected]
|
103
|
+
if not group_mask.any():
|
104
|
+
continue
|
107
105
|
|
108
|
-
|
109
|
-
|
110
|
-
|
106
|
+
# Get tokens in this group using expert `e`
|
107
|
+
group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
|
108
|
+
# Weight and scatter
|
109
|
+
weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
|
110
|
+
output[group_tokens, g] += weighted_proj
|
111
111
|
|
112
112
|
return output.view(B, S, G, -1)
|
113
113
|
|
@@ -118,7 +118,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
118
118
|
# Key/Value processing
|
119
119
|
B, S, D = key.shape
|
120
120
|
key_flat = key.view(-1, D)
|
121
|
-
print('key_flat: ', key_flat.shape)
|
122
121
|
weights_k_flat, indices_k_flat = self.router(key_flat)
|
123
122
|
# Reshape back to original dimensions
|
124
123
|
weights_k = weights_k_flat.view(B, S, -1)
|
@@ -199,7 +198,7 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
199
198
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
200
199
|
|
201
200
|
hidden_dim = embed_dim // self.num_heads
|
202
|
-
self.wq = nn.Parameter(torch.empty(self.num_query_experts,
|
201
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
|
203
202
|
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
204
203
|
self._init_query_experts()
|
205
204
|
|
@@ -217,7 +216,7 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
217
216
|
# Query processing
|
218
217
|
B, T, D = query.shape
|
219
218
|
# Flatten for query routing
|
220
|
-
query_flat = query.view(
|
219
|
+
query_flat = query.view(-1, D)
|
221
220
|
weights_q_flat, indices_q_flat = self.query_router(query_flat)
|
222
221
|
# Reshape back
|
223
222
|
weights_q = weights_q_flat.view(B, T, -1)
|
rxnn/transformers/moe.py
CHANGED
@@ -16,26 +16,20 @@ class MoeRouter(nn.Module):
|
|
16
16
|
|
17
17
|
def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
|
18
18
|
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
19
|
-
print('expert mask: ', expert_mask.shape)
|
20
19
|
expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
21
|
-
print('expert usage: ', expert_usage.shape)
|
22
20
|
mean_probs = probs.mean(dim=0)
|
23
|
-
print('mean probs: ', mean_probs.shape)
|
24
21
|
return (expert_usage * mean_probs).sum() * self.num_experts
|
25
22
|
|
26
23
|
|
27
24
|
def forward(self, x: torch.Tensor):
|
28
25
|
# Input shape: [batch*seq_len, embed_dim]
|
29
26
|
logits = self.gate(x)
|
30
|
-
print('router logits: ', logits.shape)
|
31
27
|
probs = F.softmax(logits, dim=-1)
|
32
|
-
print('router probs: ', probs.shape)
|
33
28
|
# Get top-k experts for each token
|
34
29
|
top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
|
35
30
|
|
36
31
|
# Normalize weights (sum to 1 for each token)
|
37
32
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
38
|
-
print('top k: ', top_k_weights.shape, top_k_indices.shape)
|
39
33
|
# Load Balance Loss
|
40
34
|
self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
|
41
35
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=_8dgNPxZRmplZb_k86ejZsCxhUx60mOtB8M8ZcnrTpI,32173
|
4
4
|
rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -21,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
|
21
21
|
rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
24
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=msspVdefdt2ekIN8aT-V8DolK4taESQL_NVsSGOepIs,4739
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.20.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.20.dist-info/METADATA,sha256=lEI864O9VwZMOqxSmf2a4IPpiXIhq9SANZGrRlLJxYc,16627
|
30
|
+
rxnn-0.1.20.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|