evolutionary-policy-optimization 0.0.43__tar.gz → 0.0.44__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/PKG-INFO +1 -1
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/evolutionary_policy_optimization/epo.py +31 -7
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/tests/test_epo.py +5 -2
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/requirements.txt +0 -0
{evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.44
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -626,6 +626,7 @@ class Agent(Module):
|
|
626
626
|
actor_lr = 1e-4,
|
627
627
|
critic_lr = 1e-4,
|
628
628
|
latent_lr = 1e-5,
|
629
|
+
use_critic_ema = True,
|
629
630
|
critic_ema_beta = 0.99,
|
630
631
|
batch_size = 16,
|
631
632
|
calc_gae_kwargs: dict = dict(
|
@@ -647,7 +648,9 @@ class Agent(Module):
|
|
647
648
|
self.actor = actor
|
648
649
|
|
649
650
|
self.critic = critic
|
650
|
-
|
651
|
+
|
652
|
+
self.use_critic_ema = use_critic_ema
|
653
|
+
self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
|
651
654
|
|
652
655
|
self.num_latents = latent_gene_pool.num_latents
|
653
656
|
self.latent_gene_pool = latent_gene_pool
|
@@ -676,7 +679,7 @@ class Agent(Module):
|
|
676
679
|
pkg = dict(
|
677
680
|
actor = self.actor.state_dict(),
|
678
681
|
critic = self.critic.state_dict(),
|
679
|
-
critic_ema = self.critic_ema.state_dict(),
|
682
|
+
critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
|
680
683
|
latents = self.latent_gene_pool.state_dict(),
|
681
684
|
actor_optim = self.actor_optim.state_dict(),
|
682
685
|
critic_optim = self.critic_optim.state_dict(),
|
@@ -695,7 +698,9 @@ class Agent(Module):
|
|
695
698
|
self.actor.load_state_dict(pkg['actor'])
|
696
699
|
|
697
700
|
self.critic.load_state_dict(pkg['critic'])
|
698
|
-
|
701
|
+
|
702
|
+
if self.use_critic_ema:
|
703
|
+
self.critic_ema.load_state_dict(pkg['critic_ema'])
|
699
704
|
|
700
705
|
self.latent_gene_pool.load_state_dict(pkg['latents'])
|
701
706
|
|
@@ -733,14 +738,20 @@ class Agent(Module):
|
|
733
738
|
self,
|
734
739
|
state,
|
735
740
|
latent_id = None,
|
736
|
-
latent = None
|
741
|
+
latent = None,
|
742
|
+
use_ema_if_available = False
|
737
743
|
):
|
738
744
|
assert exists(latent_id) or exists(latent)
|
739
745
|
|
740
746
|
if not exists(latent):
|
741
747
|
latent = self.latent_gene_pool(latent_id = latent_id)
|
742
748
|
|
743
|
-
|
749
|
+
critic_forward = self.critic
|
750
|
+
|
751
|
+
if use_ema_if_available and self.use_critic_ema:
|
752
|
+
critic_forward = self.critic_ema
|
753
|
+
|
754
|
+
return critic_forward(state, latent)
|
744
755
|
|
745
756
|
def update_latent_gene_pool_(
|
746
757
|
self,
|
@@ -826,6 +837,11 @@ class Agent(Module):
|
|
826
837
|
self.critic_optim.step()
|
827
838
|
self.critic_optim.zero_grad()
|
828
839
|
|
840
|
+
# maybe ema update critic
|
841
|
+
|
842
|
+
if self.use_critic_ema:
|
843
|
+
self.critic_ema.update()
|
844
|
+
|
829
845
|
# maybe update latents, if not frozen
|
830
846
|
|
831
847
|
if not self.latent_gene_pool.frozen_latents:
|
@@ -875,6 +891,7 @@ def create_agent(
|
|
875
891
|
actor_num_actions,
|
876
892
|
actor_dim_hiddens: int | tuple[int, ...],
|
877
893
|
critic_dim_hiddens: int | tuple[int, ...],
|
894
|
+
use_critic_ema = True,
|
878
895
|
latent_gene_pool_kwargs: dict = dict(),
|
879
896
|
actor_kwargs: dict = dict(),
|
880
897
|
critic_kwargs: dict = dict(),
|
@@ -901,7 +918,14 @@ def create_agent(
|
|
901
918
|
**critic_kwargs
|
902
919
|
)
|
903
920
|
|
904
|
-
|
921
|
+
agent = Agent(
|
922
|
+
actor = actor,
|
923
|
+
critic = critic,
|
924
|
+
latent_gene_pool = latent_gene_pool,
|
925
|
+
use_critic_ema = use_critic_ema
|
926
|
+
)
|
927
|
+
|
928
|
+
return agent
|
905
929
|
|
906
930
|
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
907
931
|
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
@@ -978,7 +1002,7 @@ class EPO(Module):
|
|
978
1002
|
|
979
1003
|
# values
|
980
1004
|
|
981
|
-
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
1005
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
982
1006
|
|
983
1007
|
# get the next state, action, and reward
|
984
1008
|
|
@@ -74,8 +74,10 @@ def test_create_agent(
|
|
74
74
|
agent.load('./agent.pt')
|
75
75
|
|
76
76
|
@pytest.mark.parametrize('frozen_latents', (False, True))
|
77
|
+
@pytest.mark.parametrize('use_critic_ema', (False, True))
|
77
78
|
def test_e2e_with_mock_env(
|
78
|
-
frozen_latents
|
79
|
+
frozen_latents,
|
80
|
+
use_critic_ema
|
79
81
|
):
|
80
82
|
from evolutionary_policy_optimization import create_agent, EPO, Env
|
81
83
|
|
@@ -86,8 +88,9 @@ def test_e2e_with_mock_env(
|
|
86
88
|
actor_num_actions = 5,
|
87
89
|
actor_dim_hiddens = (256, 128),
|
88
90
|
critic_dim_hiddens = (256, 128, 64),
|
91
|
+
use_critic_ema = use_critic_ema,
|
89
92
|
latent_gene_pool_kwargs = dict(
|
90
|
-
frozen_latents = frozen_latents
|
93
|
+
frozen_latents = frozen_latents,
|
91
94
|
)
|
92
95
|
)
|
93
96
|
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.43 → evolutionary_policy_optimization-0.0.44}/requirements.txt
RENAMED
File without changes
|