rxnn 0.2.34__py3-none-any.whl → 0.2.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/rl.py
CHANGED
@@ -10,7 +10,7 @@ from .ddp import distributed_mean
|
|
10
10
|
class RlAlgorithm(ABC):
|
11
11
|
def __init__(self):
|
12
12
|
super(RlAlgorithm, self).__init__()
|
13
|
-
self.
|
13
|
+
self.critic_loss_fn = nn.MSELoss()
|
14
14
|
|
15
15
|
@abstractmethod
|
16
16
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
@@ -22,7 +22,7 @@ class RlAlgorithm(ABC):
|
|
22
22
|
pass
|
23
23
|
|
24
24
|
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
25
|
-
return self.
|
25
|
+
return self.critic_loss_fn(values, ref_values)
|
26
26
|
|
27
27
|
|
28
28
|
class PPOConfig(TypedDict):
|
@@ -55,7 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
|
|
55
55
|
# Critic loss with clipped values
|
56
56
|
if self.clip_critic_values:
|
57
57
|
values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
|
58
|
-
return self.
|
58
|
+
return self.critic_loss_fn(values, ref_values)
|
59
59
|
|
60
60
|
def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
|
61
61
|
old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
|
@@ -107,29 +107,31 @@ class PPOAlgorithm(RlAlgorithm):
|
|
107
107
|
return policy_loss
|
108
108
|
|
109
109
|
def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
|
110
|
-
last_value: torch.Tensor, dones: torch.Tensor):
|
110
|
+
last_value: torch.Tensor, dones: torch.Tensor, last_done: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
111
111
|
trajectory_len, batch_size = rewards.shape
|
112
112
|
advantages = torch.zeros_like(rewards, device=rewards.device)
|
113
113
|
last_advantage = 0
|
114
114
|
next_value = last_value
|
115
|
-
next_done =
|
115
|
+
next_done = last_done
|
116
116
|
dones = dones.float()
|
117
|
-
for t in reversed(range(trajectory_len)):
|
118
|
-
# Check if next state is terminal
|
119
|
-
non_terminal = 1.0 - next_done
|
120
117
|
|
121
|
-
|
122
|
-
|
123
|
-
|
118
|
+
for t in reversed(range(trajectory_len)):
|
119
|
+
if t == trajectory_len - 1:
|
120
|
+
# For the last step, use the provided last_value
|
121
|
+
delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
|
122
|
+
else:
|
123
|
+
# For other steps, use the next value in the trajectory
|
124
|
+
delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
|
125
|
+
|
126
|
+
advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
|
124
127
|
last_advantage = advantages[t]
|
125
|
-
next_value = values[t]
|
126
128
|
next_done = dones[t]
|
127
129
|
|
128
130
|
returns = advantages + values
|
129
131
|
return advantages, returns
|
130
132
|
|
131
133
|
def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
132
|
-
advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
|
134
|
+
advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1], dones[-1])
|
133
135
|
if self.use_distributed_advantage_norm:
|
134
136
|
mean_advantage = distributed_mean(advantages.mean())
|
135
137
|
std_advantage = distributed_mean(advantages.std())
|
@@ -19,7 +19,7 @@ rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
|
19
19
|
rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
|
20
20
|
rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
|
21
21
|
rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=DKB7oiePVMz-8Tp3jJOzxVlPZMd7HUEOJUurCJj9f68,5974
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
25
|
rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.36.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.36.dist-info/METADATA,sha256=TZZFu-DI5dxsjYVR3qXL9UyTY_Z53Br9zlk3PtE2mlU,25960
|
38
|
+
rxnn-0.2.36.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|