evolutionary-policy-optimization 0.0.27__tar.gz → 0.0.29__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.27
3
+ Version: 0.0.29
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -311,6 +311,7 @@ class LatentGenePool(Module):
311
311
  frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
312
312
  frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
313
313
  frac_elitism = 0.1, # frac of population to preserve from being noised
314
+ frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
314
315
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
315
316
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
316
317
  default_should_run_ga_gamma = 1.5
@@ -348,12 +349,16 @@ class LatentGenePool(Module):
348
349
  self.num_natural_selected = int(frac_natural_selected * latents_per_island)
349
350
 
350
351
  self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
352
+
351
353
  self.crossover_random = crossover_random
352
354
 
353
355
  self.mutation_strength = mutation_strength
354
356
  self.num_elites = int(frac_elitism * latents_per_island)
355
357
  self.has_elites = self.num_elites > 0
356
358
 
359
+ latents_without_elites = num_latents - self.num_elites
360
+ self.num_migrate = int(frac_migrate * latents_without_elites)
361
+
357
362
  if not exists(should_run_genetic_algorithm):
358
363
  should_run_genetic_algorithm = ShouldRunGeneticAlgorithm(gamma = default_should_run_ga_gamma)
359
364
 
@@ -365,7 +370,6 @@ class LatentGenePool(Module):
365
370
  beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
366
371
  gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
367
372
  alpha = 0.1, # exploration factor
368
- alpha_decay = 0.995, # exploration decay each step
369
373
  inplace = True,
370
374
  ):
371
375
  islands = self.num_islands
@@ -411,8 +415,11 @@ class LatentGenePool(Module):
411
415
  def genetic_algorithm_step(
412
416
  self,
413
417
  fitness, # Float['p'],
414
- inplace = True
418
+ inplace = True,
419
+ migrate = False # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
415
420
  ):
421
+ device = self.latents.device
422
+
416
423
  """
417
424
  i - islands
418
425
  p - population
@@ -462,7 +469,7 @@ class LatentGenePool(Module):
462
469
 
463
470
  # 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
464
471
 
465
- rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
472
+ rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants), device = device).argsort(dim = -1)
466
473
  rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
467
474
 
468
475
  participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
@@ -494,6 +501,20 @@ class LatentGenePool(Module):
494
501
 
495
502
  genes = mutation(genes, mutation_strength = self.mutation_strength)
496
503
 
504
+ # 6. maybe migration
505
+
506
+ if migrate:
507
+ assert self.num_islands > 1
508
+ randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
509
+
510
+ migrate_mask = randperm < self.num_migrate
511
+
512
+ nonmigrants = rearrange(genes[~migrate_mask], '(i p) g -> i p g', i = islands)
513
+ migrants = rearrange(genes[migrate_mask], '(i p) g -> i p g', i = islands)
514
+ migrants = torch.roll(migrants, 1, dims = 0)
515
+
516
+ genes = cat((nonmigrants, migrants), dim = 1)
517
+
497
518
  # add back the elites
498
519
 
499
520
  if self.has_elites:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.27"
3
+ version = "0.0.29"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -37,7 +37,9 @@ def test_readme(
37
37
 
38
38
  fitness = torch.randn(128)
39
39
 
40
- latent_pool.genetic_algorithm_step(fitness) # update once
40
+ latent_pool.genetic_algorithm_step(fitness, migrate = num_islands > 1) # update once
41
+
42
+ latent_pool.firefly_step(fitness)
41
43
 
42
44
  @pytest.mark.parametrize('latent_ids', (2, (2, 4)))
43
45
  def test_create_agent(