evolutionary-policy-optimization 0.1.8__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/PKG-INFO +10 -1
  2. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/README.md +9 -0
  3. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/epo.py +31 -5
  4. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/tests/test_epo.py +4 -1
  6. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.10}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -206,4 +206,13 @@ agent.load('./agent.pt')
206
206
  }
207
207
  ```
208
208
 
209
+ ```bibtex
210
+ @article{Lee2024AnalysisClippedCritic
211
+ title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
212
+ author = {Yongjin Lee, Moonyoung Chung},
213
+ journal = {Authorea},
214
+ year = {2024}
215
+ }
216
+ ```
217
+
209
218
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -153,4 +153,13 @@ agent.load('./agent.pt')
153
153
  }
154
154
  ```
155
155
 
156
+ ```bibtex
157
+ @article{Lee2024AnalysisClippedCritic
158
+ title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
159
+ author = {Yongjin Lee, Moonyoung Chung},
160
+ journal = {Authorea},
161
+ year = {2024}
162
+ }
163
+ ```
164
+
156
165
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -424,18 +424,40 @@ class Critic(Module):
424
424
  latent,
425
425
  old_values,
426
426
  target,
427
- eps_clip = 0.4
427
+ eps_clip = 0.4,
428
+ use_improved = True
428
429
  ):
429
430
  logits = self.forward(state, latent, return_logits = True)
430
431
 
431
432
  value = self.maybe_bins_to_value(logits)
432
433
 
433
- clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
434
+ loss_fn = partial(self.loss_fn, reduction = 'none')
434
435
 
435
- loss = self.loss_fn(logits, target, reduction = 'none')
436
- clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
436
+ if use_improved:
437
+ clipped_target = target.clamp(-eps_clip, eps_clip)
437
438
 
438
- return torch.max(loss, clipped_loss).mean()
439
+ old_values_lo = old_values - eps_clip
440
+ old_values_hi = old_values + eps_clip
441
+
442
+ is_between = lambda lo, hi: (lo < value) & (value < hi)
443
+
444
+ clipped_loss = loss_fn(logits, clipped_target)
445
+ loss = loss_fn(logits, target)
446
+
447
+ value_loss = torch.where(
448
+ is_between(target, old_values_lo) | is_between(old_values_hi, target),
449
+ 0.,
450
+ torch.min(loss, clipped_loss)
451
+ )
452
+ else:
453
+ clipped_value = old_values + (value - old_values).clamp(-eps_clip, eps_clip)
454
+
455
+ loss = loss_fn(logits, target)
456
+ clipped_loss = loss_fn(clipped_value, target)
457
+
458
+ value_loss = torch.max(loss, clipped_loss)
459
+
460
+ return value_loss.mean()
439
461
 
440
462
  def forward(
441
463
  self,
@@ -826,6 +848,7 @@ class Agent(Module):
826
848
  critic_loss_kwargs: dict = dict(
827
849
  eps_clip = 0.4
828
850
  ),
851
+ use_improved_critic_loss = True,
829
852
  ema_kwargs: dict = dict(),
830
853
  actor_optim_kwargs: dict = dict(),
831
854
  critic_optim_kwargs: dict = dict(),
@@ -871,6 +894,8 @@ class Agent(Module):
871
894
  self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
872
895
  self.critic_loss_kwargs = critic_loss_kwargs
873
896
 
897
+ self.use_improved_critic_loss = use_improved_critic_loss
898
+
874
899
  # fitness score related
875
900
 
876
901
  self.get_fitness_scores = get_fitness_scores
@@ -1142,6 +1167,7 @@ class Agent(Module):
1142
1167
  latents,
1143
1168
  old_values = old_values,
1144
1169
  target = advantages + old_values,
1170
+ use_improved = self.use_improved_critic_loss,
1145
1171
  **self.critic_loss_kwargs
1146
1172
  )
1147
1173
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.8"
3
+ version = "0.1.10"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -79,6 +79,7 @@ def test_create_agent(
79
79
  @pytest.mark.parametrize('frozen_latents', (False, True))
80
80
  @pytest.mark.parametrize('use_critic_ema', (False, True))
81
81
  @pytest.mark.parametrize('critic_use_regression', (False, True))
82
+ @pytest.mark.parametrize('use_improved_critic_loss', (False, True))
82
83
  @pytest.mark.parametrize('num_latents', (1, 8))
83
84
  @pytest.mark.parametrize('diversity_aux_loss_weight', (0., 1e-3))
84
85
  def test_e2e_with_mock_env(
@@ -86,7 +87,8 @@ def test_e2e_with_mock_env(
86
87
  use_critic_ema,
87
88
  num_latents,
88
89
  diversity_aux_loss_weight,
89
- critic_use_regression
90
+ critic_use_regression,
91
+ use_improved_critic_loss
90
92
  ):
91
93
  from evolutionary_policy_optimization import create_agent, EPO, Env
92
94
 
@@ -102,6 +104,7 @@ def test_e2e_with_mock_env(
102
104
  critic_kwargs = dict(
103
105
  use_regression = critic_use_regression
104
106
  ),
107
+ use_improved_critic_loss = use_improved_critic_loss,
105
108
  latent_gene_pool_kwargs = dict(
106
109
  frozen_latents = frozen_latents,
107
110
  frac_natural_selected = 0.75,