evolutionary-policy-optimization 0.0.47__py3-none-any.whl → 0.0.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ from collections import namedtuple
6
6
  from random import randrange
7
7
 
8
8
  import torch
9
- from torch import nn, cat, stack, is_tensor, tensor
9
+ from torch import nn, cat, stack, is_tensor, tensor, Tensor
10
10
  import torch.nn.functional as F
11
11
  from torch.nn import Linear, Module, ModuleList
12
12
  from torch.utils.data import TensorDataset, DataLoader
@@ -357,9 +357,10 @@ class LatentGenePool(Module):
357
357
  frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
358
358
  frac_elitism = 0.1, # frac of population to preserve from being noised
359
359
  frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
360
+ migrate_every = 100, # how many steps before a migration between islands
360
361
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
361
362
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
362
- default_should_run_ga_gamma = 1.5
363
+ default_should_run_ga_gamma = 1.5,
363
364
  ):
364
365
  super().__init__()
365
366
 
@@ -409,6 +410,22 @@ class LatentGenePool(Module):
409
410
 
410
411
  self.should_run_genetic_algorithm = should_run_genetic_algorithm
411
412
 
413
+ self.can_migrate = num_islands > 1
414
+ self.migrate_every = migrate_every
415
+ self.register_buffer('step', tensor(1))
416
+
417
+ def get_distance(self):
418
+ # returns latent euclidean distance as proxy for diversity
419
+
420
+ latents = rearrange(self.latents, '(i p) g -> i p g', i = self.num_islands)
421
+
422
+ distance = torch.cdist(latents, latents)
423
+
424
+ return distance
425
+
426
+ def advance_step_(self):
427
+ self.step.add_(1)
428
+
412
429
  def firefly_step(
413
430
  self,
414
431
  fitness,
@@ -460,7 +477,7 @@ class LatentGenePool(Module):
460
477
  self,
461
478
  fitness, # Float['p'],
462
479
  inplace = True,
463
- migrate = False # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
480
+ migrate = None # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
464
481
  ):
465
482
  device = self.latents.device
466
483
 
@@ -547,8 +564,9 @@ class LatentGenePool(Module):
547
564
 
548
565
  # 6. maybe migration
549
566
 
567
+ migrate = self.can_migrate and default(migrate, divisible_by(self.step.item(), self.migrate_every))
568
+
550
569
  if migrate:
551
- assert self.num_islands > 1
552
570
  randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
553
571
 
554
572
  migrate_mask = randperm < self.num_migrate
@@ -581,6 +599,8 @@ class LatentGenePool(Module):
581
599
 
582
600
  self.latents.copy_(genes)
583
601
 
602
+ self.advance_step_()
603
+
584
604
  def forward(
585
605
  self,
586
606
  *args,
@@ -632,6 +652,7 @@ class Agent(Module):
632
652
  actor_lr = 1e-4,
633
653
  critic_lr = 1e-4,
634
654
  latent_lr = 1e-5,
655
+ diversity_aux_loss_weight = 0.,
635
656
  use_critic_ema = True,
636
657
  critic_ema_beta = 0.99,
637
658
  max_grad_norm = 0.5,
@@ -687,6 +708,11 @@ class Agent(Module):
687
708
 
688
709
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
689
710
 
711
+ # promotes latents to be farther apart for diversity maintenance
712
+
713
+ self.has_diversity_loss = diversity_aux_loss_weight > 0.
714
+ self.diversity_aux_loss_weight = diversity_aux_loss_weight
715
+
690
716
  def save(self, path, overwrite = False):
691
717
  path = Path(path)
692
718
 
@@ -868,11 +894,19 @@ class Agent(Module):
868
894
 
869
895
  # maybe update latents, if not frozen
870
896
 
871
- if not self.latent_gene_pool.frozen_latents:
872
- orig_latents.backward(latents.grad)
897
+ if self.latent_gene_pool.frozen_latents:
898
+ continue
899
+
900
+ orig_latents.backward(latents.grad)
901
+
902
+ if self.has_diversity_loss:
903
+ diversity = self.latent_gene_pool.get_distance()
904
+ diversity_loss = diversity.mul(-1).exp().mean()
905
+
906
+ (diversity_loss * self.diversity_aux_loss_weight).backward()
873
907
 
874
- self.latent_optim.step()
875
- self.latent_optim.zero_grad()
908
+ self.latent_optim.step()
909
+ self.latent_optim.zero_grad()
876
910
 
877
911
  # apply evolution
878
912
 
@@ -919,6 +953,7 @@ def create_agent(
919
953
  latent_gene_pool_kwargs: dict = dict(),
920
954
  actor_kwargs: dict = dict(),
921
955
  critic_kwargs: dict = dict(),
956
+ **kwargs
922
957
  ) -> Agent:
923
958
 
924
959
  latent_gene_pool = LatentGenePool(
@@ -946,7 +981,8 @@ def create_agent(
946
981
  actor = actor,
947
982
  critic = critic,
948
983
  latent_gene_pool = latent_gene_pool,
949
- use_critic_ema = use_critic_ema
984
+ use_critic_ema = use_critic_ema,
985
+ **kwargs
950
986
  )
951
987
 
952
988
  return agent
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.47
3
+ Version: 0.0.50
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=-uRpnD0dKF6h4drVSikm9HnlP2OZ0WYQSWRQcghzd9Y,32242
2
+ evolutionary_policy_optimization/epo.py,sha256=nIZzjz_3WUvHYPwNH9zoKM3fyRJd3MOd6UL7ooSTQV4,33445
3
3
  evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
5
- evolutionary_policy_optimization-0.0.47.dist-info/METADATA,sha256=oSI5NowsOOlQZ5cPmCs-8kYeG6TmzUybpRZt_6-cFWk,6213
6
- evolutionary_policy_optimization-0.0.47.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.47.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.47.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.50.dist-info/METADATA,sha256=iCB34UsXdosJV9q8-qXJpJfI8OdNNXB5o8MhmKy82zY,6213
6
+ evolutionary_policy_optimization-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.50.dist-info/RECORD,,