evolutionary-policy-optimization 0.0.47__py3-none-any.whl → 0.0.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +45 -9
- {evolutionary_policy_optimization-0.0.47.dist-info → evolutionary_policy_optimization-0.0.50.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.0.47.dist-info → evolutionary_policy_optimization-0.0.50.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.47.dist-info → evolutionary_policy_optimization-0.0.50.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.47.dist-info → evolutionary_policy_optimization-0.0.50.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ from collections import namedtuple
|
|
6
6
|
from random import randrange
|
7
7
|
|
8
8
|
import torch
|
9
|
-
from torch import nn, cat, stack, is_tensor, tensor
|
9
|
+
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
11
|
from torch.nn import Linear, Module, ModuleList
|
12
12
|
from torch.utils.data import TensorDataset, DataLoader
|
@@ -357,9 +357,10 @@ class LatentGenePool(Module):
|
|
357
357
|
frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
|
358
358
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
359
359
|
frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
|
360
|
+
migrate_every = 100, # how many steps before a migration between islands
|
360
361
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
361
362
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
362
|
-
default_should_run_ga_gamma = 1.5
|
363
|
+
default_should_run_ga_gamma = 1.5,
|
363
364
|
):
|
364
365
|
super().__init__()
|
365
366
|
|
@@ -409,6 +410,22 @@ class LatentGenePool(Module):
|
|
409
410
|
|
410
411
|
self.should_run_genetic_algorithm = should_run_genetic_algorithm
|
411
412
|
|
413
|
+
self.can_migrate = num_islands > 1
|
414
|
+
self.migrate_every = migrate_every
|
415
|
+
self.register_buffer('step', tensor(1))
|
416
|
+
|
417
|
+
def get_distance(self):
|
418
|
+
# returns latent euclidean distance as proxy for diversity
|
419
|
+
|
420
|
+
latents = rearrange(self.latents, '(i p) g -> i p g', i = self.num_islands)
|
421
|
+
|
422
|
+
distance = torch.cdist(latents, latents)
|
423
|
+
|
424
|
+
return distance
|
425
|
+
|
426
|
+
def advance_step_(self):
|
427
|
+
self.step.add_(1)
|
428
|
+
|
412
429
|
def firefly_step(
|
413
430
|
self,
|
414
431
|
fitness,
|
@@ -460,7 +477,7 @@ class LatentGenePool(Module):
|
|
460
477
|
self,
|
461
478
|
fitness, # Float['p'],
|
462
479
|
inplace = True,
|
463
|
-
migrate =
|
480
|
+
migrate = None # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
|
464
481
|
):
|
465
482
|
device = self.latents.device
|
466
483
|
|
@@ -547,8 +564,9 @@ class LatentGenePool(Module):
|
|
547
564
|
|
548
565
|
# 6. maybe migration
|
549
566
|
|
567
|
+
migrate = self.can_migrate and default(migrate, divisible_by(self.step.item(), self.migrate_every))
|
568
|
+
|
550
569
|
if migrate:
|
551
|
-
assert self.num_islands > 1
|
552
570
|
randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
|
553
571
|
|
554
572
|
migrate_mask = randperm < self.num_migrate
|
@@ -581,6 +599,8 @@ class LatentGenePool(Module):
|
|
581
599
|
|
582
600
|
self.latents.copy_(genes)
|
583
601
|
|
602
|
+
self.advance_step_()
|
603
|
+
|
584
604
|
def forward(
|
585
605
|
self,
|
586
606
|
*args,
|
@@ -632,6 +652,7 @@ class Agent(Module):
|
|
632
652
|
actor_lr = 1e-4,
|
633
653
|
critic_lr = 1e-4,
|
634
654
|
latent_lr = 1e-5,
|
655
|
+
diversity_aux_loss_weight = 0.,
|
635
656
|
use_critic_ema = True,
|
636
657
|
critic_ema_beta = 0.99,
|
637
658
|
max_grad_norm = 0.5,
|
@@ -687,6 +708,11 @@ class Agent(Module):
|
|
687
708
|
|
688
709
|
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
689
710
|
|
711
|
+
# promotes latents to be farther apart for diversity maintenance
|
712
|
+
|
713
|
+
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
714
|
+
self.diversity_aux_loss_weight = diversity_aux_loss_weight
|
715
|
+
|
690
716
|
def save(self, path, overwrite = False):
|
691
717
|
path = Path(path)
|
692
718
|
|
@@ -868,11 +894,19 @@ class Agent(Module):
|
|
868
894
|
|
869
895
|
# maybe update latents, if not frozen
|
870
896
|
|
871
|
-
if
|
872
|
-
|
897
|
+
if self.latent_gene_pool.frozen_latents:
|
898
|
+
continue
|
899
|
+
|
900
|
+
orig_latents.backward(latents.grad)
|
901
|
+
|
902
|
+
if self.has_diversity_loss:
|
903
|
+
diversity = self.latent_gene_pool.get_distance()
|
904
|
+
diversity_loss = diversity.mul(-1).exp().mean()
|
905
|
+
|
906
|
+
(diversity_loss * self.diversity_aux_loss_weight).backward()
|
873
907
|
|
874
|
-
|
875
|
-
|
908
|
+
self.latent_optim.step()
|
909
|
+
self.latent_optim.zero_grad()
|
876
910
|
|
877
911
|
# apply evolution
|
878
912
|
|
@@ -919,6 +953,7 @@ def create_agent(
|
|
919
953
|
latent_gene_pool_kwargs: dict = dict(),
|
920
954
|
actor_kwargs: dict = dict(),
|
921
955
|
critic_kwargs: dict = dict(),
|
956
|
+
**kwargs
|
922
957
|
) -> Agent:
|
923
958
|
|
924
959
|
latent_gene_pool = LatentGenePool(
|
@@ -946,7 +981,8 @@ def create_agent(
|
|
946
981
|
actor = actor,
|
947
982
|
critic = critic,
|
948
983
|
latent_gene_pool = latent_gene_pool,
|
949
|
-
use_critic_ema = use_critic_ema
|
984
|
+
use_critic_ema = use_critic_ema,
|
985
|
+
**kwargs
|
950
986
|
)
|
951
987
|
|
952
988
|
return agent
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.50
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,8 +1,8 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=nIZzjz_3WUvHYPwNH9zoKM3fyRJd3MOd6UL7ooSTQV4,33445
|
3
3
|
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
4
|
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
5
|
-
evolutionary_policy_optimization-0.0.
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
5
|
+
evolutionary_policy_optimization-0.0.50.dist-info/METADATA,sha256=iCB34UsXdosJV9q8-qXJpJfI8OdNNXB5o8MhmKy82zY,6213
|
6
|
+
evolutionary_policy_optimization-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.50.dist-info/RECORD,,
|
File without changes
|