rxnn 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/mrl.py +103 -78
- rxnn/training/reward.py +119 -22
- {rxnn-0.2.18.dist-info → rxnn-0.2.19.dist-info}/METADATA +1 -1
- {rxnn-0.2.18.dist-info → rxnn-0.2.19.dist-info}/RECORD +6 -6
- {rxnn-0.2.18.dist-info → rxnn-0.2.19.dist-info}/LICENSE +0 -0
- {rxnn-0.2.18.dist-info → rxnn-0.2.19.dist-info}/WHEEL +0 -0
rxnn/training/mrl.py
CHANGED
@@ -41,6 +41,7 @@ class CurriculumConfig(TypedDict):
|
|
41
41
|
random_resets: Optional[bool]
|
42
42
|
random_resets_from: Optional[int]
|
43
43
|
random_resets_ratio: Optional[float]
|
44
|
+
reward_model: Optional[MrlRewardModel]
|
44
45
|
|
45
46
|
|
46
47
|
class SamplerConfig(TypedDict):
|
@@ -90,6 +91,7 @@ class MRLTrainer:
|
|
90
91
|
"""
|
91
92
|
self.actor = actor
|
92
93
|
self.critic = critic
|
94
|
+
self.shared_reward_model = reward
|
93
95
|
self.reward = reward
|
94
96
|
self.device = device
|
95
97
|
self.max_seq_len = config.get('max_seq_len', 256)
|
@@ -221,21 +223,29 @@ class MRLTrainer:
|
|
221
223
|
|
222
224
|
return generated_answer, log_probs
|
223
225
|
|
226
|
+
def _calculate_reward(self, generated: TokenizedDict, reference: TokenizedDict,
|
227
|
+
saved_query: TokenizedDict, saved_answer: TokenizedDict,
|
228
|
+
mode: MrlRewardMode = MrlRewardMode.STANDARD,
|
229
|
+
prev_data: tuple[TokenizedDict, TokenizedDict] = None):
|
230
|
+
saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
|
231
|
+
pad_token_id=self.pad_token_id)
|
232
|
+
prev_data = smart_concat(prev_data[0], prev_data[1], self.max_seq_len,
|
233
|
+
self.pad_token_id) if prev_data is not None else None
|
234
|
+
return self.reward(generated, reference, saved_interaction, mode=mode, prev_data=prev_data), saved_interaction
|
235
|
+
|
224
236
|
def compute_reward(self, generated: TokenizedDict, reference: TokenizedDict,
|
225
237
|
saved_data: tuple[TokenizedDict, TokenizedDict], mode: MrlRewardMode = MrlRewardMode.STANDARD,
|
226
|
-
eval_mode: bool = False) -> list[float]:
|
238
|
+
eval_mode: bool = False, prev_data: tuple[TokenizedDict, TokenizedDict] = None) -> list[float]:
|
227
239
|
"""Compute reward based on memory retention (e.g., BLEU-4)."""
|
228
240
|
saved_query, saved_answer = saved_data
|
229
241
|
# 1. Concat saved (previous) interaction and calculate reward using generated sequence, reference and saved data - with autocast on/off
|
230
242
|
if self.use_amp:
|
231
243
|
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
232
|
-
saved_interaction =
|
233
|
-
|
234
|
-
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
244
|
+
reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
|
245
|
+
mode=mode, prev_data=prev_data)
|
235
246
|
else:
|
236
|
-
saved_interaction =
|
237
|
-
|
238
|
-
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
247
|
+
reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
|
248
|
+
mode=mode, prev_data=prev_data)
|
239
249
|
|
240
250
|
# 2. Run 'on reward' callbacks
|
241
251
|
for cb in self.callbacks:
|
@@ -289,22 +299,27 @@ class MRLTrainer:
|
|
289
299
|
# state from existing one, instead of new random one)
|
290
300
|
reset_done = self.reset_stm()
|
291
301
|
|
292
|
-
# 4.
|
302
|
+
# 4. Reset reward prev data running mean - it's calculated for multi-step retention, we have to reset it before episode
|
303
|
+
self.reward.reset_running_mean()
|
304
|
+
|
305
|
+
# 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
|
293
306
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
294
307
|
interactions = interactions[:self.curriculum_steps]
|
295
308
|
interactions_len = len(interactions)
|
296
|
-
#
|
309
|
+
# 6. Encode and update STM with data to save from first interaction
|
297
310
|
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
298
311
|
|
299
|
-
#
|
312
|
+
# 7. Save first interaction as data to save (for trajectory state)
|
300
313
|
query, answer = first_query, first_answer
|
301
314
|
|
302
|
-
#
|
315
|
+
# 8. Run training strategy for follow-up interactions
|
303
316
|
episode_steps = []
|
304
317
|
episode_rewards = []
|
305
318
|
|
319
|
+
prev_interaction = None
|
320
|
+
|
306
321
|
for i, interaction in enumerate(interactions):
|
307
|
-
#
|
322
|
+
# 9. Generate batch of answers based on batch of follow-up queries
|
308
323
|
next_query = self._move_batch(interaction['query'])
|
309
324
|
generated_answer, log_probs = self.generate_answer(next_query)
|
310
325
|
|
@@ -312,7 +327,7 @@ class MRLTrainer:
|
|
312
327
|
|
313
328
|
detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
|
314
329
|
|
315
|
-
#
|
330
|
+
# 10. Depending on strategy compute reward
|
316
331
|
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
317
332
|
# a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
|
318
333
|
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
@@ -320,18 +335,19 @@ class MRLTrainer:
|
|
320
335
|
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
321
336
|
# b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
|
322
337
|
reward = self.compute_reward(detached_answer, interaction['answer'],
|
323
|
-
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE
|
338
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
339
|
+
prev_data=prev_interaction)
|
324
340
|
else:
|
325
341
|
# c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
|
326
342
|
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
327
|
-
mode=MrlRewardMode.STANDARD)
|
343
|
+
mode=MrlRewardMode.STANDARD, prev_data=prev_interaction)
|
328
344
|
|
329
|
-
#
|
345
|
+
# 11. Update STM with generated response (except last interaction, it's not needed)
|
330
346
|
if not is_last_interaction:
|
331
347
|
self.encode_and_update_stm(next_query,
|
332
348
|
generated_answer) # update with generated_answer on GPU
|
333
349
|
|
334
|
-
#
|
350
|
+
# 12. Store trajectory step
|
335
351
|
trajectory: MrlTrajectoryStep = {
|
336
352
|
'state': (query, answer, interaction['query']),
|
337
353
|
'action': detached_answer,
|
@@ -342,10 +358,12 @@ class MRLTrainer:
|
|
342
358
|
episode_steps.append(trajectory)
|
343
359
|
episode_rewards.append(reward)
|
344
360
|
|
345
|
-
#
|
361
|
+
# 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
|
362
|
+
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
363
|
+
prev_interaction = (query, answer)
|
346
364
|
query, answer = interaction['query'], detached_answer
|
347
365
|
|
348
|
-
#
|
366
|
+
# 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
349
367
|
episode_trajectory: MrlTrajectoryEpisode = {
|
350
368
|
'reset_stm': reset_done,
|
351
369
|
'steps': episode_steps,
|
@@ -356,7 +374,7 @@ class MRLTrainer:
|
|
356
374
|
|
357
375
|
self._collect_writer(mean_episode_reward, epoch)
|
358
376
|
|
359
|
-
#
|
377
|
+
# 15. Run "on episode collected" callbacks
|
360
378
|
for cb in self.callbacks:
|
361
379
|
cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
|
362
380
|
|
@@ -595,63 +613,70 @@ class MRLTrainer:
|
|
595
613
|
for batch in dataloader:
|
596
614
|
with torch.no_grad():
|
597
615
|
if batch['query']['input_ids'].size(0) == batch_size:
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
616
|
+
self._increment_steps('eval')
|
617
|
+
# 3. Reset STM with random resets ratio and reward model running mean
|
618
|
+
self.reset_stm()
|
619
|
+
self.reward.reset_running_mean()
|
620
|
+
|
621
|
+
# 4. Get batches for first queries, answers and all follow-up interactions
|
622
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
623
|
+
# 5. Encode and update STM with initial interactions (batch)
|
624
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
625
|
+
|
626
|
+
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
627
|
+
interactions_len = len(interactions)
|
628
|
+
query, answer = first_query, first_answer
|
629
|
+
episode_reward = torch.tensor(0.0).to(self.device)
|
630
|
+
episode_interactions = torch.tensor(0).to(self.device)
|
631
|
+
|
632
|
+
prev_interaction = None
|
633
|
+
|
634
|
+
# 7. Run all follow-up interactions
|
635
|
+
for i, interaction in enumerate(interactions):
|
636
|
+
# 8. Generate batch of answers
|
637
|
+
next_query = self._move_batch(interaction['query'])
|
638
|
+
generated_answer, _ = self.generate_answer(next_query)
|
639
|
+
|
640
|
+
is_last_interaction = (i + 1) == interactions_len
|
641
|
+
|
642
|
+
detached_answer = self._cpu_detach(generated_answer)
|
643
|
+
|
644
|
+
# 9. Depending on current strategy and step, compute reward
|
645
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
646
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
647
|
+
mode=MrlRewardMode.NEGATIVE, eval_mode=True)
|
648
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
649
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
650
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
651
|
+
eval_mode=True, prev_data=prev_interaction)
|
652
|
+
else:
|
653
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
654
|
+
mode=MrlRewardMode.STANDARD, eval_mode=True,
|
655
|
+
prev_data=prev_interaction)
|
656
|
+
|
657
|
+
# 10. Encode and update memory for the next interaction
|
658
|
+
if not is_last_interaction:
|
659
|
+
self.encode_and_update_stm(next_query, generated_answer)
|
660
|
+
|
661
|
+
# 11. Accumulate rewards
|
662
|
+
step_reward = torch.tensor(reward).mean().to(self.device)
|
663
|
+
# total
|
664
|
+
total_reward += step_reward
|
665
|
+
count += 1
|
666
|
+
# episode
|
667
|
+
episode_reward += step_reward
|
668
|
+
episode_interactions += 1
|
669
|
+
# 12. Save previous interaction
|
670
|
+
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
671
|
+
prev_interaction = (query, answer)
|
672
|
+
query, answer = interaction['query'], detached_answer
|
673
|
+
avg_episode_reward = (episode_reward / episode_interactions).item()
|
674
|
+
# 13. Run eval TensorBoard writer with average episode reward
|
675
|
+
self._eval_writer(avg_episode_reward, epoch)
|
676
|
+
|
677
|
+
# 14. Run "on eval episode end" callbacks
|
678
|
+
for cb in self.callbacks:
|
679
|
+
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
655
680
|
|
656
681
|
# 15. Calculate average reward
|
657
682
|
if self.use_ddp:
|
@@ -679,6 +704,7 @@ class MRLTrainer:
|
|
679
704
|
self.shared_callbacks) # trainer callbacks for current curriculum stage
|
680
705
|
self.strategy = config.get('strategy',
|
681
706
|
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
707
|
+
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
682
708
|
|
683
709
|
# 2. Get epochs and random resets configs
|
684
710
|
epochs = config.get('epochs', 5) # number of epochs for current stage
|
@@ -805,4 +831,3 @@ class MRLTrainer:
|
|
805
831
|
# 21. Close writer
|
806
832
|
if self.writer:
|
807
833
|
self.writer.close()
|
808
|
-
|
rxnn/training/reward.py
CHANGED
@@ -11,6 +11,7 @@ class MrlRewardMode(Enum):
|
|
11
11
|
NEGATIVE = 2
|
12
12
|
LONG_RANGE = 3
|
13
13
|
|
14
|
+
|
14
15
|
class MrlRewardModel:
|
15
16
|
def __init__(
|
16
17
|
self,
|
@@ -18,9 +19,14 @@ class MrlRewardModel:
|
|
18
19
|
device: torch.device,
|
19
20
|
bleu_with_saved_data: bool = False,
|
20
21
|
bleu_factor: float = 0.5,
|
22
|
+
bleu_ref_factor: float = 0.5,
|
23
|
+
bleu_saved_factor: float = 0.5,
|
21
24
|
cos_factor: float = 0.5,
|
22
25
|
cos_ref_factor: float = 0.5,
|
23
26
|
cos_saved_factor: float = 0.5,
|
27
|
+
multi_cos_ref_factor: float = 0.3,
|
28
|
+
multi_cos_saved_factor: float = 0.5,
|
29
|
+
multi_cos_running_mean_factor: float = 0.2,
|
24
30
|
neg_bleu_factor: Optional[float] = None,
|
25
31
|
neg_cos_factor: Optional[float] = None,
|
26
32
|
neg_cos_ref_factor: Optional[float] = None,
|
@@ -28,45 +34,88 @@ class MrlRewardModel:
|
|
28
34
|
neg_bleu_ref_factor: float = 0.5,
|
29
35
|
neg_bleu_saved_factor: float = 0.5,
|
30
36
|
allow_not_summing_factors: bool = False,
|
37
|
+
reward_len: bool = False,
|
38
|
+
neg_reward_len: bool = False,
|
39
|
+
max_rewarded_len: int = None,
|
40
|
+
len_factor: int = None,
|
41
|
+
use_running_mean: bool = True,
|
42
|
+
running_mean_decay: float = 0.2,
|
43
|
+
bleu_saved_weights: tuple = (0.5, 0.5),
|
44
|
+
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
|
+
rewards_scale: float = 1.0,
|
31
46
|
):
|
32
47
|
self.shared_embedding = shared_embedding.to(device)
|
33
48
|
self.device = device
|
34
49
|
self.bleu_with_saved_data = bleu_with_saved_data
|
35
50
|
|
36
51
|
self.bleu_factor = bleu_factor
|
52
|
+
self.bleu_ref_factor = bleu_ref_factor
|
53
|
+
self.bleu_saved_factor = bleu_saved_factor
|
37
54
|
self.cos_factor = cos_factor
|
38
55
|
self.cos_ref_factor = cos_ref_factor
|
39
56
|
self.cos_saved_factor = cos_saved_factor
|
57
|
+
self.multi_cos_ref_factor = multi_cos_ref_factor
|
58
|
+
self.multi_cos_saved_factor = multi_cos_saved_factor
|
59
|
+
self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
|
40
60
|
self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
|
41
61
|
self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
|
42
62
|
self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
|
43
63
|
self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
|
44
64
|
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
45
65
|
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
66
|
+
self.reward_len = reward_len
|
67
|
+
self.neg_reward_len = neg_reward_len
|
68
|
+
self.max_rewarded_len = max_rewarded_len
|
69
|
+
self.len_factor = len_factor
|
70
|
+
self.use_running_mean = use_running_mean
|
71
|
+
self.running_mean_decay = running_mean_decay
|
72
|
+
self.bleu_ref_weights = bleu_ref_weights
|
73
|
+
self.bleu_saved_weights = bleu_saved_weights
|
74
|
+
self.rewards_scale = rewards_scale
|
75
|
+
|
76
|
+
self.prev_data_running_mean = None
|
46
77
|
|
47
78
|
if not allow_not_summing_factors:
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
79
|
+
if reward_len:
|
80
|
+
assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
|
81
|
+
assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
|
82
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
83
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
84
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
85
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
86
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
87
|
+
else:
|
88
|
+
assert self.bleu_factor + self.cos_factor == 1.0
|
89
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
90
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
91
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
92
|
+
assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
|
93
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
94
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
53
95
|
|
54
96
|
def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
|
55
97
|
from nltk.translate.bleu_score import sentence_bleu
|
56
|
-
refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
|
57
|
-
return sentence_bleu(refs, generated, weights=(0.25, 0.25, 0.25, 0.25))
|
58
98
|
|
59
|
-
|
99
|
+
if self.bleu_with_saved_data:
|
100
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
101
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
102
|
+
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
103
|
+
else:
|
104
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
105
|
+
|
106
|
+
|
107
|
+
def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
|
108
|
+
saved_data: torch.Tensor) -> float:
|
60
109
|
from nltk.translate.bleu_score import sentence_bleu
|
61
110
|
|
62
111
|
if self.bleu_with_saved_data:
|
63
|
-
ref_bleu = sentence_bleu([reference], generated, weights=
|
64
|
-
saved_bleu = sentence_bleu([saved_data], generated, weights=
|
112
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
113
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
65
114
|
saved_bleu = 1 - saved_bleu
|
66
115
|
|
67
|
-
return
|
116
|
+
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
68
117
|
else:
|
69
|
-
return sentence_bleu([reference], generated, weights=
|
118
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
70
119
|
|
71
120
|
def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
|
72
121
|
batch_size = generated.size(0)
|
@@ -79,33 +128,81 @@ class MrlRewardModel:
|
|
79
128
|
def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
80
129
|
generated_emb = self._sequence_embedding(generated)
|
81
130
|
|
82
|
-
gen_and_saved = F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data))
|
83
|
-
gen_and_ref = F.cosine_similarity(generated_emb, self._sequence_embedding(reference))
|
131
|
+
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
132
|
+
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
84
133
|
return gen_and_saved, gen_and_ref
|
85
134
|
|
86
|
-
def
|
87
|
-
|
135
|
+
def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
136
|
+
generated_emb = self._sequence_embedding(generated)
|
88
137
|
|
89
|
-
|
138
|
+
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
139
|
+
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
140
|
+
gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
|
141
|
+
return gen_and_saved, gen_and_ref, gen_and_mean
|
142
|
+
|
143
|
+
def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
|
144
|
+
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
145
|
+
if self.use_running_mean and negative_running_mean:
|
146
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
147
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
148
|
+
1 - gen_and_mean)
|
149
|
+
elif self.use_running_mean and include_running_mean:
|
150
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
151
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
152
|
+
else:
|
153
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
154
|
+
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
90
155
|
|
91
|
-
def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
|
156
|
+
def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
|
157
|
+
saved_data: torch.Tensor) -> torch.Tensor:
|
92
158
|
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
93
159
|
|
94
160
|
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
95
161
|
|
162
|
+
def len_reward(self, generated: TokenizedDict):
|
163
|
+
lens = generated['attention_mask'].sum(dim=1)
|
164
|
+
neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
|
165
|
+
len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
|
166
|
+
return len_reward
|
167
|
+
|
168
|
+
def reset_running_mean(self):
|
169
|
+
self.prev_data_running_mean = None
|
170
|
+
|
171
|
+
def init_running_mean(self, prev_data: torch.Tensor):
|
172
|
+
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
173
|
+
|
174
|
+
def update_running_mean(self, prev_data: torch.Tensor):
|
175
|
+
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
176
|
+
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
177
|
+
|
96
178
|
def __call__(
|
97
179
|
self,
|
98
180
|
generated: TokenizedDict,
|
99
181
|
reference: TokenizedDict,
|
100
182
|
saved_data: TokenizedDict,
|
183
|
+
prev_data: TokenizedDict = None,
|
101
184
|
mode: MrlRewardMode = MrlRewardMode.STANDARD
|
102
185
|
) -> list[float]:
|
103
|
-
if
|
186
|
+
if prev_data is not None:
|
187
|
+
if self.prev_data_running_mean is None:
|
188
|
+
self.init_running_mean(prev_data['input_ids'])
|
189
|
+
else:
|
190
|
+
self.update_running_mean(prev_data['input_ids'])
|
191
|
+
|
192
|
+
if mode == MrlRewardMode.STANDARD:
|
193
|
+
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
194
|
+
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
195
|
+
include_running_mean=prev_data is not None)
|
196
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
197
|
+
elif mode == MrlRewardMode.LONG_RANGE:
|
104
198
|
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
105
|
-
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids']
|
106
|
-
|
199
|
+
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
200
|
+
negative_running_mean=prev_data is not None)
|
201
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
107
202
|
else:
|
108
203
|
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
109
204
|
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
110
|
-
|
205
|
+
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
111
206
|
|
207
|
+
rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
208
|
+
return rewards.tolist()
|
@@ -16,8 +16,8 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
20
|
-
rxnn/training/reward.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=RSbeJRRjAH1lzkySzeoDmng6hleRmUfnNcM1YVv57as,41388
|
20
|
+
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
21
|
rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.19.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.19.dist-info/METADATA,sha256=y3om6t6e6WreQXmVjEfmr_vSkqBl-R04Tmch9Qk6rQg,25960
|
37
|
+
rxnn-0.2.19.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|