evolutionary-policy-optimization 0.0.39__py3-none-any.whl → 0.0.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +31 -15
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.41.dist-info}/METADATA +22 -1
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.41.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.41.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.41.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from functools import partial
|
3
|
+
from functools import partial, wraps
|
4
4
|
from pathlib import Path
|
5
5
|
from collections import namedtuple
|
6
6
|
|
@@ -9,6 +9,7 @@ from torch import nn, cat, stack, is_tensor, tensor
|
|
9
9
|
import torch.nn.functional as F
|
10
10
|
from torch.nn import Linear, Module, ModuleList
|
11
11
|
from torch.utils.data import TensorDataset, DataLoader
|
12
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
12
13
|
|
13
14
|
import einx
|
14
15
|
from einops import rearrange, repeat, einsum, pack
|
@@ -73,6 +74,19 @@ def gather_log_prob(
|
|
73
74
|
log_prob = log_probs.gather(-1, indices)
|
74
75
|
return rearrange(log_prob, '... 1 -> ...')
|
75
76
|
|
77
|
+
def temp_batch_dim(fn):
|
78
|
+
|
79
|
+
@wraps(fn)
|
80
|
+
def inner(*args, **kwargs):
|
81
|
+
args, kwargs = tree_map(lambda t: rearrange(t, '... -> 1 ...') if is_tensor(t) else t, (args, kwargs))
|
82
|
+
|
83
|
+
out = fn(*args, **kwargs)
|
84
|
+
|
85
|
+
out = tree_map(lambda t: rearrange(t, '1 ... -> ...') if is_tensor(t) else t, out)
|
86
|
+
return out
|
87
|
+
|
88
|
+
return inner
|
89
|
+
|
76
90
|
# generalized advantage estimate
|
77
91
|
|
78
92
|
def calc_generalized_advantage_estimate(
|
@@ -784,6 +798,10 @@ class Agent(Module):
|
|
784
798
|
|
785
799
|
latents = self.latent_gene_pool(latent_id = latent_gene_ids)
|
786
800
|
|
801
|
+
orig_latents = latents
|
802
|
+
latents = latents.detach()
|
803
|
+
latents.requires_grad_()
|
804
|
+
|
787
805
|
# learn actor
|
788
806
|
|
789
807
|
logits = self.actor(states, latents)
|
@@ -808,6 +826,14 @@ class Agent(Module):
|
|
808
826
|
self.critic_optim.step()
|
809
827
|
self.critic_optim.zero_grad()
|
810
828
|
|
829
|
+
# maybe update latents, if not frozen
|
830
|
+
|
831
|
+
if not self.latent_gene_pool.frozen_latents:
|
832
|
+
orig_latents.backward(latents.grad)
|
833
|
+
|
834
|
+
self.latent_optim.step()
|
835
|
+
self.latent_optim.zero_grad()
|
836
|
+
|
811
837
|
# apply evolution
|
812
838
|
|
813
839
|
self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
|
@@ -937,22 +963,15 @@ class EPO(Module):
|
|
937
963
|
|
938
964
|
done = tensor(False)
|
939
965
|
|
940
|
-
while time < self.max_episode_length:
|
941
|
-
|
942
|
-
batched_state = rearrange(state, '... -> 1 ...')
|
966
|
+
while time < self.max_episode_length and not done:
|
943
967
|
|
944
968
|
# sample action
|
945
969
|
|
946
|
-
action, log_prob = self.agent.get_actor_actions(
|
947
|
-
|
948
|
-
action = rearrange(action, '1 ... -> ...')
|
949
|
-
log_prob = rearrange(log_prob, '1 ... -> ...')
|
970
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
950
971
|
|
951
972
|
# values
|
952
973
|
|
953
|
-
value = self.agent.get_critic_values(
|
954
|
-
|
955
|
-
value = rearrange(value, '1 ... -> ...')
|
974
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
956
975
|
|
957
976
|
# get the next state, action, and reward
|
958
977
|
|
@@ -981,10 +1000,7 @@ class EPO(Module):
|
|
981
1000
|
|
982
1001
|
# need the final next value for GAE, iiuc
|
983
1002
|
|
984
|
-
|
985
|
-
|
986
|
-
next_value = self.agent.get_critic_values(batched_state, latent = latent)
|
987
|
-
next_value = rearrange(next_value, '1 ... -> ...')
|
1003
|
+
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
988
1004
|
|
989
1005
|
memory_for_gae = memory._replace(
|
990
1006
|
episode_id = invalid_episode,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.41
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -162,4 +162,25 @@ agent.load('./agent.pt')
|
|
162
162
|
}
|
163
163
|
```
|
164
164
|
|
165
|
+
```bibtex
|
166
|
+
@inproceedings{Khadka2018EvolutionGuidedPG,
|
167
|
+
title = {Evolution-Guided Policy Gradient in Reinforcement Learning},
|
168
|
+
author = {Shauharda Khadka and Kagan Tumer},
|
169
|
+
booktitle = {Neural Information Processing Systems},
|
170
|
+
year = {2018},
|
171
|
+
url = {https://api.semanticscholar.org/CorpusID:53096951}
|
172
|
+
}
|
173
|
+
```
|
174
|
+
|
175
|
+
```bibtex
|
176
|
+
@article{Fortunato2017NoisyNF,
|
177
|
+
title = {Noisy Networks for Exploration},
|
178
|
+
author = {Meire Fortunato and Mohammad Gheshlaghi Azar and Bilal Piot and Jacob Menick and Ian Osband and Alex Graves and Vlad Mnih and R{\'e}mi Munos and Demis Hassabis and Olivier Pietquin and Charles Blundell and Shane Legg},
|
179
|
+
journal = {ArXiv},
|
180
|
+
year = {2017},
|
181
|
+
volume = {abs/1706.10295},
|
182
|
+
url = {https://api.semanticscholar.org/CorpusID:5176587}
|
183
|
+
}
|
184
|
+
```
|
185
|
+
|
165
186
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -1,8 +1,8 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=GL3nH5crOj4y_Amu2BY0s95MJL7F2t-X085y40SgUK0,30260
|
3
3
|
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
4
|
evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
|
5
|
-
evolutionary_policy_optimization-0.0.
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
5
|
+
evolutionary_policy_optimization-0.0.41.dist-info/METADATA,sha256=TFKI2B2PeyU6pHwqmCu130k-U2Li_QmUkvVB39-4uDw,6213
|
6
|
+
evolutionary_policy_optimization-0.0.41.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.41.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.41.dist-info/RECORD,,
|
File without changes
|