evolutionary-policy-optimization 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +80 -14
- {evolutionary_policy_optimization-0.1.1.dist-info → evolutionary_policy_optimization-0.1.2.dist-info}/METADATA +2 -2
- {evolutionary_policy_optimization-0.1.1.dist-info → evolutionary_policy_optimization-0.1.2.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.1.dist-info → evolutionary_policy_optimization-0.1.2.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.1.dist-info → evolutionary_policy_optimization-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -40,6 +40,8 @@ from ema_pytorch import EMA
|
|
40
40
|
|
41
41
|
from tqdm import tqdm
|
42
42
|
|
43
|
+
from accelerate import Accelerator
|
44
|
+
|
43
45
|
# helpers
|
44
46
|
|
45
47
|
def exists(v):
|
@@ -60,6 +62,17 @@ def divisible_by(num, den):
|
|
60
62
|
def to_device(inp, device):
|
61
63
|
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
|
62
64
|
|
65
|
+
def maybe(fn):
|
66
|
+
|
67
|
+
@wraps(fn)
|
68
|
+
def decorated(inp, *args, **kwargs):
|
69
|
+
if not exists(inp):
|
70
|
+
return None
|
71
|
+
|
72
|
+
return fn(inp, *args, **kwargs)
|
73
|
+
|
74
|
+
return decorated
|
75
|
+
|
63
76
|
def interface_torch_numpy(fn, device):
|
64
77
|
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
65
78
|
|
@@ -721,10 +734,22 @@ class Agent(Module):
|
|
721
734
|
actor_optim_kwargs: dict = dict(),
|
722
735
|
critic_optim_kwargs: dict = dict(),
|
723
736
|
latent_optim_kwargs: dict = dict(),
|
724
|
-
get_fitness_scores: Callable[..., Tensor] = get_fitness_scores
|
737
|
+
get_fitness_scores: Callable[..., Tensor] = get_fitness_scores,
|
738
|
+
wrap_with_accelerate: bool = True,
|
739
|
+
accelerate_kwargs: dict = dict(),
|
725
740
|
):
|
726
741
|
super().__init__()
|
727
742
|
|
743
|
+
# hf accelerate
|
744
|
+
|
745
|
+
self.wrap_with_accelerate = wrap_with_accelerate
|
746
|
+
|
747
|
+
if wrap_with_accelerate:
|
748
|
+
accelerate = Accelerator(**accelerate_kwargs)
|
749
|
+
self.accelerate = accelerate
|
750
|
+
|
751
|
+
# actor, critic, and their shared latent gene pool
|
752
|
+
|
728
753
|
self.actor = actor
|
729
754
|
|
730
755
|
self.critic = critic
|
@@ -768,12 +793,43 @@ class Agent(Module):
|
|
768
793
|
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
769
794
|
self.diversity_aux_loss_weight = diversity_aux_loss_weight
|
770
795
|
|
771
|
-
|
796
|
+
# wrap with accelerate
|
797
|
+
|
798
|
+
self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
|
799
|
+
|
800
|
+
if wrap_with_accelerate:
|
801
|
+
(
|
802
|
+
self.actor,
|
803
|
+
self.critic,
|
804
|
+
self.latent_gene_pool,
|
805
|
+
self.actor_optim,
|
806
|
+
self.critic_optim,
|
807
|
+
self.latent_optim,
|
808
|
+
) = tuple(
|
809
|
+
maybe(self.accelerate.prepare)(m) for m in (
|
810
|
+
self.actor,
|
811
|
+
self.critic,
|
812
|
+
self.latent_gene_pool,
|
813
|
+
self.actor_optim,
|
814
|
+
self.critic_optim,
|
815
|
+
self.latent_optim,
|
816
|
+
)
|
817
|
+
)
|
818
|
+
|
819
|
+
# device tracking
|
820
|
+
|
821
|
+
self.register_buffer('dummy', tensor(0, device = self.accelerate.device))
|
822
|
+
|
823
|
+
self.critic_ema.to(self.accelerate.device)
|
772
824
|
|
773
825
|
@property
|
774
826
|
def device(self):
|
775
827
|
return self.dummy.device
|
776
828
|
|
829
|
+
@property
|
830
|
+
def unwrapped_latent_gene_pool(self):
|
831
|
+
return self.unwrap_model(self.latent_gene_pool)
|
832
|
+
|
777
833
|
def save(self, path, overwrite = False):
|
778
834
|
path = Path(path)
|
779
835
|
|
@@ -820,13 +876,15 @@ class Agent(Module):
|
|
820
876
|
latent_id = None,
|
821
877
|
latent = None,
|
822
878
|
sample = False,
|
823
|
-
temperature = 1
|
879
|
+
temperature = 1.,
|
880
|
+
use_unwrapped_model = False
|
824
881
|
):
|
882
|
+
maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
|
825
883
|
|
826
884
|
if not exists(latent) and exists(latent_id):
|
827
|
-
latent = self.latent_gene_pool(latent_id = latent_id)
|
885
|
+
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
828
886
|
|
829
|
-
logits = self.actor(state, latent)
|
887
|
+
logits = maybe_unwrap(self.actor)(state, latent)
|
830
888
|
|
831
889
|
if not sample:
|
832
890
|
return logits
|
@@ -842,13 +900,15 @@ class Agent(Module):
|
|
842
900
|
state,
|
843
901
|
latent_id = None,
|
844
902
|
latent = None,
|
845
|
-
use_ema_if_available = False
|
903
|
+
use_ema_if_available = False,
|
904
|
+
use_unwrapped_model = False
|
846
905
|
):
|
906
|
+
maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
|
847
907
|
|
848
908
|
if not exists(latent) and exists(latent_id):
|
849
|
-
latent = self.latent_gene_pool(latent_id = latent_id)
|
909
|
+
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
850
910
|
|
851
|
-
critic_forward = self.critic
|
911
|
+
critic_forward = maybe_unwrap(self.critic)
|
852
912
|
|
853
913
|
if use_ema_if_available and self.use_critic_ema:
|
854
914
|
critic_forward = self.critic_ema
|
@@ -922,6 +982,9 @@ class Agent(Module):
|
|
922
982
|
|
923
983
|
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
924
984
|
|
985
|
+
if self.wrap_with_accelerate:
|
986
|
+
dataloader = self.accelerate.prepare(dataloader)
|
987
|
+
|
925
988
|
# updating actor and critic
|
926
989
|
|
927
990
|
self.actor.train()
|
@@ -955,7 +1018,7 @@ class Agent(Module):
|
|
955
1018
|
actor_loss.backward()
|
956
1019
|
|
957
1020
|
if exists(self.has_grad_clip):
|
958
|
-
|
1021
|
+
self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
959
1022
|
|
960
1023
|
self.actor_optim.step()
|
961
1024
|
self.actor_optim.zero_grad()
|
@@ -971,7 +1034,7 @@ class Agent(Module):
|
|
971
1034
|
critic_loss.backward()
|
972
1035
|
|
973
1036
|
if exists(self.has_grad_clip):
|
974
|
-
|
1037
|
+
self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
975
1038
|
|
976
1039
|
self.critic_optim.step()
|
977
1040
|
self.critic_optim.zero_grad()
|
@@ -994,6 +1057,9 @@ class Agent(Module):
|
|
994
1057
|
|
995
1058
|
(diversity_loss * self.diversity_aux_loss_weight).backward()
|
996
1059
|
|
1060
|
+
if exists(self.has_grad_clip):
|
1061
|
+
self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
|
1062
|
+
|
997
1063
|
self.latent_optim.step()
|
998
1064
|
self.latent_optim.zero_grad()
|
999
1065
|
|
@@ -1128,7 +1194,7 @@ class EPO(Module):
|
|
1128
1194
|
self.max_episode_length = max_episode_length
|
1129
1195
|
self.fix_environ_across_latents = fix_environ_across_latents
|
1130
1196
|
|
1131
|
-
self.register_buffer('dummy', tensor(0))
|
1197
|
+
self.register_buffer('dummy', tensor(0, device = agent.device))
|
1132
1198
|
|
1133
1199
|
@property
|
1134
1200
|
def device(self):
|
@@ -1215,7 +1281,7 @@ class EPO(Module):
|
|
1215
1281
|
|
1216
1282
|
# get latent from pool
|
1217
1283
|
|
1218
|
-
latent = self.agent.
|
1284
|
+
latent = self.agent.unwrapped_latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
|
1219
1285
|
|
1220
1286
|
# until maximum episode length
|
1221
1287
|
|
@@ -1225,11 +1291,11 @@ class EPO(Module):
|
|
1225
1291
|
|
1226
1292
|
# sample action
|
1227
1293
|
|
1228
|
-
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
1294
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature, use_unwrapped_model = True)
|
1229
1295
|
|
1230
1296
|
# values
|
1231
1297
|
|
1232
|
-
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
1298
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
|
1233
1299
|
|
1234
1300
|
# get the next state, action, and reward
|
1235
1301
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.8
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
|
-
Requires-Dist: accelerate
|
37
|
+
Requires-Dist: accelerate>=1.6.0
|
38
38
|
Requires-Dist: adam-atan2-pytorch
|
39
39
|
Requires-Dist: assoc-scan>=0.0.2
|
40
40
|
Requires-Dist: einops>=0.8.1
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=VgiPvaG-Lib1JZRS3oNV7BSv-19YSWKK2yCMtqDqP-M,41682
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.2.dist-info/METADATA,sha256=kXkidoqrBxyMtQvh7GHr1wh1pexy2WDSDWN5laQSPhI,6336
|
8
|
+
evolutionary_policy_optimization-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.2.dist-info/RECORD,,
|
File without changes
|