evolutionary-policy-optimization 0.0.54__tar.gz → 0.0.56__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/PKG-INFO +1 -1
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/distributed.py +6 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/epo.py +18 -15
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/experimental.py +26 -1
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/requirements.txt +0 -0
- {evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/tests/test_epo.py +0 -0
{evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.56
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -25,6 +25,12 @@ def pad_dim_to(t, length, dim = 0):
|
|
25
25
|
def is_distributed():
|
26
26
|
return dist.is_initialized() and dist.get_world_size() > 1
|
27
27
|
|
28
|
+
def get_world_and_rank():
|
29
|
+
if not is_distributed():
|
30
|
+
return 1, 0
|
31
|
+
|
32
|
+
return dist.get_world_size(), dist.get_rank()
|
33
|
+
|
28
34
|
def maybe_sync_seed(device, max_size = int(1e6)):
|
29
35
|
rand_int = torch.randint(0, max_size, (), device = device)
|
30
36
|
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from functools import partial, wraps
|
4
3
|
from pathlib import Path
|
4
|
+
from math import ceil
|
5
|
+
from functools import partial, wraps
|
5
6
|
from collections import namedtuple
|
6
7
|
from random import randrange
|
7
8
|
|
@@ -19,6 +20,7 @@ from einops.layers.torch import Rearrange
|
|
19
20
|
|
20
21
|
from evolutionary_policy_optimization.distributed import (
|
21
22
|
is_distributed,
|
23
|
+
get_world_and_rank,
|
22
24
|
maybe_sync_seed,
|
23
25
|
all_gather_variable_dim,
|
24
26
|
maybe_barrier
|
@@ -372,13 +374,16 @@ class LatentGenePool(Module):
|
|
372
374
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
373
375
|
default_should_run_ga_gamma = 1.5,
|
374
376
|
migrate_every = 100, # how many steps before a migration between islands
|
375
|
-
apply_genetic_algorithm_every = 2
|
377
|
+
apply_genetic_algorithm_every = 2, # how many steps before crossover + mutation happens for genes
|
378
|
+
init_latent_fn: Callable = None
|
376
379
|
):
|
377
380
|
super().__init__()
|
378
381
|
|
379
382
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
380
383
|
|
381
|
-
|
384
|
+
init_fn = default(init_latent_fn, torch.randn)
|
385
|
+
|
386
|
+
latents = init_fn((num_latents, dim_latent))
|
382
387
|
|
383
388
|
if l2norm_latent:
|
384
389
|
latents = maybe_l2norm(latents, dim = -1)
|
@@ -1061,22 +1066,20 @@ class EPO(Module):
|
|
1061
1066
|
def latents_for_machine(self):
|
1062
1067
|
num_latents = self.num_latents
|
1063
1068
|
|
1064
|
-
|
1065
|
-
return list(range(self.num_latents))
|
1069
|
+
world_size, rank = get_world_and_rank()
|
1066
1070
|
|
1067
|
-
world_size, rank = dist.get_world_size(), dist.get_rank()
|
1068
1071
|
assert num_latents >= world_size, 'number of latents must be greater than world size for now'
|
1069
1072
|
assert rank < world_size
|
1070
1073
|
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1074
|
+
num_latents_per_machine = ceil(num_latents / world_size)
|
1075
|
+
|
1076
|
+
for i in range(num_latents_per_machine):
|
1077
|
+
latent_id = rank * num_latents_per_machine + i
|
1075
1078
|
|
1076
|
-
|
1077
|
-
|
1079
|
+
if latent_id >= num_latents:
|
1080
|
+
continue
|
1078
1081
|
|
1079
|
-
|
1082
|
+
yield i
|
1080
1083
|
|
1081
1084
|
@torch.no_grad()
|
1082
1085
|
def forward(
|
@@ -1093,7 +1096,7 @@ class EPO(Module):
|
|
1093
1096
|
|
1094
1097
|
cumulative_rewards = torch.zeros((self.num_latents))
|
1095
1098
|
|
1096
|
-
|
1099
|
+
latent_ids_gen = self.latents_for_machine()
|
1097
1100
|
|
1098
1101
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1099
1102
|
|
@@ -1109,7 +1112,7 @@ class EPO(Module):
|
|
1109
1112
|
|
1110
1113
|
# for each latent (on a single machine for now)
|
1111
1114
|
|
1112
|
-
for latent_id in tqdm(
|
1115
|
+
for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
|
1113
1116
|
time = 0
|
1114
1117
|
|
1115
1118
|
# initial state
|
@@ -39,9 +39,34 @@ def crossover_weights(w1, w2, transpose = False):
|
|
39
39
|
|
40
40
|
return out
|
41
41
|
|
42
|
+
def mutate_weight(
|
43
|
+
w,
|
44
|
+
transpose = False,
|
45
|
+
mutation_strength = 1.
|
46
|
+
):
|
47
|
+
|
48
|
+
if transpose:
|
49
|
+
w = w.transpose(-1, -2)
|
50
|
+
|
51
|
+
rank = min(w2.shape[1:])
|
52
|
+
assert rank >= 2
|
53
|
+
|
54
|
+
u, s, v = torch.svd(w)
|
55
|
+
u = u + torch.randn_like(u) * mutation_strength
|
56
|
+
v = v + torch.randn_like(v) * mutation_strength
|
57
|
+
|
58
|
+
out = u @ torch.diag_embed(s) @ v.mT
|
59
|
+
|
60
|
+
if transpose:
|
61
|
+
out = out.transpose(-1, -2)
|
62
|
+
|
63
|
+
return out
|
64
|
+
|
42
65
|
if __name__ == '__main__':
|
43
66
|
w1 = torch.randn(32, 16)
|
44
67
|
w2 = torch.randn(32, 16)
|
45
|
-
|
68
|
+
|
69
|
+
child = crossover_weights(w1, w2)
|
70
|
+
mutated_w1 = mutate_weight(w1)
|
46
71
|
|
47
72
|
assert child.shape == w2.shape
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.54 → evolutionary_policy_optimization-0.0.56}/requirements.txt
RENAMED
File without changes
|
File without changes
|