evolutionary-policy-optimization 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +66 -20
- {evolutionary_policy_optimization-0.0.22.dist-info → evolutionary_policy_optimization-0.0.24.dist-info}/METADATA +3 -2
- evolutionary_policy_optimization-0.0.24.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.22.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.22.dist-info → evolutionary_policy_optimization-0.0.24.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.22.dist-info → evolutionary_policy_optimization-0.0.24.dist-info}/licenses/LICENSE +0 -0
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Linear, Module, ModuleList
|
9
9
|
from torch.utils.data import TensorDataset, DataLoader
|
10
10
|
|
11
|
+
import einx
|
11
12
|
from einops import rearrange, repeat, einsum
|
12
13
|
from einops.layers.torch import Rearrange
|
13
14
|
|
@@ -31,11 +32,14 @@ def identity(t):
|
|
31
32
|
def xnor(x, y):
|
32
33
|
return not (x ^ y)
|
33
34
|
|
34
|
-
def
|
35
|
-
return
|
35
|
+
def divisible_by(num, den):
|
36
|
+
return (num % den) == 0
|
36
37
|
|
37
38
|
# tensor helpers
|
38
39
|
|
40
|
+
def l2norm(t):
|
41
|
+
return F.normalize(t, p = 2, dim = -1)
|
42
|
+
|
39
43
|
def log(t, eps = 1e-20):
|
40
44
|
return t.clamp(min = eps).log()
|
41
45
|
|
@@ -290,7 +294,7 @@ class ShouldRunGeneticAlgorithm(Module):
|
|
290
294
|
# however, this equation does not make much sense to me if fitness increases unbounded
|
291
295
|
# just let it be customizable, and offer a variant where mean and variance is over some threshold (could account for skew too)
|
292
296
|
|
293
|
-
return (fitnesses.amax() - fitnesses.amin()) > (self.gamma * torch.median(fitnesses))
|
297
|
+
return (fitnesses.amax(dim = -1) - fitnesses.amin(dim = -1)) > (self.gamma * torch.median(fitnesses, dim = -1).values)
|
294
298
|
|
295
299
|
# classes
|
296
300
|
|
@@ -300,6 +304,7 @@ class LatentGenePool(Module):
|
|
300
304
|
num_latents, # same as gene pool size
|
301
305
|
dim_latent, # gene dimension
|
302
306
|
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
307
|
+
num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
|
303
308
|
dim_state = None,
|
304
309
|
frozen_latents = True,
|
305
310
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
@@ -340,6 +345,9 @@ class LatentGenePool(Module):
|
|
340
345
|
|
341
346
|
# some derived values
|
342
347
|
|
348
|
+
assert num_islands >= 1
|
349
|
+
assert divisible_by(num_latents, num_islands)
|
350
|
+
|
343
351
|
assert 0. < frac_tournaments < 1.
|
344
352
|
assert 0. < frac_natural_selected < 1.
|
345
353
|
assert 0. <= frac_elitism < 1.
|
@@ -347,13 +355,16 @@ class LatentGenePool(Module):
|
|
347
355
|
|
348
356
|
self.dim_latent = dim_latent
|
349
357
|
self.num_latents = num_latents
|
350
|
-
self.
|
358
|
+
self.num_islands = num_islands
|
359
|
+
|
360
|
+
latents_per_island = num_latents // num_islands
|
361
|
+
self.num_natural_selected = int(frac_natural_selected * latents_per_island)
|
351
362
|
|
352
363
|
self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
|
353
364
|
self.crossover_random = crossover_random
|
354
365
|
|
355
366
|
self.mutation_strength = mutation_strength
|
356
|
-
self.num_elites = int(frac_elitism *
|
367
|
+
self.num_elites = int(frac_elitism * latents_per_island)
|
357
368
|
self.has_elites = self.num_elites > 0
|
358
369
|
|
359
370
|
if not exists(should_run_genetic_algorithm):
|
@@ -369,13 +380,15 @@ class LatentGenePool(Module):
|
|
369
380
|
inplace = True
|
370
381
|
):
|
371
382
|
"""
|
383
|
+
i - islands
|
372
384
|
p - population
|
373
385
|
g - gene dimension
|
374
386
|
n - number of genes per individual
|
387
|
+
t - num tournament participants
|
375
388
|
"""
|
376
389
|
|
377
|
-
|
378
|
-
|
390
|
+
islands = self.num_islands
|
391
|
+
tournament_participants = self.num_tournament_participants
|
379
392
|
|
380
393
|
assert self.num_latents > 1
|
381
394
|
|
@@ -384,39 +397,64 @@ class LatentGenePool(Module):
|
|
384
397
|
pop_size = genes.shape[0]
|
385
398
|
assert pop_size == fitness.shape[0]
|
386
399
|
|
400
|
+
pop_size_per_island = pop_size // islands
|
401
|
+
|
402
|
+
# split out the islands
|
403
|
+
|
404
|
+
fitness = rearrange(fitness, '(i p) -> i p', i = islands)
|
405
|
+
|
406
|
+
# from the fitness, decide whether to actually run the genetic algorithm or not
|
407
|
+
|
408
|
+
should_update_per_island = self.should_run_genetic_algorithm(fitness)
|
409
|
+
|
410
|
+
if not should_update_per_island.any():
|
411
|
+
if inplace:
|
412
|
+
return
|
413
|
+
|
414
|
+
return genes
|
415
|
+
|
416
|
+
genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
|
417
|
+
|
418
|
+
orig_genes = genes
|
419
|
+
|
387
420
|
# 1. natural selection is simple in silico
|
388
421
|
# you sort the population by the fitness and slice off the least fit end
|
389
422
|
|
390
|
-
sorted_indices = fitness.sort().indices
|
391
|
-
natural_selected_indices = sorted_indices[-self.num_natural_selected:]
|
392
|
-
|
423
|
+
sorted_indices = fitness.sort(dim = -1).indices
|
424
|
+
natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
|
425
|
+
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... n g', n = genes.shape[-2], g = genes.shape[-1])
|
426
|
+
|
427
|
+
genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
|
393
428
|
|
394
429
|
# 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
|
395
430
|
|
396
|
-
|
431
|
+
rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
|
432
|
+
rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
|
397
433
|
|
398
|
-
|
399
|
-
participant_fitness =
|
434
|
+
participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
|
435
|
+
participant_fitness = rearrange(participant_fitness, 'i (p t) -> i p t', t = tournament_participants)
|
400
436
|
|
401
|
-
|
437
|
+
parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
|
438
|
+
parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
|
402
439
|
|
403
|
-
|
440
|
+
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) n g', n = genes.shape[-2], g = genes.shape[-1])
|
404
441
|
|
405
|
-
parents =
|
442
|
+
parents = genes.gather(1, parent_gene_ids_for_gather)
|
443
|
+
parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
|
406
444
|
|
407
445
|
# 3. do a crossover of the parents - in their case they went for a simple averaging, but since we are doing tournament style and the same pair of parents may be re-selected, lets make it random interpolation
|
408
446
|
|
409
|
-
parent1, parent2 = parents.unbind(dim =
|
447
|
+
parent1, parent2 = parents.unbind(dim = 2)
|
410
448
|
children = crossover_latents(parent1, parent2, random = self.crossover_random)
|
411
449
|
|
412
450
|
# append children to gene pool
|
413
451
|
|
414
|
-
genes = cat((children, genes))
|
452
|
+
genes = cat((children, genes), dim = 1)
|
415
453
|
|
416
454
|
# 4. they use the elitism strategy to protect best performing genes from being changed
|
417
455
|
|
418
456
|
if self.has_elites:
|
419
|
-
genes, elites = genes[:-self.num_elites], genes[-self.num_elites:]
|
457
|
+
genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
|
420
458
|
|
421
459
|
# 5. mutate with gaussian noise - todo: add drawing the mutation rate from exponential distribution, from the fast genetic algorithms paper from 2017
|
422
460
|
|
@@ -425,10 +463,18 @@ class LatentGenePool(Module):
|
|
425
463
|
# add back the elites
|
426
464
|
|
427
465
|
if self.has_elites:
|
428
|
-
genes = cat((genes, elites))
|
466
|
+
genes = cat((genes, elites), dim = 1)
|
429
467
|
|
430
468
|
genes = self.maybe_l2norm(genes)
|
431
469
|
|
470
|
+
# account for criteria of whether to actually run GA or not
|
471
|
+
|
472
|
+
genes = einx.where('i, i ..., i ...', should_update_per_island, genes, orig_genes)
|
473
|
+
|
474
|
+
# merge island back into pop dimension
|
475
|
+
|
476
|
+
genes = rearrange(genes, 'i p ... -> (i p) ...')
|
477
|
+
|
432
478
|
if not inplace:
|
433
479
|
return genes
|
434
480
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.24
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -36,7 +36,8 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: adam-atan2-pytorch
|
38
38
|
Requires-Dist: assoc-scan
|
39
|
-
Requires-Dist: einops>=0.8.
|
39
|
+
Requires-Dist: einops>=0.8.1
|
40
|
+
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hl-gauss-pytorch>=0.1.19
|
41
42
|
Requires-Dist: torch>=2.2
|
42
43
|
Requires-Dist: tqdm
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=-kQgrnnOLiCOZ-6EroO057tDx0sS7TQro92cjJhSbZU,20353
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.24.dist-info/METADATA,sha256=d3imh1p1-nPpNGhD8cReLdL07_-oHZs3YqJaOEJi1TM,4958
|
5
|
+
evolutionary_policy_optimization-0.0.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.24.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=TbUX2L-Wa2zIZ2b7iHmBtaym-qDSLAFrC7iU7xReX_k,18449
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.22.dist-info/METADATA,sha256=L3G-tesSEyhrc_SbTN6HuJQlXfogEUvr3W9SXPcnRVw,4931
|
5
|
-
evolutionary_policy_optimization-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.22.dist-info/RECORD,,
|
File without changes
|