evolutionary-policy-optimization 0.0.53__tar.gz → 0.0.54__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (14) hide show
  1. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/PKG-INFO +1 -1
  2. evolutionary_policy_optimization-0.0.54/evolutionary_policy_optimization/distributed.py +82 -0
  3. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/epo.py +51 -4
  4. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/experimental.py +0 -0
  12. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/mock_env.py +0 -0
  13. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/requirements.txt +0 -0
  14. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.54}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.53
3
+ Version: 0.0.54
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,82 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+
6
+ import torch.distributed as dist
7
+
8
+ import einx
9
+ from einops import rearrange
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+ def default(val, d):
15
+ return val if exists(val) else d
16
+
17
+ def divisible_by(num, den):
18
+ return (num % den) == 0
19
+
20
+ def pad_dim_to(t, length, dim = 0):
21
+ pad_length = length - t.shape[dim]
22
+ zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
23
+ return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
24
+
25
+ def is_distributed():
26
+ return dist.is_initialized() and dist.get_world_size() > 1
27
+
28
+ def maybe_sync_seed(device, max_size = int(1e6)):
29
+ rand_int = torch.randint(0, max_size, (), device = device)
30
+
31
+ if is_distributed():
32
+ dist.all_reduce(rand_int)
33
+
34
+ return rand_int.item()
35
+
36
+ def maybe_barrier():
37
+ if not is_distributed():
38
+ return
39
+
40
+ dist.barrier()
41
+
42
+ def all_gather_same_dim(t):
43
+ t = t.contiguous()
44
+ world_size = dist.get_world_size()
45
+ gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
46
+ dist.all_gather(gathered_tensors, t)
47
+ return gathered_tensors
48
+
49
+ def gather_sizes(t, *, dim):
50
+ size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
51
+ sizes = all_gather_same_dim(size)
52
+ return torch.stack(sizes)
53
+
54
+ def has_only_one_value(t):
55
+ return (t == t[0]).all()
56
+
57
+ def all_gather_variable_dim(t, dim = 0, sizes = None):
58
+ device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
59
+
60
+ if not exists(sizes):
61
+ sizes = gather_sizes(t, dim = dim)
62
+
63
+ if has_only_one_value(sizes):
64
+ gathered_tensors = all_gather_same_dim(t)
65
+ gathered_tensors = torch.cat(gathered_tensors, dim = dim)
66
+ return gathered_tensors, sizes
67
+
68
+ max_size = sizes.amax().item()
69
+
70
+ padded_t = pad_dim_to(t, max_size, dim = dim)
71
+ gathered_tensors = all_gather_same_dim(padded_t)
72
+
73
+ gathered_tensors = torch.cat(gathered_tensors, dim = dim)
74
+ seq = torch.arange(max_size, device = device)
75
+
76
+ mask = einx.less('j i -> (i j)', seq, sizes)
77
+ seq = torch.arange(mask.shape[-1], device = device)
78
+ indices = seq[mask]
79
+
80
+ gathered_tensors = gathered_tensors.index_select(dim, indices)
81
+
82
+ return gathered_tensors, sizes
@@ -8,14 +8,22 @@ from random import randrange
8
8
  import torch
9
9
  from torch import nn, cat, stack, is_tensor, tensor, Tensor
10
10
  import torch.nn.functional as F
11
+ import torch.distributed as dist
11
12
  from torch.nn import Linear, Module, ModuleList
12
13
  from torch.utils.data import TensorDataset, DataLoader
13
- from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
14
+ from torch.utils._pytree import tree_map
14
15
 
15
16
  import einx
16
17
  from einops import rearrange, repeat, einsum, pack
17
18
  from einops.layers.torch import Rearrange
18
19
 
20
+ from evolutionary_policy_optimization.distributed import (
21
+ is_distributed,
22
+ maybe_sync_seed,
23
+ all_gather_variable_dim,
24
+ maybe_barrier
25
+ )
26
+
19
27
  from assoc_scan import AssocScan
20
28
 
21
29
  from adam_atan2_pytorch import AdoptAtan2
@@ -822,6 +830,15 @@ class Agent(Module):
822
830
 
823
831
  fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
824
832
 
833
+ # stack memories
834
+
835
+ memories = map(stack, zip(*memories))
836
+
837
+ maybe_barrier()
838
+
839
+ if is_distributed():
840
+ memories = map(partial(all_gather_variable_dim, dim = 0), memories)
841
+
825
842
  (
826
843
  episode_ids,
827
844
  states,
@@ -831,7 +848,7 @@ class Agent(Module):
831
848
  rewards,
832
849
  values,
833
850
  dones
834
- ) = map(stack, zip(*memories))
851
+ ) = memories
835
852
 
836
853
  advantages = self.calc_gae(
837
854
  rewards[:-1],
@@ -1035,6 +1052,32 @@ class EPO(Module):
1035
1052
  self.episodes_per_latent = episodes_per_latent
1036
1053
  self.max_episode_length = max_episode_length
1037
1054
 
1055
+ self.register_buffer('dummy', tensor(0))
1056
+
1057
+ @property
1058
+ def device(self):
1059
+ return self.dummy.device
1060
+
1061
+ def latents_for_machine(self):
1062
+ num_latents = self.num_latents
1063
+
1064
+ if not is_distributed():
1065
+ return list(range(self.num_latents))
1066
+
1067
+ world_size, rank = dist.get_world_size(), dist.get_rank()
1068
+ assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1069
+ assert rank < world_size
1070
+
1071
+ pad_id = -1
1072
+ num_latents_rounded_up = ceil(num_latents / world_size) * world_size
1073
+ latent_ids = torch.arange(num_latents_rounded_up)
1074
+ latent_ids[latent_ids >= num_latents] = pad_id
1075
+
1076
+ latent_ids = rearrange(latent_ids, '(world latents) -> world latents', world = world_size)
1077
+ out = latent_ids[rank]
1078
+
1079
+ return out[out != pad_id].tolist()
1080
+
1038
1081
  @torch.no_grad()
1039
1082
  def forward(
1040
1083
  self,
@@ -1050,19 +1093,23 @@ class EPO(Module):
1050
1093
 
1051
1094
  cumulative_rewards = torch.zeros((self.num_latents))
1052
1095
 
1096
+ latent_ids = self.latents_for_machine()
1097
+
1053
1098
  for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1054
1099
 
1100
+ maybe_barrier()
1101
+
1055
1102
  # maybe fix seed for environment across all latents
1056
1103
 
1057
1104
  env_reset_kwargs = dict()
1058
1105
 
1059
1106
  if fix_seed_across_latents:
1060
- seed = randrange(int(1e6))
1107
+ seed = maybe_sync_seed(device = self.device)
1061
1108
  env_reset_kwargs = dict(seed = seed)
1062
1109
 
1063
1110
  # for each latent (on a single machine for now)
1064
1111
 
1065
- for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
1112
+ for latent_id in tqdm(latent_ids, desc = 'latent'):
1066
1113
  time = 0
1067
1114
 
1068
1115
  # initial state
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.53"
3
+ version = "0.0.54"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }