rxnn 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/dataset.py +6 -1
- rxnn/training/mrl.py +65 -61
- rxnn/training/rl.py +34 -4
- {rxnn-0.2.13.dist-info → rxnn-0.2.15.dist-info}/METADATA +1 -1
- {rxnn-0.2.13.dist-info → rxnn-0.2.15.dist-info}/RECORD +7 -7
- {rxnn-0.2.13.dist-info → rxnn-0.2.15.dist-info}/LICENSE +0 -0
- {rxnn-0.2.13.dist-info → rxnn-0.2.15.dist-info}/WHEEL +0 -0
rxnn/training/dataset.py
CHANGED
@@ -977,6 +977,7 @@ class MrlCurriculumDataset(Dataset):
|
|
977
977
|
answer_field: str = 'answer',
|
978
978
|
interactions_field: str = 'interactions',
|
979
979
|
load_kwargs: dict = None,
|
980
|
+
max_seq_len: int = 1024,
|
980
981
|
**kwargs
|
981
982
|
):
|
982
983
|
"""
|
@@ -993,6 +994,7 @@ class MrlCurriculumDataset(Dataset):
|
|
993
994
|
answer_field (str): Answer field (default: "answer")
|
994
995
|
interactions_field (str): Interactions field (default: "interactions")
|
995
996
|
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
997
|
+
max_seq_len (int): Maximum sequence length (default: 1024)
|
996
998
|
**kwargs: Additional args for RxNN Dataset class
|
997
999
|
"""
|
998
1000
|
if load_kwargs is None:
|
@@ -1000,7 +1002,7 @@ class MrlCurriculumDataset(Dataset):
|
|
1000
1002
|
|
1001
1003
|
hf_dataset = load_dataset(dataset_id, mrl_subset, split=split, **load_kwargs)
|
1002
1004
|
|
1003
|
-
return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field, interactions_field=interactions_field, **kwargs)
|
1005
|
+
return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field, interactions_field=interactions_field, max_seq_len=max_seq_len, **kwargs)
|
1004
1006
|
|
1005
1007
|
@staticmethod
|
1006
1008
|
def collate_mrl_batch(batch: list[MrlDataItem]) -> MrlDataItem:
|
@@ -1098,6 +1100,7 @@ class MrlDatasets:
|
|
1098
1100
|
load_kwargs: dict = None,
|
1099
1101
|
mrl_ds_kwargs: dict = None,
|
1100
1102
|
eval_split: str = None,
|
1103
|
+
max_seq_len: int = 256,
|
1101
1104
|
):
|
1102
1105
|
"""
|
1103
1106
|
Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
|
@@ -1115,6 +1118,7 @@ class MrlDatasets:
|
|
1115
1118
|
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
1116
1119
|
mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
|
1117
1120
|
eval_split (str): Load also evaluation/validation split (default: None)
|
1121
|
+
max_seq_len (int): Maximum sequence length (default: 256)
|
1118
1122
|
"""
|
1119
1123
|
if load_kwargs is None:
|
1120
1124
|
load_kwargs = {}
|
@@ -1131,6 +1135,7 @@ class MrlDatasets:
|
|
1131
1135
|
interactions_field=interactions_field,
|
1132
1136
|
split=load_split,
|
1133
1137
|
load_kwargs=load_kwargs,
|
1138
|
+
max_seq_len=max_seq_len,
|
1134
1139
|
**mrl_ds_kwargs,
|
1135
1140
|
)
|
1136
1141
|
|
rxnn/training/mrl.py
CHANGED
@@ -92,8 +92,8 @@ class MRLTrainer:
|
|
92
92
|
self.critic = critic
|
93
93
|
self.reward = reward
|
94
94
|
self.device = device
|
95
|
-
self.max_seq_len = config.get('max_seq_len',
|
96
|
-
self.critic_max_len = config.get('critic_max_len',
|
95
|
+
self.max_seq_len = config.get('max_seq_len', 256)
|
96
|
+
self.critic_max_len = config.get('critic_max_len', 512)
|
97
97
|
|
98
98
|
# Move models to device
|
99
99
|
if use_amp:
|
@@ -306,6 +306,7 @@ class MRLTrainer:
|
|
306
306
|
for i, interaction in enumerate(interactions):
|
307
307
|
# 8. Generate batch of answers based on batch of follow-up queries
|
308
308
|
next_query = self._move_batch(interaction['query'])
|
309
|
+
print(next_query['input_ids'].size())
|
309
310
|
generated_answer, log_probs = self.generate_answer(next_query)
|
310
311
|
|
311
312
|
is_last_interaction = (i + 1) == interactions_len
|
@@ -365,7 +366,7 @@ class MRLTrainer:
|
|
365
366
|
def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
366
367
|
# 1. Calculate values with critic encoder
|
367
368
|
values = self.critic(
|
368
|
-
|
369
|
+
inputs['input_ids'],
|
369
370
|
attention_mask=inputs['attention_mask'],
|
370
371
|
).squeeze()
|
371
372
|
# 2. Calculate critic loss
|
@@ -461,6 +462,7 @@ class MRLTrainer:
|
|
461
462
|
# 1. Run update separately for episodes in trajectory - we have to reset memory before each episode, and update
|
462
463
|
# memory, based on collected episode data
|
463
464
|
all_losses = []
|
465
|
+
trajectories_len = len(trajectories)
|
464
466
|
for episode_idx, episode in enumerate(trajectories):
|
465
467
|
episode_steps = episode['steps']
|
466
468
|
should_reset_stm = episode['reset_stm']
|
@@ -506,7 +508,7 @@ class MRLTrainer:
|
|
506
508
|
action=MrlActorAction.DECODE)
|
507
509
|
|
508
510
|
# 7. Calculate RL Algorithm (PPO etc.) loss
|
509
|
-
policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
|
511
|
+
policy_loss = self.rl_algorithm.policy_loss(next_query, action, inputs['input_ids'], logits, log_probs, advantages)
|
510
512
|
|
511
513
|
# 8. Reset gradients
|
512
514
|
self.optimizer.zero_grad()
|
@@ -593,63 +595,64 @@ class MRLTrainer:
|
|
593
595
|
# 2. Run evaluation on all batch episodes
|
594
596
|
for batch in dataloader:
|
595
597
|
with torch.no_grad():
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
598
|
+
if batch['query']['input_ids'].size(0) == batch_size:
|
599
|
+
self._increment_steps('eval')
|
600
|
+
# 3. Reset STM with random resets ratio
|
601
|
+
self.reset_stm()
|
602
|
+
|
603
|
+
# 4. Get batches for first queries, answers and all follow-up interactions
|
604
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
605
|
+
# 5. Encode and update STM with initial interactions (batch)
|
606
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
607
|
+
|
608
|
+
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
609
|
+
interactions_len = len(interactions)
|
610
|
+
query, answer = first_query, first_answer
|
611
|
+
episode_reward = torch.tensor(0.0).to(self.device)
|
612
|
+
episode_interactions = torch.tensor(0).to(self.device)
|
613
|
+
# 7. Run all follow-up interactions
|
614
|
+
for i, interaction in enumerate(interactions):
|
615
|
+
# 8. Generate batch of answers
|
616
|
+
next_query = self._move_batch(interaction['query'])
|
617
|
+
generated_answer, _ = self.generate_answer(next_query)
|
618
|
+
|
619
|
+
is_last_interaction = (i + 1) == interactions_len
|
620
|
+
|
621
|
+
detached_answer = self._cpu_detach(generated_answer)
|
622
|
+
|
623
|
+
# 9. Depending on current strategy and step, compute reward
|
624
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
625
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
626
|
+
mode=MrlRewardMode.NEGATIVE, eval_mode=True)
|
627
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
628
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
629
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
630
|
+
eval_mode=True)
|
631
|
+
else:
|
632
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
633
|
+
mode=MrlRewardMode.STANDARD, eval_mode=True)
|
634
|
+
|
635
|
+
# 10. Encode and update memory for the next interaction
|
636
|
+
if not is_last_interaction:
|
637
|
+
self.encode_and_update_stm(next_query, generated_answer)
|
638
|
+
|
639
|
+
# 11. Accumulate rewards
|
640
|
+
step_reward = torch.tensor(reward).mean().to(self.device)
|
641
|
+
# total
|
642
|
+
total_reward += step_reward
|
643
|
+
count += 1
|
644
|
+
# episode
|
645
|
+
episode_reward += step_reward
|
646
|
+
episode_interactions += 1
|
647
|
+
# 12. Save previous interaction
|
648
|
+
query, answer = interaction['query'], detached_answer
|
649
|
+
avg_episode_reward = (episode_reward / episode_interactions).item()
|
650
|
+
# 13. Run eval TensorBoard writer with average episode reward
|
651
|
+
self._eval_writer(avg_episode_reward, epoch)
|
652
|
+
|
653
|
+
# 14. Run "on eval episode end" callbacks
|
654
|
+
for cb in self.callbacks:
|
655
|
+
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
653
656
|
|
654
657
|
# 15. Calculate average reward
|
655
658
|
if self.use_ddp:
|
@@ -804,3 +807,4 @@ class MRLTrainer:
|
|
804
807
|
# 21. Close writer
|
805
808
|
if self.writer:
|
806
809
|
self.writer.close()
|
810
|
+
|
rxnn/training/rl.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from typing import TypedDict
|
6
|
+
from .utils import TokenizedDict
|
6
7
|
|
7
8
|
|
8
9
|
class RlAlgorithm(ABC):
|
@@ -11,7 +12,8 @@ class RlAlgorithm(ABC):
|
|
11
12
|
self.critic_loss = nn.MSELoss()
|
12
13
|
|
13
14
|
@abstractmethod
|
14
|
-
def policy_loss(self,
|
15
|
+
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
16
|
+
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
15
17
|
pass
|
16
18
|
|
17
19
|
@abstractmethod
|
@@ -35,10 +37,38 @@ class PPOAlgorithm(RlAlgorithm):
|
|
35
37
|
self.gae_lambda = config.get('gae_lambda', 0.95)
|
36
38
|
self.clip_eps = config.get('clip_eps', 0.2)
|
37
39
|
|
38
|
-
def policy_loss(self,
|
40
|
+
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
41
|
+
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
42
|
+
query_lens = query['attention_mask'].sum(dim=1).long() # Query lengths per sample
|
43
|
+
answer_mask = answer['attention_mask']
|
44
|
+
answer_lens = answer_mask.sum(dim=1).long() # Answer lengths per sample (before padding)
|
45
|
+
|
46
|
+
max_length = query['input_ids'].size(1)
|
47
|
+
|
48
|
+
combined_lens = torch.minimum(
|
49
|
+
query_lens + answer_lens,
|
50
|
+
torch.full_like(query_lens, max_length)
|
51
|
+
)
|
52
|
+
|
53
|
+
def extract_answer_tokens(tensor: torch.Tensor) -> torch.Tensor:
|
54
|
+
B, L, *rest = tensor.size()
|
55
|
+
result = torch.zeros((B, max_length, *rest), dtype=tensor.dtype, device=tensor.device)
|
56
|
+
|
57
|
+
for i in range(B):
|
58
|
+
s = query_lens[i].item()
|
59
|
+
e = combined_lens[i].item()
|
60
|
+
valid_len = e - s
|
61
|
+
if valid_len > 0:
|
62
|
+
result[i, :valid_len] = tensor[i, s:e]
|
63
|
+
return result
|
64
|
+
|
65
|
+
new_logits = extract_answer_tokens(logits)
|
66
|
+
|
39
67
|
# a) Get new log probs
|
40
|
-
new_probs = F.log_softmax(
|
41
|
-
new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
|
68
|
+
new_probs = F.log_softmax(new_logits, dim=-1)
|
69
|
+
new_log_probs = new_probs.gather(-1, answer['input_ids'].unsqueeze(-1)).squeeze(-1)
|
70
|
+
|
71
|
+
new_log_probs = extract_answer_tokens(new_log_probs.unsqueeze(-1)).squeeze(-1) # Ensure 3D for extraction (add singleton dim)
|
42
72
|
|
43
73
|
# b) Calculate ratio
|
44
74
|
ratio = (new_log_probs - old_log_probs).exp()
|
@@ -14,11 +14,11 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
|
-
rxnn/training/dataset.py,sha256=
|
17
|
+
rxnn/training/dataset.py,sha256=mXUZa6ypTt73sE-G-s9jQ4_Vhp8zw43bjhsLEEPPnDo,50611
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
|
20
20
|
rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
|
21
|
-
rxnn/training/rl.py,sha256=
|
21
|
+
rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
24
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.15.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.15.dist-info/METADATA,sha256=ZMMOf5u9DqEQPyYszLG-o51FcVN_3vzE4aDIwsdk-lg,25960
|
37
|
+
rxnn-0.2.15.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|