evolutionary-policy-optimization 0.0.39__py3-none-any.whl → 0.0.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
3
+ from functools import partial, wraps
4
4
  from pathlib import Path
5
5
  from collections import namedtuple
6
6
 
@@ -9,6 +9,7 @@ from torch import nn, cat, stack, is_tensor, tensor
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
  from torch.utils.data import TensorDataset, DataLoader
12
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
12
13
 
13
14
  import einx
14
15
  from einops import rearrange, repeat, einsum, pack
@@ -73,6 +74,19 @@ def gather_log_prob(
73
74
  log_prob = log_probs.gather(-1, indices)
74
75
  return rearrange(log_prob, '... 1 -> ...')
75
76
 
77
+ def temp_batch_dim(fn):
78
+
79
+ @wraps(fn)
80
+ def inner(*args, **kwargs):
81
+ args, kwargs = tree_map(lambda t: rearrange(t, '... -> 1 ...') if is_tensor(t) else t, (args, kwargs))
82
+
83
+ out = fn(*args, **kwargs)
84
+
85
+ out = tree_map(lambda t: rearrange(t, '1 ... -> ...') if is_tensor(t) else t, out)
86
+ return out
87
+
88
+ return inner
89
+
76
90
  # generalized advantage estimate
77
91
 
78
92
  def calc_generalized_advantage_estimate(
@@ -784,6 +798,10 @@ class Agent(Module):
784
798
 
785
799
  latents = self.latent_gene_pool(latent_id = latent_gene_ids)
786
800
 
801
+ orig_latents = latents
802
+ latents = latents.detach()
803
+ latents.requires_grad_()
804
+
787
805
  # learn actor
788
806
 
789
807
  logits = self.actor(states, latents)
@@ -808,6 +826,14 @@ class Agent(Module):
808
826
  self.critic_optim.step()
809
827
  self.critic_optim.zero_grad()
810
828
 
829
+ # maybe update latents, if not frozen
830
+
831
+ if not self.latent_gene_pool.frozen_latents:
832
+ orig_latents.backward(latents.grad)
833
+
834
+ self.latent_optim.step()
835
+ self.latent_optim.zero_grad()
836
+
811
837
  # apply evolution
812
838
 
813
839
  self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
@@ -937,22 +963,15 @@ class EPO(Module):
937
963
 
938
964
  done = tensor(False)
939
965
 
940
- while time < self.max_episode_length:
941
-
942
- batched_state = rearrange(state, '... -> 1 ...')
966
+ while time < self.max_episode_length and not done:
943
967
 
944
968
  # sample action
945
969
 
946
- action, log_prob = self.agent.get_actor_actions(batched_state, latent = latent, sample = True, temperature = self.action_sample_temperature)
947
-
948
- action = rearrange(action, '1 ... -> ...')
949
- log_prob = rearrange(log_prob, '1 ... -> ...')
970
+ action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
950
971
 
951
972
  # values
952
973
 
953
- value = self.agent.get_critic_values(batched_state, latent = latent)
954
-
955
- value = rearrange(value, '1 ... -> ...')
974
+ value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
956
975
 
957
976
  # get the next state, action, and reward
958
977
 
@@ -981,10 +1000,7 @@ class EPO(Module):
981
1000
 
982
1001
  # need the final next value for GAE, iiuc
983
1002
 
984
- batched_state = rearrange(state, '... -> 1 ...')
985
-
986
- next_value = self.agent.get_critic_values(batched_state, latent = latent)
987
- next_value = rearrange(next_value, '1 ... -> ...')
1003
+ next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
988
1004
 
989
1005
  memory_for_gae = memory._replace(
990
1006
  episode_id = invalid_episode,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.39
3
+ Version: 0.0.41
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -162,4 +162,25 @@ agent.load('./agent.pt')
162
162
  }
163
163
  ```
164
164
 
165
+ ```bibtex
166
+ @inproceedings{Khadka2018EvolutionGuidedPG,
167
+ title = {Evolution-Guided Policy Gradient in Reinforcement Learning},
168
+ author = {Shauharda Khadka and Kagan Tumer},
169
+ booktitle = {Neural Information Processing Systems},
170
+ year = {2018},
171
+ url = {https://api.semanticscholar.org/CorpusID:53096951}
172
+ }
173
+ ```
174
+
175
+ ```bibtex
176
+ @article{Fortunato2017NoisyNF,
177
+ title = {Noisy Networks for Exploration},
178
+ author = {Meire Fortunato and Mohammad Gheshlaghi Azar and Bilal Piot and Jacob Menick and Ian Osband and Alex Graves and Vlad Mnih and R{\'e}mi Munos and Demis Hassabis and Olivier Pietquin and Charles Blundell and Shane Legg},
179
+ journal = {ArXiv},
180
+ year = {2017},
181
+ volume = {abs/1706.10295},
182
+ url = {https://api.semanticscholar.org/CorpusID:5176587}
183
+ }
184
+ ```
185
+
165
186
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -1,8 +1,8 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=lzxPamJahE5KqBwzyYlGOwNeUoB2vONLwtRcWqCI_Jw,29800
2
+ evolutionary_policy_optimization/epo.py,sha256=GL3nH5crOj4y_Amu2BY0s95MJL7F2t-X085y40SgUK0,30260
3
3
  evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
4
  evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
5
- evolutionary_policy_optimization-0.0.39.dist-info/METADATA,sha256=TTNQD7sTWIgpVwnrQrFFBD-cyySkvwJr_J3ABxTpor8,5409
6
- evolutionary_policy_optimization-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.39.dist-info/RECORD,,
5
+ evolutionary_policy_optimization-0.0.41.dist-info/METADATA,sha256=TFKI2B2PeyU6pHwqmCu130k-U2Li_QmUkvVB39-4uDw,6213
6
+ evolutionary_policy_optimization-0.0.41.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.41.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.41.dist-info/RECORD,,