evolutionary-policy-optimization 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +42 -6
- {evolutionary_policy_optimization-0.1.4.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/METADATA +12 -1
- {evolutionary_policy_optimization-0.1.4.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.4.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.4.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/licenses/LICENSE +0 -0
@@ -364,11 +364,38 @@ class Critic(Module):
|
|
364
364
|
hl_gauss_loss = hl_gauss_loss_kwargs
|
365
365
|
)
|
366
366
|
|
367
|
+
self.use_regression = use_regression
|
368
|
+
|
369
|
+
hl_gauss_loss = self.to_pred.hl_gauss_loss
|
370
|
+
|
371
|
+
self.maybe_bins_to_value = hl_gauss_loss if not use_regression else identity
|
372
|
+
self.maybe_value_to_bins = hl_gauss_loss.transform_to_logprobs if not use_regression else identity
|
373
|
+
self.loss_fn = hl_gauss_loss if not use_regression else F.mse_loss
|
374
|
+
|
375
|
+
def forward_for_loss(
|
376
|
+
self,
|
377
|
+
state,
|
378
|
+
latent,
|
379
|
+
old_values,
|
380
|
+
target,
|
381
|
+
eps_clip = 0.4
|
382
|
+
):
|
383
|
+
logits = self.forward(state, latent, return_logits = True)
|
384
|
+
|
385
|
+
value = self.maybe_bins_to_value(logits)
|
386
|
+
|
387
|
+
clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
|
388
|
+
|
389
|
+
loss = self.loss_fn(value, target, reduction = 'none')
|
390
|
+
clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
|
391
|
+
|
392
|
+
return torch.max(loss, clipped_loss).mean()
|
393
|
+
|
367
394
|
def forward(
|
368
395
|
self,
|
369
396
|
state,
|
370
397
|
latent,
|
371
|
-
|
398
|
+
return_logits = False
|
372
399
|
):
|
373
400
|
|
374
401
|
hidden = self.init_layer(state)
|
@@ -377,7 +404,8 @@ class Critic(Module):
|
|
377
404
|
|
378
405
|
hidden = self.final_act(hidden)
|
379
406
|
|
380
|
-
|
407
|
+
pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
|
408
|
+
return self.to_pred(hidden, **pred_kwargs)
|
381
409
|
|
382
410
|
# criteria for running genetic algorithm
|
383
411
|
|
@@ -740,6 +768,9 @@ class Agent(Module):
|
|
740
768
|
entropy_weight = .01,
|
741
769
|
norm_advantages = True
|
742
770
|
),
|
771
|
+
critic_loss_kwargs: dict = dict(
|
772
|
+
eps_clip = 0.4
|
773
|
+
),
|
743
774
|
ema_kwargs: dict = dict(),
|
744
775
|
actor_optim_kwargs: dict = dict(),
|
745
776
|
critic_optim_kwargs: dict = dict(),
|
@@ -778,9 +809,13 @@ class Agent(Module):
|
|
778
809
|
|
779
810
|
# gae function
|
780
811
|
|
781
|
-
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
782
812
|
self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
|
783
813
|
|
814
|
+
# actor critic loss related
|
815
|
+
|
816
|
+
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
817
|
+
self.critic_loss_kwargs = critic_loss_kwargs
|
818
|
+
|
784
819
|
# fitness score related
|
785
820
|
|
786
821
|
self.get_fitness_scores = get_fitness_scores
|
@@ -900,7 +935,6 @@ class Agent(Module):
|
|
900
935
|
if not exists(latent) and exists(latent_id):
|
901
936
|
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
902
937
|
|
903
|
-
print(self.device, state.device, next(self.actor.parameters()).device)
|
904
938
|
logits = maybe_unwrap(self.actor)(state, latent)
|
905
939
|
|
906
940
|
if not sample:
|
@@ -1044,10 +1078,12 @@ class Agent(Module):
|
|
1044
1078
|
|
1045
1079
|
# learn critic with maybe classification loss
|
1046
1080
|
|
1047
|
-
critic_loss = self.critic(
|
1081
|
+
critic_loss = self.critic.forward_for_loss(
|
1048
1082
|
states,
|
1049
1083
|
latents,
|
1050
|
-
|
1084
|
+
old_values = old_values,
|
1085
|
+
target = advantages + old_values,
|
1086
|
+
**self.critic_loss_kwargs
|
1051
1087
|
)
|
1052
1088
|
|
1053
1089
|
critic_loss.backward()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -185,4 +185,15 @@ agent.load('./agent.pt')
|
|
185
185
|
}
|
186
186
|
```
|
187
187
|
|
188
|
+
```bibtex
|
189
|
+
@article{Banerjee2022BoostingEI,
|
190
|
+
title = {Boosting Exploration in Actor-Critic Algorithms by Incentivizing Plausible Novel States},
|
191
|
+
author = {Chayan Banerjee and Zhiyong Chen and Nasimul Noman},
|
192
|
+
journal = {2023 62nd IEEE Conference on Decision and Control (CDC)},
|
193
|
+
year = {2022},
|
194
|
+
pages = {7009-7014},
|
195
|
+
url = {https://api.semanticscholar.org/CorpusID:252682944}
|
196
|
+
}
|
197
|
+
```
|
198
|
+
|
188
199
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=P-a8A1ky7FgpENUgb8VHk9qADwyQdzpUp40JoaSG2HY,43395
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.6.dist-info/METADATA,sha256=Bc4MZKhe2H6q7H2mB4g2qObQYHjG-8ZJkcF1XKxgTgw,6742
|
8
|
+
evolutionary_policy_optimization-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.6.dist-info/RECORD,,
|
File without changes
|