rxnn 0.1.83__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/.DS_Store +0 -0
- rxnn/experimental/attention.py +5 -0
- rxnn/memory/attention.py +42 -0
- rxnn/memory/stm.py +55 -12
- rxnn/rxt/models.py +71 -0
- rxnn/training/bml.py +2 -59
- rxnn/training/callbacks.py +302 -39
- rxnn/training/dataset.py +344 -1
- rxnn/training/models.py +142 -0
- rxnn/training/mrl.py +808 -0
- rxnn/training/reward.py +111 -0
- rxnn/training/rl.py +69 -0
- rxnn/training/utils.py +148 -0
- rxnn/transformers/attention.py +10 -0
- rxnn/transformers/layers.py +6 -0
- rxnn/transformers/models.py +16 -4
- rxnn/transformers/positional.py +7 -0
- rxnn/transformers/sampler.py +283 -9
- {rxnn-0.1.83.dist-info → rxnn-0.2.1.dist-info}/METADATA +11 -9
- rxnn-0.2.1.dist-info/RECORD +38 -0
- rxnn-0.1.83.dist-info/RECORD +0 -31
- {rxnn-0.1.83.dist-info → rxnn-0.2.1.dist-info}/LICENSE +0 -0
- {rxnn-0.1.83.dist-info → rxnn-0.2.1.dist-info}/WHEEL +0 -0
rxnn/training/callbacks.py
CHANGED
@@ -2,30 +2,31 @@ import os, traceback, shutil
|
|
2
2
|
import numpy as np
|
3
3
|
import torch
|
4
4
|
import torch.nn as nn
|
5
|
-
from typing import Union
|
5
|
+
from typing import Union, Optional
|
6
6
|
from torch.nn.parallel import DistributedDataParallel
|
7
7
|
from huggingface_hub import PyTorchModelHubMixin
|
8
8
|
from ..utils import human_format
|
9
9
|
|
10
|
+
|
10
11
|
class TrainerCallback:
|
11
|
-
def on_epoch_start(self, model:
|
12
|
+
def on_epoch_start(self, model: nn.Module, epoch: int) -> None:
|
12
13
|
pass
|
13
14
|
|
14
|
-
def on_epoch_end(self, model:
|
15
|
+
def on_epoch_end(self, model: nn.Module, epoch: int) -> Union[bool, None]:
|
15
16
|
pass
|
16
17
|
|
17
|
-
def on_batch_start(self, model:
|
18
|
+
def on_batch_start(self, model: nn.Module, batch_idx: int, batch: dict[str, torch.Tensor]) -> None:
|
18
19
|
pass
|
19
20
|
|
20
|
-
def on_batch_end(self, model:
|
21
|
-
|
22
|
-
|
21
|
+
def on_batch_end(self, model: nn.Module, batch_idx: int, loss: float, batch: dict[str, torch.Tensor]) -> \
|
22
|
+
Union[
|
23
|
+
bool, None]:
|
23
24
|
pass
|
24
25
|
|
25
|
-
def on_training_end(self, model:
|
26
|
+
def on_training_end(self, model: nn.Module) -> None:
|
26
27
|
pass
|
27
28
|
|
28
|
-
def on_validation_end(self, model:
|
29
|
+
def on_validation_end(self, model: nn.Module, epoch: int, val_loss: float, val_metrics: dict) -> Union[
|
29
30
|
bool, None]:
|
30
31
|
pass
|
31
32
|
|
@@ -111,7 +112,7 @@ class TokenCounterCallback(TrainerCallback):
|
|
111
112
|
print(f'Reached a limit of {human_format(self.limit)} processed tokens - stopping training')
|
112
113
|
return should_stop_training
|
113
114
|
|
114
|
-
def on_training_end(self, model:
|
115
|
+
def on_training_end(self, model: nn.Module) -> None:
|
115
116
|
print(f'Total training tokens: {human_format(self.total_tokens)}')
|
116
117
|
|
117
118
|
def get_total_tokens(self):
|
@@ -122,7 +123,6 @@ class ModelSaveCallback(TrainerCallback):
|
|
122
123
|
def __init__(
|
123
124
|
self,
|
124
125
|
save_dir: str,
|
125
|
-
save_best_only: bool = True,
|
126
126
|
max_keep: int = 3,
|
127
127
|
push_to_hub: bool = False,
|
128
128
|
hub_model_id: str = None,
|
@@ -136,7 +136,6 @@ class ModelSaveCallback(TrainerCallback):
|
|
136
136
|
use_ddp: bool = False,
|
137
137
|
):
|
138
138
|
self.save_dir = save_dir
|
139
|
-
self.save_best_only = save_best_only
|
140
139
|
self.max_keep = max_keep
|
141
140
|
self.best_loss = float('inf')
|
142
141
|
self.ckpt_paths = []
|
@@ -152,7 +151,7 @@ class ModelSaveCallback(TrainerCallback):
|
|
152
151
|
self.display_exc_trace = display_exc_trace
|
153
152
|
self.rank = int(os.environ['RANK']) if use_ddp else 0
|
154
153
|
|
155
|
-
def on_batch_end(self, model:
|
154
|
+
def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
|
156
155
|
bool, None]:
|
157
156
|
if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
158
157
|
if isinstance(model, DistributedDataParallel):
|
@@ -195,7 +194,7 @@ class ModelSaveCallback(TrainerCallback):
|
|
195
194
|
|
196
195
|
def on_validation_end(
|
197
196
|
self,
|
198
|
-
model: Union[
|
197
|
+
model: Union[nn.Module, PyTorchModelHubMixin],
|
199
198
|
epoch: int,
|
200
199
|
val_loss: float,
|
201
200
|
val_metrics: dict
|
@@ -252,7 +251,7 @@ class ModelSaveCallback(TrainerCallback):
|
|
252
251
|
if self.display_exc_trace:
|
253
252
|
traceback.print_exc()
|
254
253
|
|
255
|
-
def on_training_end(self, model: Union[
|
254
|
+
def on_training_end(self, model: Union[nn.Module, PyTorchModelHubMixin]):
|
256
255
|
if self.rank == 0:
|
257
256
|
if isinstance(model, DistributedDataParallel):
|
258
257
|
model = next(model.children())
|
@@ -291,7 +290,6 @@ class JointModelSaveCallback(TrainerCallback):
|
|
291
290
|
def __init__(
|
292
291
|
self,
|
293
292
|
save_dir: str,
|
294
|
-
save_best_only: bool = True,
|
295
293
|
max_keep: int = 3,
|
296
294
|
push_to_hub: bool = False,
|
297
295
|
hub_model_decoder: str = None,
|
@@ -308,7 +306,6 @@ class JointModelSaveCallback(TrainerCallback):
|
|
308
306
|
use_ddp: bool = False,
|
309
307
|
):
|
310
308
|
self.save_dir = save_dir
|
311
|
-
self.save_best_only = save_best_only
|
312
309
|
self.max_keep = max_keep
|
313
310
|
self.best_loss = float('inf')
|
314
311
|
self.ckpt_paths = []
|
@@ -369,7 +366,7 @@ class JointModelSaveCallback(TrainerCallback):
|
|
369
366
|
if self.display_exc_trace:
|
370
367
|
traceback.print_exc()
|
371
368
|
|
372
|
-
def on_batch_end(self, model:
|
369
|
+
def on_batch_end(self, model: nn.Module, batch_idx: int, loss: int, batch: dict[str, torch.Tensor]) -> Union[
|
373
370
|
bool, None]:
|
374
371
|
if self.rank == 0 and self.save_checkpoint_after_n_batches is not None and batch_idx != 0 and batch_idx % self.save_checkpoint_after_n_batches == 0:
|
375
372
|
if isinstance(model, DistributedDataParallel):
|
@@ -434,7 +431,7 @@ class JointModelSaveCallback(TrainerCallback):
|
|
434
431
|
|
435
432
|
def on_validation_end(
|
436
433
|
self,
|
437
|
-
model: Union[
|
434
|
+
model: Union[nn.Module, PyTorchModelHubMixin],
|
438
435
|
epoch: int,
|
439
436
|
val_loss: float,
|
440
437
|
val_metrics: dict
|
@@ -491,7 +488,7 @@ class JointModelSaveCallback(TrainerCallback):
|
|
491
488
|
if self.display_exc_trace:
|
492
489
|
traceback.print_exc()
|
493
490
|
|
494
|
-
def on_training_end(self, model: Union[
|
491
|
+
def on_training_end(self, model: Union[nn.Module, PyTorchModelHubMixin]):
|
495
492
|
if self.rank == 0:
|
496
493
|
if isinstance(model, DistributedDataParallel):
|
497
494
|
model = next(model.children())
|
@@ -500,23 +497,289 @@ class JointModelSaveCallback(TrainerCallback):
|
|
500
497
|
self._save_final(model.decoder, 'decoder', hub_id=self.hub_model_decoder)
|
501
498
|
self._save_final(model.mlm_head, 'head', hub_id=self.hub_model_head)
|
502
499
|
|
500
|
+
|
503
501
|
class EarlyStoppageCallback(TrainerCallback):
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
502
|
+
def __init__(self, num_plateau_epochs: int = 3) -> None:
|
503
|
+
super().__init__()
|
504
|
+
self.num_plateau_epochs = num_plateau_epochs
|
505
|
+
self.best_loss = 9999.0
|
506
|
+
self.best_loss_epoch = 0
|
507
|
+
|
508
|
+
def on_validation_end(
|
509
|
+
self,
|
510
|
+
model: nn.Module,
|
511
|
+
epoch: int,
|
512
|
+
val_loss: float,
|
513
|
+
val_metrics: dict
|
514
|
+
):
|
515
|
+
if val_loss < self.best_loss:
|
516
|
+
self.best_loss = val_loss
|
517
|
+
self.best_loss_epoch = epoch
|
518
|
+
elif epoch - self.best_loss_epoch >= self.num_plateau_epochs:
|
519
|
+
return True
|
520
|
+
return None
|
521
|
+
|
522
|
+
|
523
|
+
class MrlTrainerCallback:
|
524
|
+
def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, global_epoch: int,
|
525
|
+
global_epochs: int, curriculum_config: dict) -> None:
|
526
|
+
pass
|
527
|
+
|
528
|
+
def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
|
529
|
+
critic_loss: float, global_epoch: int, global_epochs: int) -> None:
|
530
|
+
pass
|
531
|
+
|
532
|
+
def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
|
533
|
+
reward: float) -> None:
|
534
|
+
pass
|
535
|
+
|
536
|
+
def on_reward(self, actor: nn.Module, reward: float, generated: str, reference: str, saved_data: str, eval_mode: bool) -> None:
|
537
|
+
pass
|
538
|
+
|
539
|
+
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
540
|
+
pass
|
541
|
+
|
542
|
+
def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
|
543
|
+
critic_loss: float) -> None:
|
544
|
+
pass
|
545
|
+
|
546
|
+
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
547
|
+
pass
|
548
|
+
|
549
|
+
def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> Union[bool, None]:
|
550
|
+
pass
|
551
|
+
|
552
|
+
def on_eval_episode_end(self, actor: nn.Module, epoch: int, batch_idx: int, reward: float) -> None:
|
553
|
+
pass
|
554
|
+
|
555
|
+
|
556
|
+
class MrlPrintCallback(MrlTrainerCallback):
|
557
|
+
def on_epoch_start(self, actor: nn.Module, epoch: int, stage_epochs: int, curriculum_config: dict,
|
558
|
+
global_epoch: int, global_epochs: int) -> None:
|
559
|
+
print(
|
560
|
+
f'Starting epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global) for {curriculum_config['steps']} steps in {curriculum_config['strategy']} strategy.')
|
561
|
+
|
562
|
+
def on_epoch_end(self, actor: nn.Module, epoch: int, stage_epochs: int, policy_loss: float,
|
563
|
+
critic_loss: float, global_epoch: int, global_epochs: int) -> None:
|
564
|
+
print(f'Finished epoch {epoch}/{stage_epochs} (stage) | {global_epoch}/{global_epochs} (global)')
|
565
|
+
print(f'Policy mean loss: {policy_loss} | Critic mean loss: {critic_loss}')
|
566
|
+
|
567
|
+
def on_episode_collected(self, actor: nn.Module, batch_idx: int, episode_trajectories: list[dict],
|
568
|
+
reward: float) -> None:
|
569
|
+
print(f'Collected {batch_idx} episode | mean reward {reward}')
|
570
|
+
|
571
|
+
def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
|
572
|
+
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
573
|
+
print(f"{'Eval' if eval_mode else 'Train'} | Collected reward {reward}")
|
574
|
+
|
575
|
+
def on_batch_updated(self, actor: nn.Module, epoch: int, step: int, policy_loss: float) -> None:
|
576
|
+
print(f'Epoch {epoch} | Step {step} - updated policy loss {policy_loss}')
|
577
|
+
|
578
|
+
def on_critic_updated(self, actor: nn.Module, critic: nn.Module, epoch: int, step: int,
|
579
|
+
critic_loss: float) -> None:
|
580
|
+
print(f'Epoch {epoch} | Step {step} - updated policy loss {critic_loss}')
|
581
|
+
|
582
|
+
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
583
|
+
print(f'Finished training for {curriculum_config['steps']} steps in {curriculum_config['strategy']} strategy.')
|
584
|
+
|
585
|
+
def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> None:
|
586
|
+
print(f'Eval epoch {epoch} - mean reward {eval_mean_reward}')
|
587
|
+
|
588
|
+
def on_eval_episode_end(self, actor: nn.Module, epoch: int, batch_idx: int, reward: float) -> None:
|
589
|
+
print(f'Eval epoch {epoch} / Episode {batch_idx} - mean reward {reward}')
|
590
|
+
|
591
|
+
|
592
|
+
class MrlEarlyStoppageCallback(MrlTrainerCallback):
|
593
|
+
def __init__(self, num_plateau_epochs: int = 2, threshold: Optional[float] = None) -> None:
|
594
|
+
super().__init__()
|
595
|
+
self.num_plateau_epochs = num_plateau_epochs
|
596
|
+
self.best_reward = -9999.0
|
597
|
+
self.best_reward_epoch = 0
|
598
|
+
self.threshold = threshold
|
599
|
+
|
600
|
+
def on_eval_end(self, _actor: nn.Module, _critic: nn.Module, epoch: int, eval_mean_reward: float) -> Union[bool, None]:
|
601
|
+
if self.threshold is not None:
|
602
|
+
if eval_mean_reward > self.threshold:
|
603
|
+
return True
|
604
|
+
|
605
|
+
if eval_mean_reward > self.best_reward:
|
606
|
+
self.best_reward = eval_mean_reward
|
607
|
+
self.best_reward_epoch = epoch
|
608
|
+
elif epoch - self.best_reward_epoch >= self.num_plateau_epochs:
|
609
|
+
return True
|
610
|
+
return None
|
611
|
+
|
612
|
+
class MrlModelSaveCallback(MrlTrainerCallback):
|
613
|
+
def __init__(
|
614
|
+
self,
|
615
|
+
save_dir: str,
|
616
|
+
max_keep: int = 3,
|
617
|
+
push_to_hub: bool = False,
|
618
|
+
hub_model_decoder: str = None,
|
619
|
+
hub_model_encoder: str = None,
|
620
|
+
hub_model_memory_attention: str = None,
|
621
|
+
hub_model_critic: str = None,
|
622
|
+
private_repo: bool = False,
|
623
|
+
hf_token: str = None,
|
624
|
+
push_checkpoint_weights: bool = True,
|
625
|
+
final_commit_message: str = None,
|
626
|
+
display_exc_trace: bool = False,
|
627
|
+
use_ddp: bool = False,
|
628
|
+
):
|
629
|
+
self.save_dir = save_dir
|
630
|
+
self.max_keep = max_keep
|
631
|
+
self.best_reward = float('-inf')
|
632
|
+
self.ckpt_paths = []
|
633
|
+
self.push_to_hub = push_to_hub
|
634
|
+
self.hub_model_decoder = hub_model_decoder
|
635
|
+
self.hub_model_encoder = hub_model_encoder
|
636
|
+
self.hub_model_memory_attention = hub_model_memory_attention
|
637
|
+
self.hub_model_critic = hub_model_critic
|
638
|
+
self.private_repo = private_repo
|
639
|
+
self.hf_token = hf_token
|
640
|
+
self.push_checkpoint_weights = push_checkpoint_weights
|
641
|
+
self.final_commit_message = final_commit_message
|
642
|
+
self.finished_epochs = 0
|
643
|
+
self.display_exc_trace = display_exc_trace
|
644
|
+
self.rank = int(os.environ['RANK']) if use_ddp else 0
|
645
|
+
|
646
|
+
def _save_eval(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, epoch: int,
|
647
|
+
reward: float, hub_id: str = None):
|
648
|
+
try:
|
649
|
+
if model.save_pretrained is not None:
|
650
|
+
ckpt_path = os.path.join(
|
651
|
+
self.save_dir,
|
652
|
+
component,
|
653
|
+
f'epoch_{epoch}_eval_reward_{reward:.4f}'
|
654
|
+
)
|
655
|
+
path_exists = os.path.exists(ckpt_path)
|
656
|
+
if not path_exists:
|
657
|
+
os.makedirs(ckpt_path)
|
658
|
+
model.save_pretrained(save_directory=ckpt_path)
|
659
|
+
else:
|
660
|
+
comp_path = os.path.join(
|
661
|
+
self.save_dir,
|
662
|
+
component
|
663
|
+
)
|
664
|
+
path_exists = os.path.exists(comp_path)
|
665
|
+
if not path_exists:
|
666
|
+
os.makedirs(comp_path)
|
667
|
+
ckpt_path = os.path.join(
|
668
|
+
comp_path,
|
669
|
+
f'epoch_{epoch}_eval_reward_{reward:.4f}.pt'
|
670
|
+
)
|
671
|
+
torch.save(model.state_dict(), ckpt_path)
|
672
|
+
self.ckpt_paths.append(ckpt_path)
|
673
|
+
|
674
|
+
# Keep only N best checkpoints
|
675
|
+
if len(self.ckpt_paths) > self.max_keep:
|
676
|
+
oldest_path = self.ckpt_paths.pop(0)
|
677
|
+
if model.save_pretrained is not None:
|
678
|
+
shutil.rmtree(oldest_path)
|
679
|
+
else:
|
680
|
+
os.remove(oldest_path)
|
681
|
+
except Exception as e:
|
682
|
+
print(f"Error saving epoch checkpoint: {str(e)}")
|
683
|
+
if self.display_exc_trace:
|
684
|
+
traceback.print_exc()
|
685
|
+
|
686
|
+
try:
|
687
|
+
if self.push_to_hub and self.push_checkpoint_weights and model.push_to_hub is not None and hub_id:
|
688
|
+
model.push_to_hub(
|
689
|
+
repo_id=hub_id,
|
690
|
+
commit_message=f'Epoch {epoch} - Eval reward {reward:.4f}',
|
691
|
+
token=self.hf_token,
|
692
|
+
private=self.private_repo,
|
693
|
+
)
|
694
|
+
except Exception as e:
|
695
|
+
print(f"Error pushing epoch checkpoint: {str(e)}")
|
696
|
+
if self.display_exc_trace:
|
697
|
+
traceback.print_exc()
|
698
|
+
|
699
|
+
def on_eval_end(self, actor: nn.Module, critic: nn.Module, epoch: int, eval_mean_reward: float) -> None:
|
700
|
+
if self.rank == 0:
|
701
|
+
self.finished_epochs += 1
|
702
|
+
if eval_mean_reward > self.best_reward:
|
703
|
+
self.best_reward = eval_mean_reward
|
704
|
+
if isinstance(actor, DistributedDataParallel):
|
705
|
+
actor = next(actor.children())
|
706
|
+
self._save_eval(actor.encoder, 'encoder', epoch, eval_mean_reward, hub_id=self.hub_model_encoder)
|
707
|
+
self._save_eval(actor.decoder, 'decoder', epoch, eval_mean_reward, hub_id=self.hub_model_decoder)
|
708
|
+
self._save_eval(actor.memory_attention, 'memory_attention', epoch, eval_mean_reward, hub_id=self.hub_model_memory_attention)
|
709
|
+
if isinstance(critic, DistributedDataParallel):
|
710
|
+
critic = next(critic.children())
|
711
|
+
self._save_eval(critic, 'critic', epoch, eval_mean_reward, hub_id=self.hub_model_critic)
|
712
|
+
|
713
|
+
def _save_final(self, model: Union[nn.Module, PyTorchModelHubMixin], component: str, hub_id: str = None):
|
714
|
+
try:
|
715
|
+
# Save final model
|
716
|
+
if model.save_pretrained is not None:
|
717
|
+
ckpt_path = os.path.join(
|
718
|
+
self.save_dir,
|
719
|
+
component,
|
720
|
+
'final_model'
|
721
|
+
)
|
722
|
+
path_exists = os.path.exists(ckpt_path)
|
723
|
+
if not path_exists:
|
724
|
+
os.makedirs(ckpt_path)
|
725
|
+
model.save_pretrained(save_directory=ckpt_path)
|
726
|
+
else:
|
727
|
+
comp_path = os.path.join(
|
728
|
+
self.save_dir,
|
729
|
+
component
|
730
|
+
)
|
731
|
+
path_exists = os.path.exists(comp_path)
|
732
|
+
if not path_exists:
|
733
|
+
os.makedirs(comp_path)
|
734
|
+
ckpt_path = os.path.join(comp_path, 'final_model.pt')
|
735
|
+
torch.save(model.state_dict(), ckpt_path)
|
736
|
+
print(f"Final model saved to {ckpt_path}")
|
737
|
+
except Exception as e:
|
738
|
+
print(f"Error saving final model: {str(e)}")
|
739
|
+
if self.display_exc_trace:
|
740
|
+
traceback.print_exc()
|
741
|
+
try:
|
742
|
+
if self.push_to_hub and model.push_to_hub is not None and hub_id:
|
743
|
+
model.push_to_hub(
|
744
|
+
repo_id=hub_id,
|
745
|
+
commit_message=self.final_commit_message or f'Model after full curriculum stage, after {self.finished_epochs} epochs',
|
746
|
+
token=self.hf_token,
|
747
|
+
private=self.private_repo,
|
748
|
+
)
|
749
|
+
except Exception as e:
|
750
|
+
print(f"Error pushing final model: {str(e)}")
|
751
|
+
if self.display_exc_trace:
|
752
|
+
traceback.print_exc()
|
753
|
+
|
754
|
+
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
755
|
+
if self.rank == 0:
|
756
|
+
if isinstance(actor, DistributedDataParallel):
|
757
|
+
actor = next(actor.children())
|
758
|
+
self._save_final(actor.encoder, 'encoder', hub_id=self.hub_model_encoder)
|
759
|
+
self._save_final(actor.decoder, 'decoder', hub_id=self.hub_model_decoder)
|
760
|
+
self._save_final(actor.memory_attention, 'memory_attention', hub_id=self.hub_model_memory_attention)
|
761
|
+
if isinstance(critic, DistributedDataParallel):
|
762
|
+
critic = next(critic.children())
|
763
|
+
self._save_final(critic, 'critic', hub_id=self.hub_model_critic)
|
764
|
+
|
765
|
+
class MrlGeneratedTokensCallback(MrlTrainerCallback):
|
766
|
+
def __init__(self, steps_log_interval: int = 100):
|
767
|
+
self.total_tokens = 0
|
768
|
+
self.steps_log_interval = steps_log_interval
|
769
|
+
self.step = 0
|
770
|
+
|
771
|
+
def on_reward(self, actor: nn.Module, reward: float, generated: dict[str, torch.Tensor],
|
772
|
+
reference: dict[str, torch.Tensor], saved_data: dict[str, torch.Tensor], eval_mode: bool) -> None:
|
773
|
+
self.step += 1
|
774
|
+
attention_mask = generated['attention_mask']
|
775
|
+
batch_tokens = attention_mask.sum().item()
|
776
|
+
self.total_tokens += batch_tokens
|
777
|
+
if self.step != 0 and self.step % self.steps_log_interval == 0:
|
778
|
+
print(f'Total processed tokens: {human_format(self.total_tokens)}')
|
779
|
+
|
780
|
+
def on_training_end(self, actor: nn.Module, critic: nn.Module, curriculum_config: dict) -> None:
|
781
|
+
print(f'Total training tokens: {human_format(self.total_tokens)}')
|
782
|
+
|
783
|
+
def get_total_tokens(self):
|
784
|
+
return self.total_tokens
|
785
|
+
|