rxnn 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +151 -156
- rxnn/experimental/moe.py +0 -40
- rxnn/transformers/attention.py +45 -37
- {rxnn-0.1.33.dist-info → rxnn-0.1.35.dist-info}/METADATA +1 -1
- {rxnn-0.1.33.dist-info → rxnn-0.1.35.dist-info}/RECORD +7 -7
- {rxnn-0.1.33.dist-info → rxnn-0.1.35.dist-info}/LICENSE +0 -0
- {rxnn-0.1.33.dist-info → rxnn-0.1.35.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -9,6 +9,10 @@ from ..transformers.moe import MoeRouter
|
|
9
9
|
|
10
10
|
class GroupedMoeAttention(GroupedQueryAttention):
|
11
11
|
"""
|
12
|
+
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
13
|
+
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
14
|
+
experts - it has to be tested.
|
15
|
+
|
12
16
|
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
13
17
|
|
14
18
|
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
@@ -44,7 +48,7 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
44
48
|
*args,
|
45
49
|
**kwargs,
|
46
50
|
):
|
47
|
-
self.num_experts = num_experts
|
51
|
+
self.num_experts = num_experts if num_experts is not None else num_heads
|
48
52
|
super(GroupedMoeAttention, self).__init__(
|
49
53
|
embed_dim,
|
50
54
|
num_heads,
|
@@ -61,78 +65,63 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
61
65
|
**kwargs,
|
62
66
|
)
|
63
67
|
|
64
|
-
def router_loss(self):
|
65
|
-
return self.router.aux_loss
|
66
|
-
|
67
68
|
def _init_kv(self, embed_dim: int):
|
68
69
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
69
|
-
|
70
70
|
hidden_dim = embed_dim // self.num_heads
|
71
|
-
self.wk = nn.Parameter(torch.empty(self.num_experts,
|
71
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
72
72
|
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
73
|
-
self.wv = nn.Parameter(torch.empty(self.num_experts,
|
73
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
74
74
|
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
75
75
|
self._init_experts()
|
76
76
|
|
77
77
|
def _init_experts(self):
|
78
|
-
nn.init.xavier_uniform_(self.wk)
|
79
|
-
nn.init.xavier_uniform_(self.wv)
|
78
|
+
torch.nn.init.xavier_uniform_(self.wk)
|
79
|
+
torch.nn.init.xavier_uniform_(self.wv)
|
80
80
|
if self.use_bias:
|
81
|
-
nn.init.zeros_(self.bk)
|
82
|
-
nn.init.zeros_(self.bv)
|
83
|
-
|
84
|
-
def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
|
85
|
-
B, S, G = indices.shape
|
86
|
-
x_flat = x.view(-1, x.size(-1)) # [B*S, D]
|
87
|
-
|
88
|
-
indices_flat = indices.view(-1, G) # [B*S, G]
|
89
|
-
weights_flat = weights.view(-1, G) # [B*S, G]
|
81
|
+
torch.nn.init.zeros_(self.bk)
|
82
|
+
torch.nn.init.zeros_(self.bv)
|
90
83
|
|
91
|
-
|
84
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
85
|
+
skip_query_processing: bool = False):
|
86
|
+
head_dim = d // self.num_heads
|
92
87
|
|
93
|
-
|
94
|
-
|
95
|
-
expert_mask = (indices_flat == e).any(dim=1) # [B*S]
|
96
|
-
if not expert_mask.any():
|
97
|
-
continue
|
88
|
+
# Process Query as in GQA
|
89
|
+
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
98
90
|
|
99
|
-
|
100
|
-
|
101
|
-
|
91
|
+
# Key/Value MoE routing
|
92
|
+
B, S, D = key.shape
|
93
|
+
key_flat = key.reshape(-1, D)
|
94
|
+
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
95
|
+
weights = weights.view(B, S, self.num_groups, 1)
|
96
|
+
indices = indices.view(B, S, self.num_groups)
|
102
97
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
continue
|
98
|
+
# Compute all experts' projections
|
99
|
+
# Shape: (B*S, num_experts, head_dim)
|
100
|
+
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
101
|
+
v_all = torch.einsum('bd,edh->beh', value.view(-1, D), self.wv)
|
108
102
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
|
113
|
-
output[group_tokens, g] += weighted_proj
|
103
|
+
if self.use_bias:
|
104
|
+
k_all += self.bk
|
105
|
+
v_all += self.bv
|
114
106
|
|
115
|
-
|
107
|
+
# Get results for all heads
|
108
|
+
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
109
|
+
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
116
110
|
|
117
|
-
|
118
|
-
|
119
|
-
|
111
|
+
# Gather top-k experts using expanded indices
|
112
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
|
113
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
114
|
+
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
120
115
|
|
121
|
-
#
|
122
|
-
B, S,
|
123
|
-
|
124
|
-
weights_k_flat, indices_k_flat = self.router(key_flat)
|
125
|
-
# Reshape back to original dimensions
|
126
|
-
weights_k = weights_k_flat.view(B, S, -1)
|
127
|
-
indices_k = indices_k_flat.view(B, S, -1)
|
128
|
-
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
129
|
-
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
116
|
+
# Weighted
|
117
|
+
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
118
|
+
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
130
119
|
|
131
|
-
#
|
132
|
-
k =
|
133
|
-
v =
|
120
|
+
# Reshape to GQA format
|
121
|
+
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
122
|
+
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
134
123
|
|
135
|
-
if not self.
|
124
|
+
if not self.rel_embed:
|
136
125
|
group_heads = self.num_heads // self.num_groups
|
137
126
|
|
138
127
|
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
@@ -197,55 +186,54 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
197
186
|
**kwargs,
|
198
187
|
)
|
199
188
|
|
200
|
-
def router_loss(self):
|
201
|
-
return (self.router.aux_loss + self.query_router.aux_loss) / 2
|
202
|
-
|
203
189
|
def _init_q(self, embed_dim: int):
|
204
190
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
205
|
-
|
206
191
|
hidden_dim = embed_dim // self.num_heads
|
207
|
-
self.wq = nn.Parameter(torch.empty(self.num_query_experts,
|
192
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
208
193
|
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
209
194
|
self._init_query_experts()
|
210
195
|
|
211
196
|
def _init_query_experts(self):
|
212
|
-
nn.init.xavier_uniform_(self.wq)
|
197
|
+
torch.nn.init.xavier_uniform_(self.wq)
|
213
198
|
if self.use_bias:
|
214
|
-
nn.init.zeros_(self.bq)
|
199
|
+
torch.nn.init.zeros_(self.bq)
|
215
200
|
|
216
201
|
def _init_out(self, embed_dim: int):
|
217
202
|
"""Initialize output projection"""
|
218
|
-
|
219
|
-
self.out_proj = nn.Linear(
|
203
|
+
out_hidden_dim = embed_dim // self.num_heads * self.num_query_groups
|
204
|
+
self.out_proj = nn.Linear(out_hidden_dim, embed_dim)
|
220
205
|
|
221
206
|
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
222
207
|
"""Transpose attention output back to (B, T, D) shape"""
|
223
|
-
|
224
|
-
return attn_output.transpose(1, 2).contiguous().view(b, t,
|
208
|
+
out_hidden_dim = d // self.num_heads * self.num_query_groups
|
209
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, out_hidden_dim)
|
225
210
|
|
226
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int
|
227
|
-
# Query processing
|
211
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
228
212
|
B, T, D = query.shape
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
weights_q = weights_q_flat.view(B, T, -1)
|
234
|
-
indices_q = indices_q_flat.view(B, T, -1)
|
235
|
-
q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
|
213
|
+
query_flat = query.reshape(-1, D)
|
214
|
+
weights_q, indices_q = self.query_router(query_flat)
|
215
|
+
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
216
|
+
indices_q = indices_q.view(B, T, self.num_query_groups)
|
236
217
|
|
237
|
-
|
238
|
-
#
|
239
|
-
|
218
|
+
# Compute all query experts
|
219
|
+
q_all = torch.einsum('bd,edh->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
220
|
+
if self.use_bias:
|
221
|
+
q_all += self.bq
|
240
222
|
|
241
|
-
|
223
|
+
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
242
224
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
225
|
+
# Gather top-k experts
|
226
|
+
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
|
227
|
+
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
228
|
+
|
229
|
+
# Weighted sum
|
230
|
+
q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
|
231
|
+
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
232
|
+
|
233
|
+
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
248
234
|
|
235
|
+
class GroupedMoeAttentionSimplified(GroupedQueryAttention):
|
236
|
+
"""
|
249
237
|
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
250
238
|
|
251
239
|
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
@@ -281,8 +269,8 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
281
269
|
*args,
|
282
270
|
**kwargs,
|
283
271
|
):
|
284
|
-
self.num_experts = num_experts
|
285
|
-
super(
|
272
|
+
self.num_experts = num_experts or num_heads
|
273
|
+
super(GroupedMoeAttentionSimplified, self).__init__(
|
286
274
|
embed_dim,
|
287
275
|
num_heads,
|
288
276
|
num_groups=num_groups,
|
@@ -298,63 +286,78 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
298
286
|
**kwargs,
|
299
287
|
)
|
300
288
|
|
289
|
+
def router_loss(self):
|
290
|
+
return self.router.aux_loss
|
291
|
+
|
301
292
|
def _init_kv(self, embed_dim: int):
|
302
293
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
294
|
+
|
303
295
|
hidden_dim = embed_dim // self.num_heads
|
304
|
-
self.wk = nn.Parameter(torch.empty(self.num_experts,
|
296
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
305
297
|
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
306
|
-
self.wv = nn.Parameter(torch.empty(self.num_experts,
|
298
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, hidden_dim, embed_dim))
|
307
299
|
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
308
300
|
self._init_experts()
|
309
301
|
|
310
302
|
def _init_experts(self):
|
311
|
-
|
312
|
-
|
303
|
+
nn.init.xavier_uniform_(self.wk)
|
304
|
+
nn.init.xavier_uniform_(self.wv)
|
313
305
|
if self.use_bias:
|
314
|
-
|
315
|
-
|
306
|
+
nn.init.zeros_(self.bk)
|
307
|
+
nn.init.zeros_(self.bv)
|
316
308
|
|
317
|
-
def
|
318
|
-
|
319
|
-
|
309
|
+
def _process_grouped_experts(self, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, weights: torch.Tensor, indices: torch.Tensor):
|
310
|
+
B, S, G = indices.shape
|
311
|
+
x_flat = x.view(-1, x.size(-1)) # [B*S, D]
|
320
312
|
|
321
|
-
|
322
|
-
|
313
|
+
indices_flat = indices.view(-1, G) # [B*S, G]
|
314
|
+
weights_flat = weights.view(-1, G) # [B*S, G]
|
323
315
|
|
324
|
-
#
|
325
|
-
B, S, D = key.shape
|
326
|
-
key_flat = key.reshape(-1, D)
|
327
|
-
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
328
|
-
weights = weights.view(B, S, self.num_groups, 1)
|
329
|
-
indices = indices.view(B, S, self.num_groups)
|
316
|
+
output = torch.zeros(B * S, G, w.size(1), device=x.device, dtype=x.dtype) # [B*S, G, hidden_dim]
|
330
317
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
318
|
+
for e in range(self.num_experts):
|
319
|
+
# 1. Find tokens where expert `e` is used in ANY group
|
320
|
+
expert_mask = (indices_flat == e).any(dim=1) # [B*S]
|
321
|
+
if not expert_mask.any():
|
322
|
+
continue
|
335
323
|
|
336
|
-
|
337
|
-
|
338
|
-
|
324
|
+
# 2. Project tokens using expert `e`
|
325
|
+
x_slice = x_flat[expert_mask] # [num_selected, D]
|
326
|
+
proj = F.linear(x_slice, w[e], b[e] if b is not None else None) # [num_selected, hidden_dim]
|
339
327
|
|
340
|
-
|
341
|
-
|
342
|
-
|
328
|
+
# 3. Scatter projections into correct groups
|
329
|
+
for g in range(G):
|
330
|
+
group_mask = indices_flat[expert_mask, g] == e # [num_selected]
|
331
|
+
if not group_mask.any():
|
332
|
+
continue
|
343
333
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
334
|
+
# Get tokens in this group using expert `e`
|
335
|
+
group_tokens = expert_mask.nonzero()[group_mask].squeeze(1)
|
336
|
+
# Weight and scatter
|
337
|
+
weighted_proj = proj[group_mask] * weights_flat[group_tokens, g].unsqueeze(-1)
|
338
|
+
output[group_tokens, g] += weighted_proj
|
348
339
|
|
349
|
-
|
350
|
-
weighted_k = selected_k * weights # [B, S, num_groups, head_dim]
|
351
|
-
weighted_v = selected_v * weights # [B, S, num_groups, head_dim]
|
340
|
+
return output.view(B, S, G, -1)
|
352
341
|
|
353
|
-
|
354
|
-
|
355
|
-
|
342
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
343
|
+
skip_query_processing: bool = False):
|
344
|
+
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
345
|
+
|
346
|
+
# Key/Value processing
|
347
|
+
B, S, D = key.shape
|
348
|
+
key_flat = key.view(-1, D)
|
349
|
+
weights_k_flat, indices_k_flat = self.router(key_flat)
|
350
|
+
# Reshape back to original dimensions
|
351
|
+
weights_k = weights_k_flat.view(B, S, -1)
|
352
|
+
indices_k = indices_k_flat.view(B, S, -1)
|
353
|
+
k = self._process_grouped_experts(key, self.wk, self.bk, weights_k, indices_k)
|
354
|
+
v = self._process_grouped_experts(value, self.wv, self.bv, weights_k, indices_k)
|
355
|
+
|
356
|
+
# Expand to GQA format
|
357
|
+
k = k.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
358
|
+
v = v.permute(0, 2, 1, 3).reshape(B, self.num_groups, S, -1)
|
356
359
|
|
357
|
-
if not self.
|
360
|
+
if not self.rel_embed:
|
358
361
|
group_heads = self.num_heads // self.num_groups
|
359
362
|
|
360
363
|
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
@@ -366,12 +369,8 @@ class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
|
366
369
|
return q, k, v
|
367
370
|
|
368
371
|
|
369
|
-
class
|
372
|
+
class DeepMoeAttentionSimplified(GroupedMoeAttentionSimplified):
|
370
373
|
"""
|
371
|
-
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
372
|
-
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
373
|
-
experts - it has to be tested.
|
374
|
-
|
375
374
|
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
376
375
|
|
377
376
|
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
@@ -406,7 +405,7 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
406
405
|
):
|
407
406
|
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
408
407
|
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
409
|
-
super(
|
408
|
+
super(DeepMoeAttentionSimplified, self).__init__(
|
410
409
|
embed_dim,
|
411
410
|
num_heads,
|
412
411
|
num_groups=num_groups,
|
@@ -423,52 +422,48 @@ class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
|
423
422
|
**kwargs,
|
424
423
|
)
|
425
424
|
|
425
|
+
def router_loss(self):
|
426
|
+
return (self.router.aux_loss + self.query_router.aux_loss) / 2
|
427
|
+
|
426
428
|
def _init_q(self, embed_dim: int):
|
427
429
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
430
|
+
|
428
431
|
hidden_dim = embed_dim // self.num_heads
|
429
|
-
self.wq = nn.Parameter(torch.empty(self.num_query_experts,
|
432
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, hidden_dim, embed_dim))
|
430
433
|
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
431
434
|
self._init_query_experts()
|
432
435
|
|
433
436
|
def _init_query_experts(self):
|
434
|
-
|
437
|
+
nn.init.xavier_uniform_(self.wq)
|
435
438
|
if self.use_bias:
|
436
|
-
|
439
|
+
nn.init.zeros_(self.bq)
|
437
440
|
|
438
441
|
def _init_out(self, embed_dim: int):
|
439
442
|
"""Initialize output projection"""
|
440
|
-
|
441
|
-
self.out_proj = nn.Linear(
|
443
|
+
hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
|
444
|
+
self.out_proj = nn.Linear(hidden_dim, embed_dim)
|
442
445
|
|
443
446
|
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
444
447
|
"""Transpose attention output back to (B, T, D) shape"""
|
445
|
-
|
446
|
-
return attn_output.transpose(1, 2).contiguous().view(b, t,
|
448
|
+
hidden_dim = d // self.num_heads * self.num_query_groups
|
449
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
|
447
450
|
|
448
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
451
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
452
|
+
# Query processing
|
449
453
|
B, T, D = query.shape
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
if self.use_bias:
|
458
|
-
q_all += self.bq
|
459
|
-
|
460
|
-
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
461
|
-
|
462
|
-
# Gather top-k experts
|
463
|
-
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
|
464
|
-
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
465
|
-
|
466
|
-
# Weighted sum
|
467
|
-
q = selected_q * weights_q # [B, T, num_query_groups, head_dim]
|
468
|
-
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
454
|
+
# Flatten for query routing
|
455
|
+
query_flat = query.view(-1, D)
|
456
|
+
weights_q_flat, indices_q_flat = self.query_router(query_flat)
|
457
|
+
# Reshape back
|
458
|
+
weights_q = weights_q_flat.view(B, T, -1)
|
459
|
+
indices_q = indices_q_flat.view(B, T, -1)
|
460
|
+
q = self._process_grouped_experts(query, self.wq, self.bq, weights_q, indices_q)
|
469
461
|
|
462
|
+
q = q.permute(0, 2, 1, 3).reshape(B, self.num_query_groups, T, -1)
|
463
|
+
# Key/Value processing
|
470
464
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
471
465
|
|
466
|
+
|
472
467
|
# Others
|
473
468
|
|
474
469
|
class FlexAttention(MultiHeadAttention):
|
rxnn/experimental/moe.py
CHANGED
@@ -3,46 +3,6 @@ import torch.nn as nn
|
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from ..transformers.moe import MoeRouter
|
5
5
|
|
6
|
-
class DynamicMoeRouter(nn.Module):
|
7
|
-
"""Dynamic Mixture-of-Experts Router layer - dynamically selects top-k experts for each token."""
|
8
|
-
|
9
|
-
def __init__(self, embed_dim: int, num_experts: int, top_ks: tuple[int] = (1, 2, 3), *args, **kwargs):
|
10
|
-
super(DynamicMoeRouter, self).__init__(*args, **kwargs)
|
11
|
-
self.top_ks = top_ks
|
12
|
-
self.num_options = len(top_ks)
|
13
|
-
self.num_experts = num_experts
|
14
|
-
self.gate = nn.Linear(embed_dim, num_experts + self.num_options, bias=False)
|
15
|
-
# For expert load balancing
|
16
|
-
self.register_buffer('aux_loss', torch.tensor(0.0), persistent=False)
|
17
|
-
|
18
|
-
def calculate_aux_loss(self, top_k_indices: torch.Tensor, routing_probs: torch.Tensor) -> torch.Tensor:
|
19
|
-
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
20
|
-
expert_usage = expert_mask.sum(dim=0).mean(dim=0)
|
21
|
-
mean_probs = routing_probs.mean(dim=0)
|
22
|
-
return (expert_usage * mean_probs).sum() * self.num_experts
|
23
|
-
|
24
|
-
def forward(self, x: torch.Tensor):
|
25
|
-
# Input shape: [batch*seq_len, embed_dim]
|
26
|
-
all_logits = self.gate(x)
|
27
|
-
routing_logits = all_logits[:, :-self.num_options]
|
28
|
-
options_logits = all_logits[:, -self.num_options:]
|
29
|
-
|
30
|
-
routing_probs = F.softmax(routing_logits, dim=-1)
|
31
|
-
top_k_id = torch.argmax(options_logits, dim=-1).item()
|
32
|
-
|
33
|
-
top_k = self.top_ks[top_k_id]
|
34
|
-
|
35
|
-
# Get top-k experts for each token
|
36
|
-
top_k_weights, top_k_indices = routing_probs.topk(top_k, dim=-1)
|
37
|
-
|
38
|
-
# Normalize weights (sum to 1 for each token)
|
39
|
-
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
|
40
|
-
|
41
|
-
# Load Balance Loss
|
42
|
-
self.aux_loss = self.calculate_aux_loss(top_k_indices, routing_probs)
|
43
|
-
|
44
|
-
return top_k_weights, top_k_indices, top_k
|
45
|
-
|
46
6
|
class MoeFeedForwardVectorized(nn.Module):
|
47
7
|
"""
|
48
8
|
Vectorized MoE - current implementation is incorrect - it calculates all the experts, then selects the correct ones.
|
rxnn/transformers/attention.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
-
from torch.
|
4
|
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
5
5
|
import math
|
6
6
|
from .positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
|
7
7
|
|
@@ -102,36 +102,41 @@ class MultiHeadAttention(nn.Module):
|
|
102
102
|
|
103
103
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
104
104
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
105
|
-
with
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
105
|
+
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
106
|
+
return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
|
107
|
+
|
108
|
+
def _torch_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
109
|
+
mask: torch.Tensor = None, enable_gqa: bool = False):
|
110
|
+
attn_output = F.scaled_dot_product_attention(
|
111
|
+
q, k, v,
|
112
|
+
attn_mask=mask if not self.is_causal else None,
|
113
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
114
|
+
is_causal=self.is_causal,
|
115
|
+
enable_gqa=enable_gqa,
|
116
|
+
)
|
113
117
|
return self._transpose_output(attn_output, b, t, d)
|
114
118
|
|
115
|
-
def
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
+
def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
|
120
|
+
if self.use_flash_attention:
|
121
|
+
# Compute attention with FlashAttention
|
122
|
+
return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask)
|
123
|
+
else:
|
124
|
+
# Compute attention using optimized PyTorch implementation
|
125
|
+
return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask)
|
126
|
+
|
127
|
+
def _calculate_attention_with_relative_embedding(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
|
128
|
+
attn_weights = self._calculate_attn_weight_with_relative_embeddings(q, k, mask=mask)
|
129
|
+
attn_weights = self.dropout(attn_weights)
|
130
|
+
return self._calculate_output(attn_weights, v, b, t, d)
|
119
131
|
|
120
132
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
121
133
|
b, t, d = query.size()
|
122
134
|
q, k, v = self._forward_qkv(query, key, value, b, t, d)
|
123
|
-
if self.
|
135
|
+
if not self.rel_embed:
|
124
136
|
q, k = self._apply_rope(q, k)
|
125
|
-
attn_output = self.
|
137
|
+
attn_output = self._calculate_attention(q, k, v, b, t, d, mask=mask)
|
126
138
|
else:
|
127
|
-
|
128
|
-
attn_weights = self._calculate_attn_weights(q, k, d, mask=mask)
|
129
|
-
else:
|
130
|
-
attn_weights = self._calculate_attn_weight_with_relative_embeddings(q, k, mask=mask)
|
131
|
-
|
132
|
-
attn_weights = self.dropout(attn_weights)
|
133
|
-
|
134
|
-
attn_output = self._calculate_output(attn_weights, v, b, t, d)
|
139
|
+
attn_output = self._calculate_attention_with_relative_embedding(q, k, v, b, t, d, mask=mask)
|
135
140
|
return self.out_proj(attn_output)
|
136
141
|
|
137
142
|
|
@@ -178,7 +183,7 @@ class GroupedQueryAttention(MultiHeadAttention):
|
|
178
183
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
179
184
|
"""Override query, key, and value projections for GQA case - split data into heads and groups"""
|
180
185
|
head_dim = d // self.num_heads
|
181
|
-
if self.
|
186
|
+
if not self.rel_embed:
|
182
187
|
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
183
188
|
k = self.k_proj(key).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
184
189
|
v = self.v_proj(value).view(b, -1, self.num_groups, head_dim).transpose(1, 2)
|
@@ -202,12 +207,14 @@ class GroupedQueryAttention(MultiHeadAttention):
|
|
202
207
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
203
208
|
return q, k, v
|
204
209
|
|
205
|
-
def
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
210
|
+
def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
|
211
|
+
is_gqa = self.num_heads != self.num_groups
|
212
|
+
if self.use_flash_attention:
|
213
|
+
# Compute attention with FlashAttention
|
214
|
+
return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
|
215
|
+
else:
|
216
|
+
# Compute attention using optimized PyTorch implementation
|
217
|
+
return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=is_gqa)
|
211
218
|
|
212
219
|
|
213
220
|
class MultiQueryAttention(MultiHeadAttention):
|
@@ -251,7 +258,7 @@ class MultiQueryAttention(MultiHeadAttention):
|
|
251
258
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
252
259
|
"""Override query, key, and value projections for GQA case - use multiple heads
|
253
260
|
for query and single for key/values"""
|
254
|
-
if self.
|
261
|
+
if not self.rel_embed:
|
255
262
|
q = self.q_proj(query).view(b, t, self.num_heads, d // self.num_heads).transpose(1, 2)
|
256
263
|
k = self.k_proj(key).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
|
257
264
|
v = self.v_proj(value).view(b, -1, 1, d // self.num_heads).transpose(1, 2)
|
@@ -261,12 +268,13 @@ class MultiQueryAttention(MultiHeadAttention):
|
|
261
268
|
v = self.v_proj(value).unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
262
269
|
return q, k, v
|
263
270
|
|
264
|
-
def
|
265
|
-
|
266
|
-
|
267
|
-
q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask,
|
268
|
-
|
269
|
-
|
271
|
+
def _calculate_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int, mask: torch.Tensor = None):
|
272
|
+
if self.use_flash_attention:
|
273
|
+
# Compute attention with FlashAttention
|
274
|
+
return self._flash_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=True)
|
275
|
+
else:
|
276
|
+
# Compute attention using optimized PyTorch implementation
|
277
|
+
return self._torch_attention(q.contiguous(), k.contiguous(), v.contiguous(), b, t, d, mask=mask, enable_gqa=True)
|
270
278
|
|
271
279
|
|
272
280
|
def init_attention(
|
@@ -1,8 +1,8 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=GxbLmOTBvUiYU0Rc_0ju1n_ocJciHC6i3neDGe-rZZc,29426
|
4
4
|
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
|
-
rxnn/experimental/moe.py,sha256=
|
5
|
+
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
8
8
|
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=zv0uH3_L39tVmpiwNdmEf6Cp602uqdbr3UQj8Z3hIIk,15349
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.35.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.35.dist-info/METADATA,sha256=aziCzqOeetdE3gMV2i15QoB5O31bGpiZgzcpGM97QPk,16627
|
30
|
+
rxnn-0.1.35.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.35.dist-info/RECORD,,
|
File without changes
|
File without changes
|