evolutionary-policy-optimization 0.0.55__tar.gz → 0.0.56__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (14) hide show
  1. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/epo.py +5 -2
  3. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/experimental.py +26 -1
  4. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/evolutionary_policy_optimization/mock_env.py +0 -0
  13. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/requirements.txt +0 -0
  14. {evolutionary_policy_optimization-0.0.55 → evolutionary_policy_optimization-0.0.56}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.55
3
+ Version: 0.0.56
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -374,13 +374,16 @@ class LatentGenePool(Module):
374
374
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
375
375
  default_should_run_ga_gamma = 1.5,
376
376
  migrate_every = 100, # how many steps before a migration between islands
377
- apply_genetic_algorithm_every = 2 # how many steps before crossover + mutation happens for genes
377
+ apply_genetic_algorithm_every = 2, # how many steps before crossover + mutation happens for genes
378
+ init_latent_fn: Callable = None
378
379
  ):
379
380
  super().__init__()
380
381
 
381
382
  maybe_l2norm = l2norm if l2norm_latent else identity
382
383
 
383
- latents = torch.randn(num_latents, dim_latent)
384
+ init_fn = default(init_latent_fn, torch.randn)
385
+
386
+ latents = init_fn((num_latents, dim_latent))
384
387
 
385
388
  if l2norm_latent:
386
389
  latents = maybe_l2norm(latents, dim = -1)
@@ -39,9 +39,34 @@ def crossover_weights(w1, w2, transpose = False):
39
39
 
40
40
  return out
41
41
 
42
+ def mutate_weight(
43
+ w,
44
+ transpose = False,
45
+ mutation_strength = 1.
46
+ ):
47
+
48
+ if transpose:
49
+ w = w.transpose(-1, -2)
50
+
51
+ rank = min(w2.shape[1:])
52
+ assert rank >= 2
53
+
54
+ u, s, v = torch.svd(w)
55
+ u = u + torch.randn_like(u) * mutation_strength
56
+ v = v + torch.randn_like(v) * mutation_strength
57
+
58
+ out = u @ torch.diag_embed(s) @ v.mT
59
+
60
+ if transpose:
61
+ out = out.transpose(-1, -2)
62
+
63
+ return out
64
+
42
65
  if __name__ == '__main__':
43
66
  w1 = torch.randn(32, 16)
44
67
  w2 = torch.randn(32, 16)
45
- child = crossover_weights(w2, w2)
68
+
69
+ child = crossover_weights(w1, w2)
70
+ mutated_w1 = mutate_weight(w1)
46
71
 
47
72
  assert child.shape == w2.shape
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.55"
3
+ version = "0.0.56"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }