rxnn 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +31 -25
- rxnn/transformers/moe.py +0 -6
- {rxnn-0.1.19.dist-info → rxnn-0.1.21.dist-info}/METADATA +1 -1
- {rxnn-0.1.19.dist-info → rxnn-0.1.21.dist-info}/RECORD +6 -6
- {rxnn-0.1.19.dist-info → rxnn-0.1.21.dist-info}/LICENSE +0 -0
- {rxnn-0.1.19.dist-info → rxnn-0.1.21.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -65,9 +65,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
65
65
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
66
66
|
|
67
67
|
hidden_dim = embed_dim // self.num_heads
|
68
|
-
self.wk = nn.Parameter(torch.empty(self.num_experts,
|
68
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
69
69
|
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
70
|
-
self.wv = nn.Parameter(torch.empty(self.num_experts,
|
70
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
71
71
|
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
72
72
|
self._init_experts()
|
73
73
|
|
@@ -80,34 +80,34 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
80
80
|
|
81
81
|
def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
|
82
82
|
B, S, G = indices.shape
|
83
|
-
x_flat = x.view(-1, x.size(-1))
|
84
|
-
|
85
|
-
# Flatten batch and sequence dimensions
|
86
|
-
indices_flat = indices.view(-1, G)
|
87
|
-
weights_flat = weights.view(-1, G, 1)
|
83
|
+
x_flat = x.view(-1, x.size(-1)) # [B*S, D]
|
88
84
|
|
89
|
-
|
90
|
-
|
91
|
-
for g in range(G):
|
92
|
-
mask.scatter_(1, indices_flat[:, g].unsqueeze(1), True)
|
85
|
+
indices_flat = indices.view(-1, G) # [B*S, G]
|
86
|
+
weights_flat = weights.view(-1, G) # [B*S, G]
|
93
87
|
|
94
|
-
output = torch.zeros(B * S, G, w.size(
|
88
|
+
output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
|
95
89
|
|
96
90
|
for e in range(self.num_experts):
|
97
|
-
|
98
|
-
|
91
|
+
# 1. Find tokens where expert `e` is used in ANY group
|
92
|
+
expert_mask = (indices_flat == e).any(dim=1) # [B*S]
|
93
|
+
if not expert_mask.any():
|
99
94
|
continue
|
100
95
|
|
101
|
-
#
|
102
|
-
x_slice = x_flat[
|
103
|
-
proj = F.linear(x_slice, w[e]
|
96
|
+
# 2. Project tokens using expert `e`
|
97
|
+
x_slice = x_flat[expert_mask] # [num_selected, D]
|
98
|
+
proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
|
104
99
|
|
105
|
-
#
|
106
|
-
|
100
|
+
# 3. Scatter projections into correct groups
|
101
|
+
for g in range(G):
|
102
|
+
group_mask = indices_flat[expert_mask, g] == e # [num_selected]
|
103
|
+
if not group_mask.any():
|
104
|
+
continue
|
107
105
|
|
108
|
-
|
109
|
-
|
110
|
-
|
106
|
+
# Get tokens in this group using expert `e`
|
107
|
+
group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
|
108
|
+
# Weight and scatter
|
109
|
+
weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
|
110
|
+
output[group_tokens, g] += weighted_proj
|
111
111
|
|
112
112
|
return output.view(B, S, G, -1)
|
113
113
|
|
@@ -118,7 +118,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
118
118
|
# Key/Value processing
|
119
119
|
B, S, D = key.shape
|
120
120
|
key_flat = key.view(-1, D)
|
121
|
-
print('key_flat: ', key_flat.shape)
|
122
121
|
weights_k_flat, indices_k_flat = self.router(key_flat)
|
123
122
|
# Reshape back to original dimensions
|
124
123
|
weights_k = weights_k_flat.view(B, S, -1)
|
@@ -126,6 +125,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
126
125
|
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
127
126
|
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
128
127
|
|
128
|
+
print('processed k', k.size())
|
129
|
+
print('processed v', v.size())
|
130
|
+
|
129
131
|
# Expand to GQA format
|
130
132
|
k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
131
133
|
v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
@@ -139,6 +141,10 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
139
141
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
140
142
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
141
143
|
|
144
|
+
print('q', q.size())
|
145
|
+
print('k', k.size())
|
146
|
+
print('v', v.size())
|
147
|
+
|
142
148
|
return q, k, v
|
143
149
|
|
144
150
|
|
@@ -199,7 +205,7 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
199
205
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
200
206
|
|
201
207
|
hidden_dim = embed_dim // self.num_heads
|
202
|
-
self.wq = nn.Parameter(torch.empty(self.num_query_experts,
|
208
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
|
203
209
|
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
204
210
|
self._init_query_experts()
|
205
211
|
|
@@ -217,7 +223,7 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
217
223
|
# Query processing
|
218
224
|
B, T, D = query.shape
|
219
225
|
# Flatten for query routing
|
220
|
-
query_flat = query.view(
|
226
|
+
query_flat = query.view(-1, D)
|
221
227
|
weights_q_flat, indices_q_flat = self.query_router(query_flat)
|
222
228
|
# Reshape back
|
223
229
|
weights_q = weights_q_flat.view(B, T, -1)
|
rxnn/transformers/moe.py
CHANGED
@@ -16,26 +16,20 @@ class MoeRouter(nn.Module):
|
|
16
16
|
|
17
17
|
def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
|
18
18
|
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
19
|
-
print('expert mask: ', expert_mask.shape)
|
20
19
|
expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
21
|
-
print('expert usage: ', expert_usage.shape)
|
22
20
|
mean_probs = probs.mean(dim=0)
|
23
|
-
print('mean probs: ', mean_probs.shape)
|
24
21
|
return (expert_usage * mean_probs).sum() * self.num_experts
|
25
22
|
|
26
23
|
|
27
24
|
def forward(self, x: torch.Tensor):
|
28
25
|
# Input shape: [batch*seq_len, embed_dim]
|
29
26
|
logits = self.gate(x)
|
30
|
-
print('router logits: ', logits.shape)
|
31
27
|
probs = F.softmax(logits, dim=-1)
|
32
|
-
print('router probs: ', probs.shape)
|
33
28
|
# Get top-k experts for each token
|
34
29
|
top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
|
35
30
|
|
36
31
|
# Normalize weights (sum to 1 for each token)
|
37
32
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
38
|
-
print('top k: ', top_k_weights.shape, top_k_indices.shape)
|
39
33
|
# Load Balance Loss
|
40
34
|
self.aux_loss = self.calculate_aux_loss(top_k_indices, probs)
|
41
35
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=6qD3QCpkQHsIaNktjcQrRitQgQ-WkRUVtSFgEDfYGbA,32340
|
4
4
|
rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -21,11 +21,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
|
21
21
|
rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
24
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=msspVdefdt2ekIN8aT-V8DolK4taESQL_NVsSGOepIs,4739
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.21.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.21.dist-info/METADATA,sha256=gtCrs3sVTMXB9UNS1-qcJNIPzHNO8d7UaJlfviJNFEI,16627
|
30
|
+
rxnn-0.1.21.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|