rxnn 0.1.23__py3-none-any.whl → 0.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +13 -327
- rxnn/experimental/models.py +3 -5
- rxnn/training/bml.py +41 -72
- rxnn/transformers/attention.py +6 -4
- rxnn/transformers/layers.py +28 -2
- rxnn/transformers/moe.py +0 -1
- {rxnn-0.1.23.dist-info → rxnn-0.1.25.dist-info}/METADATA +1 -1
- {rxnn-0.1.23.dist-info → rxnn-0.1.25.dist-info}/RECORD +10 -10
- {rxnn-0.1.23.dist-info → rxnn-0.1.25.dist-info}/LICENSE +0 -0
- {rxnn-0.1.23.dist-info → rxnn-0.1.25.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -61,6 +61,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
61
61
|
**kwargs,
|
62
62
|
)
|
63
63
|
|
64
|
+
def router_loss(self):
|
65
|
+
return self.router.aux_loss
|
66
|
+
|
64
67
|
def _init_kv(self, embed_dim: int):
|
65
68
|
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
66
69
|
|
@@ -194,6 +197,9 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
194
197
|
**kwargs,
|
195
198
|
)
|
196
199
|
|
200
|
+
def router_loss(self):
|
201
|
+
return (self.router.aux_loss + self.query_router.aux_loss) / 2
|
202
|
+
|
197
203
|
def _init_q(self, embed_dim: int):
|
198
204
|
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
199
205
|
|
@@ -212,6 +218,11 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
212
218
|
hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
|
213
219
|
self.out_proj = nn.Linear(hidden_dim, embed_dim)
|
214
220
|
|
221
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
222
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
223
|
+
hidden_dim = d // self.num_heads * self.num_query_groups
|
224
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
|
225
|
+
|
215
226
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
216
227
|
# Query processing
|
217
228
|
B, T, D = query.shape
|
@@ -227,298 +238,6 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
227
238
|
# Key/Value processing
|
228
239
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
229
240
|
|
230
|
-
# Vectorized
|
231
|
-
|
232
|
-
class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
233
|
-
"""
|
234
|
-
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
235
|
-
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
236
|
-
experts - it has to be tested.
|
237
|
-
|
238
|
-
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
239
|
-
|
240
|
-
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
241
|
-
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
242
|
-
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
243
|
-
|
244
|
-
Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
|
245
|
-
this approach - we are training the full number of keys/values heads, while using only a group.
|
246
|
-
|
247
|
-
In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
|
248
|
-
|
249
|
-
Optionally, it could use even more expert heads than attention heads - in example:
|
250
|
-
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
251
|
-
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
252
|
-
|
253
|
-
© 2025 Adam Filipek
|
254
|
-
"""
|
255
|
-
|
256
|
-
def __init__(
|
257
|
-
self,
|
258
|
-
embed_dim: int,
|
259
|
-
num_heads: int,
|
260
|
-
num_groups: int,
|
261
|
-
dropout: float = 0.0,
|
262
|
-
rope: RotaryPositionalEmbedding = None,
|
263
|
-
rope_only_for_query: bool = False,
|
264
|
-
use_relative_embeddings: bool = False,
|
265
|
-
max_seq_len: int = 1024,
|
266
|
-
use_flash_attention: bool = False,
|
267
|
-
is_causal: bool = False,
|
268
|
-
use_bias: bool = False,
|
269
|
-
num_experts: int = None,
|
270
|
-
*args,
|
271
|
-
**kwargs,
|
272
|
-
):
|
273
|
-
self.num_experts = num_experts if num_experts is not None else num_heads
|
274
|
-
super(GroupedMoeAttentionVectorized, self).__init__(
|
275
|
-
embed_dim,
|
276
|
-
num_heads,
|
277
|
-
num_groups=num_groups,
|
278
|
-
dropout=dropout,
|
279
|
-
rope=rope,
|
280
|
-
rope_only_for_query=rope_only_for_query,
|
281
|
-
use_relative_embeddings=use_relative_embeddings,
|
282
|
-
max_seq_len=max_seq_len,
|
283
|
-
use_flash_attention=use_flash_attention,
|
284
|
-
is_causal=is_causal,
|
285
|
-
use_bias=use_bias,
|
286
|
-
*args,
|
287
|
-
**kwargs,
|
288
|
-
)
|
289
|
-
|
290
|
-
def _init_kv(self, embed_dim: int):
|
291
|
-
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
292
|
-
hidden_dim = embed_dim // self.num_heads
|
293
|
-
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
294
|
-
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
295
|
-
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
296
|
-
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
297
|
-
self._init_experts()
|
298
|
-
|
299
|
-
def _init_experts(self):
|
300
|
-
torch.nn.init.xavier_uniform_(self.wk)
|
301
|
-
torch.nn.init.xavier_uniform_(self.wv)
|
302
|
-
if self.use_bias:
|
303
|
-
torch.nn.init.zeros_(self.bk)
|
304
|
-
torch.nn.init.zeros_(self.bv)
|
305
|
-
|
306
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
307
|
-
skip_query_processing: bool = False):
|
308
|
-
# Indexed version may cause memory overflow
|
309
|
-
#
|
310
|
-
# head_dim = d // self.num_heads
|
311
|
-
#
|
312
|
-
# # Process Query as in GQA
|
313
|
-
# q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
|
314
|
-
# 2) if not skip_query_processing else query
|
315
|
-
#
|
316
|
-
# # Process Key and Value with MoE routing
|
317
|
-
# key_flat = key.view(-1, d) # (B*S, d)
|
318
|
-
# value_flat = value.view(-1, d) # (B*S, d)
|
319
|
-
#
|
320
|
-
# # Get routing indices and weights for K
|
321
|
-
# weights_k, indices_k = self.router(key_flat)
|
322
|
-
# indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
|
323
|
-
# weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
|
324
|
-
#
|
325
|
-
# # Select and compute K projections for only the top_k experts
|
326
|
-
# selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
|
327
|
-
# k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
|
328
|
-
# selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
|
329
|
-
# selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
|
330
|
-
#
|
331
|
-
# # Compute V using the same indices as K (since they share the same router)
|
332
|
-
# selected_v_weights = self.v_experts[indices_k]
|
333
|
-
# v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
|
334
|
-
# selected_v = (v_proj * weights_k).sum(dim=1)
|
335
|
-
# selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
|
336
|
-
#
|
337
|
-
# # Reshape to GQA format: (B, G, S, head_dim)
|
338
|
-
# k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
339
|
-
# v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
340
|
-
#
|
341
|
-
# if not self.use_flash_attention:
|
342
|
-
# group_heads = self.num_heads // self.num_groups
|
343
|
-
#
|
344
|
-
# k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
345
|
-
# v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
346
|
-
#
|
347
|
-
# k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
348
|
-
# v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
349
|
-
#
|
350
|
-
# return q, k, v
|
351
|
-
|
352
|
-
head_dim = d // self.num_heads
|
353
|
-
|
354
|
-
# Process Query as in GQA
|
355
|
-
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
356
|
-
|
357
|
-
# Process Key and Value with MoE routing
|
358
|
-
key_flat = key.view(-1, d)
|
359
|
-
weights, indices = self.router(key_flat)
|
360
|
-
weights = weights.view(b, key.size(1), self.num_groups, 1)
|
361
|
-
indices = indices.view(b, key.size(1), self.num_groups)
|
362
|
-
|
363
|
-
# Compute all experts' K and V projections
|
364
|
-
# Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
|
365
|
-
k_all = torch.einsum(
|
366
|
-
'be, ehd -> bedh',
|
367
|
-
key_flat,
|
368
|
-
self.wk.view(self.num_experts, d, -1)
|
369
|
-
).view(b, key.size(1), self.num_experts, -1)
|
370
|
-
|
371
|
-
v_all = torch.einsum(
|
372
|
-
'be, ehd -> bedh',
|
373
|
-
value.view(-1, d),
|
374
|
-
self.wv.view(self.num_experts, d, -1)
|
375
|
-
).view(b, value.size(1), self.num_experts, -1)
|
376
|
-
|
377
|
-
# Select top_k experts and compute weighted sum
|
378
|
-
selected_k = torch.gather(
|
379
|
-
k_all,
|
380
|
-
2,
|
381
|
-
indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
382
|
-
)
|
383
|
-
selected_v = torch.gather(
|
384
|
-
v_all,
|
385
|
-
2,
|
386
|
-
indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
|
387
|
-
)
|
388
|
-
|
389
|
-
selected_k = (selected_k * weights).sum(dim=2)
|
390
|
-
selected_v = (selected_v * weights).sum(dim=2)
|
391
|
-
# Reshape to GQA format: (B, G, S, head_dim)
|
392
|
-
k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
393
|
-
v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
394
|
-
|
395
|
-
if not self.use_flash_attention:
|
396
|
-
group_heads = self.num_heads // self.num_groups
|
397
|
-
|
398
|
-
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
399
|
-
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
400
|
-
|
401
|
-
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
402
|
-
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
403
|
-
|
404
|
-
return q, k, v
|
405
|
-
|
406
|
-
|
407
|
-
class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
408
|
-
"""
|
409
|
-
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
410
|
-
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
411
|
-
experts - it has to be tested.
|
412
|
-
|
413
|
-
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
414
|
-
|
415
|
-
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
416
|
-
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
417
|
-
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
418
|
-
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
419
|
-
|
420
|
-
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
421
|
-
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
422
|
-
|
423
|
-
© 2025 Adam Filipek
|
424
|
-
"""
|
425
|
-
|
426
|
-
def __init__(
|
427
|
-
self,
|
428
|
-
embed_dim: int,
|
429
|
-
num_heads: int,
|
430
|
-
num_groups: int,
|
431
|
-
dropout: float = 0.0,
|
432
|
-
rope: RotaryPositionalEmbedding = None,
|
433
|
-
rope_only_for_query: bool = False,
|
434
|
-
use_relative_embeddings: bool = False,
|
435
|
-
max_seq_len: int = 1024,
|
436
|
-
use_flash_attention: bool = False,
|
437
|
-
is_causal: bool = False,
|
438
|
-
use_bias: bool = False,
|
439
|
-
num_experts: int = None,
|
440
|
-
num_query_experts: int = None,
|
441
|
-
num_query_groups: int = None,
|
442
|
-
*args,
|
443
|
-
**kwargs,
|
444
|
-
):
|
445
|
-
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
446
|
-
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
447
|
-
super(DeepMoeAttentionVectorized, self).__init__(
|
448
|
-
embed_dim,
|
449
|
-
num_heads,
|
450
|
-
num_groups=num_groups,
|
451
|
-
dropout=dropout,
|
452
|
-
rope=rope,
|
453
|
-
rope_only_for_query=rope_only_for_query,
|
454
|
-
use_relative_embeddings=use_relative_embeddings,
|
455
|
-
max_seq_len=max_seq_len,
|
456
|
-
use_flash_attention=use_flash_attention,
|
457
|
-
is_causal=is_causal,
|
458
|
-
use_bias=use_bias,
|
459
|
-
num_experts=num_experts,
|
460
|
-
*args,
|
461
|
-
**kwargs,
|
462
|
-
)
|
463
|
-
|
464
|
-
def _init_q(self, embed_dim: int):
|
465
|
-
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
466
|
-
hidden_dim = embed_dim // self.num_heads
|
467
|
-
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
468
|
-
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
469
|
-
self._init_query_experts()
|
470
|
-
|
471
|
-
def _init_query_experts(self):
|
472
|
-
torch.nn.init.xavier_uniform_(self.wq)
|
473
|
-
if self.use_bias:
|
474
|
-
torch.nn.init.zeros_(self.bq)
|
475
|
-
|
476
|
-
def _init_out(self, embed_dim: int):
|
477
|
-
"""Initialize output projection"""
|
478
|
-
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
|
479
|
-
|
480
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
481
|
-
# Indexed version may cause memory overflow
|
482
|
-
#
|
483
|
-
# head_dim = d // self.num_heads
|
484
|
-
#
|
485
|
-
# # Process Query with MoE routing
|
486
|
-
# query_flat = query.view(-1, d) # (B*T, d)
|
487
|
-
# weights_q, indices_q = self.query_router(query_flat)
|
488
|
-
# indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
|
489
|
-
# weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
|
490
|
-
#
|
491
|
-
# # Select and compute Q projections for top_k experts
|
492
|
-
# selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
|
493
|
-
# q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
|
494
|
-
# selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
|
495
|
-
# selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
|
496
|
-
head_dim = d // self.num_heads
|
497
|
-
|
498
|
-
# Process Query with MoE routing
|
499
|
-
query_flat = query.view(b * t, d)
|
500
|
-
weights_q, indices_q = self.query_router(query_flat)
|
501
|
-
weights_q = weights_q.view(b, t, self.num_query_groups, 1)
|
502
|
-
indices_q = indices_q.view(b, t, self.num_query_groups)
|
503
|
-
|
504
|
-
# Compute all experts' Q projections
|
505
|
-
q_all = torch.einsum(
|
506
|
-
'be, ehd -> bedh',
|
507
|
-
query_flat,
|
508
|
-
self.wq.view(self.num_query_experts, d, -1)
|
509
|
-
).view(b, t, self.num_query_experts, -1)
|
510
|
-
|
511
|
-
selected_q = torch.gather(
|
512
|
-
q_all,
|
513
|
-
2,
|
514
|
-
indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
|
515
|
-
)
|
516
|
-
selected_q = (selected_q * weights_q).sum(dim=2)
|
517
|
-
|
518
|
-
q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
|
519
|
-
|
520
|
-
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
521
|
-
|
522
241
|
|
523
242
|
# Others
|
524
243
|
|
@@ -670,8 +389,7 @@ def init_moe_attention(
|
|
670
389
|
num_query_experts: int = None,
|
671
390
|
num_query_groups: int = None,
|
672
391
|
) -> GroupedQueryAttention:
|
673
|
-
assert attention_type
|
674
|
-
"Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
392
|
+
assert attention_type in ['gma', 'dma'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
675
393
|
|
676
394
|
if attention_type == "gma":
|
677
395
|
return GroupedMoeAttention(
|
@@ -688,7 +406,7 @@ def init_moe_attention(
|
|
688
406
|
use_bias=use_bias,
|
689
407
|
num_experts=num_experts,
|
690
408
|
)
|
691
|
-
|
409
|
+
else:
|
692
410
|
return DeepMoeAttention(
|
693
411
|
embed_dim,
|
694
412
|
num_heads,
|
@@ -705,35 +423,3 @@ def init_moe_attention(
|
|
705
423
|
num_query_experts=num_query_experts,
|
706
424
|
num_query_groups=num_query_groups,
|
707
425
|
)
|
708
|
-
elif attention_type == "gma_v":
|
709
|
-
return GroupedMoeAttentionVectorized(
|
710
|
-
embed_dim,
|
711
|
-
num_heads,
|
712
|
-
gqa_groups,
|
713
|
-
dropout=dropout,
|
714
|
-
rope=rope,
|
715
|
-
use_relative_embeddings=use_relative_embeddings,
|
716
|
-
max_seq_len=max_seq_len,
|
717
|
-
rope_only_for_query=rope_only_for_query,
|
718
|
-
use_flash_attention=use_flash_attention,
|
719
|
-
is_causal=is_causal,
|
720
|
-
use_bias=use_bias,
|
721
|
-
num_experts=num_experts,
|
722
|
-
)
|
723
|
-
else:
|
724
|
-
return DeepMoeAttentionVectorized(
|
725
|
-
embed_dim,
|
726
|
-
num_heads,
|
727
|
-
gqa_groups,
|
728
|
-
dropout=dropout,
|
729
|
-
rope=rope,
|
730
|
-
use_relative_embeddings=use_relative_embeddings,
|
731
|
-
max_seq_len=max_seq_len,
|
732
|
-
rope_only_for_query=rope_only_for_query,
|
733
|
-
use_flash_attention=use_flash_attention,
|
734
|
-
is_causal=is_causal,
|
735
|
-
use_bias=use_bias,
|
736
|
-
num_experts=num_experts,
|
737
|
-
num_query_experts=num_query_experts,
|
738
|
-
num_query_groups=num_query_groups,
|
739
|
-
)
|
rxnn/experimental/models.py
CHANGED
@@ -35,7 +35,7 @@ class MoeAttentionTransformerConfig(TypedDict):
|
|
35
35
|
|
36
36
|
|
37
37
|
class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
38
|
-
"""Research model for experiments with Mixture-of-Experts Attention"""
|
38
|
+
"""Research decoder model for experiments with Mixture-of-Experts Attention"""
|
39
39
|
|
40
40
|
def __init__(
|
41
41
|
self,
|
@@ -65,8 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
65
65
|
assert ff_activation in ['relu', 'gelu',
|
66
66
|
'swish', 'silu', 'linear',
|
67
67
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
68
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', '
|
69
|
-
'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
|
68
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma"'
|
70
69
|
|
71
70
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
72
71
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -111,6 +110,5 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
111
110
|
def load_shared_embedding(self, embedding: nn.Embedding):
|
112
111
|
self.model.embedding = embedding
|
113
112
|
|
114
|
-
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) ->
|
115
|
-
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
113
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
116
114
|
return self.model(x, attention_mask=attention_mask)
|
rxnn/training/bml.py
CHANGED
@@ -50,10 +50,14 @@ class MLMTrainer(BaseTrainer):
|
|
50
50
|
vocab_size: int,
|
51
51
|
use_amp: bool = False,
|
52
52
|
dtype: torch.dtype = None,
|
53
|
+
use_moe_aux_loss: bool = False,
|
54
|
+
moe_aux_loss_scale: float = 0.01,
|
53
55
|
**kwargs
|
54
56
|
):
|
55
57
|
super(MLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
|
56
58
|
self.vocab_size = vocab_size
|
59
|
+
self.use_moe_aux_loss = use_moe_aux_loss
|
60
|
+
self.moe_aux_loss_scale = moe_aux_loss_scale
|
57
61
|
|
58
62
|
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
59
63
|
inputs = batch['input_ids']
|
@@ -65,11 +69,32 @@ class MLMTrainer(BaseTrainer):
|
|
65
69
|
attention_mask=attention_mask
|
66
70
|
)
|
67
71
|
|
68
|
-
|
72
|
+
loss = F.cross_entropy(
|
69
73
|
logits.view(-1, self.vocab_size),
|
70
74
|
labels.view(-1),
|
71
75
|
ignore_index=-100
|
72
|
-
)
|
76
|
+
)
|
77
|
+
|
78
|
+
return self._moe_aux_loss(loss), logits
|
79
|
+
|
80
|
+
def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
|
81
|
+
if not self.use_moe_aux_loss:
|
82
|
+
return main_loss
|
83
|
+
|
84
|
+
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
85
|
+
|
86
|
+
router_loss = model.encoder.model.moe_router_loss()
|
87
|
+
loss = main_loss + self.moe_aux_loss_scale * router_loss
|
88
|
+
|
89
|
+
if self.writer is not None:
|
90
|
+
if self.model.training:
|
91
|
+
self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
|
92
|
+
self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
|
93
|
+
else:
|
94
|
+
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
|
95
|
+
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
|
96
|
+
|
97
|
+
return loss
|
73
98
|
|
74
99
|
def validate(self, batch_size: int) -> tuple[float, dict]:
|
75
100
|
self.model.eval()
|
@@ -113,11 +138,15 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
113
138
|
vocab_size: int,
|
114
139
|
use_amp: bool = False,
|
115
140
|
dtype: torch.dtype = None,
|
141
|
+
use_moe_aux_loss: bool = False,
|
142
|
+
moe_aux_loss_scale: float = 0.01,
|
116
143
|
**kwargs
|
117
144
|
):
|
118
145
|
super(AutoregressiveTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
|
119
146
|
target_field_name='targets', **kwargs)
|
120
147
|
self.vocab_size = vocab_size
|
148
|
+
self.use_moe_aux_loss = use_moe_aux_loss
|
149
|
+
self.moe_aux_loss_scale = moe_aux_loss_scale
|
121
150
|
|
122
151
|
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
123
152
|
inputs = batch['input_ids']
|
@@ -132,84 +161,21 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
132
161
|
shifted_logits = outputs[:, :-1].contiguous()
|
133
162
|
shifted_targets = targets[:, 1:].contiguous()
|
134
163
|
|
135
|
-
|
164
|
+
loss = F.cross_entropy(
|
136
165
|
shifted_logits.view(-1, self.vocab_size),
|
137
166
|
shifted_targets.view(-1)
|
138
|
-
), outputs
|
139
|
-
|
140
|
-
def validate(self, batch_size: int) -> tuple[float, dict]:
|
141
|
-
self.model.eval()
|
142
|
-
val_dataloader = self._valid_loader(batch_size)
|
143
|
-
val_loss = torch.tensor(0.0).to(self.device)
|
144
|
-
correct = torch.tensor(0).to(self.device)
|
145
|
-
total = torch.tensor(0).to(self.device)
|
146
|
-
|
147
|
-
with torch.no_grad():
|
148
|
-
for batch in val_dataloader:
|
149
|
-
if self.get_batch_size(batch) == batch_size:
|
150
|
-
loss, logits = self.valid_step(batch)
|
151
|
-
val_loss += loss
|
152
|
-
shifted_logits = logits[:, :-1].contiguous()
|
153
|
-
shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
|
154
|
-
valid_indices = shifted_targets != -100
|
155
|
-
if valid_indices.any():
|
156
|
-
preds = shifted_logits.argmax(-1)
|
157
|
-
correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
|
158
|
-
total += valid_indices.sum()
|
159
|
-
|
160
|
-
avg_loss = (val_loss / len(val_dataloader)).item()
|
161
|
-
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
162
|
-
node_acc = acc.item()
|
163
|
-
if self.use_ddp:
|
164
|
-
dist.all_reduce(acc, op=dist.ReduceOp.SUM)
|
165
|
-
acc = acc / dist.get_world_size()
|
166
|
-
|
167
|
-
metrics = {
|
168
|
-
'accuracy': acc.item(),
|
169
|
-
'node_accuracy': node_acc,
|
170
|
-
}
|
171
|
-
self.model.train()
|
172
|
-
return avg_loss, metrics
|
173
|
-
|
174
|
-
|
175
|
-
class AutoregressiveMoeTrainer(BaseTrainer):
|
176
|
-
def __init__(
|
177
|
-
self,
|
178
|
-
model: ReactiveTransformerDecoder,
|
179
|
-
device: torch.device,
|
180
|
-
vocab_size: int,
|
181
|
-
use_amp: bool = False,
|
182
|
-
dtype: torch.dtype = None,
|
183
|
-
router_loss_scale: float = 0.1,
|
184
|
-
**kwargs
|
185
|
-
):
|
186
|
-
super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
|
187
|
-
target_field_name='targets', **kwargs)
|
188
|
-
self.vocab_size = vocab_size
|
189
|
-
self.router_loss_scale = router_loss_scale
|
190
|
-
|
191
|
-
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
192
|
-
inputs = batch['input_ids']
|
193
|
-
attention_mask = batch['attention_mask']
|
194
|
-
targets = batch['targets']
|
195
|
-
|
196
|
-
outputs = self.model(
|
197
|
-
inputs,
|
198
|
-
attention_mask=attention_mask
|
199
167
|
)
|
200
168
|
|
201
|
-
|
202
|
-
shifted_targets = targets[:, 1:].contiguous()
|
169
|
+
return self._moe_aux_loss(loss), outputs
|
203
170
|
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
)
|
171
|
+
def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
|
172
|
+
if not self.use_moe_aux_loss:
|
173
|
+
return main_loss
|
208
174
|
|
209
175
|
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
210
176
|
|
211
177
|
router_loss = model.model.moe_router_loss()
|
212
|
-
loss = main_loss + self.
|
178
|
+
loss = main_loss + self.moe_aux_loss_scale * router_loss
|
213
179
|
|
214
180
|
if self.writer is not None:
|
215
181
|
if self.model.training:
|
@@ -219,7 +185,7 @@ class AutoregressiveMoeTrainer(BaseTrainer):
|
|
219
185
|
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
|
220
186
|
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
|
221
187
|
|
222
|
-
return loss
|
188
|
+
return loss
|
223
189
|
|
224
190
|
def validate(self, batch_size: int) -> tuple[float, dict]:
|
225
191
|
self.model.eval()
|
@@ -279,6 +245,9 @@ class JointTrainingModel(nn.Module):
|
|
279
245
|
|
280
246
|
|
281
247
|
class JointLMTrainer(BaseTrainer):
|
248
|
+
""""
|
249
|
+
It's not recommended to use Joint LM Training in current implementation. More info soon
|
250
|
+
"""
|
282
251
|
def __init__(
|
283
252
|
self,
|
284
253
|
model: JointTrainingModel,
|
rxnn/transformers/attention.py
CHANGED
@@ -91,9 +91,13 @@ class MultiHeadAttention(nn.Module):
|
|
91
91
|
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
|
92
92
|
return F.softmax(attn_logits, dim=-1)
|
93
93
|
|
94
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
95
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
96
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, d)
|
97
|
+
|
94
98
|
def _calculate_output(self, attn_weights: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int):
|
95
99
|
"""Calculate the output by multiplying attention weights with values and concatenating heads"""
|
96
|
-
return torch.matmul(attn_weights, v)
|
100
|
+
return self._transpose_output(torch.matmul(attn_weights, v), b, t, d)
|
97
101
|
|
98
102
|
def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
99
103
|
mask: torch.Tensor = None, enable_gqa: bool = False):
|
@@ -104,9 +108,7 @@ class MultiHeadAttention(nn.Module):
|
|
104
108
|
is_causal=self.is_causal,
|
105
109
|
enable_gqa=enable_gqa,
|
106
110
|
)
|
107
|
-
|
108
|
-
# Reshape back to (B, T, D)
|
109
|
-
return attn_output.transpose(1, 2).contiguous().view(b, t, d)
|
111
|
+
return self._transpose_output(attn_output, b, t, d)
|
110
112
|
|
111
113
|
def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
|
112
114
|
mask: torch.Tensor = None):
|
rxnn/transformers/layers.py
CHANGED
@@ -60,7 +60,23 @@ class ReactiveTransformerLayer(nn.Module):
|
|
60
60
|
param.requires_grad_(is_trainable)
|
61
61
|
|
62
62
|
def moe_router_loss(self):
|
63
|
-
|
63
|
+
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
64
|
+
att_router_loss = None
|
65
|
+
if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
|
66
|
+
att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
|
67
|
+
elif self.attention.router_loss is not None:
|
68
|
+
att_router_loss = self.attention.router_loss()
|
69
|
+
elif self.memory_cross_attention.router_loss is not None:
|
70
|
+
att_router_loss = self.memory_cross_attention.router_loss()
|
71
|
+
|
72
|
+
if ff_router_loss is not None and att_router_loss is not None:
|
73
|
+
return (ff_router_loss + att_router_loss) / 2
|
74
|
+
elif ff_router_loss is not None:
|
75
|
+
return ff_router_loss
|
76
|
+
elif att_router_loss is not None:
|
77
|
+
return att_router_loss
|
78
|
+
else:
|
79
|
+
return None
|
64
80
|
|
65
81
|
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
66
82
|
# First step, self-attention
|
@@ -136,7 +152,17 @@ class ClassicTransformerLayer(nn.Module):
|
|
136
152
|
self.use_moe = use_moe
|
137
153
|
|
138
154
|
def moe_router_loss(self):
|
139
|
-
|
155
|
+
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
156
|
+
att_router_loss = self.attention.router_loss() if self.attention.router_loss is not None else None
|
157
|
+
|
158
|
+
if ff_router_loss is not None and att_router_loss is not None:
|
159
|
+
return (ff_router_loss + att_router_loss) / 2
|
160
|
+
elif ff_router_loss is not None:
|
161
|
+
return ff_router_loss
|
162
|
+
elif att_router_loss is not None:
|
163
|
+
return att_router_loss
|
164
|
+
else:
|
165
|
+
return None
|
140
166
|
|
141
167
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
142
168
|
# First step, self-attention
|
rxnn/transformers/moe.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256
|
3
|
+
rxnn/experimental/attention.py,sha256=rJSQjA7_9YqcM4Y8SyJSuZQjyz8j4XPhC5jcrcrRK2M,17891
|
4
|
+
rxnn/experimental/models.py,sha256=8KAo7BtRkke9qRlzGRtQa9-EZ34roGWrn0N_T6L-6Wc,4561
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -10,22 +10,22 @@ rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
|
11
11
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
|
13
|
-
rxnn/training/bml.py,sha256=
|
13
|
+
rxnn/training/bml.py,sha256=HtxSzI-WcpRclAuIccF_WoTZ24KzH5ZWKe8SnWgjjm4,17581
|
14
14
|
rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
|
15
15
|
rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=
|
21
|
+
rxnn/transformers/layers.py,sha256=ZJfNdgCv9dzrKqsWIMf99Ryzgs494ZhkwK4zSBYLvQ4,6880
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
24
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.25.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.25.dist-info/METADATA,sha256=kRnBikg0u4uhy5IwcFpEE3gFEB8wTvEpg6cFcAU8FGs,16627
|
30
|
+
rxnn-0.1.25.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.25.dist-info/RECORD,,
|
File without changes
|
File without changes
|