rxnn 0.2.68__tar.gz → 0.2.69__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.68 → rxnn-0.2.69}/PKG-INFO +1 -1
  2. {rxnn-0.2.68 → rxnn-0.2.69}/pyproject.toml +1 -1
  3. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/mrl.py +11 -3
  4. {rxnn-0.2.68 → rxnn-0.2.69}/LICENSE +0 -0
  5. {rxnn-0.2.68 → rxnn-0.2.69}/README.md +0 -0
  6. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/.DS_Store +0 -0
  7. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/experimental/attention.py +0 -0
  10. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/experimental/models.py +0 -0
  11. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/experimental/moe.py +0 -0
  12. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/memory/__init__.py +0 -0
  13. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/memory/attention.py +0 -0
  14. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/memory/norm.py +0 -0
  15. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/memory/stm.py +0 -0
  16. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/rxt/__init__.py +0 -0
  17. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/rxt/models.py +0 -0
  18. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/__init__.py +0 -0
  19. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/base.py +0 -0
  20. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/bml.py +0 -0
  21. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/callbacks.py +0 -0
  22. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/dataset.py +0 -0
  23. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/ddp.py +0 -0
  24. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/models.py +0 -0
  25. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/reward.py +0 -0
  26. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/rl.py +0 -0
  27. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/scheduler.py +0 -0
  28. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/tokenizer.py +0 -0
  29. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/training/utils.py +0 -0
  30. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/__init__.py +0 -0
  31. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/attention.py +0 -0
  32. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/ff.py +0 -0
  33. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/layers.py +0 -0
  34. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.68 → rxnn-0.2.69}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.68
3
+ Version: 0.2.69
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.68"
7
+ version = "0.2.69"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -41,6 +41,7 @@ class MrlConfig(TypedDict):
41
41
  use_memory_warmup: Optional[bool]
42
42
  debug_mode: Optional[bool]
43
43
  debug_interval: Optional[int]
44
+ clamp_logits: Optional[bool]
44
45
 
45
46
 
46
47
  class MrlStrategy(Enum):
@@ -152,6 +153,7 @@ class MRLTrainer:
152
153
  self.use_memory_warmup = config.get('use_memory_warmup', False)
153
154
  self.debug_mode = config.get('debug_mode', False)
154
155
  self.debug_interval = config.get('debug_interval', 10)
156
+ self.clamp_logits = config.get('clamp_logits', False)
155
157
  # Internal update epochs config
156
158
  self.shared_update_epochs = config.get('update_epochs', 10)
157
159
  self.update_epochs = self.shared_update_epochs
@@ -594,7 +596,9 @@ class MRLTrainer:
594
596
  else:
595
597
  return main_loss
596
598
 
597
- def _log_gradients(self):
599
+ def _log_gradients(self, logits: torch.Tensor):
600
+ print(
601
+ f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
598
602
  encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
599
603
  decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
600
604
  mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
@@ -633,6 +637,8 @@ class MRLTrainer:
633
637
  pad_token_id=self.pad_token_id)
634
638
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
635
639
  action=MrlActorAction.DECODE)
640
+ if self.clamp_logits:
641
+ logits = logits.clamp(min=-20.0, max=20.0)
636
642
  # 4.2 Calculate policy loss with selected algorithm
637
643
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
638
644
  advantages)
@@ -645,7 +651,7 @@ class MRLTrainer:
645
651
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
646
652
  error_if_nonfinite=False)
647
653
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
648
- self._log_gradients()
654
+ self._log_gradients(logits)
649
655
  # 4.5 Run scaled optimization step
650
656
  self.scaler.step(self.optimizer)
651
657
  self.scaler.update()
@@ -655,6 +661,8 @@ class MRLTrainer:
655
661
  pad_token_id=self.pad_token_id)
656
662
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
657
663
  action=MrlActorAction.DECODE)
664
+ if self.clamp_logits:
665
+ logits = logits.clamp(min=-20.0, max=20.0)
658
666
  # 4.2 Calculate policy loss with selected algorithm
659
667
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
660
668
  policy_loss = self._moe_aux_loss(policy_loss)
@@ -664,7 +672,7 @@ class MRLTrainer:
664
672
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
665
673
  error_if_nonfinite=False)
666
674
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
667
- self._log_gradients()
675
+ self._log_gradients(logits)
668
676
  # 4.5 Run scaled optimization step
669
677
  self.optimizer.step()
670
678
  # 5. Get float loss value for callbacks/writer
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes