evolutionary-policy-optimization 0.0.53__tar.gz → 0.0.54__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/PKG-INFO +1 -1
- evolutionary_policy_optimization-0.0.54/evolutionary_policy_optimization/distributed.py +82 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/epo.py +51 -4
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/requirements.txt +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/tests/test_epo.py +0 -0
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.54
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,82 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.autograd import Function
|
5
|
+
|
6
|
+
import torch.distributed as dist
|
7
|
+
|
8
|
+
import einx
|
9
|
+
from einops import rearrange
|
10
|
+
|
11
|
+
def exists(val):
|
12
|
+
return val is not None
|
13
|
+
|
14
|
+
def default(val, d):
|
15
|
+
return val if exists(val) else d
|
16
|
+
|
17
|
+
def divisible_by(num, den):
|
18
|
+
return (num % den) == 0
|
19
|
+
|
20
|
+
def pad_dim_to(t, length, dim = 0):
|
21
|
+
pad_length = length - t.shape[dim]
|
22
|
+
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
23
|
+
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
|
24
|
+
|
25
|
+
def is_distributed():
|
26
|
+
return dist.is_initialized() and dist.get_world_size() > 1
|
27
|
+
|
28
|
+
def maybe_sync_seed(device, max_size = int(1e6)):
|
29
|
+
rand_int = torch.randint(0, max_size, (), device = device)
|
30
|
+
|
31
|
+
if is_distributed():
|
32
|
+
dist.all_reduce(rand_int)
|
33
|
+
|
34
|
+
return rand_int.item()
|
35
|
+
|
36
|
+
def maybe_barrier():
|
37
|
+
if not is_distributed():
|
38
|
+
return
|
39
|
+
|
40
|
+
dist.barrier()
|
41
|
+
|
42
|
+
def all_gather_same_dim(t):
|
43
|
+
t = t.contiguous()
|
44
|
+
world_size = dist.get_world_size()
|
45
|
+
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
|
46
|
+
dist.all_gather(gathered_tensors, t)
|
47
|
+
return gathered_tensors
|
48
|
+
|
49
|
+
def gather_sizes(t, *, dim):
|
50
|
+
size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
|
51
|
+
sizes = all_gather_same_dim(size)
|
52
|
+
return torch.stack(sizes)
|
53
|
+
|
54
|
+
def has_only_one_value(t):
|
55
|
+
return (t == t[0]).all()
|
56
|
+
|
57
|
+
def all_gather_variable_dim(t, dim = 0, sizes = None):
|
58
|
+
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
|
59
|
+
|
60
|
+
if not exists(sizes):
|
61
|
+
sizes = gather_sizes(t, dim = dim)
|
62
|
+
|
63
|
+
if has_only_one_value(sizes):
|
64
|
+
gathered_tensors = all_gather_same_dim(t)
|
65
|
+
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
|
66
|
+
return gathered_tensors, sizes
|
67
|
+
|
68
|
+
max_size = sizes.amax().item()
|
69
|
+
|
70
|
+
padded_t = pad_dim_to(t, max_size, dim = dim)
|
71
|
+
gathered_tensors = all_gather_same_dim(padded_t)
|
72
|
+
|
73
|
+
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
|
74
|
+
seq = torch.arange(max_size, device = device)
|
75
|
+
|
76
|
+
mask = einx.less('j i -> (i j)', seq, sizes)
|
77
|
+
seq = torch.arange(mask.shape[-1], device = device)
|
78
|
+
indices = seq[mask]
|
79
|
+
|
80
|
+
gathered_tensors = gathered_tensors.index_select(dim, indices)
|
81
|
+
|
82
|
+
return gathered_tensors, sizes
|
@@ -8,14 +8,22 @@ from random import randrange
|
|
8
8
|
import torch
|
9
9
|
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
|
+
import torch.distributed as dist
|
11
12
|
from torch.nn import Linear, Module, ModuleList
|
12
13
|
from torch.utils.data import TensorDataset, DataLoader
|
13
|
-
from torch.utils._pytree import tree_map
|
14
|
+
from torch.utils._pytree import tree_map
|
14
15
|
|
15
16
|
import einx
|
16
17
|
from einops import rearrange, repeat, einsum, pack
|
17
18
|
from einops.layers.torch import Rearrange
|
18
19
|
|
20
|
+
from evolutionary_policy_optimization.distributed import (
|
21
|
+
is_distributed,
|
22
|
+
maybe_sync_seed,
|
23
|
+
all_gather_variable_dim,
|
24
|
+
maybe_barrier
|
25
|
+
)
|
26
|
+
|
19
27
|
from assoc_scan import AssocScan
|
20
28
|
|
21
29
|
from adam_atan2_pytorch import AdoptAtan2
|
@@ -822,6 +830,15 @@ class Agent(Module):
|
|
822
830
|
|
823
831
|
fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
|
824
832
|
|
833
|
+
# stack memories
|
834
|
+
|
835
|
+
memories = map(stack, zip(*memories))
|
836
|
+
|
837
|
+
maybe_barrier()
|
838
|
+
|
839
|
+
if is_distributed():
|
840
|
+
memories = map(partial(all_gather_variable_dim, dim = 0), memories)
|
841
|
+
|
825
842
|
(
|
826
843
|
episode_ids,
|
827
844
|
states,
|
@@ -831,7 +848,7 @@ class Agent(Module):
|
|
831
848
|
rewards,
|
832
849
|
values,
|
833
850
|
dones
|
834
|
-
) =
|
851
|
+
) = memories
|
835
852
|
|
836
853
|
advantages = self.calc_gae(
|
837
854
|
rewards[:-1],
|
@@ -1035,6 +1052,32 @@ class EPO(Module):
|
|
1035
1052
|
self.episodes_per_latent = episodes_per_latent
|
1036
1053
|
self.max_episode_length = max_episode_length
|
1037
1054
|
|
1055
|
+
self.register_buffer('dummy', tensor(0))
|
1056
|
+
|
1057
|
+
@property
|
1058
|
+
def device(self):
|
1059
|
+
return self.dummy.device
|
1060
|
+
|
1061
|
+
def latents_for_machine(self):
|
1062
|
+
num_latents = self.num_latents
|
1063
|
+
|
1064
|
+
if not is_distributed():
|
1065
|
+
return list(range(self.num_latents))
|
1066
|
+
|
1067
|
+
world_size, rank = dist.get_world_size(), dist.get_rank()
|
1068
|
+
assert num_latents >= world_size, 'number of latents must be greater than world size for now'
|
1069
|
+
assert rank < world_size
|
1070
|
+
|
1071
|
+
pad_id = -1
|
1072
|
+
num_latents_rounded_up = ceil(num_latents / world_size) * world_size
|
1073
|
+
latent_ids = torch.arange(num_latents_rounded_up)
|
1074
|
+
latent_ids[latent_ids >= num_latents] = pad_id
|
1075
|
+
|
1076
|
+
latent_ids = rearrange(latent_ids, '(world latents) -> world latents', world = world_size)
|
1077
|
+
out = latent_ids[rank]
|
1078
|
+
|
1079
|
+
return out[out != pad_id].tolist()
|
1080
|
+
|
1038
1081
|
@torch.no_grad()
|
1039
1082
|
def forward(
|
1040
1083
|
self,
|
@@ -1050,19 +1093,23 @@ class EPO(Module):
|
|
1050
1093
|
|
1051
1094
|
cumulative_rewards = torch.zeros((self.num_latents))
|
1052
1095
|
|
1096
|
+
latent_ids = self.latents_for_machine()
|
1097
|
+
|
1053
1098
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1054
1099
|
|
1100
|
+
maybe_barrier()
|
1101
|
+
|
1055
1102
|
# maybe fix seed for environment across all latents
|
1056
1103
|
|
1057
1104
|
env_reset_kwargs = dict()
|
1058
1105
|
|
1059
1106
|
if fix_seed_across_latents:
|
1060
|
-
seed =
|
1107
|
+
seed = maybe_sync_seed(device = self.device)
|
1061
1108
|
env_reset_kwargs = dict(seed = seed)
|
1062
1109
|
|
1063
1110
|
# for each latent (on a single machine for now)
|
1064
1111
|
|
1065
|
-
for latent_id in tqdm(
|
1112
|
+
for latent_id in tqdm(latent_ids, desc = 'latent'):
|
1066
1113
|
time = 0
|
1067
1114
|
|
1068
1115
|
# initial state
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/requirements.txt
RENAMED
File without changes
|
File without changes
|