evolutionary-policy-optimization 0.0.44__py3-none-any.whl → 0.0.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -628,6 +628,7 @@ class Agent(Module):
628
628
  latent_lr = 1e-5,
629
629
  use_critic_ema = True,
630
630
  critic_ema_beta = 0.99,
631
+ max_grad_norm = 0.5,
631
632
  batch_size = 16,
632
633
  calc_gae_kwargs: dict = dict(
633
634
  use_accelerated = False,
@@ -662,7 +663,11 @@ class Agent(Module):
662
663
  self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
663
664
  self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
664
665
 
666
+ # learning hparams
667
+
665
668
  self.batch_size = batch_size
669
+ self.max_grad_norm = max_grad_norm
670
+ self.has_grad_clip = exists(max_grad_norm)
666
671
 
667
672
  # optimizers
668
673
 
@@ -821,6 +826,9 @@ class Agent(Module):
821
826
 
822
827
  actor_loss.backward()
823
828
 
829
+ if exists(self.has_grad_clip):
830
+ nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
831
+
824
832
  self.actor_optim.step()
825
833
  self.actor_optim.zero_grad()
826
834
 
@@ -834,6 +842,9 @@ class Agent(Module):
834
842
 
835
843
  critic_loss.backward()
836
844
 
845
+ if exists(self.has_grad_clip):
846
+ nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
847
+
837
848
  self.critic_optim.step()
838
849
  self.critic_optim.zero_grad()
839
850
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.44
3
+ Version: 0.0.45
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=8ZyW21nbEokTii7alduJVDfPLhRLsLCjttWni6-BNFE,31072
2
+ evolutionary_policy_optimization/epo.py,sha256=NA-d7pWDJyQYULDIVB25lnbpTbwMyxc1U8RU8XGTNts,31500
3
3
  evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
5
- evolutionary_policy_optimization-0.0.44.dist-info/METADATA,sha256=8tJisD2xu5Q7udBrmFnfy6oxl_RQrCF9DYrE6yHkP2M,6213
6
- evolutionary_policy_optimization-0.0.44.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.44.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.44.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.45.dist-info/METADATA,sha256=3jXsZBoltrWQJk2Yd6zu1KmCcl9AEuhxES_mX8E1lAk,6213
6
+ evolutionary_policy_optimization-0.0.45.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.45.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.45.dist-info/RECORD,,