evolutionary-policy-optimization 0.0.66__py3-none-any.whl → 0.0.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -55,6 +55,9 @@ def xnor(x, y):
55
55
  def divisible_by(num, den):
56
56
  return (num % den) == 0
57
57
 
58
+ def to_device(inp, device):
59
+ return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
60
+
58
61
  # tensor helpers
59
62
 
60
63
  def l2norm(t):
@@ -181,6 +184,8 @@ class MLP(Module):
181
184
  dim_latent = 0,
182
185
  ):
183
186
  super().__init__()
187
+ dim_latent = default(dim_latent, 0)
188
+
184
189
  assert len(dims) >= 2, 'must have at least two dimensions'
185
190
 
186
191
  # add the latent to the first dim
@@ -376,6 +381,7 @@ class LatentGenePool(Module):
376
381
  init_latent_fn: Callable | None = None
377
382
  ):
378
383
  super().__init__()
384
+ assert num_latents > 1
379
385
 
380
386
  maybe_l2norm = l2norm if l2norm_latent else identity
381
387
 
@@ -670,7 +676,7 @@ class Agent(Module):
670
676
  self,
671
677
  actor: Actor,
672
678
  critic: Critic,
673
- latent_gene_pool: LatentGenePool,
679
+ latent_gene_pool: LatentGenePool | None,
674
680
  optim_klass = AdoptAtan2,
675
681
  actor_lr = 1e-4,
676
682
  critic_lr = 1e-4,
@@ -705,10 +711,14 @@ class Agent(Module):
705
711
  self.use_critic_ema = use_critic_ema
706
712
  self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
707
713
 
708
- self.num_latents = latent_gene_pool.num_latents
709
714
  self.latent_gene_pool = latent_gene_pool
715
+ self.num_latents = latent_gene_pool.num_latents if exists(latent_gene_pool) else 1
716
+ self.has_latent_genes = exists(latent_gene_pool)
717
+
718
+ assert actor.dim_latent == critic.dim_latent
710
719
 
711
- assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
720
+ if self.has_latent_genes:
721
+ assert latent_gene_pool.dim_latent == actor.dim_latent
712
722
 
713
723
  # gae function
714
724
 
@@ -730,13 +740,19 @@ class Agent(Module):
730
740
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
731
741
  self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
732
742
 
733
- self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
743
+ self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
734
744
 
735
745
  # promotes latents to be farther apart for diversity maintenance
736
746
 
737
747
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
738
748
  self.diversity_aux_loss_weight = diversity_aux_loss_weight
739
749
 
750
+ self.register_buffer('dummy', tensor(0))
751
+
752
+ @property
753
+ def device(self):
754
+ return self.dummy.device
755
+
740
756
  def save(self, path, overwrite = False):
741
757
  path = Path(path)
742
758
 
@@ -746,7 +762,7 @@ class Agent(Module):
746
762
  actor = self.actor.state_dict(),
747
763
  critic = self.critic.state_dict(),
748
764
  critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
749
- latents = self.latent_gene_pool.state_dict(),
765
+ latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
750
766
  actor_optim = self.actor_optim.state_dict(),
751
767
  critic_optim = self.critic_optim.state_dict(),
752
768
  latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
@@ -768,7 +784,8 @@ class Agent(Module):
768
784
  if self.use_critic_ema:
769
785
  self.critic_ema.load_state_dict(pkg['critic_ema'])
770
786
 
771
- self.latent_gene_pool.load_state_dict(pkg['latents'])
787
+ if exists(pkg.get('latents', None)):
788
+ self.latent_gene_pool.load_state_dict(pkg['latents'])
772
789
 
773
790
  self.actor_optim.load_state_dict(pkg['actor_optim'])
774
791
  self.critic_optim.load_state_dict(pkg['critic_optim'])
@@ -784,9 +801,8 @@ class Agent(Module):
784
801
  sample = False,
785
802
  temperature = 1.
786
803
  ):
787
- assert exists(latent_id) or exists(latent)
788
804
 
789
- if not exists(latent):
805
+ if not exists(latent) and exists(latent_id):
790
806
  latent = self.latent_gene_pool(latent_id = latent_id)
791
807
 
792
808
  logits = self.actor(state, latent)
@@ -807,9 +823,8 @@ class Agent(Module):
807
823
  latent = None,
808
824
  use_ema_if_available = False
809
825
  ):
810
- assert exists(latent_id) or exists(latent)
811
826
 
812
- if not exists(latent):
827
+ if not exists(latent) and exists(latent_id):
813
828
  latent = self.latent_gene_pool(latent_id = latent_id)
814
829
 
815
830
  critic_forward = self.critic
@@ -823,13 +838,19 @@ class Agent(Module):
823
838
  self,
824
839
  fitnesses
825
840
  ):
841
+ if not self.has_latent_genes:
842
+ return
843
+
826
844
  return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
827
845
 
828
846
  def forward(
829
847
  self,
830
848
  memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
831
849
  epochs = 2
850
+
832
851
  ):
852
+ memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
853
+
833
854
  memories, rewards_per_latent_episode = memories_and_cumulative_rewards
834
855
 
835
856
  # stack memories
@@ -893,11 +914,14 @@ class Agent(Module):
893
914
  old_values
894
915
  ) in dataloader:
895
916
 
896
- latents = self.latent_gene_pool(latent_id = latent_gene_ids)
917
+ if self.has_latent_genes:
918
+ latents = self.latent_gene_pool(latent_id = latent_gene_ids)
897
919
 
898
- orig_latents = latents
899
- latents = latents.detach()
900
- latents.requires_grad_()
920
+ orig_latents = latents
921
+ latents = latents.detach()
922
+ latents.requires_grad_()
923
+ else:
924
+ latents = None
901
925
 
902
926
  # learn actor
903
927
 
@@ -936,7 +960,7 @@ class Agent(Module):
936
960
 
937
961
  # maybe update latents, if not frozen
938
962
 
939
- if self.latent_gene_pool.frozen_latents:
963
+ if not self.has_latent_genes or self.latent_gene_pool.frozen_latents:
940
964
  continue
941
965
 
942
966
  orig_latents.backward(latents.grad)
@@ -952,7 +976,9 @@ class Agent(Module):
952
976
 
953
977
  # apply evolution
954
978
 
955
- self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
979
+ if self.has_latent_genes:
980
+
981
+ self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
956
982
 
957
983
  # reinforcement learning related - ppo
958
984
 
@@ -1005,11 +1031,16 @@ def create_agent(
1005
1031
  **kwargs
1006
1032
  ) -> Agent:
1007
1033
 
1034
+ has_latent_genes = num_latents > 1
1035
+
1036
+ if not has_latent_genes:
1037
+ dim_latent = None
1038
+
1008
1039
  latent_gene_pool = LatentGenePool(
1009
1040
  num_latents = num_latents,
1010
1041
  dim_latent = dim_latent,
1011
1042
  **latent_gene_pool_kwargs
1012
- )
1043
+ ) if has_latent_genes else None
1013
1044
 
1014
1045
  actor = Actor(
1015
1046
  num_actions = actor_num_actions,
@@ -1069,7 +1100,7 @@ class EPO(Module):
1069
1100
  self.agent = agent
1070
1101
  self.action_sample_temperature = action_sample_temperature
1071
1102
 
1072
- self.num_latents = agent.latent_gene_pool.num_latents
1103
+ self.num_latents = agent.latent_gene_pool.num_latents if agent.has_latent_genes else 1
1073
1104
  self.episodes_per_latent = episodes_per_latent
1074
1105
  self.max_episode_length = max_episode_length
1075
1106
  self.fix_environ_across_latents = fix_environ_across_latents
@@ -1159,7 +1190,7 @@ class EPO(Module):
1159
1190
 
1160
1191
  # get latent from pool
1161
1192
 
1162
- latent = self.agent.latent_gene_pool(latent_id = latent_id)
1193
+ latent = self.agent.latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1163
1194
 
1164
1195
  # until maximum episode length
1165
1196
 
@@ -1198,6 +1229,8 @@ class EPO(Module):
1198
1229
  terminated
1199
1230
  )
1200
1231
 
1232
+ memory = Memory(*tuple(t.cpu() for t in memory))
1233
+
1201
1234
  memories.append(memory)
1202
1235
 
1203
1236
  time += 1
@@ -1209,7 +1242,7 @@ class EPO(Module):
1209
1242
 
1210
1243
  memory_for_gae = memory._replace(
1211
1244
  episode_id = invalid_episode,
1212
- value = next_value,
1245
+ value = next_value.cpu(),
1213
1246
  done = tensor(True)
1214
1247
  )
1215
1248
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.66
3
+ Version: 0.0.68
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,9 +1,9 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=-ULtP9-EJDQv5TlWwkPwPiGsU7bKnD2qfAYvyDK8GIU,36912
3
+ evolutionary_policy_optimization/epo.py,sha256=xhE_kHas54xGsgOese9SQEvyK7NKZqEuK3AiVhm0y7Q,38047
4
4
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
5
  evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
- evolutionary_policy_optimization-0.0.66.dist-info/METADATA,sha256=YsfcSX3Vf5ALrePot6Eq4NStsVPzNDJ3Py6sMgh7lOE,6220
7
- evolutionary_policy_optimization-0.0.66.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.66.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.66.dist-info/RECORD,,
6
+ evolutionary_policy_optimization-0.0.68.dist-info/METADATA,sha256=hOOKOrrPQtQmK3zN1z5nkGJEoaQLyXUzs9ArsEKn1DE,6220
7
+ evolutionary_policy_optimization-0.0.68.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.68.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.68.dist-info/RECORD,,