evolutionary-policy-optimization 0.0.20__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +52 -18
- {evolutionary_policy_optimization-0.0.20.dist-info → evolutionary_policy_optimization-0.0.23.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.23.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.20.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.20.dist-info → evolutionary_policy_optimization-0.0.23.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.20.dist-info → evolutionary_policy_optimization-0.0.23.dist-info}/licenses/LICENSE +0 -0
@@ -31,11 +31,14 @@ def identity(t):
|
|
31
31
|
def xnor(x, y):
|
32
32
|
return not (x ^ y)
|
33
33
|
|
34
|
-
def
|
35
|
-
return
|
34
|
+
def divisible_by(num, den):
|
35
|
+
return (num % den) == 0
|
36
36
|
|
37
37
|
# tensor helpers
|
38
38
|
|
39
|
+
def l2norm(t):
|
40
|
+
return F.normalize(t, p = 2, dim = -1)
|
41
|
+
|
39
42
|
def log(t, eps = 1e-20):
|
40
43
|
return t.clamp(min = eps).log()
|
41
44
|
|
@@ -300,7 +303,9 @@ class LatentGenePool(Module):
|
|
300
303
|
num_latents, # same as gene pool size
|
301
304
|
dim_latent, # gene dimension
|
302
305
|
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
306
|
+
num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
|
303
307
|
dim_state = None,
|
308
|
+
frozen_latents = True,
|
304
309
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
305
310
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
306
311
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
@@ -321,7 +326,7 @@ class LatentGenePool(Module):
|
|
321
326
|
|
322
327
|
self.num_latents = num_latents
|
323
328
|
self.needs_latent_gate = num_latent_sets > 1
|
324
|
-
self.latents = nn.Parameter(latents, requires_grad =
|
329
|
+
self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
|
325
330
|
|
326
331
|
self.maybe_l2norm = maybe_l2norm
|
327
332
|
|
@@ -339,6 +344,9 @@ class LatentGenePool(Module):
|
|
339
344
|
|
340
345
|
# some derived values
|
341
346
|
|
347
|
+
assert num_islands >= 1
|
348
|
+
assert divisible_by(num_latents, num_islands)
|
349
|
+
|
342
350
|
assert 0. < frac_tournaments < 1.
|
343
351
|
assert 0. < frac_natural_selected < 1.
|
344
352
|
assert 0. <= frac_elitism < 1.
|
@@ -346,13 +354,16 @@ class LatentGenePool(Module):
|
|
346
354
|
|
347
355
|
self.dim_latent = dim_latent
|
348
356
|
self.num_latents = num_latents
|
349
|
-
self.
|
357
|
+
self.num_islands = num_islands
|
358
|
+
|
359
|
+
latents_per_island = num_latents // num_islands
|
360
|
+
self.num_natural_selected = int(frac_natural_selected * latents_per_island)
|
350
361
|
|
351
362
|
self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
|
352
363
|
self.crossover_random = crossover_random
|
353
364
|
|
354
365
|
self.mutation_strength = mutation_strength
|
355
|
-
self.num_elites = int(frac_elitism *
|
366
|
+
self.num_elites = int(frac_elitism * latents_per_island)
|
356
367
|
self.has_elites = self.num_elites > 0
|
357
368
|
|
358
369
|
if not exists(should_run_genetic_algorithm):
|
@@ -368,11 +379,16 @@ class LatentGenePool(Module):
|
|
368
379
|
inplace = True
|
369
380
|
):
|
370
381
|
"""
|
382
|
+
i - islands
|
371
383
|
p - population
|
372
384
|
g - gene dimension
|
373
385
|
n - number of genes per individual
|
386
|
+
t - num tournament participants
|
374
387
|
"""
|
375
388
|
|
389
|
+
islands = self.num_islands
|
390
|
+
tournament_participants = self.num_tournament_participants
|
391
|
+
|
376
392
|
if not self.should_run_genetic_algorithm(fitness):
|
377
393
|
return
|
378
394
|
|
@@ -383,39 +399,51 @@ class LatentGenePool(Module):
|
|
383
399
|
pop_size = genes.shape[0]
|
384
400
|
assert pop_size == fitness.shape[0]
|
385
401
|
|
402
|
+
# split out the islands
|
403
|
+
|
404
|
+
genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
|
405
|
+
fitness = rearrange(fitness, '(i p) -> i p', i = islands)
|
406
|
+
|
407
|
+
pop_size_per_island = pop_size // islands
|
408
|
+
|
386
409
|
# 1. natural selection is simple in silico
|
387
410
|
# you sort the population by the fitness and slice off the least fit end
|
388
411
|
|
389
|
-
sorted_indices = fitness.sort().indices
|
390
|
-
natural_selected_indices = sorted_indices[-self.num_natural_selected:]
|
391
|
-
|
412
|
+
sorted_indices = fitness.sort(dim = -1).indices
|
413
|
+
natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
|
414
|
+
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... n g', n = genes.shape[-2], g = genes.shape[-1])
|
415
|
+
|
416
|
+
genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
|
392
417
|
|
393
418
|
# 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
|
394
419
|
|
395
|
-
|
420
|
+
rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
|
421
|
+
rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
|
396
422
|
|
397
|
-
|
398
|
-
participant_fitness =
|
423
|
+
participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
|
424
|
+
participant_fitness = rearrange(participant_fitness, 'i (p t) -> i p t', t = tournament_participants)
|
399
425
|
|
400
|
-
|
426
|
+
parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
|
427
|
+
parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
|
401
428
|
|
402
|
-
|
429
|
+
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) n g', n = genes.shape[-2], g = genes.shape[-1])
|
403
430
|
|
404
|
-
parents =
|
431
|
+
parents = genes.gather(1, parent_gene_ids_for_gather)
|
432
|
+
parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
|
405
433
|
|
406
434
|
# 3. do a crossover of the parents - in their case they went for a simple averaging, but since we are doing tournament style and the same pair of parents may be re-selected, lets make it random interpolation
|
407
435
|
|
408
|
-
parent1, parent2 = parents.unbind(dim =
|
436
|
+
parent1, parent2 = parents.unbind(dim = 2)
|
409
437
|
children = crossover_latents(parent1, parent2, random = self.crossover_random)
|
410
438
|
|
411
439
|
# append children to gene pool
|
412
440
|
|
413
|
-
genes = cat((children, genes))
|
441
|
+
genes = cat((children, genes), dim = 1)
|
414
442
|
|
415
443
|
# 4. they use the elitism strategy to protect best performing genes from being changed
|
416
444
|
|
417
445
|
if self.has_elites:
|
418
|
-
genes, elites = genes[:-self.num_elites], genes[-self.num_elites:]
|
446
|
+
genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
|
419
447
|
|
420
448
|
# 5. mutate with gaussian noise - todo: add drawing the mutation rate from exponential distribution, from the fast genetic algorithms paper from 2017
|
421
449
|
|
@@ -424,10 +452,14 @@ class LatentGenePool(Module):
|
|
424
452
|
# add back the elites
|
425
453
|
|
426
454
|
if self.has_elites:
|
427
|
-
genes = cat((genes, elites))
|
455
|
+
genes = cat((genes, elites), dim = 1)
|
428
456
|
|
429
457
|
genes = self.maybe_l2norm(genes)
|
430
458
|
|
459
|
+
# merge island back into pop dimension
|
460
|
+
|
461
|
+
genes = rearrange(genes, 'i p ... -> (i p) ...')
|
462
|
+
|
431
463
|
if not inplace:
|
432
464
|
return genes
|
433
465
|
|
@@ -477,6 +509,8 @@ class LatentGenePool(Module):
|
|
477
509
|
else:
|
478
510
|
latent = latent[0]
|
479
511
|
|
512
|
+
latent = self.maybe_l2norm(latent)
|
513
|
+
|
480
514
|
if not exists(net):
|
481
515
|
return latent
|
482
516
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.23
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=lMKMGaJdSh71Dnkn6ZvbtVNrCJ00Cv3p_uXJy9D8K90,19908
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.23.dist-info/METADATA,sha256=m_EqBjggqqEq09A9KqjA8BQMWECLIvzwbOR8pc-UFrM,4931
|
5
|
+
evolutionary_policy_optimization-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.23.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=BTBqkgDq-x4dUMlKdSojvV2Yjzf9pDUZGMik32WjdHQ,18361
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.20.dist-info/METADATA,sha256=0QNTGATtchVuxVplbrfXAtupcrMKEQD-uisM7CFm7qE,4931
|
5
|
-
evolutionary_policy_optimization-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.20.dist-info/RECORD,,
|
File without changes
|