evolutionary-policy-optimization 0.0.46__py3-none-any.whl → 0.0.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +28 -5
- evolutionary_policy_optimization/mock_env.py +2 -1
- {evolutionary_policy_optimization-0.0.46.dist-info → evolutionary_policy_optimization-0.0.48.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.48.dist-info/RECORD +8 -0
- evolutionary_policy_optimization-0.0.46.dist-info/RECORD +0 -8
- {evolutionary_policy_optimization-0.0.46.dist-info → evolutionary_policy_optimization-0.0.48.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.46.dist-info → evolutionary_policy_optimization-0.0.48.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from functools import partial, wraps
|
4
4
|
from pathlib import Path
|
5
5
|
from collections import namedtuple
|
6
|
+
from random import randrange
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch import nn, cat, stack, is_tensor, tensor
|
@@ -356,9 +357,10 @@ class LatentGenePool(Module):
|
|
356
357
|
frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
|
357
358
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
358
359
|
frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
|
360
|
+
migrate_every = 100, # how many steps before a migration between islands
|
359
361
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
360
362
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
361
|
-
default_should_run_ga_gamma = 1.5
|
363
|
+
default_should_run_ga_gamma = 1.5,
|
362
364
|
):
|
363
365
|
super().__init__()
|
364
366
|
|
@@ -408,6 +410,13 @@ class LatentGenePool(Module):
|
|
408
410
|
|
409
411
|
self.should_run_genetic_algorithm = should_run_genetic_algorithm
|
410
412
|
|
413
|
+
self.can_migrate = num_islands > 1
|
414
|
+
self.migrate_every = migrate_every
|
415
|
+
self.register_buffer('step', tensor(0))
|
416
|
+
|
417
|
+
def advance_step_(self):
|
418
|
+
self.step.add_(1)
|
419
|
+
|
411
420
|
def firefly_step(
|
412
421
|
self,
|
413
422
|
fitness,
|
@@ -459,7 +468,7 @@ class LatentGenePool(Module):
|
|
459
468
|
self,
|
460
469
|
fitness, # Float['p'],
|
461
470
|
inplace = True,
|
462
|
-
migrate =
|
471
|
+
migrate = None # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
|
463
472
|
):
|
464
473
|
device = self.latents.device
|
465
474
|
|
@@ -546,8 +555,9 @@ class LatentGenePool(Module):
|
|
546
555
|
|
547
556
|
# 6. maybe migration
|
548
557
|
|
558
|
+
migrate = self.can_migrate and default(migrate, divisible_by(self.step.item(), self.migrate_every))
|
559
|
+
|
549
560
|
if migrate:
|
550
|
-
assert self.num_islands > 1
|
551
561
|
randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
|
552
562
|
|
553
563
|
migrate_mask = randperm < self.num_migrate
|
@@ -580,6 +590,8 @@ class LatentGenePool(Module):
|
|
580
590
|
|
581
591
|
self.latents.copy_(genes)
|
582
592
|
|
593
|
+
self.advance_step_()
|
594
|
+
|
583
595
|
def forward(
|
584
596
|
self,
|
585
597
|
*args,
|
@@ -989,7 +1001,8 @@ class EPO(Module):
|
|
989
1001
|
@torch.no_grad()
|
990
1002
|
def forward(
|
991
1003
|
self,
|
992
|
-
env
|
1004
|
+
env,
|
1005
|
+
fix_seed_across_latents = True
|
993
1006
|
) -> MemoriesAndCumulativeRewards:
|
994
1007
|
|
995
1008
|
self.agent.eval()
|
@@ -1002,12 +1015,22 @@ class EPO(Module):
|
|
1002
1015
|
|
1003
1016
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1004
1017
|
|
1018
|
+
# maybe fix seed for environment across all latents
|
1019
|
+
|
1020
|
+
env_reset_kwargs = dict()
|
1021
|
+
|
1022
|
+
if fix_seed_across_latents:
|
1023
|
+
seed = randrange(int(1e6))
|
1024
|
+
env_reset_kwargs = dict(seed = seed)
|
1025
|
+
|
1026
|
+
# for each latent (on a single machine for now)
|
1027
|
+
|
1005
1028
|
for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
|
1006
1029
|
time = 0
|
1007
1030
|
|
1008
1031
|
# initial state
|
1009
1032
|
|
1010
|
-
state = env.reset()
|
1033
|
+
state = env.reset(**env_reset_kwargs)
|
1011
1034
|
|
1012
1035
|
# get latent from pool
|
1013
1036
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.48
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,8 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=FkliOiKdmUvKuwFqb1_A-ddahnOqjTR8Djx_I6UZAlU,32625
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
|
+
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
5
|
+
evolutionary_policy_optimization-0.0.48.dist-info/METADATA,sha256=GpuUVs0VO2ydhU3X4-A_cA_xNmvdtYaAM8tb_VKneBo,6213
|
6
|
+
evolutionary_policy_optimization-0.0.48.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.48.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.48.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=SAhWgRY8uPQEKFg1_nz1mvh8A6S_sHwnDykhd0F5xEI,31853
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
|
-
evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
|
5
|
-
evolutionary_policy_optimization-0.0.46.dist-info/METADATA,sha256=xP2kdKo52-X4Z5XXTPpW0M_NFI0spuigeL7fvqFlsRM,6213
|
6
|
-
evolutionary_policy_optimization-0.0.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
evolutionary_policy_optimization-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
evolutionary_policy_optimization-0.0.46.dist-info/RECORD,,
|
File without changes
|