rxnn 0.2.48__tar.gz → 0.2.50__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rxnn-0.2.48 → rxnn-0.2.50}/PKG-INFO +2 -1
- {rxnn-0.2.48 → rxnn-0.2.50}/pyproject.toml +2 -2
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/mrl.py +13 -7
- rxnn-0.2.50/src/rxnn/training/reward.py +323 -0
- rxnn-0.2.48/src/rxnn/training/reward.py +0 -216
- {rxnn-0.2.48 → rxnn-0.2.50}/LICENSE +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/README.md +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/.DS_Store +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/__init__.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/__init__.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/attention.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/models.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/moe.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/__init__.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/attention.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/norm.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/stm.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/rxt/__init__.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/rxt/models.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/__init__.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/base.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/bml.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/callbacks.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/dataset.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/ddp.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/models.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/rl.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/scheduler.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/tokenizer.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/utils.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/__init__.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/attention.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/ff.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/layers.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/mask.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/models.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/moe.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/positional.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/sampler.py +0 -0
- {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.50
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.13
|
16
16
|
Requires-Dist: datasets (>=3.5.0,<4.0.0)
|
17
17
|
Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
|
18
|
+
Requires-Dist: nltk (>=3.9.1,<4.0.0)
|
18
19
|
Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
|
19
20
|
Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
|
20
21
|
Requires-Dist: torch (>=2.6.0,<3.0.0)
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|
4
4
|
|
5
5
|
[tool.poetry]
|
6
6
|
name = "rxnn"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.50"
|
8
8
|
description = "RxNN: Reactive Neural Networks Platform"
|
9
9
|
|
10
10
|
license = "Apache-2.0"
|
@@ -30,4 +30,4 @@ datasets = "^3.5.0"
|
|
30
30
|
tokenizers = "^0.21.0"
|
31
31
|
huggingface-hub = "^0.30.0"
|
32
32
|
tensorboard = "^2.19.0"
|
33
|
-
|
33
|
+
nltk = "^3.9.1"
|
@@ -71,6 +71,7 @@ class CurriculumConfig(TypedDict):
|
|
71
71
|
update_epochs: Optional[int]
|
72
72
|
freeze_embeddings: Optional[bool]
|
73
73
|
embedding_lr: Optional[float]
|
74
|
+
teacher_forcing: Optional[bool]
|
74
75
|
|
75
76
|
|
76
77
|
class SamplerConfig(TypedDict):
|
@@ -215,6 +216,7 @@ class MRLTrainer:
|
|
215
216
|
self.callbacks = []
|
216
217
|
self.global_epoch = 0
|
217
218
|
self.global_epochs_count = 0
|
219
|
+
self.teacher_forcing = False
|
218
220
|
|
219
221
|
def _init_optimizers(
|
220
222
|
self,
|
@@ -452,8 +454,10 @@ class MRLTrainer:
|
|
452
454
|
|
453
455
|
# 11. Update STM with generated response (except last interaction, it's not needed)
|
454
456
|
if not is_last_interaction:
|
455
|
-
self.encode_and_update_stm(
|
456
|
-
|
457
|
+
self.encode_and_update_stm(
|
458
|
+
next_query,
|
459
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
460
|
+
) # update with generated_answer on GPU
|
457
461
|
|
458
462
|
# 12. Store trajectory step
|
459
463
|
trajectory: MrlTrajectoryStep = {
|
@@ -470,7 +474,7 @@ class MRLTrainer:
|
|
470
474
|
# 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
|
471
475
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
472
476
|
prev_interaction = (query, answer)
|
473
|
-
query, answer = interaction['query'], detached_answer
|
477
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
474
478
|
|
475
479
|
# 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
476
480
|
episode_trajectory: MrlTrajectoryEpisode = {
|
@@ -857,7 +861,10 @@ class MRLTrainer:
|
|
857
861
|
|
858
862
|
# 10. Encode and update memory for the next interaction
|
859
863
|
if not is_last_interaction:
|
860
|
-
self.encode_and_update_stm(
|
864
|
+
self.encode_and_update_stm(
|
865
|
+
next_query,
|
866
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
867
|
+
)
|
861
868
|
|
862
869
|
# 11. Accumulate rewards
|
863
870
|
step_reward = torch.tensor(reward).mean().to(self.device)
|
@@ -870,7 +877,7 @@ class MRLTrainer:
|
|
870
877
|
# 12. Save previous interaction
|
871
878
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
872
879
|
prev_interaction = (query, answer)
|
873
|
-
query, answer = interaction['query'], detached_answer
|
880
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
874
881
|
avg_episode_reward = (episode_reward / episode_interactions).item()
|
875
882
|
# 13. Run eval TensorBoard writer with average episode reward
|
876
883
|
self._eval_writer(avg_episode_reward, epoch)
|
@@ -1000,8 +1007,7 @@ class MRLTrainer:
|
|
1000
1007
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
1001
1008
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
1002
1009
|
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
1003
|
-
|
1004
|
-
|
1010
|
+
self.teacher_forcing = config.get('teacher_forcing', False)
|
1005
1011
|
|
1006
1012
|
def has_param(field: OptimField) -> bool:
|
1007
1013
|
return field in config and config[field] is not None
|
@@ -0,0 +1,323 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
5
|
+
from enum import Enum
|
6
|
+
from typing import Optional
|
7
|
+
from .utils import TokenizedDict
|
8
|
+
|
9
|
+
|
10
|
+
class MrlRewardMode(Enum):
|
11
|
+
STANDARD = 1
|
12
|
+
NEGATIVE = 2
|
13
|
+
LONG_RANGE = 3
|
14
|
+
|
15
|
+
|
16
|
+
class MrlRewardModel:
|
17
|
+
def __init__(
|
18
|
+
self,
|
19
|
+
shared_embedding: nn.Embedding,
|
20
|
+
device: torch.device,
|
21
|
+
bleu_with_saved_data: bool = False,
|
22
|
+
bleu_factor: float = 0.5,
|
23
|
+
bleu_ref_factor: float = 0.5,
|
24
|
+
bleu_saved_factor: float = 0.5,
|
25
|
+
cos_factor: float = 0.5,
|
26
|
+
cos_ref_factor: float = 0.5,
|
27
|
+
cos_saved_factor: float = 0.5,
|
28
|
+
multi_cos_ref_factor: float = 0.3,
|
29
|
+
multi_cos_saved_factor: float = 0.5,
|
30
|
+
multi_cos_running_mean_factor: float = 0.2,
|
31
|
+
neg_bleu_factor: Optional[float] = None,
|
32
|
+
neg_cos_factor: Optional[float] = None,
|
33
|
+
neg_cos_ref_factor: Optional[float] = None,
|
34
|
+
neg_cos_saved_factor: Optional[float] = None,
|
35
|
+
neg_bleu_ref_factor: float = 0.5,
|
36
|
+
neg_bleu_saved_factor: float = 0.5,
|
37
|
+
allow_not_summing_factors: bool = False,
|
38
|
+
reward_len: bool = False,
|
39
|
+
neg_reward_len: bool = False,
|
40
|
+
max_rewarded_len: int = None,
|
41
|
+
target_len_as_ref: bool = False,
|
42
|
+
len_factor: int = None,
|
43
|
+
use_running_mean: bool = True,
|
44
|
+
running_mean_decay: float = 0.2,
|
45
|
+
bleu_saved_weights: tuple = (0.5, 0.5),
|
46
|
+
bleu_ref_weights: tuple = (0.5, 0.5),
|
47
|
+
tanh_reward_scale: bool = False,
|
48
|
+
rewards_scale: float = 1.0,
|
49
|
+
debug_mode: int = 0,
|
50
|
+
):
|
51
|
+
self.shared_embedding = shared_embedding.to(device)
|
52
|
+
self.device = device
|
53
|
+
self.bleu_with_saved_data = bleu_with_saved_data
|
54
|
+
|
55
|
+
self.bleu_factor = bleu_factor
|
56
|
+
self.bleu_ref_factor = bleu_ref_factor
|
57
|
+
self.bleu_saved_factor = bleu_saved_factor
|
58
|
+
self.cos_factor = cos_factor
|
59
|
+
self.cos_ref_factor = cos_ref_factor
|
60
|
+
self.cos_saved_factor = cos_saved_factor
|
61
|
+
self.multi_cos_ref_factor = multi_cos_ref_factor
|
62
|
+
self.multi_cos_saved_factor = multi_cos_saved_factor
|
63
|
+
self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
|
64
|
+
self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
|
65
|
+
self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
|
66
|
+
self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
|
67
|
+
self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
|
68
|
+
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
69
|
+
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
70
|
+
self.reward_len = reward_len
|
71
|
+
self.neg_reward_len = neg_reward_len
|
72
|
+
self.max_rewarded_len = max_rewarded_len
|
73
|
+
self.target_len_as_ref = target_len_as_ref
|
74
|
+
self.len_factor = len_factor
|
75
|
+
self.use_running_mean = use_running_mean
|
76
|
+
self.running_mean_decay = running_mean_decay
|
77
|
+
self.bleu_ref_weights = bleu_ref_weights
|
78
|
+
self.bleu_saved_weights = bleu_saved_weights
|
79
|
+
self.tanh_reward_scale = tanh_reward_scale
|
80
|
+
self.rewards_scale = rewards_scale
|
81
|
+
self.bleu_smoothing = SmoothingFunction().method4
|
82
|
+
self.debug_mode = debug_mode
|
83
|
+
|
84
|
+
self.prev_data_running_mean = None
|
85
|
+
|
86
|
+
if not allow_not_summing_factors:
|
87
|
+
if reward_len:
|
88
|
+
assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
|
89
|
+
assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
|
90
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
91
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
92
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
93
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
94
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
95
|
+
else:
|
96
|
+
assert self.bleu_factor + self.cos_factor == 1.0
|
97
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
98
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
99
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
100
|
+
assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
|
101
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
102
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
103
|
+
|
104
|
+
def _sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
105
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
106
|
+
generated, reference, saved_data = input_ids
|
107
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
108
|
+
|
109
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
110
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
111
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
112
|
+
|
113
|
+
if self.debug_mode == 2:
|
114
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
115
|
+
|
116
|
+
if self.bleu_with_saved_data:
|
117
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
118
|
+
smoothing_function=self.bleu_smoothing)
|
119
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
120
|
+
smoothing_function=self.bleu_smoothing)
|
121
|
+
if self.debug_mode == 2:
|
122
|
+
print('REF BLEU: ', ref_bleu)
|
123
|
+
print('SAVED BLEU: ', saved_bleu)
|
124
|
+
|
125
|
+
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
126
|
+
else:
|
127
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
128
|
+
|
129
|
+
def _negative_sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
130
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
131
|
+
generated, reference, saved_data = input_ids
|
132
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
133
|
+
|
134
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
135
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
136
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
137
|
+
|
138
|
+
if self.debug_mode == 2:
|
139
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
140
|
+
|
141
|
+
if self.bleu_with_saved_data:
|
142
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
143
|
+
smoothing_function=self.bleu_smoothing)
|
144
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
145
|
+
smoothing_function=self.bleu_smoothing)
|
146
|
+
saved_bleu = 1 - saved_bleu
|
147
|
+
|
148
|
+
if self.debug_mode == 2:
|
149
|
+
print('REF BLEU: ', ref_bleu)
|
150
|
+
print('SAVED BLEU: ', saved_bleu)
|
151
|
+
|
152
|
+
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
153
|
+
else:
|
154
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
155
|
+
|
156
|
+
def batch_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[float]:
|
157
|
+
batch_size = generated['input_ids'].size(0)
|
158
|
+
|
159
|
+
return [
|
160
|
+
self._sentence_bleu(
|
161
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
162
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
163
|
+
) for i in range(batch_size)
|
164
|
+
]
|
165
|
+
|
166
|
+
def negative_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[
|
167
|
+
float]:
|
168
|
+
batch_size = generated['input_ids'].size(0)
|
169
|
+
|
170
|
+
return [
|
171
|
+
self._negative_sentence_bleu(
|
172
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
173
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
174
|
+
) for i in range(batch_size)
|
175
|
+
]
|
176
|
+
|
177
|
+
def _sequence_embedding(self, sequence: TokenizedDict) -> torch.Tensor:
|
178
|
+
input_ids = sequence['input_ids']
|
179
|
+
attention_mask = sequence['attention_mask']
|
180
|
+
|
181
|
+
# Get embeddings
|
182
|
+
embeddings = self.shared_embedding(input_ids.to(self.device))
|
183
|
+
|
184
|
+
# Apply attention mask
|
185
|
+
mask_expanded = attention_mask.unsqueeze(-1).to(self.device)
|
186
|
+
masked_embeddings = embeddings * mask_expanded
|
187
|
+
|
188
|
+
# Compute mean with masking
|
189
|
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
190
|
+
token_counts = torch.sum(mask_expanded, dim=1)
|
191
|
+
token_counts = torch.clamp(token_counts, min=1e-8) # Avoid division by zero
|
192
|
+
|
193
|
+
return sum_embeddings / token_counts
|
194
|
+
|
195
|
+
def _cosine_sim(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
196
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
197
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
198
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
199
|
+
|
200
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
201
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
202
|
+
|
203
|
+
if self.debug_mode >= 1:
|
204
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
205
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
206
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0)
|
207
|
+
|
208
|
+
def _cosine_sim_running_mean(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
209
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
210
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
211
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
212
|
+
running_emb = F.normalize(self.prev_data_running_mean, dim=-1)
|
213
|
+
|
214
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
215
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
216
|
+
gen_and_mean = F.cosine_similarity(generated_emb, running_emb, dim=1)
|
217
|
+
|
218
|
+
if self.debug_mode >= 1:
|
219
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
220
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
221
|
+
print('GEN AND MEAN: ', gen_and_mean.mean())
|
222
|
+
|
223
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0), torch.clamp(gen_and_mean, min=0)
|
224
|
+
|
225
|
+
def batch_cosine(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict,
|
226
|
+
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
227
|
+
if self.use_running_mean and negative_running_mean:
|
228
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
229
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
230
|
+
1 - gen_and_mean)
|
231
|
+
elif self.use_running_mean and include_running_mean:
|
232
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
233
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
234
|
+
else:
|
235
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
236
|
+
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
237
|
+
|
238
|
+
def negative_cosine(self, generated: TokenizedDict, reference: TokenizedDict,
|
239
|
+
saved_data: TokenizedDict) -> torch.Tensor:
|
240
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
241
|
+
|
242
|
+
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
243
|
+
|
244
|
+
def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
|
245
|
+
target_lens = reference['attention_mask'].to(self.device).sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
|
246
|
+
lens = generated['attention_mask'].to(self.device).sum(dim=1)
|
247
|
+
neg_lens = target_lens / lens if self.neg_reward_len else 1.0
|
248
|
+
len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
|
249
|
+
return len_reward
|
250
|
+
|
251
|
+
def reset_running_mean(self):
|
252
|
+
self.prev_data_running_mean = None
|
253
|
+
|
254
|
+
def init_running_mean(self, prev_data: TokenizedDict):
|
255
|
+
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
256
|
+
|
257
|
+
def update_running_mean(self, prev_data: TokenizedDict):
|
258
|
+
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
259
|
+
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
260
|
+
|
261
|
+
def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
|
262
|
+
if self.tanh_reward_scale:
|
263
|
+
return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
|
264
|
+
else:
|
265
|
+
return rewards
|
266
|
+
|
267
|
+
def __call__(
|
268
|
+
self,
|
269
|
+
generated: TokenizedDict,
|
270
|
+
reference: TokenizedDict,
|
271
|
+
saved_data: TokenizedDict,
|
272
|
+
prev_data: TokenizedDict = None,
|
273
|
+
mode: MrlRewardMode = MrlRewardMode.STANDARD
|
274
|
+
) -> list[float]:
|
275
|
+
if prev_data is not None:
|
276
|
+
if self.prev_data_running_mean is None:
|
277
|
+
self.init_running_mean(prev_data)
|
278
|
+
else:
|
279
|
+
self.update_running_mean(prev_data)
|
280
|
+
|
281
|
+
if mode == MrlRewardMode.STANDARD:
|
282
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
283
|
+
cosine = self.batch_cosine(generated, reference, saved_data, include_running_mean=prev_data is not None)
|
284
|
+
|
285
|
+
if self.debug_mode >= 1:
|
286
|
+
print('STANDARD MODE')
|
287
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
288
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
289
|
+
|
290
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
291
|
+
elif mode == MrlRewardMode.LONG_RANGE:
|
292
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
293
|
+
cosine = self.batch_cosine(generated, reference, saved_data,
|
294
|
+
negative_running_mean=prev_data is not None)
|
295
|
+
|
296
|
+
if self.debug_mode >= 1:
|
297
|
+
print('LONG MODE')
|
298
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
299
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
300
|
+
|
301
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
302
|
+
else:
|
303
|
+
bleu = self.negative_bleu(generated, reference, saved_data)
|
304
|
+
cosine = self.negative_cosine(generated, reference, saved_data)
|
305
|
+
|
306
|
+
if self.debug_mode >= 1:
|
307
|
+
print('NEGATIVE MODE')
|
308
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
309
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
310
|
+
|
311
|
+
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
312
|
+
|
313
|
+
if self.reward_len:
|
314
|
+
len_reward = self.len_reward(generated, reference)
|
315
|
+
|
316
|
+
if self.debug_mode >= 1:
|
317
|
+
print('REWARD LEN: ', (len_reward.sum() / len_reward.size(0)).item())
|
318
|
+
|
319
|
+
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * len_reward) * self.rewards_scale
|
320
|
+
else:
|
321
|
+
rewards = self._pre_scale_rewards(sim_rewards) * self.rewards_scale
|
322
|
+
|
323
|
+
return rewards.tolist()
|
@@ -1,216 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.nn as nn
|
3
|
-
import torch.nn.functional as F
|
4
|
-
from enum import Enum
|
5
|
-
from typing import Optional
|
6
|
-
from .utils import TokenizedDict
|
7
|
-
|
8
|
-
|
9
|
-
class MrlRewardMode(Enum):
|
10
|
-
STANDARD = 1
|
11
|
-
NEGATIVE = 2
|
12
|
-
LONG_RANGE = 3
|
13
|
-
|
14
|
-
|
15
|
-
class MrlRewardModel:
|
16
|
-
def __init__(
|
17
|
-
self,
|
18
|
-
shared_embedding: nn.Embedding,
|
19
|
-
device: torch.device,
|
20
|
-
bleu_with_saved_data: bool = False,
|
21
|
-
bleu_factor: float = 0.5,
|
22
|
-
bleu_ref_factor: float = 0.5,
|
23
|
-
bleu_saved_factor: float = 0.5,
|
24
|
-
cos_factor: float = 0.5,
|
25
|
-
cos_ref_factor: float = 0.5,
|
26
|
-
cos_saved_factor: float = 0.5,
|
27
|
-
multi_cos_ref_factor: float = 0.3,
|
28
|
-
multi_cos_saved_factor: float = 0.5,
|
29
|
-
multi_cos_running_mean_factor: float = 0.2,
|
30
|
-
neg_bleu_factor: Optional[float] = None,
|
31
|
-
neg_cos_factor: Optional[float] = None,
|
32
|
-
neg_cos_ref_factor: Optional[float] = None,
|
33
|
-
neg_cos_saved_factor: Optional[float] = None,
|
34
|
-
neg_bleu_ref_factor: float = 0.5,
|
35
|
-
neg_bleu_saved_factor: float = 0.5,
|
36
|
-
allow_not_summing_factors: bool = False,
|
37
|
-
reward_len: bool = False,
|
38
|
-
neg_reward_len: bool = False,
|
39
|
-
max_rewarded_len: int = None,
|
40
|
-
len_factor: int = None,
|
41
|
-
use_running_mean: bool = True,
|
42
|
-
running_mean_decay: float = 0.2,
|
43
|
-
bleu_saved_weights: tuple = (0.5, 0.5),
|
44
|
-
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
|
-
tanh_reward_scale: bool = False,
|
46
|
-
rewards_scale: float = 1.0,
|
47
|
-
):
|
48
|
-
self.shared_embedding = shared_embedding.to(device)
|
49
|
-
self.device = device
|
50
|
-
self.bleu_with_saved_data = bleu_with_saved_data
|
51
|
-
|
52
|
-
self.bleu_factor = bleu_factor
|
53
|
-
self.bleu_ref_factor = bleu_ref_factor
|
54
|
-
self.bleu_saved_factor = bleu_saved_factor
|
55
|
-
self.cos_factor = cos_factor
|
56
|
-
self.cos_ref_factor = cos_ref_factor
|
57
|
-
self.cos_saved_factor = cos_saved_factor
|
58
|
-
self.multi_cos_ref_factor = multi_cos_ref_factor
|
59
|
-
self.multi_cos_saved_factor = multi_cos_saved_factor
|
60
|
-
self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
|
61
|
-
self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
|
62
|
-
self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
|
63
|
-
self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
|
64
|
-
self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
|
65
|
-
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
66
|
-
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
67
|
-
self.reward_len = reward_len
|
68
|
-
self.neg_reward_len = neg_reward_len
|
69
|
-
self.max_rewarded_len = max_rewarded_len
|
70
|
-
self.len_factor = len_factor
|
71
|
-
self.use_running_mean = use_running_mean
|
72
|
-
self.running_mean_decay = running_mean_decay
|
73
|
-
self.bleu_ref_weights = bleu_ref_weights
|
74
|
-
self.bleu_saved_weights = bleu_saved_weights
|
75
|
-
self.tanh_reward_scale = tanh_reward_scale
|
76
|
-
self.rewards_scale = rewards_scale
|
77
|
-
|
78
|
-
self.prev_data_running_mean = None
|
79
|
-
|
80
|
-
if not allow_not_summing_factors:
|
81
|
-
if reward_len:
|
82
|
-
assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
|
83
|
-
assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
|
84
|
-
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
85
|
-
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
86
|
-
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
87
|
-
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
88
|
-
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
89
|
-
else:
|
90
|
-
assert self.bleu_factor + self.cos_factor == 1.0
|
91
|
-
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
92
|
-
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
93
|
-
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
94
|
-
assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
|
95
|
-
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
96
|
-
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
97
|
-
|
98
|
-
def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
|
99
|
-
from nltk.translate.bleu_score import sentence_bleu
|
100
|
-
|
101
|
-
if self.bleu_with_saved_data:
|
102
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
103
|
-
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
104
|
-
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
105
|
-
else:
|
106
|
-
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
107
|
-
|
108
|
-
|
109
|
-
def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
|
110
|
-
saved_data: torch.Tensor) -> float:
|
111
|
-
from nltk.translate.bleu_score import sentence_bleu
|
112
|
-
|
113
|
-
if self.bleu_with_saved_data:
|
114
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
115
|
-
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
116
|
-
saved_bleu = 1 - saved_bleu
|
117
|
-
|
118
|
-
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
119
|
-
else:
|
120
|
-
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
121
|
-
|
122
|
-
def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
|
123
|
-
batch_size = generated.size(0)
|
124
|
-
return [self._sentence_bleu(generated[i], reference[i], saved_data[i]) for i in range(batch_size)]
|
125
|
-
|
126
|
-
def _sequence_embedding(self, sequence: torch.Tensor) -> torch.Tensor:
|
127
|
-
embedding = self.shared_embedding(sequence.to(self.device))
|
128
|
-
return embedding.mean(dim=1)
|
129
|
-
|
130
|
-
def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
131
|
-
generated_emb = self._sequence_embedding(generated)
|
132
|
-
|
133
|
-
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
134
|
-
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
135
|
-
return gen_and_saved, gen_and_ref
|
136
|
-
|
137
|
-
def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
138
|
-
generated_emb = self._sequence_embedding(generated)
|
139
|
-
|
140
|
-
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
141
|
-
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
142
|
-
gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
|
143
|
-
return gen_and_saved, gen_and_ref, gen_and_mean
|
144
|
-
|
145
|
-
def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
|
146
|
-
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
147
|
-
if self.use_running_mean and negative_running_mean:
|
148
|
-
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
149
|
-
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
150
|
-
1 - gen_and_mean)
|
151
|
-
elif self.use_running_mean and include_running_mean:
|
152
|
-
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
153
|
-
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
154
|
-
else:
|
155
|
-
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
156
|
-
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
157
|
-
|
158
|
-
def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
|
159
|
-
saved_data: torch.Tensor) -> torch.Tensor:
|
160
|
-
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
161
|
-
|
162
|
-
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
163
|
-
|
164
|
-
def len_reward(self, generated: TokenizedDict):
|
165
|
-
lens = generated['attention_mask'].sum(dim=1)
|
166
|
-
neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
|
167
|
-
len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
|
168
|
-
return len_reward
|
169
|
-
|
170
|
-
def reset_running_mean(self):
|
171
|
-
self.prev_data_running_mean = None
|
172
|
-
|
173
|
-
def init_running_mean(self, prev_data: torch.Tensor):
|
174
|
-
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
175
|
-
|
176
|
-
def update_running_mean(self, prev_data: torch.Tensor):
|
177
|
-
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
178
|
-
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
179
|
-
|
180
|
-
def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
|
181
|
-
if self.tanh_reward_scale:
|
182
|
-
return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
|
183
|
-
else:
|
184
|
-
return rewards
|
185
|
-
|
186
|
-
def __call__(
|
187
|
-
self,
|
188
|
-
generated: TokenizedDict,
|
189
|
-
reference: TokenizedDict,
|
190
|
-
saved_data: TokenizedDict,
|
191
|
-
prev_data: TokenizedDict = None,
|
192
|
-
mode: MrlRewardMode = MrlRewardMode.STANDARD
|
193
|
-
) -> list[float]:
|
194
|
-
if prev_data is not None:
|
195
|
-
if self.prev_data_running_mean is None:
|
196
|
-
self.init_running_mean(prev_data['input_ids'])
|
197
|
-
else:
|
198
|
-
self.update_running_mean(prev_data['input_ids'])
|
199
|
-
|
200
|
-
if mode == MrlRewardMode.STANDARD:
|
201
|
-
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
202
|
-
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
203
|
-
include_running_mean=prev_data is not None)
|
204
|
-
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
205
|
-
elif mode == MrlRewardMode.LONG_RANGE:
|
206
|
-
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
207
|
-
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
208
|
-
negative_running_mean=prev_data is not None)
|
209
|
-
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
210
|
-
else:
|
211
|
-
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
212
|
-
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
213
|
-
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
214
|
-
|
215
|
-
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
216
|
-
return rewards.tolist()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|