evolutionary-policy-optimization 0.0.48__tar.gz → 0.0.50__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (13) hide show
  1. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/epo.py +32 -7
  3. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/tests/test_epo.py +2 -0
  5. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/experimental.py +0 -0
  12. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/mock_env.py +0 -0
  13. {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/requirements.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.48
3
+ Version: 0.0.50
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -6,7 +6,7 @@ from collections import namedtuple
6
6
  from random import randrange
7
7
 
8
8
  import torch
9
- from torch import nn, cat, stack, is_tensor, tensor
9
+ from torch import nn, cat, stack, is_tensor, tensor, Tensor
10
10
  import torch.nn.functional as F
11
11
  from torch.nn import Linear, Module, ModuleList
12
12
  from torch.utils.data import TensorDataset, DataLoader
@@ -412,7 +412,16 @@ class LatentGenePool(Module):
412
412
 
413
413
  self.can_migrate = num_islands > 1
414
414
  self.migrate_every = migrate_every
415
- self.register_buffer('step', tensor(0))
415
+ self.register_buffer('step', tensor(1))
416
+
417
+ def get_distance(self):
418
+ # returns latent euclidean distance as proxy for diversity
419
+
420
+ latents = rearrange(self.latents, '(i p) g -> i p g', i = self.num_islands)
421
+
422
+ distance = torch.cdist(latents, latents)
423
+
424
+ return distance
416
425
 
417
426
  def advance_step_(self):
418
427
  self.step.add_(1)
@@ -643,6 +652,7 @@ class Agent(Module):
643
652
  actor_lr = 1e-4,
644
653
  critic_lr = 1e-4,
645
654
  latent_lr = 1e-5,
655
+ diversity_aux_loss_weight = 0.,
646
656
  use_critic_ema = True,
647
657
  critic_ema_beta = 0.99,
648
658
  max_grad_norm = 0.5,
@@ -698,6 +708,11 @@ class Agent(Module):
698
708
 
699
709
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
700
710
 
711
+ # promotes latents to be farther apart for diversity maintenance
712
+
713
+ self.has_diversity_loss = diversity_aux_loss_weight > 0.
714
+ self.diversity_aux_loss_weight = diversity_aux_loss_weight
715
+
701
716
  def save(self, path, overwrite = False):
702
717
  path = Path(path)
703
718
 
@@ -879,11 +894,19 @@ class Agent(Module):
879
894
 
880
895
  # maybe update latents, if not frozen
881
896
 
882
- if not self.latent_gene_pool.frozen_latents:
883
- orig_latents.backward(latents.grad)
897
+ if self.latent_gene_pool.frozen_latents:
898
+ continue
899
+
900
+ orig_latents.backward(latents.grad)
901
+
902
+ if self.has_diversity_loss:
903
+ diversity = self.latent_gene_pool.get_distance()
904
+ diversity_loss = diversity.mul(-1).exp().mean()
905
+
906
+ (diversity_loss * self.diversity_aux_loss_weight).backward()
884
907
 
885
- self.latent_optim.step()
886
- self.latent_optim.zero_grad()
908
+ self.latent_optim.step()
909
+ self.latent_optim.zero_grad()
887
910
 
888
911
  # apply evolution
889
912
 
@@ -930,6 +953,7 @@ def create_agent(
930
953
  latent_gene_pool_kwargs: dict = dict(),
931
954
  actor_kwargs: dict = dict(),
932
955
  critic_kwargs: dict = dict(),
956
+ **kwargs
933
957
  ) -> Agent:
934
958
 
935
959
  latent_gene_pool = LatentGenePool(
@@ -957,7 +981,8 @@ def create_agent(
957
981
  actor = actor,
958
982
  critic = critic,
959
983
  latent_gene_pool = latent_gene_pool,
960
- use_critic_ema = use_critic_ema
984
+ use_critic_ema = use_critic_ema,
985
+ **kwargs
961
986
  )
962
987
 
963
988
  return agent
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.48"
3
+ version = "0.0.50"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -75,6 +75,7 @@ def test_create_agent(
75
75
 
76
76
  @pytest.mark.parametrize('frozen_latents', (False, True))
77
77
  @pytest.mark.parametrize('use_critic_ema', (False, True))
78
+ @pytest.mark.parametrize('diversity_aux_loss_weight', (0., 1e-3))
78
79
  def test_e2e_with_mock_env(
79
80
  frozen_latents,
80
81
  use_critic_ema
@@ -89,6 +90,7 @@ def test_e2e_with_mock_env(
89
90
  actor_dim_hiddens = (256, 128),
90
91
  critic_dim_hiddens = (256, 128, 64),
91
92
  use_critic_ema = use_critic_ema,
93
+ diversity_aux_loss_weight = diversity_aux_loss_weight,
92
94
  latent_gene_pool_kwargs = dict(
93
95
  frozen_latents = frozen_latents,
94
96
  )