rxnn 0.2.47__tar.gz → 0.2.49__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {rxnn-0.2.47 → rxnn-0.2.49}/PKG-INFO +2 -1
  2. {rxnn-0.2.47 → rxnn-0.2.49}/pyproject.toml +2 -2
  3. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/norm.py +3 -1
  4. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/mrl.py +34 -10
  5. rxnn-0.2.49/src/rxnn/training/reward.py +323 -0
  6. rxnn-0.2.47/src/rxnn/training/reward.py +0 -216
  7. {rxnn-0.2.47 → rxnn-0.2.49}/LICENSE +0 -0
  8. {rxnn-0.2.47 → rxnn-0.2.49}/README.md +0 -0
  9. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/.DS_Store +0 -0
  10. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/__init__.py +0 -0
  11. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/__init__.py +0 -0
  12. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/attention.py +0 -0
  13. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/models.py +0 -0
  14. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/experimental/moe.py +0 -0
  15. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/__init__.py +0 -0
  16. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/attention.py +0 -0
  17. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/memory/stm.py +0 -0
  18. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/rxt/models.py +0 -0
  20. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/__init__.py +0 -0
  21. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/base.py +0 -0
  22. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/bml.py +0 -0
  23. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/callbacks.py +0 -0
  24. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/dataset.py +0 -0
  25. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/ddp.py +0 -0
  26. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/models.py +0 -0
  27. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/rl.py +0 -0
  28. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/layers.py +0 -0
  35. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/models.py +0 -0
  37. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/moe.py +0 -0
  38. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/positional.py +0 -0
  39. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/transformers/sampler.py +0 -0
  40. {rxnn-0.2.47 → rxnn-0.2.49}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.47
3
+ Version: 0.2.49
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
16
  Requires-Dist: datasets (>=3.5.0,<4.0.0)
17
17
  Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
18
+ Requires-Dist: nltk (>=3.9.1,<4.0.0)
18
19
  Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
19
20
  Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
20
21
  Requires-Dist: torch (>=2.6.0,<3.0.0)
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.47"
7
+ version = "0.2.49"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -30,4 +30,4 @@ datasets = "^3.5.0"
30
30
  tokenizers = "^0.21.0"
31
31
  huggingface-hub = "^0.30.0"
32
32
  tensorboard = "^2.19.0"
33
-
33
+ nltk = "^3.9.1"
@@ -163,7 +163,7 @@ def init_memory_norm(
163
163
  init_scale: float = 1.0,
164
164
  per_dim_scale: bool = False,
165
165
  ) -> nn.Module:
166
- assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
166
+ assert norm_type in ['layer', 'rms', 'adaptive', 'positional', 'classic-rms']
167
167
  if norm_type == 'layer':
168
168
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
169
169
  elif norm_type == 'rms':
@@ -172,4 +172,6 @@ def init_memory_norm(
172
172
  return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
173
173
  elif norm_type == 'positional':
174
174
  return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
175
+ elif norm_type == 'classic-rms':
176
+ return nn.RMSNorm(dim)
175
177
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
@@ -35,6 +35,7 @@ class MrlConfig(TypedDict):
35
35
  moe_aux_loss_scale: Optional[float]
36
36
  freeze_embeddings: Optional[bool]
37
37
  embedding_lr: Optional[float]
38
+ use_memory_warmup: Optional[bool]
38
39
 
39
40
 
40
41
  class MrlStrategy(Enum):
@@ -70,6 +71,7 @@ class CurriculumConfig(TypedDict):
70
71
  update_epochs: Optional[int]
71
72
  freeze_embeddings: Optional[bool]
72
73
  embedding_lr: Optional[float]
74
+ teacher_forcing: Optional[bool]
73
75
 
74
76
 
75
77
  class SamplerConfig(TypedDict):
@@ -136,6 +138,7 @@ class MRLTrainer:
136
138
  self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
137
139
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
138
140
  self.freeze_embeddings = self.shared_freeze_embeddings
141
+ self.use_memory_warmup = config.get('use_memory_warmup', False)
139
142
  # Internal update epochs config
140
143
  self.shared_update_epochs = config.get('update_epochs', 10)
141
144
  self.update_epochs = self.shared_update_epochs
@@ -213,6 +216,7 @@ class MRLTrainer:
213
216
  self.callbacks = []
214
217
  self.global_epoch = 0
215
218
  self.global_epochs_count = 0
219
+ self.teacher_forcing = False
216
220
 
217
221
  def _init_optimizers(
218
222
  self,
@@ -381,6 +385,11 @@ class MRLTrainer:
381
385
  self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
382
386
  self.stage_step['collect'])
383
387
 
388
+ def memory_warmup(self, query: TokenizedDict, answer: TokenizedDict):
389
+ if self.use_memory_warmup:
390
+ with torch.no_grad():
391
+ self.encode_and_update_stm(query, answer)
392
+
384
393
  def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
385
394
  """Collect trajectories for PPO for current curriculum step."""
386
395
  # 1. Init trajectories list
@@ -402,8 +411,13 @@ class MRLTrainer:
402
411
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
403
412
  interactions = interactions[:self.curriculum_steps]
404
413
  interactions_len = len(interactions)
414
+
415
+ first_interaction = self._move_multiple_batches(first_query, first_answer)
416
+
417
+ if reset_done:
418
+ self.memory_warmup(*first_interaction)
405
419
  # 6. Encode and update STM with data to save from first interaction
406
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
420
+ self.encode_and_update_stm(*first_interaction)
407
421
 
408
422
  # 7. Save first interaction as data to save (for trajectory state)
409
423
  query, answer = first_query, first_answer
@@ -440,8 +454,10 @@ class MRLTrainer:
440
454
 
441
455
  # 11. Update STM with generated response (except last interaction, it's not needed)
442
456
  if not is_last_interaction:
443
- self.encode_and_update_stm(next_query,
444
- generated_answer) # update with generated_answer on GPU
457
+ self.encode_and_update_stm(
458
+ next_query,
459
+ self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
460
+ ) # update with generated_answer on GPU
445
461
 
446
462
  # 12. Store trajectory step
447
463
  trajectory: MrlTrajectoryStep = {
@@ -458,7 +474,7 @@ class MRLTrainer:
458
474
  # 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
459
475
  if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
460
476
  prev_interaction = (query, answer)
461
- query, answer = interaction['query'], detached_answer
477
+ query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
462
478
 
463
479
  # 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
464
480
  episode_trajectory: MrlTrajectoryEpisode = {
@@ -649,6 +665,9 @@ class MRLTrainer:
649
665
 
650
666
  self.actor.clone_reset_memory()
651
667
 
668
+ if should_reset_stm and step_idx == 0:
669
+ self.memory_warmup(query, answer)
670
+
652
671
  # 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
653
672
  if self.memory_aware_critic:
654
673
  self.encode_and_update_stm(query, answer)
@@ -798,13 +817,16 @@ class MRLTrainer:
798
817
  if batch['query']['input_ids'].size(0) == batch_size:
799
818
  self._increment_steps('eval')
800
819
  # 3. Reset STM with random resets ratio and reward model running mean
801
- self.reset_stm()
820
+ reset_stm = self.reset_stm()
802
821
  self.reward.reset_running_mean()
803
822
 
804
823
  # 4. Get batches for first queries, answers and all follow-up interactions
805
824
  first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
806
825
  # 5. Encode and update STM with initial interactions (batch)
807
- self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
826
+ first_interaction = self._move_multiple_batches(first_query, first_answer)
827
+ if reset_stm:
828
+ self.memory_warmup(*first_interaction)
829
+ self.encode_and_update_stm(*first_interaction)
808
830
 
809
831
  # 6. Save follow-up interactions len and first query and answer as previous one for iteration
810
832
  interactions_len = len(interactions)
@@ -839,7 +861,10 @@ class MRLTrainer:
839
861
 
840
862
  # 10. Encode and update memory for the next interaction
841
863
  if not is_last_interaction:
842
- self.encode_and_update_stm(next_query, generated_answer)
864
+ self.encode_and_update_stm(
865
+ next_query,
866
+ self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
867
+ )
843
868
 
844
869
  # 11. Accumulate rewards
845
870
  step_reward = torch.tensor(reward).mean().to(self.device)
@@ -852,7 +877,7 @@ class MRLTrainer:
852
877
  # 12. Save previous interaction
853
878
  if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
854
879
  prev_interaction = (query, answer)
855
- query, answer = interaction['query'], detached_answer
880
+ query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
856
881
  avg_episode_reward = (episode_reward / episode_interactions).item()
857
882
  # 13. Run eval TensorBoard writer with average episode reward
858
883
  self._eval_writer(avg_episode_reward, epoch)
@@ -982,8 +1007,7 @@ class MRLTrainer:
982
1007
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
983
1008
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
984
1009
  self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
985
-
986
-
1010
+ self.teacher_forcing = config.get('teacher_forcing', False)
987
1011
 
988
1012
  def has_param(field: OptimField) -> bool:
989
1013
  return field in config and config[field] is not None
@@ -0,0 +1,323 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
5
+ from enum import Enum
6
+ from typing import Optional
7
+ from .utils import TokenizedDict
8
+
9
+
10
+ class MrlRewardMode(Enum):
11
+ STANDARD = 1
12
+ NEGATIVE = 2
13
+ LONG_RANGE = 3
14
+
15
+
16
+ class MrlRewardModel:
17
+ def __init__(
18
+ self,
19
+ shared_embedding: nn.Embedding,
20
+ device: torch.device,
21
+ bleu_with_saved_data: bool = False,
22
+ bleu_factor: float = 0.5,
23
+ bleu_ref_factor: float = 0.5,
24
+ bleu_saved_factor: float = 0.5,
25
+ cos_factor: float = 0.5,
26
+ cos_ref_factor: float = 0.5,
27
+ cos_saved_factor: float = 0.5,
28
+ multi_cos_ref_factor: float = 0.3,
29
+ multi_cos_saved_factor: float = 0.5,
30
+ multi_cos_running_mean_factor: float = 0.2,
31
+ neg_bleu_factor: Optional[float] = None,
32
+ neg_cos_factor: Optional[float] = None,
33
+ neg_cos_ref_factor: Optional[float] = None,
34
+ neg_cos_saved_factor: Optional[float] = None,
35
+ neg_bleu_ref_factor: float = 0.5,
36
+ neg_bleu_saved_factor: float = 0.5,
37
+ allow_not_summing_factors: bool = False,
38
+ reward_len: bool = False,
39
+ neg_reward_len: bool = False,
40
+ max_rewarded_len: int = None,
41
+ target_len_as_ref: bool = False,
42
+ len_factor: int = None,
43
+ use_running_mean: bool = True,
44
+ running_mean_decay: float = 0.2,
45
+ bleu_saved_weights: tuple = (0.5, 0.5),
46
+ bleu_ref_weights: tuple = (0.5, 0.5),
47
+ tanh_reward_scale: bool = False,
48
+ rewards_scale: float = 1.0,
49
+ debug_mode: int = 0,
50
+ ):
51
+ self.shared_embedding = shared_embedding.to(device)
52
+ self.device = device
53
+ self.bleu_with_saved_data = bleu_with_saved_data
54
+
55
+ self.bleu_factor = bleu_factor
56
+ self.bleu_ref_factor = bleu_ref_factor
57
+ self.bleu_saved_factor = bleu_saved_factor
58
+ self.cos_factor = cos_factor
59
+ self.cos_ref_factor = cos_ref_factor
60
+ self.cos_saved_factor = cos_saved_factor
61
+ self.multi_cos_ref_factor = multi_cos_ref_factor
62
+ self.multi_cos_saved_factor = multi_cos_saved_factor
63
+ self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
64
+ self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
65
+ self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
66
+ self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
67
+ self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
68
+ self.neg_bleu_ref_factor = neg_bleu_ref_factor
69
+ self.neg_bleu_saved_factor = neg_bleu_saved_factor
70
+ self.reward_len = reward_len
71
+ self.neg_reward_len = neg_reward_len
72
+ self.max_rewarded_len = max_rewarded_len
73
+ self.target_len_as_ref = target_len_as_ref
74
+ self.len_factor = len_factor
75
+ self.use_running_mean = use_running_mean
76
+ self.running_mean_decay = running_mean_decay
77
+ self.bleu_ref_weights = bleu_ref_weights
78
+ self.bleu_saved_weights = bleu_saved_weights
79
+ self.tanh_reward_scale = tanh_reward_scale
80
+ self.rewards_scale = rewards_scale
81
+ self.bleu_smoothing = SmoothingFunction().method4
82
+ self.debug_mode = debug_mode
83
+
84
+ self.prev_data_running_mean = None
85
+
86
+ if not allow_not_summing_factors:
87
+ if reward_len:
88
+ assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
89
+ assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
90
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
91
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
92
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
93
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
94
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
95
+ else:
96
+ assert self.bleu_factor + self.cos_factor == 1.0
97
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
98
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
99
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
100
+ assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
101
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
102
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
103
+
104
+ def _sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
105
+ masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
106
+ generated, reference, saved_data = input_ids
107
+ generated_mask, reference_mask, saved_data_mask = masks
108
+
109
+ generated = generated.tolist()[:generated_mask.sum().item()]
110
+ reference = reference.tolist()[:reference_mask.sum().item()]
111
+ saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
112
+
113
+ if self.debug_mode == 2:
114
+ print('LENS: ', (len(generated), len(reference), len(saved_data)))
115
+
116
+ if self.bleu_with_saved_data:
117
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
118
+ smoothing_function=self.bleu_smoothing)
119
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
120
+ smoothing_function=self.bleu_smoothing)
121
+ if self.debug_mode == 2:
122
+ print('REF BLEU: ', ref_bleu)
123
+ print('SAVED BLEU: ', saved_bleu)
124
+
125
+ return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
126
+ else:
127
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
128
+
129
+ def _negative_sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
130
+ masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
131
+ generated, reference, saved_data = input_ids
132
+ generated_mask, reference_mask, saved_data_mask = masks
133
+
134
+ generated = generated.tolist()[:generated_mask.sum().item()]
135
+ reference = reference.tolist()[:reference_mask.sum().item()]
136
+ saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
137
+
138
+ if self.debug_mode == 2:
139
+ print('LENS: ', (len(generated), len(reference), len(saved_data)))
140
+
141
+ if self.bleu_with_saved_data:
142
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
143
+ smoothing_function=self.bleu_smoothing)
144
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
145
+ smoothing_function=self.bleu_smoothing)
146
+ saved_bleu = 1 - saved_bleu
147
+
148
+ if self.debug_mode == 2:
149
+ print('REF BLEU: ', ref_bleu)
150
+ print('SAVED BLEU: ', saved_bleu)
151
+
152
+ return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
153
+ else:
154
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
155
+
156
+ def batch_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[float]:
157
+ batch_size = generated['input_ids'].size(0)
158
+
159
+ return [
160
+ self._sentence_bleu(
161
+ input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
162
+ masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
163
+ ) for i in range(batch_size)
164
+ ]
165
+
166
+ def negative_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[
167
+ float]:
168
+ batch_size = generated['input_ids'].size(0)
169
+
170
+ return [
171
+ self._negative_sentence_bleu(
172
+ input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
173
+ masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
174
+ ) for i in range(batch_size)
175
+ ]
176
+
177
+ def _sequence_embedding(self, sequence: TokenizedDict) -> torch.Tensor:
178
+ input_ids = sequence['input_ids']
179
+ attention_mask = sequence['attention_mask']
180
+
181
+ # Get embeddings
182
+ embeddings = self.shared_embedding(input_ids.to(self.device))
183
+
184
+ # Apply attention mask
185
+ mask_expanded = attention_mask.unsqueeze(-1).to(self.device)
186
+ masked_embeddings = embeddings * mask_expanded
187
+
188
+ # Compute mean with masking
189
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
190
+ token_counts = torch.sum(mask_expanded, dim=1)
191
+ token_counts = torch.clamp(token_counts, min=1e-8) # Avoid division by zero
192
+
193
+ return sum_embeddings / token_counts
194
+
195
+ def _cosine_sim(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
196
+ generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
197
+ saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
198
+ reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
199
+
200
+ gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
201
+ gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
202
+
203
+ if self.debug_mode >= 1:
204
+ print('GEN AND SAVED: ', gen_and_saved.mean())
205
+ print('GEN AND REF: ', gen_and_ref.mean())
206
+ return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0)
207
+
208
+ def _cosine_sim_running_mean(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
209
+ generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
210
+ saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
211
+ reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
212
+ running_emb = F.normalize(self.prev_data_running_mean, dim=-1)
213
+
214
+ gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
215
+ gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
216
+ gen_and_mean = F.cosine_similarity(generated_emb, running_emb, dim=1)
217
+
218
+ if self.debug_mode >= 1:
219
+ print('GEN AND SAVED: ', gen_and_saved.mean())
220
+ print('GEN AND REF: ', gen_and_ref.mean())
221
+ print('GEN AND MEAN: ', gen_and_mean.mean())
222
+
223
+ return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0), torch.clamp(gen_and_mean, min=0)
224
+
225
+ def batch_cosine(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict,
226
+ include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
227
+ if self.use_running_mean and negative_running_mean:
228
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
229
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
230
+ 1 - gen_and_mean)
231
+ elif self.use_running_mean and include_running_mean:
232
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
233
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
234
+ else:
235
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
236
+ return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
237
+
238
+ def negative_cosine(self, generated: TokenizedDict, reference: TokenizedDict,
239
+ saved_data: TokenizedDict) -> torch.Tensor:
240
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
241
+
242
+ return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
243
+
244
+ def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
245
+ target_lens = reference['attention_mask'].sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
246
+ lens = generated['attention_mask'].sum(dim=1)
247
+ neg_lens = target_lens / lens if self.neg_reward_len else 1.0
248
+ len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
249
+ return len_reward
250
+
251
+ def reset_running_mean(self):
252
+ self.prev_data_running_mean = None
253
+
254
+ def init_running_mean(self, prev_data: TokenizedDict):
255
+ self.prev_data_running_mean = self._sequence_embedding(prev_data)
256
+
257
+ def update_running_mean(self, prev_data: TokenizedDict):
258
+ self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
259
+ prev_data) + self.running_mean_decay * self.prev_data_running_mean
260
+
261
+ def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
262
+ if self.tanh_reward_scale:
263
+ return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
264
+ else:
265
+ return rewards
266
+
267
+ def __call__(
268
+ self,
269
+ generated: TokenizedDict,
270
+ reference: TokenizedDict,
271
+ saved_data: TokenizedDict,
272
+ prev_data: TokenizedDict = None,
273
+ mode: MrlRewardMode = MrlRewardMode.STANDARD
274
+ ) -> list[float]:
275
+ if prev_data is not None:
276
+ if self.prev_data_running_mean is None:
277
+ self.init_running_mean(prev_data)
278
+ else:
279
+ self.update_running_mean(prev_data)
280
+
281
+ if mode == MrlRewardMode.STANDARD:
282
+ bleu = self.batch_bleu(generated, reference, saved_data)
283
+ cosine = self.batch_cosine(generated, reference, saved_data, include_running_mean=prev_data is not None)
284
+
285
+ if self.debug_mode >= 1:
286
+ print('STANDARD MODE')
287
+ print('BLEU: ', sum(bleu) / len(bleu))
288
+ print('COSINE: ', sum(cosine) / len(cosine))
289
+
290
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
291
+ elif mode == MrlRewardMode.LONG_RANGE:
292
+ bleu = self.batch_bleu(generated, reference, saved_data)
293
+ cosine = self.batch_cosine(generated, reference, saved_data,
294
+ negative_running_mean=prev_data is not None)
295
+
296
+ if self.debug_mode >= 1:
297
+ print('LONG MODE')
298
+ print('BLEU: ', sum(bleu) / len(bleu))
299
+ print('COSINE: ', sum(cosine) / len(cosine))
300
+
301
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
302
+ else:
303
+ bleu = self.negative_bleu(generated, reference, saved_data)
304
+ cosine = self.negative_cosine(generated, reference, saved_data)
305
+
306
+ if self.debug_mode >= 1:
307
+ print('NEGATIVE MODE')
308
+ print('BLEU: ', sum(bleu) / len(bleu))
309
+ print('COSINE: ', sum(cosine) / len(cosine))
310
+
311
+ sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
312
+
313
+ if self.reward_len:
314
+ len_reward = self.len_reward(generated, reference)
315
+
316
+ if self.debug_mode >= 1:
317
+ print('REWARD LEN: ', (len_reward.sum() / len_reward.size(0)).item())
318
+
319
+ rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * len_reward) * self.rewards_scale
320
+ else:
321
+ rewards = self._pre_scale_rewards(sim_rewards) * self.rewards_scale
322
+
323
+ return rewards.tolist()
@@ -1,216 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from enum import Enum
5
- from typing import Optional
6
- from .utils import TokenizedDict
7
-
8
-
9
- class MrlRewardMode(Enum):
10
- STANDARD = 1
11
- NEGATIVE = 2
12
- LONG_RANGE = 3
13
-
14
-
15
- class MrlRewardModel:
16
- def __init__(
17
- self,
18
- shared_embedding: nn.Embedding,
19
- device: torch.device,
20
- bleu_with_saved_data: bool = False,
21
- bleu_factor: float = 0.5,
22
- bleu_ref_factor: float = 0.5,
23
- bleu_saved_factor: float = 0.5,
24
- cos_factor: float = 0.5,
25
- cos_ref_factor: float = 0.5,
26
- cos_saved_factor: float = 0.5,
27
- multi_cos_ref_factor: float = 0.3,
28
- multi_cos_saved_factor: float = 0.5,
29
- multi_cos_running_mean_factor: float = 0.2,
30
- neg_bleu_factor: Optional[float] = None,
31
- neg_cos_factor: Optional[float] = None,
32
- neg_cos_ref_factor: Optional[float] = None,
33
- neg_cos_saved_factor: Optional[float] = None,
34
- neg_bleu_ref_factor: float = 0.5,
35
- neg_bleu_saved_factor: float = 0.5,
36
- allow_not_summing_factors: bool = False,
37
- reward_len: bool = False,
38
- neg_reward_len: bool = False,
39
- max_rewarded_len: int = None,
40
- len_factor: int = None,
41
- use_running_mean: bool = True,
42
- running_mean_decay: float = 0.2,
43
- bleu_saved_weights: tuple = (0.5, 0.5),
44
- bleu_ref_weights: tuple = (0.5, 0.5),
45
- tanh_reward_scale: bool = False,
46
- rewards_scale: float = 1.0,
47
- ):
48
- self.shared_embedding = shared_embedding.to(device)
49
- self.device = device
50
- self.bleu_with_saved_data = bleu_with_saved_data
51
-
52
- self.bleu_factor = bleu_factor
53
- self.bleu_ref_factor = bleu_ref_factor
54
- self.bleu_saved_factor = bleu_saved_factor
55
- self.cos_factor = cos_factor
56
- self.cos_ref_factor = cos_ref_factor
57
- self.cos_saved_factor = cos_saved_factor
58
- self.multi_cos_ref_factor = multi_cos_ref_factor
59
- self.multi_cos_saved_factor = multi_cos_saved_factor
60
- self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
61
- self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
62
- self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
63
- self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
64
- self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
65
- self.neg_bleu_ref_factor = neg_bleu_ref_factor
66
- self.neg_bleu_saved_factor = neg_bleu_saved_factor
67
- self.reward_len = reward_len
68
- self.neg_reward_len = neg_reward_len
69
- self.max_rewarded_len = max_rewarded_len
70
- self.len_factor = len_factor
71
- self.use_running_mean = use_running_mean
72
- self.running_mean_decay = running_mean_decay
73
- self.bleu_ref_weights = bleu_ref_weights
74
- self.bleu_saved_weights = bleu_saved_weights
75
- self.tanh_reward_scale = tanh_reward_scale
76
- self.rewards_scale = rewards_scale
77
-
78
- self.prev_data_running_mean = None
79
-
80
- if not allow_not_summing_factors:
81
- if reward_len:
82
- assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
83
- assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
84
- assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
85
- assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
86
- assert self.cos_ref_factor + self.cos_saved_factor == 1.0
87
- assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
88
- assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
89
- else:
90
- assert self.bleu_factor + self.cos_factor == 1.0
91
- assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
92
- assert self.cos_ref_factor + self.cos_saved_factor == 1.0
93
- assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
94
- assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
95
- assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
96
- assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
97
-
98
- def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
99
- from nltk.translate.bleu_score import sentence_bleu
100
-
101
- if self.bleu_with_saved_data:
102
- ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
103
- saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
104
- return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
105
- else:
106
- return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
107
-
108
-
109
- def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
110
- saved_data: torch.Tensor) -> float:
111
- from nltk.translate.bleu_score import sentence_bleu
112
-
113
- if self.bleu_with_saved_data:
114
- ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
115
- saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
116
- saved_bleu = 1 - saved_bleu
117
-
118
- return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
119
- else:
120
- return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
121
-
122
- def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
123
- batch_size = generated.size(0)
124
- return [self._sentence_bleu(generated[i], reference[i], saved_data[i]) for i in range(batch_size)]
125
-
126
- def _sequence_embedding(self, sequence: torch.Tensor) -> torch.Tensor:
127
- embedding = self.shared_embedding(sequence.to(self.device))
128
- return embedding.mean(dim=1)
129
-
130
- def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
131
- generated_emb = self._sequence_embedding(generated)
132
-
133
- gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
134
- gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
135
- return gen_and_saved, gen_and_ref
136
-
137
- def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
138
- generated_emb = self._sequence_embedding(generated)
139
-
140
- gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
141
- gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
142
- gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
143
- return gen_and_saved, gen_and_ref, gen_and_mean
144
-
145
- def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
146
- include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
147
- if self.use_running_mean and negative_running_mean:
148
- gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
149
- return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
150
- 1 - gen_and_mean)
151
- elif self.use_running_mean and include_running_mean:
152
- gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
153
- return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
154
- else:
155
- gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
156
- return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
157
-
158
- def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
159
- saved_data: torch.Tensor) -> torch.Tensor:
160
- gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
161
-
162
- return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
163
-
164
- def len_reward(self, generated: TokenizedDict):
165
- lens = generated['attention_mask'].sum(dim=1)
166
- neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
167
- len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
168
- return len_reward
169
-
170
- def reset_running_mean(self):
171
- self.prev_data_running_mean = None
172
-
173
- def init_running_mean(self, prev_data: torch.Tensor):
174
- self.prev_data_running_mean = self._sequence_embedding(prev_data)
175
-
176
- def update_running_mean(self, prev_data: torch.Tensor):
177
- self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
178
- prev_data) + self.running_mean_decay * self.prev_data_running_mean
179
-
180
- def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
181
- if self.tanh_reward_scale:
182
- return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
183
- else:
184
- return rewards
185
-
186
- def __call__(
187
- self,
188
- generated: TokenizedDict,
189
- reference: TokenizedDict,
190
- saved_data: TokenizedDict,
191
- prev_data: TokenizedDict = None,
192
- mode: MrlRewardMode = MrlRewardMode.STANDARD
193
- ) -> list[float]:
194
- if prev_data is not None:
195
- if self.prev_data_running_mean is None:
196
- self.init_running_mean(prev_data['input_ids'])
197
- else:
198
- self.update_running_mean(prev_data['input_ids'])
199
-
200
- if mode == MrlRewardMode.STANDARD:
201
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
202
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
203
- include_running_mean=prev_data is not None)
204
- sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
205
- elif mode == MrlRewardMode.LONG_RANGE:
206
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
207
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
208
- negative_running_mean=prev_data is not None)
209
- sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
210
- else:
211
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
212
- cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
213
- sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
214
-
215
- rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
216
- return rewards.tolist()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes