evolutionary-policy-optimization 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -364,11 +364,38 @@ class Critic(Module):
364
364
  hl_gauss_loss = hl_gauss_loss_kwargs
365
365
  )
366
366
 
367
+ self.use_regression = use_regression
368
+
369
+ hl_gauss_loss = self.to_pred.hl_gauss_loss
370
+
371
+ self.maybe_bins_to_value = hl_gauss_loss if not use_regression else identity
372
+ self.maybe_value_to_bins = hl_gauss_loss.transform_to_logprobs if not use_regression else identity
373
+ self.loss_fn = hl_gauss_loss if not use_regression else F.mse_loss
374
+
375
+ def forward_for_loss(
376
+ self,
377
+ state,
378
+ latent,
379
+ old_values,
380
+ target,
381
+ eps_clip = 0.4
382
+ ):
383
+ logits = self.forward(state, latent, return_logits = True)
384
+
385
+ value = self.maybe_bins_to_value(logits)
386
+
387
+ clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
388
+
389
+ loss = self.loss_fn(value, target, reduction = 'none')
390
+ clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
391
+
392
+ return torch.max(loss, clipped_loss).mean()
393
+
367
394
  def forward(
368
395
  self,
369
396
  state,
370
397
  latent,
371
- target = None
398
+ return_logits = False
372
399
  ):
373
400
 
374
401
  hidden = self.init_layer(state)
@@ -377,7 +404,8 @@ class Critic(Module):
377
404
 
378
405
  hidden = self.final_act(hidden)
379
406
 
380
- return self.to_pred(hidden, target = target)
407
+ pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
408
+ return self.to_pred(hidden, **pred_kwargs)
381
409
 
382
410
  # criteria for running genetic algorithm
383
411
 
@@ -740,6 +768,9 @@ class Agent(Module):
740
768
  entropy_weight = .01,
741
769
  norm_advantages = True
742
770
  ),
771
+ critic_loss_kwargs: dict = dict(
772
+ eps_clip = 0.4
773
+ ),
743
774
  ema_kwargs: dict = dict(),
744
775
  actor_optim_kwargs: dict = dict(),
745
776
  critic_optim_kwargs: dict = dict(),
@@ -778,9 +809,13 @@ class Agent(Module):
778
809
 
779
810
  # gae function
780
811
 
781
- self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
782
812
  self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
783
813
 
814
+ # actor critic loss related
815
+
816
+ self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
817
+ self.critic_loss_kwargs = critic_loss_kwargs
818
+
784
819
  # fitness score related
785
820
 
786
821
  self.get_fitness_scores = get_fitness_scores
@@ -1043,10 +1078,12 @@ class Agent(Module):
1043
1078
 
1044
1079
  # learn critic with maybe classification loss
1045
1080
 
1046
- critic_loss = self.critic(
1081
+ critic_loss = self.critic.forward_for_loss(
1047
1082
  states,
1048
1083
  latents,
1049
- target = advantages + old_values
1084
+ old_values = old_values,
1085
+ target = advantages + old_values,
1086
+ **self.critic_loss_kwargs
1050
1087
  )
1051
1088
 
1052
1089
  critic_loss.backward()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -185,4 +185,15 @@ agent.load('./agent.pt')
185
185
  }
186
186
  ```
187
187
 
188
+ ```bibtex
189
+ @article{Banerjee2022BoostingEI,
190
+ title = {Boosting Exploration in Actor-Critic Algorithms by Incentivizing Plausible Novel States},
191
+ author = {Chayan Banerjee and Zhiyong Chen and Nasimul Noman},
192
+ journal = {2023 62nd IEEE Conference on Decision and Control (CDC)},
193
+ year = {2022},
194
+ pages = {7009-7014},
195
+ url = {https://api.semanticscholar.org/CorpusID:252682944}
196
+ }
197
+ ```
198
+
188
199
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=XB20JAnmOtemkEjr6op9EwZaFc4z48LiGivpdlcKKJM,42101
4
+ evolutionary_policy_optimization/epo.py,sha256=P-a8A1ky7FgpENUgb8VHk9qADwyQdzpUp40JoaSG2HY,43395
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.5.dist-info/METADATA,sha256=fLHqN3he3chRiGpPxj0dLL41rClzXFUdaGGzHpEd5fE,6330
8
- evolutionary_policy_optimization-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.5.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.6.dist-info/METADATA,sha256=Bc4MZKhe2H6q7H2mB4g2qObQYHjG-8ZJkcF1XKxgTgw,6742
8
+ evolutionary_policy_optimization-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.6.dist-info/RECORD,,