evolutionary-policy-optimization 0.1.16__tar.gz → 0.1.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/evolutionary_policy_optimization/epo.py +25 -5
  3. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/tests/test_epo.py +5 -2
  5. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.16 → evolutionary_policy_optimization-0.1.18}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.16
3
+ Version: 0.1.18
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -392,6 +392,8 @@ class MLP(Module):
392
392
 
393
393
  self.layers = ModuleList(layers)
394
394
 
395
+ self.final_lime = DynamicLIMe(dim, depth + 1)
396
+
395
397
  def forward(
396
398
  self,
397
399
  x,
@@ -430,7 +432,7 @@ class MLP(Module):
430
432
 
431
433
  prev_layer_inputs.append(x)
432
434
 
433
- return x
435
+ return self.final_lime(x, prev_layer_inputs)
434
436
 
435
437
  # actor, critic, and agent (actor + critic)
436
438
  # eventually, should just create a separate repo and aggregate all the MLP related architectures
@@ -943,6 +945,8 @@ class Agent(Module):
943
945
  eps_clip = 0.4
944
946
  ),
945
947
  use_improved_critic_loss = True,
948
+ shrink_and_perturb_every = None,
949
+ shrink_and_perturb_kwargs: dict = dict(),
946
950
  ema_kwargs: dict = dict(),
947
951
  actor_optim_kwargs: dict = dict(),
948
952
  critic_optim_kwargs: dict = dict(),
@@ -1007,6 +1011,12 @@ class Agent(Module):
1007
1011
 
1008
1012
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
1009
1013
 
1014
+ # shrink and perturb every
1015
+
1016
+ self.should_noise_weights = exists(shrink_and_perturb_every)
1017
+ self.shrink_and_perturb_every = shrink_and_perturb_every
1018
+ self.shrink_and_perturb_ = partial(shrink_and_perturb_, **shrink_and_perturb_kwargs)
1019
+
1010
1020
  # promotes latents to be farther apart for diversity maintenance
1011
1021
 
1012
1022
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
@@ -1016,7 +1026,7 @@ class Agent(Module):
1016
1026
 
1017
1027
  self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
1018
1028
 
1019
- dummy = tensor(0)
1029
+ step = tensor(0)
1020
1030
 
1021
1031
  self.clip_grad_norm_ = nn.utils.clip_grad_norm_
1022
1032
 
@@ -1044,15 +1054,15 @@ class Agent(Module):
1044
1054
  if exists(self.critic_ema):
1045
1055
  self.critic_ema.to(self.accelerate.device)
1046
1056
 
1047
- dummy = dummy.to(self.accelerate.device)
1057
+ step = step.to(self.accelerate.device)
1048
1058
 
1049
1059
  # device tracking
1050
1060
 
1051
- self.register_buffer('dummy', dummy)
1061
+ self.register_buffer('step', step)
1052
1062
 
1053
1063
  @property
1054
1064
  def device(self):
1055
- return self.dummy.device
1065
+ return self.step.device
1056
1066
 
1057
1067
  @property
1058
1068
  def unwrapped_latent_gene_pool(self):
@@ -1302,6 +1312,16 @@ class Agent(Module):
1302
1312
  if self.has_latent_genes:
1303
1313
  self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
1304
1314
 
1315
+ # maybe shrink and perturb
1316
+
1317
+ if self.should_noise_weights and divisible_by(self.step.item(), self.shrink_and_perturb_every):
1318
+ self.shrink_and_perturb_(self.actor)
1319
+ self.shrink_and_perturb_(self.critic)
1320
+
1321
+ # increment step
1322
+
1323
+ self.step.add_(1)
1324
+
1305
1325
  # reinforcement learning related - ppo
1306
1326
 
1307
1327
  def actor_loss(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.16"
3
+ version = "0.1.18"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -85,13 +85,15 @@ def test_create_agent(
85
85
  @pytest.mark.parametrize('use_improved_critic_loss', (False, True))
86
86
  @pytest.mark.parametrize('num_latents', (1, 8))
87
87
  @pytest.mark.parametrize('diversity_aux_loss_weight', (0., 1e-3))
88
+ @pytest.mark.parametrize('shrink_and_perturb_every', (None, 1))
88
89
  def test_e2e_with_mock_env(
89
90
  frozen_latents,
90
91
  use_critic_ema,
91
92
  num_latents,
92
93
  diversity_aux_loss_weight,
93
94
  critic_use_regression,
94
- use_improved_critic_loss
95
+ use_improved_critic_loss,
96
+ shrink_and_perturb_every
95
97
  ):
96
98
  from evolutionary_policy_optimization import create_agent, EPO, Env
97
99
 
@@ -106,6 +108,7 @@ def test_e2e_with_mock_env(
106
108
  critic_mlp_depth = 4,
107
109
  use_critic_ema = use_critic_ema,
108
110
  diversity_aux_loss_weight = diversity_aux_loss_weight,
111
+ shrink_and_perturb_every = shrink_and_perturb_every,
109
112
  critic_kwargs = dict(
110
113
  use_regression = critic_use_regression
111
114
  ),
@@ -115,7 +118,7 @@ def test_e2e_with_mock_env(
115
118
  frac_natural_selected = 0.75,
116
119
  frac_tournaments = 0.9
117
120
  ),
118
- wrap_with_accelerate = False
121
+ wrap_with_accelerate = False,
119
122
  )
120
123
 
121
124
  epo = EPO(