evolutionary-policy-optimization 0.0.39__py3-none-any.whl → 0.0.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +18 -14
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.40.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.40.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.40.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.39.dist-info → evolutionary_policy_optimization-0.0.40.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from functools import partial
|
3
|
+
from functools import partial, wraps
|
4
4
|
from pathlib import Path
|
5
5
|
from collections import namedtuple
|
6
6
|
|
@@ -9,6 +9,7 @@ from torch import nn, cat, stack, is_tensor, tensor
|
|
9
9
|
import torch.nn.functional as F
|
10
10
|
from torch.nn import Linear, Module, ModuleList
|
11
11
|
from torch.utils.data import TensorDataset, DataLoader
|
12
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
12
13
|
|
13
14
|
import einx
|
14
15
|
from einops import rearrange, repeat, einsum, pack
|
@@ -73,6 +74,19 @@ def gather_log_prob(
|
|
73
74
|
log_prob = log_probs.gather(-1, indices)
|
74
75
|
return rearrange(log_prob, '... 1 -> ...')
|
75
76
|
|
77
|
+
def temp_batch_dim(fn):
|
78
|
+
|
79
|
+
@wraps(fn)
|
80
|
+
def inner(*args, **kwargs):
|
81
|
+
args, kwargs = tree_map(lambda t: rearrange(t, '... -> 1 ...') if is_tensor(t) else t, (args, kwargs))
|
82
|
+
|
83
|
+
out = fn(*args, **kwargs)
|
84
|
+
|
85
|
+
out = tree_map(lambda t: rearrange(t, '1 ... -> ...') if is_tensor(t) else t, out)
|
86
|
+
return out
|
87
|
+
|
88
|
+
return inner
|
89
|
+
|
76
90
|
# generalized advantage estimate
|
77
91
|
|
78
92
|
def calc_generalized_advantage_estimate(
|
@@ -939,20 +953,13 @@ class EPO(Module):
|
|
939
953
|
|
940
954
|
while time < self.max_episode_length:
|
941
955
|
|
942
|
-
batched_state = rearrange(state, '... -> 1 ...')
|
943
|
-
|
944
956
|
# sample action
|
945
957
|
|
946
|
-
action, log_prob = self.agent.get_actor_actions(
|
947
|
-
|
948
|
-
action = rearrange(action, '1 ... -> ...')
|
949
|
-
log_prob = rearrange(log_prob, '1 ... -> ...')
|
958
|
+
action, log_prob = temp_batch_dim(self.agent.get_actor_actions)(state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
950
959
|
|
951
960
|
# values
|
952
961
|
|
953
|
-
value = self.agent.get_critic_values(
|
954
|
-
|
955
|
-
value = rearrange(value, '1 ... -> ...')
|
962
|
+
value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
956
963
|
|
957
964
|
# get the next state, action, and reward
|
958
965
|
|
@@ -981,10 +988,7 @@ class EPO(Module):
|
|
981
988
|
|
982
989
|
# need the final next value for GAE, iiuc
|
983
990
|
|
984
|
-
|
985
|
-
|
986
|
-
next_value = self.agent.get_critic_values(batched_state, latent = latent)
|
987
|
-
next_value = rearrange(next_value, '1 ... -> ...')
|
991
|
+
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
988
992
|
|
989
993
|
memory_for_gae = memory._replace(
|
990
994
|
episode_id = invalid_episode,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.40
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,8 +1,8 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=VZOT1-jdZBE39awP7nhE-I1lGKMTfhUv4Dls9ptNsWk,29854
|
3
3
|
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
4
|
evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
|
5
|
-
evolutionary_policy_optimization-0.0.
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
5
|
+
evolutionary_policy_optimization-0.0.40.dist-info/METADATA,sha256=5ruqqTCmYto8tqkRlc_peBgRhWkhmRdzUef2ot67ky0,5409
|
6
|
+
evolutionary_policy_optimization-0.0.40.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.40.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.40.dist-info/RECORD,,
|
File without changes
|