rxnn 0.2.34__py3-none-any.whl → 0.2.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/rl.py CHANGED
@@ -10,7 +10,7 @@ from .ddp import distributed_mean
10
10
  class RlAlgorithm(ABC):
11
11
  def __init__(self):
12
12
  super(RlAlgorithm, self).__init__()
13
- self.critic_loss = nn.MSELoss()
13
+ self.critic_loss_fn = nn.MSELoss()
14
14
 
15
15
  @abstractmethod
16
16
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
@@ -22,7 +22,7 @@ class RlAlgorithm(ABC):
22
22
  pass
23
23
 
24
24
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
25
- return self.critic_loss(values, ref_values)
25
+ return self.critic_loss_fn(values, ref_values)
26
26
 
27
27
 
28
28
  class PPOConfig(TypedDict):
@@ -55,7 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
55
55
  # Critic loss with clipped values
56
56
  if self.clip_critic_values:
57
57
  values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
- return self.critic_loss(values, ref_values)
58
+ return self.critic_loss_fn(values, ref_values)
59
59
 
60
60
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
61
61
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -107,29 +107,31 @@ class PPOAlgorithm(RlAlgorithm):
107
107
  return policy_loss
108
108
 
109
109
  def _compute_gae(self, rewards: torch.Tensor, values: torch.Tensor,
110
- last_value: torch.Tensor, dones: torch.Tensor):
110
+ last_value: torch.Tensor, dones: torch.Tensor, last_done: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
111
111
  trajectory_len, batch_size = rewards.shape
112
112
  advantages = torch.zeros_like(rewards, device=rewards.device)
113
113
  last_advantage = 0
114
114
  next_value = last_value
115
- next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
115
+ next_done = last_done
116
116
  dones = dones.float()
117
- for t in reversed(range(trajectory_len)):
118
- # Check if next state is terminal
119
- non_terminal = 1.0 - next_done
120
117
 
121
- # Delta should not include next_value if next is terminal
122
- delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
123
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
118
+ for t in reversed(range(trajectory_len)):
119
+ if t == trajectory_len - 1:
120
+ # For the last step, use the provided last_value
121
+ delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
122
+ else:
123
+ # For other steps, use the next value in the trajectory
124
+ delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
125
+
126
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
124
127
  last_advantage = advantages[t]
125
- next_value = values[t]
126
128
  next_done = dones[t]
127
129
 
128
130
  returns = advantages + values
129
131
  return advantages, returns
130
132
 
131
133
  def calculate_advantages(self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
132
- advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1])
134
+ advantages, ref_values = self._compute_gae(rewards[:-1], values[:-1], values[-1], dones[:-1], dones[-1])
133
135
  if self.use_distributed_advantage_norm:
134
136
  mean_advantage = distributed_mean(advantages.mean())
135
137
  std_advantage = distributed_mean(advantages.std())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.34
3
+ Version: 0.2.36
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -19,7 +19,7 @@ rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
20
20
  rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
- rxnn/training/rl.py,sha256=ckx1nlzIGZBabzwZNRj4isvHqRZwg0y0jGOT-SN6KZc,5841
22
+ rxnn/training/rl.py,sha256=DKB7oiePVMz-8Tp3jJOzxVlPZMd7HUEOJUurCJj9f68,5974
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
25
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.34.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.34.dist-info/METADATA,sha256=Q7LqPr7KHFhMPL6UrbqG1SmtJbM2Ho-Yuxp_7LyCtYw,25960
38
- rxnn-0.2.34.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.34.dist-info/RECORD,,
36
+ rxnn-0.2.36.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.36.dist-info/METADATA,sha256=TZZFu-DI5dxsjYVR3qXL9UyTY_Z53Br9zlk3PtE2mlU,25960
38
+ rxnn-0.2.36.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.36.dist-info/RECORD,,
File without changes
File without changes