evolutionary-policy-optimization 0.1.1__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/PKG-INFO +3 -3
  2. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/README.md +1 -1
  3. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/epo.py +100 -15
  4. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/pyproject.toml +2 -2
  5. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/train_gym.py +6 -5
  6. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.4}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.1
3
+ Version: 0.1.4
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.8
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: accelerate
37
+ Requires-Dist: accelerate>=1.6.0
38
38
  Requires-Dist: adam-atan2-pytorch
39
39
  Requires-Dist: assoc-scan>=0.0.2
40
40
  Requires-Dist: einops>=0.8.1
@@ -53,7 +53,7 @@ Description-Content-Type: text/markdown
53
53
 
54
54
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
55
55
 
56
- ## Evolutionary Policy Optimization (wip)
56
+ ## Evolutionary Policy Optimization
57
57
 
58
58
  Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
59
59
 
@@ -1,6 +1,6 @@
1
1
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
2
2
 
3
- ## Evolutionary Policy Optimization (wip)
3
+ ## Evolutionary Policy Optimization
4
4
 
5
5
  Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
6
6
 
@@ -40,6 +40,8 @@ from ema_pytorch import EMA
40
40
 
41
41
  from tqdm import tqdm
42
42
 
43
+ from accelerate import Accelerator
44
+
43
45
  # helpers
44
46
 
45
47
  def exists(v):
@@ -60,13 +62,24 @@ def divisible_by(num, den):
60
62
  def to_device(inp, device):
61
63
  return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
62
64
 
65
+ def maybe(fn):
66
+
67
+ @wraps(fn)
68
+ def decorated(inp, *args, **kwargs):
69
+ if not exists(inp):
70
+ return None
71
+
72
+ return fn(inp, *args, **kwargs)
73
+
74
+ return decorated
75
+
63
76
  def interface_torch_numpy(fn, device):
64
77
  # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
65
78
 
66
79
  @wraps(fn)
67
80
  def decorated_fn(*args, **kwargs):
68
81
 
69
- args, kwargs = tree_map(lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t, (args, kwargs))
82
+ args, kwargs = tree_map(lambda t: t.cpu().numpy() if is_tensor(t) else t, (args, kwargs))
70
83
 
71
84
  out = fn(*args, **kwargs)
72
85
 
@@ -75,6 +88,16 @@ def interface_torch_numpy(fn, device):
75
88
 
76
89
  return decorated_fn
77
90
 
91
+ def move_input_tensors_to_device(fn):
92
+
93
+ @wraps(fn)
94
+ def decorated_fn(self, *args, **kwargs):
95
+ args, kwargs = tree_map(lambda t: t.to(self.device) if is_tensor(t) else t, (args, kwargs))
96
+
97
+ return fn(self, *args, **kwargs)
98
+
99
+ return decorated_fn
100
+
78
101
  # tensor helpers
79
102
 
80
103
  def l2norm(t):
@@ -721,10 +744,22 @@ class Agent(Module):
721
744
  actor_optim_kwargs: dict = dict(),
722
745
  critic_optim_kwargs: dict = dict(),
723
746
  latent_optim_kwargs: dict = dict(),
724
- get_fitness_scores: Callable[..., Tensor] = get_fitness_scores
747
+ get_fitness_scores: Callable[..., Tensor] = get_fitness_scores,
748
+ wrap_with_accelerate: bool = True,
749
+ accelerate_kwargs: dict = dict(),
725
750
  ):
726
751
  super().__init__()
727
752
 
753
+ # hf accelerate
754
+
755
+ self.wrap_with_accelerate = wrap_with_accelerate
756
+
757
+ if wrap_with_accelerate:
758
+ accelerate = Accelerator(**accelerate_kwargs)
759
+ self.accelerate = accelerate
760
+
761
+ # actor, critic, and their shared latent gene pool
762
+
728
763
  self.actor = actor
729
764
 
730
765
  self.critic = critic
@@ -768,12 +803,48 @@ class Agent(Module):
768
803
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
769
804
  self.diversity_aux_loss_weight = diversity_aux_loss_weight
770
805
 
771
- self.register_buffer('dummy', tensor(0))
806
+ # wrap with accelerate
807
+
808
+ self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
809
+
810
+ dummy = tensor(0)
811
+
812
+ if wrap_with_accelerate:
813
+ (
814
+ self.actor,
815
+ self.critic,
816
+ self.latent_gene_pool,
817
+ self.actor_optim,
818
+ self.critic_optim,
819
+ self.latent_optim,
820
+ ) = tuple(
821
+ maybe(self.accelerate.prepare)(m) for m in (
822
+ self.actor,
823
+ self.critic,
824
+ self.latent_gene_pool,
825
+ self.actor_optim,
826
+ self.critic_optim,
827
+ self.latent_optim,
828
+ )
829
+ )
830
+
831
+ if exists(self.critic_ema):
832
+ self.critic_ema.to(self.accelerate.device)
833
+
834
+ dummy = dummy.to(self.accelerate.device)
835
+
836
+ # device tracking
837
+
838
+ self.register_buffer('dummy', dummy)
772
839
 
773
840
  @property
774
841
  def device(self):
775
842
  return self.dummy.device
776
843
 
844
+ @property
845
+ def unwrapped_latent_gene_pool(self):
846
+ return self.unwrap_model(self.latent_gene_pool)
847
+
777
848
  def save(self, path, overwrite = False):
778
849
  path = Path(path)
779
850
 
@@ -814,19 +885,23 @@ class Agent(Module):
814
885
  if exists(pkg.get('latent_optim', None)):
815
886
  self.latent_optim.load_state_dict(pkg['latent_optim'])
816
887
 
888
+ @move_input_tensors_to_device
817
889
  def get_actor_actions(
818
890
  self,
819
891
  state,
820
892
  latent_id = None,
821
893
  latent = None,
822
894
  sample = False,
823
- temperature = 1.
895
+ temperature = 1.,
896
+ use_unwrapped_model = False
824
897
  ):
898
+ maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
825
899
 
826
900
  if not exists(latent) and exists(latent_id):
827
- latent = self.latent_gene_pool(latent_id = latent_id)
901
+ latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
828
902
 
829
- logits = self.actor(state, latent)
903
+ print(self.device, state.device, next(self.actor.parameters()).device)
904
+ logits = maybe_unwrap(self.actor)(state, latent)
830
905
 
831
906
  if not sample:
832
907
  return logits
@@ -837,18 +912,22 @@ class Agent(Module):
837
912
 
838
913
  return actions, log_probs
839
914
 
915
+ @move_input_tensors_to_device
840
916
  def get_critic_values(
841
917
  self,
842
918
  state,
843
919
  latent_id = None,
844
920
  latent = None,
845
- use_ema_if_available = False
921
+ use_ema_if_available = False,
922
+ use_unwrapped_model = False
846
923
  ):
847
924
 
925
+ maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
926
+
848
927
  if not exists(latent) and exists(latent_id):
849
- latent = self.latent_gene_pool(latent_id = latent_id)
928
+ latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
850
929
 
851
- critic_forward = self.critic
930
+ critic_forward = maybe_unwrap(self.critic)
852
931
 
853
932
  if use_ema_if_available and self.use_critic_ema:
854
933
  critic_forward = self.critic_ema
@@ -922,6 +1001,9 @@ class Agent(Module):
922
1001
 
923
1002
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
924
1003
 
1004
+ if self.wrap_with_accelerate:
1005
+ dataloader = self.accelerate.prepare(dataloader)
1006
+
925
1007
  # updating actor and critic
926
1008
 
927
1009
  self.actor.train()
@@ -955,7 +1037,7 @@ class Agent(Module):
955
1037
  actor_loss.backward()
956
1038
 
957
1039
  if exists(self.has_grad_clip):
958
- nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
1040
+ self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
959
1041
 
960
1042
  self.actor_optim.step()
961
1043
  self.actor_optim.zero_grad()
@@ -971,7 +1053,7 @@ class Agent(Module):
971
1053
  critic_loss.backward()
972
1054
 
973
1055
  if exists(self.has_grad_clip):
974
- nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
1056
+ self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
975
1057
 
976
1058
  self.critic_optim.step()
977
1059
  self.critic_optim.zero_grad()
@@ -994,6 +1076,9 @@ class Agent(Module):
994
1076
 
995
1077
  (diversity_loss * self.diversity_aux_loss_weight).backward()
996
1078
 
1079
+ if exists(self.has_grad_clip):
1080
+ self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
1081
+
997
1082
  self.latent_optim.step()
998
1083
  self.latent_optim.zero_grad()
999
1084
 
@@ -1128,7 +1213,7 @@ class EPO(Module):
1128
1213
  self.max_episode_length = max_episode_length
1129
1214
  self.fix_environ_across_latents = fix_environ_across_latents
1130
1215
 
1131
- self.register_buffer('dummy', tensor(0))
1216
+ self.register_buffer('dummy', tensor(0, device = agent.device))
1132
1217
 
1133
1218
  @property
1134
1219
  def device(self):
@@ -1215,7 +1300,7 @@ class EPO(Module):
1215
1300
 
1216
1301
  # get latent from pool
1217
1302
 
1218
- latent = self.agent.latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1303
+ latent = self.agent.unwrapped_latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1219
1304
 
1220
1305
  # until maximum episode length
1221
1306
 
@@ -1225,11 +1310,11 @@ class EPO(Module):
1225
1310
 
1226
1311
  # sample action
1227
1312
 
1228
- action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1313
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature, use_unwrapped_model = True)
1229
1314
 
1230
1315
  # values
1231
1316
 
1232
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1317
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
1233
1318
 
1234
1319
  # get the next state, action, and reward
1235
1320
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.1"
3
+ version = "0.1.4"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,7 +24,7 @@ classifiers = [
24
24
  ]
25
25
 
26
26
  dependencies = [
27
- "accelerate",
27
+ "accelerate>=1.6.0",
28
28
  "adam-atan2-pytorch",
29
29
  'assoc-scan>=0.0.2',
30
30
  'einx>=0.3.0',
@@ -26,18 +26,19 @@ agent = env.to_epo_agent(
26
26
  latent_gene_pool_kwargs = dict(
27
27
  frac_natural_selected = 0.5,
28
28
  frac_tournaments = 0.5
29
+ ),
30
+ accelerate_kwargs = dict(
31
+ cpu = False
29
32
  )
30
33
  )
31
34
 
32
35
  epo = EPO(
33
36
  agent,
34
- episodes_per_latent = 1,
37
+ episodes_per_latent = 5,
35
38
  max_episode_length = 10,
36
- action_sample_temperature = 1.
39
+ action_sample_temperature = 1.,
37
40
  )
38
41
 
39
- epo.to('cpu' if not torch.cuda.is_available() else 'cuda')
40
-
41
- epo(agent, env, num_learning_cycles = 1)
42
+ epo(agent, env, num_learning_cycles = 5)
42
43
 
43
44
  agent.save('./agent.pt', overwrite = True)