evolutionary-policy-optimization 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +88 -8
- {evolutionary_policy_optimization-0.1.7.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/METADATA +20 -1
- {evolutionary_policy_optimization-0.1.7.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.7.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.7.dist-info → evolutionary_policy_optimization-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -207,13 +207,60 @@ def mutation(
|
|
207
207
|
):
|
208
208
|
mutations = torch.randn_like(latents)
|
209
209
|
|
210
|
-
|
210
|
+
if is_tensor(mutation_strength):
|
211
|
+
mutations = einx.multiply('b, b ...', mutation_strength, mutations)
|
212
|
+
else:
|
213
|
+
mutations *= mutation_strength
|
214
|
+
|
215
|
+
mutated = latents + mutations
|
211
216
|
|
212
217
|
if not l2norm_output:
|
213
218
|
return mutated
|
214
219
|
|
215
220
|
return l2norm(mutated)
|
216
221
|
|
222
|
+
# drawing mutation strengths from power law distribution
|
223
|
+
# proposed by https://arxiv.org/abs/1703.03334
|
224
|
+
|
225
|
+
class PowerLawDist(Module):
|
226
|
+
def __init__(
|
227
|
+
self,
|
228
|
+
values: Tensor | list[float] | None = None,
|
229
|
+
bins = None,
|
230
|
+
beta = 1.5,
|
231
|
+
):
|
232
|
+
super().__init__()
|
233
|
+
assert beta > 1.
|
234
|
+
|
235
|
+
assert exists(bins) or exists(values)
|
236
|
+
|
237
|
+
if exists(values):
|
238
|
+
if not is_tensor(values):
|
239
|
+
values = tensor(values)
|
240
|
+
|
241
|
+
assert values.ndim == 1
|
242
|
+
bins = values.shape[0]
|
243
|
+
|
244
|
+
self.beta = beta
|
245
|
+
|
246
|
+
cdf = torch.linspace(1, bins, bins).pow(-beta).cumsum(dim = -1)
|
247
|
+
cdf = cdf / cdf[-1]
|
248
|
+
|
249
|
+
self.register_buffer('cdf', cdf)
|
250
|
+
self.register_buffer('values', values)
|
251
|
+
|
252
|
+
def forward(self, shape):
|
253
|
+
device = self.cdf.device
|
254
|
+
|
255
|
+
uniform = torch.rand(shape, device = device)
|
256
|
+
|
257
|
+
sampled = torch.searchsorted(self.cdf, uniform)
|
258
|
+
|
259
|
+
if not exists(self.values):
|
260
|
+
return sampled
|
261
|
+
|
262
|
+
return self.values[sampled]
|
263
|
+
|
217
264
|
# simple MLP networks, but with latent variables
|
218
265
|
# the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
|
219
266
|
|
@@ -377,18 +424,38 @@ class Critic(Module):
|
|
377
424
|
latent,
|
378
425
|
old_values,
|
379
426
|
target,
|
380
|
-
eps_clip = 0.4
|
427
|
+
eps_clip = 0.4,
|
428
|
+
use_improved = True
|
381
429
|
):
|
382
430
|
logits = self.forward(state, latent, return_logits = True)
|
383
431
|
|
384
432
|
value = self.maybe_bins_to_value(logits)
|
385
433
|
|
386
|
-
|
434
|
+
if use_improved:
|
435
|
+
clipped_target = target.clamp(-eps_clip, eps_clip)
|
436
|
+
|
437
|
+
old_values_lo = old_values - eps_clip
|
438
|
+
old_values_hi = old_values + eps_clip
|
439
|
+
|
440
|
+
is_between = lambda lo, hi: (lo < value) & (value < hi)
|
387
441
|
|
388
|
-
|
389
|
-
|
442
|
+
clipped_loss = self.loss_fn(logits, clipped_target, reduction = 'none')
|
443
|
+
loss = self.loss_fn(logits, target, reduction = 'none')
|
390
444
|
|
391
|
-
|
445
|
+
value_loss = torch.where(
|
446
|
+
is_between(target, old_values_lo) | is_between(old_values_hi, target),
|
447
|
+
0.,
|
448
|
+
torch.min(loss, clipped_loss)
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
|
452
|
+
|
453
|
+
loss = self.loss_fn(logits, target, reduction = 'none')
|
454
|
+
clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
|
455
|
+
|
456
|
+
value_loss = torch.max(loss, clipped_loss)
|
457
|
+
|
458
|
+
return value_loss.mean()
|
392
459
|
|
393
460
|
def forward(
|
394
461
|
self,
|
@@ -441,6 +508,8 @@ class LatentGenePool(Module):
|
|
441
508
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
442
509
|
frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
|
443
510
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
511
|
+
fast_genetic_algorithm = False,
|
512
|
+
fast_ga_values = torch.linspace(1, 5, 10),
|
444
513
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
445
514
|
default_should_run_ga_gamma = 1.5,
|
446
515
|
migrate_every = 100, # how many steps before a migration between islands
|
@@ -488,6 +557,8 @@ class LatentGenePool(Module):
|
|
488
557
|
self.crossover_random = crossover_random
|
489
558
|
|
490
559
|
self.mutation_strength = mutation_strength
|
560
|
+
self.mutation_strength_sampler = PowerLawDist(fast_ga_values) if fast_genetic_algorithm else None
|
561
|
+
|
491
562
|
self.num_elites = int(frac_elitism * latents_per_island)
|
492
563
|
self.has_elites = self.num_elites > 0
|
493
564
|
|
@@ -656,9 +727,14 @@ class LatentGenePool(Module):
|
|
656
727
|
if self.has_elites:
|
657
728
|
genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
|
658
729
|
|
659
|
-
# 5. mutate with gaussian noise
|
730
|
+
# 5. mutate with gaussian noise
|
731
|
+
|
732
|
+
if exists(self.mutation_strength_sampler):
|
733
|
+
mutation_strength = self.mutation_strength_sampler(genes.shape[:1])
|
734
|
+
else:
|
735
|
+
mutation_strength = self.mutation_strength
|
660
736
|
|
661
|
-
genes = mutation(genes, mutation_strength =
|
737
|
+
genes = mutation(genes, mutation_strength = mutation_strength)
|
662
738
|
|
663
739
|
# 6. maybe migration
|
664
740
|
|
@@ -770,6 +846,7 @@ class Agent(Module):
|
|
770
846
|
critic_loss_kwargs: dict = dict(
|
771
847
|
eps_clip = 0.4
|
772
848
|
),
|
849
|
+
use_improved_critic_loss = True,
|
773
850
|
ema_kwargs: dict = dict(),
|
774
851
|
actor_optim_kwargs: dict = dict(),
|
775
852
|
critic_optim_kwargs: dict = dict(),
|
@@ -815,6 +892,8 @@ class Agent(Module):
|
|
815
892
|
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
816
893
|
self.critic_loss_kwargs = critic_loss_kwargs
|
817
894
|
|
895
|
+
self.use_improved_critic_loss = use_improved_critic_loss
|
896
|
+
|
818
897
|
# fitness score related
|
819
898
|
|
820
899
|
self.get_fitness_scores = get_fitness_scores
|
@@ -1086,6 +1165,7 @@ class Agent(Module):
|
|
1086
1165
|
latents,
|
1087
1166
|
old_values = old_values,
|
1088
1167
|
target = advantages + old_values,
|
1168
|
+
use_improved = self.use_improved_critic_loss,
|
1089
1169
|
**self.critic_loss_kwargs
|
1090
1170
|
)
|
1091
1171
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.9
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -196,4 +196,23 @@ agent.load('./agent.pt')
|
|
196
196
|
}
|
197
197
|
```
|
198
198
|
|
199
|
+
```bibtex
|
200
|
+
@article{Doerr2017FastGA,
|
201
|
+
title = {Fast genetic algorithms},
|
202
|
+
author = {Benjamin Doerr and Huu Phuoc Le and R{\'e}gis Makhmara and Ta Duy Nguyen},
|
203
|
+
journal = {Proceedings of the Genetic and Evolutionary Computation Conference},
|
204
|
+
year = {2017},
|
205
|
+
url = {https://api.semanticscholar.org/CorpusID:16196841}
|
206
|
+
}
|
207
|
+
```
|
208
|
+
|
209
|
+
```bibtex
|
210
|
+
@article{Lee2024AnalysisClippedCritic
|
211
|
+
title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
|
212
|
+
author = {Yongjin Lee, Moonyoung Chung},
|
213
|
+
journal = {Authorea},
|
214
|
+
year = {2024}
|
215
|
+
}
|
216
|
+
```
|
217
|
+
|
199
218
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=9GfSvOz6SwjAuZyhyvsLHPY8b2svMQlM3BRjilwsQ-g,45717
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.9.dist-info/METADATA,sha256=y5w_NwtKNQ07HeYa5r6hcPn7RsqDpehMmt5vj6mTESQ,7316
|
8
|
+
evolutionary_policy_optimization-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.9.dist-info/RECORD,,
|
File without changes
|