evolutionary-policy-optimization 0.0.43__py3-none-any.whl → 0.0.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +42 -7
- {evolutionary_policy_optimization-0.0.43.dist-info → evolutionary_policy_optimization-0.0.45.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.0.43.dist-info → evolutionary_policy_optimization-0.0.45.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.43.dist-info → evolutionary_policy_optimization-0.0.45.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.43.dist-info → evolutionary_policy_optimization-0.0.45.dist-info}/licenses/LICENSE +0 -0
@@ -626,7 +626,9 @@ class Agent(Module):
|
|
626
626
|
actor_lr = 1e-4,
|
627
627
|
critic_lr = 1e-4,
|
628
628
|
latent_lr = 1e-5,
|
629
|
+
use_critic_ema = True,
|
629
630
|
critic_ema_beta = 0.99,
|
631
|
+
max_grad_norm = 0.5,
|
630
632
|
batch_size = 16,
|
631
633
|
calc_gae_kwargs: dict = dict(
|
632
634
|
use_accelerated = False,
|
@@ -647,7 +649,9 @@ class Agent(Module):
|
|
647
649
|
self.actor = actor
|
648
650
|
|
649
651
|
self.critic = critic
|
650
|
-
|
652
|
+
|
653
|
+
self.use_critic_ema = use_critic_ema
|
654
|
+
self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
|
651
655
|
|
652
656
|
self.num_latents = latent_gene_pool.num_latents
|
653
657
|
self.latent_gene_pool = latent_gene_pool
|
@@ -659,7 +663,11 @@ class Agent(Module):
|
|
659
663
|
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
660
664
|
self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
|
661
665
|
|
666
|
+
# learning hparams
|
667
|
+
|
662
668
|
self.batch_size = batch_size
|
669
|
+
self.max_grad_norm = max_grad_norm
|
670
|
+
self.has_grad_clip = exists(max_grad_norm)
|
663
671
|
|
664
672
|
# optimizers
|
665
673
|
|
@@ -676,7 +684,7 @@ class Agent(Module):
|
|
676
684
|
pkg = dict(
|
677
685
|
actor = self.actor.state_dict(),
|
678
686
|
critic = self.critic.state_dict(),
|
679
|
-
critic_ema = self.critic_ema.state_dict(),
|
687
|
+
critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
|
680
688
|
latents = self.latent_gene_pool.state_dict(),
|
681
689
|
actor_optim = self.actor_optim.state_dict(),
|
682
690
|
critic_optim = self.critic_optim.state_dict(),
|
@@ -695,7 +703,9 @@ class Agent(Module):
|
|
695
703
|
self.actor.load_state_dict(pkg['actor'])
|
696
704
|
|
697
705
|
self.critic.load_state_dict(pkg['critic'])
|
698
|
-
|
706
|
+
|
707
|
+
if self.use_critic_ema:
|
708
|
+
self.critic_ema.load_state_dict(pkg['critic_ema'])
|
699
709
|
|
700
710
|
self.latent_gene_pool.load_state_dict(pkg['latents'])
|
701
711
|
|
@@ -733,14 +743,20 @@ class Agent(Module):
|
|
733
743
|
self,
|
734
744
|
state,
|
735
745
|
latent_id = None,
|
736
|
-
latent = None
|
746
|
+
latent = None,
|
747
|
+
use_ema_if_available = False
|
737
748
|
):
|
738
749
|
assert exists(latent_id) or exists(latent)
|
739
750
|
|
740
751
|
if not exists(latent):
|
741
752
|
latent = self.latent_gene_pool(latent_id = latent_id)
|
742
753
|
|
743
|
-
|
754
|
+
critic_forward = self.critic
|
755
|
+
|
756
|
+
if use_ema_if_available and self.use_critic_ema:
|
757
|
+
critic_forward = self.critic_ema
|
758
|
+
|
759
|
+
return critic_forward(state, latent)
|
744
760
|
|
745
761
|
def update_latent_gene_pool_(
|
746
762
|
self,
|
@@ -810,6 +826,9 @@ class Agent(Module):
|
|
810
826
|
|
811
827
|
actor_loss.backward()
|
812
828
|
|
829
|
+
if exists(self.has_grad_clip):
|
830
|
+
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
831
|
+
|
813
832
|
self.actor_optim.step()
|
814
833
|
self.actor_optim.zero_grad()
|
815
834
|
|
@@ -823,9 +842,17 @@ class Agent(Module):
|
|
823
842
|
|
824
843
|
critic_loss.backward()
|
825
844
|
|
845
|
+
if exists(self.has_grad_clip):
|
846
|
+
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
847
|
+
|
826
848
|
self.critic_optim.step()
|
827
849
|
self.critic_optim.zero_grad()
|
828
850
|
|
851
|
+
# maybe ema update critic
|
852
|
+
|
853
|
+
if self.use_critic_ema:
|
854
|
+
self.critic_ema.update()
|
855
|
+
|
829
856
|
# maybe update latents, if not frozen
|
830
857
|
|
831
858
|
if not self.latent_gene_pool.frozen_latents:
|
@@ -875,6 +902,7 @@ def create_agent(
|
|
875
902
|
actor_num_actions,
|
876
903
|
actor_dim_hiddens: int | tuple[int, ...],
|
877
904
|
critic_dim_hiddens: int | tuple[int, ...],
|
905
|
+
use_critic_ema = True,
|
878
906
|
latent_gene_pool_kwargs: dict = dict(),
|
879
907
|
actor_kwargs: dict = dict(),
|
880
908
|
critic_kwargs: dict = dict(),
|
@@ -901,7 +929,14 @@ def create_agent(
|
|
901
929
|
**critic_kwargs
|
902
930
|
)
|
903
931
|
|
904
|
-
|
932
|
+
agent = Agent(
|
933
|
+
actor = actor,
|
934
|
+
critic = critic,
|
935
|
+
latent_gene_pool = latent_gene_pool,
|
936
|
+
use_critic_ema = use_critic_ema
|
937
|
+
)
|
938
|
+
|
939
|
+
return agent
|
905
940
|
|
906
941
|
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
907
942
|
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
@@ -978,7 +1013,7 @@ class EPO(Module):
|
|
978
1013
|
|
979
1014
|
# values
|
980
1015
|
|
981
|
-
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
1016
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
982
1017
|
|
983
1018
|
# get the next state, action, and reward
|
984
1019
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.45
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,8 +1,8 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=NA-d7pWDJyQYULDIVB25lnbpTbwMyxc1U8RU8XGTNts,31500
|
3
3
|
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
4
|
evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
|
5
|
-
evolutionary_policy_optimization-0.0.
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
5
|
+
evolutionary_policy_optimization-0.0.45.dist-info/METADATA,sha256=3jXsZBoltrWQJk2Yd6zu1KmCcl9AEuhxES_mX8E1lAk,6213
|
6
|
+
evolutionary_policy_optimization-0.0.45.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.45.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.45.dist-info/RECORD,,
|
File without changes
|