rxnn 0.1.83__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,30 +2,31 @@ import os, traceback, shutil
2
2
  import numpy as np
3
3
  import torch
4
4
  import torch.nn as nn
5
- from typing import Union
5
+ from typing import Union, Optional
6
6
  from torch.nn.parallel import DistributedDataParallel
7
7
  from huggingface_hub import PyTorchModelHubMixin
8
8
  from ..utils import human_format
9
9
 
10
+
10
11
  class TrainerCallback:
11
- def on_epoch_start(self, model: torch.nn.Module, epoch: int) -> None:
12
+ def on_epoch_start(self, model: nn.Module, epoch: int) -> None:
12
13
  pass
13
14
 
14
- def on_epoch_end(self, model: torch.nn.Module, epoch: int) -> Union[bool, None]:
15
+ def on_epoch_end(self, model: nn.Module, epoch: int) -> Union[bool, None]:
15
16
  pass
16
17
 
17
- def on_batch_start(self, model: torch.nn.Module, batch_idx: int, batch: dict[str, torch.Tensor]) -> None:
18
+ def on_batch_start(self, model: nn.Module, batch_idx: int, batch: dict[str, torch.Tensor]) -> None:
18
19
  pass
19
20
 
20
- def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: float, batch: dict[str, torch.Tensor]) -> \
21
- Union[
22
- bool, None]:
21
+ def on_batch_end(self, model: nn.Module, batch_idx: int, loss: float, batch: dict[str, torch.Tensor]) -> \
22
+ Union[
23
+ bool, None]:
23
24
  pass
24
25
 
25
- def on_training_end(self, model: torch.nn.Module) -> None:
26
+ def on_training_end(self, model: nn.Module) -> None:
26
27
  pass
27
28
 
28
- def on_validation_end(self, model: torch.nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> Union[
29
+ def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> Union[
29
30
  bool, None]:
30
31
  pass
31
32
 
@@ -111,7 +112,7 @@ class TokenCounterCallback(TrainerCallback):
111
112
  print(f'Reached a limit of {human_format(self.limit)} processed tokens - stopping training')
112
113
  return should_stop_training
113
114
 
114
- def on_training_end(self, model: torch.nn.Module) -> None:
115
+ def on_training_end(self, model: nn.Module) -> None:
115
116
  print(f'Total training tokens: {human_format(self.total_tokens)}')
116
117
 
117
118
  def get_total_tokens(self):
@@ -122,7 +123,6 @@ class ModelSaveCallback(TrainerCallback):
122
123
  def __init__(
123
124
  self,
124
125
  save_dir: str,
125
- save_best_only: bool = True,
126
126
  max_keep: int = 3,
127
127
  push_to_hub: bool = False,
128
128
  hub_model_id: str = None,
@@ -136,7 +136,6 @@ class ModelSaveCallback(TrainerCallback):
136
136
  use_ddp: bool = False,
137
137
  ):
138
138
  self.save_dir = save_dir
139
- self.save_best_only = save_best_only
140
139
  self.max_keep = max_keep
141
140
  self.best_loss = float('inf')
142
141
  self.ckpt_paths = []
@@ -152,7 +151,7 @@ class ModelSaveCallback(TrainerCallback):
152
151
  self.display_exc_trace = display_exc_trace
153
152
  self.rank = int(os.environ['RANK']) if use_ddp else 0
154
153
 
155
- def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
154
+ def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
156
155
  bool, None]:
157
156
  if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
158
157
  if isinstance(model, DistributedDataParallel):
@@ -195,7 +194,7 @@ class ModelSaveCallback(TrainerCallback):
195
194
 
196
195
  def on_validation_end(
197
196
  self,
198
- model: Union[torch.nn.Module, PyTorchModelHubMixin],
197
+ model: Union[nn.Module, PyTorchModelHubMixin],
199
198
  epoch: int,
200
199
  val_loss: float,
201
200
  val_metrics: dict
@@ -252,7 +251,7 @@ class ModelSaveCallback(TrainerCallback):
252
251
  if self.display_exc_trace:
253
252
  traceback.print_exc()
254
253
 
255
- def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
254
+ def on_training_end(self, model: Union[nn.Module, PyTorchModelHubMixin]):
256
255
  if self.rank == 0:
257
256
  if isinstance(model, DistributedDataParallel):
258
257
  model = next(model.children())
@@ -291,7 +290,6 @@ class JointModelSaveCallback(TrainerCallback):
291
290
  def __init__(
292
291
  self,
293
292
  save_dir: str,
294
- save_best_only: bool = True,
295
293
  max_keep: int = 3,
296
294
  push_to_hub: bool = False,
297
295
  hub_model_decoder: str = None,
@@ -308,7 +306,6 @@ class JointModelSaveCallback(TrainerCallback):
308
306
  use_ddp: bool = False,
309
307
  ):
310
308
  self.save_dir = save_dir
311
- self.save_best_only = save_best_only
312
309
  self.max_keep = max_keep
313
310
  self.best_loss = float('inf')
314
311
  self.ckpt_paths = []
@@ -369,7 +366,7 @@ class JointModelSaveCallback(TrainerCallback):
369
366
  if self.display_exc_trace:
370
367
  traceback.print_exc()
371
368
 
372
- def on_batch_end(self, model: torch.nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
369
+ def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
373
370
  bool, None]:
374
371
  if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
375
372
  if isinstance(model, DistributedDataParallel):
@@ -434,7 +431,7 @@ class JointModelSaveCallback(TrainerCallback):
434
431
 
435
432
  def on_validation_end(
436
433
  self,
437
- model: Union[torch.nn.Module, PyTorchModelHubMixin],
434
+ model: Union[nn.Module, PyTorchModelHubMixin],
438
435
  epoch: int,
439
436
  val_loss: float,
440
437
  val_metrics: dict
@@ -491,7 +488,7 @@ class JointModelSaveCallback(TrainerCallback):
491
488
  if self.display_exc_trace:
492
489
  traceback.print_exc()
493
490
 
494
- def on_training_end(self, model: Union[torch.nn.Module, PyTorchModelHubMixin]):
491
+ def on_training_end(self, model: Union[nn.Module, PyTorchModelHubMixin]):
495
492
  if self.rank == 0:
496
493
  if isinstance(model, DistributedDataParallel):
497
494
  model = next(model.children())
@@ -500,23 +497,289 @@ class JointModelSaveCallback(TrainerCallback):
500
497
  self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
501
498
  self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
502
499
 
500
+
503
501
  class EarlyStoppageCallback(TrainerCallback):
504
- def __init__(self, num_plateau_epochs: int = 3) -> None:
505
- super().__init__()
506
- self.num_plateau_epochs = num_plateau_epochs
507
- self.best_loss = 9999.0
508
- self.best_loss_epoch = 0
509
-
510
- def on_validation_end(
511
- self,
512
- model: torch.nn.Module,
513
- epoch: int,
514
- val_loss: float,
515
- val_metrics: dict
516
- ):
517
- if val_loss < self.best_loss:
518
- self.best_loss = val_loss
519
- self.best_loss_epoch = epoch
520
- elif epoch - self.best_loss_epoch > self.num_plateau_epochs:
521
- return True
522
- return None
502
+ def __init__(self, num_plateau_epochs: int = 3) -> None:
503
+ super().__init__()
504
+ self.num_plateau_epochs = num_plateau_epochs
505
+ self.best_loss = 9999.0
506
+ self.best_loss_epoch = 0
507
+
508
+ def on_validation_end(
509
+ self,
510
+ model: nn.Module,
511
+ epoch: int,
512
+ val_loss: float,
513
+ val_metrics: dict
514
+ ):
515
+ if val_loss < self.best_loss:
516
+ self.best_loss = val_loss
517
+ self.best_loss_epoch = epoch
518
+ elif epoch - self.best_loss_epoch >= self.num_plateau_epochs:
519
+ return True
520
+ return None
521
+
522
+
523
+ class MrlTrainerCallback:
524
+ def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, global_epoch: int,
525
+ global_epochs: int, curriculum_config: dict) -> None:
526
+ pass
527
+
528
+ def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
529
+ critic_loss: float, global_epoch: int, global_epochs: int) -> None:
530
+ pass
531
+
532
+ def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
533
+ reward: float) -> None:
534
+ pass
535
+
536
+ def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
537
+ pass
538
+
539
+ def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
540
+ pass
541
+
542
+ def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
543
+ critic_loss: float) -> None:
544
+ pass
545
+
546
+ def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
547
+ pass
548
+
549
+ def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> Union[bool, None]:
550
+ pass
551
+
552
+ def on_eval_episode_end(self, actor: nn.Module, epoch: int, batch_idx: int, reward: float) -> None:
553
+ pass
554
+
555
+
556
+ class MrlPrintCallback(MrlTrainerCallback):
557
+ def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
558
+ global_epoch: int, global_epochs: int) -> None:
559
+ print(
560
+ f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config['steps']} steps in {curriculum_config['strategy']} strategy.')
561
+
562
+ def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
563
+ critic_loss: float, global_epoch: int, global_epochs: int) -> None:
564
+ print(f'Finished epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global)')
565
+ print(f'Policy mean loss: {policy_loss} | Critic mean loss: {critic_loss}')
566
+
567
+ def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
568
+ reward: float) -> None:
569
+ print(f'Collected {batch_idx} episode | mean reward {reward}')
570
+
571
+ def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
572
+ reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
573
+ print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
574
+
575
+ def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
576
+ print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
577
+
578
+ def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
579
+ critic_loss: float) -> None:
580
+ print(f'Epoch {epoch} | Step {step} - updated policy loss {critic_loss}')
581
+
582
+ def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
583
+ print(f'Finished training for {curriculum_config['steps']} steps in {curriculum_config['strategy']} strategy.')
584
+
585
+ def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> None:
586
+ print(f'Eval epoch {epoch} - mean reward {eval_mean_reward}')
587
+
588
+ def on_eval_episode_end(self, actor: nn.Module, epoch: int, batch_idx: int, reward: float) -> None:
589
+ print(f'Eval epoch {epoch} / Episode {batch_idx} - mean reward {reward}')
590
+
591
+
592
+ class MrlEarlyStoppageCallback(MrlTrainerCallback):
593
+ def __init__(self, num_plateau_epochs: int = 2, threshold: Optional[float] = None) -> None:
594
+ super().__init__()
595
+ self.num_plateau_epochs = num_plateau_epochs
596
+ self.best_reward = -9999.0
597
+ self.best_reward_epoch = 0
598
+ self.threshold = threshold
599
+
600
+ def on_eval_end(self, _actor: nn.Module, _critic: nn.Module, epoch: int, eval_mean_reward: float) -> Union[bool, None]:
601
+ if self.threshold is not None:
602
+ if eval_mean_reward > self.threshold:
603
+ return True
604
+
605
+ if eval_mean_reward > self.best_reward:
606
+ self.best_reward = eval_mean_reward
607
+ self.best_reward_epoch = epoch
608
+ elif epoch - self.best_reward_epoch >= self.num_plateau_epochs:
609
+ return True
610
+ return None
611
+
612
+ class MrlModelSaveCallback(MrlTrainerCallback):
613
+ def __init__(
614
+ self,
615
+ save_dir: str,
616
+ max_keep: int = 3,
617
+ push_to_hub: bool = False,
618
+ hub_model_decoder: str = None,
619
+ hub_model_encoder: str = None,
620
+ hub_model_memory_attention: str = None,
621
+ hub_model_critic: str = None,
622
+ private_repo: bool = False,
623
+ hf_token: str = None,
624
+ push_checkpoint_weights: bool = True,
625
+ final_commit_message: str = None,
626
+ display_exc_trace: bool = False,
627
+ use_ddp: bool = False,
628
+ ):
629
+ self.save_dir = save_dir
630
+ self.max_keep = max_keep
631
+ self.best_reward = float('-inf')
632
+ self.ckpt_paths = []
633
+ self.push_to_hub = push_to_hub
634
+ self.hub_model_decoder = hub_model_decoder
635
+ self.hub_model_encoder = hub_model_encoder
636
+ self.hub_model_memory_attention = hub_model_memory_attention
637
+ self.hub_model_critic = hub_model_critic
638
+ self.private_repo = private_repo
639
+ self.hf_token = hf_token
640
+ self.push_checkpoint_weights = push_checkpoint_weights
641
+ self.final_commit_message = final_commit_message
642
+ self.finished_epochs = 0
643
+ self.display_exc_trace = display_exc_trace
644
+ self.rank = int(os.environ['RANK']) if use_ddp else 0
645
+
646
+ def _save_eval(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, epoch: int,
647
+ reward: float, hub_id: str = None):
648
+ try:
649
+ if model.save_pretrained is not None:
650
+ ckpt_path = os.path.join(
651
+ self.save_dir,
652
+ component,
653
+ f'epoch_{epoch}_eval_reward_{reward:.4f}'
654
+ )
655
+ path_exists = os.path.exists(ckpt_path)
656
+ if not path_exists:
657
+ os.makedirs(ckpt_path)
658
+ model.save_pretrained(save_directory=ckpt_path)
659
+ else:
660
+ comp_path = os.path.join(
661
+ self.save_dir,
662
+ component
663
+ )
664
+ path_exists = os.path.exists(comp_path)
665
+ if not path_exists:
666
+ os.makedirs(comp_path)
667
+ ckpt_path = os.path.join(
668
+ comp_path,
669
+ f'epoch_{epoch}_eval_reward_{reward:.4f}.pt'
670
+ )
671
+ torch.save(model.state_dict(), ckpt_path)
672
+ self.ckpt_paths.append(ckpt_path)
673
+
674
+ # Keep only N best checkpoints
675
+ if len(self.ckpt_paths) > self.max_keep:
676
+ oldest_path = self.ckpt_paths.pop(0)
677
+ if model.save_pretrained is not None:
678
+ shutil.rmtree(oldest_path)
679
+ else:
680
+ os.remove(oldest_path)
681
+ except Exception as e:
682
+ print(f"Error saving epoch checkpoint: {str(e)}")
683
+ if self.display_exc_trace:
684
+ traceback.print_exc()
685
+
686
+ try:
687
+ if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and hub_id:
688
+ model.push_to_hub(
689
+ repo_id=hub_id,
690
+ commit_message=f'Epoch {epoch} - Eval reward {reward:.4f}',
691
+ token=self.hf_token,
692
+ private=self.private_repo,
693
+ )
694
+ except Exception as e:
695
+ print(f"Error pushing epoch checkpoint: {str(e)}")
696
+ if self.display_exc_trace:
697
+ traceback.print_exc()
698
+
699
+ def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> None:
700
+ if self.rank == 0:
701
+ self.finished_epochs += 1
702
+ if eval_mean_reward > self.best_reward:
703
+ self.best_reward = eval_mean_reward
704
+ if isinstance(actor, DistributedDataParallel):
705
+ actor = next(actor.children())
706
+ self._save_eval(actor.encoder, 'encoder', epoch, eval_mean_reward, hub_id=self.hub_model_encoder)
707
+ self._save_eval(actor.decoder, 'decoder', epoch, eval_mean_reward, hub_id=self.hub_model_decoder)
708
+ self._save_eval(actor.memory_attention, 'memory_attention', epoch, eval_mean_reward, hub_id=self.hub_model_memory_attention)
709
+ if isinstance(critic, DistributedDataParallel):
710
+ critic = next(critic.children())
711
+ self._save_eval(critic, 'critic', epoch, eval_mean_reward, hub_id=self.hub_model_critic)
712
+
713
+ def _save_final(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
714
+ try:
715
+ # Save final model
716
+ if model.save_pretrained is not None:
717
+ ckpt_path = os.path.join(
718
+ self.save_dir,
719
+ component,
720
+ 'final_model'
721
+ )
722
+ path_exists = os.path.exists(ckpt_path)
723
+ if not path_exists:
724
+ os.makedirs(ckpt_path)
725
+ model.save_pretrained(save_directory=ckpt_path)
726
+ else:
727
+ comp_path = os.path.join(
728
+ self.save_dir,
729
+ component
730
+ )
731
+ path_exists = os.path.exists(comp_path)
732
+ if not path_exists:
733
+ os.makedirs(comp_path)
734
+ ckpt_path = os.path.join(comp_path, 'final_model.pt')
735
+ torch.save(model.state_dict(), ckpt_path)
736
+ print(f"Final model saved to {ckpt_path}")
737
+ except Exception as e:
738
+ print(f"Error saving final model: {str(e)}")
739
+ if self.display_exc_trace:
740
+ traceback.print_exc()
741
+ try:
742
+ if self.push_to_hub and model.push_to_hub is not None and hub_id:
743
+ model.push_to_hub(
744
+ repo_id=hub_id,
745
+ commit_message=self.final_commit_message or f'Model after full curriculum stage, after {self.finished_epochs} epochs',
746
+ token=self.hf_token,
747
+ private=self.private_repo,
748
+ )
749
+ except Exception as e:
750
+ print(f"Error pushing final model: {str(e)}")
751
+ if self.display_exc_trace:
752
+ traceback.print_exc()
753
+
754
+ def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
755
+ if self.rank == 0:
756
+ if isinstance(actor, DistributedDataParallel):
757
+ actor = next(actor.children())
758
+ self._save_final(actor.encoder, 'encoder', hub_id=self.hub_model_encoder)
759
+ self._save_final(actor.decoder, 'decoder', hub_id=self.hub_model_decoder)
760
+ self._save_final(actor.memory_attention, 'memory_attention', hub_id=self.hub_model_memory_attention)
761
+ if isinstance(critic, DistributedDataParallel):
762
+ critic = next(critic.children())
763
+ self._save_final(critic, 'critic', hub_id=self.hub_model_critic)
764
+
765
+ class MrlGeneratedTokensCallback(MrlTrainerCallback):
766
+ def __init__(self, steps_log_interval: int = 100):
767
+ self.total_tokens = 0
768
+ self.steps_log_interval = steps_log_interval
769
+ self.step = 0
770
+
771
+ def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
772
+ reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
773
+ self.step += 1
774
+ attention_mask = generated['attention_mask']
775
+ batch_tokens = attention_mask.sum().item()
776
+ self.total_tokens += batch_tokens
777
+ if self.step != 0 and self.step % self.steps_log_interval == 0:
778
+ print(f'Total processed tokens: {human_format(self.total_tokens)}')
779
+
780
+ def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
781
+ print(f'Total training tokens: {human_format(self.total_tokens)}')
782
+
783
+ def get_total_tokens(self):
784
+ return self.total_tokens
785
+