evolutionary-policy-optimization 0.0.38__tar.gz → 0.0.40__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (13) hide show
  1. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/PKG-INFO +40 -1
  2. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/README.md +39 -0
  3. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/epo.py +170 -45
  4. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/tests/test_epo.py +31 -1
  6. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/experimental.py +0 -0
  12. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/evolutionary_policy_optimization/mock_env.py +0 -0
  13. {evolutionary_policy_optimization-0.0.38 → evolutionary_policy_optimization-0.0.40}/requirements.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.38
3
+ Version: 0.0.40
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -101,6 +101,45 @@ fitness = torch.randn(128)
101
101
  latent_pool.genetic_algorithm_step(fitness) # update latent genes with genetic algorithm
102
102
  ```
103
103
 
104
+ End to end learning
105
+
106
+ ```python
107
+ import torch
108
+
109
+ from evolutionary_policy_optimization import (
110
+ create_agent,
111
+ EPO,
112
+ Env
113
+ )
114
+
115
+ agent = create_agent(
116
+ dim_state = 512,
117
+ num_latents = 8,
118
+ dim_latent = 32,
119
+ actor_num_actions = 5,
120
+ actor_dim_hiddens = (256, 128),
121
+ critic_dim_hiddens = (256, 128, 64)
122
+ )
123
+
124
+ epo = EPO(
125
+ agent,
126
+ episodes_per_latent = 1,
127
+ max_episode_length = 10,
128
+ action_sample_temperature = 1.
129
+ )
130
+
131
+ env = Env((512,))
132
+
133
+ memories = epo(env)
134
+
135
+ agent(memories)
136
+
137
+ # saving and loading
138
+
139
+ agent.save('./agent.pt', overwrite = True)
140
+ agent.load('./agent.pt')
141
+ ```
142
+
104
143
  ## Citations
105
144
 
106
145
  ```bibtex
@@ -49,6 +49,45 @@ fitness = torch.randn(128)
49
49
  latent_pool.genetic_algorithm_step(fitness) # update latent genes with genetic algorithm
50
50
  ```
51
51
 
52
+ End to end learning
53
+
54
+ ```python
55
+ import torch
56
+
57
+ from evolutionary_policy_optimization import (
58
+ create_agent,
59
+ EPO,
60
+ Env
61
+ )
62
+
63
+ agent = create_agent(
64
+ dim_state = 512,
65
+ num_latents = 8,
66
+ dim_latent = 32,
67
+ actor_num_actions = 5,
68
+ actor_dim_hiddens = (256, 128),
69
+ critic_dim_hiddens = (256, 128, 64)
70
+ )
71
+
72
+ epo = EPO(
73
+ agent,
74
+ episodes_per_latent = 1,
75
+ max_episode_length = 10,
76
+ action_sample_temperature = 1.
77
+ )
78
+
79
+ env = Env((512,))
80
+
81
+ memories = epo(env)
82
+
83
+ agent(memories)
84
+
85
+ # saving and loading
86
+
87
+ agent.save('./agent.pt', overwrite = True)
88
+ agent.load('./agent.pt')
89
+ ```
90
+
52
91
  ## Citations
53
92
 
54
93
  ```bibtex
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
3
+ from functools import partial, wraps
4
4
  from pathlib import Path
5
5
  from collections import namedtuple
6
6
 
@@ -9,6 +9,7 @@ from torch import nn, cat, stack, is_tensor, tensor
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
  from torch.utils.data import TensorDataset, DataLoader
12
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
12
13
 
13
14
  import einx
14
15
  from einops import rearrange, repeat, einsum, pack
@@ -22,6 +23,8 @@ from hl_gauss_pytorch import HLGaussLayer
22
23
 
23
24
  from ema_pytorch import EMA
24
25
 
26
+ from tqdm import tqdm
27
+
25
28
  # helpers
26
29
 
27
30
  def exists(v):
@@ -47,9 +50,20 @@ def l2norm(t):
47
50
  def log(t, eps = 1e-20):
48
51
  return t.clamp(min = eps).log()
49
52
 
53
+ def gumbel_noise(t):
54
+ return -log(-log(torch.rand_like(t)))
55
+
56
+ def gumbel_sample(t, temperature = 1.):
57
+ is_greedy = temperature <= 0.
58
+
59
+ if not is_greedy:
60
+ t = (t / temperature) + gumbel_noise(t)
61
+
62
+ return t.argmax(dim = -1)
63
+
50
64
  def calc_entropy(logits):
51
65
  prob = logits.softmax(dim = -1)
52
- return -prob * log(prob)
66
+ return -(prob * log(prob)).sum(dim = -1)
53
67
 
54
68
  def gather_log_prob(
55
69
  logits, # Float[b l]
@@ -60,11 +74,24 @@ def gather_log_prob(
60
74
  log_prob = log_probs.gather(-1, indices)
61
75
  return rearrange(log_prob, '... 1 -> ...')
62
76
 
77
+ def temp_batch_dim(fn):
78
+
79
+ @wraps(fn)
80
+ def inner(*args, **kwargs):
81
+ args, kwargs = tree_map(lambda t: rearrange(t, '... -> 1 ...') if is_tensor(t) else t, (args, kwargs))
82
+
83
+ out = fn(*args, **kwargs)
84
+
85
+ out = tree_map(lambda t: rearrange(t, '1 ... -> ...') if is_tensor(t) else t, out)
86
+ return out
87
+
88
+ return inner
89
+
63
90
  # generalized advantage estimate
64
91
 
65
92
  def calc_generalized_advantage_estimate(
66
- rewards, # Float[g n]
67
- values, # Float[g n+1]
93
+ rewards, # Float[n]
94
+ values, # Float[n+1]
68
95
  masks, # Bool[n]
69
96
  gamma = 0.99,
70
97
  lam = 0.95,
@@ -75,9 +102,7 @@ def calc_generalized_advantage_estimate(
75
102
  use_accelerated = default(use_accelerated, rewards.is_cuda)
76
103
  device = rewards.device
77
104
 
78
- masks = repeat(masks, 'n -> g n', g = rewards.shape[0])
79
-
80
- values, values_next = values[:, :-1], values[:, 1:]
105
+ values, values_next = values[:-1], values[1:]
81
106
 
82
107
  delta = rewards + gamma * values_next * masks - values
83
108
  gates = gamma * lam * masks
@@ -565,6 +590,8 @@ class LatentGenePool(Module):
565
590
  if not exists(latent_id) and self.num_latents == 1:
566
591
  latent_id = 0
567
592
 
593
+ assert exists(latent_id)
594
+
568
595
  if not is_tensor(latent_id):
569
596
  latent_id = tensor(latent_id, device = device)
570
597
 
@@ -681,17 +708,38 @@ class Agent(Module):
681
708
  def get_actor_actions(
682
709
  self,
683
710
  state,
684
- latent_id
711
+ latent_id = None,
712
+ latent = None,
713
+ sample = False,
714
+ temperature = 1.
685
715
  ):
686
- latent = self.latent_gene_pool(latent_id = latent_id, state = state)
687
- return self.actor(state, latent)
716
+ assert exists(latent_id) or exists(latent)
717
+
718
+ if not exists(latent):
719
+ latent = self.latent_gene_pool(latent_id = latent_id)
720
+
721
+ logits = self.actor(state, latent)
722
+
723
+ if not sample:
724
+ return logits
725
+
726
+ actions = gumbel_sample(logits, temperature = temperature)
727
+
728
+ log_probs = gather_log_prob(logits, actions)
729
+
730
+ return actions, log_probs
688
731
 
689
732
  def get_critic_values(
690
733
  self,
691
734
  state,
692
- latent_id
735
+ latent_id = None,
736
+ latent = None
693
737
  ):
694
- latent = self.latent_gene_pool(latent_id = latent_id, state = state)
738
+ assert exists(latent_id) or exists(latent)
739
+
740
+ if not exists(latent):
741
+ latent = self.latent_gene_pool(latent_id = latent_id)
742
+
695
743
  return self.critic(state, latent)
696
744
 
697
745
  def update_latent_gene_pool_(
@@ -702,13 +750,13 @@ class Agent(Module):
702
750
 
703
751
  def forward(
704
752
  self,
705
- memories_and_next_value: MemoriesAndNextValue,
753
+ memories_and_fitness_scores: MemoriesAndFitnessScores,
706
754
  epochs = 2
707
755
  ):
708
- memories, next_value = memories_and_next_value
756
+ memories, fitness_scores = memories_and_fitness_scores
709
757
 
710
758
  (
711
- _,
759
+ episode_ids,
712
760
  states,
713
761
  latent_gene_ids,
714
762
  actions,
@@ -718,35 +766,46 @@ class Agent(Module):
718
766
  dones
719
767
  ) = map(stack, zip(*memories))
720
768
 
721
- values_with_next, ps = pack((values, next_value), '*')
769
+ advantages = self.calc_gae(
770
+ rewards[:-1],
771
+ values,
772
+ dones[:-1],
773
+ )
722
774
 
723
- advantages = self.calc_gae(rewards, values_with_next, dones)
775
+ valid_episode = episode_ids >= 0
724
776
 
725
- dataset = TensorDataset(states, latent_gene_ids, actions, log_probs, advantages, values)
777
+ dataset = TensorDataset(
778
+ *[
779
+ advantages[valid_episode[:-1]],
780
+ *[t[valid_episode] for t in (states, latent_gene_ids, actions, log_probs, values)]
781
+ ]
782
+ )
726
783
 
727
784
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
728
785
 
729
786
  self.actor.train()
730
787
  self.critic.train()
731
788
 
732
- for _ in range(epochs):
789
+ for _ in tqdm(range(epochs), desc = 'learning actor/critic epoch'):
733
790
  for (
791
+ advantages,
734
792
  states,
735
793
  latent_gene_ids,
736
794
  actions,
737
795
  log_probs,
738
- advantages,
739
796
  old_values
740
797
  ) in dataloader:
741
798
 
742
- latents = self.latent_gene_pool(latent_gene_ids)
799
+ latents = self.latent_gene_pool(latent_id = latent_gene_ids)
743
800
 
744
801
  # learn actor
745
802
 
746
803
  logits = self.actor(states, latents)
804
+
747
805
  actor_loss = self.actor_loss(logits, log_probs, actions, advantages)
748
806
 
749
807
  actor_loss.backward()
808
+
750
809
  self.actor_optim.step()
751
810
  self.actor_optim.zero_grad()
752
811
 
@@ -755,7 +814,7 @@ class Agent(Module):
755
814
  critic_loss = self.critic(
756
815
  states,
757
816
  latents,
758
- targets = advantages + old_values
817
+ target = advantages + old_values
759
818
  )
760
819
 
761
820
  critic_loss.backward()
@@ -763,6 +822,10 @@ class Agent(Module):
763
822
  self.critic_optim.step()
764
823
  self.critic_optim.zero_grad()
765
824
 
825
+ # apply evolution
826
+
827
+ self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
828
+
766
829
  # reinforcement learning related - ppo
767
830
 
768
831
  def actor_loss(
@@ -789,15 +852,7 @@ def actor_loss(
789
852
 
790
853
  entropy_aux_loss = -entropy_weight * entropy
791
854
 
792
- return actor_loss + entropy_aux_loss
793
-
794
- def critic_loss(
795
- pred_values, # Float[b]
796
- advantages, # Float[b]
797
- old_values # Float[b]
798
- ):
799
- discounted_values = advantages + old_values
800
- return F.mse_loss(pred_values, discounted_values)
855
+ return (actor_loss + entropy_aux_loss).mean()
801
856
 
802
857
  # agent contains the actor, critic, and the latent genetic pool
803
858
 
@@ -810,6 +865,11 @@ def create_agent(
810
865
  critic_dim_hiddens: int | tuple[int, ...],
811
866
  ) -> Agent:
812
867
 
868
+ latent_gene_pool = LatentGenePool(
869
+ num_latents = num_latents,
870
+ dim_latent = dim_latent
871
+ )
872
+
813
873
  actor = Actor(
814
874
  num_actions = actor_num_actions,
815
875
  dim_state = dim_state,
@@ -821,13 +881,7 @@ def create_agent(
821
881
  dim_state = dim_state,
822
882
  dim_latent = dim_latent,
823
883
  dim_hiddens = critic_dim_hiddens
824
- )
825
-
826
- latent_gene_pool = LatentGenePool(
827
- num_latents = num_latents,
828
- dim_latent = dim_latent,
829
- )
830
-
884
+ )
831
885
  return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
832
886
 
833
887
  # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
@@ -840,14 +894,13 @@ Memory = namedtuple('Memory', [
840
894
  'action',
841
895
  'log_prob',
842
896
  'reward',
843
- 'values',
897
+ 'value',
844
898
  'done'
845
899
  ])
846
900
 
847
- MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
901
+ MemoriesAndFitnessScores = namedtuple('MemoriesAndFitnessScores', [
848
902
  'memories',
849
- 'next_value',
850
- 'cumulative_rewards'
903
+ 'fitness_scores'
851
904
  ])
852
905
 
853
906
  class EPO(Module):
@@ -856,10 +909,12 @@ class EPO(Module):
856
909
  self,
857
910
  agent: Agent,
858
911
  episodes_per_latent,
859
- max_episode_length
912
+ max_episode_length,
913
+ action_sample_temperature = 1.
860
914
  ):
861
915
  super().__init__()
862
916
  self.agent = agent
917
+ self.action_sample_temperature = action_sample_temperature
863
918
 
864
919
  self.num_latents = agent.latent_gene_pool.num_latents
865
920
  self.episodes_per_latent = episodes_per_latent
@@ -869,10 +924,80 @@ class EPO(Module):
869
924
  def forward(
870
925
  self,
871
926
  env
872
- ) -> MemoriesAndNextValue:
927
+ ) -> MemoriesAndFitnessScores:
873
928
 
874
929
  self.agent.eval()
875
930
 
931
+ invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
932
+
876
933
  memories: list[Memory] = []
877
934
 
878
- raise NotImplementedError
935
+ fitness_scores = torch.zeros((self.num_latents))
936
+
937
+ for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
938
+
939
+ for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
940
+ time = 0
941
+
942
+ # initial state
943
+
944
+ state = env.reset()
945
+
946
+ # get latent from pool
947
+
948
+ latent = self.agent.latent_gene_pool(latent_id = latent_id)
949
+
950
+ # until maximum episode length
951
+
952
+ done = tensor(False)
953
+
954
+ while time < self.max_episode_length:
955
+
956
+ # sample action
957
+
958
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
959
+
960
+ # values
961
+
962
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
963
+
964
+ # get the next state, action, and reward
965
+
966
+ state, reward, done = env(action)
967
+
968
+ # update fitness for each gene as cumulative reward received, but make this customizable at some point
969
+
970
+ fitness_scores[latent_id] += reward
971
+
972
+ # store memories
973
+
974
+ memory = Memory(
975
+ tensor(episode_id),
976
+ state,
977
+ tensor(latent_id),
978
+ action,
979
+ log_prob,
980
+ reward,
981
+ value,
982
+ done
983
+ )
984
+
985
+ memories.append(memory)
986
+
987
+ time += 1
988
+
989
+ # need the final next value for GAE, iiuc
990
+
991
+ next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
992
+
993
+ memory_for_gae = memory._replace(
994
+ episode_id = invalid_episode,
995
+ value = next_value
996
+ )
997
+
998
+ memories.append(memory_for_gae)
999
+
1000
+ return MemoriesAndFitnessScores(
1001
+ memories = memories,
1002
+ fitness_scores = fitness_scores
1003
+ )
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.38"
3
+ version = "0.0.40"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -71,4 +71,34 @@ def test_create_agent(
71
71
  # saving and loading
72
72
 
73
73
  agent.save('./agent.pt', overwrite = True)
74
- agent.load('./agent.pt')
74
+ agent.load('./agent.pt')
75
+
76
+ def test_e2e_with_mock_env():
77
+ from evolutionary_policy_optimization import create_agent, EPO, Env
78
+
79
+ agent = create_agent(
80
+ dim_state = 512,
81
+ num_latents = 8,
82
+ dim_latent = 32,
83
+ actor_num_actions = 5,
84
+ actor_dim_hiddens = (256, 128),
85
+ critic_dim_hiddens = (256, 128, 64)
86
+ )
87
+
88
+ epo = EPO(
89
+ agent,
90
+ episodes_per_latent = 1,
91
+ max_episode_length = 10,
92
+ action_sample_temperature = 1.
93
+ )
94
+
95
+ env = Env((512,))
96
+
97
+ memories = epo(env)
98
+
99
+ agent(memories)
100
+
101
+ # saving and loading
102
+
103
+ agent.save('./agent.pt', overwrite = True)
104
+ agent.load('./agent.pt')