evolutionary-policy-optimization 0.0.53__tar.gz → 0.0.55__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/PKG-INFO +1 -1
- evolutionary_policy_optimization-0.0.55/evolutionary_policy_optimization/distributed.py +88 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/epo.py +52 -5
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/requirements.txt +0 -0
- {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/tests/test_epo.py +0 -0
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.55
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.autograd import Function
|
5
|
+
|
6
|
+
import torch.distributed as dist
|
7
|
+
|
8
|
+
import einx
|
9
|
+
from einops import rearrange
|
10
|
+
|
11
|
+
def exists(val):
|
12
|
+
return val is not None
|
13
|
+
|
14
|
+
def default(val, d):
|
15
|
+
return val if exists(val) else d
|
16
|
+
|
17
|
+
def divisible_by(num, den):
|
18
|
+
return (num % den) == 0
|
19
|
+
|
20
|
+
def pad_dim_to(t, length, dim = 0):
|
21
|
+
pad_length = length - t.shape[dim]
|
22
|
+
zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
23
|
+
return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
|
24
|
+
|
25
|
+
def is_distributed():
|
26
|
+
return dist.is_initialized() and dist.get_world_size() > 1
|
27
|
+
|
28
|
+
def get_world_and_rank():
|
29
|
+
if not is_distributed():
|
30
|
+
return 1, 0
|
31
|
+
|
32
|
+
return dist.get_world_size(), dist.get_rank()
|
33
|
+
|
34
|
+
def maybe_sync_seed(device, max_size = int(1e6)):
|
35
|
+
rand_int = torch.randint(0, max_size, (), device = device)
|
36
|
+
|
37
|
+
if is_distributed():
|
38
|
+
dist.all_reduce(rand_int)
|
39
|
+
|
40
|
+
return rand_int.item()
|
41
|
+
|
42
|
+
def maybe_barrier():
|
43
|
+
if not is_distributed():
|
44
|
+
return
|
45
|
+
|
46
|
+
dist.barrier()
|
47
|
+
|
48
|
+
def all_gather_same_dim(t):
|
49
|
+
t = t.contiguous()
|
50
|
+
world_size = dist.get_world_size()
|
51
|
+
gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
|
52
|
+
dist.all_gather(gathered_tensors, t)
|
53
|
+
return gathered_tensors
|
54
|
+
|
55
|
+
def gather_sizes(t, *, dim):
|
56
|
+
size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
|
57
|
+
sizes = all_gather_same_dim(size)
|
58
|
+
return torch.stack(sizes)
|
59
|
+
|
60
|
+
def has_only_one_value(t):
|
61
|
+
return (t == t[0]).all()
|
62
|
+
|
63
|
+
def all_gather_variable_dim(t, dim = 0, sizes = None):
|
64
|
+
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
|
65
|
+
|
66
|
+
if not exists(sizes):
|
67
|
+
sizes = gather_sizes(t, dim = dim)
|
68
|
+
|
69
|
+
if has_only_one_value(sizes):
|
70
|
+
gathered_tensors = all_gather_same_dim(t)
|
71
|
+
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
|
72
|
+
return gathered_tensors, sizes
|
73
|
+
|
74
|
+
max_size = sizes.amax().item()
|
75
|
+
|
76
|
+
padded_t = pad_dim_to(t, max_size, dim = dim)
|
77
|
+
gathered_tensors = all_gather_same_dim(padded_t)
|
78
|
+
|
79
|
+
gathered_tensors = torch.cat(gathered_tensors, dim = dim)
|
80
|
+
seq = torch.arange(max_size, device = device)
|
81
|
+
|
82
|
+
mask = einx.less('j i -> (i j)', seq, sizes)
|
83
|
+
seq = torch.arange(mask.shape[-1], device = device)
|
84
|
+
indices = seq[mask]
|
85
|
+
|
86
|
+
gathered_tensors = gathered_tensors.index_select(dim, indices)
|
87
|
+
|
88
|
+
return gathered_tensors, sizes
|
@@ -1,21 +1,31 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from functools import partial, wraps
|
4
3
|
from pathlib import Path
|
4
|
+
from math import ceil
|
5
|
+
from functools import partial, wraps
|
5
6
|
from collections import namedtuple
|
6
7
|
from random import randrange
|
7
8
|
|
8
9
|
import torch
|
9
10
|
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
10
11
|
import torch.nn.functional as F
|
12
|
+
import torch.distributed as dist
|
11
13
|
from torch.nn import Linear, Module, ModuleList
|
12
14
|
from torch.utils.data import TensorDataset, DataLoader
|
13
|
-
from torch.utils._pytree import tree_map
|
15
|
+
from torch.utils._pytree import tree_map
|
14
16
|
|
15
17
|
import einx
|
16
18
|
from einops import rearrange, repeat, einsum, pack
|
17
19
|
from einops.layers.torch import Rearrange
|
18
20
|
|
21
|
+
from evolutionary_policy_optimization.distributed import (
|
22
|
+
is_distributed,
|
23
|
+
get_world_and_rank,
|
24
|
+
maybe_sync_seed,
|
25
|
+
all_gather_variable_dim,
|
26
|
+
maybe_barrier
|
27
|
+
)
|
28
|
+
|
19
29
|
from assoc_scan import AssocScan
|
20
30
|
|
21
31
|
from adam_atan2_pytorch import AdoptAtan2
|
@@ -822,6 +832,15 @@ class Agent(Module):
|
|
822
832
|
|
823
833
|
fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
|
824
834
|
|
835
|
+
# stack memories
|
836
|
+
|
837
|
+
memories = map(stack, zip(*memories))
|
838
|
+
|
839
|
+
maybe_barrier()
|
840
|
+
|
841
|
+
if is_distributed():
|
842
|
+
memories = map(partial(all_gather_variable_dim, dim = 0), memories)
|
843
|
+
|
825
844
|
(
|
826
845
|
episode_ids,
|
827
846
|
states,
|
@@ -831,7 +850,7 @@ class Agent(Module):
|
|
831
850
|
rewards,
|
832
851
|
values,
|
833
852
|
dones
|
834
|
-
) =
|
853
|
+
) = memories
|
835
854
|
|
836
855
|
advantages = self.calc_gae(
|
837
856
|
rewards[:-1],
|
@@ -1035,6 +1054,30 @@ class EPO(Module):
|
|
1035
1054
|
self.episodes_per_latent = episodes_per_latent
|
1036
1055
|
self.max_episode_length = max_episode_length
|
1037
1056
|
|
1057
|
+
self.register_buffer('dummy', tensor(0))
|
1058
|
+
|
1059
|
+
@property
|
1060
|
+
def device(self):
|
1061
|
+
return self.dummy.device
|
1062
|
+
|
1063
|
+
def latents_for_machine(self):
|
1064
|
+
num_latents = self.num_latents
|
1065
|
+
|
1066
|
+
world_size, rank = get_world_and_rank()
|
1067
|
+
|
1068
|
+
assert num_latents >= world_size, 'number of latents must be greater than world size for now'
|
1069
|
+
assert rank < world_size
|
1070
|
+
|
1071
|
+
num_latents_per_machine = ceil(num_latents / world_size)
|
1072
|
+
|
1073
|
+
for i in range(num_latents_per_machine):
|
1074
|
+
latent_id = rank * num_latents_per_machine + i
|
1075
|
+
|
1076
|
+
if latent_id >= num_latents:
|
1077
|
+
continue
|
1078
|
+
|
1079
|
+
yield i
|
1080
|
+
|
1038
1081
|
@torch.no_grad()
|
1039
1082
|
def forward(
|
1040
1083
|
self,
|
@@ -1050,19 +1093,23 @@ class EPO(Module):
|
|
1050
1093
|
|
1051
1094
|
cumulative_rewards = torch.zeros((self.num_latents))
|
1052
1095
|
|
1096
|
+
latent_ids_gen = self.latents_for_machine()
|
1097
|
+
|
1053
1098
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1054
1099
|
|
1100
|
+
maybe_barrier()
|
1101
|
+
|
1055
1102
|
# maybe fix seed for environment across all latents
|
1056
1103
|
|
1057
1104
|
env_reset_kwargs = dict()
|
1058
1105
|
|
1059
1106
|
if fix_seed_across_latents:
|
1060
|
-
seed =
|
1107
|
+
seed = maybe_sync_seed(device = self.device)
|
1061
1108
|
env_reset_kwargs = dict(seed = seed)
|
1062
1109
|
|
1063
1110
|
# for each latent (on a single machine for now)
|
1064
1111
|
|
1065
|
-
for latent_id in tqdm(
|
1112
|
+
for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
|
1066
1113
|
time = 0
|
1067
1114
|
|
1068
1115
|
# initial state
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/requirements.txt
RENAMED
File without changes
|
File without changes
|