evolutionary-policy-optimization 0.0.53__tar.gz → 0.0.55__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (14) hide show
  1. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/PKG-INFO +1 -1
  2. evolutionary_policy_optimization-0.0.55/evolutionary_policy_optimization/distributed.py +88 -0
  3. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/epo.py +52 -5
  4. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/experimental.py +0 -0
  12. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/evolutionary_policy_optimization/mock_env.py +0 -0
  13. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/requirements.txt +0 -0
  14. {evolutionary_policy_optimization-0.0.53 → evolutionary_policy_optimization-0.0.55}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.53
3
+ Version: 0.0.55
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,88 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+
6
+ import torch.distributed as dist
7
+
8
+ import einx
9
+ from einops import rearrange
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+ def default(val, d):
15
+ return val if exists(val) else d
16
+
17
+ def divisible_by(num, den):
18
+ return (num % den) == 0
19
+
20
+ def pad_dim_to(t, length, dim = 0):
21
+ pad_length = length - t.shape[dim]
22
+ zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
23
+ return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
24
+
25
+ def is_distributed():
26
+ return dist.is_initialized() and dist.get_world_size() > 1
27
+
28
+ def get_world_and_rank():
29
+ if not is_distributed():
30
+ return 1, 0
31
+
32
+ return dist.get_world_size(), dist.get_rank()
33
+
34
+ def maybe_sync_seed(device, max_size = int(1e6)):
35
+ rand_int = torch.randint(0, max_size, (), device = device)
36
+
37
+ if is_distributed():
38
+ dist.all_reduce(rand_int)
39
+
40
+ return rand_int.item()
41
+
42
+ def maybe_barrier():
43
+ if not is_distributed():
44
+ return
45
+
46
+ dist.barrier()
47
+
48
+ def all_gather_same_dim(t):
49
+ t = t.contiguous()
50
+ world_size = dist.get_world_size()
51
+ gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
52
+ dist.all_gather(gathered_tensors, t)
53
+ return gathered_tensors
54
+
55
+ def gather_sizes(t, *, dim):
56
+ size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
57
+ sizes = all_gather_same_dim(size)
58
+ return torch.stack(sizes)
59
+
60
+ def has_only_one_value(t):
61
+ return (t == t[0]).all()
62
+
63
+ def all_gather_variable_dim(t, dim = 0, sizes = None):
64
+ device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
65
+
66
+ if not exists(sizes):
67
+ sizes = gather_sizes(t, dim = dim)
68
+
69
+ if has_only_one_value(sizes):
70
+ gathered_tensors = all_gather_same_dim(t)
71
+ gathered_tensors = torch.cat(gathered_tensors, dim = dim)
72
+ return gathered_tensors, sizes
73
+
74
+ max_size = sizes.amax().item()
75
+
76
+ padded_t = pad_dim_to(t, max_size, dim = dim)
77
+ gathered_tensors = all_gather_same_dim(padded_t)
78
+
79
+ gathered_tensors = torch.cat(gathered_tensors, dim = dim)
80
+ seq = torch.arange(max_size, device = device)
81
+
82
+ mask = einx.less('j i -> (i j)', seq, sizes)
83
+ seq = torch.arange(mask.shape[-1], device = device)
84
+ indices = seq[mask]
85
+
86
+ gathered_tensors = gathered_tensors.index_select(dim, indices)
87
+
88
+ return gathered_tensors, sizes
@@ -1,21 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial, wraps
4
3
  from pathlib import Path
4
+ from math import ceil
5
+ from functools import partial, wraps
5
6
  from collections import namedtuple
6
7
  from random import randrange
7
8
 
8
9
  import torch
9
10
  from torch import nn, cat, stack, is_tensor, tensor, Tensor
10
11
  import torch.nn.functional as F
12
+ import torch.distributed as dist
11
13
  from torch.nn import Linear, Module, ModuleList
12
14
  from torch.utils.data import TensorDataset, DataLoader
13
- from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
15
+ from torch.utils._pytree import tree_map
14
16
 
15
17
  import einx
16
18
  from einops import rearrange, repeat, einsum, pack
17
19
  from einops.layers.torch import Rearrange
18
20
 
21
+ from evolutionary_policy_optimization.distributed import (
22
+ is_distributed,
23
+ get_world_and_rank,
24
+ maybe_sync_seed,
25
+ all_gather_variable_dim,
26
+ maybe_barrier
27
+ )
28
+
19
29
  from assoc_scan import AssocScan
20
30
 
21
31
  from adam_atan2_pytorch import AdoptAtan2
@@ -822,6 +832,15 @@ class Agent(Module):
822
832
 
823
833
  fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
824
834
 
835
+ # stack memories
836
+
837
+ memories = map(stack, zip(*memories))
838
+
839
+ maybe_barrier()
840
+
841
+ if is_distributed():
842
+ memories = map(partial(all_gather_variable_dim, dim = 0), memories)
843
+
825
844
  (
826
845
  episode_ids,
827
846
  states,
@@ -831,7 +850,7 @@ class Agent(Module):
831
850
  rewards,
832
851
  values,
833
852
  dones
834
- ) = map(stack, zip(*memories))
853
+ ) = memories
835
854
 
836
855
  advantages = self.calc_gae(
837
856
  rewards[:-1],
@@ -1035,6 +1054,30 @@ class EPO(Module):
1035
1054
  self.episodes_per_latent = episodes_per_latent
1036
1055
  self.max_episode_length = max_episode_length
1037
1056
 
1057
+ self.register_buffer('dummy', tensor(0))
1058
+
1059
+ @property
1060
+ def device(self):
1061
+ return self.dummy.device
1062
+
1063
+ def latents_for_machine(self):
1064
+ num_latents = self.num_latents
1065
+
1066
+ world_size, rank = get_world_and_rank()
1067
+
1068
+ assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1069
+ assert rank < world_size
1070
+
1071
+ num_latents_per_machine = ceil(num_latents / world_size)
1072
+
1073
+ for i in range(num_latents_per_machine):
1074
+ latent_id = rank * num_latents_per_machine + i
1075
+
1076
+ if latent_id >= num_latents:
1077
+ continue
1078
+
1079
+ yield i
1080
+
1038
1081
  @torch.no_grad()
1039
1082
  def forward(
1040
1083
  self,
@@ -1050,19 +1093,23 @@ class EPO(Module):
1050
1093
 
1051
1094
  cumulative_rewards = torch.zeros((self.num_latents))
1052
1095
 
1096
+ latent_ids_gen = self.latents_for_machine()
1097
+
1053
1098
  for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1054
1099
 
1100
+ maybe_barrier()
1101
+
1055
1102
  # maybe fix seed for environment across all latents
1056
1103
 
1057
1104
  env_reset_kwargs = dict()
1058
1105
 
1059
1106
  if fix_seed_across_latents:
1060
- seed = randrange(int(1e6))
1107
+ seed = maybe_sync_seed(device = self.device)
1061
1108
  env_reset_kwargs = dict(seed = seed)
1062
1109
 
1063
1110
  # for each latent (on a single machine for now)
1064
1111
 
1065
- for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
1112
+ for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
1066
1113
  time = 0
1067
1114
 
1068
1115
  # initial state
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.53"
3
+ version = "0.0.55"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }