rxnn 0.2.48__tar.gz → 0.2.50__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {rxnn-0.2.48 → rxnn-0.2.50}/PKG-INFO +2 -1
  2. {rxnn-0.2.48 → rxnn-0.2.50}/pyproject.toml +2 -2
  3. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/mrl.py +13 -7
  4. rxnn-0.2.50/src/rxnn/training/reward.py +323 -0
  5. rxnn-0.2.48/src/rxnn/training/reward.py +0 -216
  6. {rxnn-0.2.48 → rxnn-0.2.50}/LICENSE +0 -0
  7. {rxnn-0.2.48 → rxnn-0.2.50}/README.md +0 -0
  8. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/.DS_Store +0 -0
  9. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/__init__.py +0 -0
  10. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/__init__.py +0 -0
  11. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/attention.py +0 -0
  12. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/models.py +0 -0
  13. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/experimental/moe.py +0 -0
  14. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/__init__.py +0 -0
  15. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/attention.py +0 -0
  16. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/norm.py +0 -0
  17. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/memory/stm.py +0 -0
  18. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/rxt/__init__.py +0 -0
  19. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/rxt/models.py +0 -0
  20. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/__init__.py +0 -0
  21. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/base.py +0 -0
  22. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/bml.py +0 -0
  23. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/callbacks.py +0 -0
  24. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/dataset.py +0 -0
  25. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/ddp.py +0 -0
  26. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/models.py +0 -0
  27. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/rl.py +0 -0
  28. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/scheduler.py +0 -0
  29. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/tokenizer.py +0 -0
  30. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/training/utils.py +0 -0
  31. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/__init__.py +0 -0
  32. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/attention.py +0 -0
  33. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/layers.py +0 -0
  35. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/mask.py +0 -0
  36. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/models.py +0 -0
  37. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/moe.py +0 -0
  38. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/positional.py +0 -0
  39. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/transformers/sampler.py +0 -0
  40. {rxnn-0.2.48 → rxnn-0.2.50}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.48
3
+ Version: 0.2.50
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
16
  Requires-Dist: datasets (>=3.5.0,<4.0.0)
17
17
  Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
18
+ Requires-Dist: nltk (>=3.9.1,<4.0.0)
18
19
  Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
19
20
  Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
20
21
  Requires-Dist: torch (>=2.6.0,<3.0.0)
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.48"
7
+ version = "0.2.50"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -30,4 +30,4 @@ datasets = "^3.5.0"
30
30
  tokenizers = "^0.21.0"
31
31
  huggingface-hub = "^0.30.0"
32
32
  tensorboard = "^2.19.0"
33
-
33
+ nltk = "^3.9.1"
@@ -71,6 +71,7 @@ class CurriculumConfig(TypedDict):
71
71
  update_epochs: Optional[int]
72
72
  freeze_embeddings: Optional[bool]
73
73
  embedding_lr: Optional[float]
74
+ teacher_forcing: Optional[bool]
74
75
 
75
76
 
76
77
  class SamplerConfig(TypedDict):
@@ -215,6 +216,7 @@ class MRLTrainer:
215
216
  self.callbacks = []
216
217
  self.global_epoch = 0
217
218
  self.global_epochs_count = 0
219
+ self.teacher_forcing = False
218
220
 
219
221
  def _init_optimizers(
220
222
  self,
@@ -452,8 +454,10 @@ class MRLTrainer:
452
454
 
453
455
  # 11. Update STM with generated response (except last interaction, it's not needed)
454
456
  if not is_last_interaction:
455
- self.encode_and_update_stm(next_query,
456
- generated_answer) # update with generated_answer on GPU
457
+ self.encode_and_update_stm(
458
+ next_query,
459
+ self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
460
+ ) # update with generated_answer on GPU
457
461
 
458
462
  # 12. Store trajectory step
459
463
  trajectory: MrlTrajectoryStep = {
@@ -470,7 +474,7 @@ class MRLTrainer:
470
474
  # 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
471
475
  if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
472
476
  prev_interaction = (query, answer)
473
- query, answer = interaction['query'], detached_answer
477
+ query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
474
478
 
475
479
  # 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
476
480
  episode_trajectory: MrlTrajectoryEpisode = {
@@ -857,7 +861,10 @@ class MRLTrainer:
857
861
 
858
862
  # 10. Encode and update memory for the next interaction
859
863
  if not is_last_interaction:
860
- self.encode_and_update_stm(next_query, generated_answer)
864
+ self.encode_and_update_stm(
865
+ next_query,
866
+ self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
867
+ )
861
868
 
862
869
  # 11. Accumulate rewards
863
870
  step_reward = torch.tensor(reward).mean().to(self.device)
@@ -870,7 +877,7 @@ class MRLTrainer:
870
877
  # 12. Save previous interaction
871
878
  if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
872
879
  prev_interaction = (query, answer)
873
- query, answer = interaction['query'], detached_answer
880
+ query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
874
881
  avg_episode_reward = (episode_reward / episode_interactions).item()
875
882
  # 13. Run eval TensorBoard writer with average episode reward
876
883
  self._eval_writer(avg_episode_reward, epoch)
@@ -1000,8 +1007,7 @@ class MRLTrainer:
1000
1007
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
1001
1008
  self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
1002
1009
  self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
1003
-
1004
-
1010
+ self.teacher_forcing = config.get('teacher_forcing', False)
1005
1011
 
1006
1012
  def has_param(field: OptimField) -> bool:
1007
1013
  return field in config and config[field] is not None
@@ -0,0 +1,323 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
5
+ from enum import Enum
6
+ from typing import Optional
7
+ from .utils import TokenizedDict
8
+
9
+
10
+ class MrlRewardMode(Enum):
11
+ STANDARD = 1
12
+ NEGATIVE = 2
13
+ LONG_RANGE = 3
14
+
15
+
16
+ class MrlRewardModel:
17
+ def __init__(
18
+ self,
19
+ shared_embedding: nn.Embedding,
20
+ device: torch.device,
21
+ bleu_with_saved_data: bool = False,
22
+ bleu_factor: float = 0.5,
23
+ bleu_ref_factor: float = 0.5,
24
+ bleu_saved_factor: float = 0.5,
25
+ cos_factor: float = 0.5,
26
+ cos_ref_factor: float = 0.5,
27
+ cos_saved_factor: float = 0.5,
28
+ multi_cos_ref_factor: float = 0.3,
29
+ multi_cos_saved_factor: float = 0.5,
30
+ multi_cos_running_mean_factor: float = 0.2,
31
+ neg_bleu_factor: Optional[float] = None,
32
+ neg_cos_factor: Optional[float] = None,
33
+ neg_cos_ref_factor: Optional[float] = None,
34
+ neg_cos_saved_factor: Optional[float] = None,
35
+ neg_bleu_ref_factor: float = 0.5,
36
+ neg_bleu_saved_factor: float = 0.5,
37
+ allow_not_summing_factors: bool = False,
38
+ reward_len: bool = False,
39
+ neg_reward_len: bool = False,
40
+ max_rewarded_len: int = None,
41
+ target_len_as_ref: bool = False,
42
+ len_factor: int = None,
43
+ use_running_mean: bool = True,
44
+ running_mean_decay: float = 0.2,
45
+ bleu_saved_weights: tuple = (0.5, 0.5),
46
+ bleu_ref_weights: tuple = (0.5, 0.5),
47
+ tanh_reward_scale: bool = False,
48
+ rewards_scale: float = 1.0,
49
+ debug_mode: int = 0,
50
+ ):
51
+ self.shared_embedding = shared_embedding.to(device)
52
+ self.device = device
53
+ self.bleu_with_saved_data = bleu_with_saved_data
54
+
55
+ self.bleu_factor = bleu_factor
56
+ self.bleu_ref_factor = bleu_ref_factor
57
+ self.bleu_saved_factor = bleu_saved_factor
58
+ self.cos_factor = cos_factor
59
+ self.cos_ref_factor = cos_ref_factor
60
+ self.cos_saved_factor = cos_saved_factor
61
+ self.multi_cos_ref_factor = multi_cos_ref_factor
62
+ self.multi_cos_saved_factor = multi_cos_saved_factor
63
+ self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
64
+ self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
65
+ self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
66
+ self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
67
+ self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
68
+ self.neg_bleu_ref_factor = neg_bleu_ref_factor
69
+ self.neg_bleu_saved_factor = neg_bleu_saved_factor
70
+ self.reward_len = reward_len
71
+ self.neg_reward_len = neg_reward_len
72
+ self.max_rewarded_len = max_rewarded_len
73
+ self.target_len_as_ref = target_len_as_ref
74
+ self.len_factor = len_factor
75
+ self.use_running_mean = use_running_mean
76
+ self.running_mean_decay = running_mean_decay
77
+ self.bleu_ref_weights = bleu_ref_weights
78
+ self.bleu_saved_weights = bleu_saved_weights
79
+ self.tanh_reward_scale = tanh_reward_scale
80
+ self.rewards_scale = rewards_scale
81
+ self.bleu_smoothing = SmoothingFunction().method4
82
+ self.debug_mode = debug_mode
83
+
84
+ self.prev_data_running_mean = None
85
+
86
+ if not allow_not_summing_factors:
87
+ if reward_len:
88
+ assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
89
+ assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
90
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
91
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
92
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
93
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
94
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
95
+ else:
96
+ assert self.bleu_factor + self.cos_factor == 1.0
97
+ assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
98
+ assert self.cos_ref_factor + self.cos_saved_factor == 1.0
99
+ assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
100
+ assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
101
+ assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
102
+ assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
103
+
104
+ def _sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
105
+ masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
106
+ generated, reference, saved_data = input_ids
107
+ generated_mask, reference_mask, saved_data_mask = masks
108
+
109
+ generated = generated.tolist()[:generated_mask.sum().item()]
110
+ reference = reference.tolist()[:reference_mask.sum().item()]
111
+ saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
112
+
113
+ if self.debug_mode == 2:
114
+ print('LENS: ', (len(generated), len(reference), len(saved_data)))
115
+
116
+ if self.bleu_with_saved_data:
117
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
118
+ smoothing_function=self.bleu_smoothing)
119
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
120
+ smoothing_function=self.bleu_smoothing)
121
+ if self.debug_mode == 2:
122
+ print('REF BLEU: ', ref_bleu)
123
+ print('SAVED BLEU: ', saved_bleu)
124
+
125
+ return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
126
+ else:
127
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
128
+
129
+ def _negative_sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
130
+ masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
131
+ generated, reference, saved_data = input_ids
132
+ generated_mask, reference_mask, saved_data_mask = masks
133
+
134
+ generated = generated.tolist()[:generated_mask.sum().item()]
135
+ reference = reference.tolist()[:reference_mask.sum().item()]
136
+ saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
137
+
138
+ if self.debug_mode == 2:
139
+ print('LENS: ', (len(generated), len(reference), len(saved_data)))
140
+
141
+ if self.bleu_with_saved_data:
142
+ ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
143
+ smoothing_function=self.bleu_smoothing)
144
+ saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
145
+ smoothing_function=self.bleu_smoothing)
146
+ saved_bleu = 1 - saved_bleu
147
+
148
+ if self.debug_mode == 2:
149
+ print('REF BLEU: ', ref_bleu)
150
+ print('SAVED BLEU: ', saved_bleu)
151
+
152
+ return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
153
+ else:
154
+ return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
155
+
156
+ def batch_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[float]:
157
+ batch_size = generated['input_ids'].size(0)
158
+
159
+ return [
160
+ self._sentence_bleu(
161
+ input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
162
+ masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
163
+ ) for i in range(batch_size)
164
+ ]
165
+
166
+ def negative_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[
167
+ float]:
168
+ batch_size = generated['input_ids'].size(0)
169
+
170
+ return [
171
+ self._negative_sentence_bleu(
172
+ input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
173
+ masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
174
+ ) for i in range(batch_size)
175
+ ]
176
+
177
+ def _sequence_embedding(self, sequence: TokenizedDict) -> torch.Tensor:
178
+ input_ids = sequence['input_ids']
179
+ attention_mask = sequence['attention_mask']
180
+
181
+ # Get embeddings
182
+ embeddings = self.shared_embedding(input_ids.to(self.device))
183
+
184
+ # Apply attention mask
185
+ mask_expanded = attention_mask.unsqueeze(-1).to(self.device)
186
+ masked_embeddings = embeddings * mask_expanded
187
+
188
+ # Compute mean with masking
189
+ sum_embeddings = torch.sum(masked_embeddings, dim=1)
190
+ token_counts = torch.sum(mask_expanded, dim=1)
191
+ token_counts = torch.clamp(token_counts, min=1e-8) # Avoid division by zero
192
+
193
+ return sum_embeddings / token_counts
194
+
195
+ def _cosine_sim(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
196
+ generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
197
+ saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
198
+ reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
199
+
200
+ gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
201
+ gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
202
+
203
+ if self.debug_mode >= 1:
204
+ print('GEN AND SAVED: ', gen_and_saved.mean())
205
+ print('GEN AND REF: ', gen_and_ref.mean())
206
+ return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0)
207
+
208
+ def _cosine_sim_running_mean(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
209
+ generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
210
+ saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
211
+ reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
212
+ running_emb = F.normalize(self.prev_data_running_mean, dim=-1)
213
+
214
+ gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
215
+ gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
216
+ gen_and_mean = F.cosine_similarity(generated_emb, running_emb, dim=1)
217
+
218
+ if self.debug_mode >= 1:
219
+ print('GEN AND SAVED: ', gen_and_saved.mean())
220
+ print('GEN AND REF: ', gen_and_ref.mean())
221
+ print('GEN AND MEAN: ', gen_and_mean.mean())
222
+
223
+ return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0), torch.clamp(gen_and_mean, min=0)
224
+
225
+ def batch_cosine(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict,
226
+ include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
227
+ if self.use_running_mean and negative_running_mean:
228
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
229
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
230
+ 1 - gen_and_mean)
231
+ elif self.use_running_mean and include_running_mean:
232
+ gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
233
+ return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
234
+ else:
235
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
236
+ return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
237
+
238
+ def negative_cosine(self, generated: TokenizedDict, reference: TokenizedDict,
239
+ saved_data: TokenizedDict) -> torch.Tensor:
240
+ gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
241
+
242
+ return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
243
+
244
+ def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
245
+ target_lens = reference['attention_mask'].to(self.device).sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
246
+ lens = generated['attention_mask'].to(self.device).sum(dim=1)
247
+ neg_lens = target_lens / lens if self.neg_reward_len else 1.0
248
+ len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
249
+ return len_reward
250
+
251
+ def reset_running_mean(self):
252
+ self.prev_data_running_mean = None
253
+
254
+ def init_running_mean(self, prev_data: TokenizedDict):
255
+ self.prev_data_running_mean = self._sequence_embedding(prev_data)
256
+
257
+ def update_running_mean(self, prev_data: TokenizedDict):
258
+ self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
259
+ prev_data) + self.running_mean_decay * self.prev_data_running_mean
260
+
261
+ def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
262
+ if self.tanh_reward_scale:
263
+ return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
264
+ else:
265
+ return rewards
266
+
267
+ def __call__(
268
+ self,
269
+ generated: TokenizedDict,
270
+ reference: TokenizedDict,
271
+ saved_data: TokenizedDict,
272
+ prev_data: TokenizedDict = None,
273
+ mode: MrlRewardMode = MrlRewardMode.STANDARD
274
+ ) -> list[float]:
275
+ if prev_data is not None:
276
+ if self.prev_data_running_mean is None:
277
+ self.init_running_mean(prev_data)
278
+ else:
279
+ self.update_running_mean(prev_data)
280
+
281
+ if mode == MrlRewardMode.STANDARD:
282
+ bleu = self.batch_bleu(generated, reference, saved_data)
283
+ cosine = self.batch_cosine(generated, reference, saved_data, include_running_mean=prev_data is not None)
284
+
285
+ if self.debug_mode >= 1:
286
+ print('STANDARD MODE')
287
+ print('BLEU: ', sum(bleu) / len(bleu))
288
+ print('COSINE: ', sum(cosine) / len(cosine))
289
+
290
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
291
+ elif mode == MrlRewardMode.LONG_RANGE:
292
+ bleu = self.batch_bleu(generated, reference, saved_data)
293
+ cosine = self.batch_cosine(generated, reference, saved_data,
294
+ negative_running_mean=prev_data is not None)
295
+
296
+ if self.debug_mode >= 1:
297
+ print('LONG MODE')
298
+ print('BLEU: ', sum(bleu) / len(bleu))
299
+ print('COSINE: ', sum(cosine) / len(cosine))
300
+
301
+ sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
302
+ else:
303
+ bleu = self.negative_bleu(generated, reference, saved_data)
304
+ cosine = self.negative_cosine(generated, reference, saved_data)
305
+
306
+ if self.debug_mode >= 1:
307
+ print('NEGATIVE MODE')
308
+ print('BLEU: ', sum(bleu) / len(bleu))
309
+ print('COSINE: ', sum(cosine) / len(cosine))
310
+
311
+ sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
312
+
313
+ if self.reward_len:
314
+ len_reward = self.len_reward(generated, reference)
315
+
316
+ if self.debug_mode >= 1:
317
+ print('REWARD LEN: ', (len_reward.sum() / len_reward.size(0)).item())
318
+
319
+ rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * len_reward) * self.rewards_scale
320
+ else:
321
+ rewards = self._pre_scale_rewards(sim_rewards) * self.rewards_scale
322
+
323
+ return rewards.tolist()
@@ -1,216 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from enum import Enum
5
- from typing import Optional
6
- from .utils import TokenizedDict
7
-
8
-
9
- class MrlRewardMode(Enum):
10
- STANDARD = 1
11
- NEGATIVE = 2
12
- LONG_RANGE = 3
13
-
14
-
15
- class MrlRewardModel:
16
- def __init__(
17
- self,
18
- shared_embedding: nn.Embedding,
19
- device: torch.device,
20
- bleu_with_saved_data: bool = False,
21
- bleu_factor: float = 0.5,
22
- bleu_ref_factor: float = 0.5,
23
- bleu_saved_factor: float = 0.5,
24
- cos_factor: float = 0.5,
25
- cos_ref_factor: float = 0.5,
26
- cos_saved_factor: float = 0.5,
27
- multi_cos_ref_factor: float = 0.3,
28
- multi_cos_saved_factor: float = 0.5,
29
- multi_cos_running_mean_factor: float = 0.2,
30
- neg_bleu_factor: Optional[float] = None,
31
- neg_cos_factor: Optional[float] = None,
32
- neg_cos_ref_factor: Optional[float] = None,
33
- neg_cos_saved_factor: Optional[float] = None,
34
- neg_bleu_ref_factor: float = 0.5,
35
- neg_bleu_saved_factor: float = 0.5,
36
- allow_not_summing_factors: bool = False,
37
- reward_len: bool = False,
38
- neg_reward_len: bool = False,
39
- max_rewarded_len: int = None,
40
- len_factor: int = None,
41
- use_running_mean: bool = True,
42
- running_mean_decay: float = 0.2,
43
- bleu_saved_weights: tuple = (0.5, 0.5),
44
- bleu_ref_weights: tuple = (0.5, 0.5),
45
- tanh_reward_scale: bool = False,
46
- rewards_scale: float = 1.0,
47
- ):
48
- self.shared_embedding = shared_embedding.to(device)
49
- self.device = device
50
- self.bleu_with_saved_data = bleu_with_saved_data
51
-
52
- self.bleu_factor = bleu_factor
53
- self.bleu_ref_factor = bleu_ref_factor
54
- self.bleu_saved_factor = bleu_saved_factor
55
- self.cos_factor = cos_factor
56
- self.cos_ref_factor = cos_ref_factor
57
- self.cos_saved_factor = cos_saved_factor
58
- self.multi_cos_ref_factor = multi_cos_ref_factor
59
- self.multi_cos_saved_factor = multi_cos_saved_factor
60
- self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
61
- self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
62
- self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
63
- self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
64
- self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
65
- self.neg_bleu_ref_factor = neg_bleu_ref_factor
66
- self.neg_bleu_saved_factor = neg_bleu_saved_factor
67
- self.reward_len = reward_len
68
- self.neg_reward_len = neg_reward_len
69
- self.max_rewarded_len = max_rewarded_len
70
- self.len_factor = len_factor
71
- self.use_running_mean = use_running_mean
72
- self.running_mean_decay = running_mean_decay
73
- self.bleu_ref_weights = bleu_ref_weights
74
- self.bleu_saved_weights = bleu_saved_weights
75
- self.tanh_reward_scale = tanh_reward_scale
76
- self.rewards_scale = rewards_scale
77
-
78
- self.prev_data_running_mean = None
79
-
80
- if not allow_not_summing_factors:
81
- if reward_len:
82
- assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
83
- assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
84
- assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
85
- assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
86
- assert self.cos_ref_factor + self.cos_saved_factor == 1.0
87
- assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
88
- assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
89
- else:
90
- assert self.bleu_factor + self.cos_factor == 1.0
91
- assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
92
- assert self.cos_ref_factor + self.cos_saved_factor == 1.0
93
- assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
94
- assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
95
- assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
96
- assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
97
-
98
- def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
99
- from nltk.translate.bleu_score import sentence_bleu
100
-
101
- if self.bleu_with_saved_data:
102
- ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
103
- saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
104
- return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
105
- else:
106
- return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
107
-
108
-
109
- def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
110
- saved_data: torch.Tensor) -> float:
111
- from nltk.translate.bleu_score import sentence_bleu
112
-
113
- if self.bleu_with_saved_data:
114
- ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
115
- saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
116
- saved_bleu = 1 - saved_bleu
117
-
118
- return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
119
- else:
120
- return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
121
-
122
- def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
123
- batch_size = generated.size(0)
124
- return [self._sentence_bleu(generated[i], reference[i], saved_data[i]) for i in range(batch_size)]
125
-
126
- def _sequence_embedding(self, sequence: torch.Tensor) -> torch.Tensor:
127
- embedding = self.shared_embedding(sequence.to(self.device))
128
- return embedding.mean(dim=1)
129
-
130
- def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
131
- generated_emb = self._sequence_embedding(generated)
132
-
133
- gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
134
- gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
135
- return gen_and_saved, gen_and_ref
136
-
137
- def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
138
- generated_emb = self._sequence_embedding(generated)
139
-
140
- gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
141
- gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
142
- gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
143
- return gen_and_saved, gen_and_ref, gen_and_mean
144
-
145
- def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
146
- include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
147
- if self.use_running_mean and negative_running_mean:
148
- gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
149
- return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
150
- 1 - gen_and_mean)
151
- elif self.use_running_mean and include_running_mean:
152
- gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
153
- return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
154
- else:
155
- gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
156
- return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
157
-
158
- def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
159
- saved_data: torch.Tensor) -> torch.Tensor:
160
- gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
161
-
162
- return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
163
-
164
- def len_reward(self, generated: TokenizedDict):
165
- lens = generated['attention_mask'].sum(dim=1)
166
- neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
167
- len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
168
- return len_reward
169
-
170
- def reset_running_mean(self):
171
- self.prev_data_running_mean = None
172
-
173
- def init_running_mean(self, prev_data: torch.Tensor):
174
- self.prev_data_running_mean = self._sequence_embedding(prev_data)
175
-
176
- def update_running_mean(self, prev_data: torch.Tensor):
177
- self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
178
- prev_data) + self.running_mean_decay * self.prev_data_running_mean
179
-
180
- def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
181
- if self.tanh_reward_scale:
182
- return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
183
- else:
184
- return rewards
185
-
186
- def __call__(
187
- self,
188
- generated: TokenizedDict,
189
- reference: TokenizedDict,
190
- saved_data: TokenizedDict,
191
- prev_data: TokenizedDict = None,
192
- mode: MrlRewardMode = MrlRewardMode.STANDARD
193
- ) -> list[float]:
194
- if prev_data is not None:
195
- if self.prev_data_running_mean is None:
196
- self.init_running_mean(prev_data['input_ids'])
197
- else:
198
- self.update_running_mean(prev_data['input_ids'])
199
-
200
- if mode == MrlRewardMode.STANDARD:
201
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
202
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
203
- include_running_mean=prev_data is not None)
204
- sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
205
- elif mode == MrlRewardMode.LONG_RANGE:
206
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
207
- cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
208
- negative_running_mean=prev_data is not None)
209
- sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
210
- else:
211
- bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
212
- cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
213
- sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
214
-
215
- rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
216
- return rewards.tolist()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes