evolutionary-policy-optimization 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +29 -5
- {evolutionary_policy_optimization-0.1.8.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/METADATA +10 -1
- {evolutionary_policy_optimization-0.1.8.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.8.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.8.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -424,18 +424,38 @@ class Critic(Module):
|
|
424
424
|
latent,
|
425
425
|
old_values,
|
426
426
|
target,
|
427
|
-
eps_clip = 0.4
|
427
|
+
eps_clip = 0.4,
|
428
|
+
use_improved = True
|
428
429
|
):
|
429
430
|
logits = self.forward(state, latent, return_logits = True)
|
430
431
|
|
431
432
|
value = self.maybe_bins_to_value(logits)
|
432
433
|
|
433
|
-
|
434
|
+
if use_improved:
|
435
|
+
clipped_target = target.clamp(-eps_clip, eps_clip)
|
434
436
|
|
435
|
-
|
436
|
-
|
437
|
+
old_values_lo = old_values - eps_clip
|
438
|
+
old_values_hi = old_values + eps_clip
|
437
439
|
|
438
|
-
|
440
|
+
is_between = lambda lo, hi: (lo < value) & (value < hi)
|
441
|
+
|
442
|
+
clipped_loss = self.loss_fn(logits, clipped_target, reduction = 'none')
|
443
|
+
loss = self.loss_fn(logits, target, reduction = 'none')
|
444
|
+
|
445
|
+
value_loss = torch.where(
|
446
|
+
is_between(target, old_values_lo) | is_between(old_values_hi, target),
|
447
|
+
0.,
|
448
|
+
torch.min(loss, clipped_loss)
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
|
452
|
+
|
453
|
+
loss = self.loss_fn(logits, target, reduction = 'none')
|
454
|
+
clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
|
455
|
+
|
456
|
+
value_loss = torch.max(loss, clipped_loss)
|
457
|
+
|
458
|
+
return value_loss.mean()
|
439
459
|
|
440
460
|
def forward(
|
441
461
|
self,
|
@@ -826,6 +846,7 @@ class Agent(Module):
|
|
826
846
|
critic_loss_kwargs: dict = dict(
|
827
847
|
eps_clip = 0.4
|
828
848
|
),
|
849
|
+
use_improved_critic_loss = True,
|
829
850
|
ema_kwargs: dict = dict(),
|
830
851
|
actor_optim_kwargs: dict = dict(),
|
831
852
|
critic_optim_kwargs: dict = dict(),
|
@@ -871,6 +892,8 @@ class Agent(Module):
|
|
871
892
|
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
872
893
|
self.critic_loss_kwargs = critic_loss_kwargs
|
873
894
|
|
895
|
+
self.use_improved_critic_loss = use_improved_critic_loss
|
896
|
+
|
874
897
|
# fitness score related
|
875
898
|
|
876
899
|
self.get_fitness_scores = get_fitness_scores
|
@@ -1142,6 +1165,7 @@ class Agent(Module):
|
|
1142
1165
|
latents,
|
1143
1166
|
old_values = old_values,
|
1144
1167
|
target = advantages + old_values,
|
1168
|
+
use_improved = self.use_improved_critic_loss,
|
1145
1169
|
**self.critic_loss_kwargs
|
1146
1170
|
)
|
1147
1171
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.9
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -206,4 +206,13 @@ agent.load('./agent.pt')
|
|
206
206
|
}
|
207
207
|
```
|
208
208
|
|
209
|
+
```bibtex
|
210
|
+
@article{Lee2024AnalysisClippedCritic
|
211
|
+
title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
|
212
|
+
author = {Yongjin Lee, Moonyoung Chung},
|
213
|
+
journal = {Authorea},
|
214
|
+
year = {2024}
|
215
|
+
}
|
216
|
+
```
|
217
|
+
|
209
218
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=9GfSvOz6SwjAuZyhyvsLHPY8b2svMQlM3BRjilwsQ-g,45717
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.9.dist-info/METADATA,sha256=y5w_NwtKNQ07HeYa5r6hcPn7RsqDpehMmt5vj6mTESQ,7316
|
8
|
+
evolutionary_policy_optimization-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.9.dist-info/RECORD,,
|
File without changes
|