rxnn 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/models.py +24 -4
- rxnn/training/mrl.py +18 -6
- rxnn/training/reward.py +9 -1
- rxnn/training/rl.py +28 -17
- rxnn/transformers/ff.py +2 -0
- {rxnn-0.2.32.dist-info → rxnn-0.2.34.dist-info}/METADATA +1 -1
- {rxnn-0.2.32.dist-info → rxnn-0.2.34.dist-info}/RECORD +9 -9
- {rxnn-0.2.32.dist-info → rxnn-0.2.34.dist-info}/LICENSE +0 -0
- {rxnn-0.2.32.dist-info → rxnn-0.2.34.dist-info}/WHEEL +0 -0
rxnn/training/models.py
CHANGED
@@ -4,6 +4,8 @@ from enum import Enum
|
|
4
4
|
from typing import Literal, Iterator
|
5
5
|
from huggingface_hub import PyTorchModelHubMixin
|
6
6
|
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
7
|
+
from ..transformers.ff import GatedLinearUnit, get_activation_layer
|
8
|
+
|
7
9
|
|
8
10
|
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
9
11
|
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
@@ -37,6 +39,7 @@ class MLMTrainingModel(nn.Module):
|
|
37
39
|
y = self.mlm_head(h)
|
38
40
|
return y
|
39
41
|
|
42
|
+
|
40
43
|
class JointTrainingModel(nn.Module):
|
41
44
|
def __init__(
|
42
45
|
self,
|
@@ -58,10 +61,12 @@ class JointTrainingModel(nn.Module):
|
|
58
61
|
y_d = self.decoder(x_d, attention_mask=attention_mask)
|
59
62
|
return y_e, y_d
|
60
63
|
|
64
|
+
|
61
65
|
class MrlActorAction(Enum):
|
62
66
|
DECODE = 1
|
63
67
|
UPDATE = 2
|
64
68
|
|
69
|
+
|
65
70
|
class MrlActorModel(nn.Module):
|
66
71
|
def __init__(
|
67
72
|
self,
|
@@ -153,18 +158,33 @@ class MrlActorModel(nn.Module):
|
|
153
158
|
list(self.memory_attention.parameters())
|
154
159
|
))
|
155
160
|
|
156
|
-
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
|
161
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None,
|
162
|
+
action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
|
157
163
|
if action == MrlActorAction.DECODE:
|
158
164
|
return self.decoder(x, attention_mask=attention_mask)
|
159
165
|
else:
|
160
166
|
_, ed = self.encoder(x, attention_mask=attention_mask)
|
161
167
|
return self.memory_attention(ed, attention_mask=attention_mask)
|
162
168
|
|
169
|
+
|
163
170
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
164
|
-
def __init__(self, encoder: nn.Module, embed_dim: int,
|
171
|
+
def __init__(self, encoder: nn.Module, embed_dim: int,
|
172
|
+
out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
|
165
173
|
super(MrlCriticModel, self).__init__(**kwargs)
|
166
174
|
self.encoder = encoder
|
167
|
-
self.value_head = nn.
|
175
|
+
self.value_head = nn.Sequential(
|
176
|
+
GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
|
177
|
+
nn.LayerNorm(embed_dim),
|
178
|
+
nn.Linear(embed_dim, 1),
|
179
|
+
get_activation_layer(out_activation)
|
180
|
+
)
|
181
|
+
self.output_scale = output_scale
|
182
|
+
|
183
|
+
def head_parameters(self) -> Iterator[nn.Parameter]:
|
184
|
+
return self.value_head.parameters()
|
185
|
+
|
186
|
+
def encoder_parameters(self) -> Iterator[nn.Parameter]:
|
187
|
+
return self.encoder.parameters()
|
168
188
|
|
169
189
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
170
190
|
x, _ = self.encoder(x, attention_mask=attention_mask)
|
@@ -175,4 +195,4 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
|
|
175
195
|
else:
|
176
196
|
x = x.mean(dim=1)
|
177
197
|
|
178
|
-
return self.value_head(x)
|
198
|
+
return self.value_head(x) * self.output_scale
|
rxnn/training/mrl.py
CHANGED
@@ -15,11 +15,13 @@ from .reward import MrlRewardMode, MrlRewardModel
|
|
15
15
|
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
16
16
|
from .ddp import get_os_ddp_config, distributed_mean
|
17
17
|
|
18
|
+
|
18
19
|
class MrlConfig(TypedDict):
|
19
20
|
lr: float
|
20
21
|
separate_memory_lr: Optional[bool]
|
21
22
|
memory_lr: Optional[float]
|
22
23
|
critic_lr: float
|
24
|
+
critic_encoder_lr: float
|
23
25
|
max_seq_len: int
|
24
26
|
critic_max_len: int
|
25
27
|
weight_decay: float
|
@@ -58,6 +60,7 @@ class CurriculumConfig(TypedDict):
|
|
58
60
|
lr: Optional[float]
|
59
61
|
memory_lr: Optional[float]
|
60
62
|
critic_lr: Optional[float]
|
63
|
+
critic_encoder_lr: Optional[float]
|
61
64
|
weight_decay: Optional[float]
|
62
65
|
critic_weight_decay: Optional[float]
|
63
66
|
update_epochs: Optional[int]
|
@@ -158,6 +161,7 @@ class MRLTrainer:
|
|
158
161
|
'critic_lr': config.get('critic_lr', 1e-4),
|
159
162
|
'weight_decay': config.get('weight_decay', 0.01),
|
160
163
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
164
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
161
165
|
}
|
162
166
|
else:
|
163
167
|
self.base_optim_config = {
|
@@ -165,6 +169,7 @@ class MRLTrainer:
|
|
165
169
|
'critic_lr': config.get('critic_lr', 1e-4),
|
166
170
|
'weight_decay': config.get('weight_decay', 0.01),
|
167
171
|
'critic_weight_decay': config.get('critic_weight_decay', 0.01),
|
172
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', config.get('critic_lr', 1e-4)),
|
168
173
|
}
|
169
174
|
|
170
175
|
self.optim_config = self.base_optim_config
|
@@ -202,6 +207,7 @@ class MRLTrainer:
|
|
202
207
|
critic_lr: float,
|
203
208
|
weight_decay: float,
|
204
209
|
critic_weight_decay: float,
|
210
|
+
critic_encoder_lr: float,
|
205
211
|
memory_lr: Optional[float] = None,
|
206
212
|
) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
|
207
213
|
if memory_lr is not None:
|
@@ -219,8 +225,10 @@ class MRLTrainer:
|
|
219
225
|
)
|
220
226
|
|
221
227
|
critic_optimizer = torch.optim.AdamW(
|
222
|
-
|
223
|
-
|
228
|
+
[
|
229
|
+
{'params': self.critic.head_parameters(), 'lr': critic_lr},
|
230
|
+
{'params': self.critic.encoder_parameters(), 'lr': critic_encoder_lr},
|
231
|
+
],
|
224
232
|
weight_decay=critic_weight_decay,
|
225
233
|
)
|
226
234
|
|
@@ -481,7 +489,7 @@ class MRLTrainer:
|
|
481
489
|
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
482
490
|
pad_token_id=self.pad_token_id)
|
483
491
|
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
484
|
-
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
492
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
|
485
493
|
# 2.2 Run backpropagation with scaler
|
486
494
|
self.critic_scaler.scale(critic_loss).backward()
|
487
495
|
# 2.3 Unscale and clip gradients
|
@@ -495,7 +503,7 @@ class MRLTrainer:
|
|
495
503
|
critic_state = smart_concat_critic_states(*state, max_length=self.critic_max_len,
|
496
504
|
pad_token_id=self.pad_token_id)
|
497
505
|
values = self.critic(critic_state['input_ids'], attention_mask=critic_state['attention_mask']).squeeze()
|
498
|
-
critic_loss = self.rl_algorithm.critic_loss(values, ref_values)
|
506
|
+
critic_loss = self.rl_algorithm.critic_loss(values, ref_values.detach())
|
499
507
|
# 2.2 Run backpropagation
|
500
508
|
critic_loss.backward()
|
501
509
|
# 2.3 Clip gradients
|
@@ -633,7 +641,8 @@ class MRLTrainer:
|
|
633
641
|
for i, t in enumerate(episode['steps'])
|
634
642
|
]
|
635
643
|
values = torch.stack([
|
636
|
-
self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in
|
644
|
+
self._critic_values_with_memory(r, *self._move_multiple_batches(*t['state'])) for t, r in
|
645
|
+
flat_trajectories
|
637
646
|
]).to(self.device)
|
638
647
|
rewards = torch.stack([torch.tensor(t['reward']) for t, _ in flat_trajectories]).to(self.device)
|
639
648
|
dones = torch.stack([torch.tensor(t['done']) for t, _ in flat_trajectories]).to(self.device)
|
@@ -646,7 +655,8 @@ class MRLTrainer:
|
|
646
655
|
dones = torch.stack([torch.tensor(t['done']) for t in flat_trajectories]).to(self.device)
|
647
656
|
return values, rewards, dones
|
648
657
|
|
649
|
-
def _critic_values_with_memory(self, reset_stm: bool,
|
658
|
+
def _critic_values_with_memory(self, reset_stm: bool,
|
659
|
+
*moved_state: tuple[TokenizedDict, TokenizedDict, TokenizedDict]) -> torch.Tensor:
|
650
660
|
# 1. Calculate critic values in memory aware version - reset/update STM before calculating values
|
651
661
|
with torch.no_grad():
|
652
662
|
# 2. Reset STM if it was reset in trajectory collection
|
@@ -933,6 +943,7 @@ class MRLTrainer:
|
|
933
943
|
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
934
944
|
'critic_weight_decay': config.get('critic_weight_decay',
|
935
945
|
self.base_optim_config['critic_weight_decay']),
|
946
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
936
947
|
'memory_lr': config.get('memory_lr', self.base_optim_config['memory_lr']),
|
937
948
|
}
|
938
949
|
else:
|
@@ -942,6 +953,7 @@ class MRLTrainer:
|
|
942
953
|
'weight_decay': config.get('weight_decay', self.base_optim_config['weight_decay']),
|
943
954
|
'critic_weight_decay': config.get('critic_weight_decay',
|
944
955
|
self.base_optim_config['critic_weight_decay']),
|
956
|
+
'critic_encoder_lr': config.get('critic_encoder_lr', self.base_optim_config['critic_encoder_lr']),
|
945
957
|
}
|
946
958
|
self.optimizer, self.critic_optimizer = self._init_optimizers(**self.optim_config)
|
947
959
|
elif self.optim_config != self.base_optim_config:
|
rxnn/training/reward.py
CHANGED
@@ -42,6 +42,7 @@ class MrlRewardModel:
|
|
42
42
|
running_mean_decay: float = 0.2,
|
43
43
|
bleu_saved_weights: tuple = (0.5, 0.5),
|
44
44
|
bleu_ref_weights: tuple = (0.5, 0.5),
|
45
|
+
tanh_reward_scale: bool = False,
|
45
46
|
rewards_scale: float = 1.0,
|
46
47
|
):
|
47
48
|
self.shared_embedding = shared_embedding.to(device)
|
@@ -71,6 +72,7 @@ class MrlRewardModel:
|
|
71
72
|
self.running_mean_decay = running_mean_decay
|
72
73
|
self.bleu_ref_weights = bleu_ref_weights
|
73
74
|
self.bleu_saved_weights = bleu_saved_weights
|
75
|
+
self.tanh_reward_scale = tanh_reward_scale
|
74
76
|
self.rewards_scale = rewards_scale
|
75
77
|
|
76
78
|
self.prev_data_running_mean = None
|
@@ -175,6 +177,12 @@ class MrlRewardModel:
|
|
175
177
|
self.prev_data_running_mean = (1 - self.running_mean_decay) * self._sequence_embedding(
|
176
178
|
prev_data) + self.running_mean_decay * self.prev_data_running_mean
|
177
179
|
|
180
|
+
def _pre_scale_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
|
181
|
+
if self.tanh_reward_scale:
|
182
|
+
return (rewards * 2) - 1 # Convert [0,1] to [-1,1]
|
183
|
+
else:
|
184
|
+
return rewards
|
185
|
+
|
178
186
|
def __call__(
|
179
187
|
self,
|
180
188
|
generated: TokenizedDict,
|
@@ -204,5 +212,5 @@ class MrlRewardModel:
|
|
204
212
|
cosine = self.negative_cosine(generated['input_ids'], reference['input_ids'], saved_data['input_ids'])
|
205
213
|
sim_rewards = self.neg_bleu_factor * torch.tensor(bleu, device=self.device) + self.neg_cos_factor * cosine
|
206
214
|
|
207
|
-
rewards = (sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
215
|
+
rewards = self._pre_scale_rewards(sim_rewards + self.len_factor * self.len_reward(generated) if self.reward_len else sim_rewards) * self.rewards_scale
|
208
216
|
return rewards.tolist()
|
rxnn/training/rl.py
CHANGED
@@ -21,8 +21,8 @@ class RlAlgorithm(ABC):
|
|
21
21
|
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
22
22
|
pass
|
23
23
|
|
24
|
-
def critic_loss(self,
|
25
|
-
return self.critic_loss(
|
24
|
+
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
25
|
+
return self.critic_loss(values, ref_values)
|
26
26
|
|
27
27
|
|
28
28
|
class PPOConfig(TypedDict):
|
@@ -31,6 +31,8 @@ class PPOConfig(TypedDict):
|
|
31
31
|
gae_gamma: Optional[float]
|
32
32
|
entropy_coef: Optional[float]
|
33
33
|
use_distributed_advantage_norm: Optional[bool]
|
34
|
+
clip_critic_values: Optional[bool]
|
35
|
+
critic_value_clip: Optional[float]
|
34
36
|
|
35
37
|
|
36
38
|
class PPOAlgorithm(RlAlgorithm):
|
@@ -46,6 +48,14 @@ class PPOAlgorithm(RlAlgorithm):
|
|
46
48
|
self.gae_gamma = config.get('gae_gamma', 0.99)
|
47
49
|
self.entropy_coef = config.get('entropy_coef', 0.01)
|
48
50
|
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
51
|
+
self.clip_critic_values = config.get('clip_critic_values', True)
|
52
|
+
self.critic_value_clip = config.get('critic_value_clip', 10.0)
|
53
|
+
|
54
|
+
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
55
|
+
# Critic loss with clipped values
|
56
|
+
if self.clip_critic_values:
|
57
|
+
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
58
|
+
return self.critic_loss(values, ref_values)
|
49
59
|
|
50
60
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
51
61
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -96,23 +106,24 @@ class PPOAlgorithm(RlAlgorithm):
|
|
96
106
|
|
97
107
|
return policy_loss
|
98
108
|
|
99
|
-
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
|
100
|
-
|
101
|
-
|
109
|
+
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
|
110
|
+
last_value: torch.Tensor, dones: torch.Tensor):
|
111
|
+
trajectory_len, batch_size = rewards.shape
|
112
|
+
advantages = torch.zeros_like(rewards, device=rewards.device)
|
102
113
|
last_advantage = 0
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
delta = rewards[t] + self.gae_gamma * next_values - values[t]
|
114
|
-
advantages[t] = delta + self.gae_gamma * self.gae_lambda * last_advantage
|
114
|
+
next_value = last_value
|
115
|
+
next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
|
116
|
+
dones = dones.float()
|
117
|
+
for t in reversed(range(trajectory_len)):
|
118
|
+
# Check if next state is terminal
|
119
|
+
non_terminal = 1.0 - next_done
|
120
|
+
|
121
|
+
# Delta should not include next_value if next is terminal
|
122
|
+
delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
|
123
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
|
115
124
|
last_advantage = advantages[t]
|
125
|
+
next_value = values[t]
|
126
|
+
next_done = dones[t]
|
116
127
|
|
117
128
|
returns = advantages + values
|
118
129
|
return advantages, returns
|
rxnn/transformers/ff.py
CHANGED
@@ -16,16 +16,16 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
|
16
16
|
rxnn/training/callbacks.py,sha256=p72lbzFAmFjpcUvyy4aUB3qd53I8C6Sk5w9nQvsKgTk,35852
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
20
|
-
rxnn/training/mrl.py,sha256=
|
21
|
-
rxnn/training/reward.py,sha256=
|
22
|
-
rxnn/training/rl.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
|
20
|
+
rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
|
21
|
+
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
|
+
rxnn/training/rl.py,sha256=ckx1nlzIGZBabzwZNRj4isvHqRZwg0y0jGOT-SN6KZc,5841
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
|
28
|
-
rxnn/transformers/ff.py,sha256=
|
28
|
+
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
29
|
rxnn/transformers/layers.py,sha256=l0bXmhN7KOkCw0KTVLixWSo9Op4SesGabWJ4R4EQBMY,7988
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
31
|
rxnn/transformers/models.py,sha256=hey6tFN9gmLfWCZLjtl_9OcvIjGpWLI1IDeVnr5y8YM,10583
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.34.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.34.dist-info/METADATA,sha256=Q7LqPr7KHFhMPL6UrbqG1SmtJbM2Ho-Yuxp_7LyCtYw,25960
|
38
|
+
rxnn-0.2.34.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.34.dist-info/RECORD,,
|
File without changes
|
File without changes
|