evolutionary-policy-optimization 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -424,18 +424,40 @@ class Critic(Module):
424
424
  latent,
425
425
  old_values,
426
426
  target,
427
- eps_clip = 0.4
427
+ eps_clip = 0.4,
428
+ use_improved = True
428
429
  ):
429
430
  logits = self.forward(state, latent, return_logits = True)
430
431
 
431
432
  value = self.maybe_bins_to_value(logits)
432
433
 
433
- clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
434
+ loss_fn = partial(self.loss_fn, reduction = 'none')
434
435
 
435
- loss = self.loss_fn(logits, target, reduction = 'none')
436
- clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
436
+ if use_improved:
437
+ clipped_target = target.clamp(-eps_clip, eps_clip)
437
438
 
438
- return torch.max(loss, clipped_loss).mean()
439
+ old_values_lo = old_values - eps_clip
440
+ old_values_hi = old_values + eps_clip
441
+
442
+ is_between = lambda lo, hi: (lo < value) & (value < hi)
443
+
444
+ clipped_loss = loss_fn(logits, clipped_target)
445
+ loss = loss_fn(logits, target)
446
+
447
+ value_loss = torch.where(
448
+ is_between(target, old_values_lo) | is_between(old_values_hi, target),
449
+ 0.,
450
+ torch.min(loss, clipped_loss)
451
+ )
452
+ else:
453
+ clipped_value = old_values + (value - old_values).clamp(-eps_clip, eps_clip)
454
+
455
+ loss = loss_fn(logits, target)
456
+ clipped_loss = loss_fn(clipped_value, target)
457
+
458
+ value_loss = torch.max(loss, clipped_loss)
459
+
460
+ return value_loss.mean()
439
461
 
440
462
  def forward(
441
463
  self,
@@ -826,6 +848,7 @@ class Agent(Module):
826
848
  critic_loss_kwargs: dict = dict(
827
849
  eps_clip = 0.4
828
850
  ),
851
+ use_improved_critic_loss = True,
829
852
  ema_kwargs: dict = dict(),
830
853
  actor_optim_kwargs: dict = dict(),
831
854
  critic_optim_kwargs: dict = dict(),
@@ -871,6 +894,8 @@ class Agent(Module):
871
894
  self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
872
895
  self.critic_loss_kwargs = critic_loss_kwargs
873
896
 
897
+ self.use_improved_critic_loss = use_improved_critic_loss
898
+
874
899
  # fitness score related
875
900
 
876
901
  self.get_fitness_scores = get_fitness_scores
@@ -1142,6 +1167,7 @@ class Agent(Module):
1142
1167
  latents,
1143
1168
  old_values = old_values,
1144
1169
  target = advantages + old_values,
1170
+ use_improved = self.use_improved_critic_loss,
1145
1171
  **self.critic_loss_kwargs
1146
1172
  )
1147
1173
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -206,4 +206,13 @@ agent.load('./agent.pt')
206
206
  }
207
207
  ```
208
208
 
209
+ ```bibtex
210
+ @article{Lee2024AnalysisClippedCritic
211
+ title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
212
+ author = {Yongjin Lee, Moonyoung Chung},
213
+ journal = {Authorea},
214
+ year = {2024}
215
+ }
216
+ ```
217
+
209
218
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=Ua0o4Xe-Z6gy76-nbB1yKndePGurSwW_otXXrrJWhgc,44835
4
+ evolutionary_policy_optimization/epo.py,sha256=lYK6gfkAfYA5lm2fB0XC1UjvtjmlKCsvQNFrigj4JN0,45669
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.8.dist-info/METADATA,sha256=tEVMyHVZjknJMQ0QEIVJhMj6QTDYW5Uqcq6nqa7LHpo,7088
8
- evolutionary_policy_optimization-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.8.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.10.dist-info/METADATA,sha256=X2DgvM3fkJej6brFyPXSgR5qX-n7Mdg0FaB79AI_3l8,7317
8
+ evolutionary_policy_optimization-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.10.dist-info/RECORD,,