evolutionary-policy-optimization 0.0.42__py3-none-any.whl → 0.0.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -626,6 +626,7 @@ class Agent(Module):
626
626
  actor_lr = 1e-4,
627
627
  critic_lr = 1e-4,
628
628
  latent_lr = 1e-5,
629
+ use_critic_ema = True,
629
630
  critic_ema_beta = 0.99,
630
631
  batch_size = 16,
631
632
  calc_gae_kwargs: dict = dict(
@@ -647,7 +648,9 @@ class Agent(Module):
647
648
  self.actor = actor
648
649
 
649
650
  self.critic = critic
650
- self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs)
651
+
652
+ self.use_critic_ema = use_critic_ema
653
+ self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
651
654
 
652
655
  self.num_latents = latent_gene_pool.num_latents
653
656
  self.latent_gene_pool = latent_gene_pool
@@ -676,7 +679,7 @@ class Agent(Module):
676
679
  pkg = dict(
677
680
  actor = self.actor.state_dict(),
678
681
  critic = self.critic.state_dict(),
679
- critic_ema = self.critic_ema.state_dict(),
682
+ critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
680
683
  latents = self.latent_gene_pool.state_dict(),
681
684
  actor_optim = self.actor_optim.state_dict(),
682
685
  critic_optim = self.critic_optim.state_dict(),
@@ -695,7 +698,9 @@ class Agent(Module):
695
698
  self.actor.load_state_dict(pkg['actor'])
696
699
 
697
700
  self.critic.load_state_dict(pkg['critic'])
698
- self.critic_ema.load_state_dict(pkg['critic_ema'])
701
+
702
+ if self.use_critic_ema:
703
+ self.critic_ema.load_state_dict(pkg['critic_ema'])
699
704
 
700
705
  self.latent_gene_pool.load_state_dict(pkg['latents'])
701
706
 
@@ -733,14 +738,20 @@ class Agent(Module):
733
738
  self,
734
739
  state,
735
740
  latent_id = None,
736
- latent = None
741
+ latent = None,
742
+ use_ema_if_available = False
737
743
  ):
738
744
  assert exists(latent_id) or exists(latent)
739
745
 
740
746
  if not exists(latent):
741
747
  latent = self.latent_gene_pool(latent_id = latent_id)
742
748
 
743
- return self.critic(state, latent)
749
+ critic_forward = self.critic
750
+
751
+ if use_ema_if_available and self.use_critic_ema:
752
+ critic_forward = self.critic_ema
753
+
754
+ return critic_forward(state, latent)
744
755
 
745
756
  def update_latent_gene_pool_(
746
757
  self,
@@ -826,6 +837,11 @@ class Agent(Module):
826
837
  self.critic_optim.step()
827
838
  self.critic_optim.zero_grad()
828
839
 
840
+ # maybe ema update critic
841
+
842
+ if self.use_critic_ema:
843
+ self.critic_ema.update()
844
+
829
845
  # maybe update latents, if not frozen
830
846
 
831
847
  if not self.latent_gene_pool.frozen_latents:
@@ -875,6 +891,7 @@ def create_agent(
875
891
  actor_num_actions,
876
892
  actor_dim_hiddens: int | tuple[int, ...],
877
893
  critic_dim_hiddens: int | tuple[int, ...],
894
+ use_critic_ema = True,
878
895
  latent_gene_pool_kwargs: dict = dict(),
879
896
  actor_kwargs: dict = dict(),
880
897
  critic_kwargs: dict = dict(),
@@ -901,7 +918,14 @@ def create_agent(
901
918
  **critic_kwargs
902
919
  )
903
920
 
904
- return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
921
+ agent = Agent(
922
+ actor = actor,
923
+ critic = critic,
924
+ latent_gene_pool = latent_gene_pool,
925
+ use_critic_ema = use_critic_ema
926
+ )
927
+
928
+ return agent
905
929
 
906
930
  # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
907
931
  # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
@@ -978,7 +1002,7 @@ class EPO(Module):
978
1002
 
979
1003
  # values
980
1004
 
981
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1005
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
982
1006
 
983
1007
  # get the next state, action, and reward
984
1008
 
@@ -7,7 +7,7 @@ from torch.nn import Module
7
7
  # functions
8
8
 
9
9
  def cast_tuple(v):
10
- return v if isinstance(v, tuple) else v\
10
+ return v if isinstance(v, tuple) else (v,)
11
11
 
12
12
  # mock env
13
13
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.42
3
+ Version: 0.0.44
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,8 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/epo.py,sha256=8ZyW21nbEokTii7alduJVDfPLhRLsLCjttWni6-BNFE,31072
3
+ evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
+ evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
5
+ evolutionary_policy_optimization-0.0.44.dist-info/METADATA,sha256=8tJisD2xu5Q7udBrmFnfy6oxl_RQrCF9DYrE6yHkP2M,6213
6
+ evolutionary_policy_optimization-0.0.44.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.44.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.44.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=Yf-iw1gqmAUEVzg6_PwYy-q4005eroZKUYGxNgwCsKk,30440
3
- evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
- evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
5
- evolutionary_policy_optimization-0.0.42.dist-info/METADATA,sha256=wiDM3tKsE9zHhyZJGaGcSA-jZuo38W4b_SCU2vQvpGc,6213
6
- evolutionary_policy_optimization-0.0.42.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.42.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.42.dist-info/RECORD,,