evolutionary-policy-optimization 0.0.52__tar.gz → 0.0.54__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/PKG-INFO +1 -1
- evolutionary_policy_optimization-0.0.54/evolutionary_policy_optimization/distributed.py +82 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/epo.py +60 -5
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/requirements.txt +0 -0
- {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/tests/test_epo.py +0 -0
{evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.54
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,82 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.autograd import Function
|
5
|
+
|
6
|
+
import torch.distributed as dist
|
7
|
+
|
8
|
+
import einx
|
9
|
+
from einops import rearrange
|
10
|
+
|
11
|
+
def exists(val):
|
12
|
+
return val is not None
|
13
|
+
|
14
|
+
def default(val, d):
|
15
|
+
return val if exists(val) else d
|
16
|
+
|
17
|
+
def divisible_by(num, den):
|
18
|
+
return (num % den) == 0
|
19
|
+
|
20
|
+
def pad_dim_to(t, length, dim = 0):
|
21
|
+
pad_length = length - t.shape[dim]
|
22
|
+
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
23
|
+
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
|
24
|
+
|
25
|
+
def is_distributed():
|
26
|
+
return dist.is_initialized() and dist.get_world_size() > 1
|
27
|
+
|
28
|
+
def maybe_sync_seed(device, max_size = int(1e6)):
|
29
|
+
rand_int = torch.randint(0, max_size, (), device = device)
|
30
|
+
|
31
|
+
if is_distributed():
|
32
|
+
dist.all_reduce(rand_int)
|
33
|
+
|
34
|
+
return rand_int.item()
|
35
|
+
|
36
|
+
def maybe_barrier():
|
37
|
+
if not is_distributed():
|
38
|
+
return
|
39
|
+
|
40
|
+
dist.barrier()
|
41
|
+
|
42
|
+
def all_gather_same_dim(t):
|
43
|
+
t = t.contiguous()
|
44
|
+
world_size = dist.get_world_size()
|
45
|
+
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
|
46
|
+
dist.all_gather(gathered_tensors, t)
|
47
|
+
return gathered_tensors
|
48
|
+
|
49
|
+
def gather_sizes(t, *, dim):
|
50
|
+
size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
|
51
|
+
sizes = all_gather_same_dim(size)
|
52
|
+
return torch.stack(sizes)
|
53
|
+
|
54
|
+
def has_only_one_value(t):
|
55
|
+
return (t == t[0]).all()
|
56
|
+
|
57
|
+
def all_gather_variable_dim(t, dim = 0, sizes = None):
|
58
|
+
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
|
59
|
+
|
60
|
+
if not exists(sizes):
|
61
|
+
sizes = gather_sizes(t, dim = dim)
|
62
|
+
|
63
|
+
if has_only_one_value(sizes):
|
64
|
+
gathered_tensors = all_gather_same_dim(t)
|
65
|
+
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
|
66
|
+
return gathered_tensors, sizes
|
67
|
+
|
68
|
+
max_size = sizes.amax().item()
|
69
|
+
|
70
|
+
padded_t = pad_dim_to(t, max_size, dim = dim)
|
71
|
+
gathered_tensors = all_gather_same_dim(padded_t)
|
72
|
+
|
73
|
+
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
|
74
|
+
seq = torch.arange(max_size, device = device)
|
75
|
+
|
76
|
+
mask = einx.less('j i -> (i j)', seq, sizes)
|
77
|
+
seq = torch.arange(mask.shape[-1], device = device)
|
78
|
+
indices = seq[mask]
|
79
|
+
|
80
|
+
gathered_tensors = gathered_tensors.index_select(dim, indices)
|
81
|
+
|
82
|
+
return gathered_tensors, sizes
|
@@ -8,14 +8,22 @@ from random import randrange
|
|
8
8
|
import torch
|
9
9
|
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
|
+
import torch.distributed as dist
|
11
12
|
from torch.nn import Linear, Module, ModuleList
|
12
13
|
from torch.utils.data import TensorDataset, DataLoader
|
13
|
-
from torch.utils._pytree import tree_map
|
14
|
+
from torch.utils._pytree import tree_map
|
14
15
|
|
15
16
|
import einx
|
16
17
|
from einops import rearrange, repeat, einsum, pack
|
17
18
|
from einops.layers.torch import Rearrange
|
18
19
|
|
20
|
+
from evolutionary_policy_optimization.distributed import (
|
21
|
+
is_distributed,
|
22
|
+
maybe_sync_seed,
|
23
|
+
all_gather_variable_dim,
|
24
|
+
maybe_barrier
|
25
|
+
)
|
26
|
+
|
19
27
|
from assoc_scan import AssocScan
|
20
28
|
|
21
29
|
from adam_atan2_pytorch import AdoptAtan2
|
@@ -360,10 +368,11 @@ class LatentGenePool(Module):
|
|
360
368
|
frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
|
361
369
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
362
370
|
frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
|
363
|
-
migrate_every = 100, # how many steps before a migration between islands
|
364
371
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
365
372
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
366
373
|
default_should_run_ga_gamma = 1.5,
|
374
|
+
migrate_every = 100, # how many steps before a migration between islands
|
375
|
+
apply_genetic_algorithm_every = 2 # how many steps before crossover + mutation happens for genes
|
367
376
|
):
|
368
377
|
super().__init__()
|
369
378
|
|
@@ -413,7 +422,10 @@ class LatentGenePool(Module):
|
|
413
422
|
self.should_run_genetic_algorithm = should_run_genetic_algorithm
|
414
423
|
|
415
424
|
self.can_migrate = num_islands > 1
|
425
|
+
|
416
426
|
self.migrate_every = migrate_every
|
427
|
+
self.apply_genetic_algorithm_every = apply_genetic_algorithm_every
|
428
|
+
|
417
429
|
self.register_buffer('step', tensor(1))
|
418
430
|
|
419
431
|
def get_distance(self):
|
@@ -483,6 +495,10 @@ class LatentGenePool(Module):
|
|
483
495
|
):
|
484
496
|
device = self.latents.device
|
485
497
|
|
498
|
+
if not divisible_by(self.step.item(), self.apply_genetic_algorithm_every):
|
499
|
+
self.advance_step_()
|
500
|
+
return
|
501
|
+
|
486
502
|
"""
|
487
503
|
i - islands
|
488
504
|
p - population
|
@@ -814,6 +830,15 @@ class Agent(Module):
|
|
814
830
|
|
815
831
|
fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
|
816
832
|
|
833
|
+
# stack memories
|
834
|
+
|
835
|
+
memories = map(stack, zip(*memories))
|
836
|
+
|
837
|
+
maybe_barrier()
|
838
|
+
|
839
|
+
if is_distributed():
|
840
|
+
memories = map(partial(all_gather_variable_dim, dim = 0), memories)
|
841
|
+
|
817
842
|
(
|
818
843
|
episode_ids,
|
819
844
|
states,
|
@@ -823,7 +848,7 @@ class Agent(Module):
|
|
823
848
|
rewards,
|
824
849
|
values,
|
825
850
|
dones
|
826
|
-
) =
|
851
|
+
) = memories
|
827
852
|
|
828
853
|
advantages = self.calc_gae(
|
829
854
|
rewards[:-1],
|
@@ -1027,6 +1052,32 @@ class EPO(Module):
|
|
1027
1052
|
self.episodes_per_latent = episodes_per_latent
|
1028
1053
|
self.max_episode_length = max_episode_length
|
1029
1054
|
|
1055
|
+
self.register_buffer('dummy', tensor(0))
|
1056
|
+
|
1057
|
+
@property
|
1058
|
+
def device(self):
|
1059
|
+
return self.dummy.device
|
1060
|
+
|
1061
|
+
def latents_for_machine(self):
|
1062
|
+
num_latents = self.num_latents
|
1063
|
+
|
1064
|
+
if not is_distributed():
|
1065
|
+
return list(range(self.num_latents))
|
1066
|
+
|
1067
|
+
world_size, rank = dist.get_world_size(), dist.get_rank()
|
1068
|
+
assert num_latents >= world_size, 'number of latents must be greater than world size for now'
|
1069
|
+
assert rank < world_size
|
1070
|
+
|
1071
|
+
pad_id = -1
|
1072
|
+
num_latents_rounded_up = ceil(num_latents / world_size) * world_size
|
1073
|
+
latent_ids = torch.arange(num_latents_rounded_up)
|
1074
|
+
latent_ids[latent_ids >= num_latents] = pad_id
|
1075
|
+
|
1076
|
+
latent_ids = rearrange(latent_ids, '(world latents) -> world latents', world = world_size)
|
1077
|
+
out = latent_ids[rank]
|
1078
|
+
|
1079
|
+
return out[out != pad_id].tolist()
|
1080
|
+
|
1030
1081
|
@torch.no_grad()
|
1031
1082
|
def forward(
|
1032
1083
|
self,
|
@@ -1042,19 +1093,23 @@ class EPO(Module):
|
|
1042
1093
|
|
1043
1094
|
cumulative_rewards = torch.zeros((self.num_latents))
|
1044
1095
|
|
1096
|
+
latent_ids = self.latents_for_machine()
|
1097
|
+
|
1045
1098
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1046
1099
|
|
1100
|
+
maybe_barrier()
|
1101
|
+
|
1047
1102
|
# maybe fix seed for environment across all latents
|
1048
1103
|
|
1049
1104
|
env_reset_kwargs = dict()
|
1050
1105
|
|
1051
1106
|
if fix_seed_across_latents:
|
1052
|
-
seed =
|
1107
|
+
seed = maybe_sync_seed(device = self.device)
|
1053
1108
|
env_reset_kwargs = dict(seed = seed)
|
1054
1109
|
|
1055
1110
|
# for each latent (on a single machine for now)
|
1056
1111
|
|
1057
|
-
for latent_id in tqdm(
|
1112
|
+
for latent_id in tqdm(latent_ids, desc = 'latent'):
|
1058
1113
|
time = 0
|
1059
1114
|
|
1060
1115
|
# initial state
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/requirements.txt
RENAMED
File without changes
|
File without changes
|