rxnn 0.2.47__tar.gz → 0.2.49__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.47 → rxnn-0.2.49}/PKG-INFO +2 -1
- {rxnn-0.2.47 → rxnn-0.2.49}/pyproject.toml +2 -2
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/norm.py +3 -1
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/mrl.py +34 -10
- rxnn-0.2.49/src/rxnn/training/reward.py +323 -0
- rxnn-0.2.47/src/rxnn/training/reward.py +0 -216
- {rxnn-0.2.47 → rxnn-0.2.49}/LICENSE +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/README.md +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.49
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.13
|
16
16
|
Requires-Dist: datasets (>=3.5.0,<4.0.0)
|
17
17
|
Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
|
18
|
+
Requires-Dist: nltk (>=3.9.1,<4.0.0)
|
18
19
|
Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
|
19
20
|
Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
|
20
21
|
Requires-Dist: torch (>=2.6.0,<3.0.0)
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|
4
4
|
|
5
5
|
[tool.poetry]
|
6
6
|
name = "rxnn"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.49"
|
8
8
|
description = "RxNN: Reactive Neural Networks Platform"
|
9
9
|
|
10
10
|
license = "Apache-2.0"
|
@@ -30,4 +30,4 @@ datasets = "^3.5.0"
|
|
30
30
|
tokenizers = "^0.21.0"
|
31
31
|
huggingface-hub = "^0.30.0"
|
32
32
|
tensorboard = "^2.19.0"
|
33
|
-
|
33
|
+
nltk = "^3.9.1"
|
@@ -163,7 +163,7 @@ def init_memory_norm(
|
|
163
163
|
init_scale: float = 1.0,
|
164
164
|
per_dim_scale: bool = False,
|
165
165
|
) -> nn.Module:
|
166
|
-
assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
|
166
|
+
assert norm_type in ['layer', 'rms', 'adaptive', 'positional', 'classic-rms']
|
167
167
|
if norm_type == 'layer':
|
168
168
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
169
169
|
elif norm_type == 'rms':
|
@@ -172,4 +172,6 @@ def init_memory_norm(
|
|
172
172
|
return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
|
173
173
|
elif norm_type == 'positional':
|
174
174
|
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
|
175
|
+
elif norm_type == 'classic-rms':
|
176
|
+
return nn.RMSNorm(dim)
|
175
177
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
@@ -35,6 +35,7 @@ class MrlConfig(TypedDict):
|
|
35
35
|
moe_aux_loss_scale: Optional[float]
|
36
36
|
freeze_embeddings: Optional[bool]
|
37
37
|
embedding_lr: Optional[float]
|
38
|
+
use_memory_warmup: Optional[bool]
|
38
39
|
|
39
40
|
|
40
41
|
class MrlStrategy(Enum):
|
@@ -70,6 +71,7 @@ class CurriculumConfig(TypedDict):
|
|
70
71
|
update_epochs: Optional[int]
|
71
72
|
freeze_embeddings: Optional[bool]
|
72
73
|
embedding_lr: Optional[float]
|
74
|
+
teacher_forcing: Optional[bool]
|
73
75
|
|
74
76
|
|
75
77
|
class SamplerConfig(TypedDict):
|
@@ -136,6 +138,7 @@ class MRLTrainer:
|
|
136
138
|
self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
|
137
139
|
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
138
140
|
self.freeze_embeddings = self.shared_freeze_embeddings
|
141
|
+
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
139
142
|
# Internal update epochs config
|
140
143
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
141
144
|
self.update_epochs = self.shared_update_epochs
|
@@ -213,6 +216,7 @@ class MRLTrainer:
|
|
213
216
|
self.callbacks = []
|
214
217
|
self.global_epoch = 0
|
215
218
|
self.global_epochs_count = 0
|
219
|
+
self.teacher_forcing = False
|
216
220
|
|
217
221
|
def _init_optimizers(
|
218
222
|
self,
|
@@ -381,6 +385,11 @@ class MRLTrainer:
|
|
381
385
|
self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
|
382
386
|
self.stage_step['collect'])
|
383
387
|
|
388
|
+
def memory_warmup(self, query: TokenizedDict, answer: TokenizedDict):
|
389
|
+
if self.use_memory_warmup:
|
390
|
+
with torch.no_grad():
|
391
|
+
self.encode_and_update_stm(query, answer)
|
392
|
+
|
384
393
|
def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
|
385
394
|
"""Collect trajectories for PPO for current curriculum step."""
|
386
395
|
# 1. Init trajectories list
|
@@ -402,8 +411,13 @@ class MRLTrainer:
|
|
402
411
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
403
412
|
interactions = interactions[:self.curriculum_steps]
|
404
413
|
interactions_len = len(interactions)
|
414
|
+
|
415
|
+
first_interaction = self._move_multiple_batches(first_query, first_answer)
|
416
|
+
|
417
|
+
if reset_done:
|
418
|
+
self.memory_warmup(*first_interaction)
|
405
419
|
# 6. Encode and update STM with data to save from first interaction
|
406
|
-
self.encode_and_update_stm(*
|
420
|
+
self.encode_and_update_stm(*first_interaction)
|
407
421
|
|
408
422
|
# 7. Save first interaction as data to save (for trajectory state)
|
409
423
|
query, answer = first_query, first_answer
|
@@ -440,8 +454,10 @@ class MRLTrainer:
|
|
440
454
|
|
441
455
|
# 11. Update STM with generated response (except last interaction, it's not needed)
|
442
456
|
if not is_last_interaction:
|
443
|
-
self.encode_and_update_stm(
|
444
|
-
|
457
|
+
self.encode_and_update_stm(
|
458
|
+
next_query,
|
459
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
460
|
+
) # update with generated_answer on GPU
|
445
461
|
|
446
462
|
# 12. Store trajectory step
|
447
463
|
trajectory: MrlTrajectoryStep = {
|
@@ -458,7 +474,7 @@ class MRLTrainer:
|
|
458
474
|
# 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
|
459
475
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
460
476
|
prev_interaction = (query, answer)
|
461
|
-
query, answer = interaction['query'], detached_answer
|
477
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
462
478
|
|
463
479
|
# 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
464
480
|
episode_trajectory: MrlTrajectoryEpisode = {
|
@@ -649,6 +665,9 @@ class MRLTrainer:
|
|
649
665
|
|
650
666
|
self.actor.clone_reset_memory()
|
651
667
|
|
668
|
+
if should_reset_stm and step_idx == 0:
|
669
|
+
self.memory_warmup(query, answer)
|
670
|
+
|
652
671
|
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
653
672
|
if self.memory_aware_critic:
|
654
673
|
self.encode_and_update_stm(query, answer)
|
@@ -798,13 +817,16 @@ class MRLTrainer:
|
|
798
817
|
if batch['query']['input_ids'].size(0) == batch_size:
|
799
818
|
self._increment_steps('eval')
|
800
819
|
# 3. Reset STM with random resets ratio and reward model running mean
|
801
|
-
self.reset_stm()
|
820
|
+
reset_stm = self.reset_stm()
|
802
821
|
self.reward.reset_running_mean()
|
803
822
|
|
804
823
|
# 4. Get batches for first queries, answers and all follow-up interactions
|
805
824
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
806
825
|
# 5. Encode and update STM with initial interactions (batch)
|
807
|
-
self.
|
826
|
+
first_interaction = self._move_multiple_batches(first_query, first_answer)
|
827
|
+
if reset_stm:
|
828
|
+
self.memory_warmup(*first_interaction)
|
829
|
+
self.encode_and_update_stm(*first_interaction)
|
808
830
|
|
809
831
|
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
810
832
|
interactions_len = len(interactions)
|
@@ -839,7 +861,10 @@ class MRLTrainer:
|
|
839
861
|
|
840
862
|
# 10. Encode and update memory for the next interaction
|
841
863
|
if not is_last_interaction:
|
842
|
-
self.encode_and_update_stm(
|
864
|
+
self.encode_and_update_stm(
|
865
|
+
next_query,
|
866
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
867
|
+
)
|
843
868
|
|
844
869
|
# 11. Accumulate rewards
|
845
870
|
step_reward = torch.tensor(reward).mean().to(self.device)
|
@@ -852,7 +877,7 @@ class MRLTrainer:
|
|
852
877
|
# 12. Save previous interaction
|
853
878
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
854
879
|
prev_interaction = (query, answer)
|
855
|
-
query, answer = interaction['query'], detached_answer
|
880
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
856
881
|
avg_episode_reward = (episode_reward / episode_interactions).item()
|
857
882
|
# 13. Run eval TensorBoard writer with average episode reward
|
858
883
|
self._eval_writer(avg_episode_reward, epoch)
|
@@ -982,8 +1007,7 @@ class MRLTrainer:
|
|
982
1007
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
983
1008
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
984
1009
|
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
985
|
-
|
986
|
-
|
1010
|
+
self.teacher_forcing = config.get('teacher_forcing', False)
|
987
1011
|
|
988
1012
|
def has_param(field: OptimField) -> bool:
|
989
1013
|
return field in config and config[field] is not None
|
@@ -0,0 +1,323 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
5
|
+
from enum import Enum
|
6
|
+
from typing import Optional
|
7
|
+
from .utils import TokenizedDict
|
8
|
+
|
9
|
+
|
10
|
+
class MrlRewardMode(Enum):
|
11
|
+
STANDARD = 1
|
12
|
+
NEGATIVE = 2
|
13
|
+
LONG_RANGE = 3
|
14
|
+
|
15
|
+
|
16
|
+
class MrlRewardModel:
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
shared_embedding: nn.Embedding,
|
20
|
+
device: torch.device,
|
21
|
+
bleu_with_saved_data: bool = False,
|
22
|
+
bleu_factor: float = 0.5,
|
23
|
+
bleu_ref_factor: float = 0.5,
|
24
|
+
bleu_saved_factor: float = 0.5,
|
25
|
+
cos_factor: float = 0.5,
|
26
|
+
cos_ref_factor: float = 0.5,
|
27
|
+
cos_saved_factor: float = 0.5,
|
28
|
+
multi_cos_ref_factor: float = 0.3,
|
29
|
+
multi_cos_saved_factor: float = 0.5,
|
30
|
+
multi_cos_running_mean_factor: float = 0.2,
|
31
|
+
neg_bleu_factor: Optional[float] = None,
|
32
|
+
neg_cos_factor: Optional[float] = None,
|
33
|
+
neg_cos_ref_factor: Optional[float] = None,
|
34
|
+
neg_cos_saved_factor: Optional[float] = None,
|
35
|
+
neg_bleu_ref_factor: float = 0.5,
|
36
|
+
neg_bleu_saved_factor: float = 0.5,
|
37
|
+
allow_not_summing_factors: bool = False,
|
38
|
+
reward_len: bool = False,
|
39
|
+
neg_reward_len: bool = False,
|
40
|
+
max_rewarded_len: int = None,
|
41
|
+
target_len_as_ref: bool = False,
|
42
|
+
len_factor: int = None,
|
43
|
+
use_running_mean: bool = True,
|
44
|
+
running_mean_decay: float = 0.2,
|
45
|
+
bleu_saved_weights: tuple = (0.5, 0.5),
|
46
|
+
bleu_ref_weights: tuple = (0.5, 0.5),
|
47
|
+
tanh_reward_scale: bool = False,
|
48
|
+
rewards_scale: float = 1.0,
|
49
|
+
debug_mode: int = 0,
|
50
|
+
):
|
51
|
+
self.shared_embedding = shared_embedding.to(device)
|
52
|
+
self.device = device
|
53
|
+
self.bleu_with_saved_data = bleu_with_saved_data
|
54
|
+
|
55
|
+
self.bleu_factor = bleu_factor
|
56
|
+
self.bleu_ref_factor = bleu_ref_factor
|
57
|
+
self.bleu_saved_factor = bleu_saved_factor
|
58
|
+
self.cos_factor = cos_factor
|
59
|
+
self.cos_ref_factor = cos_ref_factor
|
60
|
+
self.cos_saved_factor = cos_saved_factor
|
61
|
+
self.multi_cos_ref_factor = multi_cos_ref_factor
|
62
|
+
self.multi_cos_saved_factor = multi_cos_saved_factor
|
63
|
+
self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
|
64
|
+
self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
|
65
|
+
self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
|
66
|
+
self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
|
67
|
+
self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
|
68
|
+
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
69
|
+
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
70
|
+
self.reward_len = reward_len
|
71
|
+
self.neg_reward_len = neg_reward_len
|
72
|
+
self.max_rewarded_len = max_rewarded_len
|
73
|
+
self.target_len_as_ref = target_len_as_ref
|
74
|
+
self.len_factor = len_factor
|
75
|
+
self.use_running_mean = use_running_mean
|
76
|
+
self.running_mean_decay = running_mean_decay
|
77
|
+
self.bleu_ref_weights = bleu_ref_weights
|
78
|
+
self.bleu_saved_weights = bleu_saved_weights
|
79
|
+
self.tanh_reward_scale = tanh_reward_scale
|
80
|
+
self.rewards_scale = rewards_scale
|
81
|
+
self.bleu_smoothing = SmoothingFunction().method4
|
82
|
+
self.debug_mode = debug_mode
|
83
|
+
|
84
|
+
self.prev_data_running_mean = None
|
85
|
+
|
86
|
+
if not allow_not_summing_factors:
|
87
|
+
if reward_len:
|
88
|
+
assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
|
89
|
+
assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
|
90
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
91
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
92
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
93
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
94
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
95
|
+
else:
|
96
|
+
assert self.bleu_factor + self.cos_factor == 1.0
|
97
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
98
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
99
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
100
|
+
assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
|
101
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
102
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
103
|
+
|
104
|
+
def _sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
105
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
106
|
+
generated, reference, saved_data = input_ids
|
107
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
108
|
+
|
109
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
110
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
111
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
112
|
+
|
113
|
+
if self.debug_mode == 2:
|
114
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
115
|
+
|
116
|
+
if self.bleu_with_saved_data:
|
117
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
118
|
+
smoothing_function=self.bleu_smoothing)
|
119
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
120
|
+
smoothing_function=self.bleu_smoothing)
|
121
|
+
if self.debug_mode == 2:
|
122
|
+
print('REF BLEU: ', ref_bleu)
|
123
|
+
print('SAVED BLEU: ', saved_bleu)
|
124
|
+
|
125
|
+
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
126
|
+
else:
|
127
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
128
|
+
|
129
|
+
def _negative_sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
130
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
131
|
+
generated, reference, saved_data = input_ids
|
132
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
133
|
+
|
134
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
135
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
136
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
137
|
+
|
138
|
+
if self.debug_mode == 2:
|
139
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
140
|
+
|
141
|
+
if self.bleu_with_saved_data:
|
142
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
143
|
+
smoothing_function=self.bleu_smoothing)
|
144
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
145
|
+
smoothing_function=self.bleu_smoothing)
|
146
|
+
saved_bleu = 1 - saved_bleu
|
147
|
+
|
148
|
+
if self.debug_mode == 2:
|
149
|
+
print('REF BLEU: ', ref_bleu)
|
150
|
+
print('SAVED BLEU: ', saved_bleu)
|
151
|
+
|
152
|
+
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
153
|
+
else:
|
154
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
155
|
+
|
156
|
+
def batch_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[float]:
|
157
|
+
batch_size = generated['input_ids'].size(0)
|
158
|
+
|
159
|
+
return [
|
160
|
+
self._sentence_bleu(
|
161
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
162
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
163
|
+
) for i in range(batch_size)
|
164
|
+
]
|
165
|
+
|
166
|
+
def negative_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[
|
167
|
+
float]:
|
168
|
+
batch_size = generated['input_ids'].size(0)
|
169
|
+
|
170
|
+
return [
|
171
|
+
self._negative_sentence_bleu(
|
172
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
173
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
174
|
+
) for i in range(batch_size)
|
175
|
+
]
|
176
|
+
|
177
|
+
def _sequence_embedding(self, sequence: TokenizedDict) -> torch.Tensor:
|
178
|
+
input_ids = sequence['input_ids']
|
179
|
+
attention_mask = sequence['attention_mask']
|
180
|
+
|
181
|
+
# Get embeddings
|
182
|
+
embeddings = self.shared_embedding(input_ids.to(self.device))
|
183
|
+
|
184
|
+
# Apply attention mask
|
185
|
+
mask_expanded = attention_mask.unsqueeze(-1).to(self.device)
|
186
|
+
masked_embeddings = embeddings * mask_expanded
|
187
|
+
|
188
|
+
# Compute mean with masking
|
189
|
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
190
|
+
token_counts = torch.sum(mask_expanded, dim=1)
|
191
|
+
token_counts = torch.clamp(token_counts, min=1e-8) # Avoid division by zero
|
192
|
+
|
193
|
+
return sum_embeddings / token_counts
|
194
|
+
|
195
|
+
def _cosine_sim(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
196
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
197
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
198
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
199
|
+
|
200
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
201
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
202
|
+
|
203
|
+
if self.debug_mode >= 1:
|
204
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
205
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
206
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0)
|
207
|
+
|
208
|
+
def _cosine_sim_running_mean(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
209
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
210
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
211
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
212
|
+
running_emb = F.normalize(self.prev_data_running_mean, dim=-1)
|
213
|
+
|
214
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
215
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
216
|
+
gen_and_mean = F.cosine_similarity(generated_emb, running_emb, dim=1)
|
217
|
+
|
218
|
+
if self.debug_mode >= 1:
|
219
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
220
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
221
|
+
print('GEN AND MEAN: ', gen_and_mean.mean())
|
222
|
+
|
223
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0), torch.clamp(gen_and_mean, min=0)
|
224
|
+
|
225
|
+
def batch_cosine(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict,
|
226
|
+
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
227
|
+
if self.use_running_mean and negative_running_mean:
|
228
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
229
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
230
|
+
1 - gen_and_mean)
|
231
|
+
elif self.use_running_mean and include_running_mean:
|
232
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
233
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
234
|
+
else:
|
235
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
236
|
+
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
237
|
+
|
238
|
+
def negative_cosine(self, generated: TokenizedDict, reference: TokenizedDict,
|
239
|
+
saved_data: TokenizedDict) -> torch.Tensor:
|
240
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
241
|
+
|
242
|
+
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
243
|
+
|
244
|
+
def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
|
245
|
+
target_lens = reference['attention_mask'].sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
|
246
|
+
lens = generated['attention_mask'].sum(dim=1)
|
247
|
+
neg_lens = target_lens / lens if self.neg_reward_len else 1.0
|
248
|
+
len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
|
249
|
+
return len_reward
|
250
|
+
|
251
|
+
def reset_running_mean(self):
|
252
|
+
self.prev_data_running_mean = None
|
253
|
+
|
254
|
+
def init_running_mean(self, prev_data: TokenizedDict):
|
255
|
+
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
256
|
+
|
257
|
+
def update_running_mean(self, prev_data: TokenizedDict):
|
258
|
+
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
259
|
+
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
260
|
+
|
261
|
+
def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
|
262
|
+
if self.tanh_reward_scale:
|
263
|
+
return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
|
264
|
+
else:
|
265
|
+
return rewards
|
266
|
+
|
267
|
+
def __call__(
|
268
|
+
self,
|
269
|
+
generated: TokenizedDict,
|
270
|
+
reference: TokenizedDict,
|
271
|
+
saved_data: TokenizedDict,
|
272
|
+
prev_data: TokenizedDict = None,
|
273
|
+
mode: MrlRewardMode = MrlRewardMode.STANDARD
|
274
|
+
) -> list[float]:
|
275
|
+
if prev_data is not None:
|
276
|
+
if self.prev_data_running_mean is None:
|
277
|
+
self.init_running_mean(prev_data)
|
278
|
+
else:
|
279
|
+
self.update_running_mean(prev_data)
|
280
|
+
|
281
|
+
if mode == MrlRewardMode.STANDARD:
|
282
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
283
|
+
cosine = self.batch_cosine(generated, reference, saved_data, include_running_mean=prev_data is not None)
|
284
|
+
|
285
|
+
if self.debug_mode >= 1:
|
286
|
+
print('STANDARD MODE')
|
287
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
288
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
289
|
+
|
290
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
291
|
+
elif mode == MrlRewardMode.LONG_RANGE:
|
292
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
293
|
+
cosine = self.batch_cosine(generated, reference, saved_data,
|
294
|
+
negative_running_mean=prev_data is not None)
|
295
|
+
|
296
|
+
if self.debug_mode >= 1:
|
297
|
+
print('LONG MODE')
|
298
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
299
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
300
|
+
|
301
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
302
|
+
else:
|
303
|
+
bleu = self.negative_bleu(generated, reference, saved_data)
|
304
|
+
cosine = self.negative_cosine(generated, reference, saved_data)
|
305
|
+
|
306
|
+
if self.debug_mode >= 1:
|
307
|
+
print('NEGATIVE MODE')
|
308
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
309
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
310
|
+
|
311
|
+
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
312
|
+
|
313
|
+
if self.reward_len:
|
314
|
+
len_reward = self.len_reward(generated, reference)
|
315
|
+
|
316
|
+
if self.debug_mode >= 1:
|
317
|
+
print('REWARD LEN: ', (len_reward.sum() / len_reward.size(0)).item())
|
318
|
+
|
319
|
+
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * len_reward) * self.rewards_scale
|
320
|
+
else:
|
321
|
+
rewards = self._pre_scale_rewards(sim_rewards) * self.rewards_scale
|
322
|
+
|
323
|
+
return rewards.tolist()
|
@@ -1,216 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.nn as nn
|
3
|
-
import torch.nn.functional as F
|
4
|
-
from enum import Enum
|
5
|
-
from typing import Optional
|
6
|
-
from .utils import TokenizedDict
|
7
|
-
|
8
|
-
|
9
|
-
class MrlRewardMode(Enum):
|
10
|
-
STANDARD = 1
|
11
|
-
NEGATIVE = 2
|
12
|
-
LONG_RANGE = 3
|
13
|
-
|
14
|
-
|
15
|
-
class MrlRewardModel:
|
16
|
-
def __init__(
|
17
|
-
self,
|
18
|
-
shared_embedding: nn.Embedding,
|
19
|
-
device: torch.device,
|
20
|
-
bleu_with_saved_data: bool = False,
|
21
|
-
bleu_factor: float = 0.5,
|
22
|
-
bleu_ref_factor: float = 0.5,
|
23
|
-
bleu_saved_factor: float = 0.5,
|
24
|
-
cos_factor: float = 0.5,
|
25
|
-
cos_ref_factor: float = 0.5,
|
26
|
-
cos_saved_factor: float = 0.5,
|
27
|
-
multi_cos_ref_factor: float = 0.3,
|
28
|
-
multi_cos_saved_factor: float = 0.5,
|
29
|
-
multi_cos_running_mean_factor: float = 0.2,
|
30
|
-
neg_bleu_factor: Optional[float] = None,
|
31
|
-
neg_cos_factor: Optional[float] = None,
|
32
|
-
neg_cos_ref_factor: Optional[float] = None,
|
33
|
-
neg_cos_saved_factor: Optional[float] = None,
|
34
|
-
neg_bleu_ref_factor: float = 0.5,
|
35
|
-
neg_bleu_saved_factor: float = 0.5,
|
36
|
-
allow_not_summing_factors: bool = False,
|
37
|
-
reward_len: bool = False,
|
38
|
-
neg_reward_len: bool = False,
|
39
|
-
max_rewarded_len: int = None,
|
40
|
-
len_factor: int = None,
|
41
|
-
use_running_mean: bool = True,
|
42
|
-
running_mean_decay: float = 0.2,
|
43
|
-
bleu_saved_weights: tuple = (0.5, 0.5),
|
44
|
-
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
|
-
tanh_reward_scale: bool = False,
|
46
|
-
rewards_scale: float = 1.0,
|
47
|
-
):
|
48
|
-
self.shared_embedding = shared_embedding.to(device)
|
49
|
-
self.device = device
|
50
|
-
self.bleu_with_saved_data = bleu_with_saved_data
|
51
|
-
|
52
|
-
self.bleu_factor = bleu_factor
|
53
|
-
self.bleu_ref_factor = bleu_ref_factor
|
54
|
-
self.bleu_saved_factor = bleu_saved_factor
|
55
|
-
self.cos_factor = cos_factor
|
56
|
-
self.cos_ref_factor = cos_ref_factor
|
57
|
-
self.cos_saved_factor = cos_saved_factor
|
58
|
-
self.multi_cos_ref_factor = multi_cos_ref_factor
|
59
|
-
self.multi_cos_saved_factor = multi_cos_saved_factor
|
60
|
-
self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
|
61
|
-
self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
|
62
|
-
self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
|
63
|
-
self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
|
64
|
-
self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
|
65
|
-
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
66
|
-
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
67
|
-
self.reward_len = reward_len
|
68
|
-
self.neg_reward_len = neg_reward_len
|
69
|
-
self.max_rewarded_len = max_rewarded_len
|
70
|
-
self.len_factor = len_factor
|
71
|
-
self.use_running_mean = use_running_mean
|
72
|
-
self.running_mean_decay = running_mean_decay
|
73
|
-
self.bleu_ref_weights = bleu_ref_weights
|
74
|
-
self.bleu_saved_weights = bleu_saved_weights
|
75
|
-
self.tanh_reward_scale = tanh_reward_scale
|
76
|
-
self.rewards_scale = rewards_scale
|
77
|
-
|
78
|
-
self.prev_data_running_mean = None
|
79
|
-
|
80
|
-
if not allow_not_summing_factors:
|
81
|
-
if reward_len:
|
82
|
-
assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
|
83
|
-
assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
|
84
|
-
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
85
|
-
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
86
|
-
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
87
|
-
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
88
|
-
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
89
|
-
else:
|
90
|
-
assert self.bleu_factor + self.cos_factor == 1.0
|
91
|
-
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
92
|
-
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
93
|
-
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
94
|
-
assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
|
95
|
-
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
96
|
-
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
97
|
-
|
98
|
-
def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
|
99
|
-
from nltk.translate.bleu_score import sentence_bleu
|
100
|
-
|
101
|
-
if self.bleu_with_saved_data:
|
102
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
103
|
-
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
104
|
-
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
105
|
-
else:
|
106
|
-
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
107
|
-
|
108
|
-
|
109
|
-
def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
|
110
|
-
saved_data: torch.Tensor) -> float:
|
111
|
-
from nltk.translate.bleu_score import sentence_bleu
|
112
|
-
|
113
|
-
if self.bleu_with_saved_data:
|
114
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
115
|
-
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
116
|
-
saved_bleu = 1 - saved_bleu
|
117
|
-
|
118
|
-
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
119
|
-
else:
|
120
|
-
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
121
|
-
|
122
|
-
def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
|
123
|
-
batch_size = generated.size(0)
|
124
|
-
return [self._sentence_bleu(generated[i], reference[i], saved_data[i]) for i in range(batch_size)]
|
125
|
-
|
126
|
-
def _sequence_embedding(self, sequence: torch.Tensor) -> torch.Tensor:
|
127
|
-
embedding = self.shared_embedding(sequence.to(self.device))
|
128
|
-
return embedding.mean(dim=1)
|
129
|
-
|
130
|
-
def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
131
|
-
generated_emb = self._sequence_embedding(generated)
|
132
|
-
|
133
|
-
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
134
|
-
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
135
|
-
return gen_and_saved, gen_and_ref
|
136
|
-
|
137
|
-
def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
138
|
-
generated_emb = self._sequence_embedding(generated)
|
139
|
-
|
140
|
-
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
141
|
-
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
142
|
-
gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
|
143
|
-
return gen_and_saved, gen_and_ref, gen_and_mean
|
144
|
-
|
145
|
-
def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
|
146
|
-
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
147
|
-
if self.use_running_mean and negative_running_mean:
|
148
|
-
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
149
|
-
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
150
|
-
1 - gen_and_mean)
|
151
|
-
elif self.use_running_mean and include_running_mean:
|
152
|
-
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
153
|
-
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
154
|
-
else:
|
155
|
-
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
156
|
-
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
157
|
-
|
158
|
-
def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
|
159
|
-
saved_data: torch.Tensor) -> torch.Tensor:
|
160
|
-
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
161
|
-
|
162
|
-
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
163
|
-
|
164
|
-
def len_reward(self, generated: TokenizedDict):
|
165
|
-
lens = generated['attention_mask'].sum(dim=1)
|
166
|
-
neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
|
167
|
-
len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
|
168
|
-
return len_reward
|
169
|
-
|
170
|
-
def reset_running_mean(self):
|
171
|
-
self.prev_data_running_mean = None
|
172
|
-
|
173
|
-
def init_running_mean(self, prev_data: torch.Tensor):
|
174
|
-
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
175
|
-
|
176
|
-
def update_running_mean(self, prev_data: torch.Tensor):
|
177
|
-
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
178
|
-
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
179
|
-
|
180
|
-
def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
|
181
|
-
if self.tanh_reward_scale:
|
182
|
-
return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
|
183
|
-
else:
|
184
|
-
return rewards
|
185
|
-
|
186
|
-
def __call__(
|
187
|
-
self,
|
188
|
-
generated: TokenizedDict,
|
189
|
-
reference: TokenizedDict,
|
190
|
-
saved_data: TokenizedDict,
|
191
|
-
prev_data: TokenizedDict = None,
|
192
|
-
mode: MrlRewardMode = MrlRewardMode.STANDARD
|
193
|
-
) -> list[float]:
|
194
|
-
if prev_data is not None:
|
195
|
-
if self.prev_data_running_mean is None:
|
196
|
-
self.init_running_mean(prev_data['input_ids'])
|
197
|
-
else:
|
198
|
-
self.update_running_mean(prev_data['input_ids'])
|
199
|
-
|
200
|
-
if mode == MrlRewardMode.STANDARD:
|
201
|
-
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
202
|
-
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
203
|
-
include_running_mean=prev_data is not None)
|
204
|
-
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
205
|
-
elif mode == MrlRewardMode.LONG_RANGE:
|
206
|
-
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
207
|
-
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
208
|
-
negative_running_mean=prev_data is not None)
|
209
|
-
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
210
|
-
else:
|
211
|
-
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
212
|
-
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
213
|
-
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
214
|
-
|
215
|
-
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
216
|
-
return rewards.tolist()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|