evolutionary-policy-optimization 0.1.6__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/PKG-INFO +11 -1
  2. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/README.md +10 -0
  3. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/epo.py +67 -8
  4. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/tests/test_epo.py +6 -2
  6. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.6 → evolutionary_policy_optimization-0.1.8}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -196,4 +196,14 @@ agent.load('./agent.pt')
196
196
  }
197
197
  ```
198
198
 
199
+ ```bibtex
200
+ @article{Doerr2017FastGA,
201
+ title = {Fast genetic algorithms},
202
+ author = {Benjamin Doerr and Huu Phuoc Le and R{\'e}gis Makhmara and Ta Duy Nguyen},
203
+ journal = {Proceedings of the Genetic and Evolutionary Computation Conference},
204
+ year = {2017},
205
+ url = {https://api.semanticscholar.org/CorpusID:16196841}
206
+ }
207
+ ```
208
+
199
209
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -143,4 +143,14 @@ agent.load('./agent.pt')
143
143
  }
144
144
  ```
145
145
 
146
+ ```bibtex
147
+ @article{Doerr2017FastGA,
148
+ title = {Fast genetic algorithms},
149
+ author = {Benjamin Doerr and Huu Phuoc Le and R{\'e}gis Makhmara and Ta Duy Nguyen},
150
+ journal = {Proceedings of the Genetic and Evolutionary Computation Conference},
151
+ year = {2017},
152
+ url = {https://api.semanticscholar.org/CorpusID:16196841}
153
+ }
154
+ ```
155
+
146
156
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -207,13 +207,60 @@ def mutation(
207
207
  ):
208
208
  mutations = torch.randn_like(latents)
209
209
 
210
- mutated = latents + mutations * mutation_strength
210
+ if is_tensor(mutation_strength):
211
+ mutations = einx.multiply('b, b ...', mutation_strength, mutations)
212
+ else:
213
+ mutations *= mutation_strength
214
+
215
+ mutated = latents + mutations
211
216
 
212
217
  if not l2norm_output:
213
218
  return mutated
214
219
 
215
220
  return l2norm(mutated)
216
221
 
222
+ # drawing mutation strengths from power law distribution
223
+ # proposed by https://arxiv.org/abs/1703.03334
224
+
225
+ class PowerLawDist(Module):
226
+ def __init__(
227
+ self,
228
+ values: Tensor | list[float] | None = None,
229
+ bins = None,
230
+ beta = 1.5,
231
+ ):
232
+ super().__init__()
233
+ assert beta > 1.
234
+
235
+ assert exists(bins) or exists(values)
236
+
237
+ if exists(values):
238
+ if not is_tensor(values):
239
+ values = tensor(values)
240
+
241
+ assert values.ndim == 1
242
+ bins = values.shape[0]
243
+
244
+ self.beta = beta
245
+
246
+ cdf = torch.linspace(1, bins, bins).pow(-beta).cumsum(dim = -1)
247
+ cdf = cdf / cdf[-1]
248
+
249
+ self.register_buffer('cdf', cdf)
250
+ self.register_buffer('values', values)
251
+
252
+ def forward(self, shape):
253
+ device = self.cdf.device
254
+
255
+ uniform = torch.rand(shape, device = device)
256
+
257
+ sampled = torch.searchsorted(self.cdf, uniform)
258
+
259
+ if not exists(self.values):
260
+ return sampled
261
+
262
+ return self.values[sampled]
263
+
217
264
  # simple MLP networks, but with latent variables
218
265
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
219
266
 
@@ -369,7 +416,6 @@ class Critic(Module):
369
416
  hl_gauss_loss = self.to_pred.hl_gauss_loss
370
417
 
371
418
  self.maybe_bins_to_value = hl_gauss_loss if not use_regression else identity
372
- self.maybe_value_to_bins = hl_gauss_loss.transform_to_logprobs if not use_regression else identity
373
419
  self.loss_fn = hl_gauss_loss if not use_regression else F.mse_loss
374
420
 
375
421
  def forward_for_loss(
@@ -386,7 +432,7 @@ class Critic(Module):
386
432
 
387
433
  clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
388
434
 
389
- loss = self.loss_fn(value, target, reduction = 'none')
435
+ loss = self.loss_fn(logits, target, reduction = 'none')
390
436
  clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
391
437
 
392
438
  return torch.max(loss, clipped_loss).mean()
@@ -442,6 +488,8 @@ class LatentGenePool(Module):
442
488
  frac_elitism = 0.1, # frac of population to preserve from being noised
443
489
  frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
444
490
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
491
+ fast_genetic_algorithm = False,
492
+ fast_ga_values = torch.linspace(1, 5, 10),
445
493
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
446
494
  default_should_run_ga_gamma = 1.5,
447
495
  migrate_every = 100, # how many steps before a migration between islands
@@ -489,6 +537,8 @@ class LatentGenePool(Module):
489
537
  self.crossover_random = crossover_random
490
538
 
491
539
  self.mutation_strength = mutation_strength
540
+ self.mutation_strength_sampler = PowerLawDist(fast_ga_values) if fast_genetic_algorithm else None
541
+
492
542
  self.num_elites = int(frac_elitism * latents_per_island)
493
543
  self.has_elites = self.num_elites > 0
494
544
 
@@ -657,9 +707,14 @@ class LatentGenePool(Module):
657
707
  if self.has_elites:
658
708
  genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
659
709
 
660
- # 5. mutate with gaussian noise - todo: add drawing the mutation rate from exponential distribution, from the fast genetic algorithms paper from 2017
710
+ # 5. mutate with gaussian noise
711
+
712
+ if exists(self.mutation_strength_sampler):
713
+ mutation_strength = self.mutation_strength_sampler(genes.shape[:1])
714
+ else:
715
+ mutation_strength = self.mutation_strength
661
716
 
662
- genes = mutation(genes, mutation_strength = self.mutation_strength)
717
+ genes = mutation(genes, mutation_strength = mutation_strength)
663
718
 
664
719
  # 6. maybe migration
665
720
 
@@ -844,7 +899,11 @@ class Agent(Module):
844
899
 
845
900
  dummy = tensor(0)
846
901
 
902
+ self.clip_grad_norm_ = nn.utils.clip_grad_norm_
903
+
847
904
  if wrap_with_accelerate:
905
+ self.clip_grad_norm_ = self.accelerate.clip_grad_norm_
906
+
848
907
  (
849
908
  self.actor,
850
909
  self.critic,
@@ -1071,7 +1130,7 @@ class Agent(Module):
1071
1130
  actor_loss.backward()
1072
1131
 
1073
1132
  if exists(self.has_grad_clip):
1074
- self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
1133
+ self.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
1075
1134
 
1076
1135
  self.actor_optim.step()
1077
1136
  self.actor_optim.zero_grad()
@@ -1089,7 +1148,7 @@ class Agent(Module):
1089
1148
  critic_loss.backward()
1090
1149
 
1091
1150
  if exists(self.has_grad_clip):
1092
- self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
1151
+ self.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
1093
1152
 
1094
1153
  self.critic_optim.step()
1095
1154
  self.critic_optim.zero_grad()
@@ -1113,7 +1172,7 @@ class Agent(Module):
1113
1172
  (diversity_loss * self.diversity_aux_loss_weight).backward()
1114
1173
 
1115
1174
  if exists(self.has_grad_clip):
1116
- self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
1175
+ self.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
1117
1176
 
1118
1177
  self.latent_optim.step()
1119
1178
  self.latent_optim.zero_grad()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.6"
3
+ version = "0.1.8"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -9,15 +9,18 @@ from evolutionary_policy_optimization import (
9
9
 
10
10
  @pytest.mark.parametrize('latent_ids', (2, (2, 4)))
11
11
  @pytest.mark.parametrize('num_islands', (1, 4))
12
+ @pytest.mark.parametrize('sampled_mutation_strengths', (False, True))
12
13
  def test_readme(
13
14
  latent_ids,
14
- num_islands
15
+ num_islands,
16
+ sampled_mutation_strengths
15
17
  ):
16
18
 
17
19
  latent_pool = LatentGenePool(
18
20
  num_latents = 128,
19
21
  dim_latent = 32,
20
22
  num_islands = num_islands,
23
+ fast_genetic_algorithm = sampled_mutation_strengths
21
24
  )
22
25
 
23
26
  state = torch.randn(2, 512)
@@ -103,7 +106,8 @@ def test_e2e_with_mock_env(
103
106
  frozen_latents = frozen_latents,
104
107
  frac_natural_selected = 0.75,
105
108
  frac_tournaments = 0.9
106
- )
109
+ ),
110
+ wrap_with_accelerate = False
107
111
  )
108
112
 
109
113
  epo = EPO(