rxnn 0.2.34__py3-none-any.whl → 0.2.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/rl.py CHANGED
@@ -10,7 +10,7 @@ from .ddp import distributed_mean
10
10
  class RlAlgorithm(ABC):
11
11
  def __init__(self):
12
12
  super(RlAlgorithm, self).__init__()
13
- self.critic_loss = nn.MSELoss()
13
+ self.critic_loss_fn = nn.MSELoss()
14
14
 
15
15
  @abstractmethod
16
16
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
@@ -22,7 +22,7 @@ class RlAlgorithm(ABC):
22
22
  pass
23
23
 
24
24
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
25
- return self.critic_loss(values, ref_values)
25
+ return self.critic_loss_fn(values, ref_values)
26
26
 
27
27
 
28
28
  class PPOConfig(TypedDict):
@@ -55,7 +55,7 @@ class PPOAlgorithm(RlAlgorithm):
55
55
  # Critic loss with clipped values
56
56
  if self.clip_critic_values:
57
57
  values = torch.clamp(values, -self.critic_value_clip, self.critic_value_clip)
58
- return self.critic_loss(values, ref_values)
58
+ return self.critic_loss_fn(values, ref_values)
59
59
 
60
60
  def policy_loss(self, query: TokenizedDict, answer: TokenizedDict, logits: torch.Tensor,
61
61
  old_log_probs: torch.Tensor, advantages: torch.Tensor) -> torch.Tensor:
@@ -114,15 +114,17 @@ class PPOAlgorithm(RlAlgorithm):
114
114
  next_value = last_value
115
115
  next_done = torch.zeros(batch_size, device=dones.device) # Last state is terminal
116
116
  dones = dones.float()
117
- for t in reversed(range(trajectory_len)):
118
- # Check if next state is terminal
119
- non_terminal = 1.0 - next_done
120
117
 
121
- # Delta should not include next_value if next is terminal
122
- delta = rewards[t] + self.gae_gamma * next_value * non_terminal - values[t]
123
- advantages[t] = delta + self.gae_gamma * self.gae_lambda * non_terminal * last_advantage
118
+ for t in reversed(range(trajectory_len)):
119
+ if t == trajectory_len - 1:
120
+ # For the last step, use the provided last_value
121
+ delta = rewards[t] + self.gae_gamma * next_value * (1 - next_done) - values[t]
122
+ else:
123
+ # For other steps, use the next value in the trajectory
124
+ delta = rewards[t] + self.gae_gamma * values[t + 1] * (1 - dones[t + 1]) - values[t]
125
+
126
+ advantages[t] = delta + self.gae_gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
124
127
  last_advantage = advantages[t]
125
- next_value = values[t]
126
128
  next_done = dones[t]
127
129
 
128
130
  returns = advantages + values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.34
3
+ Version: 0.2.35
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -19,7 +19,7 @@ rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
19
  rxnn/training/models.py,sha256=y-9XHedSheyK1AmLBp3ayulnUvAmDuJ3t0qVg8wHBRg,7463
20
20
  rxnn/training/mrl.py,sha256=fIrg1Er0aAK4TnyDRmJC1m7az9wdkhikxv0CBCrGT-c,55868
21
21
  rxnn/training/reward.py,sha256=B7nerPk9eNAv2i7umtNF88tVQVwijNNrchIrEITGHKk,11623
22
- rxnn/training/rl.py,sha256=ckx1nlzIGZBabzwZNRj4isvHqRZwg0y0jGOT-SN6KZc,5841
22
+ rxnn/training/rl.py,sha256=OyfqVEh5Z-YvoW7baTTpOH0DuYexUOfhkZmbCpEoS_A,5962
23
23
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
24
24
  rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
25
25
  rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
33
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
34
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
35
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.34.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.34.dist-info/METADATA,sha256=Q7LqPr7KHFhMPL6UrbqG1SmtJbM2Ho-Yuxp_7LyCtYw,25960
38
- rxnn-0.2.34.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.34.dist-info/RECORD,,
36
+ rxnn-0.2.35.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
+ rxnn-0.2.35.dist-info/METADATA,sha256=EpcmlYIdyJMIP6KnT-065MWoGUudGYFoSbeRJiZAqnk,25960
38
+ rxnn-0.2.35.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ rxnn-0.2.35.dist-info/RECORD,,
File without changes
File without changes