evolutionary-policy-optimization 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -207,13 +207,60 @@ def mutation(
207
207
  ):
208
208
  mutations = torch.randn_like(latents)
209
209
 
210
- mutated = latents + mutations * mutation_strength
210
+ if is_tensor(mutation_strength):
211
+ mutations = einx.multiply('b, b ...', mutation_strength, mutations)
212
+ else:
213
+ mutations *= mutation_strength
214
+
215
+ mutated = latents + mutations
211
216
 
212
217
  if not l2norm_output:
213
218
  return mutated
214
219
 
215
220
  return l2norm(mutated)
216
221
 
222
+ # drawing mutation strengths from power law distribution
223
+ # proposed by https://arxiv.org/abs/1703.03334
224
+
225
+ class PowerLawDist(Module):
226
+ def __init__(
227
+ self,
228
+ values: Tensor | list[float] | None = None,
229
+ bins = None,
230
+ beta = 1.5,
231
+ ):
232
+ super().__init__()
233
+ assert beta > 1.
234
+
235
+ assert exists(bins) or exists(values)
236
+
237
+ if exists(values):
238
+ if not is_tensor(values):
239
+ values = tensor(values)
240
+
241
+ assert values.ndim == 1
242
+ bins = values.shape[0]
243
+
244
+ self.beta = beta
245
+
246
+ cdf = torch.linspace(1, bins, bins).pow(-beta).cumsum(dim = -1)
247
+ cdf = cdf / cdf[-1]
248
+
249
+ self.register_buffer('cdf', cdf)
250
+ self.register_buffer('values', values)
251
+
252
+ def forward(self, shape):
253
+ device = self.cdf.device
254
+
255
+ uniform = torch.rand(shape, device = device)
256
+
257
+ sampled = torch.searchsorted(self.cdf, uniform)
258
+
259
+ if not exists(self.values):
260
+ return sampled
261
+
262
+ return self.values[sampled]
263
+
217
264
  # simple MLP networks, but with latent variables
218
265
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
219
266
 
@@ -377,18 +424,38 @@ class Critic(Module):
377
424
  latent,
378
425
  old_values,
379
426
  target,
380
- eps_clip = 0.4
427
+ eps_clip = 0.4,
428
+ use_improved = True
381
429
  ):
382
430
  logits = self.forward(state, latent, return_logits = True)
383
431
 
384
432
  value = self.maybe_bins_to_value(logits)
385
433
 
386
- clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
434
+ if use_improved:
435
+ clipped_target = target.clamp(-eps_clip, eps_clip)
436
+
437
+ old_values_lo = old_values - eps_clip
438
+ old_values_hi = old_values + eps_clip
439
+
440
+ is_between = lambda lo, hi: (lo < value) & (value < hi)
387
441
 
388
- loss = self.loss_fn(logits, target, reduction = 'none')
389
- clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
442
+ clipped_loss = self.loss_fn(logits, clipped_target, reduction = 'none')
443
+ loss = self.loss_fn(logits, target, reduction = 'none')
390
444
 
391
- return torch.max(loss, clipped_loss).mean()
445
+ value_loss = torch.where(
446
+ is_between(target, old_values_lo) | is_between(old_values_hi, target),
447
+ 0.,
448
+ torch.min(loss, clipped_loss)
449
+ )
450
+ else:
451
+ clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
452
+
453
+ loss = self.loss_fn(logits, target, reduction = 'none')
454
+ clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
455
+
456
+ value_loss = torch.max(loss, clipped_loss)
457
+
458
+ return value_loss.mean()
392
459
 
393
460
  def forward(
394
461
  self,
@@ -441,6 +508,8 @@ class LatentGenePool(Module):
441
508
  frac_elitism = 0.1, # frac of population to preserve from being noised
442
509
  frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
443
510
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
511
+ fast_genetic_algorithm = False,
512
+ fast_ga_values = torch.linspace(1, 5, 10),
444
513
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
445
514
  default_should_run_ga_gamma = 1.5,
446
515
  migrate_every = 100, # how many steps before a migration between islands
@@ -488,6 +557,8 @@ class LatentGenePool(Module):
488
557
  self.crossover_random = crossover_random
489
558
 
490
559
  self.mutation_strength = mutation_strength
560
+ self.mutation_strength_sampler = PowerLawDist(fast_ga_values) if fast_genetic_algorithm else None
561
+
491
562
  self.num_elites = int(frac_elitism * latents_per_island)
492
563
  self.has_elites = self.num_elites > 0
493
564
 
@@ -656,9 +727,14 @@ class LatentGenePool(Module):
656
727
  if self.has_elites:
657
728
  genes, elites = genes[:, :-self.num_elites], genes[:, -self.num_elites:]
658
729
 
659
- # 5. mutate with gaussian noise - todo: add drawing the mutation rate from exponential distribution, from the fast genetic algorithms paper from 2017
730
+ # 5. mutate with gaussian noise
731
+
732
+ if exists(self.mutation_strength_sampler):
733
+ mutation_strength = self.mutation_strength_sampler(genes.shape[:1])
734
+ else:
735
+ mutation_strength = self.mutation_strength
660
736
 
661
- genes = mutation(genes, mutation_strength = self.mutation_strength)
737
+ genes = mutation(genes, mutation_strength = mutation_strength)
662
738
 
663
739
  # 6. maybe migration
664
740
 
@@ -770,6 +846,7 @@ class Agent(Module):
770
846
  critic_loss_kwargs: dict = dict(
771
847
  eps_clip = 0.4
772
848
  ),
849
+ use_improved_critic_loss = True,
773
850
  ema_kwargs: dict = dict(),
774
851
  actor_optim_kwargs: dict = dict(),
775
852
  critic_optim_kwargs: dict = dict(),
@@ -815,6 +892,8 @@ class Agent(Module):
815
892
  self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
816
893
  self.critic_loss_kwargs = critic_loss_kwargs
817
894
 
895
+ self.use_improved_critic_loss = use_improved_critic_loss
896
+
818
897
  # fitness score related
819
898
 
820
899
  self.get_fitness_scores = get_fitness_scores
@@ -1086,6 +1165,7 @@ class Agent(Module):
1086
1165
  latents,
1087
1166
  old_values = old_values,
1088
1167
  target = advantages + old_values,
1168
+ use_improved = self.use_improved_critic_loss,
1089
1169
  **self.critic_loss_kwargs
1090
1170
  )
1091
1171
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -196,4 +196,23 @@ agent.load('./agent.pt')
196
196
  }
197
197
  ```
198
198
 
199
+ ```bibtex
200
+ @article{Doerr2017FastGA,
201
+ title = {Fast genetic algorithms},
202
+ author = {Benjamin Doerr and Huu Phuoc Le and R{\'e}gis Makhmara and Ta Duy Nguyen},
203
+ journal = {Proceedings of the Genetic and Evolutionary Computation Conference},
204
+ year = {2017},
205
+ url = {https://api.semanticscholar.org/CorpusID:16196841}
206
+ }
207
+ ```
208
+
209
+ ```bibtex
210
+ @article{Lee2024AnalysisClippedCritic
211
+ title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
212
+ author = {Yongjin Lee, Moonyoung Chung},
213
+ journal = {Authorea},
214
+ year = {2024}
215
+ }
216
+ ```
217
+
199
218
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=5rOygXAfbb4dmjfseBcHgxHPpTFNMrrMDrY9IsJuZ28,43381
4
+ evolutionary_policy_optimization/epo.py,sha256=9GfSvOz6SwjAuZyhyvsLHPY8b2svMQlM3BRjilwsQ-g,45717
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.7.dist-info/METADATA,sha256=yc_7LIYTbAhc7disU0o4ep-xVT1Ku3_nEF01yHcUzDE,6742
8
- evolutionary_policy_optimization-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.7.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.9.dist-info/METADATA,sha256=y5w_NwtKNQ07HeYa5r6hcPn7RsqDpehMmt5vj6mTESQ,7316
8
+ evolutionary_policy_optimization-0.1.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.9.dist-info/RECORD,,