evolutionary-policy-optimization 0.0.37__py3-none-any.whl → 0.0.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,5 +4,8 @@ from evolutionary_policy_optimization.epo import (
4
4
  Critic,
5
5
  create_agent,
6
6
  Agent,
7
- LatentGenePool
7
+ LatentGenePool,
8
+ EPO
8
9
  )
10
+
11
+ from evolutionary_policy_optimization.mock_env import Env
@@ -5,13 +5,13 @@ from pathlib import Path
5
5
  from collections import namedtuple
6
6
 
7
7
  import torch
8
- from torch import nn, cat, is_tensor, tensor
8
+ from torch import nn, cat, stack, is_tensor, tensor
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
  from torch.utils.data import TensorDataset, DataLoader
12
12
 
13
13
  import einx
14
- from einops import rearrange, repeat, einsum
14
+ from einops import rearrange, repeat, einsum, pack
15
15
  from einops.layers.torch import Rearrange
16
16
 
17
17
  from assoc_scan import AssocScan
@@ -22,6 +22,8 @@ from hl_gauss_pytorch import HLGaussLayer
22
22
 
23
23
  from ema_pytorch import EMA
24
24
 
25
+ from tqdm import tqdm
26
+
25
27
  # helpers
26
28
 
27
29
  def exists(v):
@@ -47,9 +49,20 @@ def l2norm(t):
47
49
  def log(t, eps = 1e-20):
48
50
  return t.clamp(min = eps).log()
49
51
 
52
+ def gumbel_noise(t):
53
+ return -log(-log(torch.rand_like(t)))
54
+
55
+ def gumbel_sample(t, temperature = 1.):
56
+ is_greedy = temperature <= 0.
57
+
58
+ if not is_greedy:
59
+ t = (t / temperature) + gumbel_noise(t)
60
+
61
+ return t.argmax(dim = -1)
62
+
50
63
  def calc_entropy(logits):
51
64
  prob = logits.softmax(dim = -1)
52
- return -prob * log(prob)
65
+ return -(prob * log(prob)).sum(dim = -1)
53
66
 
54
67
  def gather_log_prob(
55
68
  logits, # Float[b l]
@@ -63,8 +76,8 @@ def gather_log_prob(
63
76
  # generalized advantage estimate
64
77
 
65
78
  def calc_generalized_advantage_estimate(
66
- rewards, # Float[g n]
67
- values, # Float[g n+1]
79
+ rewards, # Float[n]
80
+ values, # Float[n+1]
68
81
  masks, # Bool[n]
69
82
  gamma = 0.99,
70
83
  lam = 0.95,
@@ -75,9 +88,7 @@ def calc_generalized_advantage_estimate(
75
88
  use_accelerated = default(use_accelerated, rewards.is_cuda)
76
89
  device = rewards.device
77
90
 
78
- masks = repeat(masks, 'n -> g n', g = rewards.shape[0])
79
-
80
- values, values_next = values[:, :-1], values[:, 1:]
91
+ values, values_next = values[:-1], values[1:]
81
92
 
82
93
  delta = rewards + gamma * values_next * masks - values
83
94
  gates = gamma * lam * masks
@@ -319,7 +330,6 @@ class LatentGenePool(Module):
319
330
  num_latents, # same as gene pool size
320
331
  dim_latent, # gene dimension
321
332
  num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
322
- dim_state = None,
323
333
  frozen_latents = True,
324
334
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
325
335
  l2norm_latent = False, # whether to enforce latents on hypersphere,
@@ -384,7 +394,6 @@ class LatentGenePool(Module):
384
394
  fitness,
385
395
  beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
386
396
  gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
387
- alpha = 0.1, # exploration factor
388
397
  inplace = True,
389
398
  ):
390
399
  islands = self.num_islands
@@ -555,7 +564,6 @@ class LatentGenePool(Module):
555
564
  def forward(
556
565
  self,
557
566
  *args,
558
- state: Tensor | None = None,
559
567
  latent_id: int | None = None,
560
568
  net: Module | None = None,
561
569
  net_latent_kwarg_name = 'latent',
@@ -568,6 +576,8 @@ class LatentGenePool(Module):
568
576
  if not exists(latent_id) and self.num_latents == 1:
569
577
  latent_id = 0
570
578
 
579
+ assert exists(latent_id)
580
+
571
581
  if not is_tensor(latent_id):
572
582
  latent_id = tensor(latent_id, device = device)
573
583
 
@@ -575,8 +585,6 @@ class LatentGenePool(Module):
575
585
 
576
586
  # fetch latent
577
587
 
578
- fetching_multiple_latents = latent_id.numel() > 1
579
-
580
588
  latent = self.latents[latent_id]
581
589
 
582
590
  latent = self.maybe_l2norm(latent)
@@ -686,17 +694,38 @@ class Agent(Module):
686
694
  def get_actor_actions(
687
695
  self,
688
696
  state,
689
- latent_id
697
+ latent_id = None,
698
+ latent = None,
699
+ sample = False,
700
+ temperature = 1.
690
701
  ):
691
- latent = self.latent_gene_pool(latent_id = latent_id, state = state)
692
- return self.actor(state, latent)
702
+ assert exists(latent_id) or exists(latent)
703
+
704
+ if not exists(latent):
705
+ latent = self.latent_gene_pool(latent_id = latent_id)
706
+
707
+ logits = self.actor(state, latent)
708
+
709
+ if not sample:
710
+ return logits
711
+
712
+ actions = gumbel_sample(logits, temperature = temperature)
713
+
714
+ log_probs = gather_log_prob(logits, actions)
715
+
716
+ return actions, log_probs
693
717
 
694
718
  def get_critic_values(
695
719
  self,
696
720
  state,
697
- latent_id
721
+ latent_id = None,
722
+ latent = None
698
723
  ):
699
- latent = self.latent_gene_pool(latent_id = latent_id, state = state)
724
+ assert exists(latent_id) or exists(latent)
725
+
726
+ if not exists(latent):
727
+ latent = self.latent_gene_pool(latent_id = latent_id)
728
+
700
729
  return self.critic(state, latent)
701
730
 
702
731
  def update_latent_gene_pool_(
@@ -707,12 +736,13 @@ class Agent(Module):
707
736
 
708
737
  def forward(
709
738
  self,
710
- memories_and_next_value: MemoriesAndNextValue,
739
+ memories_and_fitness_scores: MemoriesAndFitnessScores,
711
740
  epochs = 2
712
741
  ):
713
- memories, next_value = memories_and_next_value
742
+ memories, fitness_scores = memories_and_fitness_scores
714
743
 
715
744
  (
745
+ episode_ids,
716
746
  states,
717
747
  latent_gene_ids,
718
748
  actions,
@@ -722,35 +752,46 @@ class Agent(Module):
722
752
  dones
723
753
  ) = map(stack, zip(*memories))
724
754
 
725
- values_with_next, ps = pack((values, next_value), '*')
755
+ advantages = self.calc_gae(
756
+ rewards[:-1],
757
+ values,
758
+ dones[:-1],
759
+ )
726
760
 
727
- advantages = self.calc_gae(rewards, values_with_next, dones)
761
+ valid_episode = episode_ids >= 0
728
762
 
729
- dataset = TensorDataset(states, latent_gene_ids, actions, log_probs, advantages, values)
763
+ dataset = TensorDataset(
764
+ *[
765
+ advantages[valid_episode[:-1]],
766
+ *[t[valid_episode] for t in (states, latent_gene_ids, actions, log_probs, values)]
767
+ ]
768
+ )
730
769
 
731
770
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
732
771
 
733
772
  self.actor.train()
734
773
  self.critic.train()
735
774
 
736
- for _ in range(epochs):
775
+ for _ in tqdm(range(epochs), desc = 'learning actor/critic epoch'):
737
776
  for (
777
+ advantages,
738
778
  states,
739
779
  latent_gene_ids,
740
780
  actions,
741
781
  log_probs,
742
- advantages,
743
782
  old_values
744
783
  ) in dataloader:
745
784
 
746
- latents = self.latent_gene_pool(latent_gene_ids)
785
+ latents = self.latent_gene_pool(latent_id = latent_gene_ids)
747
786
 
748
787
  # learn actor
749
788
 
750
789
  logits = self.actor(states, latents)
790
+
751
791
  actor_loss = self.actor_loss(logits, log_probs, actions, advantages)
752
792
 
753
793
  actor_loss.backward()
794
+
754
795
  self.actor_optim.step()
755
796
  self.actor_optim.zero_grad()
756
797
 
@@ -759,7 +800,7 @@ class Agent(Module):
759
800
  critic_loss = self.critic(
760
801
  states,
761
802
  latents,
762
- targets = advantages + old_values
803
+ target = advantages + old_values
763
804
  )
764
805
 
765
806
  critic_loss.backward()
@@ -767,6 +808,10 @@ class Agent(Module):
767
808
  self.critic_optim.step()
768
809
  self.critic_optim.zero_grad()
769
810
 
811
+ # apply evolution
812
+
813
+ self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
814
+
770
815
  # reinforcement learning related - ppo
771
816
 
772
817
  def actor_loss(
@@ -785,7 +830,7 @@ def actor_loss(
785
830
 
786
831
  clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
787
832
 
788
- actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
833
+ actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
789
834
 
790
835
  # add entropy loss for exploration
791
836
 
@@ -793,15 +838,7 @@ def actor_loss(
793
838
 
794
839
  entropy_aux_loss = -entropy_weight * entropy
795
840
 
796
- return actor_loss + entropy_aux_loss
797
-
798
- def critic_loss(
799
- pred_values, # Float[b]
800
- advantages, # Float[b]
801
- old_values # Float[b]
802
- ):
803
- discounted_values = advantages + old_values
804
- return F.mse_loss(pred_values, discounted_values)
841
+ return (actor_loss + entropy_aux_loss).mean()
805
842
 
806
843
  # agent contains the actor, critic, and the latent genetic pool
807
844
 
@@ -814,6 +851,11 @@ def create_agent(
814
851
  critic_dim_hiddens: int | tuple[int, ...],
815
852
  ) -> Agent:
816
853
 
854
+ latent_gene_pool = LatentGenePool(
855
+ num_latents = num_latents,
856
+ dim_latent = dim_latent
857
+ )
858
+
817
859
  actor = Actor(
818
860
  num_actions = actor_num_actions,
819
861
  dim_state = dim_state,
@@ -825,46 +867,133 @@ def create_agent(
825
867
  dim_state = dim_state,
826
868
  dim_latent = dim_latent,
827
869
  dim_hiddens = critic_dim_hiddens
828
- )
829
-
830
- latent_gene_pool = LatentGenePool(
831
- dim_state = dim_state,
832
- num_latents = num_latents,
833
- dim_latent = dim_latent,
834
- )
835
-
870
+ )
836
871
  return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
837
872
 
838
873
  # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
839
874
  # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
840
875
 
841
876
  Memory = namedtuple('Memory', [
877
+ 'episode_id',
842
878
  'state',
843
879
  'latent_gene_id',
844
880
  'action',
845
881
  'log_prob',
846
882
  'reward',
847
- 'values',
883
+ 'value',
848
884
  'done'
849
885
  ])
850
886
 
851
- MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
887
+ MemoriesAndFitnessScores = namedtuple('MemoriesAndFitnessScores', [
852
888
  'memories',
853
- 'next_value'
889
+ 'fitness_scores'
854
890
  ])
855
891
 
856
892
  class EPO(Module):
857
893
 
858
894
  def __init__(
859
895
  self,
860
- agent: Agent
896
+ agent: Agent,
897
+ episodes_per_latent,
898
+ max_episode_length,
899
+ action_sample_temperature = 1.
861
900
  ):
862
901
  super().__init__()
863
902
  self.agent = agent
903
+ self.action_sample_temperature = action_sample_temperature
904
+
905
+ self.num_latents = agent.latent_gene_pool.num_latents
906
+ self.episodes_per_latent = episodes_per_latent
907
+ self.max_episode_length = max_episode_length
864
908
 
909
+ @torch.no_grad()
865
910
  def forward(
866
911
  self,
867
912
  env
868
- ) -> MemoriesAndNextValue:
913
+ ) -> MemoriesAndFitnessScores:
914
+
915
+ self.agent.eval()
916
+
917
+ invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
918
+
919
+ memories: list[Memory] = []
920
+
921
+ fitness_scores = torch.zeros((self.num_latents))
922
+
923
+ for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
924
+
925
+ for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
926
+ time = 0
927
+
928
+ # initial state
929
+
930
+ state = env.reset()
931
+
932
+ # get latent from pool
933
+
934
+ latent = self.agent.latent_gene_pool(latent_id = latent_id)
935
+
936
+ # until maximum episode length
937
+
938
+ done = tensor(False)
939
+
940
+ while time < self.max_episode_length:
941
+
942
+ batched_state = rearrange(state, '... -> 1 ...')
943
+
944
+ # sample action
869
945
 
870
- raise NotImplementedError
946
+ action, log_prob = self.agent.get_actor_actions(batched_state, latent = latent, sample = True, temperature = self.action_sample_temperature)
947
+
948
+ action = rearrange(action, '1 ... -> ...')
949
+ log_prob = rearrange(log_prob, '1 ... -> ...')
950
+
951
+ # values
952
+
953
+ value = self.agent.get_critic_values(batched_state, latent = latent)
954
+
955
+ value = rearrange(value, '1 ... -> ...')
956
+
957
+ # get the next state, action, and reward
958
+
959
+ state, reward, done = env(action)
960
+
961
+ # update fitness for each gene as cumulative reward received, but make this customizable at some point
962
+
963
+ fitness_scores[latent_id] += reward
964
+
965
+ # store memories
966
+
967
+ memory = Memory(
968
+ tensor(episode_id),
969
+ state,
970
+ tensor(latent_id),
971
+ action,
972
+ log_prob,
973
+ reward,
974
+ value,
975
+ done
976
+ )
977
+
978
+ memories.append(memory)
979
+
980
+ time += 1
981
+
982
+ # need the final next value for GAE, iiuc
983
+
984
+ batched_state = rearrange(state, '... -> 1 ...')
985
+
986
+ next_value = self.agent.get_critic_values(batched_state, latent = latent)
987
+ next_value = rearrange(next_value, '1 ... -> ...')
988
+
989
+ memory_for_gae = memory._replace(
990
+ episode_id = invalid_episode,
991
+ value = next_value
992
+ )
993
+
994
+ memories.append(memory_for_gae)
995
+
996
+ return MemoriesAndFitnessScores(
997
+ memories = memories,
998
+ fitness_scores = fitness_scores
999
+ )
@@ -1,27 +1,47 @@
1
1
  import torch
2
+ from einops import rearrange
2
3
 
3
4
  def crossover_weights(w1, w2, transpose = False):
4
5
  assert w2.shape == w2.shape
5
- assert w1.ndim == 2
6
+
7
+ no_batch = w1.ndim == 2
8
+
9
+ if no_batch:
10
+ w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
11
+
12
+ assert w1.ndim == 3
6
13
 
7
14
  if transpose:
8
- w1, w2 = w1.t(), w2.t()
15
+ w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
9
16
 
10
- rank = min(w2.shape)
17
+ rank = min(w2.shape[1:])
11
18
  assert rank >= 2
12
19
 
20
+ batch = w1.shape[0]
21
+
13
22
  u1, s1, v1 = torch.svd(w1)
14
23
  u2, s2, v2 = torch.svd(w2)
15
24
 
16
- mask = torch.randperm(rank) < (rank // 2)
25
+ batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
26
+ mask = batch_randperm < (rank // 2)
17
27
 
18
- u = torch.where(mask[None, :], u1, u2)
28
+ u = torch.where(mask[:, None, :], u1, u2)
19
29
  s = torch.where(mask, s1, s2)
20
- v = torch.where(mask[None, :], v1, v2)
30
+ v = torch.where(mask[:, None, :], v1, v2)
21
31
 
22
32
  out = u @ torch.diag_embed(s) @ v.mT
23
33
 
24
34
  if transpose:
25
- out = out.t()
35
+ out = rearrange(out, 'b j i -> b i j')
36
+
37
+ if no_batch:
38
+ out = rearrange(out, '1 ... -> ...')
26
39
 
27
40
  return out
41
+
42
+ if __name__ == '__main__':
43
+ w1 = torch.randn(32, 16)
44
+ w2 = torch.randn(32, 16)
45
+ child = crossover_weights(w2, w2)
46
+
47
+ assert child.shape == w2.shape
@@ -4,15 +4,20 @@ import torch
4
4
  from torch import tensor, randn, randint
5
5
  from torch.nn import Module
6
6
 
7
+ # functions
8
+
9
+ def cast_tuple(v):
10
+ return v if isinstance(v, tuple) else v\
11
+
7
12
  # mock env
8
13
 
9
14
  class Env(Module):
10
15
  def __init__(
11
16
  self,
12
- state_shape: tuple[int, ...]
17
+ state_shape: int | tuple[int, ...]
13
18
  ):
14
19
  super().__init__()
15
- self.state_shape = state_shape
20
+ self.state_shape = cast_tuple(state_shape)
16
21
  self.register_buffer('dummy', tensor(0))
17
22
 
18
23
  @property
@@ -31,6 +36,6 @@ class Env(Module):
31
36
  ):
32
37
  state = randn(self.state_shape, device = self.device)
33
38
  reward = randint(0, 5, (), device = self.device).float()
34
- done = zeros((), device = self.device, dtype = torch.bool)
39
+ done = torch.zeros((), device = self.device, dtype = torch.bool)
35
40
 
36
41
  return state, reward, done
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.37
3
+ Version: 0.0.39
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -60,8 +60,6 @@ This paper stands out, as I have witnessed the positive effects first hand in an
60
60
 
61
61
  Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
62
62
 
63
- Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
64
-
65
63
  ## Install
66
64
 
67
65
  ```bash
@@ -103,6 +101,45 @@ fitness = torch.randn(128)
103
101
  latent_pool.genetic_algorithm_step(fitness) # update latent genes with genetic algorithm
104
102
  ```
105
103
 
104
+ End to end learning
105
+
106
+ ```python
107
+ import torch
108
+
109
+ from evolutionary_policy_optimization import (
110
+ create_agent,
111
+ EPO,
112
+ Env
113
+ )
114
+
115
+ agent = create_agent(
116
+ dim_state = 512,
117
+ num_latents = 8,
118
+ dim_latent = 32,
119
+ actor_num_actions = 5,
120
+ actor_dim_hiddens = (256, 128),
121
+ critic_dim_hiddens = (256, 128, 64)
122
+ )
123
+
124
+ epo = EPO(
125
+ agent,
126
+ episodes_per_latent = 1,
127
+ max_episode_length = 10,
128
+ action_sample_temperature = 1.
129
+ )
130
+
131
+ env = Env((512,))
132
+
133
+ memories = epo(env)
134
+
135
+ agent(memories)
136
+
137
+ # saving and loading
138
+
139
+ agent.save('./agent.pt', overwrite = True)
140
+ agent.load('./agent.pt')
141
+ ```
142
+
106
143
  ## Citations
107
144
 
108
145
  ```bibtex
@@ -0,0 +1,8 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/epo.py,sha256=lzxPamJahE5KqBwzyYlGOwNeUoB2vONLwtRcWqCI_Jw,29800
3
+ evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
+ evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
5
+ evolutionary_policy_optimization-0.0.39.dist-info/METADATA,sha256=TTNQD7sTWIgpVwnrQrFFBD-cyySkvwJr_J3ABxTpor8,5409
6
+ evolutionary_policy_optimization-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.39.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=onIGNWHg1EGQwJ9TfkkJ8Yz8_S-BPoaqrxJwq54BXp0,25992
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization/mock_env.py,sha256=3xrd-gwjZeVd_sEvxIyX0lppnMWcfQGOapO-XjKmExI,816
5
- evolutionary_policy_optimization-0.0.37.dist-info/METADATA,sha256=nPWBCvx02MHWdKu5cEoPmHFMFKhwepOfStkXIXR2NHc,4992
6
- evolutionary_policy_optimization-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.37.dist-info/RECORD,,