evolutionary-policy-optimization 0.0.34__tar.gz → 0.0.36__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (13) hide show
  1. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/evolutionary_policy_optimization/epo.py +71 -4
  3. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/.github/workflows/python-publish.yml +0 -0
  5. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/.github/workflows/test.yml +0 -0
  6. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/.gitignore +0 -0
  7. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/LICENSE +0 -0
  8. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/README.md +0 -0
  9. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/evolutionary_policy_optimization/__init__.py +0 -0
  10. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/evolutionary_policy_optimization/experimental.py +0 -0
  11. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/evolutionary_policy_optimization/mock_env.py +0 -0
  12. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/requirements.txt +0 -0
  13. {evolutionary_policy_optimization-0.0.34 → evolutionary_policy_optimization-0.0.36}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.34
3
+ Version: 0.0.36
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from functools import partial
3
4
  from pathlib import Path
4
5
  from collections import namedtuple
5
6
 
@@ -68,7 +69,6 @@ def calc_generalized_advantage_estimate(
68
69
  gamma = 0.99,
69
70
  lam = 0.95,
70
71
  use_accelerated = None
71
-
72
72
  ):
73
73
  assert values.shape[-1] == (rewards.shape[-1] + 1)
74
74
 
@@ -558,6 +558,7 @@ class LatentGenePool(Module):
558
558
  state: Tensor | None = None,
559
559
  latent_id: int | None = None,
560
560
  net: Module | None = None,
561
+ net_latent_kwarg_name = 'latent',
561
562
  **kwargs,
562
563
  ):
563
564
  device = self.latents.device
@@ -583,9 +584,11 @@ class LatentGenePool(Module):
583
584
  if not exists(net):
584
585
  return latent
585
586
 
587
+ latent_kwarg = {net_latent_kwarg_name: latent}
588
+
586
589
  return net(
587
590
  *args,
588
- latent = latent,
591
+ **latent_kwarg,
589
592
  **kwargs
590
593
  )
591
594
 
@@ -602,6 +605,16 @@ class Agent(Module):
602
605
  critic_lr = 1e-4,
603
606
  latent_lr = 1e-5,
604
607
  critic_ema_beta = 0.99,
608
+ batch_size = 16,
609
+ calc_gae_kwargs: dict = dict(
610
+ use_accelerated = False,
611
+ gamma = 0.99,
612
+ lam = 0.95,
613
+ ),
614
+ actor_loss_kwargs: dict = dict(
615
+ eps_clip = 0.2,
616
+ entropy_weight = .01
617
+ ),
605
618
  ema_kwargs: dict = dict(),
606
619
  actor_optim_kwargs: dict = dict(),
607
620
  critic_optim_kwargs: dict = dict(),
@@ -619,6 +632,13 @@ class Agent(Module):
619
632
 
620
633
  assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
621
634
 
635
+ # gae function
636
+
637
+ self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
638
+ self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
639
+
640
+ self.batch_size = batch_size
641
+
622
642
  # optimizers
623
643
 
624
644
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
@@ -687,11 +707,58 @@ class Agent(Module):
687
707
 
688
708
  def forward(
689
709
  self,
690
- memories_and_next_value: MemoriesAndNextValue
710
+ memories_and_next_value: MemoriesAndNextValue,
711
+ epochs = 2
691
712
  ):
692
713
  memories, next_value = memories_and_next_value
693
714
 
694
- raise NotImplementedError
715
+ (
716
+ states,
717
+ latent_gene_ids,
718
+ actions,
719
+ log_probs,
720
+ rewards,
721
+ values,
722
+ dones
723
+ ) = map(stack, zip(*memories))
724
+
725
+ values_with_next, ps = pack((values, next_value), '*')
726
+
727
+ advantages = self.calc_gae(rewards, values_with_next, dones)
728
+
729
+ dataset = TensorDataset(states, latent_gene_ids, actions, log_probs, advantages, values)
730
+
731
+ dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
732
+
733
+ self.actor.train()
734
+ self.critic.train()
735
+
736
+ for _ in range(epochs):
737
+ for (
738
+ states,
739
+ latent_gene_ids,
740
+ actions,
741
+ log_probs,
742
+ advantages,
743
+ old_values
744
+ ) in dataloader:
745
+
746
+ # learn actor
747
+
748
+ logits = self.actor(states)
749
+ actor_loss = self.actor_loss(logits, log_probs, actions, advantages)
750
+
751
+ actor_loss.backward()
752
+ self.actor_optim.step()
753
+ self.actor_optim.zero_grad()
754
+
755
+ # learn critic with maybe classification loss
756
+
757
+ critic_loss = self.critic(states, advantages + old_values)
758
+ critic_loss.backward()
759
+
760
+ self.critic_optim.step()
761
+ self.critic_optim.zero_grad()
695
762
 
696
763
  # reinforcement learning related - ppo
697
764
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.34"
3
+ version = "0.0.36"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }