evolutionary-policy-optimization 0.0.50__py3-none-any.whl → 0.0.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,6 +48,9 @@ def divisible_by(num, den):
48
48
  def l2norm(t):
49
49
  return F.normalize(t, p = 2, dim = -1)
50
50
 
51
+ def batch_randperm(shape, device):
52
+ return torch.randn(shape, device = device).argsort(dim = -1)
53
+
51
54
  def log(t, eps = 1e-20):
52
55
  return t.clamp(min = eps).log()
53
56
 
@@ -393,7 +396,6 @@ class LatentGenePool(Module):
393
396
 
394
397
  latents_per_island = num_latents // num_islands
395
398
  self.num_natural_selected = int(frac_natural_selected * latents_per_island)
396
-
397
399
  self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
398
400
 
399
401
  self.crossover_random = crossover_random
@@ -530,7 +532,9 @@ class LatentGenePool(Module):
530
532
 
531
533
  # 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
532
534
 
533
- rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants), device = device).argsort(dim = -1)
535
+ tournament_shape = (islands, pop_size_per_island - self.num_natural_selected, self.num_natural_selected) # (island, num children needed, natural selected population to be bred)
536
+
537
+ rand_tournament_gene_ids = batch_randperm(tournament_shape, device)[..., :tournament_participants]
534
538
  rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
535
539
 
536
540
  participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
@@ -901,7 +905,7 @@ class Agent(Module):
901
905
 
902
906
  if self.has_diversity_loss:
903
907
  diversity = self.latent_gene_pool.get_distance()
904
- diversity_loss = diversity.mul(-1).exp().mean()
908
+ diversity_loss = (-diversity).tril(-1).exp().mean()
905
909
 
906
910
  (diversity_loss * self.diversity_aux_loss_weight).backward()
907
911
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.50
3
+ Version: 0.0.52
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=nIZzjz_3WUvHYPwNH9zoKM3fyRJd3MOd6UL7ooSTQV4,33445
2
+ evolutionary_policy_optimization/epo.py,sha256=7iGmye36JsQ1jUm_UCCrSdvC67u70EPTAp3nCbKGjO8,33675
3
3
  evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
5
- evolutionary_policy_optimization-0.0.50.dist-info/METADATA,sha256=iCB34UsXdosJV9q8-qXJpJfI8OdNNXB5o8MhmKy82zY,6213
6
- evolutionary_policy_optimization-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.50.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.52.dist-info/METADATA,sha256=dET2Z11ktfJYei-xFfv7SjGKTglco6dDvELIzKrLsAQ,6213
6
+ evolutionary_policy_optimization-0.0.52.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.52.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.52.dist-info/RECORD,,