evolutionary-policy-optimization 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module, ModuleList
9
9
  from torch.utils.data import TensorDataset, DataLoader
10
10
 
11
+ import einx
11
12
  from einops import rearrange, repeat, einsum
12
13
  from einops.layers.torch import Rearrange
13
14
 
@@ -293,7 +294,7 @@ class ShouldRunGeneticAlgorithm(Module):
293
294
  # however, this equation does not make much sense to me if fitness increases unbounded
294
295
  # just let it be customizable, and offer a variant where mean and variance is over some threshold (could account for skew too)
295
296
 
296
- return (fitnesses.amax() - fitnesses.amin()) > (self.gamma * torch.median(fitnesses))
297
+ return (fitnesses.amax(dim = -1) - fitnesses.amin(dim = -1)) > (self.gamma * torch.median(fitnesses, dim = -1).values)
297
298
 
298
299
  # classes
299
300
 
@@ -371,6 +372,49 @@ class LatentGenePool(Module):
371
372
 
372
373
  self.should_run_genetic_algorithm = should_run_genetic_algorithm
373
374
 
375
+ def firefly_step(
376
+ self,
377
+ fitness,
378
+ beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
379
+ gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
380
+ alpha = 0.1, # exploration factor
381
+ alpha_decay = 0.995, # exploration decay each step
382
+ inplace = True,
383
+ ):
384
+ islands = self.num_islands
385
+ fireflies = self.latents # the latents are the fireflies
386
+
387
+ assert fitness.shape[0] == fireflies.shape[0]
388
+
389
+ fitness = rearrange(fitness, '(i p) -> i p', i = islands)
390
+ fireflies = rearrange(fireflies, '(i p) ... -> i p ...', i = islands)
391
+
392
+ # fireflies with lower light intensity (high cost) moves towards the higher intensity (lower cost)
393
+
394
+ move_mask = einx.less('i x, i y -> i x y', fitness, fitness)
395
+
396
+ # get vectors of fireflies to one another
397
+ # calculate distance and the beta
398
+
399
+ delta_positions = einx.subtract('i y ... d, i x ... d -> i x y ... d', fireflies, fireflies)
400
+
401
+ distance = delta_positions.norm(dim = -1)
402
+
403
+ betas = beta0 * (-gamma * distance ** 2).exp()
404
+
405
+ # move the fireflies according to attraction
406
+
407
+ fireflies += einsum(move_mask, betas, delta_positions, 'i x y, i x y ..., i x y ... -> i x ...')
408
+
409
+ # merge back the islands
410
+
411
+ fireflies = rearrange(fireflies, 'i p ... -> (i p) ...')
412
+
413
+ if not inplace:
414
+ return fireflies
415
+
416
+ self.latents.copy_(fireflies)
417
+
374
418
  @torch.no_grad()
375
419
  # non-gradient optimization, at least, not on the individual level (taken care of by rl component)
376
420
  def genetic_algorithm_step(
@@ -389,9 +433,6 @@ class LatentGenePool(Module):
389
433
  islands = self.num_islands
390
434
  tournament_participants = self.num_tournament_participants
391
435
 
392
- if not self.should_run_genetic_algorithm(fitness):
393
- return
394
-
395
436
  assert self.num_latents > 1
396
437
 
397
438
  genes = self.latents # the latents are the genes
@@ -399,12 +440,25 @@ class LatentGenePool(Module):
399
440
  pop_size = genes.shape[0]
400
441
  assert pop_size == fitness.shape[0]
401
442
 
443
+ pop_size_per_island = pop_size // islands
444
+
402
445
  # split out the islands
403
446
 
404
- genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
405
447
  fitness = rearrange(fitness, '(i p) -> i p', i = islands)
406
448
 
407
- pop_size_per_island = pop_size // islands
449
+ # from the fitness, decide whether to actually run the genetic algorithm or not
450
+
451
+ should_update_per_island = self.should_run_genetic_algorithm(fitness)
452
+
453
+ if not should_update_per_island.any():
454
+ if inplace:
455
+ return
456
+
457
+ return genes
458
+
459
+ genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
460
+
461
+ orig_genes = genes
408
462
 
409
463
  # 1. natural selection is simple in silico
410
464
  # you sort the population by the fitness and slice off the least fit end
@@ -456,6 +510,10 @@ class LatentGenePool(Module):
456
510
 
457
511
  genes = self.maybe_l2norm(genes)
458
512
 
513
+ # account for criteria of whether to actually run GA or not
514
+
515
+ genes = einx.where('i, i ..., i ...', should_update_per_island, genes, orig_genes)
516
+
459
517
  # merge island back into pop dimension
460
518
 
461
519
  genes = rearrange(genes, 'i p ... -> (i p) ...')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.23
3
+ Version: 0.0.25
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: adam-atan2-pytorch
38
38
  Requires-Dist: assoc-scan
39
- Requires-Dist: einops>=0.8.0
39
+ Requires-Dist: einops>=0.8.1
40
+ Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hl-gauss-pytorch>=0.1.19
41
42
  Requires-Dist: torch>=2.2
42
43
  Requires-Dist: tqdm
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=BLwy7PBZOjw6H7MFvMq9CC7Mdm3K8fpzBNH6HbNu6LY,21927
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.25.dist-info/METADATA,sha256=p3-_SuLvKs8E0z1l567qA0Pbsv2dOLlrJPX4WYoZaB4,4958
5
+ evolutionary_policy_optimization-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.25.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=lMKMGaJdSh71Dnkn6ZvbtVNrCJ00Cv3p_uXJy9D8K90,19908
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.23.dist-info/METADATA,sha256=m_EqBjggqqEq09A9KqjA8BQMWECLIvzwbOR8pc-UFrM,4931
5
- evolutionary_policy_optimization-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.23.dist-info/RECORD,,