evolutionary-policy-optimization 0.0.38__py3-none-any.whl → 0.0.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,6 +22,8 @@ from hl_gauss_pytorch import HLGaussLayer
22
22
 
23
23
  from ema_pytorch import EMA
24
24
 
25
+ from tqdm import tqdm
26
+
25
27
  # helpers
26
28
 
27
29
  def exists(v):
@@ -47,9 +49,20 @@ def l2norm(t):
47
49
  def log(t, eps = 1e-20):
48
50
  return t.clamp(min = eps).log()
49
51
 
52
+ def gumbel_noise(t):
53
+ return -log(-log(torch.rand_like(t)))
54
+
55
+ def gumbel_sample(t, temperature = 1.):
56
+ is_greedy = temperature <= 0.
57
+
58
+ if not is_greedy:
59
+ t = (t / temperature) + gumbel_noise(t)
60
+
61
+ return t.argmax(dim = -1)
62
+
50
63
  def calc_entropy(logits):
51
64
  prob = logits.softmax(dim = -1)
52
- return -prob * log(prob)
65
+ return -(prob * log(prob)).sum(dim = -1)
53
66
 
54
67
  def gather_log_prob(
55
68
  logits, # Float[b l]
@@ -63,8 +76,8 @@ def gather_log_prob(
63
76
  # generalized advantage estimate
64
77
 
65
78
  def calc_generalized_advantage_estimate(
66
- rewards, # Float[g n]
67
- values, # Float[g n+1]
79
+ rewards, # Float[n]
80
+ values, # Float[n+1]
68
81
  masks, # Bool[n]
69
82
  gamma = 0.99,
70
83
  lam = 0.95,
@@ -75,9 +88,7 @@ def calc_generalized_advantage_estimate(
75
88
  use_accelerated = default(use_accelerated, rewards.is_cuda)
76
89
  device = rewards.device
77
90
 
78
- masks = repeat(masks, 'n -> g n', g = rewards.shape[0])
79
-
80
- values, values_next = values[:, :-1], values[:, 1:]
91
+ values, values_next = values[:-1], values[1:]
81
92
 
82
93
  delta = rewards + gamma * values_next * masks - values
83
94
  gates = gamma * lam * masks
@@ -565,6 +576,8 @@ class LatentGenePool(Module):
565
576
  if not exists(latent_id) and self.num_latents == 1:
566
577
  latent_id = 0
567
578
 
579
+ assert exists(latent_id)
580
+
568
581
  if not is_tensor(latent_id):
569
582
  latent_id = tensor(latent_id, device = device)
570
583
 
@@ -681,17 +694,38 @@ class Agent(Module):
681
694
  def get_actor_actions(
682
695
  self,
683
696
  state,
684
- latent_id
697
+ latent_id = None,
698
+ latent = None,
699
+ sample = False,
700
+ temperature = 1.
685
701
  ):
686
- latent = self.latent_gene_pool(latent_id = latent_id, state = state)
687
- return self.actor(state, latent)
702
+ assert exists(latent_id) or exists(latent)
703
+
704
+ if not exists(latent):
705
+ latent = self.latent_gene_pool(latent_id = latent_id)
706
+
707
+ logits = self.actor(state, latent)
708
+
709
+ if not sample:
710
+ return logits
711
+
712
+ actions = gumbel_sample(logits, temperature = temperature)
713
+
714
+ log_probs = gather_log_prob(logits, actions)
715
+
716
+ return actions, log_probs
688
717
 
689
718
  def get_critic_values(
690
719
  self,
691
720
  state,
692
- latent_id
721
+ latent_id = None,
722
+ latent = None
693
723
  ):
694
- latent = self.latent_gene_pool(latent_id = latent_id, state = state)
724
+ assert exists(latent_id) or exists(latent)
725
+
726
+ if not exists(latent):
727
+ latent = self.latent_gene_pool(latent_id = latent_id)
728
+
695
729
  return self.critic(state, latent)
696
730
 
697
731
  def update_latent_gene_pool_(
@@ -702,13 +736,13 @@ class Agent(Module):
702
736
 
703
737
  def forward(
704
738
  self,
705
- memories_and_next_value: MemoriesAndNextValue,
739
+ memories_and_fitness_scores: MemoriesAndFitnessScores,
706
740
  epochs = 2
707
741
  ):
708
- memories, next_value = memories_and_next_value
742
+ memories, fitness_scores = memories_and_fitness_scores
709
743
 
710
744
  (
711
- _,
745
+ episode_ids,
712
746
  states,
713
747
  latent_gene_ids,
714
748
  actions,
@@ -718,35 +752,46 @@ class Agent(Module):
718
752
  dones
719
753
  ) = map(stack, zip(*memories))
720
754
 
721
- values_with_next, ps = pack((values, next_value), '*')
755
+ advantages = self.calc_gae(
756
+ rewards[:-1],
757
+ values,
758
+ dones[:-1],
759
+ )
722
760
 
723
- advantages = self.calc_gae(rewards, values_with_next, dones)
761
+ valid_episode = episode_ids >= 0
724
762
 
725
- dataset = TensorDataset(states, latent_gene_ids, actions, log_probs, advantages, values)
763
+ dataset = TensorDataset(
764
+ *[
765
+ advantages[valid_episode[:-1]],
766
+ *[t[valid_episode] for t in (states, latent_gene_ids, actions, log_probs, values)]
767
+ ]
768
+ )
726
769
 
727
770
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
728
771
 
729
772
  self.actor.train()
730
773
  self.critic.train()
731
774
 
732
- for _ in range(epochs):
775
+ for _ in tqdm(range(epochs), desc = 'learning actor/critic epoch'):
733
776
  for (
777
+ advantages,
734
778
  states,
735
779
  latent_gene_ids,
736
780
  actions,
737
781
  log_probs,
738
- advantages,
739
782
  old_values
740
783
  ) in dataloader:
741
784
 
742
- latents = self.latent_gene_pool(latent_gene_ids)
785
+ latents = self.latent_gene_pool(latent_id = latent_gene_ids)
743
786
 
744
787
  # learn actor
745
788
 
746
789
  logits = self.actor(states, latents)
790
+
747
791
  actor_loss = self.actor_loss(logits, log_probs, actions, advantages)
748
792
 
749
793
  actor_loss.backward()
794
+
750
795
  self.actor_optim.step()
751
796
  self.actor_optim.zero_grad()
752
797
 
@@ -755,7 +800,7 @@ class Agent(Module):
755
800
  critic_loss = self.critic(
756
801
  states,
757
802
  latents,
758
- targets = advantages + old_values
803
+ target = advantages + old_values
759
804
  )
760
805
 
761
806
  critic_loss.backward()
@@ -763,6 +808,10 @@ class Agent(Module):
763
808
  self.critic_optim.step()
764
809
  self.critic_optim.zero_grad()
765
810
 
811
+ # apply evolution
812
+
813
+ self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
814
+
766
815
  # reinforcement learning related - ppo
767
816
 
768
817
  def actor_loss(
@@ -789,15 +838,7 @@ def actor_loss(
789
838
 
790
839
  entropy_aux_loss = -entropy_weight * entropy
791
840
 
792
- return actor_loss + entropy_aux_loss
793
-
794
- def critic_loss(
795
- pred_values, # Float[b]
796
- advantages, # Float[b]
797
- old_values # Float[b]
798
- ):
799
- discounted_values = advantages + old_values
800
- return F.mse_loss(pred_values, discounted_values)
841
+ return (actor_loss + entropy_aux_loss).mean()
801
842
 
802
843
  # agent contains the actor, critic, and the latent genetic pool
803
844
 
@@ -810,6 +851,11 @@ def create_agent(
810
851
  critic_dim_hiddens: int | tuple[int, ...],
811
852
  ) -> Agent:
812
853
 
854
+ latent_gene_pool = LatentGenePool(
855
+ num_latents = num_latents,
856
+ dim_latent = dim_latent
857
+ )
858
+
813
859
  actor = Actor(
814
860
  num_actions = actor_num_actions,
815
861
  dim_state = dim_state,
@@ -821,13 +867,7 @@ def create_agent(
821
867
  dim_state = dim_state,
822
868
  dim_latent = dim_latent,
823
869
  dim_hiddens = critic_dim_hiddens
824
- )
825
-
826
- latent_gene_pool = LatentGenePool(
827
- num_latents = num_latents,
828
- dim_latent = dim_latent,
829
- )
830
-
870
+ )
831
871
  return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
832
872
 
833
873
  # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
@@ -840,14 +880,13 @@ Memory = namedtuple('Memory', [
840
880
  'action',
841
881
  'log_prob',
842
882
  'reward',
843
- 'values',
883
+ 'value',
844
884
  'done'
845
885
  ])
846
886
 
847
- MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
887
+ MemoriesAndFitnessScores = namedtuple('MemoriesAndFitnessScores', [
848
888
  'memories',
849
- 'next_value',
850
- 'cumulative_rewards'
889
+ 'fitness_scores'
851
890
  ])
852
891
 
853
892
  class EPO(Module):
@@ -856,10 +895,12 @@ class EPO(Module):
856
895
  self,
857
896
  agent: Agent,
858
897
  episodes_per_latent,
859
- max_episode_length
898
+ max_episode_length,
899
+ action_sample_temperature = 1.
860
900
  ):
861
901
  super().__init__()
862
902
  self.agent = agent
903
+ self.action_sample_temperature = action_sample_temperature
863
904
 
864
905
  self.num_latents = agent.latent_gene_pool.num_latents
865
906
  self.episodes_per_latent = episodes_per_latent
@@ -869,10 +910,90 @@ class EPO(Module):
869
910
  def forward(
870
911
  self,
871
912
  env
872
- ) -> MemoriesAndNextValue:
913
+ ) -> MemoriesAndFitnessScores:
873
914
 
874
915
  self.agent.eval()
875
916
 
917
+ invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
918
+
876
919
  memories: list[Memory] = []
877
920
 
878
- raise NotImplementedError
921
+ fitness_scores = torch.zeros((self.num_latents))
922
+
923
+ for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
924
+
925
+ for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
926
+ time = 0
927
+
928
+ # initial state
929
+
930
+ state = env.reset()
931
+
932
+ # get latent from pool
933
+
934
+ latent = self.agent.latent_gene_pool(latent_id = latent_id)
935
+
936
+ # until maximum episode length
937
+
938
+ done = tensor(False)
939
+
940
+ while time < self.max_episode_length:
941
+
942
+ batched_state = rearrange(state, '... -> 1 ...')
943
+
944
+ # sample action
945
+
946
+ action, log_prob = self.agent.get_actor_actions(batched_state, latent = latent, sample = True, temperature = self.action_sample_temperature)
947
+
948
+ action = rearrange(action, '1 ... -> ...')
949
+ log_prob = rearrange(log_prob, '1 ... -> ...')
950
+
951
+ # values
952
+
953
+ value = self.agent.get_critic_values(batched_state, latent = latent)
954
+
955
+ value = rearrange(value, '1 ... -> ...')
956
+
957
+ # get the next state, action, and reward
958
+
959
+ state, reward, done = env(action)
960
+
961
+ # update fitness for each gene as cumulative reward received, but make this customizable at some point
962
+
963
+ fitness_scores[latent_id] += reward
964
+
965
+ # store memories
966
+
967
+ memory = Memory(
968
+ tensor(episode_id),
969
+ state,
970
+ tensor(latent_id),
971
+ action,
972
+ log_prob,
973
+ reward,
974
+ value,
975
+ done
976
+ )
977
+
978
+ memories.append(memory)
979
+
980
+ time += 1
981
+
982
+ # need the final next value for GAE, iiuc
983
+
984
+ batched_state = rearrange(state, '... -> 1 ...')
985
+
986
+ next_value = self.agent.get_critic_values(batched_state, latent = latent)
987
+ next_value = rearrange(next_value, '1 ... -> ...')
988
+
989
+ memory_for_gae = memory._replace(
990
+ episode_id = invalid_episode,
991
+ value = next_value
992
+ )
993
+
994
+ memories.append(memory_for_gae)
995
+
996
+ return MemoriesAndFitnessScores(
997
+ memories = memories,
998
+ fitness_scores = fitness_scores
999
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.38
3
+ Version: 0.0.39
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -101,6 +101,45 @@ fitness = torch.randn(128)
101
101
  latent_pool.genetic_algorithm_step(fitness) # update latent genes with genetic algorithm
102
102
  ```
103
103
 
104
+ End to end learning
105
+
106
+ ```python
107
+ import torch
108
+
109
+ from evolutionary_policy_optimization import (
110
+ create_agent,
111
+ EPO,
112
+ Env
113
+ )
114
+
115
+ agent = create_agent(
116
+ dim_state = 512,
117
+ num_latents = 8,
118
+ dim_latent = 32,
119
+ actor_num_actions = 5,
120
+ actor_dim_hiddens = (256, 128),
121
+ critic_dim_hiddens = (256, 128, 64)
122
+ )
123
+
124
+ epo = EPO(
125
+ agent,
126
+ episodes_per_latent = 1,
127
+ max_episode_length = 10,
128
+ action_sample_temperature = 1.
129
+ )
130
+
131
+ env = Env((512,))
132
+
133
+ memories = epo(env)
134
+
135
+ agent(memories)
136
+
137
+ # saving and loading
138
+
139
+ agent.save('./agent.pt', overwrite = True)
140
+ agent.load('./agent.pt')
141
+ ```
142
+
104
143
  ## Citations
105
144
 
106
145
  ```bibtex
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=z41p5LmvOHULq6o5aIj9Q6lpyka5DvkqsJ493-WL-EQ,26175
2
+ evolutionary_policy_optimization/epo.py,sha256=lzxPamJahE5KqBwzyYlGOwNeUoB2vONLwtRcWqCI_Jw,29800
3
3
  evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
5
- evolutionary_policy_optimization-0.0.38.dist-info/METADATA,sha256=Lofrc6waEB8qBD19pjBKQjbKBYMNUyjYZZrJCO1fji8,4818
6
- evolutionary_policy_optimization-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.38.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.39.dist-info/METADATA,sha256=TTNQD7sTWIgpVwnrQrFFBD-cyySkvwJr_J3ABxTpor8,5409
6
+ evolutionary_policy_optimization-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.39.dist-info/RECORD,,