evolutionary-policy-optimization 0.1.16__tar.gz → 0.1.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/evolutionary_policy_optimization/epo.py +22 -4
  3. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/tests/test_epo.py +5 -2
  5. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.17}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.16
3
+ Version: 0.1.17
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -943,6 +943,8 @@ class Agent(Module):
943
943
  eps_clip = 0.4
944
944
  ),
945
945
  use_improved_critic_loss = True,
946
+ shrink_and_perturb_every = None,
947
+ shrink_and_perturb_kwargs: dict = dict(),
946
948
  ema_kwargs: dict = dict(),
947
949
  actor_optim_kwargs: dict = dict(),
948
950
  critic_optim_kwargs: dict = dict(),
@@ -1007,6 +1009,12 @@ class Agent(Module):
1007
1009
 
1008
1010
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
1009
1011
 
1012
+ # shrink and perturb every
1013
+
1014
+ self.should_noise_weights = exists(shrink_and_perturb_every)
1015
+ self.shrink_and_perturb_every = shrink_and_perturb_every
1016
+ self.shrink_and_perturb_ = partial(shrink_and_perturb_, **shrink_and_perturb_kwargs)
1017
+
1010
1018
  # promotes latents to be farther apart for diversity maintenance
1011
1019
 
1012
1020
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
@@ -1016,7 +1024,7 @@ class Agent(Module):
1016
1024
 
1017
1025
  self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
1018
1026
 
1019
- dummy = tensor(0)
1027
+ step = tensor(0)
1020
1028
 
1021
1029
  self.clip_grad_norm_ = nn.utils.clip_grad_norm_
1022
1030
 
@@ -1044,15 +1052,15 @@ class Agent(Module):
1044
1052
  if exists(self.critic_ema):
1045
1053
  self.critic_ema.to(self.accelerate.device)
1046
1054
 
1047
- dummy = dummy.to(self.accelerate.device)
1055
+ step = step.to(self.accelerate.device)
1048
1056
 
1049
1057
  # device tracking
1050
1058
 
1051
- self.register_buffer('dummy', dummy)
1059
+ self.register_buffer('step', step)
1052
1060
 
1053
1061
  @property
1054
1062
  def device(self):
1055
- return self.dummy.device
1063
+ return self.step.device
1056
1064
 
1057
1065
  @property
1058
1066
  def unwrapped_latent_gene_pool(self):
@@ -1302,6 +1310,16 @@ class Agent(Module):
1302
1310
  if self.has_latent_genes:
1303
1311
  self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
1304
1312
 
1313
+ # maybe shrink and perturb
1314
+
1315
+ if self.should_noise_weights and divisible_by(self.step.item(), self.shrink_and_perturb_every):
1316
+ self.shrink_and_perturb_(self.actor)
1317
+ self.shrink_and_perturb_(self.critic)
1318
+
1319
+ # increment step
1320
+
1321
+ self.step.add_(1)
1322
+
1305
1323
  # reinforcement learning related - ppo
1306
1324
 
1307
1325
  def actor_loss(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.16"
3
+ version = "0.1.17"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -85,13 +85,15 @@ def test_create_agent(
85
85
  @pytest.mark.parametrize('use_improved_critic_loss', (False, True))
86
86
  @pytest.mark.parametrize('num_latents', (1, 8))
87
87
  @pytest.mark.parametrize('diversity_aux_loss_weight', (0., 1e-3))
88
+ @pytest.mark.parametrize('shrink_and_perturb_every', (None, 1))
88
89
  def test_e2e_with_mock_env(
89
90
  frozen_latents,
90
91
  use_critic_ema,
91
92
  num_latents,
92
93
  diversity_aux_loss_weight,
93
94
  critic_use_regression,
94
- use_improved_critic_loss
95
+ use_improved_critic_loss,
96
+ shrink_and_perturb_every
95
97
  ):
96
98
  from evolutionary_policy_optimization import create_agent, EPO, Env
97
99
 
@@ -106,6 +108,7 @@ def test_e2e_with_mock_env(
106
108
  critic_mlp_depth = 4,
107
109
  use_critic_ema = use_critic_ema,
108
110
  diversity_aux_loss_weight = diversity_aux_loss_weight,
111
+ shrink_and_perturb_every = shrink_and_perturb_every,
109
112
  critic_kwargs = dict(
110
113
  use_regression = critic_use_regression
111
114
  ),
@@ -115,7 +118,7 @@ def test_e2e_with_mock_env(
115
118
  frac_natural_selected = 0.75,
116
119
  frac_tournaments = 0.9
117
120
  ),
118
- wrap_with_accelerate = False
121
+ wrap_with_accelerate = False,
119
122
  )
120
123
 
121
124
  epo = EPO(