evolutionary-policy-optimization 0.0.35__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from functools import partial
3
4
  from pathlib import Path
4
5
  from collections import namedtuple
5
6
 
@@ -68,7 +69,6 @@ def calc_generalized_advantage_estimate(
68
69
  gamma = 0.99,
69
70
  lam = 0.95,
70
71
  use_accelerated = None
71
-
72
72
  ):
73
73
  assert values.shape[-1] == (rewards.shape[-1] + 1)
74
74
 
@@ -605,6 +605,16 @@ class Agent(Module):
605
605
  critic_lr = 1e-4,
606
606
  latent_lr = 1e-5,
607
607
  critic_ema_beta = 0.99,
608
+ batch_size = 16,
609
+ calc_gae_kwargs: dict = dict(
610
+ use_accelerated = False,
611
+ gamma = 0.99,
612
+ lam = 0.95,
613
+ ),
614
+ actor_loss_kwargs: dict = dict(
615
+ eps_clip = 0.2,
616
+ entropy_weight = .01
617
+ ),
608
618
  ema_kwargs: dict = dict(),
609
619
  actor_optim_kwargs: dict = dict(),
610
620
  critic_optim_kwargs: dict = dict(),
@@ -622,6 +632,13 @@ class Agent(Module):
622
632
 
623
633
  assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
624
634
 
635
+ # gae function
636
+
637
+ self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
638
+ self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
639
+
640
+ self.batch_size = batch_size
641
+
625
642
  # optimizers
626
643
 
627
644
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
@@ -690,11 +707,65 @@ class Agent(Module):
690
707
 
691
708
  def forward(
692
709
  self,
693
- memories_and_next_value: MemoriesAndNextValue
710
+ memories_and_next_value: MemoriesAndNextValue,
711
+ epochs = 2
694
712
  ):
695
713
  memories, next_value = memories_and_next_value
696
714
 
697
- raise NotImplementedError
715
+ (
716
+ states,
717
+ latent_gene_ids,
718
+ actions,
719
+ log_probs,
720
+ rewards,
721
+ values,
722
+ dones
723
+ ) = map(stack, zip(*memories))
724
+
725
+ values_with_next, ps = pack((values, next_value), '*')
726
+
727
+ advantages = self.calc_gae(rewards, values_with_next, dones)
728
+
729
+ dataset = TensorDataset(states, latent_gene_ids, actions, log_probs, advantages, values)
730
+
731
+ dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
732
+
733
+ self.actor.train()
734
+ self.critic.train()
735
+
736
+ for _ in range(epochs):
737
+ for (
738
+ states,
739
+ latent_gene_ids,
740
+ actions,
741
+ log_probs,
742
+ advantages,
743
+ old_values
744
+ ) in dataloader:
745
+
746
+ latents = self.latent_gene_pool(latent_gene_ids)
747
+
748
+ # learn actor
749
+
750
+ logits = self.actor(states, latents)
751
+ actor_loss = self.actor_loss(logits, log_probs, actions, advantages)
752
+
753
+ actor_loss.backward()
754
+ self.actor_optim.step()
755
+ self.actor_optim.zero_grad()
756
+
757
+ # learn critic with maybe classification loss
758
+
759
+ critic_loss = self.critic(
760
+ states,
761
+ latents,
762
+ targets = advantages + old_values
763
+ )
764
+
765
+ critic_loss.backward()
766
+
767
+ self.critic_optim.step()
768
+ self.critic_optim.zero_grad()
698
769
 
699
770
  # reinforcement learning related - ppo
700
771
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.35
3
+ Version: 0.0.37
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=OVI2XCAWYtggA4TRVDmXecQkXfXB1r-HchcXGWsvIvg,23941
2
+ evolutionary_policy_optimization/epo.py,sha256=onIGNWHg1EGQwJ9TfkkJ8Yz8_S-BPoaqrxJwq54BXp0,25992
3
3
  evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=3xrd-gwjZeVd_sEvxIyX0lppnMWcfQGOapO-XjKmExI,816
5
- evolutionary_policy_optimization-0.0.35.dist-info/METADATA,sha256=mkHr6X8PUDj_cYBEPo4KgW9COSm70Tud_6ckgtvZ-Ds,4992
6
- evolutionary_policy_optimization-0.0.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.35.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.37.dist-info/METADATA,sha256=nPWBCvx02MHWdKu5cEoPmHFMFKhwepOfStkXIXR2NHc,4992
6
+ evolutionary_policy_optimization-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.37.dist-info/RECORD,,