evolutionary-policy-optimization 0.0.43__py3-none-any.whl → 0.0.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -626,7 +626,9 @@ class Agent(Module):
626
626
  actor_lr = 1e-4,
627
627
  critic_lr = 1e-4,
628
628
  latent_lr = 1e-5,
629
+ use_critic_ema = True,
629
630
  critic_ema_beta = 0.99,
631
+ max_grad_norm = 0.5,
630
632
  batch_size = 16,
631
633
  calc_gae_kwargs: dict = dict(
632
634
  use_accelerated = False,
@@ -647,7 +649,9 @@ class Agent(Module):
647
649
  self.actor = actor
648
650
 
649
651
  self.critic = critic
650
- self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs)
652
+
653
+ self.use_critic_ema = use_critic_ema
654
+ self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
651
655
 
652
656
  self.num_latents = latent_gene_pool.num_latents
653
657
  self.latent_gene_pool = latent_gene_pool
@@ -659,7 +663,11 @@ class Agent(Module):
659
663
  self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
660
664
  self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
661
665
 
666
+ # learning hparams
667
+
662
668
  self.batch_size = batch_size
669
+ self.max_grad_norm = max_grad_norm
670
+ self.has_grad_clip = exists(max_grad_norm)
663
671
 
664
672
  # optimizers
665
673
 
@@ -676,7 +684,7 @@ class Agent(Module):
676
684
  pkg = dict(
677
685
  actor = self.actor.state_dict(),
678
686
  critic = self.critic.state_dict(),
679
- critic_ema = self.critic_ema.state_dict(),
687
+ critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
680
688
  latents = self.latent_gene_pool.state_dict(),
681
689
  actor_optim = self.actor_optim.state_dict(),
682
690
  critic_optim = self.critic_optim.state_dict(),
@@ -695,7 +703,9 @@ class Agent(Module):
695
703
  self.actor.load_state_dict(pkg['actor'])
696
704
 
697
705
  self.critic.load_state_dict(pkg['critic'])
698
- self.critic_ema.load_state_dict(pkg['critic_ema'])
706
+
707
+ if self.use_critic_ema:
708
+ self.critic_ema.load_state_dict(pkg['critic_ema'])
699
709
 
700
710
  self.latent_gene_pool.load_state_dict(pkg['latents'])
701
711
 
@@ -733,14 +743,20 @@ class Agent(Module):
733
743
  self,
734
744
  state,
735
745
  latent_id = None,
736
- latent = None
746
+ latent = None,
747
+ use_ema_if_available = False
737
748
  ):
738
749
  assert exists(latent_id) or exists(latent)
739
750
 
740
751
  if not exists(latent):
741
752
  latent = self.latent_gene_pool(latent_id = latent_id)
742
753
 
743
- return self.critic(state, latent)
754
+ critic_forward = self.critic
755
+
756
+ if use_ema_if_available and self.use_critic_ema:
757
+ critic_forward = self.critic_ema
758
+
759
+ return critic_forward(state, latent)
744
760
 
745
761
  def update_latent_gene_pool_(
746
762
  self,
@@ -810,6 +826,9 @@ class Agent(Module):
810
826
 
811
827
  actor_loss.backward()
812
828
 
829
+ if exists(self.has_grad_clip):
830
+ nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
831
+
813
832
  self.actor_optim.step()
814
833
  self.actor_optim.zero_grad()
815
834
 
@@ -823,9 +842,17 @@ class Agent(Module):
823
842
 
824
843
  critic_loss.backward()
825
844
 
845
+ if exists(self.has_grad_clip):
846
+ nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
847
+
826
848
  self.critic_optim.step()
827
849
  self.critic_optim.zero_grad()
828
850
 
851
+ # maybe ema update critic
852
+
853
+ if self.use_critic_ema:
854
+ self.critic_ema.update()
855
+
829
856
  # maybe update latents, if not frozen
830
857
 
831
858
  if not self.latent_gene_pool.frozen_latents:
@@ -875,6 +902,7 @@ def create_agent(
875
902
  actor_num_actions,
876
903
  actor_dim_hiddens: int | tuple[int, ...],
877
904
  critic_dim_hiddens: int | tuple[int, ...],
905
+ use_critic_ema = True,
878
906
  latent_gene_pool_kwargs: dict = dict(),
879
907
  actor_kwargs: dict = dict(),
880
908
  critic_kwargs: dict = dict(),
@@ -901,7 +929,14 @@ def create_agent(
901
929
  **critic_kwargs
902
930
  )
903
931
 
904
- return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
932
+ agent = Agent(
933
+ actor = actor,
934
+ critic = critic,
935
+ latent_gene_pool = latent_gene_pool,
936
+ use_critic_ema = use_critic_ema
937
+ )
938
+
939
+ return agent
905
940
 
906
941
  # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
907
942
  # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
@@ -978,7 +1013,7 @@ class EPO(Module):
978
1013
 
979
1014
  # values
980
1015
 
981
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1016
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
982
1017
 
983
1018
  # get the next state, action, and reward
984
1019
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.43
3
+ Version: 0.0.45
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=Yf-iw1gqmAUEVzg6_PwYy-q4005eroZKUYGxNgwCsKk,30440
2
+ evolutionary_policy_optimization/epo.py,sha256=NA-d7pWDJyQYULDIVB25lnbpTbwMyxc1U8RU8XGTNts,31500
3
3
  evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
5
- evolutionary_policy_optimization-0.0.43.dist-info/METADATA,sha256=pMVLppijepjmI1A9wVVhdX2IXo4BNPsOozpMAAsS6Lo,6213
6
- evolutionary_policy_optimization-0.0.43.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.43.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.43.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.45.dist-info/METADATA,sha256=3jXsZBoltrWQJk2Yd6zu1KmCcl9AEuhxES_mX8E1lAk,6213
6
+ evolutionary_policy_optimization-0.0.45.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.45.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.45.dist-info/RECORD,,