evolutionary-policy-optimization 0.0.23__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +21 -6
- {evolutionary_policy_optimization-0.0.23.dist-info → evolutionary_policy_optimization-0.0.24.dist-info}/METADATA +3 -2
- evolutionary_policy_optimization-0.0.24.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.23.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.23.dist-info → evolutionary_policy_optimization-0.0.24.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.23.dist-info → evolutionary_policy_optimization-0.0.24.dist-info}/licenses/LICENSE +0 -0
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Linear, Module, ModuleList
|
9
9
|
from torch.utils.data import TensorDataset, DataLoader
|
10
10
|
|
11
|
+
import einx
|
11
12
|
from einops import rearrange, repeat, einsum
|
12
13
|
from einops.layers.torch import Rearrange
|
13
14
|
|
@@ -293,7 +294,7 @@ class ShouldRunGeneticAlgorithm(Module):
|
|
293
294
|
# however, this equation does not make much sense to me if fitness increases unbounded
|
294
295
|
# just let it be customizable, and offer a variant where mean and variance is over some threshold (could account for skew too)
|
295
296
|
|
296
|
-
return (fitnesses.amax() - fitnesses.amin()) > (self.gamma * torch.median(fitnesses))
|
297
|
+
return (fitnesses.amax(dim = -1) - fitnesses.amin(dim = -1)) > (self.gamma * torch.median(fitnesses, dim = -1).values)
|
297
298
|
|
298
299
|
# classes
|
299
300
|
|
@@ -389,9 +390,6 @@ class LatentGenePool(Module):
|
|
389
390
|
islands = self.num_islands
|
390
391
|
tournament_participants = self.num_tournament_participants
|
391
392
|
|
392
|
-
if not self.should_run_genetic_algorithm(fitness):
|
393
|
-
return
|
394
|
-
|
395
393
|
assert self.num_latents > 1
|
396
394
|
|
397
395
|
genes = self.latents # the latents are the genes
|
@@ -399,12 +397,25 @@ class LatentGenePool(Module):
|
|
399
397
|
pop_size = genes.shape[0]
|
400
398
|
assert pop_size == fitness.shape[0]
|
401
399
|
|
400
|
+
pop_size_per_island = pop_size // islands
|
401
|
+
|
402
402
|
# split out the islands
|
403
403
|
|
404
|
-
genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
|
405
404
|
fitness = rearrange(fitness, '(i p) -> i p', i = islands)
|
406
405
|
|
407
|
-
|
406
|
+
# from the fitness, decide whether to actually run the genetic algorithm or not
|
407
|
+
|
408
|
+
should_update_per_island = self.should_run_genetic_algorithm(fitness)
|
409
|
+
|
410
|
+
if not should_update_per_island.any():
|
411
|
+
if inplace:
|
412
|
+
return
|
413
|
+
|
414
|
+
return genes
|
415
|
+
|
416
|
+
genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
|
417
|
+
|
418
|
+
orig_genes = genes
|
408
419
|
|
409
420
|
# 1. natural selection is simple in silico
|
410
421
|
# you sort the population by the fitness and slice off the least fit end
|
@@ -456,6 +467,10 @@ class LatentGenePool(Module):
|
|
456
467
|
|
457
468
|
genes = self.maybe_l2norm(genes)
|
458
469
|
|
470
|
+
# account for criteria of whether to actually run GA or not
|
471
|
+
|
472
|
+
genes = einx.where('i, i ..., i ...', should_update_per_island, genes, orig_genes)
|
473
|
+
|
459
474
|
# merge island back into pop dimension
|
460
475
|
|
461
476
|
genes = rearrange(genes, 'i p ... -> (i p) ...')
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.24
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: adam-atan2-pytorch
|
38
38
|
Requires-Dist: assoc-scan
|
39
|
-
Requires-Dist: einops>=0.8.
|
39
|
+
Requires-Dist: einops>=0.8.1
|
40
|
+
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hl-gauss-pytorch>=0.1.19
|
41
42
|
Requires-Dist: torch>=2.2
|
42
43
|
Requires-Dist: tqdm
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=-kQgrnnOLiCOZ-6EroO057tDx0sS7TQro92cjJhSbZU,20353
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.24.dist-info/METADATA,sha256=d3imh1p1-nPpNGhD8cReLdL07_-oHZs3YqJaOEJi1TM,4958
|
5
|
+
evolutionary_policy_optimization-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.24.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=lMKMGaJdSh71Dnkn6ZvbtVNrCJ00Cv3p_uXJy9D8K90,19908
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.23.dist-info/METADATA,sha256=m_EqBjggqqEq09A9KqjA8BQMWECLIvzwbOR8pc-UFrM,4931
|
5
|
-
evolutionary_policy_optimization-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.23.dist-info/RECORD,,
|
File without changes
|