rxnn 0.2.68__py3-none-any.whl → 0.2.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/mrl.py
CHANGED
@@ -41,6 +41,7 @@ class MrlConfig(TypedDict):
|
|
41
41
|
use_memory_warmup: Optional[bool]
|
42
42
|
debug_mode: Optional[bool]
|
43
43
|
debug_interval: Optional[int]
|
44
|
+
clamp_logits: Optional[bool]
|
44
45
|
|
45
46
|
|
46
47
|
class MrlStrategy(Enum):
|
@@ -152,6 +153,7 @@ class MRLTrainer:
|
|
152
153
|
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
153
154
|
self.debug_mode = config.get('debug_mode', False)
|
154
155
|
self.debug_interval = config.get('debug_interval', 10)
|
156
|
+
self.clamp_logits = config.get('clamp_logits', False)
|
155
157
|
# Internal update epochs config
|
156
158
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
157
159
|
self.update_epochs = self.shared_update_epochs
|
@@ -594,7 +596,9 @@ class MRLTrainer:
|
|
594
596
|
else:
|
595
597
|
return main_loss
|
596
598
|
|
597
|
-
def _log_gradients(self):
|
599
|
+
def _log_gradients(self, logits: torch.Tensor):
|
600
|
+
print(
|
601
|
+
f"Returned logits stats: min={logits.min().item():.4f}, max={logits.max().item():.4f}")
|
598
602
|
encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
|
599
603
|
decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
|
600
604
|
mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
|
@@ -633,6 +637,8 @@ class MRLTrainer:
|
|
633
637
|
pad_token_id=self.pad_token_id)
|
634
638
|
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
635
639
|
action=MrlActorAction.DECODE)
|
640
|
+
if self.clamp_logits:
|
641
|
+
logits = logits.clamp(min=-20.0, max=20.0)
|
636
642
|
# 4.2 Calculate policy loss with selected algorithm
|
637
643
|
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs,
|
638
644
|
advantages)
|
@@ -645,7 +651,7 @@ class MRLTrainer:
|
|
645
651
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
646
652
|
error_if_nonfinite=False)
|
647
653
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
648
|
-
self._log_gradients()
|
654
|
+
self._log_gradients(logits)
|
649
655
|
# 4.5 Run scaled optimization step
|
650
656
|
self.scaler.step(self.optimizer)
|
651
657
|
self.scaler.update()
|
@@ -655,6 +661,8 @@ class MRLTrainer:
|
|
655
661
|
pad_token_id=self.pad_token_id)
|
656
662
|
logits = self.actor(inputs['input_ids'], attention_mask=inputs['attention_mask'],
|
657
663
|
action=MrlActorAction.DECODE)
|
664
|
+
if self.clamp_logits:
|
665
|
+
logits = logits.clamp(min=-20.0, max=20.0)
|
658
666
|
# 4.2 Calculate policy loss with selected algorithm
|
659
667
|
policy_loss = self.rl_algorithm.policy_loss(next_query, action, logits, old_log_probs, advantages)
|
660
668
|
policy_loss = self._moe_aux_loss(policy_loss)
|
@@ -664,7 +672,7 @@ class MRLTrainer:
|
|
664
672
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
665
673
|
error_if_nonfinite=False)
|
666
674
|
if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
|
667
|
-
self._log_gradients()
|
675
|
+
self._log_gradients(logits)
|
668
676
|
# 4.5 Run scaled optimization step
|
669
677
|
self.optimizer.step()
|
670
678
|
# 5. Get float loss value for callbacks/writer
|
@@ -17,7 +17,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
|
|
17
17
|
rxnn/training/dataset.py,sha256=tbtOSYldHnQB6SWgee_yUj9zTbgoEoLFNa6wvUS6Apg,51292
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=KIiOCW0VgKtMA4EMQ---xsVExdI1mBsgWjtRSmJpecA,9033
|
20
|
-
rxnn/training/mrl.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=c_7P_DhroK3pQLubfmlVryWBSwlZ0BssU8zZ6UhjOaI,65919
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
22
|
rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.69.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.69.dist-info/METADATA,sha256=YcmghdF8ypeyOCmglJaws18cDtTqSIE8P-gReGIMzsU,60420
|
38
|
+
rxnn-0.2.69.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.69.dist-info/RECORD,,
|
File without changes
|
File without changes
|