rxnn 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/models.py +9 -4
- rxnn/training/mrl.py +156 -93
- rxnn/training/reward.py +119 -22
- {rxnn-0.2.18.dist-info → rxnn-0.2.20.dist-info}/METADATA +1 -1
- {rxnn-0.2.18.dist-info → rxnn-0.2.20.dist-info}/RECORD +7 -7
- {rxnn-0.2.18.dist-info → rxnn-0.2.20.dist-info}/LICENSE +0 -0
- {rxnn-0.2.18.dist-info → rxnn-0.2.20.dist-info}/WHEEL +0 -0
rxnn/training/models.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
from enum import Enum
|
4
|
+
from typing import Literal
|
4
5
|
from huggingface_hub import PyTorchModelHubMixin
|
5
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
6
7
|
|
@@ -74,23 +75,27 @@ class MrlActorModel(nn.Module):
|
|
74
75
|
self.decoder = decoder
|
75
76
|
self.memory_attention = memory_attention
|
76
77
|
|
77
|
-
def freeze_components(self):
|
78
|
+
def freeze_components(self, stage: Literal['update', 'fetch', 'both'] = 'both'):
|
78
79
|
"""Freeze encoder/decoder except memory-related layers."""
|
79
80
|
if self.encoder.freeze_without_memory is not None:
|
80
81
|
self.encoder.freeze_without_memory()
|
82
|
+
if stage == 'update':
|
83
|
+
self.encoder.freeze_memory()
|
81
84
|
else:
|
82
85
|
for param in self.encoder.parameters():
|
83
86
|
param.requires_grad = False
|
84
|
-
self.encoder.model.trainable_cross_attention_(True)
|
87
|
+
self.encoder.model.trainable_cross_attention_(True if stage != 'update' else False)
|
85
88
|
if self.decoder.freeze_without_memory is not None:
|
86
89
|
self.decoder.freeze_without_memory()
|
90
|
+
if stage == 'update':
|
91
|
+
self.decoder.freeze_memory()
|
87
92
|
else:
|
88
93
|
for param in self.decoder.parameters():
|
89
94
|
param.requires_grad = False
|
90
|
-
self.decoder.model.trainable_cross_attention_(True)
|
95
|
+
self.decoder.model.trainable_cross_attention_(True if stage != 'update' else False)
|
91
96
|
# Unfreeze memory attention
|
92
97
|
for param in self.memory_attention.parameters():
|
93
|
-
param.requires_grad = True
|
98
|
+
param.requires_grad = True if stage != 'fetch' else False
|
94
99
|
|
95
100
|
def unfreeze_components(self):
|
96
101
|
"""Unfreeze all components after initial training."""
|
rxnn/training/mrl.py
CHANGED
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|
3
3
|
from torch.utils.tensorboard import SummaryWriter
|
4
4
|
import torch.distributed as dist
|
5
5
|
from torch.nn.parallel import DistributedDataParallel
|
6
|
-
from typing import Optional, TypedDict
|
6
|
+
from typing import Optional, TypedDict, Union
|
7
7
|
from enum import Enum
|
8
8
|
import random, os
|
9
9
|
from ..transformers.sampler import BatchSampler
|
@@ -37,10 +37,15 @@ class CurriculumConfig(TypedDict):
|
|
37
37
|
eval_dataset: Optional[MrlCurriculumDataset]
|
38
38
|
callbacks: Optional[list[MrlTrainerCallback]]
|
39
39
|
strategy: MrlStrategy
|
40
|
-
unfreeze_epoch: Optional[int]
|
40
|
+
unfreeze_epoch: Optional[Union[int, tuple[int, int, int]]]
|
41
41
|
random_resets: Optional[bool]
|
42
42
|
random_resets_from: Optional[int]
|
43
43
|
random_resets_ratio: Optional[float]
|
44
|
+
reward_model: Optional[MrlRewardModel]
|
45
|
+
lr: Optional[float]
|
46
|
+
critic_lr: Optional[float]
|
47
|
+
weight_decay: Optional[float]
|
48
|
+
critic_weight_decay: Optional[float]
|
44
49
|
|
45
50
|
|
46
51
|
class SamplerConfig(TypedDict):
|
@@ -90,6 +95,7 @@ class MRLTrainer:
|
|
90
95
|
"""
|
91
96
|
self.actor = actor
|
92
97
|
self.critic = critic
|
98
|
+
self.shared_reward_model = reward
|
93
99
|
self.reward = reward
|
94
100
|
self.device = device
|
95
101
|
self.max_seq_len = config.get('max_seq_len', 256)
|
@@ -117,17 +123,15 @@ class MRLTrainer:
|
|
117
123
|
self.use_amp = use_amp
|
118
124
|
self.dtype = dtype
|
119
125
|
|
126
|
+
self.base_optim_config = {
|
127
|
+
'lr': config.get('lr', 3e-4),
|
128
|
+
'critic_lr': config.get('critic_lr', 1e-4),
|
129
|
+
'weight_decay': config.get('weight_decay', 0.01),
|
130
|
+
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
131
|
+
}
|
132
|
+
|
120
133
|
# Optimizers
|
121
|
-
self.optimizer =
|
122
|
-
self.actor.unique_parameters(),
|
123
|
-
lr=config.get("lr", 3e-4),
|
124
|
-
weight_decay=config.get("weight_decay", 0.01),
|
125
|
-
)
|
126
|
-
self.critic_optimizer = torch.optim.AdamW(
|
127
|
-
self.critic.parameters(),
|
128
|
-
lr=config.get("critic_lr", 1e-4),
|
129
|
-
weight_decay=config.get("critic_weight_decay", 0.01),
|
130
|
-
)
|
134
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
|
131
135
|
|
132
136
|
self.scaler = torch.amp.GradScaler() if self.use_amp else None
|
133
137
|
self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
|
@@ -154,6 +158,21 @@ class MRLTrainer:
|
|
154
158
|
self.global_epoch = 0
|
155
159
|
self.global_epochs_count = 0
|
156
160
|
|
161
|
+
def _init_optimizers(self, lr: float, critic_lr: float, weight_decay: float, critic_weight_decay: float):
|
162
|
+
optimizer = torch.optim.AdamW(
|
163
|
+
self.actor.unique_parameters(),
|
164
|
+
lr=lr,
|
165
|
+
weight_decay=weight_decay,
|
166
|
+
)
|
167
|
+
|
168
|
+
critic_optimizer = torch.optim.AdamW(
|
169
|
+
self.critic.parameters(),
|
170
|
+
lr=critic_lr,
|
171
|
+
weight_decay=critic_weight_decay,
|
172
|
+
)
|
173
|
+
return optimizer, critic_optimizer
|
174
|
+
|
175
|
+
|
157
176
|
def _init_steps(self):
|
158
177
|
return {
|
159
178
|
'collect': 0,
|
@@ -221,21 +240,29 @@ class MRLTrainer:
|
|
221
240
|
|
222
241
|
return generated_answer, log_probs
|
223
242
|
|
243
|
+
def _calculate_reward(self, generated: TokenizedDict, reference: TokenizedDict,
|
244
|
+
saved_query: TokenizedDict, saved_answer: TokenizedDict,
|
245
|
+
mode: MrlRewardMode = MrlRewardMode.STANDARD,
|
246
|
+
prev_data: tuple[TokenizedDict, TokenizedDict] = None):
|
247
|
+
saved_interaction = smart_concat(saved_query, saved_answer, max_length=self.max_seq_len,
|
248
|
+
pad_token_id=self.pad_token_id)
|
249
|
+
prev_data = smart_concat(prev_data[0], prev_data[1], self.max_seq_len,
|
250
|
+
self.pad_token_id) if prev_data is not None else None
|
251
|
+
return self.reward(generated, reference, saved_interaction, mode=mode, prev_data=prev_data), saved_interaction
|
252
|
+
|
224
253
|
def compute_reward(self, generated: TokenizedDict, reference: TokenizedDict,
|
225
254
|
saved_data: tuple[TokenizedDict, TokenizedDict], mode: MrlRewardMode = MrlRewardMode.STANDARD,
|
226
|
-
eval_mode: bool = False) -> list[float]:
|
255
|
+
eval_mode: bool = False, prev_data: tuple[TokenizedDict, TokenizedDict] = None) -> list[float]:
|
227
256
|
"""Compute reward based on memory retention (e.g., BLEU-4)."""
|
228
257
|
saved_query, saved_answer = saved_data
|
229
258
|
# 1. Concat saved (previous) interaction and calculate reward using generated sequence, reference and saved data - with autocast on/off
|
230
259
|
if self.use_amp:
|
231
260
|
with torch.amp.autocast(device_type=self.device.type, dtype=self.dtype):
|
232
|
-
saved_interaction =
|
233
|
-
|
234
|
-
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
261
|
+
reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
|
262
|
+
mode=mode, prev_data=prev_data)
|
235
263
|
else:
|
236
|
-
saved_interaction =
|
237
|
-
|
238
|
-
reward = self.reward(generated, reference, saved_interaction, mode=mode)
|
264
|
+
reward, saved_interaction = self._calculate_reward(generated, reference, saved_query, saved_answer,
|
265
|
+
mode=mode, prev_data=prev_data)
|
239
266
|
|
240
267
|
# 2. Run 'on reward' callbacks
|
241
268
|
for cb in self.callbacks:
|
@@ -289,22 +316,27 @@ class MRLTrainer:
|
|
289
316
|
# state from existing one, instead of new random one)
|
290
317
|
reset_done = self.reset_stm()
|
291
318
|
|
292
|
-
# 4.
|
319
|
+
# 4. Reset reward prev data running mean - it's calculated for multi-step retention, we have to reset it before episode
|
320
|
+
self.reward.reset_running_mean()
|
321
|
+
|
322
|
+
# 5. Get first batch of interactions (data to save) and follow-up interactions for current episode, based on curriculum step
|
293
323
|
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
294
324
|
interactions = interactions[:self.curriculum_steps]
|
295
325
|
interactions_len = len(interactions)
|
296
|
-
#
|
326
|
+
# 6. Encode and update STM with data to save from first interaction
|
297
327
|
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
298
328
|
|
299
|
-
#
|
329
|
+
# 7. Save first interaction as data to save (for trajectory state)
|
300
330
|
query, answer = first_query, first_answer
|
301
331
|
|
302
|
-
#
|
332
|
+
# 8. Run training strategy for follow-up interactions
|
303
333
|
episode_steps = []
|
304
334
|
episode_rewards = []
|
305
335
|
|
336
|
+
prev_interaction = None
|
337
|
+
|
306
338
|
for i, interaction in enumerate(interactions):
|
307
|
-
#
|
339
|
+
# 9. Generate batch of answers based on batch of follow-up queries
|
308
340
|
next_query = self._move_batch(interaction['query'])
|
309
341
|
generated_answer, log_probs = self.generate_answer(next_query)
|
310
342
|
|
@@ -312,7 +344,7 @@ class MRLTrainer:
|
|
312
344
|
|
313
345
|
detached_answer = self._cpu_detach(generated_answer) # detach and keep states on CPU
|
314
346
|
|
315
|
-
#
|
347
|
+
# 10. Depending on strategy compute reward
|
316
348
|
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
317
349
|
# a) long-range - first interaction - change topic - negative reward (it shouldn't include saved data)
|
318
350
|
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
@@ -320,18 +352,19 @@ class MRLTrainer:
|
|
320
352
|
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
321
353
|
# b) long-range - last interaction - first interaction topic - long-range reward (it should include content from first interaction)
|
322
354
|
reward = self.compute_reward(detached_answer, interaction['answer'],
|
323
|
-
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE
|
355
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
356
|
+
prev_data=prev_interaction)
|
324
357
|
else:
|
325
358
|
# c) standard reward - generated answer should include some content from previous interaction (saved data), like reference answer
|
326
359
|
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
327
|
-
mode=MrlRewardMode.STANDARD)
|
360
|
+
mode=MrlRewardMode.STANDARD, prev_data=prev_interaction)
|
328
361
|
|
329
|
-
#
|
362
|
+
# 11. Update STM with generated response (except last interaction, it's not needed)
|
330
363
|
if not is_last_interaction:
|
331
364
|
self.encode_and_update_stm(next_query,
|
332
365
|
generated_answer) # update with generated_answer on GPU
|
333
366
|
|
334
|
-
#
|
367
|
+
# 12. Store trajectory step
|
335
368
|
trajectory: MrlTrajectoryStep = {
|
336
369
|
'state': (query, answer, interaction['query']),
|
337
370
|
'action': detached_answer,
|
@@ -342,10 +375,12 @@ class MRLTrainer:
|
|
342
375
|
episode_steps.append(trajectory)
|
343
376
|
episode_rewards.append(reward)
|
344
377
|
|
345
|
-
#
|
378
|
+
# 13. Set previous and current interaction query and generated answer (batches), as saved data for next interaction
|
379
|
+
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
380
|
+
prev_interaction = (query, answer)
|
346
381
|
query, answer = interaction['query'], detached_answer
|
347
382
|
|
348
|
-
#
|
383
|
+
# 14. Append full batched episode (number of steps depends on curriculum stage) to trajectories
|
349
384
|
episode_trajectory: MrlTrajectoryEpisode = {
|
350
385
|
'reset_stm': reset_done,
|
351
386
|
'steps': episode_steps,
|
@@ -356,7 +391,7 @@ class MRLTrainer:
|
|
356
391
|
|
357
392
|
self._collect_writer(mean_episode_reward, epoch)
|
358
393
|
|
359
|
-
#
|
394
|
+
# 15. Run "on episode collected" callbacks
|
360
395
|
for cb in self.callbacks:
|
361
396
|
cb.on_episode_collected(self.actor, batch_idx, episode_trajectory, mean_episode_reward)
|
362
397
|
|
@@ -595,63 +630,70 @@ class MRLTrainer:
|
|
595
630
|
for batch in dataloader:
|
596
631
|
with torch.no_grad():
|
597
632
|
if batch['query']['input_ids'].size(0) == batch_size:
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
633
|
+
self._increment_steps('eval')
|
634
|
+
# 3. Reset STM with random resets ratio and reward model running mean
|
635
|
+
self.reset_stm()
|
636
|
+
self.reward.reset_running_mean()
|
637
|
+
|
638
|
+
# 4. Get batches for first queries, answers and all follow-up interactions
|
639
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
640
|
+
# 5. Encode and update STM with initial interactions (batch)
|
641
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
642
|
+
|
643
|
+
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
644
|
+
interactions_len = len(interactions)
|
645
|
+
query, answer = first_query, first_answer
|
646
|
+
episode_reward = torch.tensor(0.0).to(self.device)
|
647
|
+
episode_interactions = torch.tensor(0).to(self.device)
|
648
|
+
|
649
|
+
prev_interaction = None
|
650
|
+
|
651
|
+
# 7. Run all follow-up interactions
|
652
|
+
for i, interaction in enumerate(interactions):
|
653
|
+
# 8. Generate batch of answers
|
654
|
+
next_query = self._move_batch(interaction['query'])
|
655
|
+
generated_answer, _ = self.generate_answer(next_query)
|
656
|
+
|
657
|
+
is_last_interaction = (i + 1) == interactions_len
|
658
|
+
|
659
|
+
detached_answer = self._cpu_detach(generated_answer)
|
660
|
+
|
661
|
+
# 9. Depending on current strategy and step, compute reward
|
662
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
663
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
664
|
+
mode=MrlRewardMode.NEGATIVE, eval_mode=True)
|
665
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
666
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
667
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
668
|
+
eval_mode=True, prev_data=prev_interaction)
|
669
|
+
else:
|
670
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
671
|
+
mode=MrlRewardMode.STANDARD, eval_mode=True,
|
672
|
+
prev_data=prev_interaction)
|
673
|
+
|
674
|
+
# 10. Encode and update memory for the next interaction
|
675
|
+
if not is_last_interaction:
|
676
|
+
self.encode_and_update_stm(next_query, generated_answer)
|
677
|
+
|
678
|
+
# 11. Accumulate rewards
|
679
|
+
step_reward = torch.tensor(reward).mean().to(self.device)
|
680
|
+
# total
|
681
|
+
total_reward += step_reward
|
682
|
+
count += 1
|
683
|
+
# episode
|
684
|
+
episode_reward += step_reward
|
685
|
+
episode_interactions += 1
|
686
|
+
# 12. Save previous interaction
|
687
|
+
if not (self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0):
|
688
|
+
prev_interaction = (query, answer)
|
689
|
+
query, answer = interaction['query'], detached_answer
|
690
|
+
avg_episode_reward = (episode_reward / episode_interactions).item()
|
691
|
+
# 13. Run eval TensorBoard writer with average episode reward
|
692
|
+
self._eval_writer(avg_episode_reward, epoch)
|
693
|
+
|
694
|
+
# 14. Run "on eval episode end" callbacks
|
695
|
+
for cb in self.callbacks:
|
696
|
+
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
655
697
|
|
656
698
|
# 15. Calculate average reward
|
657
699
|
if self.use_ddp:
|
@@ -679,6 +721,14 @@ class MRLTrainer:
|
|
679
721
|
self.shared_callbacks) # trainer callbacks for current curriculum stage
|
680
722
|
self.strategy = config.get('strategy',
|
681
723
|
MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
|
724
|
+
self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
|
725
|
+
if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
|
726
|
+
self.optimizer, self.critic_optimizer = self._init_optimizers(
|
727
|
+
lr=config['lr'] or self.base_optim_config['lr'],
|
728
|
+
critic_lr=config['critic_lr'] or self.base_optim_config['critic_lr'],
|
729
|
+
weight_decay=config['weight_decay'] or self.base_optim_config['weight_decay'],
|
730
|
+
critic_weight_decay=config['critic_weight_decay'] or self.base_optim_config['critic_weight_decay']
|
731
|
+
)
|
682
732
|
|
683
733
|
# 2. Get epochs and random resets configs
|
684
734
|
epochs = config.get('epochs', 5) # number of epochs for current stage
|
@@ -720,7 +770,11 @@ class MRLTrainer:
|
|
720
770
|
|
721
771
|
# 4. Freeze all components except memory attention and memory cross-attention layers in decoder/encoder
|
722
772
|
if unfreeze_epoch != 0:
|
723
|
-
|
773
|
+
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
774
|
+
if is_staged_unfreeze:
|
775
|
+
self.actor.freeze_components('update')
|
776
|
+
else:
|
777
|
+
self.actor.freeze_components()
|
724
778
|
|
725
779
|
# 5. Setup train DataLoader
|
726
780
|
if self.use_ddp:
|
@@ -761,8 +815,18 @@ class MRLTrainer:
|
|
761
815
|
self.random_resets_ratio = 1.0
|
762
816
|
|
763
817
|
# 11. Unfreeze all components before selected epoch
|
764
|
-
|
765
|
-
|
818
|
+
is_staged_unfreeze = isinstance(unfreeze_epoch, tuple)
|
819
|
+
if is_staged_unfreeze:
|
820
|
+
fetch_epoch, both_epoch, all_epoch = unfreeze_epoch
|
821
|
+
if epoch == fetch_epoch:
|
822
|
+
self.actor.freeze_components('fetch')
|
823
|
+
elif epoch == both_epoch:
|
824
|
+
self.actor.freeze_components('both')
|
825
|
+
elif epoch == all_epoch:
|
826
|
+
self.actor.unfreeze_components()
|
827
|
+
else:
|
828
|
+
if epoch == unfreeze_epoch:
|
829
|
+
self.actor.unfreeze_components()
|
766
830
|
|
767
831
|
# 12. Set epoch for distributed sampler
|
768
832
|
if train_sampler is not None:
|
@@ -805,4 +869,3 @@ class MRLTrainer:
|
|
805
869
|
# 21. Close writer
|
806
870
|
if self.writer:
|
807
871
|
self.writer.close()
|
808
|
-
|
rxnn/training/reward.py
CHANGED
@@ -11,6 +11,7 @@ class MrlRewardMode(Enum):
|
|
11
11
|
NEGATIVE = 2
|
12
12
|
LONG_RANGE = 3
|
13
13
|
|
14
|
+
|
14
15
|
class MrlRewardModel:
|
15
16
|
def __init__(
|
16
17
|
self,
|
@@ -18,9 +19,14 @@ class MrlRewardModel:
|
|
18
19
|
device: torch.device,
|
19
20
|
bleu_with_saved_data: bool = False,
|
20
21
|
bleu_factor: float = 0.5,
|
22
|
+
bleu_ref_factor: float = 0.5,
|
23
|
+
bleu_saved_factor: float = 0.5,
|
21
24
|
cos_factor: float = 0.5,
|
22
25
|
cos_ref_factor: float = 0.5,
|
23
26
|
cos_saved_factor: float = 0.5,
|
27
|
+
multi_cos_ref_factor: float = 0.3,
|
28
|
+
multi_cos_saved_factor: float = 0.5,
|
29
|
+
multi_cos_running_mean_factor: float = 0.2,
|
24
30
|
neg_bleu_factor: Optional[float] = None,
|
25
31
|
neg_cos_factor: Optional[float] = None,
|
26
32
|
neg_cos_ref_factor: Optional[float] = None,
|
@@ -28,45 +34,88 @@ class MrlRewardModel:
|
|
28
34
|
neg_bleu_ref_factor: float = 0.5,
|
29
35
|
neg_bleu_saved_factor: float = 0.5,
|
30
36
|
allow_not_summing_factors: bool = False,
|
37
|
+
reward_len: bool = False,
|
38
|
+
neg_reward_len: bool = False,
|
39
|
+
max_rewarded_len: int = None,
|
40
|
+
len_factor: int = None,
|
41
|
+
use_running_mean: bool = True,
|
42
|
+
running_mean_decay: float = 0.2,
|
43
|
+
bleu_saved_weights: tuple = (0.5, 0.5),
|
44
|
+
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
|
+
rewards_scale: float = 1.0,
|
31
46
|
):
|
32
47
|
self.shared_embedding = shared_embedding.to(device)
|
33
48
|
self.device = device
|
34
49
|
self.bleu_with_saved_data = bleu_with_saved_data
|
35
50
|
|
36
51
|
self.bleu_factor = bleu_factor
|
52
|
+
self.bleu_ref_factor = bleu_ref_factor
|
53
|
+
self.bleu_saved_factor = bleu_saved_factor
|
37
54
|
self.cos_factor = cos_factor
|
38
55
|
self.cos_ref_factor = cos_ref_factor
|
39
56
|
self.cos_saved_factor = cos_saved_factor
|
57
|
+
self.multi_cos_ref_factor = multi_cos_ref_factor
|
58
|
+
self.multi_cos_saved_factor = multi_cos_saved_factor
|
59
|
+
self.multi_cos_running_mean_factor = multi_cos_running_mean_factor
|
40
60
|
self.neg_bleu_factor = neg_bleu_factor if neg_bleu_factor is not None else bleu_factor
|
41
61
|
self.neg_cos_factor = neg_cos_factor if neg_cos_factor is not None else cos_factor
|
42
62
|
self.neg_cos_ref_factor = neg_cos_ref_factor if neg_cos_ref_factor is not None else cos_ref_factor
|
43
63
|
self.neg_cos_saved_factor = neg_cos_saved_factor if neg_cos_saved_factor is not None else cos_saved_factor
|
44
64
|
self.neg_bleu_ref_factor = neg_bleu_ref_factor
|
45
65
|
self.neg_bleu_saved_factor = neg_bleu_saved_factor
|
66
|
+
self.reward_len = reward_len
|
67
|
+
self.neg_reward_len = neg_reward_len
|
68
|
+
self.max_rewarded_len = max_rewarded_len
|
69
|
+
self.len_factor = len_factor
|
70
|
+
self.use_running_mean = use_running_mean
|
71
|
+
self.running_mean_decay = running_mean_decay
|
72
|
+
self.bleu_ref_weights = bleu_ref_weights
|
73
|
+
self.bleu_saved_weights = bleu_saved_weights
|
74
|
+
self.rewards_scale = rewards_scale
|
75
|
+
|
76
|
+
self.prev_data_running_mean = None
|
46
77
|
|
47
78
|
if not allow_not_summing_factors:
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
79
|
+
if reward_len:
|
80
|
+
assert self.bleu_factor + self.cos_factor + self.len_factor == 1.0
|
81
|
+
assert self.neg_bleu_factor + self.neg_cos_factor + self.len_factor == 1.0
|
82
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
83
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
84
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
85
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
86
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
87
|
+
else:
|
88
|
+
assert self.bleu_factor + self.cos_factor == 1.0
|
89
|
+
assert self.bleu_ref_factor + self.bleu_saved_factor == 1.0
|
90
|
+
assert self.cos_ref_factor + self.cos_saved_factor == 1.0
|
91
|
+
assert self.multi_cos_ref_factor + self.multi_cos_saved_factor + self.multi_cos_running_mean_factor == 1.0
|
92
|
+
assert self.neg_bleu_factor + self.neg_cos_factor == 1.0
|
93
|
+
assert self.neg_cos_ref_factor + self.neg_cos_saved_factor == 1.0
|
94
|
+
assert self.neg_bleu_ref_factor + self.neg_bleu_saved_factor == 1.0
|
53
95
|
|
54
96
|
def _sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> float:
|
55
97
|
from nltk.translate.bleu_score import sentence_bleu
|
56
|
-
refs = [reference, saved_data] if self.bleu_with_saved_data else [reference]
|
57
|
-
return sentence_bleu(refs, generated, weights=(0.25, 0.25, 0.25, 0.25))
|
58
98
|
|
59
|
-
|
99
|
+
if self.bleu_with_saved_data:
|
100
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
101
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
102
|
+
return self.bleu_ref_factor * ref_bleu + self.bleu_saved_factor * saved_bleu
|
103
|
+
else:
|
104
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
105
|
+
|
106
|
+
|
107
|
+
def _negative_sentence_bleu(self, generated: torch.Tensor, reference: torch.Tensor,
|
108
|
+
saved_data: torch.Tensor) -> float:
|
60
109
|
from nltk.translate.bleu_score import sentence_bleu
|
61
110
|
|
62
111
|
if self.bleu_with_saved_data:
|
63
|
-
ref_bleu = sentence_bleu([reference], generated, weights=
|
64
|
-
saved_bleu = sentence_bleu([saved_data], generated, weights=
|
112
|
+
ref_bleu = sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
113
|
+
saved_bleu = sentence_bleu([saved_data], generated, weights=self.bleu_saved_weights)
|
65
114
|
saved_bleu = 1 - saved_bleu
|
66
115
|
|
67
|
-
return
|
116
|
+
return self.neg_bleu_ref_factor * ref_bleu + self.neg_bleu_saved_factor * saved_bleu
|
68
117
|
else:
|
69
|
-
return sentence_bleu([reference], generated, weights=
|
118
|
+
return sentence_bleu([reference], generated, weights=self.bleu_ref_weights)
|
70
119
|
|
71
120
|
def batch_bleu(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor) -> list[float]:
|
72
121
|
batch_size = generated.size(0)
|
@@ -79,33 +128,81 @@ class MrlRewardModel:
|
|
79
128
|
def _cosine_sim(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
80
129
|
generated_emb = self._sequence_embedding(generated)
|
81
130
|
|
82
|
-
gen_and_saved = F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data))
|
83
|
-
gen_and_ref = F.cosine_similarity(generated_emb, self._sequence_embedding(reference))
|
131
|
+
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
132
|
+
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
84
133
|
return gen_and_saved, gen_and_ref
|
85
134
|
|
86
|
-
def
|
87
|
-
|
135
|
+
def _cosine_sim_running_mean(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor):
|
136
|
+
generated_emb = self._sequence_embedding(generated)
|
88
137
|
|
89
|
-
|
138
|
+
gen_and_saved = (F.cosine_similarity(generated_emb, self._sequence_embedding(saved_data)) + 1) / 2
|
139
|
+
gen_and_ref = (F.cosine_similarity(generated_emb, self._sequence_embedding(reference)) + 1) / 2
|
140
|
+
gen_and_mean = (F.cosine_similarity(generated_emb, self.prev_data_running_mean) + 1) / 2
|
141
|
+
return gen_and_saved, gen_and_ref, gen_and_mean
|
142
|
+
|
143
|
+
def batch_cosine(self, generated: torch.Tensor, reference: torch.Tensor, saved_data: torch.Tensor,
|
144
|
+
include_running_mean: bool = False, negative_running_mean: bool = False) -> torch.Tensor:
|
145
|
+
if self.use_running_mean and negative_running_mean:
|
146
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
147
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * (
|
148
|
+
1 - gen_and_mean)
|
149
|
+
elif self.use_running_mean and include_running_mean:
|
150
|
+
gen_and_saved, gen_and_ref, gen_and_mean = self._cosine_sim_running_mean(generated, reference, saved_data)
|
151
|
+
return self.multi_cos_saved_factor * gen_and_saved + self.multi_cos_ref_factor * gen_and_ref + self.multi_cos_saved_factor * gen_and_mean
|
152
|
+
else:
|
153
|
+
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
154
|
+
return self.cos_saved_factor * gen_and_saved + self.cos_ref_factor * gen_and_ref
|
90
155
|
|
91
|
-
def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
|
156
|
+
def negative_cosine(self, generated: torch.Tensor, reference: torch.Tensor,
|
157
|
+
saved_data: torch.Tensor) -> torch.Tensor:
|
92
158
|
gen_and_saved, gen_and_ref = self._cosine_sim(generated, reference, saved_data)
|
93
159
|
|
94
160
|
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
95
161
|
|
162
|
+
def len_reward(self, generated: TokenizedDict):
|
163
|
+
lens = generated['attention_mask'].sum(dim=1)
|
164
|
+
neg_lens = self.max_rewarded_len / lens if self.neg_reward_len else 1.0
|
165
|
+
len_reward = torch.where(lens >= self.max_rewarded_len, neg_lens, lens / self.max_rewarded_len)
|
166
|
+
return len_reward
|
167
|
+
|
168
|
+
def reset_running_mean(self):
|
169
|
+
self.prev_data_running_mean = None
|
170
|
+
|
171
|
+
def init_running_mean(self, prev_data: torch.Tensor):
|
172
|
+
self.prev_data_running_mean = self._sequence_embedding(prev_data)
|
173
|
+
|
174
|
+
def update_running_mean(self, prev_data: torch.Tensor):
|
175
|
+
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
176
|
+
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
177
|
+
|
96
178
|
def __call__(
|
97
179
|
self,
|
98
180
|
generated: TokenizedDict,
|
99
181
|
reference: TokenizedDict,
|
100
182
|
saved_data: TokenizedDict,
|
183
|
+
prev_data: TokenizedDict = None,
|
101
184
|
mode: MrlRewardMode = MrlRewardMode.STANDARD
|
102
185
|
) -> list[float]:
|
103
|
-
if
|
186
|
+
if prev_data is not None:
|
187
|
+
if self.prev_data_running_mean is None:
|
188
|
+
self.init_running_mean(prev_data['input_ids'])
|
189
|
+
else:
|
190
|
+
self.update_running_mean(prev_data['input_ids'])
|
191
|
+
|
192
|
+
if mode == MrlRewardMode.STANDARD:
|
193
|
+
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
194
|
+
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
195
|
+
include_running_mean=prev_data is not None)
|
196
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
197
|
+
elif mode == MrlRewardMode.LONG_RANGE:
|
104
198
|
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
105
|
-
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids']
|
106
|
-
|
199
|
+
cosine = self.batch_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'],
|
200
|
+
negative_running_mean=prev_data is not None)
|
201
|
+
sim_rewards = self.bleu_factor * torch.tensor(bleu, device=self.device) + self.cos_factor * cosine
|
107
202
|
else:
|
108
203
|
bleu = self.batch_bleu(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
109
204
|
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
110
|
-
|
205
|
+
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
111
206
|
|
207
|
+
rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
208
|
+
return rewards.tolist()
|
@@ -15,9 +15,9 @@ rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
|
-
rxnn/training/models.py,sha256=
|
19
|
-
rxnn/training/mrl.py,sha256=
|
20
|
-
rxnn/training/reward.py,sha256=
|
18
|
+
rxnn/training/models.py,sha256=qXfD3_97T9z724NN4myjzrpX6-jYA9Igl266ZwtJCtc,5519
|
19
|
+
rxnn/training/mrl.py,sha256=zk4m1JFuX0y82J0tG2XkY0Pz6Uy2did9cngOXqR9lMk,43326
|
20
|
+
rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
|
21
21
|
rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.20.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.20.dist-info/METADATA,sha256=mgimK5GvI27RapfLjhlIdBwgfVdKoMA5Ig5yVxfeYIw,25960
|
37
|
+
rxnn-0.2.20.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|