evolutionary-policy-optimization 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module, ModuleList
9
9
  from torch.utils.data import TensorDataset, DataLoader
10
10
 
11
+ import einx
11
12
  from einops import rearrange, repeat, einsum
12
13
  from einops.layers.torch import Rearrange
13
14
 
@@ -31,11 +32,14 @@ def identity(t):
31
32
  def xnor(x, y):
32
33
  return not (x ^ y)
33
34
 
34
- def l2norm(t):
35
- return F.normalize(t, p = 2, dim = -1)
35
+ def divisible_by(num, den):
36
+ return (num % den) == 0
36
37
 
37
38
  # tensor helpers
38
39
 
40
+ def l2norm(t):
41
+ return F.normalize(t, p = 2, dim = -1)
42
+
39
43
  def log(t, eps = 1e-20):
40
44
  return t.clamp(min = eps).log()
41
45
 
@@ -290,7 +294,7 @@ class ShouldRunGeneticAlgorithm(Module):
290
294
  # however, this equation does not make much sense to me if fitness increases unbounded
291
295
  # just let it be customizable, and offer a variant where mean and variance is over some threshold (could account for skew too)
292
296
 
293
- return (fitnesses.amax() - fitnesses.amin()) > (self.gamma * torch.median(fitnesses))
297
+ return (fitnesses.amax(dim = -1) - fitnesses.amin(dim = -1)) > (self.gamma * torch.median(fitnesses, dim = -1).values)
294
298
 
295
299
  # classes
296
300
 
@@ -300,6 +304,7 @@ class LatentGenePool(Module):
300
304
  num_latents, # same as gene pool size
301
305
  dim_latent, # gene dimension
302
306
  num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
307
+ num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
303
308
  dim_state = None,
304
309
  frozen_latents = True,
305
310
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
@@ -340,6 +345,9 @@ class LatentGenePool(Module):
340
345
 
341
346
  # some derived values
342
347
 
348
+ assert num_islands >= 1
349
+ assert divisible_by(num_latents, num_islands)
350
+
343
351
  assert 0. < frac_tournaments < 1.
344
352
  assert 0. < frac_natural_selected < 1.
345
353
  assert 0. <= frac_elitism < 1.
@@ -347,13 +355,16 @@ class LatentGenePool(Module):
347
355
 
348
356
  self.dim_latent = dim_latent
349
357
  self.num_latents = num_latents
350
- self.num_natural_selected = int(frac_natural_selected * num_latents)
358
+ self.num_islands = num_islands
359
+
360
+ latents_per_island = num_latents // num_islands
361
+ self.num_natural_selected = int(frac_natural_selected * latents_per_island)
351
362
 
352
363
  self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
353
364
  self.crossover_random = crossover_random
354
365
 
355
366
  self.mutation_strength = mutation_strength
356
- self.num_elites = int(frac_elitism * num_latents)
367
+ self.num_elites = int(frac_elitism * latents_per_island)
357
368
  self.has_elites = self.num_elites > 0
358
369
 
359
370
  if not exists(should_run_genetic_algorithm):
@@ -369,13 +380,15 @@ class LatentGenePool(Module):
369
380
  inplace = True
370
381
  ):
371
382
  """
383
+ i - islands
372
384
  p - population
373
385
  g - gene dimension
374
386
  n - number of genes per individual
387
+ t - num tournament participants
375
388
  """
376
389
 
377
- if not self.should_run_genetic_algorithm(fitness):
378
- return
390
+ islands = self.num_islands
391
+ tournament_participants = self.num_tournament_participants
379
392
 
380
393
  assert self.num_latents > 1
381
394
 
@@ -384,39 +397,64 @@ class LatentGenePool(Module):
384
397
  pop_size = genes.shape[0]
385
398
  assert pop_size == fitness.shape[0]
386
399
 
400
+ pop_size_per_island = pop_size // islands
401
+
402
+ # split out the islands
403
+
404
+ fitness = rearrange(fitness, '(i p) -> i p', i = islands)
405
+
406
+ # from the fitness, decide whether to actually run the genetic algorithm or not
407
+
408
+ should_update_per_island = self.should_run_genetic_algorithm(fitness)
409
+
410
+ if not should_update_per_island.any():
411
+ if inplace:
412
+ return
413
+
414
+ return genes
415
+
416
+ genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
417
+
418
+ orig_genes = genes
419
+
387
420
  # 1. natural selection is simple in silico
388
421
  # you sort the population by the fitness and slice off the least fit end
389
422
 
390
- sorted_indices = fitness.sort().indices
391
- natural_selected_indices = sorted_indices[-self.num_natural_selected:]
392
- genes, fitness = genes[natural_selected_indices], fitness[natural_selected_indices]
423
+ sorted_indices = fitness.sort(dim = -1).indices
424
+ natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
425
+ natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... n g', n = genes.shape[-2], g = genes.shape[-1])
426
+
427
+ genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
393
428
 
394
429
  # 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
395
430
 
396
- batch_randperm = torch.randn((pop_size - self.num_natural_selected, self.num_tournament_participants)).argsort(dim = -1)
431
+ rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
432
+ rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
397
433
 
398
- participants = genes[batch_randperm]
399
- participant_fitness = fitness[batch_randperm]
434
+ participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
435
+ participant_fitness = rearrange(participant_fitness, 'i (p t) -> i p t', t = tournament_participants)
400
436
 
401
- tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
437
+ parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
438
+ parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
402
439
 
403
- tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
440
+ parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) n g', n = genes.shape[-2], g = genes.shape[-1])
404
441
 
405
- parents = participants.gather(-3, tournament_winner_indices)
442
+ parents = genes.gather(1, parent_gene_ids_for_gather)
443
+ parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
406
444
 
407
445
  # 3. do a crossover of the parents - in their case they went for a simple averaging, but since we are doing tournament style and the same pair of parents may be re-selected, lets make it random interpolation
408
446
 
409
- parent1, parent2 = parents.unbind(dim = 1)
447
+ parent1, parent2 = parents.unbind(dim = 2)
410
448
  children = crossover_latents(parent1, parent2, random = self.crossover_random)
411
449
 
412
450
  # append children to gene pool
413
451
 
414
- genes = cat((children, genes))
452
+ genes = cat((children, genes), dim = 1)
415
453
 
416
454
  # 4. they use the elitism strategy to protect best performing genes from being changed
417
455
 
418
456
  if self.has_elites:
419
- genes, elites = genes[:-self.num_elites], genes[-self.num_elites:]
457
+ genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
420
458
 
421
459
  # 5. mutate with gaussian noise - todo: add drawing the mutation rate from exponential distribution, from the fast genetic algorithms paper from 2017
422
460
 
@@ -425,10 +463,18 @@ class LatentGenePool(Module):
425
463
  # add back the elites
426
464
 
427
465
  if self.has_elites:
428
- genes = cat((genes, elites))
466
+ genes = cat((genes, elites), dim = 1)
429
467
 
430
468
  genes = self.maybe_l2norm(genes)
431
469
 
470
+ # account for criteria of whether to actually run GA or not
471
+
472
+ genes = einx.where('i, i ..., i ...', should_update_per_island, genes, orig_genes)
473
+
474
+ # merge island back into pop dimension
475
+
476
+ genes = rearrange(genes, 'i p ... -> (i p) ...')
477
+
432
478
  if not inplace:
433
479
  return genes
434
480
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: adam-atan2-pytorch
38
38
  Requires-Dist: assoc-scan
39
- Requires-Dist: einops>=0.8.0
39
+ Requires-Dist: einops>=0.8.1
40
+ Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hl-gauss-pytorch>=0.1.19
41
42
  Requires-Dist: torch>=2.2
42
43
  Requires-Dist: tqdm
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=-kQgrnnOLiCOZ-6EroO057tDx0sS7TQro92cjJhSbZU,20353
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.24.dist-info/METADATA,sha256=d3imh1p1-nPpNGhD8cReLdL07_-oHZs3YqJaOEJi1TM,4958
5
+ evolutionary_policy_optimization-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.24.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=TbUX2L-Wa2zIZ2b7iHmBtaym-qDSLAFrC7iU7xReX_k,18449
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.22.dist-info/METADATA,sha256=L3G-tesSEyhrc_SbTN6HuJQlXfogEUvr3W9SXPcnRVw,4931
5
- evolutionary_policy_optimization-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.22.dist-info/RECORD,,