evolutionary-policy-optimization 0.0.22__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -31,11 +31,14 @@ def identity(t):
31
31
  def xnor(x, y):
32
32
  return not (x ^ y)
33
33
 
34
- def l2norm(t):
35
- return F.normalize(t, p = 2, dim = -1)
34
+ def divisible_by(num, den):
35
+ return (num % den) == 0
36
36
 
37
37
  # tensor helpers
38
38
 
39
+ def l2norm(t):
40
+ return F.normalize(t, p = 2, dim = -1)
41
+
39
42
  def log(t, eps = 1e-20):
40
43
  return t.clamp(min = eps).log()
41
44
 
@@ -300,6 +303,7 @@ class LatentGenePool(Module):
300
303
  num_latents, # same as gene pool size
301
304
  dim_latent, # gene dimension
302
305
  num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
306
+ num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
303
307
  dim_state = None,
304
308
  frozen_latents = True,
305
309
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
@@ -340,6 +344,9 @@ class LatentGenePool(Module):
340
344
 
341
345
  # some derived values
342
346
 
347
+ assert num_islands >= 1
348
+ assert divisible_by(num_latents, num_islands)
349
+
343
350
  assert 0. < frac_tournaments < 1.
344
351
  assert 0. < frac_natural_selected < 1.
345
352
  assert 0. <= frac_elitism < 1.
@@ -347,13 +354,16 @@ class LatentGenePool(Module):
347
354
 
348
355
  self.dim_latent = dim_latent
349
356
  self.num_latents = num_latents
350
- self.num_natural_selected = int(frac_natural_selected * num_latents)
357
+ self.num_islands = num_islands
358
+
359
+ latents_per_island = num_latents // num_islands
360
+ self.num_natural_selected = int(frac_natural_selected * latents_per_island)
351
361
 
352
362
  self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
353
363
  self.crossover_random = crossover_random
354
364
 
355
365
  self.mutation_strength = mutation_strength
356
- self.num_elites = int(frac_elitism * num_latents)
366
+ self.num_elites = int(frac_elitism * latents_per_island)
357
367
  self.has_elites = self.num_elites > 0
358
368
 
359
369
  if not exists(should_run_genetic_algorithm):
@@ -369,11 +379,16 @@ class LatentGenePool(Module):
369
379
  inplace = True
370
380
  ):
371
381
  """
382
+ i - islands
372
383
  p - population
373
384
  g - gene dimension
374
385
  n - number of genes per individual
386
+ t - num tournament participants
375
387
  """
376
388
 
389
+ islands = self.num_islands
390
+ tournament_participants = self.num_tournament_participants
391
+
377
392
  if not self.should_run_genetic_algorithm(fitness):
378
393
  return
379
394
 
@@ -384,39 +399,51 @@ class LatentGenePool(Module):
384
399
  pop_size = genes.shape[0]
385
400
  assert pop_size == fitness.shape[0]
386
401
 
402
+ # split out the islands
403
+
404
+ genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
405
+ fitness = rearrange(fitness, '(i p) -> i p', i = islands)
406
+
407
+ pop_size_per_island = pop_size // islands
408
+
387
409
  # 1. natural selection is simple in silico
388
410
  # you sort the population by the fitness and slice off the least fit end
389
411
 
390
- sorted_indices = fitness.sort().indices
391
- natural_selected_indices = sorted_indices[-self.num_natural_selected:]
392
- genes, fitness = genes[natural_selected_indices], fitness[natural_selected_indices]
412
+ sorted_indices = fitness.sort(dim = -1).indices
413
+ natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
414
+ natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... n g', n = genes.shape[-2], g = genes.shape[-1])
415
+
416
+ genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
393
417
 
394
418
  # 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
395
419
 
396
- batch_randperm = torch.randn((pop_size - self.num_natural_selected, self.num_tournament_participants)).argsort(dim = -1)
420
+ rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
421
+ rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
397
422
 
398
- participants = genes[batch_randperm]
399
- participant_fitness = fitness[batch_randperm]
423
+ participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
424
+ participant_fitness = rearrange(participant_fitness, 'i (p t) -> i p t', t = tournament_participants)
400
425
 
401
- tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
426
+ parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
427
+ parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
402
428
 
403
- tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
429
+ parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) n g', n = genes.shape[-2], g = genes.shape[-1])
404
430
 
405
- parents = participants.gather(-3, tournament_winner_indices)
431
+ parents = genes.gather(1, parent_gene_ids_for_gather)
432
+ parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
406
433
 
407
434
  # 3. do a crossover of the parents - in their case they went for a simple averaging, but since we are doing tournament style and the same pair of parents may be re-selected, lets make it random interpolation
408
435
 
409
- parent1, parent2 = parents.unbind(dim = 1)
436
+ parent1, parent2 = parents.unbind(dim = 2)
410
437
  children = crossover_latents(parent1, parent2, random = self.crossover_random)
411
438
 
412
439
  # append children to gene pool
413
440
 
414
- genes = cat((children, genes))
441
+ genes = cat((children, genes), dim = 1)
415
442
 
416
443
  # 4. they use the elitism strategy to protect best performing genes from being changed
417
444
 
418
445
  if self.has_elites:
419
- genes, elites = genes[:-self.num_elites], genes[-self.num_elites:]
446
+ genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
420
447
 
421
448
  # 5. mutate with gaussian noise - todo: add drawing the mutation rate from exponential distribution, from the fast genetic algorithms paper from 2017
422
449
 
@@ -425,10 +452,14 @@ class LatentGenePool(Module):
425
452
  # add back the elites
426
453
 
427
454
  if self.has_elites:
428
- genes = cat((genes, elites))
455
+ genes = cat((genes, elites), dim = 1)
429
456
 
430
457
  genes = self.maybe_l2norm(genes)
431
458
 
459
+ # merge island back into pop dimension
460
+
461
+ genes = rearrange(genes, 'i p ... -> (i p) ...')
462
+
432
463
  if not inplace:
433
464
  return genes
434
465
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.22
3
+ Version: 0.0.23
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=lMKMGaJdSh71Dnkn6ZvbtVNrCJ00Cv3p_uXJy9D8K90,19908
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.23.dist-info/METADATA,sha256=m_EqBjggqqEq09A9KqjA8BQMWECLIvzwbOR8pc-UFrM,4931
5
+ evolutionary_policy_optimization-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.23.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=TbUX2L-Wa2zIZ2b7iHmBtaym-qDSLAFrC7iU7xReX_k,18449
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.22.dist-info/METADATA,sha256=L3G-tesSEyhrc_SbTN6HuJQlXfogEUvr3W9SXPcnRVw,4931
5
- evolutionary_policy_optimization-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.22.dist-info/RECORD,,