evolutionary-policy-optimization 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/PKG-INFO +2 -1
  2. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/__init__.py +4 -0
  3. evolutionary_policy_optimization-0.1.2/evolutionary_policy_optimization/env_wrappers.py +36 -0
  4. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/epo.py +81 -14
  5. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/pyproject.toml +2 -1
  6. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/train_gym.py +19 -19
  7. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/.github/workflows/python-publish.yml +0 -0
  8. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/.github/workflows/test.yml +0 -0
  9. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/.gitignore +0 -0
  10. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/LICENSE +0 -0
  11. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/README.md +0 -0
  12. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/distributed.py +0 -0
  13. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/experimental.py +0 -0
  14. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/evolutionary_policy_optimization/mock_env.py +0 -0
  15. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/requirements.txt +0 -0
  16. {evolutionary_policy_optimization-0.1.0 → evolutionary_policy_optimization-0.1.2}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.8
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
+ Requires-Dist: accelerate>=1.6.0
37
38
  Requires-Dist: adam-atan2-pytorch
38
39
  Requires-Dist: assoc-scan>=0.0.2
39
40
  Requires-Dist: einops>=0.8.1
@@ -9,3 +9,7 @@ from evolutionary_policy_optimization.epo import (
9
9
  )
10
10
 
11
11
  from evolutionary_policy_optimization.mock_env import Env
12
+
13
+ from evolutionary_policy_optimization.env_wrappers import (
14
+ GymnasiumEnvWrapper
15
+ )
@@ -0,0 +1,36 @@
1
+ import torch
2
+ from torch.nn import Module
3
+
4
+ from evolutionary_policy_optimization.epo import create_agent, Agent
5
+
6
+ class GymnasiumEnvWrapper(Module):
7
+ def __init__(
8
+ self,
9
+ env
10
+ ):
11
+ super().__init__()
12
+ self.env = env
13
+
14
+ def reset(self, *args, **kwargs):
15
+ return self.env.reset(*args, **kwargs)
16
+
17
+ def step(self, *args, **kwargs):
18
+ return self.env.step(*args, **kwargs)
19
+
20
+ def to_agent_hparams(self):
21
+ return dict(
22
+ dim_state = self.env.observation_space.shape[0],
23
+ actor_num_actions = self.env.action_space.n
24
+ )
25
+
26
+ def to_epo_agent(
27
+ self,
28
+ *args,
29
+ **kwargs
30
+ ) -> Agent:
31
+
32
+ return create_agent(
33
+ *args,
34
+ **self.to_agent_hparams(),
35
+ **kwargs
36
+ )
@@ -40,6 +40,8 @@ from ema_pytorch import EMA
40
40
 
41
41
  from tqdm import tqdm
42
42
 
43
+ from accelerate import Accelerator
44
+
43
45
  # helpers
44
46
 
45
47
  def exists(v):
@@ -60,6 +62,17 @@ def divisible_by(num, den):
60
62
  def to_device(inp, device):
61
63
  return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
62
64
 
65
+ def maybe(fn):
66
+
67
+ @wraps(fn)
68
+ def decorated(inp, *args, **kwargs):
69
+ if not exists(inp):
70
+ return None
71
+
72
+ return fn(inp, *args, **kwargs)
73
+
74
+ return decorated
75
+
63
76
  def interface_torch_numpy(fn, device):
64
77
  # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
65
78
 
@@ -721,10 +734,22 @@ class Agent(Module):
721
734
  actor_optim_kwargs: dict = dict(),
722
735
  critic_optim_kwargs: dict = dict(),
723
736
  latent_optim_kwargs: dict = dict(),
724
- get_fitness_scores: Callable[..., Tensor] = get_fitness_scores
737
+ get_fitness_scores: Callable[..., Tensor] = get_fitness_scores,
738
+ wrap_with_accelerate: bool = True,
739
+ accelerate_kwargs: dict = dict(),
725
740
  ):
726
741
  super().__init__()
727
742
 
743
+ # hf accelerate
744
+
745
+ self.wrap_with_accelerate = wrap_with_accelerate
746
+
747
+ if wrap_with_accelerate:
748
+ accelerate = Accelerator(**accelerate_kwargs)
749
+ self.accelerate = accelerate
750
+
751
+ # actor, critic, and their shared latent gene pool
752
+
728
753
  self.actor = actor
729
754
 
730
755
  self.critic = critic
@@ -768,12 +793,43 @@ class Agent(Module):
768
793
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
769
794
  self.diversity_aux_loss_weight = diversity_aux_loss_weight
770
795
 
771
- self.register_buffer('dummy', tensor(0))
796
+ # wrap with accelerate
797
+
798
+ self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
799
+
800
+ if wrap_with_accelerate:
801
+ (
802
+ self.actor,
803
+ self.critic,
804
+ self.latent_gene_pool,
805
+ self.actor_optim,
806
+ self.critic_optim,
807
+ self.latent_optim,
808
+ ) = tuple(
809
+ maybe(self.accelerate.prepare)(m) for m in (
810
+ self.actor,
811
+ self.critic,
812
+ self.latent_gene_pool,
813
+ self.actor_optim,
814
+ self.critic_optim,
815
+ self.latent_optim,
816
+ )
817
+ )
818
+
819
+ # device tracking
820
+
821
+ self.register_buffer('dummy', tensor(0, device = self.accelerate.device))
822
+
823
+ self.critic_ema.to(self.accelerate.device)
772
824
 
773
825
  @property
774
826
  def device(self):
775
827
  return self.dummy.device
776
828
 
829
+ @property
830
+ def unwrapped_latent_gene_pool(self):
831
+ return self.unwrap_model(self.latent_gene_pool)
832
+
777
833
  def save(self, path, overwrite = False):
778
834
  path = Path(path)
779
835
 
@@ -820,13 +876,15 @@ class Agent(Module):
820
876
  latent_id = None,
821
877
  latent = None,
822
878
  sample = False,
823
- temperature = 1.
879
+ temperature = 1.,
880
+ use_unwrapped_model = False
824
881
  ):
882
+ maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
825
883
 
826
884
  if not exists(latent) and exists(latent_id):
827
- latent = self.latent_gene_pool(latent_id = latent_id)
885
+ latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
828
886
 
829
- logits = self.actor(state, latent)
887
+ logits = maybe_unwrap(self.actor)(state, latent)
830
888
 
831
889
  if not sample:
832
890
  return logits
@@ -842,13 +900,15 @@ class Agent(Module):
842
900
  state,
843
901
  latent_id = None,
844
902
  latent = None,
845
- use_ema_if_available = False
903
+ use_ema_if_available = False,
904
+ use_unwrapped_model = False
846
905
  ):
906
+ maybe_unwrap = identity if not use_unwrapped_model else self.unwrap_model
847
907
 
848
908
  if not exists(latent) and exists(latent_id):
849
- latent = self.latent_gene_pool(latent_id = latent_id)
909
+ latent = maybe_unwrap(self.latent_gene_pool)(latent_id = latent_id)
850
910
 
851
- critic_forward = self.critic
911
+ critic_forward = maybe_unwrap(self.critic)
852
912
 
853
913
  if use_ema_if_available and self.use_critic_ema:
854
914
  critic_forward = self.critic_ema
@@ -922,6 +982,9 @@ class Agent(Module):
922
982
 
923
983
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
924
984
 
985
+ if self.wrap_with_accelerate:
986
+ dataloader = self.accelerate.prepare(dataloader)
987
+
925
988
  # updating actor and critic
926
989
 
927
990
  self.actor.train()
@@ -955,7 +1018,7 @@ class Agent(Module):
955
1018
  actor_loss.backward()
956
1019
 
957
1020
  if exists(self.has_grad_clip):
958
- nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
1021
+ self.accelerate.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
959
1022
 
960
1023
  self.actor_optim.step()
961
1024
  self.actor_optim.zero_grad()
@@ -971,7 +1034,7 @@ class Agent(Module):
971
1034
  critic_loss.backward()
972
1035
 
973
1036
  if exists(self.has_grad_clip):
974
- nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
1037
+ self.accelerate.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
975
1038
 
976
1039
  self.critic_optim.step()
977
1040
  self.critic_optim.zero_grad()
@@ -994,6 +1057,9 @@ class Agent(Module):
994
1057
 
995
1058
  (diversity_loss * self.diversity_aux_loss_weight).backward()
996
1059
 
1060
+ if exists(self.has_grad_clip):
1061
+ self.accelerate.clip_grad_norm_(self.latent_gene_pool.parameters(), self.max_grad_norm)
1062
+
997
1063
  self.latent_optim.step()
998
1064
  self.latent_optim.zero_grad()
999
1065
 
@@ -1040,6 +1106,7 @@ def actor_loss(
1040
1106
  # agent contains the actor, critic, and the latent genetic pool
1041
1107
 
1042
1108
  def create_agent(
1109
+ *,
1043
1110
  dim_state,
1044
1111
  num_latents,
1045
1112
  dim_latent,
@@ -1127,7 +1194,7 @@ class EPO(Module):
1127
1194
  self.max_episode_length = max_episode_length
1128
1195
  self.fix_environ_across_latents = fix_environ_across_latents
1129
1196
 
1130
- self.register_buffer('dummy', tensor(0))
1197
+ self.register_buffer('dummy', tensor(0, device = agent.device))
1131
1198
 
1132
1199
  @property
1133
1200
  def device(self):
@@ -1214,7 +1281,7 @@ class EPO(Module):
1214
1281
 
1215
1282
  # get latent from pool
1216
1283
 
1217
- latent = self.agent.latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1284
+ latent = self.agent.unwrapped_latent_gene_pool(latent_id = latent_id) if self.agent.has_latent_genes else None
1218
1285
 
1219
1286
  # until maximum episode length
1220
1287
 
@@ -1224,11 +1291,11 @@ class EPO(Module):
1224
1291
 
1225
1292
  # sample action
1226
1293
 
1227
- action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
1294
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature, use_unwrapped_model = True)
1228
1295
 
1229
1296
  # values
1230
1297
 
1231
- value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True)
1298
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
1232
1299
 
1233
1300
  # get the next state, action, and reward
1234
1301
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.0"
3
+ version = "0.1.2"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -24,6 +24,7 @@ classifiers = [
24
24
  ]
25
25
 
26
26
  dependencies = [
27
+ "accelerate>=1.6.0",
27
28
  "adam-atan2-pytorch",
28
29
  'assoc-scan>=0.0.2',
29
30
  'einx>=0.3.0',
@@ -1,3 +1,10 @@
1
+ import torch
2
+
3
+ from evolutionary_policy_optimization import (
4
+ EPO,
5
+ GymnasiumEnvWrapper
6
+ )
7
+
1
8
  # gymnasium
2
9
 
3
10
  import gymnasium as gym
@@ -7,38 +14,31 @@ env = gym.make(
7
14
  render_mode = 'rgb_array'
8
15
  )
9
16
 
10
- state_dim = env.observation_space.shape[0]
11
- num_actions = env.action_space.n
17
+ env = GymnasiumEnvWrapper(env)
12
18
 
13
19
  # epo
14
20
 
15
- import torch
16
-
17
- from evolutionary_policy_optimization import (
18
- create_agent,
19
- EPO,
20
- Env
21
- )
22
-
23
- agent = create_agent(
24
- dim_state = state_dim,
25
- num_latents = 1,
21
+ agent = env.to_epo_agent(
22
+ num_latents = 8,
26
23
  dim_latent = 32,
27
- actor_num_actions = num_actions,
28
24
  actor_dim_hiddens = (256, 128),
29
25
  critic_dim_hiddens = (256, 128, 64),
30
26
  latent_gene_pool_kwargs = dict(
31
- frac_natural_selected = 0.5
27
+ frac_natural_selected = 0.5,
28
+ frac_tournaments = 0.5
29
+ ),
30
+ accelerate_kwargs = dict(
31
+ cpu = False
32
32
  )
33
33
  )
34
34
 
35
35
  epo = EPO(
36
36
  agent,
37
- episodes_per_latent = 1,
37
+ episodes_per_latent = 5,
38
38
  max_episode_length = 10,
39
- action_sample_temperature = 1.
39
+ action_sample_temperature = 1.,
40
40
  )
41
41
 
42
- epo.to('cpu' if not torch.cuda.is_available() else 'cuda')
43
-
44
42
  epo(agent, env, num_learning_cycles = 5)
43
+
44
+ agent.save('./agent.pt', overwrite = True)