rxnn 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +2 -327
- rxnn/experimental/models.py +3 -5
- rxnn/training/bml.py +41 -72
- rxnn/transformers/layers.py +3 -2
- {rxnn-0.1.24.dist-info → rxnn-0.1.26.dist-info}/METADATA +1 -1
- {rxnn-0.1.24.dist-info → rxnn-0.1.26.dist-info}/RECORD +8 -8
- {rxnn-0.1.24.dist-info → rxnn-0.1.26.dist-info}/LICENSE +0 -0
- {rxnn-0.1.24.dist-info → rxnn-0.1.26.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -238,298 +238,6 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
238
238
|
# Key/Value processing
|
239
239
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
240
240
|
|
241
|
-
# Vectorized
|
242
|
-
|
243
|
-
class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
244
|
-
"""
|
245
|
-
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
246
|
-
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
247
|
-
experts - it has to be tested.
|
248
|
-
|
249
|
-
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
250
|
-
|
251
|
-
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
252
|
-
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
253
|
-
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
254
|
-
|
255
|
-
Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
|
256
|
-
this approach - we are training the full number of keys/values heads, while using only a group.
|
257
|
-
|
258
|
-
In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
|
259
|
-
|
260
|
-
Optionally, it could use even more expert heads than attention heads - in example:
|
261
|
-
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
262
|
-
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
263
|
-
|
264
|
-
© 2025 Adam Filipek
|
265
|
-
"""
|
266
|
-
|
267
|
-
def __init__(
|
268
|
-
self,
|
269
|
-
embed_dim: int,
|
270
|
-
num_heads: int,
|
271
|
-
num_groups: int,
|
272
|
-
dropout: float = 0.0,
|
273
|
-
rope: RotaryPositionalEmbedding = None,
|
274
|
-
rope_only_for_query: bool = False,
|
275
|
-
use_relative_embeddings: bool = False,
|
276
|
-
max_seq_len: int = 1024,
|
277
|
-
use_flash_attention: bool = False,
|
278
|
-
is_causal: bool = False,
|
279
|
-
use_bias: bool = False,
|
280
|
-
num_experts: int = None,
|
281
|
-
*args,
|
282
|
-
**kwargs,
|
283
|
-
):
|
284
|
-
self.num_experts = num_experts if num_experts is not None else num_heads
|
285
|
-
super(GroupedMoeAttentionVectorized, self).__init__(
|
286
|
-
embed_dim,
|
287
|
-
num_heads,
|
288
|
-
num_groups=num_groups,
|
289
|
-
dropout=dropout,
|
290
|
-
rope=rope,
|
291
|
-
rope_only_for_query=rope_only_for_query,
|
292
|
-
use_relative_embeddings=use_relative_embeddings,
|
293
|
-
max_seq_len=max_seq_len,
|
294
|
-
use_flash_attention=use_flash_attention,
|
295
|
-
is_causal=is_causal,
|
296
|
-
use_bias=use_bias,
|
297
|
-
*args,
|
298
|
-
**kwargs,
|
299
|
-
)
|
300
|
-
|
301
|
-
def _init_kv(self, embed_dim: int):
|
302
|
-
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
303
|
-
hidden_dim = embed_dim // self.num_heads
|
304
|
-
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
305
|
-
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
306
|
-
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
307
|
-
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
308
|
-
self._init_experts()
|
309
|
-
|
310
|
-
def _init_experts(self):
|
311
|
-
torch.nn.init.xavier_uniform_(self.wk)
|
312
|
-
torch.nn.init.xavier_uniform_(self.wv)
|
313
|
-
if self.use_bias:
|
314
|
-
torch.nn.init.zeros_(self.bk)
|
315
|
-
torch.nn.init.zeros_(self.bv)
|
316
|
-
|
317
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
318
|
-
skip_query_processing: bool = False):
|
319
|
-
# Indexed version may cause memory overflow
|
320
|
-
#
|
321
|
-
# head_dim = d // self.num_heads
|
322
|
-
#
|
323
|
-
# # Process Query as in GQA
|
324
|
-
# q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
|
325
|
-
# 2) if not skip_query_processing else query
|
326
|
-
#
|
327
|
-
# # Process Key and Value with MoE routing
|
328
|
-
# key_flat = key.view(-1, d) # (B*S, d)
|
329
|
-
# value_flat = value.view(-1, d) # (B*S, d)
|
330
|
-
#
|
331
|
-
# # Get routing indices and weights for K
|
332
|
-
# weights_k, indices_k = self.router(key_flat)
|
333
|
-
# indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
|
334
|
-
# weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
|
335
|
-
#
|
336
|
-
# # Select and compute K projections for only the top_k experts
|
337
|
-
# selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
|
338
|
-
# k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
|
339
|
-
# selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
|
340
|
-
# selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
|
341
|
-
#
|
342
|
-
# # Compute V using the same indices as K (since they share the same router)
|
343
|
-
# selected_v_weights = self.v_experts[indices_k]
|
344
|
-
# v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
|
345
|
-
# selected_v = (v_proj * weights_k).sum(dim=1)
|
346
|
-
# selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
|
347
|
-
#
|
348
|
-
# # Reshape to GQA format: (B, G, S, head_dim)
|
349
|
-
# k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
350
|
-
# v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
351
|
-
#
|
352
|
-
# if not self.use_flash_attention:
|
353
|
-
# group_heads = self.num_heads // self.num_groups
|
354
|
-
#
|
355
|
-
# k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
356
|
-
# v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
357
|
-
#
|
358
|
-
# k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
359
|
-
# v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
360
|
-
#
|
361
|
-
# return q, k, v
|
362
|
-
|
363
|
-
head_dim = d // self.num_heads
|
364
|
-
|
365
|
-
# Process Query as in GQA
|
366
|
-
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
367
|
-
|
368
|
-
# Process Key and Value with MoE routing
|
369
|
-
key_flat = key.view(-1, d)
|
370
|
-
weights, indices = self.router(key_flat)
|
371
|
-
weights = weights.view(b, key.size(1), self.num_groups, 1)
|
372
|
-
indices = indices.view(b, key.size(1), self.num_groups)
|
373
|
-
|
374
|
-
# Compute all experts' K and V projections
|
375
|
-
# Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
|
376
|
-
k_all = torch.einsum(
|
377
|
-
'be, ehd -> bedh',
|
378
|
-
key_flat,
|
379
|
-
self.wk.view(self.num_experts, d, -1)
|
380
|
-
).view(b, key.size(1), self.num_experts, -1)
|
381
|
-
|
382
|
-
v_all = torch.einsum(
|
383
|
-
'be, ehd -> bedh',
|
384
|
-
value.view(-1, d),
|
385
|
-
self.wv.view(self.num_experts, d, -1)
|
386
|
-
).view(b, value.size(1), self.num_experts, -1)
|
387
|
-
|
388
|
-
# Select top_k experts and compute weighted sum
|
389
|
-
selected_k = torch.gather(
|
390
|
-
k_all,
|
391
|
-
2,
|
392
|
-
indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
393
|
-
)
|
394
|
-
selected_v = torch.gather(
|
395
|
-
v_all,
|
396
|
-
2,
|
397
|
-
indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
|
398
|
-
)
|
399
|
-
|
400
|
-
selected_k = (selected_k * weights).sum(dim=2)
|
401
|
-
selected_v = (selected_v * weights).sum(dim=2)
|
402
|
-
# Reshape to GQA format: (B, G, S, head_dim)
|
403
|
-
k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
404
|
-
v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
405
|
-
|
406
|
-
if not self.use_flash_attention:
|
407
|
-
group_heads = self.num_heads // self.num_groups
|
408
|
-
|
409
|
-
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
410
|
-
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
411
|
-
|
412
|
-
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
413
|
-
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
414
|
-
|
415
|
-
return q, k, v
|
416
|
-
|
417
|
-
|
418
|
-
class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
419
|
-
"""
|
420
|
-
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
421
|
-
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
422
|
-
experts - it has to be tested.
|
423
|
-
|
424
|
-
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
425
|
-
|
426
|
-
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
427
|
-
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
428
|
-
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
429
|
-
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
430
|
-
|
431
|
-
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
432
|
-
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
433
|
-
|
434
|
-
© 2025 Adam Filipek
|
435
|
-
"""
|
436
|
-
|
437
|
-
def __init__(
|
438
|
-
self,
|
439
|
-
embed_dim: int,
|
440
|
-
num_heads: int,
|
441
|
-
num_groups: int,
|
442
|
-
dropout: float = 0.0,
|
443
|
-
rope: RotaryPositionalEmbedding = None,
|
444
|
-
rope_only_for_query: bool = False,
|
445
|
-
use_relative_embeddings: bool = False,
|
446
|
-
max_seq_len: int = 1024,
|
447
|
-
use_flash_attention: bool = False,
|
448
|
-
is_causal: bool = False,
|
449
|
-
use_bias: bool = False,
|
450
|
-
num_experts: int = None,
|
451
|
-
num_query_experts: int = None,
|
452
|
-
num_query_groups: int = None,
|
453
|
-
*args,
|
454
|
-
**kwargs,
|
455
|
-
):
|
456
|
-
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
457
|
-
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
458
|
-
super(DeepMoeAttentionVectorized, self).__init__(
|
459
|
-
embed_dim,
|
460
|
-
num_heads,
|
461
|
-
num_groups=num_groups,
|
462
|
-
dropout=dropout,
|
463
|
-
rope=rope,
|
464
|
-
rope_only_for_query=rope_only_for_query,
|
465
|
-
use_relative_embeddings=use_relative_embeddings,
|
466
|
-
max_seq_len=max_seq_len,
|
467
|
-
use_flash_attention=use_flash_attention,
|
468
|
-
is_causal=is_causal,
|
469
|
-
use_bias=use_bias,
|
470
|
-
num_experts=num_experts,
|
471
|
-
*args,
|
472
|
-
**kwargs,
|
473
|
-
)
|
474
|
-
|
475
|
-
def _init_q(self, embed_dim: int):
|
476
|
-
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
477
|
-
hidden_dim = embed_dim // self.num_heads
|
478
|
-
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
479
|
-
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
480
|
-
self._init_query_experts()
|
481
|
-
|
482
|
-
def _init_query_experts(self):
|
483
|
-
torch.nn.init.xavier_uniform_(self.wq)
|
484
|
-
if self.use_bias:
|
485
|
-
torch.nn.init.zeros_(self.bq)
|
486
|
-
|
487
|
-
def _init_out(self, embed_dim: int):
|
488
|
-
"""Initialize output projection"""
|
489
|
-
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
|
490
|
-
|
491
|
-
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
492
|
-
# Indexed version may cause memory overflow
|
493
|
-
#
|
494
|
-
# head_dim = d // self.num_heads
|
495
|
-
#
|
496
|
-
# # Process Query with MoE routing
|
497
|
-
# query_flat = query.view(-1, d) # (B*T, d)
|
498
|
-
# weights_q, indices_q = self.query_router(query_flat)
|
499
|
-
# indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
|
500
|
-
# weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
|
501
|
-
#
|
502
|
-
# # Select and compute Q projections for top_k experts
|
503
|
-
# selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
|
504
|
-
# q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
|
505
|
-
# selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
|
506
|
-
# selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
|
507
|
-
head_dim = d // self.num_heads
|
508
|
-
|
509
|
-
# Process Query with MoE routing
|
510
|
-
query_flat = query.view(b * t, d)
|
511
|
-
weights_q, indices_q = self.query_router(query_flat)
|
512
|
-
weights_q = weights_q.view(b, t, self.num_query_groups, 1)
|
513
|
-
indices_q = indices_q.view(b, t, self.num_query_groups)
|
514
|
-
|
515
|
-
# Compute all experts' Q projections
|
516
|
-
q_all = torch.einsum(
|
517
|
-
'be, ehd -> bedh',
|
518
|
-
query_flat,
|
519
|
-
self.wq.view(self.num_query_experts, d, -1)
|
520
|
-
).view(b, t, self.num_query_experts, -1)
|
521
|
-
|
522
|
-
selected_q = torch.gather(
|
523
|
-
q_all,
|
524
|
-
2,
|
525
|
-
indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
|
526
|
-
)
|
527
|
-
selected_q = (selected_q * weights_q).sum(dim=2)
|
528
|
-
|
529
|
-
q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
|
530
|
-
|
531
|
-
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
532
|
-
|
533
241
|
|
534
242
|
# Others
|
535
243
|
|
@@ -681,8 +389,7 @@ def init_moe_attention(
|
|
681
389
|
num_query_experts: int = None,
|
682
390
|
num_query_groups: int = None,
|
683
391
|
) -> GroupedQueryAttention:
|
684
|
-
assert attention_type
|
685
|
-
"Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
392
|
+
assert attention_type in ['gma', 'dma'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
686
393
|
|
687
394
|
if attention_type == "gma":
|
688
395
|
return GroupedMoeAttention(
|
@@ -699,7 +406,7 @@ def init_moe_attention(
|
|
699
406
|
use_bias=use_bias,
|
700
407
|
num_experts=num_experts,
|
701
408
|
)
|
702
|
-
|
409
|
+
else:
|
703
410
|
return DeepMoeAttention(
|
704
411
|
embed_dim,
|
705
412
|
num_heads,
|
@@ -716,35 +423,3 @@ def init_moe_attention(
|
|
716
423
|
num_query_experts=num_query_experts,
|
717
424
|
num_query_groups=num_query_groups,
|
718
425
|
)
|
719
|
-
elif attention_type == "gma_v":
|
720
|
-
return GroupedMoeAttentionVectorized(
|
721
|
-
embed_dim,
|
722
|
-
num_heads,
|
723
|
-
gqa_groups,
|
724
|
-
dropout=dropout,
|
725
|
-
rope=rope,
|
726
|
-
use_relative_embeddings=use_relative_embeddings,
|
727
|
-
max_seq_len=max_seq_len,
|
728
|
-
rope_only_for_query=rope_only_for_query,
|
729
|
-
use_flash_attention=use_flash_attention,
|
730
|
-
is_causal=is_causal,
|
731
|
-
use_bias=use_bias,
|
732
|
-
num_experts=num_experts,
|
733
|
-
)
|
734
|
-
else:
|
735
|
-
return DeepMoeAttentionVectorized(
|
736
|
-
embed_dim,
|
737
|
-
num_heads,
|
738
|
-
gqa_groups,
|
739
|
-
dropout=dropout,
|
740
|
-
rope=rope,
|
741
|
-
use_relative_embeddings=use_relative_embeddings,
|
742
|
-
max_seq_len=max_seq_len,
|
743
|
-
rope_only_for_query=rope_only_for_query,
|
744
|
-
use_flash_attention=use_flash_attention,
|
745
|
-
is_causal=is_causal,
|
746
|
-
use_bias=use_bias,
|
747
|
-
num_experts=num_experts,
|
748
|
-
num_query_experts=num_query_experts,
|
749
|
-
num_query_groups=num_query_groups,
|
750
|
-
)
|
rxnn/experimental/models.py
CHANGED
@@ -35,7 +35,7 @@ class MoeAttentionTransformerConfig(TypedDict):
|
|
35
35
|
|
36
36
|
|
37
37
|
class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
38
|
-
"""Research model for experiments with Mixture-of-Experts Attention"""
|
38
|
+
"""Research decoder model for experiments with Mixture-of-Experts Attention"""
|
39
39
|
|
40
40
|
def __init__(
|
41
41
|
self,
|
@@ -65,8 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
65
65
|
assert ff_activation in ['relu', 'gelu',
|
66
66
|
'swish', 'silu', 'linear',
|
67
67
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
68
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', '
|
69
|
-
'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
|
68
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma"'
|
70
69
|
|
71
70
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
72
71
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -111,6 +110,5 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
111
110
|
def load_shared_embedding(self, embedding: nn.Embedding):
|
112
111
|
self.model.embedding = embedding
|
113
112
|
|
114
|
-
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) ->
|
115
|
-
torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
113
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
116
114
|
return self.model(x, attention_mask=attention_mask)
|
rxnn/training/bml.py
CHANGED
@@ -50,10 +50,14 @@ class MLMTrainer(BaseTrainer):
|
|
50
50
|
vocab_size: int,
|
51
51
|
use_amp: bool = False,
|
52
52
|
dtype: torch.dtype = None,
|
53
|
+
use_moe_aux_loss: bool = False,
|
54
|
+
moe_aux_loss_scale: float = 0.01,
|
53
55
|
**kwargs
|
54
56
|
):
|
55
57
|
super(MLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
|
56
58
|
self.vocab_size = vocab_size
|
59
|
+
self.use_moe_aux_loss = use_moe_aux_loss
|
60
|
+
self.moe_aux_loss_scale = moe_aux_loss_scale
|
57
61
|
|
58
62
|
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
59
63
|
inputs = batch['input_ids']
|
@@ -65,11 +69,32 @@ class MLMTrainer(BaseTrainer):
|
|
65
69
|
attention_mask=attention_mask
|
66
70
|
)
|
67
71
|
|
68
|
-
|
72
|
+
loss = F.cross_entropy(
|
69
73
|
logits.view(-1, self.vocab_size),
|
70
74
|
labels.view(-1),
|
71
75
|
ignore_index=-100
|
72
|
-
)
|
76
|
+
)
|
77
|
+
|
78
|
+
return self._moe_aux_loss(loss), logits
|
79
|
+
|
80
|
+
def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
|
81
|
+
if not self.use_moe_aux_loss:
|
82
|
+
return main_loss
|
83
|
+
|
84
|
+
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
85
|
+
|
86
|
+
router_loss = model.encoder.model.moe_router_loss()
|
87
|
+
loss = main_loss + self.moe_aux_loss_scale * router_loss
|
88
|
+
|
89
|
+
if self.writer is not None:
|
90
|
+
if self.model.training:
|
91
|
+
self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
|
92
|
+
self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
|
93
|
+
else:
|
94
|
+
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
|
95
|
+
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
|
96
|
+
|
97
|
+
return loss
|
73
98
|
|
74
99
|
def validate(self, batch_size: int) -> tuple[float, dict]:
|
75
100
|
self.model.eval()
|
@@ -113,11 +138,15 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
113
138
|
vocab_size: int,
|
114
139
|
use_amp: bool = False,
|
115
140
|
dtype: torch.dtype = None,
|
141
|
+
use_moe_aux_loss: bool = False,
|
142
|
+
moe_aux_loss_scale: float = 0.01,
|
116
143
|
**kwargs
|
117
144
|
):
|
118
145
|
super(AutoregressiveTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
|
119
146
|
target_field_name='targets', **kwargs)
|
120
147
|
self.vocab_size = vocab_size
|
148
|
+
self.use_moe_aux_loss = use_moe_aux_loss
|
149
|
+
self.moe_aux_loss_scale = moe_aux_loss_scale
|
121
150
|
|
122
151
|
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
123
152
|
inputs = batch['input_ids']
|
@@ -132,84 +161,21 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
132
161
|
shifted_logits = outputs[:, :-1].contiguous()
|
133
162
|
shifted_targets = targets[:, 1:].contiguous()
|
134
163
|
|
135
|
-
|
164
|
+
loss = F.cross_entropy(
|
136
165
|
shifted_logits.view(-1, self.vocab_size),
|
137
166
|
shifted_targets.view(-1)
|
138
|
-
), outputs
|
139
|
-
|
140
|
-
def validate(self, batch_size: int) -> tuple[float, dict]:
|
141
|
-
self.model.eval()
|
142
|
-
val_dataloader = self._valid_loader(batch_size)
|
143
|
-
val_loss = torch.tensor(0.0).to(self.device)
|
144
|
-
correct = torch.tensor(0).to(self.device)
|
145
|
-
total = torch.tensor(0).to(self.device)
|
146
|
-
|
147
|
-
with torch.no_grad():
|
148
|
-
for batch in val_dataloader:
|
149
|
-
if self.get_batch_size(batch) == batch_size:
|
150
|
-
loss, logits = self.valid_step(batch)
|
151
|
-
val_loss += loss
|
152
|
-
shifted_logits = logits[:, :-1].contiguous()
|
153
|
-
shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
|
154
|
-
valid_indices = shifted_targets != -100
|
155
|
-
if valid_indices.any():
|
156
|
-
preds = shifted_logits.argmax(-1)
|
157
|
-
correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
|
158
|
-
total += valid_indices.sum()
|
159
|
-
|
160
|
-
avg_loss = (val_loss / len(val_dataloader)).item()
|
161
|
-
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
162
|
-
node_acc = acc.item()
|
163
|
-
if self.use_ddp:
|
164
|
-
dist.all_reduce(acc, op=dist.ReduceOp.SUM)
|
165
|
-
acc = acc / dist.get_world_size()
|
166
|
-
|
167
|
-
metrics = {
|
168
|
-
'accuracy': acc.item(),
|
169
|
-
'node_accuracy': node_acc,
|
170
|
-
}
|
171
|
-
self.model.train()
|
172
|
-
return avg_loss, metrics
|
173
|
-
|
174
|
-
|
175
|
-
class AutoregressiveMoeTrainer(BaseTrainer):
|
176
|
-
def __init__(
|
177
|
-
self,
|
178
|
-
model: ReactiveTransformerDecoder,
|
179
|
-
device: torch.device,
|
180
|
-
vocab_size: int,
|
181
|
-
use_amp: bool = False,
|
182
|
-
dtype: torch.dtype = None,
|
183
|
-
router_loss_scale: float = 0.1,
|
184
|
-
**kwargs
|
185
|
-
):
|
186
|
-
super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
|
187
|
-
target_field_name='targets', **kwargs)
|
188
|
-
self.vocab_size = vocab_size
|
189
|
-
self.router_loss_scale = router_loss_scale
|
190
|
-
|
191
|
-
def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
192
|
-
inputs = batch['input_ids']
|
193
|
-
attention_mask = batch['attention_mask']
|
194
|
-
targets = batch['targets']
|
195
|
-
|
196
|
-
outputs = self.model(
|
197
|
-
inputs,
|
198
|
-
attention_mask=attention_mask
|
199
167
|
)
|
200
168
|
|
201
|
-
|
202
|
-
shifted_targets = targets[:, 1:].contiguous()
|
169
|
+
return self._moe_aux_loss(loss), outputs
|
203
170
|
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
)
|
171
|
+
def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
|
172
|
+
if not self.use_moe_aux_loss:
|
173
|
+
return main_loss
|
208
174
|
|
209
175
|
model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
|
210
176
|
|
211
177
|
router_loss = model.model.moe_router_loss()
|
212
|
-
loss = main_loss + self.
|
178
|
+
loss = main_loss + self.moe_aux_loss_scale * router_loss
|
213
179
|
|
214
180
|
if self.writer is not None:
|
215
181
|
if self.model.training:
|
@@ -219,7 +185,7 @@ class AutoregressiveMoeTrainer(BaseTrainer):
|
|
219
185
|
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
|
220
186
|
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
|
221
187
|
|
222
|
-
return loss
|
188
|
+
return loss
|
223
189
|
|
224
190
|
def validate(self, batch_size: int) -> tuple[float, dict]:
|
225
191
|
self.model.eval()
|
@@ -279,6 +245,9 @@ class JointTrainingModel(nn.Module):
|
|
279
245
|
|
280
246
|
|
281
247
|
class JointLMTrainer(BaseTrainer):
|
248
|
+
""""
|
249
|
+
It's not recommended to use Joint LM Training in current implementation. More info soon
|
250
|
+
"""
|
282
251
|
def __init__(
|
283
252
|
self,
|
284
253
|
model: JointTrainingModel,
|
rxnn/transformers/layers.py
CHANGED
@@ -132,9 +132,10 @@ class ClassicTransformerLayer(nn.Module):
|
|
132
132
|
|
133
133
|
if use_gated:
|
134
134
|
if use_moe:
|
135
|
-
self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, top_k=moe_top_k,
|
135
|
+
self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
|
136
|
+
dropout=ff_dropout)
|
136
137
|
else:
|
137
|
-
self.ff = GatedFeedForward(embed_dim, ff_dim, dropout=ff_dropout)
|
138
|
+
self.ff = GatedFeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
|
138
139
|
else:
|
139
140
|
if use_moe:
|
140
141
|
self.ff = MoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256
|
3
|
+
rxnn/experimental/attention.py,sha256=rJSQjA7_9YqcM4Y8SyJSuZQjyz8j4XPhC5jcrcrRK2M,17891
|
4
|
+
rxnn/experimental/models.py,sha256=8KAo7BtRkke9qRlzGRtQa9-EZ34roGWrn0N_T6L-6Wc,4561
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -10,7 +10,7 @@ rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
|
11
11
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
|
13
|
-
rxnn/training/bml.py,sha256=
|
13
|
+
rxnn/training/bml.py,sha256=HtxSzI-WcpRclAuIccF_WoTZ24KzH5ZWKe8SnWgjjm4,17581
|
14
14
|
rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
|
15
15
|
rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
@@ -18,14 +18,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
|
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=
|
21
|
+
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
24
24
|
rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.26.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.26.dist-info/METADATA,sha256=qKx2OFl-ca7OCZoSf99l5lOmg2czSCN8UN_Esqmfco0,16627
|
30
|
+
rxnn-0.1.26.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.26.dist-info/RECORD,,
|
File without changes
|
File without changes
|