evolutionary-policy-optimization 0.0.55__py3-none-any.whl → 0.0.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from typing import Callable
2
3
 
3
4
  from pathlib import Path
4
5
  from math import ceil
@@ -374,13 +375,16 @@ class LatentGenePool(Module):
374
375
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
375
376
  default_should_run_ga_gamma = 1.5,
376
377
  migrate_every = 100, # how many steps before a migration between islands
377
- apply_genetic_algorithm_every = 2 # how many steps before crossover + mutation happens for genes
378
+ apply_genetic_algorithm_every = 2, # how many steps before crossover + mutation happens for genes
379
+ init_latent_fn: Callable | None = None
378
380
  ):
379
381
  super().__init__()
380
382
 
381
383
  maybe_l2norm = l2norm if l2norm_latent else identity
382
384
 
383
- latents = torch.randn(num_latents, dim_latent)
385
+ init_fn = default(init_latent_fn, torch.randn)
386
+
387
+ latents = init_fn((num_latents, dim_latent))
384
388
 
385
389
  if l2norm_latent:
386
390
  latents = maybe_l2norm(latents, dim = -1)
@@ -39,9 +39,34 @@ def crossover_weights(w1, w2, transpose = False):
39
39
 
40
40
  return out
41
41
 
42
+ def mutate_weight(
43
+ w,
44
+ transpose = False,
45
+ mutation_strength = 1.
46
+ ):
47
+
48
+ if transpose:
49
+ w = w.transpose(-1, -2)
50
+
51
+ rank = min(w2.shape[1:])
52
+ assert rank >= 2
53
+
54
+ u, s, v = torch.svd(w)
55
+ u = u + torch.randn_like(u) * mutation_strength
56
+ v = v + torch.randn_like(v) * mutation_strength
57
+
58
+ out = u @ torch.diag_embed(s) @ v.mT
59
+
60
+ if transpose:
61
+ out = out.transpose(-1, -2)
62
+
63
+ return out
64
+
42
65
  if __name__ == '__main__':
43
66
  w1 = torch.randn(32, 16)
44
67
  w2 = torch.randn(32, 16)
45
- child = crossover_weights(w2, w2)
68
+
69
+ child = crossover_weights(w1, w2)
70
+ mutated_w1 = mutate_weight(w1)
46
71
 
47
72
  assert child.shape == w2.shape
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.55
3
+ Version: 0.0.57
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=qPj5kRsISY1I6WjCc-ejpuiwOSxtPsSdMABmchXJ3s0,35252
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
+ evolutionary_policy_optimization-0.0.57.dist-info/METADATA,sha256=WBHRK98s_lzWbqG4ouq620ayykPF9SHUz3HdvsRUywc,6213
7
+ evolutionary_policy_optimization-0.0.57.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.57.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.57.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=e0AI7S5QK_uLfokzWTnsAua_HcPW0PyqY-GzUUev0R8,35123
4
- evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
5
- evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
- evolutionary_policy_optimization-0.0.55.dist-info/METADATA,sha256=nsWgp2caBwAiWKMU_BH6Sw58gHdpxE29vXxbAXxWa70,6213
7
- evolutionary_policy_optimization-0.0.55.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.55.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.55.dist-info/RECORD,,