evolutionary-policy-optimization 0.1.8__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/PKG-INFO +10 -1
  2. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/README.md +9 -0
  3. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/evolutionary_policy_optimization/epo.py +29 -5
  4. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/tests/test_epo.py +4 -1
  6. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.8 → evolutionary_policy_optimization-0.1.9}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -206,4 +206,13 @@ agent.load('./agent.pt')
206
206
  }
207
207
  ```
208
208
 
209
+ ```bibtex
210
+ @article{Lee2024AnalysisClippedCritic
211
+ title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
212
+ author = {Yongjin Lee, Moonyoung Chung},
213
+ journal = {Authorea},
214
+ year = {2024}
215
+ }
216
+ ```
217
+
209
218
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -153,4 +153,13 @@ agent.load('./agent.pt')
153
153
  }
154
154
  ```
155
155
 
156
+ ```bibtex
157
+ @article{Lee2024AnalysisClippedCritic
158
+ title = {On Analysis of Clipped Critic Loss in Proximal Policy Gradient},
159
+ author = {Yongjin Lee, Moonyoung Chung},
160
+ journal = {Authorea},
161
+ year = {2024}
162
+ }
163
+ ```
164
+
156
165
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -424,18 +424,38 @@ class Critic(Module):
424
424
  latent,
425
425
  old_values,
426
426
  target,
427
- eps_clip = 0.4
427
+ eps_clip = 0.4,
428
+ use_improved = True
428
429
  ):
429
430
  logits = self.forward(state, latent, return_logits = True)
430
431
 
431
432
  value = self.maybe_bins_to_value(logits)
432
433
 
433
- clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
434
+ if use_improved:
435
+ clipped_target = target.clamp(-eps_clip, eps_clip)
434
436
 
435
- loss = self.loss_fn(logits, target, reduction = 'none')
436
- clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
437
+ old_values_lo = old_values - eps_clip
438
+ old_values_hi = old_values + eps_clip
437
439
 
438
- return torch.max(loss, clipped_loss).mean()
440
+ is_between = lambda lo, hi: (lo < value) & (value < hi)
441
+
442
+ clipped_loss = self.loss_fn(logits, clipped_target, reduction = 'none')
443
+ loss = self.loss_fn(logits, target, reduction = 'none')
444
+
445
+ value_loss = torch.where(
446
+ is_between(target, old_values_lo) | is_between(old_values_hi, target),
447
+ 0.,
448
+ torch.min(loss, clipped_loss)
449
+ )
450
+ else:
451
+ clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
452
+
453
+ loss = self.loss_fn(logits, target, reduction = 'none')
454
+ clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
455
+
456
+ value_loss = torch.max(loss, clipped_loss)
457
+
458
+ return value_loss.mean()
439
459
 
440
460
  def forward(
441
461
  self,
@@ -826,6 +846,7 @@ class Agent(Module):
826
846
  critic_loss_kwargs: dict = dict(
827
847
  eps_clip = 0.4
828
848
  ),
849
+ use_improved_critic_loss = True,
829
850
  ema_kwargs: dict = dict(),
830
851
  actor_optim_kwargs: dict = dict(),
831
852
  critic_optim_kwargs: dict = dict(),
@@ -871,6 +892,8 @@ class Agent(Module):
871
892
  self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
872
893
  self.critic_loss_kwargs = critic_loss_kwargs
873
894
 
895
+ self.use_improved_critic_loss = use_improved_critic_loss
896
+
874
897
  # fitness score related
875
898
 
876
899
  self.get_fitness_scores = get_fitness_scores
@@ -1142,6 +1165,7 @@ class Agent(Module):
1142
1165
  latents,
1143
1166
  old_values = old_values,
1144
1167
  target = advantages + old_values,
1168
+ use_improved = self.use_improved_critic_loss,
1145
1169
  **self.critic_loss_kwargs
1146
1170
  )
1147
1171
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.8"
3
+ version = "0.1.9"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -79,6 +79,7 @@ def test_create_agent(
79
79
  @pytest.mark.parametrize('frozen_latents', (False, True))
80
80
  @pytest.mark.parametrize('use_critic_ema', (False, True))
81
81
  @pytest.mark.parametrize('critic_use_regression', (False, True))
82
+ @pytest.mark.parametrize('use_improved_critic_loss', (False, True))
82
83
  @pytest.mark.parametrize('num_latents', (1, 8))
83
84
  @pytest.mark.parametrize('diversity_aux_loss_weight', (0., 1e-3))
84
85
  def test_e2e_with_mock_env(
@@ -86,7 +87,8 @@ def test_e2e_with_mock_env(
86
87
  use_critic_ema,
87
88
  num_latents,
88
89
  diversity_aux_loss_weight,
89
- critic_use_regression
90
+ critic_use_regression,
91
+ use_improved_critic_loss
90
92
  ):
91
93
  from evolutionary_policy_optimization import create_agent, EPO, Env
92
94
 
@@ -102,6 +104,7 @@ def test_e2e_with_mock_env(
102
104
  critic_kwargs = dict(
103
105
  use_regression = critic_use_regression
104
106
  ),
107
+ use_improved_critic_loss = use_improved_critic_loss,
105
108
  latent_gene_pool_kwargs = dict(
106
109
  frozen_latents = frozen_latents,
107
110
  frac_natural_selected = 0.75,