evolutionary-policy-optimization 0.0.52__tar.gz → 0.0.54__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (14) hide show
  1. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/PKG-INFO +1 -1
  2. evolutionary_policy_optimization-0.0.54/evolutionary_policy_optimization/distributed.py +82 -0
  3. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/epo.py +60 -5
  4. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/experimental.py +0 -0
  12. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/evolutionary_policy_optimization/mock_env.py +0 -0
  13. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/requirements.txt +0 -0
  14. {evolutionary_policy_optimization-0.0.52 → evolutionary_policy_optimization-0.0.54}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.52
3
+ Version: 0.0.54
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,82 @@
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+
6
+ import torch.distributed as dist
7
+
8
+ import einx
9
+ from einops import rearrange
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+ def default(val, d):
15
+ return val if exists(val) else d
16
+
17
+ def divisible_by(num, den):
18
+ return (num % den) == 0
19
+
20
+ def pad_dim_to(t, length, dim = 0):
21
+ pad_length = length - t.shape[dim]
22
+ zero_pairs = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
23
+ return F.pad(t, (*((0, 0) * zero_pairs), 0, pad_length))
24
+
25
+ def is_distributed():
26
+ return dist.is_initialized() and dist.get_world_size() > 1
27
+
28
+ def maybe_sync_seed(device, max_size = int(1e6)):
29
+ rand_int = torch.randint(0, max_size, (), device = device)
30
+
31
+ if is_distributed():
32
+ dist.all_reduce(rand_int)
33
+
34
+ return rand_int.item()
35
+
36
+ def maybe_barrier():
37
+ if not is_distributed():
38
+ return
39
+
40
+ dist.barrier()
41
+
42
+ def all_gather_same_dim(t):
43
+ t = t.contiguous()
44
+ world_size = dist.get_world_size()
45
+ gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
46
+ dist.all_gather(gathered_tensors, t)
47
+ return gathered_tensors
48
+
49
+ def gather_sizes(t, *, dim):
50
+ size = torch.tensor(t.shape[dim], device = t.device, dtype = torch.long)
51
+ sizes = all_gather_same_dim(size)
52
+ return torch.stack(sizes)
53
+
54
+ def has_only_one_value(t):
55
+ return (t == t[0]).all()
56
+
57
+ def all_gather_variable_dim(t, dim = 0, sizes = None):
58
+ device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
59
+
60
+ if not exists(sizes):
61
+ sizes = gather_sizes(t, dim = dim)
62
+
63
+ if has_only_one_value(sizes):
64
+ gathered_tensors = all_gather_same_dim(t)
65
+ gathered_tensors = torch.cat(gathered_tensors, dim = dim)
66
+ return gathered_tensors, sizes
67
+
68
+ max_size = sizes.amax().item()
69
+
70
+ padded_t = pad_dim_to(t, max_size, dim = dim)
71
+ gathered_tensors = all_gather_same_dim(padded_t)
72
+
73
+ gathered_tensors = torch.cat(gathered_tensors, dim = dim)
74
+ seq = torch.arange(max_size, device = device)
75
+
76
+ mask = einx.less('j i -> (i j)', seq, sizes)
77
+ seq = torch.arange(mask.shape[-1], device = device)
78
+ indices = seq[mask]
79
+
80
+ gathered_tensors = gathered_tensors.index_select(dim, indices)
81
+
82
+ return gathered_tensors, sizes
@@ -8,14 +8,22 @@ from random import randrange
8
8
  import torch
9
9
  from torch import nn, cat, stack, is_tensor, tensor, Tensor
10
10
  import torch.nn.functional as F
11
+ import torch.distributed as dist
11
12
  from torch.nn import Linear, Module, ModuleList
12
13
  from torch.utils.data import TensorDataset, DataLoader
13
- from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
14
+ from torch.utils._pytree import tree_map
14
15
 
15
16
  import einx
16
17
  from einops import rearrange, repeat, einsum, pack
17
18
  from einops.layers.torch import Rearrange
18
19
 
20
+ from evolutionary_policy_optimization.distributed import (
21
+ is_distributed,
22
+ maybe_sync_seed,
23
+ all_gather_variable_dim,
24
+ maybe_barrier
25
+ )
26
+
19
27
  from assoc_scan import AssocScan
20
28
 
21
29
  from adam_atan2_pytorch import AdoptAtan2
@@ -360,10 +368,11 @@ class LatentGenePool(Module):
360
368
  frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
361
369
  frac_elitism = 0.1, # frac of population to preserve from being noised
362
370
  frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
363
- migrate_every = 100, # how many steps before a migration between islands
364
371
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
365
372
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
366
373
  default_should_run_ga_gamma = 1.5,
374
+ migrate_every = 100, # how many steps before a migration between islands
375
+ apply_genetic_algorithm_every = 2 # how many steps before crossover + mutation happens for genes
367
376
  ):
368
377
  super().__init__()
369
378
 
@@ -413,7 +422,10 @@ class LatentGenePool(Module):
413
422
  self.should_run_genetic_algorithm = should_run_genetic_algorithm
414
423
 
415
424
  self.can_migrate = num_islands > 1
425
+
416
426
  self.migrate_every = migrate_every
427
+ self.apply_genetic_algorithm_every = apply_genetic_algorithm_every
428
+
417
429
  self.register_buffer('step', tensor(1))
418
430
 
419
431
  def get_distance(self):
@@ -483,6 +495,10 @@ class LatentGenePool(Module):
483
495
  ):
484
496
  device = self.latents.device
485
497
 
498
+ if not divisible_by(self.step.item(), self.apply_genetic_algorithm_every):
499
+ self.advance_step_()
500
+ return
501
+
486
502
  """
487
503
  i - islands
488
504
  p - population
@@ -814,6 +830,15 @@ class Agent(Module):
814
830
 
815
831
  fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
816
832
 
833
+ # stack memories
834
+
835
+ memories = map(stack, zip(*memories))
836
+
837
+ maybe_barrier()
838
+
839
+ if is_distributed():
840
+ memories = map(partial(all_gather_variable_dim, dim = 0), memories)
841
+
817
842
  (
818
843
  episode_ids,
819
844
  states,
@@ -823,7 +848,7 @@ class Agent(Module):
823
848
  rewards,
824
849
  values,
825
850
  dones
826
- ) = map(stack, zip(*memories))
851
+ ) = memories
827
852
 
828
853
  advantages = self.calc_gae(
829
854
  rewards[:-1],
@@ -1027,6 +1052,32 @@ class EPO(Module):
1027
1052
  self.episodes_per_latent = episodes_per_latent
1028
1053
  self.max_episode_length = max_episode_length
1029
1054
 
1055
+ self.register_buffer('dummy', tensor(0))
1056
+
1057
+ @property
1058
+ def device(self):
1059
+ return self.dummy.device
1060
+
1061
+ def latents_for_machine(self):
1062
+ num_latents = self.num_latents
1063
+
1064
+ if not is_distributed():
1065
+ return list(range(self.num_latents))
1066
+
1067
+ world_size, rank = dist.get_world_size(), dist.get_rank()
1068
+ assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1069
+ assert rank < world_size
1070
+
1071
+ pad_id = -1
1072
+ num_latents_rounded_up = ceil(num_latents / world_size) * world_size
1073
+ latent_ids = torch.arange(num_latents_rounded_up)
1074
+ latent_ids[latent_ids >= num_latents] = pad_id
1075
+
1076
+ latent_ids = rearrange(latent_ids, '(world latents) -> world latents', world = world_size)
1077
+ out = latent_ids[rank]
1078
+
1079
+ return out[out != pad_id].tolist()
1080
+
1030
1081
  @torch.no_grad()
1031
1082
  def forward(
1032
1083
  self,
@@ -1042,19 +1093,23 @@ class EPO(Module):
1042
1093
 
1043
1094
  cumulative_rewards = torch.zeros((self.num_latents))
1044
1095
 
1096
+ latent_ids = self.latents_for_machine()
1097
+
1045
1098
  for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1046
1099
 
1100
+ maybe_barrier()
1101
+
1047
1102
  # maybe fix seed for environment across all latents
1048
1103
 
1049
1104
  env_reset_kwargs = dict()
1050
1105
 
1051
1106
  if fix_seed_across_latents:
1052
- seed = randrange(int(1e6))
1107
+ seed = maybe_sync_seed(device = self.device)
1053
1108
  env_reset_kwargs = dict(seed = seed)
1054
1109
 
1055
1110
  # for each latent (on a single machine for now)
1056
1111
 
1057
- for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
1112
+ for latent_id in tqdm(latent_ids, desc = 'latent'):
1058
1113
  time = 0
1059
1114
 
1060
1115
  # initial state
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.52"
3
+ version = "0.0.54"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }