evolutionary-policy-optimization 0.0.70__tar.gz → 0.0.72__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/PKG-INFO +1 -1
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/evolutionary_policy_optimization/epo.py +27 -5
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/evolutionary_policy_optimization/mock_env.py +3 -2
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/evolutionary_policy_optimization/distributed.py +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/requirements.txt +0 -0
- {evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/tests/test_epo.py +0 -0
{evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.72
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -8,8 +8,10 @@ from functools import partial, wraps
|
|
8
8
|
from collections import namedtuple
|
9
9
|
from random import randrange
|
10
10
|
|
11
|
+
import numpy as np
|
12
|
+
|
11
13
|
import torch
|
12
|
-
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
14
|
+
from torch import nn, cat, stack, is_tensor, tensor, from_numpy, Tensor
|
13
15
|
import torch.nn.functional as F
|
14
16
|
import torch.distributed as dist
|
15
17
|
from torch.nn import Linear, Module, ModuleList
|
@@ -58,6 +60,21 @@ def divisible_by(num, den):
|
|
58
60
|
def to_device(inp, device):
|
59
61
|
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
|
60
62
|
|
63
|
+
def interface_torch_numpy(fn, device):
|
64
|
+
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
65
|
+
|
66
|
+
@wraps(fn)
|
67
|
+
def decorated_fn(*args, **kwargs):
|
68
|
+
|
69
|
+
args, kwargs = tree_map(lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t, (args, kwargs))
|
70
|
+
|
71
|
+
out = fn(*args, **kwargs)
|
72
|
+
|
73
|
+
out = tree_map(lambda t: from_numpy(t).to(device) if isinstance(t, np.ndarray) else t, out)
|
74
|
+
return out
|
75
|
+
|
76
|
+
return decorated_fn
|
77
|
+
|
61
78
|
# tensor helpers
|
62
79
|
|
63
80
|
def l2norm(t):
|
@@ -1178,7 +1195,7 @@ class EPO(Module):
|
|
1178
1195
|
if not exists(memories):
|
1179
1196
|
memories = []
|
1180
1197
|
|
1181
|
-
rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
|
1198
|
+
rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent), device = self.device)
|
1182
1199
|
|
1183
1200
|
rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
|
1184
1201
|
|
@@ -1193,7 +1210,7 @@ class EPO(Module):
|
|
1193
1210
|
if fix_environ_across_latents:
|
1194
1211
|
reset_kwargs.update(seed = maybe_seed)
|
1195
1212
|
|
1196
|
-
state = env.reset(**reset_kwargs)
|
1213
|
+
state = interface_torch_numpy(env.reset, device = self.device)(**reset_kwargs)
|
1197
1214
|
|
1198
1215
|
# get latent from pool
|
1199
1216
|
|
@@ -1215,7 +1232,7 @@ class EPO(Module):
|
|
1215
1232
|
|
1216
1233
|
# get the next state, action, and reward
|
1217
1234
|
|
1218
|
-
state, reward, truncated, terminated = env(action)
|
1235
|
+
state, reward, truncated, terminated = interface_torch_numpy(env.forward, device = self.device)(action)
|
1219
1236
|
|
1220
1237
|
done = truncated or terminated
|
1221
1238
|
|
@@ -1264,9 +1281,14 @@ class EPO(Module):
|
|
1264
1281
|
self,
|
1265
1282
|
agent: Agent,
|
1266
1283
|
env,
|
1267
|
-
num_learning_cycles
|
1284
|
+
num_learning_cycles,
|
1285
|
+
seed = None
|
1268
1286
|
):
|
1269
1287
|
|
1288
|
+
if exists(seed):
|
1289
|
+
torch.manual_seed(seed)
|
1290
|
+
np.random.seed(seed)
|
1291
|
+
|
1270
1292
|
for _ in tqdm(range(num_learning_cycles), desc = 'learning cycle'):
|
1271
1293
|
|
1272
1294
|
memories = self.gather_experience_from(env)
|
@@ -34,7 +34,7 @@ class Env(Module):
|
|
34
34
|
):
|
35
35
|
state = randn(self.state_shape, device = self.device)
|
36
36
|
self.step.zero_()
|
37
|
-
return state
|
37
|
+
return state.numpy()
|
38
38
|
|
39
39
|
def forward(
|
40
40
|
self,
|
@@ -51,4 +51,5 @@ class Env(Module):
|
|
51
51
|
|
52
52
|
self.step.add_(1)
|
53
53
|
|
54
|
-
|
54
|
+
out = state, reward, truncated, terminated
|
55
|
+
return tuple(t.numpy() for t in out)
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.70 → evolutionary_policy_optimization-0.0.72}/requirements.txt
RENAMED
File without changes
|
File without changes
|