rxnn 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/mrl.py CHANGED
@@ -41,6 +41,7 @@ class CurriculumConfig(TypedDict):
41
41
  random_resets: Optional[bool]
42
42
  random_resets_from: Optional[int]
43
43
  random_resets_ratio: Optional[float]
44
+ reward_model: Optional[MrlRewardModel]
44
45
 
45
46
 
46
47
  class SamplerConfig(TypedDict):
@@ -90,6 +91,7 @@ class MRLTrainer:
90
91
  """
91
92
  self.actor = actor
92
93
  self.critic = critic
94
+ self.shared_reward_model = reward
93
95
  self.reward = reward
94
96
  self.device = device
95
97
  self.max_seq_len = config.get('max_seq_len', 256)
@@ -221,21 +223,29 @@ class MRLTrainer:
221
223
 
222
224
  return generated_answer, log_probs
223
225
 
226
+ def _calculate_reward(self, generated: TokenizedDict, reference: TokenizedDict,
227
+ saved_query: TokenizedDict, saved_answer: TokenizedDict,
228
+ mode: MrlRewardMode = MrlRewardMode.STANDARD,
229
+ prev_data: tuple[TokenizedDict, TokenizedDict] = None):
230
+ saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
231
+ pad_token_id=self.pad_token_id)
232
+ prev_data = smart_concat(prev_data[0], prev_data[1], self.max_seq_len,
233
+ self.pad_token_id) if prev_data is not None else None
234
+ return self.reward(generated, reference, saved_interaction, mode=mode, prev_data=prev_data), saved_interaction
235
+
224
236
  def compute_reward(self, generated: TokenizedDict, reference: TokenizedDict,
225
237
  saved_data: tuple[TokenizedDict, TokenizedDict], mode: MrlRewardMode = MrlRewardMode.STANDARD,
226
- eval_mode: bool = False) -> list[float]:
238
+ eval_mode: bool = False, prev_data: tuple[TokenizedDict, TokenizedDict] = None) -> list[float]:
227
239
  """Compute reward based on memory retention (e.g., BLEU-4)."""
228
240
  saved_query, saved_answer = saved_data
229
241
  # 1. Concat saved (previous) interaction and calculate reward using generated sequence, reference and saved data - with autocast on/off
230
242
  if self.use_amp:
231
243
  with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
232
- saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
233
- pad_token_id=self.pad_token_id)
234
- reward = self.reward(generated, reference, saved_interaction, mode=mode)
244
+ reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
245
+ mode=mode, prev_data=prev_data)
235
246
  else:
236
- saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
237
- pad_token_id=self.pad_token_id)
238
- reward = self.reward(generated, reference, saved_interaction, mode=mode)
247
+ reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
248
+ mode=mode, prev_data=prev_data)
239
249
 
240
250
  # 2. Run 'on reward' callbacks
241
251
  for cb in self.callbacks:
@@ -289,22 +299,27 @@ class MRLTrainer:
289
299
  # state from existing one, instead of new random one)
290
300
  reset_done = self.reset_stm()
291
301
 
292
- # 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
302
+ # 4. Reset reward prev data running mean - it's calculated for multi-step retention, we have to reset it before episode
303
+ self.reward.reset_running_mean()
304
+
305
+ # 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
293
306
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
294
307
  interactions = interactions[:self.curriculum_steps]
295
308
  interactions_len = len(interactions)
296
- # 5. Encode and update STM with data to save from first interaction
309
+ # 6. Encode and update STM with data to save from first interaction
297
310
  self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
298
311
 
299
- # 6. Save first interaction as data to save (for trajectory state)
312
+ # 7. Save first interaction as data to save (for trajectory state)
300
313
  query, answer = first_query, first_answer
301
314
 
302
- # 7. Run training strategy for follow-up interactions
315
+ # 8. Run training strategy for follow-up interactions
303
316
  episode_steps = []
304
317
  episode_rewards = []
305
318
 
319
+ prev_interaction = None
320
+
306
321
  for i, interaction in enumerate(interactions):
307
- # 8. Generate batch of answers based on batch of follow-up queries
322
+ # 9. Generate batch of answers based on batch of follow-up queries
308
323
  next_query = self._move_batch(interaction['query'])
309
324
  generated_answer, log_probs = self.generate_answer(next_query)
310
325
 
@@ -312,7 +327,7 @@ class MRLTrainer:
312
327
 
313
328
  detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
314
329
 
315
- # 9. Depending on strategy compute reward
330
+ # 10. Depending on strategy compute reward
316
331
  if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
317
332
  # a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
318
333
  reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
@@ -320,18 +335,19 @@ class MRLTrainer:
320
335
  elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
321
336
  # b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
322
337
  reward = self.compute_reward(detached_answer, interaction['answer'],
323
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
338
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
339
+ prev_data=prev_interaction)
324
340
  else:
325
341
  # c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
326
342
  reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
327
- mode=MrlRewardMode.STANDARD)
343
+ mode=MrlRewardMode.STANDARD, prev_data=prev_interaction)
328
344
 
329
- # 10. Update STM with generated response (except last interaction, it's not needed)
345
+ # 11. Update STM with generated response (except last interaction, it's not needed)
330
346
  if not is_last_interaction:
331
347
  self.encode_and_update_stm(next_query,
332
348
  generated_answer) # update with generated_answer on GPU
333
349
 
334
- # 11. Store trajectory step
350
+ # 12. Store trajectory step
335
351
  trajectory: MrlTrajectoryStep = {
336
352
  'state': (query, answer, interaction['query']),
337
353
  'action': detached_answer,
@@ -342,10 +358,12 @@ class MRLTrainer:
342
358
  episode_steps.append(trajectory)
343
359
  episode_rewards.append(reward)
344
360
 
345
- # 12. Set current interaction query and generated answer (batches), as saved data for next interaction
361
+ # 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
362
+ if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
363
+ prev_interaction = (query, answer)
346
364
  query, answer = interaction['query'], detached_answer
347
365
 
348
- # 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
366
+ # 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
349
367
  episode_trajectory: MrlTrajectoryEpisode = {
350
368
  'reset_stm': reset_done,
351
369
  'steps': episode_steps,
@@ -356,7 +374,7 @@ class MRLTrainer:
356
374
 
357
375
  self._collect_writer(mean_episode_reward, epoch)
358
376
 
359
- # 14. Run "on episode collected" callbacks
377
+ # 15. Run "on episode collected" callbacks
360
378
  for cb in self.callbacks:
361
379
  cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
362
380
 
@@ -595,63 +613,70 @@ class MRLTrainer:
595
613
  for batch in dataloader:
596
614
  with torch.no_grad():
597
615
  if batch['query']['input_ids'].size(0) == batch_size:
598
- self._increment_steps('eval')
599
- # 3. Reset STM with random resets ratio
600
- self.reset_stm()
601
-
602
- # 4. Get batches for first queries, answers and all follow-up interactions
603
- first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
604
- # 5. Encode and update STM with initial interactions (batch)
605
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
606
-
607
- # 6. Save follow-up interactions len and first query and answer as previous one for iteration
608
- interactions_len = len(interactions)
609
- query, answer = first_query, first_answer
610
- episode_reward = torch.tensor(0.0).to(self.device)
611
- episode_interactions = torch.tensor(0).to(self.device)
612
- # 7. Run all follow-up interactions
613
- for i, interaction in enumerate(interactions):
614
- # 8. Generate batch of answers
615
- next_query = self._move_batch(interaction['query'])
616
- generated_answer, _ = self.generate_answer(next_query)
617
-
618
- is_last_interaction = (i + 1) == interactions_len
619
-
620
- detached_answer = self._cpu_detach(generated_answer)
621
-
622
- # 9. Depending on current strategy and step, compute reward
623
- if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
624
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
625
- mode=MrlRewardMode.NEGATIVE, eval_mode=True)
626
- elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
627
- reward = self.compute_reward(detached_answer, interaction['answer'],
628
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
629
- eval_mode=True)
630
- else:
631
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
632
- mode=MrlRewardMode.STANDARD, eval_mode=True)
633
-
634
- # 10. Encode and update memory for the next interaction
635
- if not is_last_interaction:
636
- self.encode_and_update_stm(next_query, generated_answer)
637
-
638
- # 11. Accumulate rewards
639
- step_reward = torch.tensor(reward).mean().to(self.device)
640
- # total
641
- total_reward += step_reward
642
- count += 1
643
- # episode
644
- episode_reward += step_reward
645
- episode_interactions += 1
646
- # 12. Save previous interaction
647
- query, answer = interaction['query'], detached_answer
648
- avg_episode_reward = (episode_reward / episode_interactions).item()
649
- # 13. Run eval TensorBoard writer with average episode reward
650
- self._eval_writer(avg_episode_reward, epoch)
651
-
652
- # 14. Run "on eval episode end" callbacks
653
- for cb in self.callbacks:
654
- cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
616
+ self._increment_steps('eval')
617
+ # 3. Reset STM with random resets ratio and reward model running mean
618
+ self.reset_stm()
619
+ self.reward.reset_running_mean()
620
+
621
+ # 4. Get batches for first queries, answers and all follow-up interactions
622
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
623
+ # 5. Encode and update STM with initial interactions (batch)
624
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
625
+
626
+ # 6. Save follow-up interactions len and first query and answer as previous one for iteration
627
+ interactions_len = len(interactions)
628
+ query, answer = first_query, first_answer
629
+ episode_reward = torch.tensor(0.0).to(self.device)
630
+ episode_interactions = torch.tensor(0).to(self.device)
631
+
632
+ prev_interaction = None
633
+
634
+ # 7. Run all follow-up interactions
635
+ for i, interaction in enumerate(interactions):
636
+ # 8. Generate batch of answers
637
+ next_query = self._move_batch(interaction['query'])
638
+ generated_answer, _ = self.generate_answer(next_query)
639
+
640
+ is_last_interaction = (i + 1) == interactions_len
641
+
642
+ detached_answer = self._cpu_detach(generated_answer)
643
+
644
+ # 9. Depending on current strategy and step, compute reward
645
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
646
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
647
+ mode=MrlRewardMode.NEGATIVE, eval_mode=True)
648
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
649
+ reward = self.compute_reward(detached_answer, interaction['answer'],
650
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
651
+ eval_mode=True, prev_data=prev_interaction)
652
+ else:
653
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
654
+ mode=MrlRewardMode.STANDARD, eval_mode=True,
655
+ prev_data=prev_interaction)
656
+
657
+ # 10. Encode and update memory for the next interaction
658
+ if not is_last_interaction:
659
+ self.encode_and_update_stm(next_query, generated_answer)
660
+
661
+ # 11. Accumulate rewards
662
+ step_reward = torch.tensor(reward).mean().to(self.device)
663
+ # total
664
+ total_reward += step_reward
665
+ count += 1
666
+ # episode
667
+ episode_reward += step_reward
668
+ episode_interactions += 1
669
+ # 12. Save previous interaction
670
+ if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
671
+ prev_interaction = (query, answer)
672
+ query, answer = interaction['query'], detached_answer
673
+ avg_episode_reward = (episode_reward / episode_interactions).item()
674
+ # 13. Run eval TensorBoard writer with average episode reward
675
+ self._eval_writer(avg_episode_reward, epoch)
676
+
677
+ # 14. Run "on eval episode end" callbacks
678
+ for cb in self.callbacks:
679
+ cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
655
680
 
656
681
  # 15. Calculate average reward
657
682
  if self.use_ddp:
@@ -679,6 +704,7 @@ class MRLTrainer:
679
704
  self.shared_callbacks) # trainer callbacks for current curriculum stage
680
705
  self.strategy = config.get('strategy',
681
706
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
707
+ self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
682
708
 
683
709
  # 2. Get epochs and random resets configs
684
710
  epochs = config.get('epochs', 5) # number of epochs for current stage
@@ -805,4 +831,3 @@ class MRLTrainer:
805
831
  # 21. Close writer
806
832
  if self.writer:
807
833
  self.writer.close()
808
-
rxnn/training/reward.py CHANGED
@@ -11,6 +11,7 @@ class MrlRewardMode(Enum):
11
11
  NEGATIVE = 2
12
12
  LONG_RANGE = 3
13
13
 
14
+
14
15
  class MrlRewardModel:
15
16
  def __init__(
16
17
  self,
@@ -18,9 +19,14 @@ class MrlRewardModel:
18
19
  device: torch.device,
19
20
  bleu_with_saved_data: bool = False,
20
21
  bleu_factor: float = 0.5,
22
+ bleu_ref_factor: float = 0.5,
23
+ bleu_saved_factor: float = 0.5,
21
24
  cos_factor: float = 0.5,
22
25
  cos_ref_factor: float = 0.5,
23
26
  cos_saved_factor: float = 0.5,
27
+ multi_cos_ref_factor: float = 0.3,
28
+ multi_cos_saved_factor: float = 0.5,
29
+ multi_cos_running_mean_factor: float = 0.2,
24
30
  neg_bleu_factor: Optional[float] = None,
25
31
  neg_cos_factor: Optional[float] = None,
26
32
  neg_cos_ref_factor: Optional[float] = None,
@@ -28,45 +34,88 @@ class MrlRewardModel:
28
34
  neg_bleu_ref_factor: float = 0.5,
29
35
  neg_bleu_saved_factor: float = 0.5,
30
36
  allow_not_summing_factors: bool = False,
37
+ reward_len: bool = False,
38
+ neg_reward_len: bool = False,
39
+ max_rewarded_len: int = None,
40
+ len_factor: int = None,
41
+ use_running_mean: bool = True,
42
+ running_mean_decay: float = 0.2,
43
+ bleu_saved_weights: tuple = (0.5, 0.5),
44
+ bleu_ref_weights: tuple = (0.5, 0.5),
45
+ rewards_scale: float = 1.0,
31
46
  ):
32
47
  self.shared_embedding = shared_embedding.to(device)
33
48
  self.device = device
34
49
  self.bleu_with_saved_data = bleu_with_saved_data
35
50
 
36
51
  self.bleu_factor = bleu_factor
52
+ self.bleu_ref_factor = bleu_ref_factor
53
+ self.bleu_saved_factor = bleu_saved_factor
37
54
  self.cos_factor = cos_factor
38
55
  self.cos_ref_factor = cos_ref_factor
39
56
  self.cos_saved_factor = cos_saved_factor
57
+ self.multi_cos_ref_factor = multi_cos_ref_factor
58
+ self.multi_cos_saved_factor = multi_cos_saved_factor
59
+ self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
40
60
  self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
41
61
  self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
42
62
  self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
43
63
  self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
44
64
  self.neg_bleu_ref_factor = neg_bleu_ref_factor
45
65
  self.neg_bleu_saved_factor = neg_bleu_saved_factor
66
+ self.reward_len = reward_len
67
+ self.neg_reward_len = neg_reward_len
68
+ self.max_rewarded_len = max_rewarded_len
69
+ self.len_factor = len_factor
70
+ self.use_running_mean = use_running_mean
71
+ self.running_mean_decay = running_mean_decay
72
+ self.bleu_ref_weights = bleu_ref_weights
73
+ self.bleu_saved_weights = bleu_saved_weights
74
+ self.rewards_scale = rewards_scale
75
+
76
+ self.prev_data_running_mean = None
46
77
 
47
78
  if not allow_not_summing_factors:
48
- assert self.bleu_factor + self.cos_factor == 1.0
49
- assert self.cos_ref_factor + self.cos_saved_factor == 1.0
50
- assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
51
- assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
52
- assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
79
+ if reward_len:
80
+ assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
81
+ assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
82
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
83
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
84
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
85
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
86
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
87
+ else:
88
+ assert self.bleu_factor + self.cos_factor == 1.0
89
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
90
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
91
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
92
+ assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
93
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
94
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
53
95
 
54
96
  def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
55
97
  from nltk.translate.bleu_score import sentence_bleu
56
- refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
57
- return sentence_bleu(refs, generated, weights=(0.25, 0.25, 0.25, 0.25))
58
98
 
59
- def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
99
+ if self.bleu_with_saved_data:
100
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
101
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
102
+ return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
103
+ else:
104
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
105
+
106
+
107
+ def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
108
+ saved_data: torch.Tensor) -> float:
60
109
  from nltk.translate.bleu_score import sentence_bleu
61
110
 
62
111
  if self.bleu_with_saved_data:
63
- ref_bleu = sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
64
- saved_bleu = sentence_bleu([saved_data], generated, weights=(0.25, 0.25, 0.25))
112
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
113
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
65
114
  saved_bleu = 1 - saved_bleu
66
115
 
67
- return (self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu) / 2
116
+ return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
68
117
  else:
69
- return sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
118
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
70
119
 
71
120
  def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
72
121
  batch_size = generated.size(0)
@@ -79,33 +128,81 @@ class MrlRewardModel:
79
128
  def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
80
129
  generated_emb = self._sequence_embedding(generated)
81
130
 
82
- gen_and_saved = F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data))
83
- gen_and_ref = F.cosine_similarity(generated_emb, self._sequence_embedding(reference))
131
+ gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
132
+ gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
84
133
  return gen_and_saved, gen_and_ref
85
134
 
86
- def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
87
- gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
135
+ def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
136
+ generated_emb = self._sequence_embedding(generated)
88
137
 
89
- return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
138
+ gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
139
+ gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
140
+ gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
141
+ return gen_and_saved, gen_and_ref, gen_and_mean
142
+
143
+ def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
144
+ include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
145
+ if self.use_running_mean and negative_running_mean:
146
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
147
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
148
+ 1 - gen_and_mean)
149
+ elif self.use_running_mean and include_running_mean:
150
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
151
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
152
+ else:
153
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
154
+ return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
90
155
 
91
- def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
156
+ def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
157
+ saved_data: torch.Tensor) -> torch.Tensor:
92
158
  gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
93
159
 
94
160
  return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
95
161
 
162
+ def len_reward(self, generated: TokenizedDict):
163
+ lens = generated['attention_mask'].sum(dim=1)
164
+ neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
165
+ len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
166
+ return len_reward
167
+
168
+ def reset_running_mean(self):
169
+ self.prev_data_running_mean = None
170
+
171
+ def init_running_mean(self, prev_data: torch.Tensor):
172
+ self.prev_data_running_mean = self._sequence_embedding(prev_data)
173
+
174
+ def update_running_mean(self, prev_data: torch.Tensor):
175
+ self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
176
+ prev_data) + self.running_mean_decay * self.prev_data_running_mean
177
+
96
178
  def __call__(
97
179
  self,
98
180
  generated: TokenizedDict,
99
181
  reference: TokenizedDict,
100
182
  saved_data: TokenizedDict,
183
+ prev_data: TokenizedDict = None,
101
184
  mode: MrlRewardMode = MrlRewardMode.STANDARD
102
185
  ) -> list[float]:
103
- if mode == MrlRewardMode.STANDARD or mode == MrlRewardMode.LONG_RANGE:
186
+ if prev_data is not None:
187
+ if self.prev_data_running_mean is None:
188
+ self.init_running_mean(prev_data['input_ids'])
189
+ else:
190
+ self.update_running_mean(prev_data['input_ids'])
191
+
192
+ if mode == MrlRewardMode.STANDARD:
193
+ bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
194
+ cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
195
+ include_running_mean=prev_data is not None)
196
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
197
+ elif mode == MrlRewardMode.LONG_RANGE:
104
198
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
105
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
106
- return (self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine).tolist()
199
+ cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
200
+ negative_running_mean=prev_data is not None)
201
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
107
202
  else:
108
203
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
109
204
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
110
- return (self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine).tolist()
205
+ sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
111
206
 
207
+ rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
208
+ return rewards.tolist()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.18
3
+ Version: 0.2.19
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -16,8 +16,8 @@ rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
18
  rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=Ar2g-vjqTq_4qLKc4L1Ai0j2LX-x98dmsx_VaWVV-Es,39448
20
- rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
19
+ rxnn/training/mrl.py,sha256=RSbeJRRjAH1lzkySzeoDmng6hleRmUfnNcM1YVv57as,41388
20
+ rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
21
  rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.18.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.18.dist-info/METADATA,sha256=_hGNlaH_rclBfQdzA7tCFhkI-RZPiK5tNBM8tjUsbWQ,25960
37
- rxnn-0.2.18.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.18.dist-info/RECORD,,
35
+ rxnn-0.2.19.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.19.dist-info/METADATA,sha256=y3om6t6e6WreQXmVjEfmr_vSkqBl-R04Tmch9Qk6rQg,25960
37
+ rxnn-0.2.19.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.19.dist-info/RECORD,,
File without changes
File without changes