evolutionary-policy-optimization 0.0.27__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +20 -3
- {evolutionary_policy_optimization-0.0.27.dist-info → evolutionary_policy_optimization-0.0.28.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.28.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.27.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.27.dist-info → evolutionary_policy_optimization-0.0.28.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.27.dist-info → evolutionary_policy_optimization-0.0.28.dist-info}/licenses/LICENSE +0 -0
@@ -311,6 +311,7 @@ class LatentGenePool(Module):
|
|
311
311
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
312
312
|
frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
|
313
313
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
314
|
+
frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
|
314
315
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
315
316
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
316
317
|
default_should_run_ga_gamma = 1.5
|
@@ -348,12 +349,16 @@ class LatentGenePool(Module):
|
|
348
349
|
self.num_natural_selected = int(frac_natural_selected * latents_per_island)
|
349
350
|
|
350
351
|
self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
|
352
|
+
|
351
353
|
self.crossover_random = crossover_random
|
352
354
|
|
353
355
|
self.mutation_strength = mutation_strength
|
354
356
|
self.num_elites = int(frac_elitism * latents_per_island)
|
355
357
|
self.has_elites = self.num_elites > 0
|
356
358
|
|
359
|
+
latents_without_elites = num_latents - self.num_elites
|
360
|
+
self.num_migrate = int(frac_migrate * latents_without_elites)
|
361
|
+
|
357
362
|
if not exists(should_run_genetic_algorithm):
|
358
363
|
should_run_genetic_algorithm = ShouldRunGeneticAlgorithm(gamma = default_should_run_ga_gamma)
|
359
364
|
|
@@ -365,7 +370,6 @@ class LatentGenePool(Module):
|
|
365
370
|
beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
|
366
371
|
gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
|
367
372
|
alpha = 0.1, # exploration factor
|
368
|
-
alpha_decay = 0.995, # exploration decay each step
|
369
373
|
inplace = True,
|
370
374
|
):
|
371
375
|
islands = self.num_islands
|
@@ -411,8 +415,11 @@ class LatentGenePool(Module):
|
|
411
415
|
def genetic_algorithm_step(
|
412
416
|
self,
|
413
417
|
fitness, # Float['p'],
|
414
|
-
inplace = True
|
418
|
+
inplace = True,
|
419
|
+
migrate = False # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
|
415
420
|
):
|
421
|
+
device = self.latents.device
|
422
|
+
|
416
423
|
"""
|
417
424
|
i - islands
|
418
425
|
p - population
|
@@ -462,7 +469,7 @@ class LatentGenePool(Module):
|
|
462
469
|
|
463
470
|
# 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
|
464
471
|
|
465
|
-
rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
|
472
|
+
rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants), device = device).argsort(dim = -1)
|
466
473
|
rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
|
467
474
|
|
468
475
|
participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
|
@@ -494,6 +501,16 @@ class LatentGenePool(Module):
|
|
494
501
|
|
495
502
|
genes = mutation(genes, mutation_strength = self.mutation_strength)
|
496
503
|
|
504
|
+
# 6. maybe migration
|
505
|
+
|
506
|
+
if migrate:
|
507
|
+
assert self.num_islands > 1
|
508
|
+
randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
|
509
|
+
|
510
|
+
migrate_mask = randperm < self.num_migrate
|
511
|
+
maybe_migrated_genes = torch.roll(genes, 1, dims = 0)
|
512
|
+
genes = einx.where('i p, i p g, i p g', migrate_mask, maybe_migrated_genes, genes)
|
513
|
+
|
497
514
|
# add back the elites
|
498
515
|
|
499
516
|
if self.has_elites:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.28
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=GckXFGdRoZT149cOlMqLUVe9oXr1QXP-gPZTv4H_HFU,21692
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.28.dist-info/METADATA,sha256=Fn846Lxaxo_OrXFD-_8IECOJ9fZL2JosriGUKMO0CfQ,4958
|
5
|
+
evolutionary_policy_optimization-0.0.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.28.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.28.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=UCCwYK-b20X-5Cq-pah1NTeHFc_35b4xZ3y0aSR8aaI,20783
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.27.dist-info/METADATA,sha256=pJ2kQD5YtKDSUp1TCO_hsrRMh6FCMm8dyu6WrpVHiQk,4958
|
5
|
-
evolutionary_policy_optimization-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.27.dist-info/RECORD,,
|
File without changes
|