evolutionary-policy-optimization 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +26 -40
- {evolutionary_policy_optimization-0.0.26.dist-info → evolutionary_policy_optimization-0.0.28.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.28.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.26.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.26.dist-info → evolutionary_policy_optimization-0.0.28.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.26.dist-info → evolutionary_policy_optimization-0.0.28.dist-info}/licenses/LICENSE +0 -0
@@ -303,7 +303,6 @@ class LatentGenePool(Module):
|
|
303
303
|
self,
|
304
304
|
num_latents, # same as gene pool size
|
305
305
|
dim_latent, # gene dimension
|
306
|
-
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
307
306
|
num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
|
308
307
|
dim_state = None,
|
309
308
|
frozen_latents = True,
|
@@ -312,6 +311,7 @@ class LatentGenePool(Module):
|
|
312
311
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
313
312
|
frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
|
314
313
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
314
|
+
frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
|
315
315
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
316
316
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
317
317
|
default_should_run_ga_gamma = 1.5
|
@@ -320,29 +320,17 @@ class LatentGenePool(Module):
|
|
320
320
|
|
321
321
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
322
322
|
|
323
|
-
latents = torch.randn(num_latents,
|
323
|
+
latents = torch.randn(num_latents, dim_latent)
|
324
324
|
|
325
325
|
if l2norm_latent:
|
326
326
|
latents = maybe_l2norm(latents, dim = -1)
|
327
327
|
|
328
328
|
self.num_latents = num_latents
|
329
|
-
self.
|
329
|
+
self.frozen_latents = frozen_latents
|
330
330
|
self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
|
331
331
|
|
332
332
|
self.maybe_l2norm = maybe_l2norm
|
333
333
|
|
334
|
-
# gene expression as a function of environment
|
335
|
-
|
336
|
-
self.num_latent_sets = num_latent_sets
|
337
|
-
|
338
|
-
if self.needs_latent_gate:
|
339
|
-
assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
|
340
|
-
|
341
|
-
self.to_latent_gate = nn.Sequential(
|
342
|
-
Linear(dim_state, num_latent_sets),
|
343
|
-
nn.Softmax(dim = -1)
|
344
|
-
) if self.needs_latent_gate else None
|
345
|
-
|
346
334
|
# some derived values
|
347
335
|
|
348
336
|
assert num_islands >= 1
|
@@ -361,12 +349,16 @@ class LatentGenePool(Module):
|
|
361
349
|
self.num_natural_selected = int(frac_natural_selected * latents_per_island)
|
362
350
|
|
363
351
|
self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
|
352
|
+
|
364
353
|
self.crossover_random = crossover_random
|
365
354
|
|
366
355
|
self.mutation_strength = mutation_strength
|
367
356
|
self.num_elites = int(frac_elitism * latents_per_island)
|
368
357
|
self.has_elites = self.num_elites > 0
|
369
358
|
|
359
|
+
latents_without_elites = num_latents - self.num_elites
|
360
|
+
self.num_migrate = int(frac_migrate * latents_without_elites)
|
361
|
+
|
370
362
|
if not exists(should_run_genetic_algorithm):
|
371
363
|
should_run_genetic_algorithm = ShouldRunGeneticAlgorithm(gamma = default_should_run_ga_gamma)
|
372
364
|
|
@@ -378,7 +370,6 @@ class LatentGenePool(Module):
|
|
378
370
|
beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
|
379
371
|
gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
|
380
372
|
alpha = 0.1, # exploration factor
|
381
|
-
alpha_decay = 0.995, # exploration decay each step
|
382
373
|
inplace = True,
|
383
374
|
):
|
384
375
|
islands = self.num_islands
|
@@ -424,8 +415,11 @@ class LatentGenePool(Module):
|
|
424
415
|
def genetic_algorithm_step(
|
425
416
|
self,
|
426
417
|
fitness, # Float['p'],
|
427
|
-
inplace = True
|
418
|
+
inplace = True,
|
419
|
+
migrate = False # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
|
428
420
|
):
|
421
|
+
device = self.latents.device
|
422
|
+
|
429
423
|
"""
|
430
424
|
i - islands
|
431
425
|
p - population
|
@@ -460,7 +454,7 @@ class LatentGenePool(Module):
|
|
460
454
|
|
461
455
|
return genes
|
462
456
|
|
463
|
-
genes = rearrange(genes, '(i p)
|
457
|
+
genes = rearrange(genes, '(i p) ... -> i p ...', i = islands)
|
464
458
|
|
465
459
|
orig_genes = genes
|
466
460
|
|
@@ -469,13 +463,13 @@ class LatentGenePool(Module):
|
|
469
463
|
|
470
464
|
sorted_indices = fitness.sort(dim = -1).indices
|
471
465
|
natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
|
472
|
-
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ...
|
466
|
+
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... g', g = genes.shape[-1])
|
473
467
|
|
474
468
|
genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
|
475
469
|
|
476
470
|
# 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
|
477
471
|
|
478
|
-
rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
|
472
|
+
rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants), device = device).argsort(dim = -1)
|
479
473
|
rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
|
480
474
|
|
481
475
|
participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
|
@@ -484,7 +478,7 @@ class LatentGenePool(Module):
|
|
484
478
|
parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
|
485
479
|
parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
|
486
480
|
|
487
|
-
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents)
|
481
|
+
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) g', g = genes.shape[-1])
|
488
482
|
|
489
483
|
parents = genes.gather(1, parent_gene_ids_for_gather)
|
490
484
|
parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
|
@@ -507,6 +501,16 @@ class LatentGenePool(Module):
|
|
507
501
|
|
508
502
|
genes = mutation(genes, mutation_strength = self.mutation_strength)
|
509
503
|
|
504
|
+
# 6. maybe migration
|
505
|
+
|
506
|
+
if migrate:
|
507
|
+
assert self.num_islands > 1
|
508
|
+
randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
|
509
|
+
|
510
|
+
migrate_mask = randperm < self.num_migrate
|
511
|
+
maybe_migrated_genes = torch.roll(genes, 1, dims = 0)
|
512
|
+
genes = einx.where('i p, i p g, i p g', migrate_mask, maybe_migrated_genes, genes)
|
513
|
+
|
510
514
|
# add back the elites
|
511
515
|
|
512
516
|
if self.has_elites:
|
@@ -555,22 +559,6 @@ class LatentGenePool(Module):
|
|
555
559
|
|
556
560
|
latent = self.latents[latent_id]
|
557
561
|
|
558
|
-
if self.needs_latent_gate:
|
559
|
-
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
560
|
-
|
561
|
-
if not fetching_multiple_latents:
|
562
|
-
latent = repeat(latent, '... -> b ...', b = state.shape[0])
|
563
|
-
|
564
|
-
assert latent.shape[0] == state.shape[0]
|
565
|
-
|
566
|
-
gates = self.to_latent_gate(state)
|
567
|
-
latent = einsum(latent, gates, 'b n g, b n -> b g')
|
568
|
-
|
569
|
-
elif fetching_multiple_latents:
|
570
|
-
latent = latent[:, 0]
|
571
|
-
else:
|
572
|
-
latent = latent[0]
|
573
|
-
|
574
562
|
latent = self.maybe_l2norm(latent)
|
575
563
|
|
576
564
|
if not exists(net):
|
@@ -612,7 +600,7 @@ class Agent(Module):
|
|
612
600
|
self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
|
613
601
|
self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
|
614
602
|
|
615
|
-
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.
|
603
|
+
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
616
604
|
|
617
605
|
def get_actor_actions(
|
618
606
|
self,
|
@@ -687,7 +675,6 @@ def create_agent(
|
|
687
675
|
actor_num_actions,
|
688
676
|
actor_dim_hiddens: int | tuple[int, ...],
|
689
677
|
critic_dim_hiddens: int | tuple[int, ...],
|
690
|
-
num_latent_sets = 1
|
691
678
|
) -> Agent:
|
692
679
|
|
693
680
|
actor = Actor(
|
@@ -707,7 +694,6 @@ def create_agent(
|
|
707
694
|
dim_state = dim_state,
|
708
695
|
num_latents = num_latents,
|
709
696
|
dim_latent = dim_latent,
|
710
|
-
num_latent_sets = num_latent_sets
|
711
697
|
)
|
712
698
|
|
713
699
|
return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.28
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=GckXFGdRoZT149cOlMqLUVe9oXr1QXP-gPZTv4H_HFU,21692
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.28.dist-info/METADATA,sha256=Fn846Lxaxo_OrXFD-_8IECOJ9fZL2JosriGUKMO0CfQ,4958
|
5
|
+
evolutionary_policy_optimization-0.0.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.28.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.28.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=zYKRKUkvFdxgHkc2yduN76Hph3asWX33mnpDF3isDfo,22019
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.26.dist-info/METADATA,sha256=l24aFXZu4kp1oxZeIdFTUw1mwkyzln9C64S3HNqebF4,4958
|
5
|
-
evolutionary_policy_optimization-0.0.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.26.dist-info/RECORD,,
|
File without changes
|