rxnn 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/models.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  from enum import Enum
4
+ from typing import Literal
4
5
  from huggingface_hub import PyTorchModelHubMixin
5
6
  from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
6
7
 
@@ -74,23 +75,27 @@ class MrlActorModel(nn.Module):
74
75
  self.decoder = decoder
75
76
  self.memory_attention = memory_attention
76
77
 
77
- def freeze_components(self):
78
+ def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
78
79
  """Freeze encoder/decoder except memory-related layers."""
79
80
  if self.encoder.freeze_without_memory is not None:
80
81
  self.encoder.freeze_without_memory()
82
+ if stage == 'update':
83
+ self.encoder.freeze_memory()
81
84
  else:
82
85
  for param in self.encoder.parameters():
83
86
  param.requires_grad = False
84
- self.encoder.model.trainable_cross_attention_(True)
87
+ self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False)
85
88
  if self.decoder.freeze_without_memory is not None:
86
89
  self.decoder.freeze_without_memory()
90
+ if stage == 'update':
91
+ self.decoder.freeze_memory()
87
92
  else:
88
93
  for param in self.decoder.parameters():
89
94
  param.requires_grad = False
90
- self.decoder.model.trainable_cross_attention_(True)
95
+ self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False)
91
96
  # Unfreeze memory attention
92
97
  for param in self.memory_attention.parameters():
93
- param.requires_grad = True
98
+ param.requires_grad = True if stage != 'fetch' else False
94
99
 
95
100
  def unfreeze_components(self):
96
101
  """Unfreeze all components after initial training."""
rxnn/training/mrl.py CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
3
3
  from torch.utils.tensorboard import SummaryWriter
4
4
  import torch.distributed as dist
5
5
  from torch.nn.parallel import DistributedDataParallel
6
- from typing import Optional, TypedDict
6
+ from typing import Optional, TypedDict, Union
7
7
  from enum import Enum
8
8
  import random, os
9
9
  from ..transformers.sampler import BatchSampler
@@ -37,10 +37,15 @@ class CurriculumConfig(TypedDict):
37
37
  eval_dataset: Optional[MrlCurriculumDataset]
38
38
  callbacks: Optional[list[MrlTrainerCallback]]
39
39
  strategy: MrlStrategy
40
- unfreeze_epoch: Optional[int]
40
+ unfreeze_epoch: Optional[Union[int, tuple[int, int, int]]]
41
41
  random_resets: Optional[bool]
42
42
  random_resets_from: Optional[int]
43
43
  random_resets_ratio: Optional[float]
44
+ reward_model: Optional[MrlRewardModel]
45
+ lr: Optional[float]
46
+ critic_lr: Optional[float]
47
+ weight_decay: Optional[float]
48
+ critic_weight_decay: Optional[float]
44
49
 
45
50
 
46
51
  class SamplerConfig(TypedDict):
@@ -90,6 +95,7 @@ class MRLTrainer:
90
95
  """
91
96
  self.actor = actor
92
97
  self.critic = critic
98
+ self.shared_reward_model = reward
93
99
  self.reward = reward
94
100
  self.device = device
95
101
  self.max_seq_len = config.get('max_seq_len', 256)
@@ -117,17 +123,15 @@ class MRLTrainer:
117
123
  self.use_amp = use_amp
118
124
  self.dtype = dtype
119
125
 
126
+ self.base_optim_config = {
127
+ 'lr': config.get('lr', 3e-4),
128
+ 'critic_lr': config.get('critic_lr', 1e-4),
129
+ 'weight_decay': config.get('weight_decay', 0.01),
130
+ 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
131
+ }
132
+
120
133
  # Optimizers
121
- self.optimizer = torch.optim.AdamW(
122
- self.actor.unique_parameters(),
123
- lr=config.get("lr", 3e-4),
124
- weight_decay=config.get("weight_decay", 0.01),
125
- )
126
- self.critic_optimizer = torch.optim.AdamW(
127
- self.critic.parameters(),
128
- lr=config.get("critic_lr", 1e-4),
129
- weight_decay=config.get("critic_weight_decay", 0.01),
130
- )
134
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
131
135
 
132
136
  self.scaler = torch.amp.GradScaler() if self.use_amp else None
133
137
  self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
@@ -154,6 +158,21 @@ class MRLTrainer:
154
158
  self.global_epoch = 0
155
159
  self.global_epochs_count = 0
156
160
 
161
+ def _init_optimizers(self, lr: float, critic_lr: float, weight_decay: float, critic_weight_decay: float):
162
+ optimizer = torch.optim.AdamW(
163
+ self.actor.unique_parameters(),
164
+ lr=lr,
165
+ weight_decay=weight_decay,
166
+ )
167
+
168
+ critic_optimizer = torch.optim.AdamW(
169
+ self.critic.parameters(),
170
+ lr=critic_lr,
171
+ weight_decay=critic_weight_decay,
172
+ )
173
+ return optimizer, critic_optimizer
174
+
175
+
157
176
  def _init_steps(self):
158
177
  return {
159
178
  'collect': 0,
@@ -221,21 +240,29 @@ class MRLTrainer:
221
240
 
222
241
  return generated_answer, log_probs
223
242
 
243
+ def _calculate_reward(self, generated: TokenizedDict, reference: TokenizedDict,
244
+ saved_query: TokenizedDict, saved_answer: TokenizedDict,
245
+ mode: MrlRewardMode = MrlRewardMode.STANDARD,
246
+ prev_data: tuple[TokenizedDict, TokenizedDict] = None):
247
+ saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
248
+ pad_token_id=self.pad_token_id)
249
+ prev_data = smart_concat(prev_data[0], prev_data[1], self.max_seq_len,
250
+ self.pad_token_id) if prev_data is not None else None
251
+ return self.reward(generated, reference, saved_interaction, mode=mode, prev_data=prev_data), saved_interaction
252
+
224
253
  def compute_reward(self, generated: TokenizedDict, reference: TokenizedDict,
225
254
  saved_data: tuple[TokenizedDict, TokenizedDict], mode: MrlRewardMode = MrlRewardMode.STANDARD,
226
- eval_mode: bool = False) -> list[float]:
255
+ eval_mode: bool = False, prev_data: tuple[TokenizedDict, TokenizedDict] = None) -> list[float]:
227
256
  """Compute reward based on memory retention (e.g., BLEU-4)."""
228
257
  saved_query, saved_answer = saved_data
229
258
  # 1. Concat saved (previous) interaction and calculate reward using generated sequence, reference and saved data - with autocast on/off
230
259
  if self.use_amp:
231
260
  with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
232
- saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
233
- pad_token_id=self.pad_token_id)
234
- reward = self.reward(generated, reference, saved_interaction, mode=mode)
261
+ reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
262
+ mode=mode, prev_data=prev_data)
235
263
  else:
236
- saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
237
- pad_token_id=self.pad_token_id)
238
- reward = self.reward(generated, reference, saved_interaction, mode=mode)
264
+ reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
265
+ mode=mode, prev_data=prev_data)
239
266
 
240
267
  # 2. Run 'on reward' callbacks
241
268
  for cb in self.callbacks:
@@ -289,22 +316,27 @@ class MRLTrainer:
289
316
  # state from existing one, instead of new random one)
290
317
  reset_done = self.reset_stm()
291
318
 
292
- # 4. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
319
+ # 4. Reset reward prev data running mean - it's calculated for multi-step retention, we have to reset it before episode
320
+ self.reward.reset_running_mean()
321
+
322
+ # 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
293
323
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
294
324
  interactions = interactions[:self.curriculum_steps]
295
325
  interactions_len = len(interactions)
296
- # 5. Encode and update STM with data to save from first interaction
326
+ # 6. Encode and update STM with data to save from first interaction
297
327
  self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
298
328
 
299
- # 6. Save first interaction as data to save (for trajectory state)
329
+ # 7. Save first interaction as data to save (for trajectory state)
300
330
  query, answer = first_query, first_answer
301
331
 
302
- # 7. Run training strategy for follow-up interactions
332
+ # 8. Run training strategy for follow-up interactions
303
333
  episode_steps = []
304
334
  episode_rewards = []
305
335
 
336
+ prev_interaction = None
337
+
306
338
  for i, interaction in enumerate(interactions):
307
- # 8. Generate batch of answers based on batch of follow-up queries
339
+ # 9. Generate batch of answers based on batch of follow-up queries
308
340
  next_query = self._move_batch(interaction['query'])
309
341
  generated_answer, log_probs = self.generate_answer(next_query)
310
342
 
@@ -312,7 +344,7 @@ class MRLTrainer:
312
344
 
313
345
  detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
314
346
 
315
- # 9. Depending on strategy compute reward
347
+ # 10. Depending on strategy compute reward
316
348
  if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
317
349
  # a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
318
350
  reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
@@ -320,18 +352,19 @@ class MRLTrainer:
320
352
  elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
321
353
  # b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
322
354
  reward = self.compute_reward(detached_answer, interaction['answer'],
323
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE)
355
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
356
+ prev_data=prev_interaction)
324
357
  else:
325
358
  # c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
326
359
  reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
327
- mode=MrlRewardMode.STANDARD)
360
+ mode=MrlRewardMode.STANDARD, prev_data=prev_interaction)
328
361
 
329
- # 10. Update STM with generated response (except last interaction, it's not needed)
362
+ # 11. Update STM with generated response (except last interaction, it's not needed)
330
363
  if not is_last_interaction:
331
364
  self.encode_and_update_stm(next_query,
332
365
  generated_answer) # update with generated_answer on GPU
333
366
 
334
- # 11. Store trajectory step
367
+ # 12. Store trajectory step
335
368
  trajectory: MrlTrajectoryStep = {
336
369
  'state': (query, answer, interaction['query']),
337
370
  'action': detached_answer,
@@ -342,10 +375,12 @@ class MRLTrainer:
342
375
  episode_steps.append(trajectory)
343
376
  episode_rewards.append(reward)
344
377
 
345
- # 12. Set current interaction query and generated answer (batches), as saved data for next interaction
378
+ # 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
379
+ if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
380
+ prev_interaction = (query, answer)
346
381
  query, answer = interaction['query'], detached_answer
347
382
 
348
- # 13. Append full batched episode (number of steps depends on curriculum stage) to trajectories
383
+ # 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
349
384
  episode_trajectory: MrlTrajectoryEpisode = {
350
385
  'reset_stm': reset_done,
351
386
  'steps': episode_steps,
@@ -356,7 +391,7 @@ class MRLTrainer:
356
391
 
357
392
  self._collect_writer(mean_episode_reward, epoch)
358
393
 
359
- # 14. Run "on episode collected" callbacks
394
+ # 15. Run "on episode collected" callbacks
360
395
  for cb in self.callbacks:
361
396
  cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
362
397
 
@@ -595,63 +630,70 @@ class MRLTrainer:
595
630
  for batch in dataloader:
596
631
  with torch.no_grad():
597
632
  if batch['query']['input_ids'].size(0) == batch_size:
598
- self._increment_steps('eval')
599
- # 3. Reset STM with random resets ratio
600
- self.reset_stm()
601
-
602
- # 4. Get batches for first queries, answers and all follow-up interactions
603
- first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
604
- # 5. Encode and update STM with initial interactions (batch)
605
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
606
-
607
- # 6. Save follow-up interactions len and first query and answer as previous one for iteration
608
- interactions_len = len(interactions)
609
- query, answer = first_query, first_answer
610
- episode_reward = torch.tensor(0.0).to(self.device)
611
- episode_interactions = torch.tensor(0).to(self.device)
612
- # 7. Run all follow-up interactions
613
- for i, interaction in enumerate(interactions):
614
- # 8. Generate batch of answers
615
- next_query = self._move_batch(interaction['query'])
616
- generated_answer, _ = self.generate_answer(next_query)
617
-
618
- is_last_interaction = (i + 1) == interactions_len
619
-
620
- detached_answer = self._cpu_detach(generated_answer)
621
-
622
- # 9. Depending on current strategy and step, compute reward
623
- if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
624
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
625
- mode=MrlRewardMode.NEGATIVE, eval_mode=True)
626
- elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
627
- reward = self.compute_reward(detached_answer, interaction['answer'],
628
- (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
629
- eval_mode=True)
630
- else:
631
- reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
632
- mode=MrlRewardMode.STANDARD, eval_mode=True)
633
-
634
- # 10. Encode and update memory for the next interaction
635
- if not is_last_interaction:
636
- self.encode_and_update_stm(next_query, generated_answer)
637
-
638
- # 11. Accumulate rewards
639
- step_reward = torch.tensor(reward).mean().to(self.device)
640
- # total
641
- total_reward += step_reward
642
- count += 1
643
- # episode
644
- episode_reward += step_reward
645
- episode_interactions += 1
646
- # 12. Save previous interaction
647
- query, answer = interaction['query'], detached_answer
648
- avg_episode_reward = (episode_reward / episode_interactions).item()
649
- # 13. Run eval TensorBoard writer with average episode reward
650
- self._eval_writer(avg_episode_reward, epoch)
651
-
652
- # 14. Run "on eval episode end" callbacks
653
- for cb in self.callbacks:
654
- cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
633
+ self._increment_steps('eval')
634
+ # 3. Reset STM with random resets ratio and reward model running mean
635
+ self.reset_stm()
636
+ self.reward.reset_running_mean()
637
+
638
+ # 4. Get batches for first queries, answers and all follow-up interactions
639
+ first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
640
+ # 5. Encode and update STM with initial interactions (batch)
641
+ self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
642
+
643
+ # 6. Save follow-up interactions len and first query and answer as previous one for iteration
644
+ interactions_len = len(interactions)
645
+ query, answer = first_query, first_answer
646
+ episode_reward = torch.tensor(0.0).to(self.device)
647
+ episode_interactions = torch.tensor(0).to(self.device)
648
+
649
+ prev_interaction = None
650
+
651
+ # 7. Run all follow-up interactions
652
+ for i, interaction in enumerate(interactions):
653
+ # 8. Generate batch of answers
654
+ next_query = self._move_batch(interaction['query'])
655
+ generated_answer, _ = self.generate_answer(next_query)
656
+
657
+ is_last_interaction = (i + 1) == interactions_len
658
+
659
+ detached_answer = self._cpu_detach(generated_answer)
660
+
661
+ # 9. Depending on current strategy and step, compute reward
662
+ if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
663
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
664
+ mode=MrlRewardMode.NEGATIVE, eval_mode=True)
665
+ elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
666
+ reward = self.compute_reward(detached_answer, interaction['answer'],
667
+ (first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
668
+ eval_mode=True, prev_data=prev_interaction)
669
+ else:
670
+ reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
671
+ mode=MrlRewardMode.STANDARD, eval_mode=True,
672
+ prev_data=prev_interaction)
673
+
674
+ # 10. Encode and update memory for the next interaction
675
+ if not is_last_interaction:
676
+ self.encode_and_update_stm(next_query, generated_answer)
677
+
678
+ # 11. Accumulate rewards
679
+ step_reward = torch.tensor(reward).mean().to(self.device)
680
+ # total
681
+ total_reward += step_reward
682
+ count += 1
683
+ # episode
684
+ episode_reward += step_reward
685
+ episode_interactions += 1
686
+ # 12. Save previous interaction
687
+ if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
688
+ prev_interaction = (query, answer)
689
+ query, answer = interaction['query'], detached_answer
690
+ avg_episode_reward = (episode_reward / episode_interactions).item()
691
+ # 13. Run eval TensorBoard writer with average episode reward
692
+ self._eval_writer(avg_episode_reward, epoch)
693
+
694
+ # 14. Run "on eval episode end" callbacks
695
+ for cb in self.callbacks:
696
+ cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
655
697
 
656
698
  # 15. Calculate average reward
657
699
  if self.use_ddp:
@@ -679,6 +721,14 @@ class MRLTrainer:
679
721
  self.shared_callbacks) # trainer callbacks for current curriculum stage
680
722
  self.strategy = config.get('strategy',
681
723
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
724
+ self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
725
+ if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
726
+ self.optimizer, self.critic_optimizer = self._init_optimizers(
727
+ lr=config['lr'] or self.base_optim_config['lr'],
728
+ critic_lr=config['critic_lr'] or self.base_optim_config['critic_lr'],
729
+ weight_decay=config['weight_decay'] or self.base_optim_config['weight_decay'],
730
+ critic_weight_decay=config['critic_weight_decay'] or self.base_optim_config['critic_weight_decay']
731
+ )
682
732
 
683
733
  # 2. Get epochs and random resets configs
684
734
  epochs = config.get('epochs', 5) # number of epochs for current stage
@@ -720,7 +770,11 @@ class MRLTrainer:
720
770
 
721
771
  # 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
722
772
  if unfreeze_epoch != 0:
723
- self.actor.freeze_components()
773
+ is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
774
+ if is_staged_unfreeze:
775
+ self.actor.freeze_components('update')
776
+ else:
777
+ self.actor.freeze_components()
724
778
 
725
779
  # 5. Setup train DataLoader
726
780
  if self.use_ddp:
@@ -761,8 +815,18 @@ class MRLTrainer:
761
815
  self.random_resets_ratio = 1.0
762
816
 
763
817
  # 11. Unfreeze all components before selected epoch
764
- if epoch == unfreeze_epoch:
765
- self.actor.unfreeze_components()
818
+ is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
819
+ if is_staged_unfreeze:
820
+ fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
821
+ if epoch == fetch_epoch:
822
+ self.actor.freeze_components('fetch')
823
+ elif epoch == both_epoch:
824
+ self.actor.freeze_components('both')
825
+ elif epoch == all_epoch:
826
+ self.actor.unfreeze_components()
827
+ else:
828
+ if epoch == unfreeze_epoch:
829
+ self.actor.unfreeze_components()
766
830
 
767
831
  # 12. Set epoch for distributed sampler
768
832
  if train_sampler is not None:
@@ -805,4 +869,3 @@ class MRLTrainer:
805
869
  # 21. Close writer
806
870
  if self.writer:
807
871
  self.writer.close()
808
-
rxnn/training/reward.py CHANGED
@@ -11,6 +11,7 @@ class MrlRewardMode(Enum):
11
11
  NEGATIVE = 2
12
12
  LONG_RANGE = 3
13
13
 
14
+
14
15
  class MrlRewardModel:
15
16
  def __init__(
16
17
  self,
@@ -18,9 +19,14 @@ class MrlRewardModel:
18
19
  device: torch.device,
19
20
  bleu_with_saved_data: bool = False,
20
21
  bleu_factor: float = 0.5,
22
+ bleu_ref_factor: float = 0.5,
23
+ bleu_saved_factor: float = 0.5,
21
24
  cos_factor: float = 0.5,
22
25
  cos_ref_factor: float = 0.5,
23
26
  cos_saved_factor: float = 0.5,
27
+ multi_cos_ref_factor: float = 0.3,
28
+ multi_cos_saved_factor: float = 0.5,
29
+ multi_cos_running_mean_factor: float = 0.2,
24
30
  neg_bleu_factor: Optional[float] = None,
25
31
  neg_cos_factor: Optional[float] = None,
26
32
  neg_cos_ref_factor: Optional[float] = None,
@@ -28,45 +34,88 @@ class MrlRewardModel:
28
34
  neg_bleu_ref_factor: float = 0.5,
29
35
  neg_bleu_saved_factor: float = 0.5,
30
36
  allow_not_summing_factors: bool = False,
37
+ reward_len: bool = False,
38
+ neg_reward_len: bool = False,
39
+ max_rewarded_len: int = None,
40
+ len_factor: int = None,
41
+ use_running_mean: bool = True,
42
+ running_mean_decay: float = 0.2,
43
+ bleu_saved_weights: tuple = (0.5, 0.5),
44
+ bleu_ref_weights: tuple = (0.5, 0.5),
45
+ rewards_scale: float = 1.0,
31
46
  ):
32
47
  self.shared_embedding = shared_embedding.to(device)
33
48
  self.device = device
34
49
  self.bleu_with_saved_data = bleu_with_saved_data
35
50
 
36
51
  self.bleu_factor = bleu_factor
52
+ self.bleu_ref_factor = bleu_ref_factor
53
+ self.bleu_saved_factor = bleu_saved_factor
37
54
  self.cos_factor = cos_factor
38
55
  self.cos_ref_factor = cos_ref_factor
39
56
  self.cos_saved_factor = cos_saved_factor
57
+ self.multi_cos_ref_factor = multi_cos_ref_factor
58
+ self.multi_cos_saved_factor = multi_cos_saved_factor
59
+ self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
40
60
  self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
41
61
  self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
42
62
  self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
43
63
  self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
44
64
  self.neg_bleu_ref_factor = neg_bleu_ref_factor
45
65
  self.neg_bleu_saved_factor = neg_bleu_saved_factor
66
+ self.reward_len = reward_len
67
+ self.neg_reward_len = neg_reward_len
68
+ self.max_rewarded_len = max_rewarded_len
69
+ self.len_factor = len_factor
70
+ self.use_running_mean = use_running_mean
71
+ self.running_mean_decay = running_mean_decay
72
+ self.bleu_ref_weights = bleu_ref_weights
73
+ self.bleu_saved_weights = bleu_saved_weights
74
+ self.rewards_scale = rewards_scale
75
+
76
+ self.prev_data_running_mean = None
46
77
 
47
78
  if not allow_not_summing_factors:
48
- assert self.bleu_factor + self.cos_factor == 1.0
49
- assert self.cos_ref_factor + self.cos_saved_factor == 1.0
50
- assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
51
- assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
52
- assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
79
+ if reward_len:
80
+ assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
81
+ assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
82
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
83
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
84
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
85
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
86
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
87
+ else:
88
+ assert self.bleu_factor + self.cos_factor == 1.0
89
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
90
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
91
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
92
+ assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
93
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
94
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
53
95
 
54
96
  def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
55
97
  from nltk.translate.bleu_score import sentence_bleu
56
- refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
57
- return sentence_bleu(refs, generated, weights=(0.25, 0.25, 0.25, 0.25))
58
98
 
59
- def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
99
+ if self.bleu_with_saved_data:
100
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
101
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
102
+ return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
103
+ else:
104
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
105
+
106
+
107
+ def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
108
+ saved_data: torch.Tensor) -> float:
60
109
  from nltk.translate.bleu_score import sentence_bleu
61
110
 
62
111
  if self.bleu_with_saved_data:
63
- ref_bleu = sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
64
- saved_bleu = sentence_bleu([saved_data], generated, weights=(0.25, 0.25, 0.25))
112
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
113
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
65
114
  saved_bleu = 1 - saved_bleu
66
115
 
67
- return (self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu) / 2
116
+ return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
68
117
  else:
69
- return sentence_bleu([reference], generated, weights=(0.25, 0.25, 0.25, 0.25))
118
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
70
119
 
71
120
  def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
72
121
  batch_size = generated.size(0)
@@ -79,33 +128,81 @@ class MrlRewardModel:
79
128
  def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
80
129
  generated_emb = self._sequence_embedding(generated)
81
130
 
82
- gen_and_saved = F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data))
83
- gen_and_ref = F.cosine_similarity(generated_emb, self._sequence_embedding(reference))
131
+ gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
132
+ gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
84
133
  return gen_and_saved, gen_and_ref
85
134
 
86
- def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
87
- gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
135
+ def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
136
+ generated_emb = self._sequence_embedding(generated)
88
137
 
89
- return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
138
+ gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
139
+ gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
140
+ gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
141
+ return gen_and_saved, gen_and_ref, gen_and_mean
142
+
143
+ def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
144
+ include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
145
+ if self.use_running_mean and negative_running_mean:
146
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
147
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
148
+ 1 - gen_and_mean)
149
+ elif self.use_running_mean and include_running_mean:
150
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
151
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
152
+ else:
153
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
154
+ return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
90
155
 
91
- def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> torch.Tensor:
156
+ def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
157
+ saved_data: torch.Tensor) -> torch.Tensor:
92
158
  gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
93
159
 
94
160
  return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
95
161
 
162
+ def len_reward(self, generated: TokenizedDict):
163
+ lens = generated['attention_mask'].sum(dim=1)
164
+ neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
165
+ len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
166
+ return len_reward
167
+
168
+ def reset_running_mean(self):
169
+ self.prev_data_running_mean = None
170
+
171
+ def init_running_mean(self, prev_data: torch.Tensor):
172
+ self.prev_data_running_mean = self._sequence_embedding(prev_data)
173
+
174
+ def update_running_mean(self, prev_data: torch.Tensor):
175
+ self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
176
+ prev_data) + self.running_mean_decay * self.prev_data_running_mean
177
+
96
178
  def __call__(
97
179
  self,
98
180
  generated: TokenizedDict,
99
181
  reference: TokenizedDict,
100
182
  saved_data: TokenizedDict,
183
+ prev_data: TokenizedDict = None,
101
184
  mode: MrlRewardMode = MrlRewardMode.STANDARD
102
185
  ) -> list[float]:
103
- if mode == MrlRewardMode.STANDARD or mode == MrlRewardMode.LONG_RANGE:
186
+ if prev_data is not None:
187
+ if self.prev_data_running_mean is None:
188
+ self.init_running_mean(prev_data['input_ids'])
189
+ else:
190
+ self.update_running_mean(prev_data['input_ids'])
191
+
192
+ if mode == MrlRewardMode.STANDARD:
193
+ bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
194
+ cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
195
+ include_running_mean=prev_data is not None)
196
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
197
+ elif mode == MrlRewardMode.LONG_RANGE:
104
198
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
105
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
106
- return (self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine).tolist()
199
+ cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
200
+ negative_running_mean=prev_data is not None)
201
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
107
202
  else:
108
203
  bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
109
204
  cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
110
- return (self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine).tolist()
205
+ sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
111
206
 
207
+ rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
208
+ return rewards.tolist()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.18
3
+ Version: 0.2.20
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -15,9 +15,9 @@ rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
- rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
19
- rxnn/training/mrl.py,sha256=Ar2g-vjqTq_4qLKc4L1Ai0j2LX-x98dmsx_VaWVV-Es,39448
20
- rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
18
+ rxnn/training/models.py,sha256=qXfD3_97T9z724NN4myjzrpX6-jYA9Igl266ZwtJCtc,5519
19
+ rxnn/training/mrl.py,sha256=zk4m1JFuX0y82J0tG2XkY0Pz6Uy2did9cngOXqR9lMk,43326
20
+ rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
21
  rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
23
23
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.18.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.18.dist-info/METADATA,sha256=_hGNlaH_rclBfQdzA7tCFhkI-RZPiK5tNBM8tjUsbWQ,25960
37
- rxnn-0.2.18.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.18.dist-info/RECORD,,
35
+ rxnn-0.2.20.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.20.dist-info/METADATA,sha256=mgimK5GvI27RapfLjhlIdBwgfVdKoMA5Ig5yVxfeYIw,25960
37
+ rxnn-0.2.20.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.20.dist-info/RECORD,,
File without changes
File without changes