rxnn 0.2.49__py3-none-any.whl → 0.2.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/models.py +7 -6
- rxnn/training/reward.py +2 -2
- rxnn/training/rl.py +1 -0
- {rxnn-0.2.49.dist-info → rxnn-0.2.51.dist-info}/METADATA +1 -1
- {rxnn-0.2.49.dist-info → rxnn-0.2.51.dist-info}/RECORD +7 -7
- {rxnn-0.2.49.dist-info → rxnn-0.2.51.dist-info}/LICENSE +0 -0
- {rxnn-0.2.49.dist-info → rxnn-0.2.51.dist-info}/WHEEL +0 -0
rxnn/training/models.py
CHANGED
@@ -208,17 +208,18 @@ class MrlActorModel(nn.Module):
|
|
208
208
|
|
209
209
|
|
210
210
|
class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipeline_tag="text-classification"):
|
211
|
-
def __init__(self, encoder: nn.Module, embed_dim: int,
|
212
|
-
out_activation: Literal['sigmoid', 'tanh', 'linear'] = 'sigmoid', output_scale: float = 1.0, **kwargs):
|
211
|
+
def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
|
213
212
|
super(MrlCriticModel, self).__init__(**kwargs)
|
214
213
|
self.encoder = encoder
|
215
214
|
self.value_head = nn.Sequential(
|
216
215
|
GatedLinearUnit(embed_dim, embed_dim, nn.SiLU()),
|
217
216
|
nn.LayerNorm(embed_dim),
|
218
|
-
nn.Linear(embed_dim, 1)
|
219
|
-
get_activation_layer(out_activation)
|
217
|
+
nn.Linear(embed_dim, 1)
|
220
218
|
)
|
221
|
-
|
219
|
+
# Learnable scaling parameters
|
220
|
+
self.scale = nn.Parameter(torch.tensor(1.0))
|
221
|
+
self.shift = nn.Parameter(torch.tensor(0.0))
|
222
|
+
|
222
223
|
|
223
224
|
def head_parameters(self) -> Iterator[nn.Parameter]:
|
224
225
|
return self.value_head.parameters()
|
@@ -235,4 +236,4 @@ class MrlCriticModel(nn.Module, PyTorchModelHubMixin, license="apache-2.0", pipe
|
|
235
236
|
else:
|
236
237
|
x = x.mean(dim=1)
|
237
238
|
|
238
|
-
return self.value_head(x) * self.
|
239
|
+
return self.value_head(x) * self.scale + self.shift
|
rxnn/training/reward.py
CHANGED
@@ -242,8 +242,8 @@ class MrlRewardModel:
|
|
242
242
|
return self.neg_cos_saved_factor * (1 - gen_and_saved) + self.neg_cos_ref_factor * gen_and_ref
|
243
243
|
|
244
244
|
def len_reward(self, generated: TokenizedDict, reference: TokenizedDict) -> torch.Tensor:
|
245
|
-
target_lens = reference['attention_mask'].sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
|
246
|
-
lens = generated['attention_mask'].sum(dim=1)
|
245
|
+
target_lens = reference['attention_mask'].to(self.device).sum(dim=1) if self.target_len_as_ref else self.max_rewarded_len
|
246
|
+
lens = generated['attention_mask'].to(self.device).sum(dim=1)
|
247
247
|
neg_lens = target_lens / lens if self.neg_reward_len else 1.0
|
248
248
|
len_reward = torch.where(lens >= target_lens, neg_lens, lens / target_lens)
|
249
249
|
return len_reward
|
rxnn/training/rl.py
CHANGED
@@ -55,6 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
|
|
55
55
|
# Critic loss with clipped values
|
56
56
|
if self.clip_critic_values:
|
57
57
|
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
58
|
+
ref_values = torch.clamp(ref_values, -self.critic_value_clip, self.critic_value_clip)
|
58
59
|
return self.critic_loss_fn(values, ref_values)
|
59
60
|
|
60
61
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
@@ -16,10 +16,10 @@ rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
|
|
16
16
|
rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36779
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
|
-
rxnn/training/models.py,sha256=
|
19
|
+
rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
|
20
20
|
rxnn/training/mrl.py,sha256=IOi_xbQ47RPgv_2ucT9EkPeWLGBRlgPxKHFeQsYc3Pw,61074
|
21
|
-
rxnn/training/reward.py,sha256=
|
22
|
-
rxnn/training/rl.py,sha256=
|
21
|
+
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
|
+
rxnn/training/rl.py,sha256=FiOag3kaI4I40ylXE9Yx5iHWmprINBSMBbarKudABEE,6269
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.51.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.51.dist-info/METADATA,sha256=vDeCYIrxa3o0Pe09n_nppoMvyAIHnSyRJw4Q74ofBIQ,25997
|
38
|
+
rxnn-0.2.51.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.51.dist-info/RECORD,,
|
File without changes
|
File without changes
|