evolutionary-policy-optimization 0.0.31__tar.gz → 0.0.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (13) hide show
  1. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/evolutionary_policy_optimization/epo.py +11 -1
  3. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/.github/workflows/python-publish.yml +0 -0
  5. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/.github/workflows/test.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/.gitignore +0 -0
  7. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/LICENSE +0 -0
  8. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/README.md +0 -0
  9. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/evolutionary_policy_optimization/__init__.py +0 -0
  10. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/evolutionary_policy_optimization/experimental.py +0 -0
  11. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/evolutionary_policy_optimization/mock_env.py +0 -0
  12. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/requirements.txt +0 -0
  13. {evolutionary_policy_optimization-0.0.31 → evolutionary_policy_optimization-0.0.33}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.31
3
+ Version: 0.0.33
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -19,6 +19,8 @@ from adam_atan2_pytorch import AdoptAtan2
19
19
 
20
20
  from hl_gauss_pytorch import HLGaussLayer
21
21
 
22
+ from ema_pytorch import EMA
23
+
22
24
  # helpers
23
25
 
24
26
  def exists(v):
@@ -271,7 +273,7 @@ class Critic(Module):
271
273
 
272
274
  self.to_pred = HLGaussLayer(
273
275
  dim = dim_last,
274
- use_regression = False,
276
+ use_regression = use_regression,
275
277
  hl_gauss_loss = hl_gauss_loss_kwargs
276
278
  )
277
279
 
@@ -599,6 +601,8 @@ class Agent(Module):
599
601
  actor_lr = 1e-4,
600
602
  critic_lr = 1e-4,
601
603
  latent_lr = 1e-5,
604
+ critic_ema_beta = 0.99,
605
+ ema_kwargs: dict = dict(),
602
606
  actor_optim_kwargs: dict = dict(),
603
607
  critic_optim_kwargs: dict = dict(),
604
608
  latent_optim_kwargs: dict = dict(),
@@ -606,7 +610,9 @@ class Agent(Module):
606
610
  super().__init__()
607
611
 
608
612
  self.actor = actor
613
+
609
614
  self.critic = critic
615
+ self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs)
610
616
 
611
617
  self.num_latents = latent_gene_pool.num_latents
612
618
  self.latent_gene_pool = latent_gene_pool
@@ -628,6 +634,7 @@ class Agent(Module):
628
634
  pkg = dict(
629
635
  actor = self.actor.state_dict(),
630
636
  critic = self.critic.state_dict(),
637
+ critic_ema = self.critic_ema.state_dict(),
631
638
  latents = self.latent_gene_pool.state_dict(),
632
639
  actor_optim = self.actor_optim.state_dict(),
633
640
  critic_optim = self.critic_optim.state_dict(),
@@ -644,7 +651,10 @@ class Agent(Module):
644
651
  pkg = torch.load(str(path), weights_only = True)
645
652
 
646
653
  self.actor.load_state_dict(pkg['actor'])
654
+
647
655
  self.critic.load_state_dict(pkg['critic'])
656
+ self.critic_ema.load_state_dict(pkg['critic_ema'])
657
+
648
658
  self.latent_gene_pool.load_state_dict(pkg['latents'])
649
659
 
650
660
  self.actor_optim.load_state_dict(pkg['actor_optim'])
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.31"
3
+ version = "0.0.33"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }