evolutionary-policy-optimization 0.0.54__py3-none-any.whl → 0.0.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,12 @@ def pad_dim_to(t, length, dim = 0):
25
25
  def is_distributed():
26
26
  return dist.is_initialized() and dist.get_world_size() > 1
27
27
 
28
+ def get_world_and_rank():
29
+ if not is_distributed():
30
+ return 1, 0
31
+
32
+ return dist.get_world_size(), dist.get_rank()
33
+
28
34
  def maybe_sync_seed(device, max_size = int(1e6)):
29
35
  rand_int = torch.randint(0, max_size, (), device = device)
30
36
 
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial, wraps
4
3
  from pathlib import Path
4
+ from math import ceil
5
+ from functools import partial, wraps
5
6
  from collections import namedtuple
6
7
  from random import randrange
7
8
 
@@ -19,6 +20,7 @@ from einops.layers.torch import Rearrange
19
20
 
20
21
  from evolutionary_policy_optimization.distributed import (
21
22
  is_distributed,
23
+ get_world_and_rank,
22
24
  maybe_sync_seed,
23
25
  all_gather_variable_dim,
24
26
  maybe_barrier
@@ -372,13 +374,16 @@ class LatentGenePool(Module):
372
374
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
373
375
  default_should_run_ga_gamma = 1.5,
374
376
  migrate_every = 100, # how many steps before a migration between islands
375
- apply_genetic_algorithm_every = 2 # how many steps before crossover + mutation happens for genes
377
+ apply_genetic_algorithm_every = 2, # how many steps before crossover + mutation happens for genes
378
+ init_latent_fn: Callable = None
376
379
  ):
377
380
  super().__init__()
378
381
 
379
382
  maybe_l2norm = l2norm if l2norm_latent else identity
380
383
 
381
- latents = torch.randn(num_latents, dim_latent)
384
+ init_fn = default(init_latent_fn, torch.randn)
385
+
386
+ latents = init_fn((num_latents, dim_latent))
382
387
 
383
388
  if l2norm_latent:
384
389
  latents = maybe_l2norm(latents, dim = -1)
@@ -1061,22 +1066,20 @@ class EPO(Module):
1061
1066
  def latents_for_machine(self):
1062
1067
  num_latents = self.num_latents
1063
1068
 
1064
- if not is_distributed():
1065
- return list(range(self.num_latents))
1069
+ world_size, rank = get_world_and_rank()
1066
1070
 
1067
- world_size, rank = dist.get_world_size(), dist.get_rank()
1068
1071
  assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1069
1072
  assert rank < world_size
1070
1073
 
1071
- pad_id = -1
1072
- num_latents_rounded_up = ceil(num_latents / world_size) * world_size
1073
- latent_ids = torch.arange(num_latents_rounded_up)
1074
- latent_ids[latent_ids >= num_latents] = pad_id
1074
+ num_latents_per_machine = ceil(num_latents / world_size)
1075
+
1076
+ for i in range(num_latents_per_machine):
1077
+ latent_id = rank * num_latents_per_machine + i
1075
1078
 
1076
- latent_ids = rearrange(latent_ids, '(world latents) -> world latents', world = world_size)
1077
- out = latent_ids[rank]
1079
+ if latent_id >= num_latents:
1080
+ continue
1078
1081
 
1079
- return out[out != pad_id].tolist()
1082
+ yield i
1080
1083
 
1081
1084
  @torch.no_grad()
1082
1085
  def forward(
@@ -1093,7 +1096,7 @@ class EPO(Module):
1093
1096
 
1094
1097
  cumulative_rewards = torch.zeros((self.num_latents))
1095
1098
 
1096
- latent_ids = self.latents_for_machine()
1099
+ latent_ids_gen = self.latents_for_machine()
1097
1100
 
1098
1101
  for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1099
1102
 
@@ -1109,7 +1112,7 @@ class EPO(Module):
1109
1112
 
1110
1113
  # for each latent (on a single machine for now)
1111
1114
 
1112
- for latent_id in tqdm(latent_ids, desc = 'latent'):
1115
+ for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
1113
1116
  time = 0
1114
1117
 
1115
1118
  # initial state
@@ -39,9 +39,34 @@ def crossover_weights(w1, w2, transpose = False):
39
39
 
40
40
  return out
41
41
 
42
+ def mutate_weight(
43
+ w,
44
+ transpose = False,
45
+ mutation_strength = 1.
46
+ ):
47
+
48
+ if transpose:
49
+ w = w.transpose(-1, -2)
50
+
51
+ rank = min(w2.shape[1:])
52
+ assert rank >= 2
53
+
54
+ u, s, v = torch.svd(w)
55
+ u = u + torch.randn_like(u) * mutation_strength
56
+ v = v + torch.randn_like(v) * mutation_strength
57
+
58
+ out = u @ torch.diag_embed(s) @ v.mT
59
+
60
+ if transpose:
61
+ out = out.transpose(-1, -2)
62
+
63
+ return out
64
+
42
65
  if __name__ == '__main__':
43
66
  w1 = torch.randn(32, 16)
44
67
  w2 = torch.randn(32, 16)
45
- child = crossover_weights(w2, w2)
68
+
69
+ child = crossover_weights(w1, w2)
70
+ mutated_w1 = mutate_weight(w1)
46
71
 
47
72
  assert child.shape == w2.shape
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.54
3
+ Version: 0.0.56
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=N7xmO3CRXeaJAy-2rysZg-DBvkZCZB2ySJT7Iq__r6w,35217
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
+ evolutionary_policy_optimization-0.0.56.dist-info/METADATA,sha256=o2-1eCh8MuQVd0SH0GiUBBIAcqdK7cceuiu093cuEA4,6213
7
+ evolutionary_policy_optimization-0.0.56.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.56.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.56.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=lSSf_vB04NgVJFBh2n36cGuKZWgOpp8PnPpLDmHT6nU,2296
3
- evolutionary_policy_optimization/epo.py,sha256=5QJj_l4pihbSdRk1aZnE2dUyWlaqb_VjIKo6Azzksgs,35292
4
- evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
5
- evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
- evolutionary_policy_optimization-0.0.54.dist-info/METADATA,sha256=phQq8QaMT7TQQG2Sqz1BW4E1dln1HU10DMExwRvGGkg,6213
7
- evolutionary_policy_optimization-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.54.dist-info/RECORD,,