rxnn 0.2.68__py3-none-any.whl → 0.2.69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/mrl.py CHANGED
@@ -41,6 +41,7 @@ class MrlConfig(TypedDict):
41
41
  use_memory_warmup: Optional[bool]
42
42
  debug_mode: Optional[bool]
43
43
  debug_interval: Optional[int]
44
+ clamp_logits: Optional[bool]
44
45
 
45
46
 
46
47
  class MrlStrategy(Enum):
@@ -152,6 +153,7 @@ class MRLTrainer:
152
153
  self.use_memory_warmup = config.get('use_memory_warmup', False)
153
154
  self.debug_mode = config.get('debug_mode', False)
154
155
  self.debug_interval = config.get('debug_interval', 10)
156
+ self.clamp_logits = config.get('clamp_logits', False)
155
157
  # Internal update epochs config
156
158
  self.shared_update_epochs = config.get('update_epochs', 10)
157
159
  self.update_epochs = self.shared_update_epochs
@@ -594,7 +596,9 @@ class MRLTrainer:
594
596
  else:
595
597
  return main_loss
596
598
 
597
- def _log_gradients(self):
599
+ def _log_gradients(self, logits: torch.Tensor):
600
+ print(
601
+ f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
598
602
  encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
599
603
  decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
600
604
  mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
@@ -633,6 +637,8 @@ class MRLTrainer:
633
637
  pad_token_id=self.pad_token_id)
634
638
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
635
639
  action=MrlActorAction.DECODE)
640
+ if self.clamp_logits:
641
+ logits = logits.clamp(min=-20.0, max=20.0)
636
642
  # 4.2 Calculate policy loss with selected algorithm
637
643
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
638
644
  advantages)
@@ -645,7 +651,7 @@ class MRLTrainer:
645
651
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
646
652
  error_if_nonfinite=False)
647
653
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
648
- self._log_gradients()
654
+ self._log_gradients(logits)
649
655
  # 4.5 Run scaled optimization step
650
656
  self.scaler.step(self.optimizer)
651
657
  self.scaler.update()
@@ -655,6 +661,8 @@ class MRLTrainer:
655
661
  pad_token_id=self.pad_token_id)
656
662
  logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
657
663
  action=MrlActorAction.DECODE)
664
+ if self.clamp_logits:
665
+ logits = logits.clamp(min=-20.0, max=20.0)
658
666
  # 4.2 Calculate policy loss with selected algorithm
659
667
  policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
660
668
  policy_loss = self._moe_aux_loss(policy_loss)
@@ -664,7 +672,7 @@ class MRLTrainer:
664
672
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
665
673
  error_if_nonfinite=False)
666
674
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
667
- self._log_gradients()
675
+ self._log_gradients(logits)
668
676
  # 4.5 Run scaled optimization step
669
677
  self.optimizer.step()
670
678
  # 5. Get float loss value for callbacks/writer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.68
3
+ Version: 0.2.69
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -17,7 +17,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
17
17
  rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
18
18
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
20
- rxnn/training/mrl.py,sha256=2J6Wh4xtsVoE6duEevmovDpmSsMkEoH39Ru0bE8lhFo,65481
20
+ rxnn/training/mrl.py,sha256=c_7P_DhroK3pQLubfmlVryWBSwlZ0BssU8zZ6UhjOaI,65919
21
21
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
22
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.68.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.68.dist-info/METADATA,sha256=w7AYGnPAW9xy8DMmcueWdfoV1oQxqOVVDcRxlvA8gWQ,60420
38
- rxnn-0.2.68.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.68.dist-info/RECORD,,
36
+ rxnn-0.2.69.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.69.dist-info/METADATA,sha256=YcmghdF8ypeyOCmglJaws18cDtTqSIE8P-gReGIMzsU,60420
38
+ rxnn-0.2.69.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.69.dist-info/RECORD,,
File without changes
File without changes