evolutionary-policy-optimization 0.0.66__py3-none-any.whl → 0.0.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -181,6 +181,8 @@ class MLP(Module):
181
181
  dim_latent = 0,
182
182
  ):
183
183
  super().__init__()
184
+ dim_latent = default(dim_latent, 0)
185
+
184
186
  assert len(dims) >= 2, 'must have at least two dimensions'
185
187
 
186
188
  # add the latent to the first dim
@@ -376,6 +378,7 @@ class LatentGenePool(Module):
376
378
  init_latent_fn: Callable | None = None
377
379
  ):
378
380
  super().__init__()
381
+ assert num_latents > 1
379
382
 
380
383
  maybe_l2norm = l2norm if l2norm_latent else identity
381
384
 
@@ -670,7 +673,7 @@ class Agent(Module):
670
673
  self,
671
674
  actor: Actor,
672
675
  critic: Critic,
673
- latent_gene_pool: LatentGenePool,
676
+ latent_gene_pool: LatentGenePool | None,
674
677
  optim_klass = AdoptAtan2,
675
678
  actor_lr = 1e-4,
676
679
  critic_lr = 1e-4,
@@ -705,10 +708,14 @@ class Agent(Module):
705
708
  self.use_critic_ema = use_critic_ema
706
709
  self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
707
710
 
708
- self.num_latents = latent_gene_pool.num_latents
709
711
  self.latent_gene_pool = latent_gene_pool
712
+ self.num_latents = latent_gene_pool.num_latents if exists(latent_gene_pool) else 1
713
+ self.has_latent_genes = exists(latent_gene_pool)
714
+
715
+ assert actor.dim_latent == critic.dim_latent
710
716
 
711
- assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
717
+ if self.has_latent_genes:
718
+ assert latent_gene_pool.dim_latent == actor.dim_latent
712
719
 
713
720
  # gae function
714
721
 
@@ -730,7 +737,7 @@ class Agent(Module):
730
737
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
731
738
  self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
732
739
 
733
- self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
740
+ self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
734
741
 
735
742
  # promotes latents to be farther apart for diversity maintenance
736
743
 
@@ -746,7 +753,7 @@ class Agent(Module):
746
753
  actor = self.actor.state_dict(),
747
754
  critic = self.critic.state_dict(),
748
755
  critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
749
- latents = self.latent_gene_pool.state_dict(),
756
+ latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
750
757
  actor_optim = self.actor_optim.state_dict(),
751
758
  critic_optim = self.critic_optim.state_dict(),
752
759
  latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
@@ -768,7 +775,8 @@ class Agent(Module):
768
775
  if self.use_critic_ema:
769
776
  self.critic_ema.load_state_dict(pkg['critic_ema'])
770
777
 
771
- self.latent_gene_pool.load_state_dict(pkg['latents'])
778
+ if exists(pkg.get('latents', None)):
779
+ self.latent_gene_pool.load_state_dict(pkg['latents'])
772
780
 
773
781
  self.actor_optim.load_state_dict(pkg['actor_optim'])
774
782
  self.critic_optim.load_state_dict(pkg['critic_optim'])
@@ -784,9 +792,8 @@ class Agent(Module):
784
792
  sample = False,
785
793
  temperature = 1.
786
794
  ):
787
- assert exists(latent_id) or exists(latent)
788
795
 
789
- if not exists(latent):
796
+ if not exists(latent) and exists(latent_id):
790
797
  latent = self.latent_gene_pool(latent_id = latent_id)
791
798
 
792
799
  logits = self.actor(state, latent)
@@ -807,9 +814,8 @@ class Agent(Module):
807
814
  latent = None,
808
815
  use_ema_if_available = False
809
816
  ):
810
- assert exists(latent_id) or exists(latent)
811
817
 
812
- if not exists(latent):
818
+ if not exists(latent) and exists(latent_id):
813
819
  latent = self.latent_gene_pool(latent_id = latent_id)
814
820
 
815
821
  critic_forward = self.critic
@@ -823,6 +829,9 @@ class Agent(Module):
823
829
  self,
824
830
  fitnesses
825
831
  ):
832
+ if not self.has_latent_genes:
833
+ return
834
+
826
835
  return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
827
836
 
828
837
  def forward(
@@ -893,11 +902,14 @@ class Agent(Module):
893
902
  old_values
894
903
  ) in dataloader:
895
904
 
896
- latents = self.latent_gene_pool(latent_id = latent_gene_ids)
905
+ if self.has_latent_genes:
906
+ latents = self.latent_gene_pool(latent_id = latent_gene_ids)
897
907
 
898
- orig_latents = latents
899
- latents = latents.detach()
900
- latents.requires_grad_()
908
+ orig_latents = latents
909
+ latents = latents.detach()
910
+ latents.requires_grad_()
911
+ else:
912
+ latents = None
901
913
 
902
914
  # learn actor
903
915
 
@@ -936,7 +948,7 @@ class Agent(Module):
936
948
 
937
949
  # maybe update latents, if not frozen
938
950
 
939
- if self.latent_gene_pool.frozen_latents:
951
+ if not self.has_latent_genes or self.latent_gene_pool.frozen_latents:
940
952
  continue
941
953
 
942
954
  orig_latents.backward(latents.grad)
@@ -952,7 +964,9 @@ class Agent(Module):
952
964
 
953
965
  # apply evolution
954
966
 
955
- self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
967
+ if self.has_latent_genes:
968
+
969
+ self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
956
970
 
957
971
  # reinforcement learning related - ppo
958
972
 
@@ -1005,11 +1019,16 @@ def create_agent(
1005
1019
  **kwargs
1006
1020
  ) -> Agent:
1007
1021
 
1022
+ has_latent_genes = num_latents > 1
1023
+
1024
+ if not has_latent_genes:
1025
+ dim_latent = None
1026
+
1008
1027
  latent_gene_pool = LatentGenePool(
1009
1028
  num_latents = num_latents,
1010
1029
  dim_latent = dim_latent,
1011
1030
  **latent_gene_pool_kwargs
1012
- )
1031
+ ) if has_latent_genes else None
1013
1032
 
1014
1033
  actor = Actor(
1015
1034
  num_actions = actor_num_actions,
@@ -1069,7 +1088,7 @@ class EPO(Module):
1069
1088
  self.agent = agent
1070
1089
  self.action_sample_temperature = action_sample_temperature
1071
1090
 
1072
- self.num_latents = agent.latent_gene_pool.num_latents
1091
+ self.num_latents = agent.latent_gene_pool.num_latents if agent.has_latent_genes else 1
1073
1092
  self.episodes_per_latent = episodes_per_latent
1074
1093
  self.max_episode_length = max_episode_length
1075
1094
  self.fix_environ_across_latents = fix_environ_across_latents
@@ -1159,7 +1178,7 @@ class EPO(Module):
1159
1178
 
1160
1179
  # get latent from pool
1161
1180
 
1162
- latent = self.agent.latent_gene_pool(latent_id = latent_id)
1181
+ latent = self.agent.latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1163
1182
 
1164
1183
  # until maximum episode length
1165
1184
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.66
3
+ Version: 0.0.67
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,9 +1,9 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=-ULtP9-EJDQv5TlWwkPwPiGsU7bKnD2qfAYvyDK8GIU,36912
3
+ evolutionary_policy_optimization/epo.py,sha256=7WOGl22WknudsNLSZ18AWwW7rPt5ITMAtByetUwLp7M,37654
4
4
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
5
  evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
- evolutionary_policy_optimization-0.0.66.dist-info/METADATA,sha256=YsfcSX3Vf5ALrePot6Eq4NStsVPzNDJ3Py6sMgh7lOE,6220
7
- evolutionary_policy_optimization-0.0.66.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.66.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.66.dist-info/RECORD,,
6
+ evolutionary_policy_optimization-0.0.67.dist-info/METADATA,sha256=S1biwayyDA4vTOXknMU5KeWtJTFvxePHZZ0OZzuaNms,6220
7
+ evolutionary_policy_optimization-0.0.67.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.67.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.67.dist-info/RECORD,,