evolutionary-policy-optimization 0.0.23__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module, ModuleList
9
9
  from torch.utils.data import TensorDataset, DataLoader
10
10
 
11
+ import einx
11
12
  from einops import rearrange, repeat, einsum
12
13
  from einops.layers.torch import Rearrange
13
14
 
@@ -293,7 +294,7 @@ class ShouldRunGeneticAlgorithm(Module):
293
294
  # however, this equation does not make much sense to me if fitness increases unbounded
294
295
  # just let it be customizable, and offer a variant where mean and variance is over some threshold (could account for skew too)
295
296
 
296
- return (fitnesses.amax() - fitnesses.amin()) > (self.gamma * torch.median(fitnesses))
297
+ return (fitnesses.amax(dim = -1) - fitnesses.amin(dim = -1)) > (self.gamma * torch.median(fitnesses, dim = -1).values)
297
298
 
298
299
  # classes
299
300
 
@@ -389,9 +390,6 @@ class LatentGenePool(Module):
389
390
  islands = self.num_islands
390
391
  tournament_participants = self.num_tournament_participants
391
392
 
392
- if not self.should_run_genetic_algorithm(fitness):
393
- return
394
-
395
393
  assert self.num_latents > 1
396
394
 
397
395
  genes = self.latents # the latents are the genes
@@ -399,12 +397,25 @@ class LatentGenePool(Module):
399
397
  pop_size = genes.shape[0]
400
398
  assert pop_size == fitness.shape[0]
401
399
 
400
+ pop_size_per_island = pop_size // islands
401
+
402
402
  # split out the islands
403
403
 
404
- genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
405
404
  fitness = rearrange(fitness, '(i p) -> i p', i = islands)
406
405
 
407
- pop_size_per_island = pop_size // islands
406
+ # from the fitness, decide whether to actually run the genetic algorithm or not
407
+
408
+ should_update_per_island = self.should_run_genetic_algorithm(fitness)
409
+
410
+ if not should_update_per_island.any():
411
+ if inplace:
412
+ return
413
+
414
+ return genes
415
+
416
+ genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
417
+
418
+ orig_genes = genes
408
419
 
409
420
  # 1. natural selection is simple in silico
410
421
  # you sort the population by the fitness and slice off the least fit end
@@ -456,6 +467,10 @@ class LatentGenePool(Module):
456
467
 
457
468
  genes = self.maybe_l2norm(genes)
458
469
 
470
+ # account for criteria of whether to actually run GA or not
471
+
472
+ genes = einx.where('i, i ..., i ...', should_update_per_island, genes, orig_genes)
473
+
459
474
  # merge island back into pop dimension
460
475
 
461
476
  genes = rearrange(genes, 'i p ... -> (i p) ...')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.23
3
+ Version: 0.0.24
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: adam-atan2-pytorch
38
38
  Requires-Dist: assoc-scan
39
- Requires-Dist: einops>=0.8.0
39
+ Requires-Dist: einops>=0.8.1
40
+ Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: hl-gauss-pytorch>=0.1.19
41
42
  Requires-Dist: torch>=2.2
42
43
  Requires-Dist: tqdm
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=-kQgrnnOLiCOZ-6EroO057tDx0sS7TQro92cjJhSbZU,20353
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.24.dist-info/METADATA,sha256=d3imh1p1-nPpNGhD8cReLdL07_-oHZs3YqJaOEJi1TM,4958
5
+ evolutionary_policy_optimization-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.24.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=lMKMGaJdSh71Dnkn6ZvbtVNrCJ00Cv3p_uXJy9D8K90,19908
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.23.dist-info/METADATA,sha256=m_EqBjggqqEq09A9KqjA8BQMWECLIvzwbOR8pc-UFrM,4931
5
- evolutionary_policy_optimization-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.23.dist-info/RECORD,,