evolutionary-policy-optimization 0.1.1__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/PKG-INFO +2 -2
  2. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/epo.py +80 -14
  3. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/pyproject.toml +2 -2
  4. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/train_gym.py +6 -5
  5. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/.github/workflows/python-publish.yml +0 -0
  6. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/.github/workflows/test.yml +0 -0
  7. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/.gitignore +0 -0
  8. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/LICENSE +0 -0
  9. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/README.md +0 -0
  10. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/__init__.py +0 -0
  11. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/distributed.py +0 -0
  12. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  13. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.1 → evolutionary_policy_optimization-0.1.2}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.8
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
- Requires-Dist: accelerate
37
+ Requires-Dist: accelerate>=1.6.0
38
38
  Requires-Dist: adam-atan2-pytorch
39
39
  Requires-Dist: assoc-scan>=0.0.2
40
40
  Requires-Dist: einops>=0.8.1
@@ -40,6 +40,8 @@ from ema_pytorch import EMA
40
40
 
41
41
  from tqdm import tqdm
42
42
 
43
+ from accelerate import Accelerator
44
+
43
45
  # helpers
44
46
 
45
47
  def exists(v):
@@ -60,6 +62,17 @@ def divisible_by(num, den):
60
62
  def to_device(inp, device):
61
63
  return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
62
64
 
65
+ def maybe(fn):
66
+
67
+ @wraps(fn)
68
+ def decorated(inp, *args, **kwargs):
69
+ if not exists(inp):
70
+ return None
71
+
72
+ return fn(inp, *args, **kwargs)
73
+
74
+ return decorated
75
+
63
76
  def interface_torch_numpy(fn, device):
64
77
  # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
65
78
 
@@ -721,10 +734,22 @@ class Agent(Module):
721
734
  actor_optim_kwargs: dict = dict(),
722
735
  critic_optim_kwargs: dict = dict(),
723
736
  latent_optim_kwargs: dict = dict(),
724
- get_fitness_scores: Callable[..., Tensor] = get_fitness_scores
737
+ get_fitness_scores: Callable[..., Tensor] = get_fitness_scores,
738
+ wrap_with_accelerate: bool = True,
739
+ accelerate_kwargs: dict = dict(),
725
740
  ):
726
741
  super().__init__()
727
742
 
743
+ # hf accelerate
744
+
745
+ self.wrap_with_accelerate = wrap_with_accelerate
746
+
747
+ if wrap_with_accelerate:
748
+ accelerate = Accelerator(**accelerate_kwargs)
749
+ self.accelerate = accelerate
750
+
751
+ # actor, critic, and their shared latent gene pool
752
+
728
753
  self.actor = actor
729
754
 
730
755
  self.critic = critic
@@ -768,12 +793,43 @@ class Agent(Module):
768
793
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
769
794
  self.diversity_aux_loss_weight = diversity_aux_loss_weight
770
795
 
771
- self.register_buffer('dummy', tensor(0))
796
+ # wrap with accelerate
797
+
798
+ self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
799
+
800
+ if wrap_with_accelerate:
801
+ (
802
+ self.actor,
803
+ self.critic,
804
+ self.latent_gene_pool,
805
+ self.actor_optim,
806
+ self.critic_optim,
807
+ self.latent_optim,
808
+ ) = tuple(
809
+ maybe(self.accelerate.prepare)(m) for m in (
810
+ self.actor,
811
+ self.critic,
812
+ self.latent_gene_pool,
813
+ self.actor_optim,
814
+ self.critic_optim,
815
+ self.latent_optim,
816
+ )
817
+ )
818
+
819
+ # device tracking
820
+
821
+ self.register_buffer('dummy', tensor(0, device = self.accelerate.device))
822
+
823
+ self.critic_ema.to(self.accelerate.device)
772
824
 
773
825
  @property
774
826
  def device(self):
775
827
  return self.dummy.device
776
828
 
829
+ @property
830
+ def unwrapped_latent_gene_pool(self):
831
+ return self.unwrap_model(self.latent_gene_pool)
832
+
777
833
  def save(self, path, overwrite = False):
778
834
  path = Path(path)
779
835
 
@@ -820,13 +876,15 @@ class Agent(Module):
820
876
  latent_id = None,
821
877
  latent = None,
822
878
  sample = False,
823
- temperature = 1.
879
+ temperature = 1.,
880
+ use_unwrapped_model = False
824
881
  ):
882
+ maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
825
883
 
826
884
  if not exists(latent) and exists(latent_id):
827
- latent = self.latent_gene_pool(latent_id = latent_id)
885
+ latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
828
886
 
829
- logits = self.actor(state, latent)
887
+ logits = maybe_unwrap(self.actor)(state, latent)
830
888
 
831
889
  if not sample:
832
890
  return logits
@@ -842,13 +900,15 @@ class Agent(Module):
842
900
  state,
843
901
  latent_id = None,
844
902
  latent = None,
845
- use_ema_if_available = False
903
+ use_ema_if_available = False,
904
+ use_unwrapped_model = False
846
905
  ):
906
+ maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
847
907
 
848
908
  if not exists(latent) and exists(latent_id):
849
- latent = self.latent_gene_pool(latent_id = latent_id)
909
+ latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
850
910
 
851
- critic_forward = self.critic
911
+ critic_forward = maybe_unwrap(self.critic)
852
912
 
853
913
  if use_ema_if_available and self.use_critic_ema:
854
914
  critic_forward = self.critic_ema
@@ -922,6 +982,9 @@ class Agent(Module):
922
982
 
923
983
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
924
984
 
985
+ if self.wrap_with_accelerate:
986
+ dataloader = self.accelerate.prepare(dataloader)
987
+
925
988
  # updating actor and critic
926
989
 
927
990
  self.actor.train()
@@ -955,7 +1018,7 @@ class Agent(Module):
955
1018
  actor_loss.backward()
956
1019
 
957
1020
  if exists(self.has_grad_clip):
958
- nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
1021
+ self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
959
1022
 
960
1023
  self.actor_optim.step()
961
1024
  self.actor_optim.zero_grad()
@@ -971,7 +1034,7 @@ class Agent(Module):
971
1034
  critic_loss.backward()
972
1035
 
973
1036
  if exists(self.has_grad_clip):
974
- nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
1037
+ self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
975
1038
 
976
1039
  self.critic_optim.step()
977
1040
  self.critic_optim.zero_grad()
@@ -994,6 +1057,9 @@ class Agent(Module):
994
1057
 
995
1058
  (diversity_loss * self.diversity_aux_loss_weight).backward()
996
1059
 
1060
+ if exists(self.has_grad_clip):
1061
+ self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
1062
+
997
1063
  self.latent_optim.step()
998
1064
  self.latent_optim.zero_grad()
999
1065
 
@@ -1128,7 +1194,7 @@ class EPO(Module):
1128
1194
  self.max_episode_length = max_episode_length
1129
1195
  self.fix_environ_across_latents = fix_environ_across_latents
1130
1196
 
1131
- self.register_buffer('dummy', tensor(0))
1197
+ self.register_buffer('dummy', tensor(0, device = agent.device))
1132
1198
 
1133
1199
  @property
1134
1200
  def device(self):
@@ -1215,7 +1281,7 @@ class EPO(Module):
1215
1281
 
1216
1282
  # get latent from pool
1217
1283
 
1218
- latent = self.agent.latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1284
+ latent = self.agent.unwrapped_latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1219
1285
 
1220
1286
  # until maximum episode length
1221
1287
 
@@ -1225,11 +1291,11 @@ class EPO(Module):
1225
1291
 
1226
1292
  # sample action
1227
1293
 
1228
- action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1294
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature, use_unwrapped_model = True)
1229
1295
 
1230
1296
  # values
1231
1297
 
1232
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1298
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
1233
1299
 
1234
1300
  # get the next state, action, and reward
1235
1301
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.1"
3
+ version = "0.1.2"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,7 +24,7 @@ classifiers = [
24
24
  ]
25
25
 
26
26
  dependencies = [
27
- "accelerate",
27
+ "accelerate>=1.6.0",
28
28
  "adam-atan2-pytorch",
29
29
  'assoc-scan>=0.0.2',
30
30
  'einx>=0.3.0',
@@ -26,18 +26,19 @@ agent = env.to_epo_agent(
26
26
  latent_gene_pool_kwargs = dict(
27
27
  frac_natural_selected = 0.5,
28
28
  frac_tournaments = 0.5
29
+ ),
30
+ accelerate_kwargs = dict(
31
+ cpu = False
29
32
  )
30
33
  )
31
34
 
32
35
  epo = EPO(
33
36
  agent,
34
- episodes_per_latent = 1,
37
+ episodes_per_latent = 5,
35
38
  max_episode_length = 10,
36
- action_sample_temperature = 1.
39
+ action_sample_temperature = 1.,
37
40
  )
38
41
 
39
- epo.to('cpu' if not torch.cuda.is_available() else 'cuda')
40
-
41
- epo(agent, env, num_learning_cycles = 1)
42
+ epo(agent, env, num_learning_cycles = 5)
42
43
 
43
44
  agent.save('./agent.pt', overwrite = True)