evolutionary-policy-optimization 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -364,11 +364,37 @@ class Critic(Module):
364
364
  hl_gauss_loss = hl_gauss_loss_kwargs
365
365
  )
366
366
 
367
+ self.use_regression = use_regression
368
+
369
+ hl_gauss_loss = self.to_pred.hl_gauss_loss
370
+
371
+ self.maybe_bins_to_value = hl_gauss_loss if not use_regression else identity
372
+ self.loss_fn = hl_gauss_loss if not use_regression else F.mse_loss
373
+
374
+ def forward_for_loss(
375
+ self,
376
+ state,
377
+ latent,
378
+ old_values,
379
+ target,
380
+ eps_clip = 0.4
381
+ ):
382
+ logits = self.forward(state, latent, return_logits = True)
383
+
384
+ value = self.maybe_bins_to_value(logits)
385
+
386
+ clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
387
+
388
+ loss = self.loss_fn(logits, target, reduction = 'none')
389
+ clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
390
+
391
+ return torch.max(loss, clipped_loss).mean()
392
+
367
393
  def forward(
368
394
  self,
369
395
  state,
370
396
  latent,
371
- target = None
397
+ return_logits = False
372
398
  ):
373
399
 
374
400
  hidden = self.init_layer(state)
@@ -377,7 +403,8 @@ class Critic(Module):
377
403
 
378
404
  hidden = self.final_act(hidden)
379
405
 
380
- return self.to_pred(hidden, target = target)
406
+ pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
407
+ return self.to_pred(hidden, **pred_kwargs)
381
408
 
382
409
  # criteria for running genetic algorithm
383
410
 
@@ -740,6 +767,9 @@ class Agent(Module):
740
767
  entropy_weight = .01,
741
768
  norm_advantages = True
742
769
  ),
770
+ critic_loss_kwargs: dict = dict(
771
+ eps_clip = 0.4
772
+ ),
743
773
  ema_kwargs: dict = dict(),
744
774
  actor_optim_kwargs: dict = dict(),
745
775
  critic_optim_kwargs: dict = dict(),
@@ -778,9 +808,13 @@ class Agent(Module):
778
808
 
779
809
  # gae function
780
810
 
781
- self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
782
811
  self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
783
812
 
813
+ # actor critic loss related
814
+
815
+ self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
816
+ self.critic_loss_kwargs = critic_loss_kwargs
817
+
784
818
  # fitness score related
785
819
 
786
820
  self.get_fitness_scores = get_fitness_scores
@@ -809,7 +843,11 @@ class Agent(Module):
809
843
 
810
844
  dummy = tensor(0)
811
845
 
846
+ self.clip_grad_norm_ = nn.utils.clip_grad_norm_
847
+
812
848
  if wrap_with_accelerate:
849
+ self.clip_grad_norm_ = self.accelerate.clip_grad_norm_
850
+
813
851
  (
814
852
  self.actor,
815
853
  self.critic,
@@ -1036,23 +1074,25 @@ class Agent(Module):
1036
1074
  actor_loss.backward()
1037
1075
 
1038
1076
  if exists(self.has_grad_clip):
1039
- self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
1077
+ self.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
1040
1078
 
1041
1079
  self.actor_optim.step()
1042
1080
  self.actor_optim.zero_grad()
1043
1081
 
1044
1082
  # learn critic with maybe classification loss
1045
1083
 
1046
- critic_loss = self.critic(
1084
+ critic_loss = self.critic.forward_for_loss(
1047
1085
  states,
1048
1086
  latents,
1049
- target = advantages + old_values
1087
+ old_values = old_values,
1088
+ target = advantages + old_values,
1089
+ **self.critic_loss_kwargs
1050
1090
  )
1051
1091
 
1052
1092
  critic_loss.backward()
1053
1093
 
1054
1094
  if exists(self.has_grad_clip):
1055
- self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
1095
+ self.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
1056
1096
 
1057
1097
  self.critic_optim.step()
1058
1098
  self.critic_optim.zero_grad()
@@ -1076,7 +1116,7 @@ class Agent(Module):
1076
1116
  (diversity_loss * self.diversity_aux_loss_weight).backward()
1077
1117
 
1078
1118
  if exists(self.has_grad_clip):
1079
- self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
1119
+ self.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
1080
1120
 
1081
1121
  self.latent_optim.step()
1082
1122
  self.latent_optim.zero_grad()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -185,4 +185,15 @@ agent.load('./agent.pt')
185
185
  }
186
186
  ```
187
187
 
188
+ ```bibtex
189
+ @article{Banerjee2022BoostingEI,
190
+ title = {Boosting Exploration in Actor-Critic Algorithms by Incentivizing Plausible Novel States},
191
+ author = {Chayan Banerjee and Zhiyong Chen and Nasimul Noman},
192
+ journal = {2023 62nd IEEE Conference on Decision and Control (CDC)},
193
+ year = {2022},
194
+ pages = {7009-7014},
195
+ url = {https://api.semanticscholar.org/CorpusID:252682944}
196
+ }
197
+ ```
198
+
188
199
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=XB20JAnmOtemkEjr6op9EwZaFc4z48LiGivpdlcKKJM,42101
4
+ evolutionary_policy_optimization/epo.py,sha256=5rOygXAfbb4dmjfseBcHgxHPpTFNMrrMDrY9IsJuZ28,43381
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.5.dist-info/METADATA,sha256=fLHqN3he3chRiGpPxj0dLL41rClzXFUdaGGzHpEd5fE,6330
8
- evolutionary_policy_optimization-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.5.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.7.dist-info/METADATA,sha256=yc_7LIYTbAhc7disU0o4ep-xVT1Ku3_nEF01yHcUzDE,6742
8
+ evolutionary_policy_optimization-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.7.dist-info/RECORD,,