evolutionary-policy-optimization 0.0.47__py3-none-any.whl → 0.0.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -357,9 +357,10 @@ class LatentGenePool(Module):
357
357
  frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
358
358
  frac_elitism = 0.1, # frac of population to preserve from being noised
359
359
  frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
360
+ migrate_every = 100, # how many steps before a migration between islands
360
361
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
361
362
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
362
- default_should_run_ga_gamma = 1.5
363
+ default_should_run_ga_gamma = 1.5,
363
364
  ):
364
365
  super().__init__()
365
366
 
@@ -409,6 +410,13 @@ class LatentGenePool(Module):
409
410
 
410
411
  self.should_run_genetic_algorithm = should_run_genetic_algorithm
411
412
 
413
+ self.can_migrate = num_islands > 1
414
+ self.migrate_every = migrate_every
415
+ self.register_buffer('step', tensor(0))
416
+
417
+ def advance_step_(self):
418
+ self.step.add_(1)
419
+
412
420
  def firefly_step(
413
421
  self,
414
422
  fitness,
@@ -460,7 +468,7 @@ class LatentGenePool(Module):
460
468
  self,
461
469
  fitness, # Float['p'],
462
470
  inplace = True,
463
- migrate = False # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
471
+ migrate = None # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
464
472
  ):
465
473
  device = self.latents.device
466
474
 
@@ -547,8 +555,9 @@ class LatentGenePool(Module):
547
555
 
548
556
  # 6. maybe migration
549
557
 
558
+ migrate = self.can_migrate and default(migrate, divisible_by(self.step.item(), self.migrate_every))
559
+
550
560
  if migrate:
551
- assert self.num_islands > 1
552
561
  randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
553
562
 
554
563
  migrate_mask = randperm < self.num_migrate
@@ -581,6 +590,8 @@ class LatentGenePool(Module):
581
590
 
582
591
  self.latents.copy_(genes)
583
592
 
593
+ self.advance_step_()
594
+
584
595
  def forward(
585
596
  self,
586
597
  *args,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.47
3
+ Version: 0.0.48
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=-uRpnD0dKF6h4drVSikm9HnlP2OZ0WYQSWRQcghzd9Y,32242
2
+ evolutionary_policy_optimization/epo.py,sha256=FkliOiKdmUvKuwFqb1_A-ddahnOqjTR8Djx_I6UZAlU,32625
3
3
  evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
5
- evolutionary_policy_optimization-0.0.47.dist-info/METADATA,sha256=oSI5NowsOOlQZ5cPmCs-8kYeG6TmzUybpRZt_6-cFWk,6213
6
- evolutionary_policy_optimization-0.0.47.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.47.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.47.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.48.dist-info/METADATA,sha256=GpuUVs0VO2ydhU3X4-A_cA_xNmvdtYaAM8tb_VKneBo,6213
6
+ evolutionary_policy_optimization-0.0.48.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.48.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.48.dist-info/RECORD,,