evolutionary-policy-optimization 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +42 -5
- {evolutionary_policy_optimization-0.1.5.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/METADATA +12 -1
- {evolutionary_policy_optimization-0.1.5.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.5.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.5.dist-info → evolutionary_policy_optimization-0.1.6.dist-info}/licenses/LICENSE +0 -0
@@ -364,11 +364,38 @@ class Critic(Module):
|
|
364
364
|
hl_gauss_loss = hl_gauss_loss_kwargs
|
365
365
|
)
|
366
366
|
|
367
|
+
self.use_regression = use_regression
|
368
|
+
|
369
|
+
hl_gauss_loss = self.to_pred.hl_gauss_loss
|
370
|
+
|
371
|
+
self.maybe_bins_to_value = hl_gauss_loss if not use_regression else identity
|
372
|
+
self.maybe_value_to_bins = hl_gauss_loss.transform_to_logprobs if not use_regression else identity
|
373
|
+
self.loss_fn = hl_gauss_loss if not use_regression else F.mse_loss
|
374
|
+
|
375
|
+
def forward_for_loss(
|
376
|
+
self,
|
377
|
+
state,
|
378
|
+
latent,
|
379
|
+
old_values,
|
380
|
+
target,
|
381
|
+
eps_clip = 0.4
|
382
|
+
):
|
383
|
+
logits = self.forward(state, latent, return_logits = True)
|
384
|
+
|
385
|
+
value = self.maybe_bins_to_value(logits)
|
386
|
+
|
387
|
+
clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
|
388
|
+
|
389
|
+
loss = self.loss_fn(value, target, reduction = 'none')
|
390
|
+
clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
|
391
|
+
|
392
|
+
return torch.max(loss, clipped_loss).mean()
|
393
|
+
|
367
394
|
def forward(
|
368
395
|
self,
|
369
396
|
state,
|
370
397
|
latent,
|
371
|
-
|
398
|
+
return_logits = False
|
372
399
|
):
|
373
400
|
|
374
401
|
hidden = self.init_layer(state)
|
@@ -377,7 +404,8 @@ class Critic(Module):
|
|
377
404
|
|
378
405
|
hidden = self.final_act(hidden)
|
379
406
|
|
380
|
-
|
407
|
+
pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
|
408
|
+
return self.to_pred(hidden, **pred_kwargs)
|
381
409
|
|
382
410
|
# criteria for running genetic algorithm
|
383
411
|
|
@@ -740,6 +768,9 @@ class Agent(Module):
|
|
740
768
|
entropy_weight = .01,
|
741
769
|
norm_advantages = True
|
742
770
|
),
|
771
|
+
critic_loss_kwargs: dict = dict(
|
772
|
+
eps_clip = 0.4
|
773
|
+
),
|
743
774
|
ema_kwargs: dict = dict(),
|
744
775
|
actor_optim_kwargs: dict = dict(),
|
745
776
|
critic_optim_kwargs: dict = dict(),
|
@@ -778,9 +809,13 @@ class Agent(Module):
|
|
778
809
|
|
779
810
|
# gae function
|
780
811
|
|
781
|
-
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
782
812
|
self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
|
783
813
|
|
814
|
+
# actor critic loss related
|
815
|
+
|
816
|
+
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
817
|
+
self.critic_loss_kwargs = critic_loss_kwargs
|
818
|
+
|
784
819
|
# fitness score related
|
785
820
|
|
786
821
|
self.get_fitness_scores = get_fitness_scores
|
@@ -1043,10 +1078,12 @@ class Agent(Module):
|
|
1043
1078
|
|
1044
1079
|
# learn critic with maybe classification loss
|
1045
1080
|
|
1046
|
-
critic_loss = self.critic(
|
1081
|
+
critic_loss = self.critic.forward_for_loss(
|
1047
1082
|
states,
|
1048
1083
|
latents,
|
1049
|
-
|
1084
|
+
old_values = old_values,
|
1085
|
+
target = advantages + old_values,
|
1086
|
+
**self.critic_loss_kwargs
|
1050
1087
|
)
|
1051
1088
|
|
1052
1089
|
critic_loss.backward()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.6
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -185,4 +185,15 @@ agent.load('./agent.pt')
|
|
185
185
|
}
|
186
186
|
```
|
187
187
|
|
188
|
+
```bibtex
|
189
|
+
@article{Banerjee2022BoostingEI,
|
190
|
+
title = {Boosting Exploration in Actor-Critic Algorithms by Incentivizing Plausible Novel States},
|
191
|
+
author = {Chayan Banerjee and Zhiyong Chen and Nasimul Noman},
|
192
|
+
journal = {2023 62nd IEEE Conference on Decision and Control (CDC)},
|
193
|
+
year = {2022},
|
194
|
+
pages = {7009-7014},
|
195
|
+
url = {https://api.semanticscholar.org/CorpusID:252682944}
|
196
|
+
}
|
197
|
+
```
|
198
|
+
|
188
199
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=P-a8A1ky7FgpENUgb8VHk9qADwyQdzpUp40JoaSG2HY,43395
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.6.dist-info/METADATA,sha256=Bc4MZKhe2H6q7H2mB4g2qObQYHjG-8ZJkcF1XKxgTgw,6742
|
8
|
+
evolutionary_policy_optimization-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.6.dist-info/RECORD,,
|
File without changes
|