rxnn 0.2.34__py3-none-any.whl → 0.2.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/rl.py
CHANGED
@@ -10,7 +10,7 @@ from .ddp import distributed_mean
|
|
10
10
|
class RlAlgorithm(ABC):
|
11
11
|
def __init__(self):
|
12
12
|
super(RlAlgorithm, self).__init__()
|
13
|
-
self.
|
13
|
+
self.critic_loss_fn = nn.MSELoss()
|
14
14
|
|
15
15
|
@abstractmethod
|
16
16
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
@@ -22,7 +22,7 @@ class RlAlgorithm(ABC):
|
|
22
22
|
pass
|
23
23
|
|
24
24
|
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
25
|
-
return self.
|
25
|
+
return self.critic_loss_fn(values, ref_values)
|
26
26
|
|
27
27
|
|
28
28
|
class PPOConfig(TypedDict):
|
@@ -55,7 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
|
|
55
55
|
# Critic loss with clipped values
|
56
56
|
if self.clip_critic_values:
|
57
57
|
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
58
|
-
return self.
|
58
|
+
return self.critic_loss_fn(values, ref_values)
|
59
59
|
|
60
60
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
61
61
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -114,15 +114,17 @@ class PPOAlgorithm(RlAlgorithm):
|
|
114
114
|
next_value = last_value
|
115
115
|
next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
|
116
116
|
dones = dones.float()
|
117
|
-
for t in reversed(range(trajectory_len)):
|
118
|
-
# Check if next state is terminal
|
119
|
-
non_terminal = 1.0 - next_done
|
120
117
|
|
121
|
-
|
122
|
-
|
123
|
-
|
118
|
+
for t in reversed(range(trajectory_len)):
|
119
|
+
if t == trajectory_len - 1:
|
120
|
+
# For the last step, use the provided last_value
|
121
|
+
delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
|
122
|
+
else:
|
123
|
+
# For other steps, use the next value in the trajectory
|
124
|
+
delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
|
125
|
+
|
126
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
|
124
127
|
last_advantage = advantages[t]
|
125
|
-
next_value = values[t]
|
126
128
|
next_done = dones[t]
|
127
129
|
|
128
130
|
returns = advantages + values
|
@@ -19,7 +19,7 @@ rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
|
19
19
|
rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
|
20
20
|
rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=OyfqVEh5Z-YvoW7baTTpOH0DuYexUOfhkZmbCpEoS_A,5962
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.35.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.35.dist-info/METADATA,sha256=EpcmlYIdyJMIP6KnT-065MWoGUudGYFoSbeRJiZAqnk,25960
|
38
|
+
rxnn-0.2.35.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.35.dist-info/RECORD,,
|
File without changes
|
File without changes
|