evolutionary-policy-optimization 0.1.1__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/PKG-INFO +3 -3
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/README.md +1 -1
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/epo.py +100 -15
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/pyproject.toml +2 -2
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/train_gym.py +6 -5
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/distributed.py +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/env_wrappers.py +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/requirements.txt +0 -0
- {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.4
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.8
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
|
-
Requires-Dist: accelerate
|
37
|
+
Requires-Dist: accelerate>=1.6.0
|
38
38
|
Requires-Dist: adam-atan2-pytorch
|
39
39
|
Requires-Dist: assoc-scan>=0.0.2
|
40
40
|
Requires-Dist: einops>=0.8.1
|
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
|
|
53
53
|
|
54
54
|
<img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
|
55
55
|
|
56
|
-
## Evolutionary Policy Optimization
|
56
|
+
## Evolutionary Policy Optimization
|
57
57
|
|
58
58
|
Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
|
59
59
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
<img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
|
2
2
|
|
3
|
-
## Evolutionary Policy Optimization
|
3
|
+
## Evolutionary Policy Optimization
|
4
4
|
|
5
5
|
Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
|
6
6
|
|
@@ -40,6 +40,8 @@ from ema_pytorch import EMA
|
|
40
40
|
|
41
41
|
from tqdm import tqdm
|
42
42
|
|
43
|
+
from accelerate import Accelerator
|
44
|
+
|
43
45
|
# helpers
|
44
46
|
|
45
47
|
def exists(v):
|
@@ -60,13 +62,24 @@ def divisible_by(num, den):
|
|
60
62
|
def to_device(inp, device):
|
61
63
|
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
|
62
64
|
|
65
|
+
def maybe(fn):
|
66
|
+
|
67
|
+
@wraps(fn)
|
68
|
+
def decorated(inp, *args, **kwargs):
|
69
|
+
if not exists(inp):
|
70
|
+
return None
|
71
|
+
|
72
|
+
return fn(inp, *args, **kwargs)
|
73
|
+
|
74
|
+
return decorated
|
75
|
+
|
63
76
|
def interface_torch_numpy(fn, device):
|
64
77
|
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
65
78
|
|
66
79
|
@wraps(fn)
|
67
80
|
def decorated_fn(*args, **kwargs):
|
68
81
|
|
69
|
-
args, kwargs = tree_map(lambda t: t.cpu().numpy() if
|
82
|
+
args, kwargs = tree_map(lambda t: t.cpu().numpy() if is_tensor(t) else t, (args, kwargs))
|
70
83
|
|
71
84
|
out = fn(*args, **kwargs)
|
72
85
|
|
@@ -75,6 +88,16 @@ def interface_torch_numpy(fn, device):
|
|
75
88
|
|
76
89
|
return decorated_fn
|
77
90
|
|
91
|
+
def move_input_tensors_to_device(fn):
|
92
|
+
|
93
|
+
@wraps(fn)
|
94
|
+
def decorated_fn(self, *args, **kwargs):
|
95
|
+
args, kwargs = tree_map(lambda t: t.to(self.device) if is_tensor(t) else t, (args, kwargs))
|
96
|
+
|
97
|
+
return fn(self, *args, **kwargs)
|
98
|
+
|
99
|
+
return decorated_fn
|
100
|
+
|
78
101
|
# tensor helpers
|
79
102
|
|
80
103
|
def l2norm(t):
|
@@ -721,10 +744,22 @@ class Agent(Module):
|
|
721
744
|
actor_optim_kwargs: dict = dict(),
|
722
745
|
critic_optim_kwargs: dict = dict(),
|
723
746
|
latent_optim_kwargs: dict = dict(),
|
724
|
-
get_fitness_scores: Callable[..., Tensor] = get_fitness_scores
|
747
|
+
get_fitness_scores: Callable[..., Tensor] = get_fitness_scores,
|
748
|
+
wrap_with_accelerate: bool = True,
|
749
|
+
accelerate_kwargs: dict = dict(),
|
725
750
|
):
|
726
751
|
super().__init__()
|
727
752
|
|
753
|
+
# hf accelerate
|
754
|
+
|
755
|
+
self.wrap_with_accelerate = wrap_with_accelerate
|
756
|
+
|
757
|
+
if wrap_with_accelerate:
|
758
|
+
accelerate = Accelerator(**accelerate_kwargs)
|
759
|
+
self.accelerate = accelerate
|
760
|
+
|
761
|
+
# actor, critic, and their shared latent gene pool
|
762
|
+
|
728
763
|
self.actor = actor
|
729
764
|
|
730
765
|
self.critic = critic
|
@@ -768,12 +803,48 @@ class Agent(Module):
|
|
768
803
|
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
769
804
|
self.diversity_aux_loss_weight = diversity_aux_loss_weight
|
770
805
|
|
771
|
-
|
806
|
+
# wrap with accelerate
|
807
|
+
|
808
|
+
self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
|
809
|
+
|
810
|
+
dummy = tensor(0)
|
811
|
+
|
812
|
+
if wrap_with_accelerate:
|
813
|
+
(
|
814
|
+
self.actor,
|
815
|
+
self.critic,
|
816
|
+
self.latent_gene_pool,
|
817
|
+
self.actor_optim,
|
818
|
+
self.critic_optim,
|
819
|
+
self.latent_optim,
|
820
|
+
) = tuple(
|
821
|
+
maybe(self.accelerate.prepare)(m) for m in (
|
822
|
+
self.actor,
|
823
|
+
self.critic,
|
824
|
+
self.latent_gene_pool,
|
825
|
+
self.actor_optim,
|
826
|
+
self.critic_optim,
|
827
|
+
self.latent_optim,
|
828
|
+
)
|
829
|
+
)
|
830
|
+
|
831
|
+
if exists(self.critic_ema):
|
832
|
+
self.critic_ema.to(self.accelerate.device)
|
833
|
+
|
834
|
+
dummy = dummy.to(self.accelerate.device)
|
835
|
+
|
836
|
+
# device tracking
|
837
|
+
|
838
|
+
self.register_buffer('dummy', dummy)
|
772
839
|
|
773
840
|
@property
|
774
841
|
def device(self):
|
775
842
|
return self.dummy.device
|
776
843
|
|
844
|
+
@property
|
845
|
+
def unwrapped_latent_gene_pool(self):
|
846
|
+
return self.unwrap_model(self.latent_gene_pool)
|
847
|
+
|
777
848
|
def save(self, path, overwrite = False):
|
778
849
|
path = Path(path)
|
779
850
|
|
@@ -814,19 +885,23 @@ class Agent(Module):
|
|
814
885
|
if exists(pkg.get('latent_optim', None)):
|
815
886
|
self.latent_optim.load_state_dict(pkg['latent_optim'])
|
816
887
|
|
888
|
+
@move_input_tensors_to_device
|
817
889
|
def get_actor_actions(
|
818
890
|
self,
|
819
891
|
state,
|
820
892
|
latent_id = None,
|
821
893
|
latent = None,
|
822
894
|
sample = False,
|
823
|
-
temperature = 1
|
895
|
+
temperature = 1.,
|
896
|
+
use_unwrapped_model = False
|
824
897
|
):
|
898
|
+
maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
|
825
899
|
|
826
900
|
if not exists(latent) and exists(latent_id):
|
827
|
-
latent = self.latent_gene_pool(latent_id = latent_id)
|
901
|
+
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
828
902
|
|
829
|
-
|
903
|
+
print(self.device, state.device, next(self.actor.parameters()).device)
|
904
|
+
logits = maybe_unwrap(self.actor)(state, latent)
|
830
905
|
|
831
906
|
if not sample:
|
832
907
|
return logits
|
@@ -837,18 +912,22 @@ class Agent(Module):
|
|
837
912
|
|
838
913
|
return actions, log_probs
|
839
914
|
|
915
|
+
@move_input_tensors_to_device
|
840
916
|
def get_critic_values(
|
841
917
|
self,
|
842
918
|
state,
|
843
919
|
latent_id = None,
|
844
920
|
latent = None,
|
845
|
-
use_ema_if_available = False
|
921
|
+
use_ema_if_available = False,
|
922
|
+
use_unwrapped_model = False
|
846
923
|
):
|
847
924
|
|
925
|
+
maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
|
926
|
+
|
848
927
|
if not exists(latent) and exists(latent_id):
|
849
|
-
latent = self.latent_gene_pool(latent_id = latent_id)
|
928
|
+
latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
|
850
929
|
|
851
|
-
critic_forward = self.critic
|
930
|
+
critic_forward = maybe_unwrap(self.critic)
|
852
931
|
|
853
932
|
if use_ema_if_available and self.use_critic_ema:
|
854
933
|
critic_forward = self.critic_ema
|
@@ -922,6 +1001,9 @@ class Agent(Module):
|
|
922
1001
|
|
923
1002
|
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
924
1003
|
|
1004
|
+
if self.wrap_with_accelerate:
|
1005
|
+
dataloader = self.accelerate.prepare(dataloader)
|
1006
|
+
|
925
1007
|
# updating actor and critic
|
926
1008
|
|
927
1009
|
self.actor.train()
|
@@ -955,7 +1037,7 @@ class Agent(Module):
|
|
955
1037
|
actor_loss.backward()
|
956
1038
|
|
957
1039
|
if exists(self.has_grad_clip):
|
958
|
-
|
1040
|
+
self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
959
1041
|
|
960
1042
|
self.actor_optim.step()
|
961
1043
|
self.actor_optim.zero_grad()
|
@@ -971,7 +1053,7 @@ class Agent(Module):
|
|
971
1053
|
critic_loss.backward()
|
972
1054
|
|
973
1055
|
if exists(self.has_grad_clip):
|
974
|
-
|
1056
|
+
self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
975
1057
|
|
976
1058
|
self.critic_optim.step()
|
977
1059
|
self.critic_optim.zero_grad()
|
@@ -994,6 +1076,9 @@ class Agent(Module):
|
|
994
1076
|
|
995
1077
|
(diversity_loss * self.diversity_aux_loss_weight).backward()
|
996
1078
|
|
1079
|
+
if exists(self.has_grad_clip):
|
1080
|
+
self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
|
1081
|
+
|
997
1082
|
self.latent_optim.step()
|
998
1083
|
self.latent_optim.zero_grad()
|
999
1084
|
|
@@ -1128,7 +1213,7 @@ class EPO(Module):
|
|
1128
1213
|
self.max_episode_length = max_episode_length
|
1129
1214
|
self.fix_environ_across_latents = fix_environ_across_latents
|
1130
1215
|
|
1131
|
-
self.register_buffer('dummy', tensor(0))
|
1216
|
+
self.register_buffer('dummy', tensor(0, device = agent.device))
|
1132
1217
|
|
1133
1218
|
@property
|
1134
1219
|
def device(self):
|
@@ -1215,7 +1300,7 @@ class EPO(Module):
|
|
1215
1300
|
|
1216
1301
|
# get latent from pool
|
1217
1302
|
|
1218
|
-
latent = self.agent.
|
1303
|
+
latent = self.agent.unwrapped_latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
|
1219
1304
|
|
1220
1305
|
# until maximum episode length
|
1221
1306
|
|
@@ -1225,11 +1310,11 @@ class EPO(Module):
|
|
1225
1310
|
|
1226
1311
|
# sample action
|
1227
1312
|
|
1228
|
-
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
1313
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature, use_unwrapped_model = True)
|
1229
1314
|
|
1230
1315
|
# values
|
1231
1316
|
|
1232
|
-
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
|
1317
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
|
1233
1318
|
|
1234
1319
|
# get the next state, action, and reward
|
1235
1320
|
|
{evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/pyproject.toml
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "evolutionary-policy-optimization"
|
3
|
-
version = "0.1.
|
3
|
+
version = "0.1.4"
|
4
4
|
description = "EPO - Pytorch"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -24,7 +24,7 @@ classifiers = [
|
|
24
24
|
]
|
25
25
|
|
26
26
|
dependencies = [
|
27
|
-
"accelerate",
|
27
|
+
"accelerate>=1.6.0",
|
28
28
|
"adam-atan2-pytorch",
|
29
29
|
'assoc-scan>=0.0.2',
|
30
30
|
'einx>=0.3.0',
|
{evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/train_gym.py
RENAMED
@@ -26,18 +26,19 @@ agent = env.to_epo_agent(
|
|
26
26
|
latent_gene_pool_kwargs = dict(
|
27
27
|
frac_natural_selected = 0.5,
|
28
28
|
frac_tournaments = 0.5
|
29
|
+
),
|
30
|
+
accelerate_kwargs = dict(
|
31
|
+
cpu = False
|
29
32
|
)
|
30
33
|
)
|
31
34
|
|
32
35
|
epo = EPO(
|
33
36
|
agent,
|
34
|
-
episodes_per_latent =
|
37
|
+
episodes_per_latent = 5,
|
35
38
|
max_episode_length = 10,
|
36
|
-
action_sample_temperature = 1
|
39
|
+
action_sample_temperature = 1.,
|
37
40
|
)
|
38
41
|
|
39
|
-
epo
|
40
|
-
|
41
|
-
epo(agent, env, num_learning_cycles = 1)
|
42
|
+
epo(agent, env, num_learning_cycles = 5)
|
42
43
|
|
43
44
|
agent.save('./agent.pt', overwrite = True)
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/.gitignore
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/requirements.txt
RENAMED
File without changes
|
{evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/tests/test_epo.py
RENAMED
File without changes
|