evolutionary-policy-optimization 0.1.7__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/PKG-INFO +11 -1
  2. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/README.md +10 -0
  3. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/epo.py +59 -3
  4. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/tests/test_epo.py +4 -1
  6. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.7 → evolutionary_policy_optimization-0.1.8}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -196,4 +196,14 @@ agent.load('./agent.pt')
196
196
  }
197
197
  ```
198
198
 
199
+ ```bibtex
200
+ @article{Doerr2017FastGA,
201
+ title = {Fast genetic algorithms},
202
+ author = {Benjamin Doerr and Huu Phuoc Le and R{\'e}gis Makhmara and Ta Duy Nguyen},
203
+ journal = {Proceedings of the Genetic and Evolutionary Computation Conference},
204
+ year = {2017},
205
+ url = {https://api.semanticscholar.org/CorpusID:16196841}
206
+ }
207
+ ```
208
+
199
209
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -143,4 +143,14 @@ agent.load('./agent.pt')
143
143
  }
144
144
  ```
145
145
 
146
+ ```bibtex
147
+ @article{Doerr2017FastGA,
148
+ title = {Fast genetic algorithms},
149
+ author = {Benjamin Doerr and Huu Phuoc Le and R{\'e}gis Makhmara and Ta Duy Nguyen},
150
+ journal = {Proceedings of the Genetic and Evolutionary Computation Conference},
151
+ year = {2017},
152
+ url = {https://api.semanticscholar.org/CorpusID:16196841}
153
+ }
154
+ ```
155
+
146
156
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -207,13 +207,60 @@ def mutation(
207
207
  ):
208
208
  mutations = torch.randn_like(latents)
209
209
 
210
- mutated = latents + mutations * mutation_strength
210
+ if is_tensor(mutation_strength):
211
+ mutations = einx.multiply('b, b ...', mutation_strength, mutations)
212
+ else:
213
+ mutations *= mutation_strength
214
+
215
+ mutated = latents + mutations
211
216
 
212
217
  if not l2norm_output:
213
218
  return mutated
214
219
 
215
220
  return l2norm(mutated)
216
221
 
222
+ # drawing mutation strengths from power law distribution
223
+ # proposed by https://arxiv.org/abs/1703.03334
224
+
225
+ class PowerLawDist(Module):
226
+ def __init__(
227
+ self,
228
+ values: Tensor | list[float] | None = None,
229
+ bins = None,
230
+ beta = 1.5,
231
+ ):
232
+ super().__init__()
233
+ assert beta > 1.
234
+
235
+ assert exists(bins) or exists(values)
236
+
237
+ if exists(values):
238
+ if not is_tensor(values):
239
+ values = tensor(values)
240
+
241
+ assert values.ndim == 1
242
+ bins = values.shape[0]
243
+
244
+ self.beta = beta
245
+
246
+ cdf = torch.linspace(1, bins, bins).pow(-beta).cumsum(dim = -1)
247
+ cdf = cdf / cdf[-1]
248
+
249
+ self.register_buffer('cdf', cdf)
250
+ self.register_buffer('values', values)
251
+
252
+ def forward(self, shape):
253
+ device = self.cdf.device
254
+
255
+ uniform = torch.rand(shape, device = device)
256
+
257
+ sampled = torch.searchsorted(self.cdf, uniform)
258
+
259
+ if not exists(self.values):
260
+ return sampled
261
+
262
+ return self.values[sampled]
263
+
217
264
  # simple MLP networks, but with latent variables
218
265
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
219
266
 
@@ -441,6 +488,8 @@ class LatentGenePool(Module):
441
488
  frac_elitism = 0.1, # frac of population to preserve from being noised
442
489
  frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
443
490
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
491
+ fast_genetic_algorithm = False,
492
+ fast_ga_values = torch.linspace(1, 5, 10),
444
493
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
445
494
  default_should_run_ga_gamma = 1.5,
446
495
  migrate_every = 100, # how many steps before a migration between islands
@@ -488,6 +537,8 @@ class LatentGenePool(Module):
488
537
  self.crossover_random = crossover_random
489
538
 
490
539
  self.mutation_strength = mutation_strength
540
+ self.mutation_strength_sampler = PowerLawDist(fast_ga_values) if fast_genetic_algorithm else None
541
+
491
542
  self.num_elites = int(frac_elitism * latents_per_island)
492
543
  self.has_elites = self.num_elites > 0
493
544
 
@@ -656,9 +707,14 @@ class LatentGenePool(Module):
656
707
  if self.has_elites:
657
708
  genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
658
709
 
659
- # 5. mutate with gaussian noise - todo: add drawing the mutation rate from exponential distribution, from the fast genetic algorithms paper from 2017
710
+ # 5. mutate with gaussian noise
711
+
712
+ if exists(self.mutation_strength_sampler):
713
+ mutation_strength = self.mutation_strength_sampler(genes.shape[:1])
714
+ else:
715
+ mutation_strength = self.mutation_strength
660
716
 
661
- genes = mutation(genes, mutation_strength = self.mutation_strength)
717
+ genes = mutation(genes, mutation_strength = mutation_strength)
662
718
 
663
719
  # 6. maybe migration
664
720
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.7"
3
+ version = "0.1.8"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -9,15 +9,18 @@ from evolutionary_policy_optimization import (
9
9
 
10
10
  @pytest.mark.parametrize('latent_ids', (2, (2, 4)))
11
11
  @pytest.mark.parametrize('num_islands', (1, 4))
12
+ @pytest.mark.parametrize('sampled_mutation_strengths', (False, True))
12
13
  def test_readme(
13
14
  latent_ids,
14
- num_islands
15
+ num_islands,
16
+ sampled_mutation_strengths
15
17
  ):
16
18
 
17
19
  latent_pool = LatentGenePool(
18
20
  num_latents = 128,
19
21
  dim_latent = 32,
20
22
  num_islands = num_islands,
23
+ fast_genetic_algorithm = sampled_mutation_strengths
21
24
  )
22
25
 
23
26
  state = torch.randn(2, 512)