evolutionary-policy-optimization 0.0.66__py3-none-any.whl → 0.0.68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +53 -20
- {evolutionary_policy_optimization-0.0.66.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.0.66.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.66.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.66.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/licenses/LICENSE +0 -0
@@ -55,6 +55,9 @@ def xnor(x, y):
|
|
55
55
|
def divisible_by(num, den):
|
56
56
|
return (num % den) == 0
|
57
57
|
|
58
|
+
def to_device(inp, device):
|
59
|
+
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
|
60
|
+
|
58
61
|
# tensor helpers
|
59
62
|
|
60
63
|
def l2norm(t):
|
@@ -181,6 +184,8 @@ class MLP(Module):
|
|
181
184
|
dim_latent = 0,
|
182
185
|
):
|
183
186
|
super().__init__()
|
187
|
+
dim_latent = default(dim_latent, 0)
|
188
|
+
|
184
189
|
assert len(dims) >= 2, 'must have at least two dimensions'
|
185
190
|
|
186
191
|
# add the latent to the first dim
|
@@ -376,6 +381,7 @@ class LatentGenePool(Module):
|
|
376
381
|
init_latent_fn: Callable | None = None
|
377
382
|
):
|
378
383
|
super().__init__()
|
384
|
+
assert num_latents > 1
|
379
385
|
|
380
386
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
381
387
|
|
@@ -670,7 +676,7 @@ class Agent(Module):
|
|
670
676
|
self,
|
671
677
|
actor: Actor,
|
672
678
|
critic: Critic,
|
673
|
-
latent_gene_pool: LatentGenePool,
|
679
|
+
latent_gene_pool: LatentGenePool | None,
|
674
680
|
optim_klass = AdoptAtan2,
|
675
681
|
actor_lr = 1e-4,
|
676
682
|
critic_lr = 1e-4,
|
@@ -705,10 +711,14 @@ class Agent(Module):
|
|
705
711
|
self.use_critic_ema = use_critic_ema
|
706
712
|
self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
|
707
713
|
|
708
|
-
self.num_latents = latent_gene_pool.num_latents
|
709
714
|
self.latent_gene_pool = latent_gene_pool
|
715
|
+
self.num_latents = latent_gene_pool.num_latents if exists(latent_gene_pool) else 1
|
716
|
+
self.has_latent_genes = exists(latent_gene_pool)
|
717
|
+
|
718
|
+
assert actor.dim_latent == critic.dim_latent
|
710
719
|
|
711
|
-
|
720
|
+
if self.has_latent_genes:
|
721
|
+
assert latent_gene_pool.dim_latent == actor.dim_latent
|
712
722
|
|
713
723
|
# gae function
|
714
724
|
|
@@ -730,13 +740,19 @@ class Agent(Module):
|
|
730
740
|
self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
|
731
741
|
self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
|
732
742
|
|
733
|
-
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
743
|
+
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
|
734
744
|
|
735
745
|
# promotes latents to be farther apart for diversity maintenance
|
736
746
|
|
737
747
|
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
738
748
|
self.diversity_aux_loss_weight = diversity_aux_loss_weight
|
739
749
|
|
750
|
+
self.register_buffer('dummy', tensor(0))
|
751
|
+
|
752
|
+
@property
|
753
|
+
def device(self):
|
754
|
+
return self.dummy.device
|
755
|
+
|
740
756
|
def save(self, path, overwrite = False):
|
741
757
|
path = Path(path)
|
742
758
|
|
@@ -746,7 +762,7 @@ class Agent(Module):
|
|
746
762
|
actor = self.actor.state_dict(),
|
747
763
|
critic = self.critic.state_dict(),
|
748
764
|
critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
|
749
|
-
latents = self.latent_gene_pool.state_dict(),
|
765
|
+
latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
|
750
766
|
actor_optim = self.actor_optim.state_dict(),
|
751
767
|
critic_optim = self.critic_optim.state_dict(),
|
752
768
|
latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
|
@@ -768,7 +784,8 @@ class Agent(Module):
|
|
768
784
|
if self.use_critic_ema:
|
769
785
|
self.critic_ema.load_state_dict(pkg['critic_ema'])
|
770
786
|
|
771
|
-
|
787
|
+
if exists(pkg.get('latents', None)):
|
788
|
+
self.latent_gene_pool.load_state_dict(pkg['latents'])
|
772
789
|
|
773
790
|
self.actor_optim.load_state_dict(pkg['actor_optim'])
|
774
791
|
self.critic_optim.load_state_dict(pkg['critic_optim'])
|
@@ -784,9 +801,8 @@ class Agent(Module):
|
|
784
801
|
sample = False,
|
785
802
|
temperature = 1.
|
786
803
|
):
|
787
|
-
assert exists(latent_id) or exists(latent)
|
788
804
|
|
789
|
-
if not exists(latent):
|
805
|
+
if not exists(latent) and exists(latent_id):
|
790
806
|
latent = self.latent_gene_pool(latent_id = latent_id)
|
791
807
|
|
792
808
|
logits = self.actor(state, latent)
|
@@ -807,9 +823,8 @@ class Agent(Module):
|
|
807
823
|
latent = None,
|
808
824
|
use_ema_if_available = False
|
809
825
|
):
|
810
|
-
assert exists(latent_id) or exists(latent)
|
811
826
|
|
812
|
-
if not exists(latent):
|
827
|
+
if not exists(latent) and exists(latent_id):
|
813
828
|
latent = self.latent_gene_pool(latent_id = latent_id)
|
814
829
|
|
815
830
|
critic_forward = self.critic
|
@@ -823,13 +838,19 @@ class Agent(Module):
|
|
823
838
|
self,
|
824
839
|
fitnesses
|
825
840
|
):
|
841
|
+
if not self.has_latent_genes:
|
842
|
+
return
|
843
|
+
|
826
844
|
return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
|
827
845
|
|
828
846
|
def forward(
|
829
847
|
self,
|
830
848
|
memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
|
831
849
|
epochs = 2
|
850
|
+
|
832
851
|
):
|
852
|
+
memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
|
853
|
+
|
833
854
|
memories, rewards_per_latent_episode = memories_and_cumulative_rewards
|
834
855
|
|
835
856
|
# stack memories
|
@@ -893,11 +914,14 @@ class Agent(Module):
|
|
893
914
|
old_values
|
894
915
|
) in dataloader:
|
895
916
|
|
896
|
-
|
917
|
+
if self.has_latent_genes:
|
918
|
+
latents = self.latent_gene_pool(latent_id = latent_gene_ids)
|
897
919
|
|
898
|
-
|
899
|
-
|
900
|
-
|
920
|
+
orig_latents = latents
|
921
|
+
latents = latents.detach()
|
922
|
+
latents.requires_grad_()
|
923
|
+
else:
|
924
|
+
latents = None
|
901
925
|
|
902
926
|
# learn actor
|
903
927
|
|
@@ -936,7 +960,7 @@ class Agent(Module):
|
|
936
960
|
|
937
961
|
# maybe update latents, if not frozen
|
938
962
|
|
939
|
-
if self.latent_gene_pool.frozen_latents:
|
963
|
+
if not self.has_latent_genes or self.latent_gene_pool.frozen_latents:
|
940
964
|
continue
|
941
965
|
|
942
966
|
orig_latents.backward(latents.grad)
|
@@ -952,7 +976,9 @@ class Agent(Module):
|
|
952
976
|
|
953
977
|
# apply evolution
|
954
978
|
|
955
|
-
self.
|
979
|
+
if self.has_latent_genes:
|
980
|
+
|
981
|
+
self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
|
956
982
|
|
957
983
|
# reinforcement learning related - ppo
|
958
984
|
|
@@ -1005,11 +1031,16 @@ def create_agent(
|
|
1005
1031
|
**kwargs
|
1006
1032
|
) -> Agent:
|
1007
1033
|
|
1034
|
+
has_latent_genes = num_latents > 1
|
1035
|
+
|
1036
|
+
if not has_latent_genes:
|
1037
|
+
dim_latent = None
|
1038
|
+
|
1008
1039
|
latent_gene_pool = LatentGenePool(
|
1009
1040
|
num_latents = num_latents,
|
1010
1041
|
dim_latent = dim_latent,
|
1011
1042
|
**latent_gene_pool_kwargs
|
1012
|
-
)
|
1043
|
+
) if has_latent_genes else None
|
1013
1044
|
|
1014
1045
|
actor = Actor(
|
1015
1046
|
num_actions = actor_num_actions,
|
@@ -1069,7 +1100,7 @@ class EPO(Module):
|
|
1069
1100
|
self.agent = agent
|
1070
1101
|
self.action_sample_temperature = action_sample_temperature
|
1071
1102
|
|
1072
|
-
self.num_latents = agent.latent_gene_pool.num_latents
|
1103
|
+
self.num_latents = agent.latent_gene_pool.num_latents if agent.has_latent_genes else 1
|
1073
1104
|
self.episodes_per_latent = episodes_per_latent
|
1074
1105
|
self.max_episode_length = max_episode_length
|
1075
1106
|
self.fix_environ_across_latents = fix_environ_across_latents
|
@@ -1159,7 +1190,7 @@ class EPO(Module):
|
|
1159
1190
|
|
1160
1191
|
# get latent from pool
|
1161
1192
|
|
1162
|
-
latent = self.agent.latent_gene_pool(latent_id = latent_id)
|
1193
|
+
latent = self.agent.latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
|
1163
1194
|
|
1164
1195
|
# until maximum episode length
|
1165
1196
|
|
@@ -1198,6 +1229,8 @@ class EPO(Module):
|
|
1198
1229
|
terminated
|
1199
1230
|
)
|
1200
1231
|
|
1232
|
+
memory = Memory(*tuple(t.cpu() for t in memory))
|
1233
|
+
|
1201
1234
|
memories.append(memory)
|
1202
1235
|
|
1203
1236
|
time += 1
|
@@ -1209,7 +1242,7 @@ class EPO(Module):
|
|
1209
1242
|
|
1210
1243
|
memory_for_gae = memory._replace(
|
1211
1244
|
episode_id = invalid_episode,
|
1212
|
-
value = next_value,
|
1245
|
+
value = next_value.cpu(),
|
1213
1246
|
done = tensor(True)
|
1214
1247
|
)
|
1215
1248
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.68
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,9 +1,9 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=xhE_kHas54xGsgOese9SQEvyK7NKZqEuK3AiVhm0y7Q,38047
|
4
4
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
5
|
evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
9
|
-
evolutionary_policy_optimization-0.0.
|
6
|
+
evolutionary_policy_optimization-0.0.68.dist-info/METADATA,sha256=hOOKOrrPQtQmK3zN1z5nkGJEoaQLyXUzs9ArsEKn1DE,6220
|
7
|
+
evolutionary_policy_optimization-0.0.68.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.68.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.68.dist-info/RECORD,,
|
File without changes
|