rxnn 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/memory/stm.py +3 -3
- rxnn/training/dataset.py +3 -0
- rxnn/training/mrl.py +66 -63
- rxnn/training/rl.py +36 -4
- {rxnn-0.2.12.dist-info → rxnn-0.2.14.dist-info}/METADATA +1 -1
- {rxnn-0.2.12.dist-info → rxnn-0.2.14.dist-info}/RECORD +8 -8
- {rxnn-0.2.12.dist-info → rxnn-0.2.14.dist-info}/LICENSE +0 -0
- {rxnn-0.2.12.dist-info → rxnn-0.2.14.dist-info}/WHEEL +0 -0
rxnn/memory/stm.py
CHANGED
@@ -39,6 +39,7 @@ class ShortTermMemory(nn.Module):
|
|
39
39
|
return self.memory[layer]
|
40
40
|
|
41
41
|
def update_layer(self, layer: int, new_stm: torch.Tensor):
|
42
|
+
self.memory = self.memory.clone()
|
42
43
|
self.memory[layer] = new_stm
|
43
44
|
|
44
45
|
def update_all(self, new_stm: torch.Tensor):
|
@@ -60,7 +61,7 @@ class ShortTermMemory(nn.Module):
|
|
60
61
|
self.register_buffer('memory', trained_stm)
|
61
62
|
|
62
63
|
def reset(self, init_type: str = None):
|
63
|
-
self.memory
|
64
|
+
self.memory = self._init_tensor(init_type).to(self.memory.device)
|
64
65
|
|
65
66
|
def resize(self, new_stm_size: int, init_type: str = None):
|
66
67
|
self.stm_size = new_stm_size
|
@@ -85,8 +86,7 @@ class ShortTermMemory(nn.Module):
|
|
85
86
|
if use_mean_from_batch:
|
86
87
|
batch_mean = self.memory.mean(dim=(1, 2, 3), keepdim=True)
|
87
88
|
delattr(self, 'memory')
|
88
|
-
self.register_buffer('memory',
|
89
|
-
self.memory.copy_(batch_mean)
|
89
|
+
self.register_buffer('memory', batch_mean)
|
90
90
|
else:
|
91
91
|
delattr(self, 'memory')
|
92
92
|
self.register_buffer('memory', self._init_tensor())
|
rxnn/training/dataset.py
CHANGED
@@ -1098,6 +1098,7 @@ class MrlDatasets:
|
|
1098
1098
|
load_kwargs: dict = None,
|
1099
1099
|
mrl_ds_kwargs: dict = None,
|
1100
1100
|
eval_split: str = None,
|
1101
|
+
max_seq_len: int = 256,
|
1101
1102
|
):
|
1102
1103
|
"""
|
1103
1104
|
Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
|
@@ -1115,6 +1116,7 @@ class MrlDatasets:
|
|
1115
1116
|
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
1116
1117
|
mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
|
1117
1118
|
eval_split (str): Load also evaluation/validation split (default: None)
|
1119
|
+
max_seq_len (int): Maximum sequence length (default: 256)
|
1118
1120
|
"""
|
1119
1121
|
if load_kwargs is None:
|
1120
1122
|
load_kwargs = {}
|
@@ -1131,6 +1133,7 @@ class MrlDatasets:
|
|
1131
1133
|
interactions_field=interactions_field,
|
1132
1134
|
split=load_split,
|
1133
1135
|
load_kwargs=load_kwargs,
|
1136
|
+
max_seq_len=max_seq_len,
|
1134
1137
|
**mrl_ds_kwargs,
|
1135
1138
|
)
|
1136
1139
|
|
rxnn/training/mrl.py
CHANGED
@@ -92,8 +92,8 @@ class MRLTrainer:
|
|
92
92
|
self.critic = critic
|
93
93
|
self.reward = reward
|
94
94
|
self.device = device
|
95
|
-
self.max_seq_len = config.get('max_seq_len',
|
96
|
-
self.critic_max_len = config.get('critic_max_len',
|
95
|
+
self.max_seq_len = config.get('max_seq_len', 256)
|
96
|
+
self.critic_max_len = config.get('critic_max_len', 512)
|
97
97
|
|
98
98
|
# Move models to device
|
99
99
|
if use_amp:
|
@@ -306,6 +306,7 @@ class MRLTrainer:
|
|
306
306
|
for i, interaction in enumerate(interactions):
|
307
307
|
# 8. Generate batch of answers based on batch of follow-up queries
|
308
308
|
next_query = self._move_batch(interaction['query'])
|
309
|
+
print(next_query['input_ids'].size())
|
309
310
|
generated_answer, log_probs = self.generate_answer(next_query)
|
310
311
|
|
311
312
|
is_last_interaction = (i + 1) == interactions_len
|
@@ -365,7 +366,7 @@ class MRLTrainer:
|
|
365
366
|
def _critic_loss(self, inputs: TokenizedDict, rewards: torch.Tensor) -> torch.Tensor:
|
366
367
|
# 1. Calculate values with critic encoder
|
367
368
|
values = self.critic(
|
368
|
-
|
369
|
+
inputs['input_ids'],
|
369
370
|
attention_mask=inputs['attention_mask'],
|
370
371
|
).squeeze()
|
371
372
|
# 2. Calculate critic loss
|
@@ -507,21 +508,21 @@ class MRLTrainer:
|
|
507
508
|
action=MrlActorAction.DECODE)
|
508
509
|
|
509
510
|
# 7. Calculate RL Algorithm (PPO etc.) loss
|
510
|
-
policy_loss = self.rl_algorithm.policy_loss(action['input_ids'], logits, log_probs, advantages)
|
511
|
+
policy_loss = self.rl_algorithm.policy_loss(next_query, action, inputs['input_ids'], logits, log_probs, advantages)
|
511
512
|
|
512
513
|
# 8. Reset gradients
|
513
514
|
self.optimizer.zero_grad()
|
514
515
|
|
515
516
|
# 9. Update the model in AMP or regular mode
|
516
517
|
if self.use_amp:
|
517
|
-
self.scaler.scale(policy_loss).backward()
|
518
|
+
self.scaler.scale(policy_loss).backward(retain_graph=True)
|
518
519
|
self.scaler.unscale_(self.optimizer)
|
519
520
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
520
521
|
error_if_nonfinite=False)
|
521
522
|
self.scaler.step(self.optimizer)
|
522
523
|
self.scaler.update()
|
523
524
|
else:
|
524
|
-
policy_loss.backward()
|
525
|
+
policy_loss.backward(retain_graph=True)
|
525
526
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
526
527
|
error_if_nonfinite=False)
|
527
528
|
self.optimizer.step()
|
@@ -594,63 +595,64 @@ class MRLTrainer:
|
|
594
595
|
# 2. Run evaluation on all batch episodes
|
595
596
|
for batch in dataloader:
|
596
597
|
with torch.no_grad():
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
598
|
+
if batch['query']['input_ids'].size(0) == batch_size:
|
599
|
+
self._increment_steps('eval')
|
600
|
+
# 3. Reset STM with random resets ratio
|
601
|
+
self.reset_stm()
|
602
|
+
|
603
|
+
# 4. Get batches for first queries, answers and all follow-up interactions
|
604
|
+
first_query, first_answer, interactions = batch['query'], batch['answer'], batch['interactions']
|
605
|
+
# 5. Encode and update STM with initial interactions (batch)
|
606
|
+
self.encode_and_update_stm(*self._move_multiple_batches(first_query, first_answer))
|
607
|
+
|
608
|
+
# 6. Save follow-up interactions len and first query and answer as previous one for iteration
|
609
|
+
interactions_len = len(interactions)
|
610
|
+
query, answer = first_query, first_answer
|
611
|
+
episode_reward = torch.tensor(0.0).to(self.device)
|
612
|
+
episode_interactions = torch.tensor(0).to(self.device)
|
613
|
+
# 7. Run all follow-up interactions
|
614
|
+
for i, interaction in enumerate(interactions):
|
615
|
+
# 8. Generate batch of answers
|
616
|
+
next_query = self._move_batch(interaction['query'])
|
617
|
+
generated_answer, _ = self.generate_answer(next_query)
|
618
|
+
|
619
|
+
is_last_interaction = (i + 1) == interactions_len
|
620
|
+
|
621
|
+
detached_answer = self._cpu_detach(generated_answer)
|
622
|
+
|
623
|
+
# 9. Depending on current strategy and step, compute reward
|
624
|
+
if self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and i == 0:
|
625
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
626
|
+
mode=MrlRewardMode.NEGATIVE, eval_mode=True)
|
627
|
+
elif self.strategy == MrlStrategy.LONG_RANGE_STRATEGY and is_last_interaction:
|
628
|
+
reward = self.compute_reward(detached_answer, interaction['answer'],
|
629
|
+
(first_query, first_answer), mode=MrlRewardMode.LONG_RANGE,
|
630
|
+
eval_mode=True)
|
631
|
+
else:
|
632
|
+
reward = self.compute_reward(detached_answer, interaction['answer'], (query, answer),
|
633
|
+
mode=MrlRewardMode.STANDARD, eval_mode=True)
|
634
|
+
|
635
|
+
# 10. Encode and update memory for the next interaction
|
636
|
+
if not is_last_interaction:
|
637
|
+
self.encode_and_update_stm(next_query, generated_answer)
|
638
|
+
|
639
|
+
# 11. Accumulate rewards
|
640
|
+
step_reward = torch.tensor(reward).mean().to(self.device)
|
641
|
+
# total
|
642
|
+
total_reward += step_reward
|
643
|
+
count += 1
|
644
|
+
# episode
|
645
|
+
episode_reward += step_reward
|
646
|
+
episode_interactions += 1
|
647
|
+
# 12. Save previous interaction
|
648
|
+
query, answer = interaction['query'], detached_answer
|
649
|
+
avg_episode_reward = (episode_reward / episode_interactions).item()
|
650
|
+
# 13. Run eval TensorBoard writer with average episode reward
|
651
|
+
self._eval_writer(avg_episode_reward, epoch)
|
652
|
+
|
653
|
+
# 14. Run "on eval episode end" callbacks
|
654
|
+
for cb in self.callbacks:
|
655
|
+
cb.on_eval_episode_end(self.actor, epoch, self.epoch_step['eval'], avg_episode_reward)
|
654
656
|
|
655
657
|
# 15. Calculate average reward
|
656
658
|
if self.use_ddp:
|
@@ -805,3 +807,4 @@ class MRLTrainer:
|
|
805
807
|
# 21. Close writer
|
806
808
|
if self.writer:
|
807
809
|
self.writer.close()
|
810
|
+
|
rxnn/training/rl.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from typing import TypedDict
|
6
|
+
from .utils import TokenizedDict
|
6
7
|
|
7
8
|
|
8
9
|
class RlAlgorithm(ABC):
|
@@ -11,7 +12,8 @@ class RlAlgorithm(ABC):
|
|
11
12
|
self.critic_loss = nn.MSELoss()
|
12
13
|
|
13
14
|
@abstractmethod
|
14
|
-
def policy_loss(self,
|
15
|
+
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
16
|
+
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
15
17
|
pass
|
16
18
|
|
17
19
|
@abstractmethod
|
@@ -35,14 +37,44 @@ class PPOAlgorithm(RlAlgorithm):
|
|
35
37
|
self.gae_lambda = config.get('gae_lambda', 0.95)
|
36
38
|
self.clip_eps = config.get('clip_eps', 0.2)
|
37
39
|
|
38
|
-
def policy_loss(self,
|
40
|
+
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
41
|
+
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
42
|
+
query_lens = query['attention_mask'].sum(dim=1).long() # Query lengths per sample
|
43
|
+
answer_mask = answer['attention_mask']
|
44
|
+
answer_lens = answer_mask.sum(dim=1).long() # Answer lengths per sample (before padding)
|
45
|
+
|
46
|
+
max_length = query['input_ids'].size(1)
|
47
|
+
|
48
|
+
combined_lens = torch.minimum(
|
49
|
+
query_lens + answer_lens,
|
50
|
+
torch.full_like(query_lens, max_length)
|
51
|
+
)
|
52
|
+
|
53
|
+
def extract_answer_tokens(tensor: torch.Tensor) -> torch.Tensor:
|
54
|
+
B, L, *rest = tensor.size()
|
55
|
+
result = torch.zeros((B, max_length, *rest), dtype=tensor.dtype, device=tensor.device)
|
56
|
+
|
57
|
+
for i in range(B):
|
58
|
+
s = query_lens[i].item()
|
59
|
+
e = combined_lens[i].item()
|
60
|
+
valid_len = e - s
|
61
|
+
if valid_len > 0:
|
62
|
+
result[i, :valid_len] = tensor[i, s:e]
|
63
|
+
return result
|
64
|
+
|
65
|
+
new_logits = extract_answer_tokens(logits)
|
66
|
+
|
39
67
|
# a) Get new log probs
|
40
|
-
new_probs = F.log_softmax(
|
41
|
-
new_log_probs = new_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
|
68
|
+
new_probs = F.log_softmax(new_logits, dim=-1)
|
69
|
+
new_log_probs = new_probs.gather(-1, answer['input_ids'].unsqueeze(-1)).squeeze(-1)
|
70
|
+
|
71
|
+
new_log_probs = extract_answer_tokens(new_log_probs.unsqueeze(-1)).squeeze(-1) # Ensure 3D for extraction (add singleton dim)
|
42
72
|
|
43
73
|
# b) Calculate ratio
|
44
74
|
ratio = (new_log_probs - old_log_probs).exp()
|
45
75
|
|
76
|
+
advantages = advantages.unsqueeze(-1)
|
77
|
+
|
46
78
|
# c) Clipped surrogate loss
|
47
79
|
surr1 = ratio * advantages
|
48
80
|
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
|
@@ -7,18 +7,18 @@ rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
|
9
9
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
10
|
-
rxnn/memory/stm.py,sha256=
|
10
|
+
rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
rxnn/rxt/models.py,sha256=qlYgU002VE21ZOlcxEM9iv9tAvsbe4mngcMI2sw3j9k,12078
|
13
13
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
|
-
rxnn/training/dataset.py,sha256=
|
17
|
+
rxnn/training/dataset.py,sha256=m1opjNA7XHl6Ys-NtERM00c0BLN2xuu84lsfXp-3GQA,50478
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
|
-
rxnn/training/mrl.py,sha256=
|
19
|
+
rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
|
20
20
|
rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
|
21
|
-
rxnn/training/rl.py,sha256=
|
21
|
+
rxnn/training/rl.py,sha256=DHFwnPUlnq2JVj6CS6DwifnC_eMeBAUVp36UCAWNMis,3934
|
22
22
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
23
23
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
24
24
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.14.dist-info/METADATA,sha256=dutamudjxMj9IzykuCONpMyqnU4emEEwvseD4nmKkfs,25960
|
37
|
+
rxnn-0.2.14.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|