evolutionary-policy-optimization 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -424,18 +424,38 @@ class Critic(Module):
424
424
  latent,
425
425
  old_values,
426
426
  target,
427
- eps_clip = 0.4
427
+ eps_clip = 0.4,
428
+ use_improved = True
428
429
  ):
429
430
  logits = self.forward(state, latent, return_logits = True)
430
431
 
431
432
  value = self.maybe_bins_to_value(logits)
432
433
 
433
- clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
434
+ if use_improved:
435
+ clipped_target = target.clamp(-eps_clip, eps_clip)
434
436
 
435
- loss = self.loss_fn(logits, target, reduction = 'none')
436
- clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
437
+ old_values_lo = old_values - eps_clip
438
+ old_values_hi = old_values + eps_clip
437
439
 
438
- return torch.max(loss, clipped_loss).mean()
440
+ is_between = lambda lo, hi: (lo < value) & (value < hi)
441
+
442
+ clipped_loss = self.loss_fn(logits, clipped_target, reduction = 'none')
443
+ loss = self.loss_fn(logits, target, reduction = 'none')
444
+
445
+ value_loss = torch.where(
446
+ is_between(target, old_values_lo) | is_between(old_values_hi, target),
447
+ 0.,
448
+ torch.min(loss, clipped_loss)
449
+ )
450
+ else:
451
+ clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
452
+
453
+ loss = self.loss_fn(logits, target, reduction = 'none')
454
+ clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
455
+
456
+ value_loss = torch.max(loss, clipped_loss)
457
+
458
+ return value_loss.mean()
439
459
 
440
460
  def forward(
441
461
  self,
@@ -826,6 +846,7 @@ class Agent(Module):
826
846
  critic_loss_kwargs: dict = dict(
827
847
  eps_clip = 0.4
828
848
  ),
849
+ use_improved_critic_loss = True,
829
850
  ema_kwargs: dict = dict(),
830
851
  actor_optim_kwargs: dict = dict(),
831
852
  critic_optim_kwargs: dict = dict(),
@@ -871,6 +892,8 @@ class Agent(Module):
871
892
  self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
872
893
  self.critic_loss_kwargs = critic_loss_kwargs
873
894
 
895
+ self.use_improved_critic_loss = use_improved_critic_loss
896
+
874
897
  # fitness score related
875
898
 
876
899
  self.get_fitness_scores = get_fitness_scores
@@ -1142,6 +1165,7 @@ class Agent(Module):
1142
1165
  latents,
1143
1166
  old_values = old_values,
1144
1167
  target = advantages + old_values,
1168
+ use_improved = self.use_improved_critic_loss,
1145
1169
  **self.critic_loss_kwargs
1146
1170
  )
1147
1171
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -206,4 +206,13 @@ agent.load('./agent.pt')
206
206
  }
207
207
  ```
208
208
 
209
+ ```bibtex
210
+ @article{Lee2024AnalysisClippedCritic
211
+ title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
212
+ author = {Yongjin Lee, Moonyoung Chung},
213
+ journal = {Authorea},
214
+ year = {2024}
215
+ }
216
+ ```
217
+
209
218
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=Ua0o4Xe-Z6gy76-nbB1yKndePGurSwW_otXXrrJWhgc,44835
4
+ evolutionary_policy_optimization/epo.py,sha256=9GfSvOz6SwjAuZyhyvsLHPY8b2svMQlM3BRjilwsQ-g,45717
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.8.dist-info/METADATA,sha256=tEVMyHVZjknJMQ0QEIVJhMj6QTDYW5Uqcq6nqa7LHpo,7088
8
- evolutionary_policy_optimization-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.8.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.9.dist-info/METADATA,sha256=y5w_NwtKNQ07HeYa5r6hcPn7RsqDpehMmt5vj6mTESQ,7316
8
+ evolutionary_policy_optimization-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.9.dist-info/RECORD,,