rxnn 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -238,298 +238,6 @@ class DeepMoeAttention(GroupedMoeAttention):
238
238
  # Key/Value processing
239
239
  return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
240
240
 
241
- # Vectorized
242
-
243
- class GroupedMoeAttentionVectorized(GroupedQueryAttention):
244
- """
245
- Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
246
- for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
247
- experts - it has to be tested.
248
-
249
- Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
250
-
251
- Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
252
- number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
253
- - with num_groups set to 1, it will be MoE MultiQueryAttention
254
-
255
- Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
256
- this approach - we are training the full number of keys/values heads, while using only a group.
257
-
258
- In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
259
-
260
- Optionally, it could use even more expert heads than attention heads - in example:
261
- - 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
262
- 4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
263
-
264
- © 2025 Adam Filipek
265
- """
266
-
267
- def __init__(
268
- self,
269
- embed_dim: int,
270
- num_heads: int,
271
- num_groups: int,
272
- dropout: float = 0.0,
273
- rope: RotaryPositionalEmbedding = None,
274
- rope_only_for_query: bool = False,
275
- use_relative_embeddings: bool = False,
276
- max_seq_len: int = 1024,
277
- use_flash_attention: bool = False,
278
- is_causal: bool = False,
279
- use_bias: bool = False,
280
- num_experts: int = None,
281
- *args,
282
- **kwargs,
283
- ):
284
- self.num_experts = num_experts if num_experts is not None else num_heads
285
- super(GroupedMoeAttentionVectorized, self).__init__(
286
- embed_dim,
287
- num_heads,
288
- num_groups=num_groups,
289
- dropout=dropout,
290
- rope=rope,
291
- rope_only_for_query=rope_only_for_query,
292
- use_relative_embeddings=use_relative_embeddings,
293
- max_seq_len=max_seq_len,
294
- use_flash_attention=use_flash_attention,
295
- is_causal=is_causal,
296
- use_bias=use_bias,
297
- *args,
298
- **kwargs,
299
- )
300
-
301
- def _init_kv(self, embed_dim: int):
302
- self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
303
- hidden_dim = embed_dim // self.num_heads
304
- self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
305
- self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
306
- self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
307
- self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
308
- self._init_experts()
309
-
310
- def _init_experts(self):
311
- torch.nn.init.xavier_uniform_(self.wk)
312
- torch.nn.init.xavier_uniform_(self.wv)
313
- if self.use_bias:
314
- torch.nn.init.zeros_(self.bk)
315
- torch.nn.init.zeros_(self.bv)
316
-
317
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
318
- skip_query_processing: bool = False):
319
- # Indexed version may cause memory overflow
320
- #
321
- # head_dim = d // self.num_heads
322
- #
323
- # # Process Query as in GQA
324
- # q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1,
325
- # 2) if not skip_query_processing else query
326
- #
327
- # # Process Key and Value with MoE routing
328
- # key_flat = key.view(-1, d) # (B*S, d)
329
- # value_flat = value.view(-1, d) # (B*S, d)
330
- #
331
- # # Get routing indices and weights for K
332
- # weights_k, indices_k = self.router(key_flat)
333
- # indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
334
- # weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
335
- #
336
- # # Select and compute K projections for only the top_k experts
337
- # selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
338
- # k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
339
- # selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
340
- # selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
341
- #
342
- # # Compute V using the same indices as K (since they share the same router)
343
- # selected_v_weights = self.v_experts[indices_k]
344
- # v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
345
- # selected_v = (v_proj * weights_k).sum(dim=1)
346
- # selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
347
- #
348
- # # Reshape to GQA format: (B, G, S, head_dim)
349
- # k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
350
- # v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
351
- #
352
- # if not self.use_flash_attention:
353
- # group_heads = self.num_heads // self.num_groups
354
- #
355
- # k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
356
- # v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
357
- #
358
- # k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
359
- # v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
360
- #
361
- # return q, k, v
362
-
363
- head_dim = d // self.num_heads
364
-
365
- # Process Query as in GQA
366
- q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
367
-
368
- # Process Key and Value with MoE routing
369
- key_flat = key.view(-1, d)
370
- weights, indices = self.router(key_flat)
371
- weights = weights.view(b, key.size(1), self.num_groups, 1)
372
- indices = indices.view(b, key.size(1), self.num_groups)
373
-
374
- # Compute all experts' K and V projections
375
- # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
376
- k_all = torch.einsum(
377
- 'be, ehd -> bedh',
378
- key_flat,
379
- self.wk.view(self.num_experts, d, -1)
380
- ).view(b, key.size(1), self.num_experts, -1)
381
-
382
- v_all = torch.einsum(
383
- 'be, ehd -> bedh',
384
- value.view(-1, d),
385
- self.wv.view(self.num_experts, d, -1)
386
- ).view(b, value.size(1), self.num_experts, -1)
387
-
388
- # Select top_k experts and compute weighted sum
389
- selected_k = torch.gather(
390
- k_all,
391
- 2,
392
- indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
393
- )
394
- selected_v = torch.gather(
395
- v_all,
396
- 2,
397
- indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
398
- )
399
-
400
- selected_k = (selected_k * weights).sum(dim=2)
401
- selected_v = (selected_v * weights).sum(dim=2)
402
- # Reshape to GQA format: (B, G, S, head_dim)
403
- k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
404
- v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
405
-
406
- if not self.use_flash_attention:
407
- group_heads = self.num_heads // self.num_groups
408
-
409
- k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
410
- v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
411
-
412
- k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
413
- v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
414
-
415
- return q, k, v
416
-
417
-
418
- class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
419
- """
420
- Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
421
- for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
422
- experts - it has to be tested.
423
-
424
- Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
425
-
426
- In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
427
- query heads - with that approach, each token could attend to every other token, but only partially - only some part of
428
- information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
429
- sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
430
-
431
- This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
432
- a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
433
-
434
- © 2025 Adam Filipek
435
- """
436
-
437
- def __init__(
438
- self,
439
- embed_dim: int,
440
- num_heads: int,
441
- num_groups: int,
442
- dropout: float = 0.0,
443
- rope: RotaryPositionalEmbedding = None,
444
- rope_only_for_query: bool = False,
445
- use_relative_embeddings: bool = False,
446
- max_seq_len: int = 1024,
447
- use_flash_attention: bool = False,
448
- is_causal: bool = False,
449
- use_bias: bool = False,
450
- num_experts: int = None,
451
- num_query_experts: int = None,
452
- num_query_groups: int = None,
453
- *args,
454
- **kwargs,
455
- ):
456
- self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
457
- self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
458
- super(DeepMoeAttentionVectorized, self).__init__(
459
- embed_dim,
460
- num_heads,
461
- num_groups=num_groups,
462
- dropout=dropout,
463
- rope=rope,
464
- rope_only_for_query=rope_only_for_query,
465
- use_relative_embeddings=use_relative_embeddings,
466
- max_seq_len=max_seq_len,
467
- use_flash_attention=use_flash_attention,
468
- is_causal=is_causal,
469
- use_bias=use_bias,
470
- num_experts=num_experts,
471
- *args,
472
- **kwargs,
473
- )
474
-
475
- def _init_q(self, embed_dim: int):
476
- self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
477
- hidden_dim = embed_dim // self.num_heads
478
- self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
479
- self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
480
- self._init_query_experts()
481
-
482
- def _init_query_experts(self):
483
- torch.nn.init.xavier_uniform_(self.wq)
484
- if self.use_bias:
485
- torch.nn.init.zeros_(self.bq)
486
-
487
- def _init_out(self, embed_dim: int):
488
- """Initialize output projection"""
489
- self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
490
-
491
- def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
492
- # Indexed version may cause memory overflow
493
- #
494
- # head_dim = d // self.num_heads
495
- #
496
- # # Process Query with MoE routing
497
- # query_flat = query.view(-1, d) # (B*T, d)
498
- # weights_q, indices_q = self.query_router(query_flat)
499
- # indices_q = indices_q.view(-1, self.num_query_groups) # (B*T, top_k_q)
500
- # weights_q = weights_q.view(-1, self.num_query_groups, 1) # (B*T, top_k_q, 1)
501
- #
502
- # # Select and compute Q projections for top_k experts
503
- # selected_q_weights = self.wq[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
504
- # q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
505
- # selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
506
- # selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
507
- head_dim = d // self.num_heads
508
-
509
- # Process Query with MoE routing
510
- query_flat = query.view(b * t, d)
511
- weights_q, indices_q = self.query_router(query_flat)
512
- weights_q = weights_q.view(b, t, self.num_query_groups, 1)
513
- indices_q = indices_q.view(b, t, self.num_query_groups)
514
-
515
- # Compute all experts' Q projections
516
- q_all = torch.einsum(
517
- 'be, ehd -> bedh',
518
- query_flat,
519
- self.wq.view(self.num_query_experts, d, -1)
520
- ).view(b, t, self.num_query_experts, -1)
521
-
522
- selected_q = torch.gather(
523
- q_all,
524
- 2,
525
- indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.shape[-1])
526
- )
527
- selected_q = (selected_q * weights_q).sum(dim=2)
528
-
529
- q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
530
-
531
- return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
532
-
533
241
 
534
242
  # Others
535
243
 
@@ -681,8 +389,7 @@ def init_moe_attention(
681
389
  num_query_experts: int = None,
682
390
  num_query_groups: int = None,
683
391
  ) -> GroupedQueryAttention:
684
- assert attention_type == 'gma' or attention_type == 'dma' or attention_type == 'gma_v' or attention_type == 'dma_v', \
685
- "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
392
+ assert attention_type in ['gma', 'dma'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
686
393
 
687
394
  if attention_type == "gma":
688
395
  return GroupedMoeAttention(
@@ -699,7 +406,7 @@ def init_moe_attention(
699
406
  use_bias=use_bias,
700
407
  num_experts=num_experts,
701
408
  )
702
- elif attention_type == "dma":
409
+ else:
703
410
  return DeepMoeAttention(
704
411
  embed_dim,
705
412
  num_heads,
@@ -716,35 +423,3 @@ def init_moe_attention(
716
423
  num_query_experts=num_query_experts,
717
424
  num_query_groups=num_query_groups,
718
425
  )
719
- elif attention_type == "gma_v":
720
- return GroupedMoeAttentionVectorized(
721
- embed_dim,
722
- num_heads,
723
- gqa_groups,
724
- dropout=dropout,
725
- rope=rope,
726
- use_relative_embeddings=use_relative_embeddings,
727
- max_seq_len=max_seq_len,
728
- rope_only_for_query=rope_only_for_query,
729
- use_flash_attention=use_flash_attention,
730
- is_causal=is_causal,
731
- use_bias=use_bias,
732
- num_experts=num_experts,
733
- )
734
- else:
735
- return DeepMoeAttentionVectorized(
736
- embed_dim,
737
- num_heads,
738
- gqa_groups,
739
- dropout=dropout,
740
- rope=rope,
741
- use_relative_embeddings=use_relative_embeddings,
742
- max_seq_len=max_seq_len,
743
- rope_only_for_query=rope_only_for_query,
744
- use_flash_attention=use_flash_attention,
745
- is_causal=is_causal,
746
- use_bias=use_bias,
747
- num_experts=num_experts,
748
- num_query_experts=num_query_experts,
749
- num_query_groups=num_query_groups,
750
- )
@@ -35,7 +35,7 @@ class MoeAttentionTransformerConfig(TypedDict):
35
35
 
36
36
 
37
37
  class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
38
- """Research model for experiments with Mixture-of-Experts Attention"""
38
+ """Research decoder model for experiments with Mixture-of-Experts Attention"""
39
39
 
40
40
  def __init__(
41
41
  self,
@@ -65,8 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
65
65
  assert ff_activation in ['relu', 'gelu',
66
66
  'swish', 'silu', 'linear',
67
67
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
68
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v',
69
- 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v"'
68
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma"'
70
69
 
71
70
  embedding = nn.Embedding(vocab_size, embed_dim)
72
71
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -111,6 +110,5 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
111
110
  def load_shared_embedding(self, embedding: nn.Embedding):
112
111
  self.model.embedding = embedding
113
112
 
114
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> Union[
115
- torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
113
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
116
114
  return self.model(x, attention_mask=attention_mask)
rxnn/training/bml.py CHANGED
@@ -50,10 +50,14 @@ class MLMTrainer(BaseTrainer):
50
50
  vocab_size: int,
51
51
  use_amp: bool = False,
52
52
  dtype: torch.dtype = None,
53
+ use_moe_aux_loss: bool = False,
54
+ moe_aux_loss_scale: float = 0.01,
53
55
  **kwargs
54
56
  ):
55
57
  super(MLMTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype, **kwargs)
56
58
  self.vocab_size = vocab_size
59
+ self.use_moe_aux_loss = use_moe_aux_loss
60
+ self.moe_aux_loss_scale = moe_aux_loss_scale
57
61
 
58
62
  def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
59
63
  inputs = batch['input_ids']
@@ -65,11 +69,32 @@ class MLMTrainer(BaseTrainer):
65
69
  attention_mask=attention_mask
66
70
  )
67
71
 
68
- return F.cross_entropy(
72
+ loss = F.cross_entropy(
69
73
  logits.view(-1, self.vocab_size),
70
74
  labels.view(-1),
71
75
  ignore_index=-100
72
- ), logits
76
+ )
77
+
78
+ return self._moe_aux_loss(loss), logits
79
+
80
+ def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
81
+ if not self.use_moe_aux_loss:
82
+ return main_loss
83
+
84
+ model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
85
+
86
+ router_loss = model.encoder.model.moe_router_loss()
87
+ loss = main_loss + self.moe_aux_loss_scale * router_loss
88
+
89
+ if self.writer is not None:
90
+ if self.model.training:
91
+ self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
92
+ self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
93
+ else:
94
+ self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
95
+ self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
96
+
97
+ return loss
73
98
 
74
99
  def validate(self, batch_size: int) -> tuple[float, dict]:
75
100
  self.model.eval()
@@ -113,11 +138,15 @@ class AutoregressiveTrainer(BaseTrainer):
113
138
  vocab_size: int,
114
139
  use_amp: bool = False,
115
140
  dtype: torch.dtype = None,
141
+ use_moe_aux_loss: bool = False,
142
+ moe_aux_loss_scale: float = 0.01,
116
143
  **kwargs
117
144
  ):
118
145
  super(AutoregressiveTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
119
146
  target_field_name='targets', **kwargs)
120
147
  self.vocab_size = vocab_size
148
+ self.use_moe_aux_loss = use_moe_aux_loss
149
+ self.moe_aux_loss_scale = moe_aux_loss_scale
121
150
 
122
151
  def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
123
152
  inputs = batch['input_ids']
@@ -132,84 +161,21 @@ class AutoregressiveTrainer(BaseTrainer):
132
161
  shifted_logits = outputs[:, :-1].contiguous()
133
162
  shifted_targets = targets[:, 1:].contiguous()
134
163
 
135
- return F.cross_entropy(
164
+ loss = F.cross_entropy(
136
165
  shifted_logits.view(-1, self.vocab_size),
137
166
  shifted_targets.view(-1)
138
- ), outputs
139
-
140
- def validate(self, batch_size: int) -> tuple[float, dict]:
141
- self.model.eval()
142
- val_dataloader = self._valid_loader(batch_size)
143
- val_loss = torch.tensor(0.0).to(self.device)
144
- correct = torch.tensor(0).to(self.device)
145
- total = torch.tensor(0).to(self.device)
146
-
147
- with torch.no_grad():
148
- for batch in val_dataloader:
149
- if self.get_batch_size(batch) == batch_size:
150
- loss, logits = self.valid_step(batch)
151
- val_loss += loss
152
- shifted_logits = logits[:, :-1].contiguous()
153
- shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
154
- valid_indices = shifted_targets != -100
155
- if valid_indices.any():
156
- preds = shifted_logits.argmax(-1)
157
- correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
158
- total += valid_indices.sum()
159
-
160
- avg_loss = (val_loss / len(val_dataloader)).item()
161
- acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
162
- node_acc = acc.item()
163
- if self.use_ddp:
164
- dist.all_reduce(acc, op=dist.ReduceOp.SUM)
165
- acc = acc / dist.get_world_size()
166
-
167
- metrics = {
168
- 'accuracy': acc.item(),
169
- 'node_accuracy': node_acc,
170
- }
171
- self.model.train()
172
- return avg_loss, metrics
173
-
174
-
175
- class AutoregressiveMoeTrainer(BaseTrainer):
176
- def __init__(
177
- self,
178
- model: ReactiveTransformerDecoder,
179
- device: torch.device,
180
- vocab_size: int,
181
- use_amp: bool = False,
182
- dtype: torch.dtype = None,
183
- router_loss_scale: float = 0.1,
184
- **kwargs
185
- ):
186
- super(AutoregressiveMoeTrainer, self).__init__(model, device, use_amp=use_amp, dtype=dtype,
187
- target_field_name='targets', **kwargs)
188
- self.vocab_size = vocab_size
189
- self.router_loss_scale = router_loss_scale
190
-
191
- def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
192
- inputs = batch['input_ids']
193
- attention_mask = batch['attention_mask']
194
- targets = batch['targets']
195
-
196
- outputs = self.model(
197
- inputs,
198
- attention_mask=attention_mask
199
167
  )
200
168
 
201
- shifted_logits = outputs[:, :-1].contiguous()
202
- shifted_targets = targets[:, 1:].contiguous()
169
+ return self._moe_aux_loss(loss), outputs
203
170
 
204
- main_loss = F.cross_entropy(
205
- shifted_logits.view(-1, self.vocab_size),
206
- shifted_targets.view(-1)
207
- )
171
+ def _moe_aux_loss(self, main_loss: torch.Tensor) -> torch.Tensor:
172
+ if not self.use_moe_aux_loss:
173
+ return main_loss
208
174
 
209
175
  model = next(self.model.children()) if isinstance(self.model, DistributedDataParallel) else self.model
210
176
 
211
177
  router_loss = model.model.moe_router_loss()
212
- loss = main_loss + self.router_loss_scale * router_loss
178
+ loss = main_loss + self.moe_aux_loss_scale * router_loss
213
179
 
214
180
  if self.writer is not None:
215
181
  if self.model.training:
@@ -219,7 +185,7 @@ class AutoregressiveMoeTrainer(BaseTrainer):
219
185
  self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
220
186
  self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
221
187
 
222
- return loss, outputs
188
+ return loss
223
189
 
224
190
  def validate(self, batch_size: int) -> tuple[float, dict]:
225
191
  self.model.eval()
@@ -279,6 +245,9 @@ class JointTrainingModel(nn.Module):
279
245
 
280
246
 
281
247
  class JointLMTrainer(BaseTrainer):
248
+ """"
249
+ It's not recommended to use Joint LM Training in current implementation. More info soon
250
+ """
282
251
  def __init__(
283
252
  self,
284
253
  model: JointTrainingModel,
@@ -132,9 +132,10 @@ class ClassicTransformerLayer(nn.Module):
132
132
 
133
133
  if use_gated:
134
134
  if use_moe:
135
- self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, top_k=moe_top_k, dropout=ff_dropout)
135
+ self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
136
+ dropout=ff_dropout)
136
137
  else:
137
- self.ff = GatedFeedForward(embed_dim, ff_dim, dropout=ff_dropout)
138
+ self.ff = GatedFeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
138
139
  else:
139
140
  if use_moe:
140
141
  self.ff = MoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.24
3
+ Version: 0.1.26
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -1,7 +1,7 @@
1
1
  rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- rxnn/experimental/attention.py,sha256=r3y0_BaweONrkm6Z-9zn56U9jlBMfYBs8NlWNm7rR90,32424
4
- rxnn/experimental/models.py,sha256=-XkEHsyT8iNAjhZbgC7N_5nzP4ENVJLwxSoLHgMfA0I,4668
3
+ rxnn/experimental/attention.py,sha256=rJSQjA7_9YqcM4Y8SyJSuZQjyz8j4XPhC5jcrcrRK2M,17891
4
+ rxnn/experimental/models.py,sha256=8KAo7BtRkke9qRlzGRtQa9-EZ34roGWrn0N_T6L-6Wc,4561
5
5
  rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
6
6
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
@@ -10,7 +10,7 @@ rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
11
11
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
13
- rxnn/training/bml.py,sha256=o_88ZL1YWd5gWXaBqYPK2UzSTbJaiTiw96E6z73LeOQ,18660
13
+ rxnn/training/bml.py,sha256=HtxSzI-WcpRclAuIccF_WoTZ24KzH5ZWKe8SnWgjjm4,17581
14
14
  rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
15
15
  rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
@@ -18,14 +18,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
- rxnn/transformers/layers.py,sha256=ZJfNdgCv9dzrKqsWIMf99Ryzgs494ZhkwK4zSBYLvQ4,6880
21
+ rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
23
  rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
24
24
  rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.24.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.24.dist-info/METADATA,sha256=_h4mqmSKPEr0mxc2CaMn-yzvmZ5Lqlk_H4parGt-eHk,16627
30
- rxnn-0.1.24.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.24.dist-info/RECORD,,
28
+ rxnn-0.1.26.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.26.dist-info/METADATA,sha256=qKx2OFl-ca7OCZoSf99l5lOmg2czSCN8UN_Esqmfco0,16627
30
+ rxnn-0.1.26.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.26.dist-info/RECORD,,
File without changes
File without changes