evolutionary-policy-optimization 0.0.65__py3-none-any.whl → 0.0.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +39 -20
- {evolutionary_policy_optimization-0.0.65.dist-info → evolutionary_policy_optimization-0.0.67.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.0.65.dist-info → evolutionary_policy_optimization-0.0.67.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.65.dist-info → evolutionary_policy_optimization-0.0.67.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.65.dist-info → evolutionary_policy_optimization-0.0.67.dist-info}/licenses/LICENSE +0 -0
@@ -181,6 +181,8 @@ class MLP(Module):
|
|
181
181
|
dim_latent = 0,
|
182
182
|
):
|
183
183
|
super().__init__()
|
184
|
+
dim_latent = default(dim_latent, 0)
|
185
|
+
|
184
186
|
assert len(dims) >= 2, 'must have at least two dimensions'
|
185
187
|
|
186
188
|
# add the latent to the first dim
|
@@ -376,6 +378,7 @@ class LatentGenePool(Module):
|
|
376
378
|
init_latent_fn: Callable | None = None
|
377
379
|
):
|
378
380
|
super().__init__()
|
381
|
+
assert num_latents > 1
|
379
382
|
|
380
383
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
381
384
|
|
@@ -626,8 +629,8 @@ class LatentGenePool(Module):
|
|
626
629
|
|
627
630
|
def forward(
|
628
631
|
self,
|
629
|
-
*args,
|
630
632
|
latent_id: int | None = None,
|
633
|
+
*args,
|
631
634
|
net: Module | None = None,
|
632
635
|
net_latent_kwarg_name = 'latent',
|
633
636
|
**kwargs,
|
@@ -670,7 +673,7 @@ class Agent(Module):
|
|
670
673
|
self,
|
671
674
|
actor: Actor,
|
672
675
|
critic: Critic,
|
673
|
-
latent_gene_pool: LatentGenePool,
|
676
|
+
latent_gene_pool: LatentGenePool | None,
|
674
677
|
optim_klass = AdoptAtan2,
|
675
678
|
actor_lr = 1e-4,
|
676
679
|
critic_lr = 1e-4,
|
@@ -705,10 +708,14 @@ class Agent(Module):
|
|
705
708
|
self.use_critic_ema = use_critic_ema
|
706
709
|
self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
|
707
710
|
|
708
|
-
self.num_latents = latent_gene_pool.num_latents
|
709
711
|
self.latent_gene_pool = latent_gene_pool
|
712
|
+
self.num_latents = latent_gene_pool.num_latents if exists(latent_gene_pool) else 1
|
713
|
+
self.has_latent_genes = exists(latent_gene_pool)
|
710
714
|
|
711
|
-
assert actor.dim_latent == critic.dim_latent
|
715
|
+
assert actor.dim_latent == critic.dim_latent
|
716
|
+
|
717
|
+
if self.has_latent_genes:
|
718
|
+
assert latent_gene_pool.dim_latent == actor.dim_latent
|
712
719
|
|
713
720
|
# gae function
|
714
721
|
|
@@ -730,7 +737,7 @@ class Agent(Module):
|
|
730
737
|
self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
|
731
738
|
self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
|
732
739
|
|
733
|
-
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
740
|
+
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
|
734
741
|
|
735
742
|
# promotes latents to be farther apart for diversity maintenance
|
736
743
|
|
@@ -746,7 +753,7 @@ class Agent(Module):
|
|
746
753
|
actor = self.actor.state_dict(),
|
747
754
|
critic = self.critic.state_dict(),
|
748
755
|
critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
|
749
|
-
latents = self.latent_gene_pool.state_dict(),
|
756
|
+
latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
|
750
757
|
actor_optim = self.actor_optim.state_dict(),
|
751
758
|
critic_optim = self.critic_optim.state_dict(),
|
752
759
|
latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
|
@@ -768,7 +775,8 @@ class Agent(Module):
|
|
768
775
|
if self.use_critic_ema:
|
769
776
|
self.critic_ema.load_state_dict(pkg['critic_ema'])
|
770
777
|
|
771
|
-
|
778
|
+
if exists(pkg.get('latents', None)):
|
779
|
+
self.latent_gene_pool.load_state_dict(pkg['latents'])
|
772
780
|
|
773
781
|
self.actor_optim.load_state_dict(pkg['actor_optim'])
|
774
782
|
self.critic_optim.load_state_dict(pkg['critic_optim'])
|
@@ -784,9 +792,8 @@ class Agent(Module):
|
|
784
792
|
sample = False,
|
785
793
|
temperature = 1.
|
786
794
|
):
|
787
|
-
assert exists(latent_id) or exists(latent)
|
788
795
|
|
789
|
-
if not exists(latent):
|
796
|
+
if not exists(latent) and exists(latent_id):
|
790
797
|
latent = self.latent_gene_pool(latent_id = latent_id)
|
791
798
|
|
792
799
|
logits = self.actor(state, latent)
|
@@ -807,9 +814,8 @@ class Agent(Module):
|
|
807
814
|
latent = None,
|
808
815
|
use_ema_if_available = False
|
809
816
|
):
|
810
|
-
assert exists(latent_id) or exists(latent)
|
811
817
|
|
812
|
-
if not exists(latent):
|
818
|
+
if not exists(latent) and exists(latent_id):
|
813
819
|
latent = self.latent_gene_pool(latent_id = latent_id)
|
814
820
|
|
815
821
|
critic_forward = self.critic
|
@@ -823,6 +829,9 @@ class Agent(Module):
|
|
823
829
|
self,
|
824
830
|
fitnesses
|
825
831
|
):
|
832
|
+
if not self.has_latent_genes:
|
833
|
+
return
|
834
|
+
|
826
835
|
return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
|
827
836
|
|
828
837
|
def forward(
|
@@ -893,11 +902,14 @@ class Agent(Module):
|
|
893
902
|
old_values
|
894
903
|
) in dataloader:
|
895
904
|
|
896
|
-
|
905
|
+
if self.has_latent_genes:
|
906
|
+
latents = self.latent_gene_pool(latent_id = latent_gene_ids)
|
897
907
|
|
898
|
-
|
899
|
-
|
900
|
-
|
908
|
+
orig_latents = latents
|
909
|
+
latents = latents.detach()
|
910
|
+
latents.requires_grad_()
|
911
|
+
else:
|
912
|
+
latents = None
|
901
913
|
|
902
914
|
# learn actor
|
903
915
|
|
@@ -936,7 +948,7 @@ class Agent(Module):
|
|
936
948
|
|
937
949
|
# maybe update latents, if not frozen
|
938
950
|
|
939
|
-
if self.latent_gene_pool.frozen_latents:
|
951
|
+
if not self.has_latent_genes or self.latent_gene_pool.frozen_latents:
|
940
952
|
continue
|
941
953
|
|
942
954
|
orig_latents.backward(latents.grad)
|
@@ -952,7 +964,9 @@ class Agent(Module):
|
|
952
964
|
|
953
965
|
# apply evolution
|
954
966
|
|
955
|
-
self.
|
967
|
+
if self.has_latent_genes:
|
968
|
+
|
969
|
+
self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
|
956
970
|
|
957
971
|
# reinforcement learning related - ppo
|
958
972
|
|
@@ -1005,11 +1019,16 @@ def create_agent(
|
|
1005
1019
|
**kwargs
|
1006
1020
|
) -> Agent:
|
1007
1021
|
|
1022
|
+
has_latent_genes = num_latents > 1
|
1023
|
+
|
1024
|
+
if not has_latent_genes:
|
1025
|
+
dim_latent = None
|
1026
|
+
|
1008
1027
|
latent_gene_pool = LatentGenePool(
|
1009
1028
|
num_latents = num_latents,
|
1010
1029
|
dim_latent = dim_latent,
|
1011
1030
|
**latent_gene_pool_kwargs
|
1012
|
-
)
|
1031
|
+
) if has_latent_genes else None
|
1013
1032
|
|
1014
1033
|
actor = Actor(
|
1015
1034
|
num_actions = actor_num_actions,
|
@@ -1069,7 +1088,7 @@ class EPO(Module):
|
|
1069
1088
|
self.agent = agent
|
1070
1089
|
self.action_sample_temperature = action_sample_temperature
|
1071
1090
|
|
1072
|
-
self.num_latents = agent.latent_gene_pool.num_latents
|
1091
|
+
self.num_latents = agent.latent_gene_pool.num_latents if agent.has_latent_genes else 1
|
1073
1092
|
self.episodes_per_latent = episodes_per_latent
|
1074
1093
|
self.max_episode_length = max_episode_length
|
1075
1094
|
self.fix_environ_across_latents = fix_environ_across_latents
|
@@ -1159,7 +1178,7 @@ class EPO(Module):
|
|
1159
1178
|
|
1160
1179
|
# get latent from pool
|
1161
1180
|
|
1162
|
-
latent = self.agent.latent_gene_pool(latent_id = latent_id)
|
1181
|
+
latent = self.agent.latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
|
1163
1182
|
|
1164
1183
|
# until maximum episode length
|
1165
1184
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.67
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,9 +1,9 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=7WOGl22WknudsNLSZ18AWwW7rPt5ITMAtByetUwLp7M,37654
|
4
4
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
5
|
evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
9
|
-
evolutionary_policy_optimization-0.0.
|
6
|
+
evolutionary_policy_optimization-0.0.67.dist-info/METADATA,sha256=S1biwayyDA4vTOXknMU5KeWtJTFvxePHZZ0OZzuaNms,6220
|
7
|
+
evolutionary_policy_optimization-0.0.67.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.67.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.67.dist-info/RECORD,,
|
File without changes
|