evolutionary-policy-optimization 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +67 -8
- {evolutionary_policy_optimization-0.1.6.dist-info → evolutionary_policy_optimization-0.1.8.dist-info}/METADATA +11 -1
- {evolutionary_policy_optimization-0.1.6.dist-info → evolutionary_policy_optimization-0.1.8.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.6.dist-info → evolutionary_policy_optimization-0.1.8.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.6.dist-info → evolutionary_policy_optimization-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -207,13 +207,60 @@ def mutation(
|
|
207
207
|
):
|
208
208
|
mutations = torch.randn_like(latents)
|
209
209
|
|
210
|
-
|
210
|
+
if is_tensor(mutation_strength):
|
211
|
+
mutations = einx.multiply('b, b ...', mutation_strength, mutations)
|
212
|
+
else:
|
213
|
+
mutations *= mutation_strength
|
214
|
+
|
215
|
+
mutated = latents + mutations
|
211
216
|
|
212
217
|
if not l2norm_output:
|
213
218
|
return mutated
|
214
219
|
|
215
220
|
return l2norm(mutated)
|
216
221
|
|
222
|
+
# drawing mutation strengths from power law distribution
|
223
|
+
# proposed by https://arxiv.org/abs/1703.03334
|
224
|
+
|
225
|
+
class PowerLawDist(Module):
|
226
|
+
def __init__(
|
227
|
+
self,
|
228
|
+
values: Tensor | list[float] | None = None,
|
229
|
+
bins = None,
|
230
|
+
beta = 1.5,
|
231
|
+
):
|
232
|
+
super().__init__()
|
233
|
+
assert beta > 1.
|
234
|
+
|
235
|
+
assert exists(bins) or exists(values)
|
236
|
+
|
237
|
+
if exists(values):
|
238
|
+
if not is_tensor(values):
|
239
|
+
values = tensor(values)
|
240
|
+
|
241
|
+
assert values.ndim == 1
|
242
|
+
bins = values.shape[0]
|
243
|
+
|
244
|
+
self.beta = beta
|
245
|
+
|
246
|
+
cdf = torch.linspace(1, bins, bins).pow(-beta).cumsum(dim = -1)
|
247
|
+
cdf = cdf / cdf[-1]
|
248
|
+
|
249
|
+
self.register_buffer('cdf', cdf)
|
250
|
+
self.register_buffer('values', values)
|
251
|
+
|
252
|
+
def forward(self, shape):
|
253
|
+
device = self.cdf.device
|
254
|
+
|
255
|
+
uniform = torch.rand(shape, device = device)
|
256
|
+
|
257
|
+
sampled = torch.searchsorted(self.cdf, uniform)
|
258
|
+
|
259
|
+
if not exists(self.values):
|
260
|
+
return sampled
|
261
|
+
|
262
|
+
return self.values[sampled]
|
263
|
+
|
217
264
|
# simple MLP networks, but with latent variables
|
218
265
|
# the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
|
219
266
|
|
@@ -369,7 +416,6 @@ class Critic(Module):
|
|
369
416
|
hl_gauss_loss = self.to_pred.hl_gauss_loss
|
370
417
|
|
371
418
|
self.maybe_bins_to_value = hl_gauss_loss if not use_regression else identity
|
372
|
-
self.maybe_value_to_bins = hl_gauss_loss.transform_to_logprobs if not use_regression else identity
|
373
419
|
self.loss_fn = hl_gauss_loss if not use_regression else F.mse_loss
|
374
420
|
|
375
421
|
def forward_for_loss(
|
@@ -386,7 +432,7 @@ class Critic(Module):
|
|
386
432
|
|
387
433
|
clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
|
388
434
|
|
389
|
-
loss = self.loss_fn(
|
435
|
+
loss = self.loss_fn(logits, target, reduction = 'none')
|
390
436
|
clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
|
391
437
|
|
392
438
|
return torch.max(loss, clipped_loss).mean()
|
@@ -442,6 +488,8 @@ class LatentGenePool(Module):
|
|
442
488
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
443
489
|
frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
|
444
490
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
491
|
+
fast_genetic_algorithm = False,
|
492
|
+
fast_ga_values = torch.linspace(1, 5, 10),
|
445
493
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
446
494
|
default_should_run_ga_gamma = 1.5,
|
447
495
|
migrate_every = 100, # how many steps before a migration between islands
|
@@ -489,6 +537,8 @@ class LatentGenePool(Module):
|
|
489
537
|
self.crossover_random = crossover_random
|
490
538
|
|
491
539
|
self.mutation_strength = mutation_strength
|
540
|
+
self.mutation_strength_sampler = PowerLawDist(fast_ga_values) if fast_genetic_algorithm else None
|
541
|
+
|
492
542
|
self.num_elites = int(frac_elitism * latents_per_island)
|
493
543
|
self.has_elites = self.num_elites > 0
|
494
544
|
|
@@ -657,9 +707,14 @@ class LatentGenePool(Module):
|
|
657
707
|
if self.has_elites:
|
658
708
|
genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
|
659
709
|
|
660
|
-
# 5. mutate with gaussian noise
|
710
|
+
# 5. mutate with gaussian noise
|
711
|
+
|
712
|
+
if exists(self.mutation_strength_sampler):
|
713
|
+
mutation_strength = self.mutation_strength_sampler(genes.shape[:1])
|
714
|
+
else:
|
715
|
+
mutation_strength = self.mutation_strength
|
661
716
|
|
662
|
-
genes = mutation(genes, mutation_strength =
|
717
|
+
genes = mutation(genes, mutation_strength = mutation_strength)
|
663
718
|
|
664
719
|
# 6. maybe migration
|
665
720
|
|
@@ -844,7 +899,11 @@ class Agent(Module):
|
|
844
899
|
|
845
900
|
dummy = tensor(0)
|
846
901
|
|
902
|
+
self.clip_grad_norm_ = nn.utils.clip_grad_norm_
|
903
|
+
|
847
904
|
if wrap_with_accelerate:
|
905
|
+
self.clip_grad_norm_ = self.accelerate.clip_grad_norm_
|
906
|
+
|
848
907
|
(
|
849
908
|
self.actor,
|
850
909
|
self.critic,
|
@@ -1071,7 +1130,7 @@ class Agent(Module):
|
|
1071
1130
|
actor_loss.backward()
|
1072
1131
|
|
1073
1132
|
if exists(self.has_grad_clip):
|
1074
|
-
self.
|
1133
|
+
self.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
1075
1134
|
|
1076
1135
|
self.actor_optim.step()
|
1077
1136
|
self.actor_optim.zero_grad()
|
@@ -1089,7 +1148,7 @@ class Agent(Module):
|
|
1089
1148
|
critic_loss.backward()
|
1090
1149
|
|
1091
1150
|
if exists(self.has_grad_clip):
|
1092
|
-
self.
|
1151
|
+
self.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
1093
1152
|
|
1094
1153
|
self.critic_optim.step()
|
1095
1154
|
self.critic_optim.zero_grad()
|
@@ -1113,7 +1172,7 @@ class Agent(Module):
|
|
1113
1172
|
(diversity_loss * self.diversity_aux_loss_weight).backward()
|
1114
1173
|
|
1115
1174
|
if exists(self.has_grad_clip):
|
1116
|
-
self.
|
1175
|
+
self.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
|
1117
1176
|
|
1118
1177
|
self.latent_optim.step()
|
1119
1178
|
self.latent_optim.zero_grad()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.8
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -196,4 +196,14 @@ agent.load('./agent.pt')
|
|
196
196
|
}
|
197
197
|
```
|
198
198
|
|
199
|
+
```bibtex
|
200
|
+
@article{Doerr2017FastGA,
|
201
|
+
title = {Fast genetic algorithms},
|
202
|
+
author = {Benjamin Doerr and Huu Phuoc Le and R{\'e}gis Makhmara and Ta Duy Nguyen},
|
203
|
+
journal = {Proceedings of the Genetic and Evolutionary Computation Conference},
|
204
|
+
year = {2017},
|
205
|
+
url = {https://api.semanticscholar.org/CorpusID:16196841}
|
206
|
+
}
|
207
|
+
```
|
208
|
+
|
199
209
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=Ua0o4Xe-Z6gy76-nbB1yKndePGurSwW_otXXrrJWhgc,44835
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.8.dist-info/METADATA,sha256=tEVMyHVZjknJMQ0QEIVJhMj6QTDYW5Uqcq6nqa7LHpo,7088
|
8
|
+
evolutionary_policy_optimization-0.1.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.8.dist-info/RECORD,,
|
File without changes
|