rxnn 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/dataset.py +3 -0
- rxnn/training/mrl.py +65 -61
- rxnn/training/rl.py +34 -4
- {rxnn-0.2.13.dist-info → rxnn-0.2.14.dist-info}/METADATA +1 -1
- {rxnn-0.2.13.dist-info → rxnn-0.2.14.dist-info}/RECORD +7 -7
- {rxnn-0.2.13.dist-info → rxnn-0.2.14.dist-info}/LICENSE +0 -0
- {rxnn-0.2.13.dist-info → rxnn-0.2.14.dist-info}/WHEEL +0 -0
rxnn/training/dataset.py
CHANGED
@@ -1098,6 +1098,7 @@ class MrlDatasets:
|
|
1098
1098
|
load_kwargs: dict = None,
|
1099
1099
|
mrl_ds_kwargs: dict = None,
|
1100
1100
|
eval_split: str = None,
|
1101
|
+
max_seq_len: int = 256,
|
1101
1102
|
):
|
1102
1103
|
"""
|
1103
1104
|
Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
|
@@ -1115,6 +1116,7 @@ class MrlDatasets:
|
|
1115
1116
|
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
1116
1117
|
mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
|
1117
1118
|
eval_split (str): Load also evaluation/validation split (default: None)
|
1119
|
+
max_seq_len (int): Maximum sequence length (default: 256)
|
1118
1120
|
"""
|
1119
1121
|
if load_kwargs is None:
|
1120
1122
|
load_kwargs = {}
|
@@ -1131,6 +1133,7 @@ class MrlDatasets:
|
|
1131
1133
|
interactions_field=interactions_field,
|
1132
1134
|
split=load_split,
|
1133
1135
|
load_kwargs=load_kwargs,
|
1136
|
+
max_seq_len=max_seq_len,
|
1134
1137
|
**mrl_ds_kwargs,
|
1135
1138
|
)
|
1136
1139
|
|
rxnn/training/mrl.py
CHANGED
@@ -92,8 +92,8 @@ class MRLTrainer:
|
|
92
92
|
self.critic = critic
|
93
93
|
self.reward = reward
|
94
94
|
self.device = device
|
95
|
-
self.max_seq_len = config.get('max_seq_len',
|
96
|
-
self.critic_max_len = config.get('critic_max_len',
|
95
|
+
self.max_seq_len = config.get('max_seq_len', 256)
|
96
|
+
self.critic_max_len = config.get('critic_max_len', 512)
|
97
97
|
|
98
98
|
# Move models to device
|
99
99
|
if use_amp:
|
@@ -306,6 +306,7 @@ class MRLTrainer:
|
|
306
306
|
for i, interaction in enumerate(interactions):
|
307
307
|
# 8. Generate batch of answers based on batch of follow-up queries
|
308
308
|
next_query = self._move_batch(interaction['query'])
|
309
|
+
print(next_query['input_ids'].size())
|
309
310
|
generated_answer, log_probs = self.generate_answer(next_query)
|
310
311
|
|
311
312
|
is_last_interaction = (i + 1) == interactions_len
|
@@ -365,7 +366,7 @@ class MRLTrainer:
|
|
365
366
|
def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
366
367
|
# 1. Calculate values with critic encoder
|
367
368
|
values = self.critic(
|
368
|
-
|
369
|
+
inputs['input_ids'],
|
369
370
|
attention_mask=inputs['attention_mask'],
|
370
371
|
).squeeze()
|
371
372
|
# 2. Calculate critic loss
|
@@ -461,6 +462,7 @@ class MRLTrainer:
|
|
461
462
|
# 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
|
462
463
|
# memory, based on collected episode data
|
463
464
|
all_losses = []
|
465
|
+
trajectories_len = len(trajectories)
|
464
466
|
for episode_idx, episode in enumerate(trajectories):
|
465
467
|
episode_steps = episode['steps']
|
466
468
|
should_reset_stm = episode['reset_stm']
|
@@ -506,7 +508,7 @@ class MRLTrainer:
|
|
506
508
|
action=MrlActorAction.DECODE)
|
507
509
|
|
508
510
|
# 7. Calculate RL Algorithm (PPO etc.) loss
|
509
|
-
policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
|
511
|
+
policy_loss = self.rl_algorithm.policy_loss(next_query, action, inputs['input_ids'], logits, log_probs, advantages)
|
510
512
|
|
511
513
|
# 8. Reset gradients
|
512
514
|
self.optimizer.zero_grad()
|
@@ -593,63 +595,64 @@ class MRLTrainer:
|
|
593
595
|
# 2. Run evaluation on all batch episodes
|
594
596
|
for batch in dataloader:
|
595
597
|
with torch.no_grad():
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
598
|
+
if batch['query']['input_ids'].size(0) == batch_size:
|
599
|
+
self._increment_steps('eval')
|
600
|
+
# 3. Reset STM with random resets ratio
|
601
|
+
self.reset_stm()
|
602
|
+
|
603
|
+
# 4. Get batches for first queries, answers and all follow-up interactions
|
604
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
605
|
+
# 5. Encode and update STM with initial interactions (batch)
|
606
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
607
|
+
|
608
|
+
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
609
|
+
interactions_len = len(interactions)
|
610
|
+
query, answer = first_query, first_answer
|
611
|
+
episode_reward = torch.tensor(0.0).to(self.device)
|
612
|
+
episode_interactions = torch.tensor(0).to(self.device)
|
613
|
+
# 7. Run all follow-up interactions
|
614
|
+
for i, interaction in enumerate(interactions):
|
615
|
+
# 8. Generate batch of answers
|
616
|
+
next_query = self._move_batch(interaction['query'])
|
617
|
+
generated_answer, _ = self.generate_answer(next_query)
|
618
|
+
|
619
|
+
is_last_interaction = (i + 1) == interactions_len
|
620
|
+
|
621
|
+
detached_answer = self._cpu_detach(generated_answer)
|
622
|
+
|
623
|
+
# 9. Depending on current strategy and step, compute reward
|
624
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
625
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
626
|
+
mode=MrlRewardMode.NEGATIVE, eval_mode=True)
|
627
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
628
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
629
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
630
|
+
eval_mode=True)
|
631
|
+
else:
|
632
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
633
|
+
mode=MrlRewardMode.STANDARD, eval_mode=True)
|
634
|
+
|
635
|
+
# 10. Encode and update memory for the next interaction
|
636
|
+
if not is_last_interaction:
|
637
|
+
self.encode_and_update_stm(next_query, generated_answer)
|
638
|
+
|
639
|
+
# 11. Accumulate rewards
|
640
|
+
step_reward = torch.tensor(reward).mean().to(self.device)
|
641
|
+
# total
|
642
|
+
total_reward += step_reward
|
643
|
+
count += 1
|
644
|
+
# episode
|
645
|
+
episode_reward += step_reward
|
646
|
+
episode_interactions += 1
|
647
|
+
# 12. Save previous interaction
|
648
|
+
query, answer = interaction['query'], detached_answer
|
649
|
+
avg_episode_reward = (episode_reward / episode_interactions).item()
|
650
|
+
# 13. Run eval TensorBoard writer with average episode reward
|
651
|
+
self._eval_writer(avg_episode_reward, epoch)
|
652
|
+
|
653
|
+
# 14. Run "on eval episode end" callbacks
|
654
|
+
for cb in self.callbacks:
|
655
|
+
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
653
656
|
|
654
657
|
# 15. Calculate average reward
|
655
658
|
if self.use_ddp:
|
@@ -804,3 +807,4 @@ class MRLTrainer:
|
|
804
807
|
# 21. Close writer
|
805
808
|
if self.writer:
|
806
809
|
self.writer.close()
|
810
|
+
|
rxnn/training/rl.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from typing import TypedDict
|
6
|
+
from .utils import TokenizedDict
|
6
7
|
|
7
8
|
|
8
9
|
class RlAlgorithm(ABC):
|
@@ -11,7 +12,8 @@ class RlAlgorithm(ABC):
|
|
11
12
|
self.critic_loss = nn.MSELoss()
|
12
13
|
|
13
14
|
@abstractmethod
|
14
|
-
def policy_loss(self,
|
15
|
+
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
16
|
+
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
15
17
|
pass
|
16
18
|
|
17
19
|
@abstractmethod
|
@@ -35,10 +37,38 @@ class PPOAlgorithm(RlAlgorithm):
|
|
35
37
|
self.gae_lambda = config.get('gae_lambda', 0.95)
|
36
38
|
self.clip_eps = config.get('clip_eps', 0.2)
|
37
39
|
|
38
|
-
def policy_loss(self,
|
40
|
+
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
41
|
+
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
42
|
+
query_lens = query['attention_mask'].sum(dim=1).long() # Query lengths per sample
|
43
|
+
answer_mask = answer['attention_mask']
|
44
|
+
answer_lens = answer_mask.sum(dim=1).long() # Answer lengths per sample (before padding)
|
45
|
+
|
46
|
+
max_length = query['input_ids'].size(1)
|
47
|
+
|
48
|
+
combined_lens = torch.minimum(
|
49
|
+
query_lens + answer_lens,
|
50
|
+
torch.full_like(query_lens, max_length)
|
51
|
+
)
|
52
|
+
|
53
|
+
def extract_answer_tokens(tensor: torch.Tensor) -> torch.Tensor:
|
54
|
+
B, L, *rest = tensor.size()
|
55
|
+
result = torch.zeros((B, max_length, *rest), dtype=tensor.dtype, device=tensor.device)
|
56
|
+
|
57
|
+
for i in range(B):
|
58
|
+
s = query_lens[i].item()
|
59
|
+
e = combined_lens[i].item()
|
60
|
+
valid_len = e - s
|
61
|
+
if valid_len > 0:
|
62
|
+
result[i, :valid_len] = tensor[i, s:e]
|
63
|
+
return result
|
64
|
+
|
65
|
+
new_logits = extract_answer_tokens(logits)
|
66
|
+
|
39
67
|
# a) Get new log probs
|
40
|
-
new_probs = F.log_softmax(
|
41
|
-
new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
|
68
|
+
new_probs = F.log_softmax(new_logits, dim=-1)
|
69
|
+
new_log_probs = new_probs.gather(-1, answer['input_ids'].unsqueeze(-1)).squeeze(-1)
|
70
|
+
|
71
|
+
new_log_probs = extract_answer_tokens(new_log_probs.unsqueeze(-1)).squeeze(-1) # Ensure 3D for extraction (add singleton dim)
|
42
72
|
|
43
73
|
# b) Calculate ratio
|
44
74
|
ratio = (new_log_probs - old_log_probs).exp()
|
@@ -14,11 +14,11 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
|
-
rxnn/training/dataset.py,sha256=
|
17
|
+
rxnn/training/dataset.py,sha256=m1opjNA7XHl6Ys-NtERM00c0BLN2xuu84lsfXp-3GQA,50478
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
|
20
20
|
rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
|
21
|
-
rxnn/training/rl.py,sha256=
|
21
|
+
rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
24
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.14.dist-info/METADATA,sha256=dutamudjxMj9IzykuCONpMyqnU4emEEwvseD4nmKkfs,25960
|
37
|
+
rxnn-0.2.14.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|