rxnn 0.1.23__py3-none-any.whl → 0.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,6 +61,9 @@ class GroupedMoeAttention(GroupedQueryAttention):
61
61
  **kwargs,
62
62
  )
63
63
 
64
+ def router_loss(self):
65
+ return self.router.aux_loss
66
+
64
67
  def _init_kv(self, embed_dim: int):
65
68
  self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
66
69
 
@@ -194,6 +197,9 @@ class DeepMoeAttention(GroupedMoeAttention):
194
197
  **kwargs,
195
198
  )
196
199
 
200
+ def router_loss(self):
201
+ return (self.router.aux_loss + self.query_router.aux_loss) / 2
202
+
197
203
  def _init_q(self, embed_dim: int):
198
204
  self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
199
205
 
@@ -212,6 +218,11 @@ class DeepMoeAttention(GroupedMoeAttention):
212
218
  hidden_dim = embed_dim // (self.num_heads // self.num_query_groups)
213
219
  self.out_proj = nn.Linear(hidden_dim, embed_dim)
214
220
 
221
+ def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
222
+ """Transpose attention output back to (B, T, D) shape"""
223
+ hidden_dim = d // self.num_heads * self.num_query_groups
224
+ return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
225
+
215
226
  def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
216
227
  # Query processing
217
228
  B, T, D = query.shape
@@ -227,298 +238,6 @@ class DeepMoeAttention(GroupedMoeAttention):
227
238
  # Key/Value processing
228
239
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
229
240
 
230
- # Vectorized
231
-
232
- class GroupedMoeAttentionVectorized(GroupedQueryAttention):
233
- """
234
- Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
235
- for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
236
- experts - it has to be tested.
237
-
238
- Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
239
-
240
- Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
241
- number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
242
- - with num_groups set to 1, it will be MoE MultiQueryAttention
243
-
244
- Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
245
- this approach - we are training the full number of keys/values heads, while using only a group.
246
-
247
- In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
248
-
249
- Optionally, it could use even more expert heads than attention heads - in example:
250
- - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
251
- 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
252
-
253
- © 2025 Adam Filipek
254
- """
255
-
256
- def __init__(
257
- self,
258
- embed_dim: int,
259
- num_heads: int,
260
- num_groups: int,
261
- dropout: float = 0.0,
262
- rope: RotaryPositionalEmbedding = None,
263
- rope_only_for_query: bool = False,
264
- use_relative_embeddings: bool = False,
265
- max_seq_len: int = 1024,
266
- use_flash_attention: bool = False,
267
- is_causal: bool = False,
268
- use_bias: bool = False,
269
- num_experts: int = None,
270
- *args,
271
- **kwargs,
272
- ):
273
- self.num_experts = num_experts if num_experts is not None else num_heads
274
- super(GroupedMoeAttentionVectorized, self).__init__(
275
- embed_dim,
276
- num_heads,
277
- num_groups=num_groups,
278
- dropout=dropout,
279
- rope=rope,
280
- rope_only_for_query=rope_only_for_query,
281
- use_relative_embeddings=use_relative_embeddings,
282
- max_seq_len=max_seq_len,
283
- use_flash_attention=use_flash_attention,
284
- is_causal=is_causal,
285
- use_bias=use_bias,
286
- *args,
287
- **kwargs,
288
- )
289
-
290
- def _init_kv(self, embed_dim: int):
291
- self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
292
- hidden_dim = embed_dim // self.num_heads
293
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
294
- self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
295
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
296
- self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
297
- self._init_experts()
298
-
299
- def _init_experts(self):
300
- torch.nn.init.xavier_uniform_(self.wk)
301
- torch.nn.init.xavier_uniform_(self.wv)
302
- if self.use_bias:
303
- torch.nn.init.zeros_(self.bk)
304
- torch.nn.init.zeros_(self.bv)
305
-
306
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
307
- skip_query_processing: bool = False):
308
- # Indexed version may cause memory overflow
309
- #
310
- # head_dim = d // self.num_heads
311
- #
312
- # # Process Query as in GQA
313
- # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
314
- # 2) if not skip_query_processing else query
315
- #
316
- # # Process Key and Value with MoE routing
317
- # key_flat = key.view(-1, d) # (B*S, d)
318
- # value_flat = value.view(-1, d) # (B*S, d)
319
- #
320
- # # Get routing indices and weights for K
321
- # weights_k, indices_k = self.router(key_flat)
322
- # indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
323
- # weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
324
- #
325
- # # Select and compute K projections for only the top_k experts
326
- # selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
327
- # k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
328
- # selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
329
- # selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
330
- #
331
- # # Compute V using the same indices as K (since they share the same router)
332
- # selected_v_weights = self.v_experts[indices_k]
333
- # v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
334
- # selected_v = (v_proj * weights_k).sum(dim=1)
335
- # selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
336
- #
337
- # # Reshape to GQA format: (B, G, S, head_dim)
338
- # k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
339
- # v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
340
- #
341
- # if not self.use_flash_attention:
342
- # group_heads = self.num_heads // self.num_groups
343
- #
344
- # k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
345
- # v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
346
- #
347
- # k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
348
- # v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
349
- #
350
- # return q, k, v
351
-
352
- head_dim = d // self.num_heads
353
-
354
- # Process Query as in GQA
355
- q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
356
-
357
- # Process Key and Value with MoE routing
358
- key_flat = key.view(-1, d)
359
- weights, indices = self.router(key_flat)
360
- weights = weights.view(b, key.size(1), self.num_groups, 1)
361
- indices = indices.view(b, key.size(1), self.num_groups)
362
-
363
- # Compute all experts' K and V projections
364
- # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
365
- k_all = torch.einsum(
366
- 'be, ehd -> bedh',
367
- key_flat,
368
- self.wk.view(self.num_experts, d, -1)
369
- ).view(b, key.size(1), self.num_experts, -1)
370
-
371
- v_all = torch.einsum(
372
- 'be, ehd -> bedh',
373
- value.view(-1, d),
374
- self.wv.view(self.num_experts, d, -1)
375
- ).view(b, value.size(1), self.num_experts, -1)
376
-
377
- # Select top_k experts and compute weighted sum
378
- selected_k = torch.gather(
379
- k_all,
380
- 2,
381
- indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
382
- )
383
- selected_v = torch.gather(
384
- v_all,
385
- 2,
386
- indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
387
- )
388
-
389
- selected_k = (selected_k * weights).sum(dim=2)
390
- selected_v = (selected_v * weights).sum(dim=2)
391
- # Reshape to GQA format: (B, G, S, head_dim)
392
- k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
393
- v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
394
-
395
- if not self.use_flash_attention:
396
- group_heads = self.num_heads // self.num_groups
397
-
398
- k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
399
- v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
400
-
401
- k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
402
- v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
403
-
404
- return q, k, v
405
-
406
-
407
- class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
408
- """
409
- Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
410
- for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
411
- experts - it has to be tested.
412
-
413
- Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
414
-
415
- In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
416
- query heads - with that approach, each token could attend to every other token, but only partially - only some part of
417
- information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
418
- sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
419
-
420
- This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
421
- a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
422
-
423
- © 2025 Adam Filipek
424
- """
425
-
426
- def __init__(
427
- self,
428
- embed_dim: int,
429
- num_heads: int,
430
- num_groups: int,
431
- dropout: float = 0.0,
432
- rope: RotaryPositionalEmbedding = None,
433
- rope_only_for_query: bool = False,
434
- use_relative_embeddings: bool = False,
435
- max_seq_len: int = 1024,
436
- use_flash_attention: bool = False,
437
- is_causal: bool = False,
438
- use_bias: bool = False,
439
- num_experts: int = None,
440
- num_query_experts: int = None,
441
- num_query_groups: int = None,
442
- *args,
443
- **kwargs,
444
- ):
445
- self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
446
- self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
447
- super(DeepMoeAttentionVectorized, self).__init__(
448
- embed_dim,
449
- num_heads,
450
- num_groups=num_groups,
451
- dropout=dropout,
452
- rope=rope,
453
- rope_only_for_query=rope_only_for_query,
454
- use_relative_embeddings=use_relative_embeddings,
455
- max_seq_len=max_seq_len,
456
- use_flash_attention=use_flash_attention,
457
- is_causal=is_causal,
458
- use_bias=use_bias,
459
- num_experts=num_experts,
460
- *args,
461
- **kwargs,
462
- )
463
-
464
- def _init_q(self, embed_dim: int):
465
- self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
466
- hidden_dim = embed_dim // self.num_heads
467
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
468
- self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
469
- self._init_query_experts()
470
-
471
- def _init_query_experts(self):
472
- torch.nn.init.xavier_uniform_(self.wq)
473
- if self.use_bias:
474
- torch.nn.init.zeros_(self.bq)
475
-
476
- def _init_out(self, embed_dim: int):
477
- """Initialize output projection"""
478
- self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
479
-
480
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
481
- # Indexed version may cause memory overflow
482
- #
483
- # head_dim = d // self.num_heads
484
- #
485
- # # Process Query with MoE routing
486
- # query_flat = query.view(-1, d) # (B*T, d)
487
- # weights_q, indices_q = self.query_router(query_flat)
488
- # indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
489
- # weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
490
- #
491
- # # Select and compute Q projections for top_k experts
492
- # selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
493
- # q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
494
- # selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
495
- # selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
496
- head_dim = d // self.num_heads
497
-
498
- # Process Query with MoE routing
499
- query_flat = query.view(b * t, d)
500
- weights_q, indices_q = self.query_router(query_flat)
501
- weights_q = weights_q.view(b, t, self.num_query_groups, 1)
502
- indices_q = indices_q.view(b, t, self.num_query_groups)
503
-
504
- # Compute all experts' Q projections
505
- q_all = torch.einsum(
506
- 'be, ehd -> bedh',
507
- query_flat,
508
- self.wq.view(self.num_query_experts, d, -1)
509
- ).view(b, t, self.num_query_experts, -1)
510
-
511
- selected_q = torch.gather(
512
- q_all,
513
- 2,
514
- indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
515
- )
516
- selected_q = (selected_q * weights_q).sum(dim=2)
517
-
518
- q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
519
-
520
- return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
521
-
522
241
 
523
242
  # Others
524
243
 
@@ -670,8 +389,7 @@ def init_moe_attention(
670
389
  num_query_experts: int = None,
671
390
  num_query_groups: int = None,
672
391
  ) -> GroupedQueryAttention:
673
- assert attention_type == 'gma' or attention_type == 'dma' or attention_type == 'gma_v' or attention_type == 'dma_v', \
674
- "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
392
+ assert attention_type in ['gma', 'dma'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
675
393
 
676
394
  if attention_type == "gma":
677
395
  return GroupedMoeAttention(
@@ -688,7 +406,7 @@ def init_moe_attention(
688
406
  use_bias=use_bias,
689
407
  num_experts=num_experts,
690
408
  )
691
- elif attention_type == "dma":
409
+ else:
692
410
  return DeepMoeAttention(
693
411
  embed_dim,
694
412
  num_heads,
@@ -705,35 +423,3 @@ def init_moe_attention(
705
423
  num_query_experts=num_query_experts,
706
424
  num_query_groups=num_query_groups,
707
425
  )
708
- elif attention_type == "gma_v":
709
- return GroupedMoeAttentionVectorized(
710
- embed_dim,
711
- num_heads,
712
- gqa_groups,
713
- dropout=dropout,
714
- rope=rope,
715
- use_relative_embeddings=use_relative_embeddings,
716
- max_seq_len=max_seq_len,
717
- rope_only_for_query=rope_only_for_query,
718
- use_flash_attention=use_flash_attention,
719
- is_causal=is_causal,
720
- use_bias=use_bias,
721
- num_experts=num_experts,
722
- )
723
- else:
724
- return DeepMoeAttentionVectorized(
725
- embed_dim,
726
- num_heads,
727
- gqa_groups,
728
- dropout=dropout,
729
- rope=rope,
730
- use_relative_embeddings=use_relative_embeddings,
731
- max_seq_len=max_seq_len,
732
- rope_only_for_query=rope_only_for_query,
733
- use_flash_attention=use_flash_attention,
734
- is_causal=is_causal,
735
- use_bias=use_bias,
736
- num_experts=num_experts,
737
- num_query_experts=num_query_experts,
738
- num_query_groups=num_query_groups,
739
- )
@@ -35,7 +35,7 @@ class MoeAttentionTransformerConfig(TypedDict):
35
35
 
36
36
 
37
37
  class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
38
- """Research model for experiments with Mixture-of-Experts Attention"""
38
+ """Research decoder model for experiments with Mixture-of-Experts Attention"""
39
39
 
40
40
  def __init__(
41
41
  self,
@@ -65,8 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
65
65
  assert ff_activation in ['relu', 'gelu',
66
66
  'swish', 'silu', 'linear',
67
67
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
68
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v',
69
- 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
68
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma"'
70
69
 
71
70
  embedding = nn.Embedding(vocab_size, embed_dim)
72
71
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -111,6 +110,5 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
111
110
  def load_shared_embedding(self, embedding: nn.Embedding):
112
111
  self.model.embedding = embedding
113
112
 
114
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
115
- torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
113
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
116
114
  return self.model(x, attention_mask=attention_mask)
rxnn/training/bml.py CHANGED
@@ -50,10 +50,14 @@ class MLMTrainer(BaseTrainer):
50
50
  vocab_size: int,
51
51
  use_amp: bool = False,
52
52
  dtype: torch.dtype = None,
53
+ use_moe_aux_loss: bool = False,
54
+ moe_aux_loss_scale: float = 0.01,
53
55
  **kwargs
54
56
  ):
55
57
  super(MLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
56
58
  self.vocab_size = vocab_size
59
+ self.use_moe_aux_loss = use_moe_aux_loss
60
+ self.moe_aux_loss_scale = moe_aux_loss_scale
57
61
 
58
62
  def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
59
63
  inputs = batch['input_ids']
@@ -65,11 +69,32 @@ class MLMTrainer(BaseTrainer):
65
69
  attention_mask=attention_mask
66
70
  )
67
71
 
68
- return F.cross_entropy(
72
+ loss = F.cross_entropy(
69
73
  logits.view(-1, self.vocab_size),
70
74
  labels.view(-1),
71
75
  ignore_index=-100
72
- ), logits
76
+ )
77
+
78
+ return self._moe_aux_loss(loss), logits
79
+
80
+ def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
81
+ if not self.use_moe_aux_loss:
82
+ return main_loss
83
+
84
+ model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
85
+
86
+ router_loss = model.encoder.model.moe_router_loss()
87
+ loss = main_loss + self.moe_aux_loss_scale * router_loss
88
+
89
+ if self.writer is not None:
90
+ if self.model.training:
91
+ self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
92
+ self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
93
+ else:
94
+ self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
95
+ self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
96
+
97
+ return loss
73
98
 
74
99
  def validate(self, batch_size: int) -> tuple[float, dict]:
75
100
  self.model.eval()
@@ -113,11 +138,15 @@ class AutoregressiveTrainer(BaseTrainer):
113
138
  vocab_size: int,
114
139
  use_amp: bool = False,
115
140
  dtype: torch.dtype = None,
141
+ use_moe_aux_loss: bool = False,
142
+ moe_aux_loss_scale: float = 0.01,
116
143
  **kwargs
117
144
  ):
118
145
  super(AutoregressiveTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
119
146
  target_field_name='targets', **kwargs)
120
147
  self.vocab_size = vocab_size
148
+ self.use_moe_aux_loss = use_moe_aux_loss
149
+ self.moe_aux_loss_scale = moe_aux_loss_scale
121
150
 
122
151
  def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
123
152
  inputs = batch['input_ids']
@@ -132,84 +161,21 @@ class AutoregressiveTrainer(BaseTrainer):
132
161
  shifted_logits = outputs[:, :-1].contiguous()
133
162
  shifted_targets = targets[:, 1:].contiguous()
134
163
 
135
- return F.cross_entropy(
164
+ loss = F.cross_entropy(
136
165
  shifted_logits.view(-1, self.vocab_size),
137
166
  shifted_targets.view(-1)
138
- ), outputs
139
-
140
- def validate(self, batch_size: int) -> tuple[float, dict]:
141
- self.model.eval()
142
- val_dataloader = self._valid_loader(batch_size)
143
- val_loss = torch.tensor(0.0).to(self.device)
144
- correct = torch.tensor(0).to(self.device)
145
- total = torch.tensor(0).to(self.device)
146
-
147
- with torch.no_grad():
148
- for batch in val_dataloader:
149
- if self.get_batch_size(batch) == batch_size:
150
- loss, logits = self.valid_step(batch)
151
- val_loss += loss
152
- shifted_logits = logits[:, :-1].contiguous()
153
- shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
154
- valid_indices = shifted_targets != -100
155
- if valid_indices.any():
156
- preds = shifted_logits.argmax(-1)
157
- correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
158
- total += valid_indices.sum()
159
-
160
- avg_loss = (val_loss / len(val_dataloader)).item()
161
- acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
162
- node_acc = acc.item()
163
- if self.use_ddp:
164
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
165
- acc = acc / dist.get_world_size()
166
-
167
- metrics = {
168
- 'accuracy': acc.item(),
169
- 'node_accuracy': node_acc,
170
- }
171
- self.model.train()
172
- return avg_loss, metrics
173
-
174
-
175
- class AutoregressiveMoeTrainer(BaseTrainer):
176
- def __init__(
177
- self,
178
- model: ReactiveTransformerDecoder,
179
- device: torch.device,
180
- vocab_size: int,
181
- use_amp: bool = False,
182
- dtype: torch.dtype = None,
183
- router_loss_scale: float = 0.1,
184
- **kwargs
185
- ):
186
- super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
187
- target_field_name='targets', **kwargs)
188
- self.vocab_size = vocab_size
189
- self.router_loss_scale = router_loss_scale
190
-
191
- def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
192
- inputs = batch['input_ids']
193
- attention_mask = batch['attention_mask']
194
- targets = batch['targets']
195
-
196
- outputs = self.model(
197
- inputs,
198
- attention_mask=attention_mask
199
167
  )
200
168
 
201
- shifted_logits = outputs[:, :-1].contiguous()
202
- shifted_targets = targets[:, 1:].contiguous()
169
+ return self._moe_aux_loss(loss), outputs
203
170
 
204
- main_loss = F.cross_entropy(
205
- shifted_logits.view(-1, self.vocab_size),
206
- shifted_targets.view(-1)
207
- )
171
+ def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
172
+ if not self.use_moe_aux_loss:
173
+ return main_loss
208
174
 
209
175
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
210
176
 
211
177
  router_loss = model.model.moe_router_loss()
212
- loss = main_loss + self.router_loss_scale * router_loss
178
+ loss = main_loss + self.moe_aux_loss_scale * router_loss
213
179
 
214
180
  if self.writer is not None:
215
181
  if self.model.training:
@@ -219,7 +185,7 @@ class AutoregressiveMoeTrainer(BaseTrainer):
219
185
  self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
220
186
  self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
221
187
 
222
- return loss, outputs
188
+ return loss
223
189
 
224
190
  def validate(self, batch_size: int) -> tuple[float, dict]:
225
191
  self.model.eval()
@@ -279,6 +245,9 @@ class JointTrainingModel(nn.Module):
279
245
 
280
246
 
281
247
  class JointLMTrainer(BaseTrainer):
248
+ """"
249
+ It's not recommended to use Joint LM Training in current implementation. More info soon
250
+ """
282
251
  def __init__(
283
252
  self,
284
253
  model: JointTrainingModel,
@@ -91,9 +91,13 @@ class MultiHeadAttention(nn.Module):
91
91
  attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))
92
92
  return F.softmax(attn_logits, dim=-1)
93
93
 
94
+ def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
95
+ """Transpose attention output back to (B, T, D) shape"""
96
+ return attn_output.transpose(1, 2).contiguous().view(b, t, d)
97
+
94
98
  def _calculate_output(self, attn_weights: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int):
95
99
  """Calculate the output by multiplying attention weights with values and concatenating heads"""
96
- return torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(b, t, d)
100
+ return self._transpose_output(torch.matmul(attn_weights, v), b, t, d)
97
101
 
98
102
  def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
99
103
  mask: torch.Tensor = None, enable_gqa: bool = False):
@@ -104,9 +108,7 @@ class MultiHeadAttention(nn.Module):
104
108
  is_causal=self.is_causal,
105
109
  enable_gqa=enable_gqa,
106
110
  )
107
-
108
- # Reshape back to (B, T, D)
109
- return attn_output.transpose(1, 2).contiguous().view(b, t, d)
111
+ return self._transpose_output(attn_output, b, t, d)
110
112
 
111
113
  def _calculate_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, b: int, t: int, d: int,
112
114
  mask: torch.Tensor = None):
@@ -60,7 +60,23 @@ class ReactiveTransformerLayer(nn.Module):
60
60
  param.requires_grad_(is_trainable)
61
61
 
62
62
  def moe_router_loss(self):
63
- return self.ff.router_loss() if self.use_moe else None
63
+ ff_router_loss = self.ff.router_loss() if self.use_moe else None
64
+ att_router_loss = None
65
+ if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
66
+ att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
67
+ elif self.attention.router_loss is not None:
68
+ att_router_loss = self.attention.router_loss()
69
+ elif self.memory_cross_attention.router_loss is not None:
70
+ att_router_loss = self.memory_cross_attention.router_loss()
71
+
72
+ if ff_router_loss is not None and att_router_loss is not None:
73
+ return (ff_router_loss + att_router_loss) / 2
74
+ elif ff_router_loss is not None:
75
+ return ff_router_loss
76
+ elif att_router_loss is not None:
77
+ return att_router_loss
78
+ else:
79
+ return None
64
80
 
65
81
  def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
66
82
  # First step, self-attention
@@ -136,7 +152,17 @@ class ClassicTransformerLayer(nn.Module):
136
152
  self.use_moe = use_moe
137
153
 
138
154
  def moe_router_loss(self):
139
- return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
155
+ ff_router_loss = self.ff.router_loss() if self.use_moe else None
156
+ att_router_loss = self.attention.router_loss() if self.attention.router_loss is not None else None
157
+
158
+ if ff_router_loss is not None and att_router_loss is not None:
159
+ return (ff_router_loss + att_router_loss) / 2
160
+ elif ff_router_loss is not None:
161
+ return ff_router_loss
162
+ elif att_router_loss is not None:
163
+ return att_router_loss
164
+ else:
165
+ return None
140
166
 
141
167
  def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
142
168
  # First step, self-attention
rxnn/transformers/moe.py CHANGED
@@ -20,7 +20,6 @@ class MoeRouter(nn.Module):
20
20
  mean_probs = probs.mean(dim=0)
21
21
  return (expert_usage * mean_probs).sum() * self.num_experts
22
22
 
23
-
24
23
  def forward(self, x: torch.Tensor):
25
24
  # Input shape: [batch*seq_len, embed_dim]
26
25
  logits = self.gate(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.23
3
+ Version: 0.1.25
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,7 +1,7 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=gMEcFJHGOkz8R_s4dGEJB5cb2K3pbXZi4XBwyhEdB4s,31967
4
- rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
3
+ rxnn/experimental/attention.py,sha256=rJSQjA7_9YqcM4Y8SyJSuZQjyz8j4XPhC5jcrcrRK2M,17891
4
+ rxnn/experimental/models.py,sha256=8KAo7BtRkke9qRlzGRtQa9-EZ34roGWrn0N_T6L-6Wc,4561
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
@@ -10,22 +10,22 @@ rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
11
11
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
13
- rxnn/training/bml.py,sha256=o_88ZL1YWd5gWXaBqYPK2UzSTbJaiTiw96E6z73LeOQ,18660
13
+ rxnn/training/bml.py,sha256=HtxSzI-WcpRclAuIccF_WoTZ24KzH5ZWKe8SnWgjjm4,17581
14
14
  rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
15
15
  rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
17
17
  rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
19
+ rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
- rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
21
+ rxnn/transformers/layers.py,sha256=ZJfNdgCv9dzrKqsWIMf99Ryzgs494ZhkwK4zSBYLvQ4,6880
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
23
  rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
24
- rxnn/transformers/moe.py,sha256=msspVdefdt2ekIN8aT-V8DolK4taESQL_NVsSGOepIs,4739
24
+ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.23.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.23.dist-info/METADATA,sha256=rZSBuoIgf8jKB11LKgMg7U42Wx7VNT_4EU3FVyED2YQ,16627
30
- rxnn-0.1.23.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.23.dist-info/RECORD,,
28
+ rxnn-0.1.25.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.25.dist-info/METADATA,sha256=kRnBikg0u4uhy5IwcFpEE3gFEB8wTvEpg6cFcAU8FGs,16627
30
+ rxnn-0.1.25.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.25.dist-info/RECORD,,
File without changes
File without changes