evolutionary-policy-optimization 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/__init__.py +4 -0
- evolutionary_policy_optimization/env_wrappers.py +36 -0
- evolutionary_policy_optimization/epo.py +81 -14
- {evolutionary_policy_optimization-0.1.0.dist-info → evolutionary_policy_optimization-0.1.2.dist-info}/METADATA +2 -1
- evolutionary_policy_optimization-0.1.2.dist-info/RECORD +10 -0
- evolutionary_policy_optimization-0.1.0.dist-info/RECORD +0 -9
- {evolutionary_policy_optimization-0.1.0.dist-info → evolutionary_policy_optimization-0.1.2.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.0.dist-info → evolutionary_policy_optimization-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,36 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.nn import Module
|
3
|
+
|
4
|
+
from evolutionary_policy_optimization.epo import create_agent, Agent
|
5
|
+
|
6
|
+
class GymnasiumEnvWrapper(Module):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
env
|
10
|
+
):
|
11
|
+
super().__init__()
|
12
|
+
self.env = env
|
13
|
+
|
14
|
+
def reset(self, *args, **kwargs):
|
15
|
+
return self.env.reset(*args, **kwargs)
|
16
|
+
|
17
|
+
def step(self, *args, **kwargs):
|
18
|
+
return self.env.step(*args, **kwargs)
|
19
|
+
|
20
|
+
def to_agent_hparams(self):
|
21
|
+
return dict(
|
22
|
+
dim_state = self.env.observation_space.shape[0],
|
23
|
+
actor_num_actions = self.env.action_space.n
|
24
|
+
)
|
25
|
+
|
26
|
+
def to_epo_agent(
|
27
|
+
self,
|
28
|
+
*args,
|
29
|
+
**kwargs
|
30
|
+
) -> Agent:
|
31
|
+
|
32
|
+
return create_agent(
|
33
|
+
*args,
|
34
|
+
**self.to_agent_hparams(),
|
35
|
+
**kwargs
|
36
|
+
)
|
@@ -40,6 +40,8 @@ from ema_pytorch import EMA
|
|
40
40
|
|
41
41
|
from tqdm import tqdm
|
42
42
|
|
43
|
+
from accelerate import Accelerator
|
44
|
+
|
43
45
|
# helpers
|
44
46
|
|
45
47
|
def exists(v):
|
@@ -60,6 +62,17 @@ def divisible_by(num, den):
|
|
60
62
|
def to_device(inp, device):
|
61
63
|
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
|
62
64
|
|
65
|
+
def maybe(fn):
|
66
|
+
|
67
|
+
@wraps(fn)
|
68
|
+
def decorated(inp, *args, **kwargs):
|
69
|
+
if not exists(inp):
|
70
|
+
return None
|
71
|
+
|
72
|
+
return fn(inp, *args, **kwargs)
|
73
|
+
|
74
|
+
return decorated
|
75
|
+
|
63
76
|
def interface_torch_numpy(fn, device):
|
64
77
|
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
65
78
|
|
@@ -721,10 +734,22 @@ class Agent(Module):
|
|
721
734
|
actor_optim_kwargs: dict = dict(),
|
722
735
|
critic_optim_kwargs: dict = dict(),
|
723
736
|
latent_optim_kwargs: dict = dict(),
|
724
|
-
get_fitness_scores: Callable[..., Tensor] = get_fitness_scores
|
737
|
+
get_fitness_scores: Callable[..., Tensor] = get_fitness_scores,
|
738
|
+
wrap_with_accelerate: bool = True,
|
739
|
+
accelerate_kwargs: dict = dict(),
|
725
740
|
):
|
726
741
|
super().__init__()
|
727
742
|
|
743
|
+
# hf accelerate
|
744
|
+
|
745
|
+
self.wrap_with_accelerate = wrap_with_accelerate
|
746
|
+
|
747
|
+
if wrap_with_accelerate:
|
748
|
+
accelerate = Accelerator(**accelerate_kwargs)
|
749
|
+
self.accelerate = accelerate
|
750
|
+
|
751
|
+
# actor, critic, and their shared latent gene pool
|
752
|
+
|
728
753
|
self.actor = actor
|
729
754
|
|
730
755
|
self.critic = critic
|
@@ -768,12 +793,43 @@ class Agent(Module):
|
|
768
793
|
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
769
794
|
self.diversity_aux_loss_weight = diversity_aux_loss_weight
|
770
795
|
|
771
|
-
|
796
|
+
# wrap with accelerate
|
797
|
+
|
798
|
+
self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
|
799
|
+
|
800
|
+
if wrap_with_accelerate:
|
801
|
+
(
|
802
|
+
self.actor,
|
803
|
+
self.critic,
|
804
|
+
self.latent_gene_pool,
|
805
|
+
self.actor_optim,
|
806
|
+
self.critic_optim,
|
807
|
+
self.latent_optim,
|
808
|
+
) = tuple(
|
809
|
+
maybe(self.accelerate.prepare)(m) for m in (
|
810
|
+
self.actor,
|
811
|
+
self.critic,
|
812
|
+
self.latent_gene_pool,
|
813
|
+
self.actor_optim,
|
814
|
+
self.critic_optim,
|
815
|
+
self.latent_optim,
|
816
|
+
)
|
817
|
+
)
|
818
|
+
|
819
|
+
# device tracking
|
820
|
+
|
821
|
+
self.register_buffer('dummy', tensor(0, device = self.accelerate.device))
|
822
|
+
|
823
|
+
self.critic_ema.to(self.accelerate.device)
|
772
824
|
|
773
825
|
@property
|
774
826
|
def device(self):
|
775
827
|
return self.dummy.device
|
776
828
|
|
829
|
+
@property
|
830
|
+
def unwrapped_latent_gene_pool(self):
|
831
|
+
return self.unwrap_model(self.latent_gene_pool)
|
832
|
+
|
777
833
|
def save(self, path, overwrite = False):
|
778
834
|
path = Path(path)
|
779
835
|
|
@@ -820,13 +876,15 @@ class Agent(Module):
|
|
820
876
|
latent_id = None,
|
821
877
|
latent = None,
|
822
878
|
sample = False,
|
823
|
-
temperature = 1
|
879
|
+
temperature = 1.,
|
880
|
+
use_unwrapped_model = False
|
824
881
|
):
|
882
|
+
maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
|
825
883
|
|
826
884
|
if not exists(latent) and exists(latent_id):
|
827
|
-
latent = self.latent_gene_pool(latent_id = latent_id)
|
885
|
+
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
828
886
|
|
829
|
-
logits = self.actor(state, latent)
|
887
|
+
logits = maybe_unwrap(self.actor)(state, latent)
|
830
888
|
|
831
889
|
if not sample:
|
832
890
|
return logits
|
@@ -842,13 +900,15 @@ class Agent(Module):
|
|
842
900
|
state,
|
843
901
|
latent_id = None,
|
844
902
|
latent = None,
|
845
|
-
use_ema_if_available = False
|
903
|
+
use_ema_if_available = False,
|
904
|
+
use_unwrapped_model = False
|
846
905
|
):
|
906
|
+
maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
|
847
907
|
|
848
908
|
if not exists(latent) and exists(latent_id):
|
849
|
-
latent = self.latent_gene_pool(latent_id = latent_id)
|
909
|
+
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
850
910
|
|
851
|
-
critic_forward = self.critic
|
911
|
+
critic_forward = maybe_unwrap(self.critic)
|
852
912
|
|
853
913
|
if use_ema_if_available and self.use_critic_ema:
|
854
914
|
critic_forward = self.critic_ema
|
@@ -922,6 +982,9 @@ class Agent(Module):
|
|
922
982
|
|
923
983
|
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
924
984
|
|
985
|
+
if self.wrap_with_accelerate:
|
986
|
+
dataloader = self.accelerate.prepare(dataloader)
|
987
|
+
|
925
988
|
# updating actor and critic
|
926
989
|
|
927
990
|
self.actor.train()
|
@@ -955,7 +1018,7 @@ class Agent(Module):
|
|
955
1018
|
actor_loss.backward()
|
956
1019
|
|
957
1020
|
if exists(self.has_grad_clip):
|
958
|
-
|
1021
|
+
self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
959
1022
|
|
960
1023
|
self.actor_optim.step()
|
961
1024
|
self.actor_optim.zero_grad()
|
@@ -971,7 +1034,7 @@ class Agent(Module):
|
|
971
1034
|
critic_loss.backward()
|
972
1035
|
|
973
1036
|
if exists(self.has_grad_clip):
|
974
|
-
|
1037
|
+
self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
975
1038
|
|
976
1039
|
self.critic_optim.step()
|
977
1040
|
self.critic_optim.zero_grad()
|
@@ -994,6 +1057,9 @@ class Agent(Module):
|
|
994
1057
|
|
995
1058
|
(diversity_loss * self.diversity_aux_loss_weight).backward()
|
996
1059
|
|
1060
|
+
if exists(self.has_grad_clip):
|
1061
|
+
self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
|
1062
|
+
|
997
1063
|
self.latent_optim.step()
|
998
1064
|
self.latent_optim.zero_grad()
|
999
1065
|
|
@@ -1040,6 +1106,7 @@ def actor_loss(
|
|
1040
1106
|
# agent contains the actor, critic, and the latent genetic pool
|
1041
1107
|
|
1042
1108
|
def create_agent(
|
1109
|
+
*,
|
1043
1110
|
dim_state,
|
1044
1111
|
num_latents,
|
1045
1112
|
dim_latent,
|
@@ -1127,7 +1194,7 @@ class EPO(Module):
|
|
1127
1194
|
self.max_episode_length = max_episode_length
|
1128
1195
|
self.fix_environ_across_latents = fix_environ_across_latents
|
1129
1196
|
|
1130
|
-
self.register_buffer('dummy', tensor(0))
|
1197
|
+
self.register_buffer('dummy', tensor(0, device = agent.device))
|
1131
1198
|
|
1132
1199
|
@property
|
1133
1200
|
def device(self):
|
@@ -1214,7 +1281,7 @@ class EPO(Module):
|
|
1214
1281
|
|
1215
1282
|
# get latent from pool
|
1216
1283
|
|
1217
|
-
latent = self.agent.
|
1284
|
+
latent = self.agent.unwrapped_latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
|
1218
1285
|
|
1219
1286
|
# until maximum episode length
|
1220
1287
|
|
@@ -1224,11 +1291,11 @@ class EPO(Module):
|
|
1224
1291
|
|
1225
1292
|
# sample action
|
1226
1293
|
|
1227
|
-
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
1294
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature, use_unwrapped_model = True)
|
1228
1295
|
|
1229
1296
|
# values
|
1230
1297
|
|
1231
|
-
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
1298
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
|
1232
1299
|
|
1233
1300
|
# get the next state, action, and reward
|
1234
1301
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.2
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.8
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
|
+
Requires-Dist: accelerate>=1.6.0
|
37
38
|
Requires-Dist: adam-atan2-pytorch
|
38
39
|
Requires-Dist: assoc-scan>=0.0.2
|
39
40
|
Requires-Dist: einops>=0.8.1
|
@@ -0,0 +1,10 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
|
+
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
+
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=VgiPvaG-Lib1JZRS3oNV7BSv-19YSWKK2yCMtqDqP-M,41682
|
5
|
+
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
|
+
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
+
evolutionary_policy_optimization-0.1.2.dist-info/METADATA,sha256=kXkidoqrBxyMtQvh7GHr1wh1pexy2WDSDWN5laQSPhI,6336
|
8
|
+
evolutionary_policy_optimization-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.2.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=VrFD5lFQrS7KeYTC-WavEMTHgQXoq7vNPVXnRJwFSDI,39491
|
4
|
-
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
-
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
6
|
-
evolutionary_policy_optimization-0.1.0.dist-info/METADATA,sha256=Nl3JSTSireDXXvX1ZrMQqTA8CXz1USUQOwPuuCtgCJw,6303
|
7
|
-
evolutionary_policy_optimization-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
evolutionary_policy_optimization-0.1.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
evolutionary_policy_optimization-0.1.0.dist-info/RECORD,,
|
File without changes
|