rxnn 0.2.47__py3-none-any.whl → 0.2.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/norm.py +3 -1
- rxnn/training/mrl.py +34 -10
- rxnn/training/reward.py +152 -45
- {rxnn-0.2.47.dist-info → rxnn-0.2.49.dist-info}/METADATA +2 -1
- {rxnn-0.2.47.dist-info → rxnn-0.2.49.dist-info}/RECORD +7 -7
- {rxnn-0.2.47.dist-info → rxnn-0.2.49.dist-info}/LICENSE +0 -0
- {rxnn-0.2.47.dist-info → rxnn-0.2.49.dist-info}/WHEEL +0 -0
rxnn/memory/norm.py
CHANGED
@@ -163,7 +163,7 @@ def init_memory_norm(
|
|
163
163
|
init_scale: float = 1.0,
|
164
164
|
per_dim_scale: bool = False,
|
165
165
|
) -> nn.Module:
|
166
|
-
assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
|
166
|
+
assert norm_type in ['layer', 'rms', 'adaptive', 'positional', 'classic-rms']
|
167
167
|
if norm_type == 'layer':
|
168
168
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
169
169
|
elif norm_type == 'rms':
|
@@ -172,4 +172,6 @@ def init_memory_norm(
|
|
172
172
|
return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
|
173
173
|
elif norm_type == 'positional':
|
174
174
|
return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
|
175
|
+
elif norm_type == 'classic-rms':
|
176
|
+
return nn.RMSNorm(dim)
|
175
177
|
return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
|
rxnn/training/mrl.py
CHANGED
@@ -35,6 +35,7 @@ class MrlConfig(TypedDict):
|
|
35
35
|
moe_aux_loss_scale: Optional[float]
|
36
36
|
freeze_embeddings: Optional[bool]
|
37
37
|
embedding_lr: Optional[float]
|
38
|
+
use_memory_warmup: Optional[bool]
|
38
39
|
|
39
40
|
|
40
41
|
class MrlStrategy(Enum):
|
@@ -70,6 +71,7 @@ class CurriculumConfig(TypedDict):
|
|
70
71
|
update_epochs: Optional[int]
|
71
72
|
freeze_embeddings: Optional[bool]
|
72
73
|
embedding_lr: Optional[float]
|
74
|
+
teacher_forcing: Optional[bool]
|
73
75
|
|
74
76
|
|
75
77
|
class SamplerConfig(TypedDict):
|
@@ -136,6 +138,7 @@ class MRLTrainer:
|
|
136
138
|
self.moe_aux_loss_scale = config.get('moe_aux_loss_scale', 0.01)
|
137
139
|
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
138
140
|
self.freeze_embeddings = self.shared_freeze_embeddings
|
141
|
+
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
139
142
|
# Internal update epochs config
|
140
143
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
141
144
|
self.update_epochs = self.shared_update_epochs
|
@@ -213,6 +216,7 @@ class MRLTrainer:
|
|
213
216
|
self.callbacks = []
|
214
217
|
self.global_epoch = 0
|
215
218
|
self.global_epochs_count = 0
|
219
|
+
self.teacher_forcing = False
|
216
220
|
|
217
221
|
def _init_optimizers(
|
218
222
|
self,
|
@@ -381,6 +385,11 @@ class MRLTrainer:
|
|
381
385
|
self.writer.add_scalar(f'Collect/episode reward (steps: {self.curriculum_steps})', avg_reward,
|
382
386
|
self.stage_step['collect'])
|
383
387
|
|
388
|
+
def memory_warmup(self, query: TokenizedDict, answer: TokenizedDict):
|
389
|
+
if self.use_memory_warmup:
|
390
|
+
with torch.no_grad():
|
391
|
+
self.encode_and_update_stm(query, answer)
|
392
|
+
|
384
393
|
def collect_trajectories(self, dataloader: DataLoader, epoch: int, batch_size: int) -> list[MrlTrajectoryEpisode]:
|
385
394
|
"""Collect trajectories for PPO for current curriculum step."""
|
386
395
|
# 1. Init trajectories list
|
@@ -402,8 +411,13 @@ class MRLTrainer:
|
|
402
411
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
403
412
|
interactions = interactions[:self.curriculum_steps]
|
404
413
|
interactions_len = len(interactions)
|
414
|
+
|
415
|
+
first_interaction = self._move_multiple_batches(first_query, first_answer)
|
416
|
+
|
417
|
+
if reset_done:
|
418
|
+
self.memory_warmup(*first_interaction)
|
405
419
|
# 6. Encode and update STM with data to save from first interaction
|
406
|
-
self.encode_and_update_stm(*
|
420
|
+
self.encode_and_update_stm(*first_interaction)
|
407
421
|
|
408
422
|
# 7. Save first interaction as data to save (for trajectory state)
|
409
423
|
query, answer = first_query, first_answer
|
@@ -440,8 +454,10 @@ class MRLTrainer:
|
|
440
454
|
|
441
455
|
# 11. Update STM with generated response (except last interaction, it's not needed)
|
442
456
|
if not is_last_interaction:
|
443
|
-
self.encode_and_update_stm(
|
444
|
-
|
457
|
+
self.encode_and_update_stm(
|
458
|
+
next_query,
|
459
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
460
|
+
) # update with generated_answer on GPU
|
445
461
|
|
446
462
|
# 12. Store trajectory step
|
447
463
|
trajectory: MrlTrajectoryStep = {
|
@@ -458,7 +474,7 @@ class MRLTrainer:
|
|
458
474
|
# 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
|
459
475
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
460
476
|
prev_interaction = (query, answer)
|
461
|
-
query, answer = interaction['query'], detached_answer
|
477
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
462
478
|
|
463
479
|
# 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
464
480
|
episode_trajectory: MrlTrajectoryEpisode = {
|
@@ -649,6 +665,9 @@ class MRLTrainer:
|
|
649
665
|
|
650
666
|
self.actor.clone_reset_memory()
|
651
667
|
|
668
|
+
if should_reset_stm and step_idx == 0:
|
669
|
+
self.memory_warmup(query, answer)
|
670
|
+
|
652
671
|
# 7. In memory aware critic version, encode and update STM before critic update, to include its gradients in critic loss too
|
653
672
|
if self.memory_aware_critic:
|
654
673
|
self.encode_and_update_stm(query, answer)
|
@@ -798,13 +817,16 @@ class MRLTrainer:
|
|
798
817
|
if batch['query']['input_ids'].size(0) == batch_size:
|
799
818
|
self._increment_steps('eval')
|
800
819
|
# 3. Reset STM with random resets ratio and reward model running mean
|
801
|
-
self.reset_stm()
|
820
|
+
reset_stm = self.reset_stm()
|
802
821
|
self.reward.reset_running_mean()
|
803
822
|
|
804
823
|
# 4. Get batches for first queries, answers and all follow-up interactions
|
805
824
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
806
825
|
# 5. Encode and update STM with initial interactions (batch)
|
807
|
-
self.
|
826
|
+
first_interaction = self._move_multiple_batches(first_query, first_answer)
|
827
|
+
if reset_stm:
|
828
|
+
self.memory_warmup(*first_interaction)
|
829
|
+
self.encode_and_update_stm(*first_interaction)
|
808
830
|
|
809
831
|
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
810
832
|
interactions_len = len(interactions)
|
@@ -839,7 +861,10 @@ class MRLTrainer:
|
|
839
861
|
|
840
862
|
# 10. Encode and update memory for the next interaction
|
841
863
|
if not is_last_interaction:
|
842
|
-
self.encode_and_update_stm(
|
864
|
+
self.encode_and_update_stm(
|
865
|
+
next_query,
|
866
|
+
self._move_batch(interaction['answer']) if self.teacher_forcing else generated_answer
|
867
|
+
)
|
843
868
|
|
844
869
|
# 11. Accumulate rewards
|
845
870
|
step_reward = torch.tensor(reward).mean().to(self.device)
|
@@ -852,7 +877,7 @@ class MRLTrainer:
|
|
852
877
|
# 12. Save previous interaction
|
853
878
|
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
854
879
|
prev_interaction = (query, answer)
|
855
|
-
query, answer = interaction['query'], detached_answer
|
880
|
+
query, answer = interaction['query'], (interaction['answer'] if self.teacher_forcing else detached_answer)
|
856
881
|
avg_episode_reward = (episode_reward / episode_interactions).item()
|
857
882
|
# 13. Run eval TensorBoard writer with average episode reward
|
858
883
|
self._eval_writer(avg_episode_reward, epoch)
|
@@ -982,8 +1007,7 @@ class MRLTrainer:
|
|
982
1007
|
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
983
1008
|
self.update_epochs = config.get('update_epochs', self.shared_update_epochs) # Internal update epochs
|
984
1009
|
self.freeze_embeddings = config.get('freeze_embeddings', self.shared_freeze_embeddings)
|
985
|
-
|
986
|
-
|
1010
|
+
self.teacher_forcing = config.get('teacher_forcing', False)
|
987
1011
|
|
988
1012
|
def has_param(field: OptimField) -> bool:
|
989
1013
|
return field in config and config[field] is not None
|
rxnn/training/reward.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
4
5
|
from enum import Enum
|
5
6
|
from typing import Optional
|
6
7
|
from .utils import TokenizedDict
|
@@ -37,6 +38,7 @@ class MrlRewardModel:
|
|
37
38
|
reward_len: bool = False,
|
38
39
|
neg_reward_len: bool = False,
|
39
40
|
max_rewarded_len: int = None,
|
41
|
+
target_len_as_ref: bool = False,
|
40
42
|
len_factor: int = None,
|
41
43
|
use_running_mean: bool = True,
|
42
44
|
running_mean_decay: float = 0.2,
|
@@ -44,6 +46,7 @@ class MrlRewardModel:
|
|
44
46
|
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
47
|
tanh_reward_scale: bool = False,
|
46
48
|
rewards_scale: float = 1.0,
|
49
|
+
debug_mode: int = 0,
|
47
50
|
):
|
48
51
|
self.shared_embedding = shared_embedding.to(device)
|
49
52
|
self.device = device
|
@@ -67,6 +70,7 @@ class MrlRewardModel:
|
|
67
70
|
self.reward_len = reward_len
|
68
71
|
self.neg_reward_len = neg_reward_len
|
69
72
|
self.max_rewarded_len = max_rewarded_len
|
73
|
+
self.target_len_as_ref = target_len_as_ref
|
70
74
|
self.len_factor = len_factor
|
71
75
|
self.use_running_mean = use_running_mean
|
72
76
|
self.running_mean_decay = running_mean_decay
|
@@ -74,6 +78,8 @@ class MrlRewardModel:
|
|
74
78
|
self.bleu_saved_weights = bleu_saved_weights
|
75
79
|
self.tanh_reward_scale = tanh_reward_scale
|
76
80
|
self.rewards_scale = rewards_scale
|
81
|
+
self.bleu_smoothing = SmoothingFunction().method4
|
82
|
+
self.debug_mode = debug_mode
|
77
83
|
|
78
84
|
self.prev_data_running_mean = None
|
79
85
|
|
@@ -95,59 +101,133 @@ class MrlRewardModel:
|
|
95
101
|
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
96
102
|
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
97
103
|
|
98
|
-
def _sentence_bleu(self,
|
99
|
-
|
104
|
+
def _sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
105
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
106
|
+
generated, reference, saved_data = input_ids
|
107
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
108
|
+
|
109
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
110
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
111
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
112
|
+
|
113
|
+
if self.debug_mode == 2:
|
114
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
100
115
|
|
101
116
|
if self.bleu_with_saved_data:
|
102
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights
|
103
|
-
|
117
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
118
|
+
smoothing_function=self.bleu_smoothing)
|
119
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
120
|
+
smoothing_function=self.bleu_smoothing)
|
121
|
+
if self.debug_mode == 2:
|
122
|
+
print('REF BLEU: ', ref_bleu)
|
123
|
+
print('SAVED BLEU: ', saved_bleu)
|
124
|
+
|
104
125
|
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
105
126
|
else:
|
106
127
|
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
107
128
|
|
129
|
+
def _negative_sentence_bleu(self, input_ids: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
130
|
+
masks: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> float:
|
131
|
+
generated, reference, saved_data = input_ids
|
132
|
+
generated_mask, reference_mask, saved_data_mask = masks
|
133
|
+
|
134
|
+
generated = generated.tolist()[:generated_mask.sum().item()]
|
135
|
+
reference = reference.tolist()[:reference_mask.sum().item()]
|
136
|
+
saved_data = saved_data.tolist()[:saved_data_mask.sum().item()]
|
108
137
|
|
109
|
-
|
110
|
-
|
111
|
-
from nltk.translate.bleu_score import sentence_bleu
|
138
|
+
if self.debug_mode == 2:
|
139
|
+
print('LENS: ', (len(generated), len(reference), len(saved_data)))
|
112
140
|
|
113
141
|
if self.bleu_with_saved_data:
|
114
|
-
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights
|
115
|
-
|
142
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights,
|
143
|
+
smoothing_function=self.bleu_smoothing)
|
144
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights,
|
145
|
+
smoothing_function=self.bleu_smoothing)
|
116
146
|
saved_bleu = 1 - saved_bleu
|
117
147
|
|
148
|
+
if self.debug_mode == 2:
|
149
|
+
print('REF BLEU: ', ref_bleu)
|
150
|
+
print('SAVED BLEU: ', saved_bleu)
|
151
|
+
|
118
152
|
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
119
153
|
else:
|
120
154
|
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
121
155
|
|
122
|
-
def batch_bleu(self, generated:
|
123
|
-
batch_size = generated.size(0)
|
124
|
-
|
156
|
+
def batch_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[float]:
|
157
|
+
batch_size = generated['input_ids'].size(0)
|
158
|
+
|
159
|
+
return [
|
160
|
+
self._sentence_bleu(
|
161
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
162
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
163
|
+
) for i in range(batch_size)
|
164
|
+
]
|
165
|
+
|
166
|
+
def negative_bleu(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict) -> list[
|
167
|
+
float]:
|
168
|
+
batch_size = generated['input_ids'].size(0)
|
169
|
+
|
170
|
+
return [
|
171
|
+
self._negative_sentence_bleu(
|
172
|
+
input_ids=(generated['input_ids'][i], reference['input_ids'][i], saved_data['input_ids'][i]),
|
173
|
+
masks=(generated['attention_mask'][i], reference['attention_mask'][i], saved_data['attention_mask'][i])
|
174
|
+
) for i in range(batch_size)
|
175
|
+
]
|
176
|
+
|
177
|
+
def _sequence_embedding(self, sequence: TokenizedDict) -> torch.Tensor:
|
178
|
+
input_ids = sequence['input_ids']
|
179
|
+
attention_mask = sequence['attention_mask']
|
180
|
+
|
181
|
+
# Get embeddings
|
182
|
+
embeddings = self.shared_embedding(input_ids.to(self.device))
|
183
|
+
|
184
|
+
# Apply attention mask
|
185
|
+
mask_expanded = attention_mask.unsqueeze(-1).to(self.device)
|
186
|
+
masked_embeddings = embeddings * mask_expanded
|
187
|
+
|
188
|
+
# Compute mean with masking
|
189
|
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
190
|
+
token_counts = torch.sum(mask_expanded, dim=1)
|
191
|
+
token_counts = torch.clamp(token_counts, min=1e-8) # Avoid division by zero
|
192
|
+
|
193
|
+
return sum_embeddings / token_counts
|
125
194
|
|
126
|
-
def
|
127
|
-
|
128
|
-
|
195
|
+
def _cosine_sim(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
196
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
197
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
198
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
129
199
|
|
130
|
-
|
131
|
-
|
200
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
201
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
132
202
|
|
133
|
-
|
134
|
-
|
135
|
-
|
203
|
+
if self.debug_mode >= 1:
|
204
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
205
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
206
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0)
|
136
207
|
|
137
|
-
def _cosine_sim_running_mean(self, generated:
|
138
|
-
generated_emb = self._sequence_embedding(generated)
|
208
|
+
def _cosine_sim_running_mean(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict):
|
209
|
+
generated_emb = F.normalize(self._sequence_embedding(generated), dim=-1)
|
210
|
+
saved_data_emb = F.normalize(self._sequence_embedding(saved_data), dim=-1)
|
211
|
+
reference_emb = F.normalize(self._sequence_embedding(reference), dim=-1)
|
212
|
+
running_emb = F.normalize(self.prev_data_running_mean, dim=-1)
|
139
213
|
|
140
|
-
gen_and_saved =
|
141
|
-
gen_and_ref =
|
142
|
-
gen_and_mean =
|
143
|
-
return gen_and_saved, gen_and_ref, gen_and_mean
|
214
|
+
gen_and_saved = F.cosine_similarity(generated_emb, saved_data_emb, dim=1)
|
215
|
+
gen_and_ref = F.cosine_similarity(generated_emb, reference_emb, dim=1)
|
216
|
+
gen_and_mean = F.cosine_similarity(generated_emb, running_emb, dim=1)
|
144
217
|
|
145
|
-
|
218
|
+
if self.debug_mode >= 1:
|
219
|
+
print('GEN AND SAVED: ', gen_and_saved.mean())
|
220
|
+
print('GEN AND REF: ', gen_and_ref.mean())
|
221
|
+
print('GEN AND MEAN: ', gen_and_mean.mean())
|
222
|
+
|
223
|
+
return torch.clamp(gen_and_saved, min=0), torch.clamp(gen_and_ref, min=0), torch.clamp(gen_and_mean, min=0)
|
224
|
+
|
225
|
+
def batch_cosine(self, generated: TokenizedDict, reference: TokenizedDict, saved_data: TokenizedDict,
|
146
226
|
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
147
227
|
if self.use_running_mean and negative_running_mean:
|
148
228
|
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
149
229
|
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
150
|
-
|
230
|
+
1 - gen_and_mean)
|
151
231
|
elif self.use_running_mean and include_running_mean:
|
152
232
|
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
153
233
|
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
@@ -155,25 +235,26 @@ class MrlRewardModel:
|
|
155
235
|
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
156
236
|
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
157
237
|
|
158
|
-
def negative_cosine(self, generated:
|
159
|
-
saved_data:
|
238
|
+
def negative_cosine(self, generated: TokenizedDict, reference: TokenizedDict,
|
239
|
+
saved_data: TokenizedDict) -> torch.Tensor:
|
160
240
|
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
161
241
|
|
162
242
|
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
163
243
|
|
164
|
-
def len_reward(self, generated: TokenizedDict):
|
244
|
+
def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
|
245
|
+
target_lens = reference['attention_mask'].sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
|
165
246
|
lens = generated['attention_mask'].sum(dim=1)
|
166
|
-
neg_lens =
|
167
|
-
len_reward = torch.where(lens >=
|
247
|
+
neg_lens = target_lens / lens if self.neg_reward_len else 1.0
|
248
|
+
len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
|
168
249
|
return len_reward
|
169
250
|
|
170
251
|
def reset_running_mean(self):
|
171
252
|
self.prev_data_running_mean = None
|
172
253
|
|
173
|
-
def init_running_mean(self, prev_data:
|
254
|
+
def init_running_mean(self, prev_data: TokenizedDict):
|
174
255
|
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
175
256
|
|
176
|
-
def update_running_mean(self, prev_data:
|
257
|
+
def update_running_mean(self, prev_data: TokenizedDict):
|
177
258
|
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
178
259
|
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
179
260
|
|
@@ -193,24 +274,50 @@ class MrlRewardModel:
|
|
193
274
|
) -> list[float]:
|
194
275
|
if prev_data is not None:
|
195
276
|
if self.prev_data_running_mean is None:
|
196
|
-
self.init_running_mean(prev_data
|
277
|
+
self.init_running_mean(prev_data)
|
197
278
|
else:
|
198
|
-
self.update_running_mean(prev_data
|
279
|
+
self.update_running_mean(prev_data)
|
199
280
|
|
200
281
|
if mode == MrlRewardMode.STANDARD:
|
201
|
-
bleu = self.batch_bleu(generated
|
202
|
-
cosine = self.batch_cosine(generated
|
203
|
-
|
282
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
283
|
+
cosine = self.batch_cosine(generated, reference, saved_data, include_running_mean=prev_data is not None)
|
284
|
+
|
285
|
+
if self.debug_mode >= 1:
|
286
|
+
print('STANDARD MODE')
|
287
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
288
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
289
|
+
|
204
290
|
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
205
291
|
elif mode == MrlRewardMode.LONG_RANGE:
|
206
|
-
bleu = self.batch_bleu(generated
|
207
|
-
cosine = self.batch_cosine(generated
|
292
|
+
bleu = self.batch_bleu(generated, reference, saved_data)
|
293
|
+
cosine = self.batch_cosine(generated, reference, saved_data,
|
208
294
|
negative_running_mean=prev_data is not None)
|
295
|
+
|
296
|
+
if self.debug_mode >= 1:
|
297
|
+
print('LONG MODE')
|
298
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
299
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
300
|
+
|
209
301
|
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
210
302
|
else:
|
211
|
-
bleu = self.
|
212
|
-
cosine = self.negative_cosine(generated
|
303
|
+
bleu = self.negative_bleu(generated, reference, saved_data)
|
304
|
+
cosine = self.negative_cosine(generated, reference, saved_data)
|
305
|
+
|
306
|
+
if self.debug_mode >= 1:
|
307
|
+
print('NEGATIVE MODE')
|
308
|
+
print('BLEU: ', sum(bleu) / len(bleu))
|
309
|
+
print('COSINE: ', sum(cosine) / len(cosine))
|
310
|
+
|
213
311
|
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
214
312
|
|
215
|
-
|
313
|
+
if self.reward_len:
|
314
|
+
len_reward = self.len_reward(generated, reference)
|
315
|
+
|
316
|
+
if self.debug_mode >= 1:
|
317
|
+
print('REWARD LEN: ', (len_reward.sum() / len_reward.size(0)).item())
|
318
|
+
|
319
|
+
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * len_reward) * self.rewards_scale
|
320
|
+
else:
|
321
|
+
rewards = self._pre_scale_rewards(sim_rewards) * self.rewards_scale
|
322
|
+
|
216
323
|
return rewards.tolist()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: rxnn
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.49
|
4
4
|
Summary: RxNN: Reactive Neural Networks Platform
|
5
5
|
License: Apache-2.0
|
6
6
|
Keywords: deep-learning,ai,machine-learning
|
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.13
|
16
16
|
Requires-Dist: datasets (>=3.5.0,<4.0.0)
|
17
17
|
Requires-Dist: huggingface-hub (>=0.30.0,<0.31.0)
|
18
|
+
Requires-Dist: nltk (>=3.9.1,<4.0.0)
|
18
19
|
Requires-Dist: tensorboard (>=2.19.0,<3.0.0)
|
19
20
|
Requires-Dist: tokenizers (>=0.21.0,<0.22.0)
|
20
21
|
Requires-Dist: torch (>=2.6.0,<3.0.0)
|
@@ -6,7 +6,7 @@ rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4
|
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=kan6UNPTjLfO7zKNp92hGooldgWPi3li_2-_L5xiErs,2784
|
9
|
-
rxnn/memory/norm.py,sha256=
|
9
|
+
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=new_YXLe9vfIBPX-pmFRoV523d7yCjEgfTY06EaH3Ms,14605
|
@@ -17,8 +17,8 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
|
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=L2emJM06u7B9f9T1dFsGXzXX-rsV77ND7L1pAM9Z_Ow,9051
|
20
|
-
rxnn/training/mrl.py,sha256=
|
21
|
-
rxnn/training/reward.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=IOi_xbQ47RPgv_2ucT9EkPeWLGBRlgPxKHFeQsYc3Pw,61074
|
21
|
+
rxnn/training/reward.py,sha256=dq3b5DRhBLHOvtlHX3eSSuxYBGYCyV5gVqbzCam4uP8,16112
|
22
22
|
rxnn/training/rl.py,sha256=q4NzIZAmXRHVToT13IHrPTtEikWQUvT0NO0IjApjAO8,6171
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.49.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.49.dist-info/METADATA,sha256=Yd5xCJVA_rFdzYkTkHZ8tyronArNMOgUQ6VqNF9-vqs,25997
|
38
|
+
rxnn-0.2.49.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.49.dist-info/RECORD,,
|
File without changes
|
File without changes
|