evolutionary-policy-optimization 0.0.70__py3-none-any.whl → 0.0.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,8 +8,10 @@ from functools import partial, wraps
8
8
  from collections import namedtuple
9
9
  from random import randrange
10
10
 
11
+ import numpy as np
12
+
11
13
  import torch
12
- from torch import nn, cat, stack, is_tensor, tensor, Tensor
14
+ from torch import nn, cat, stack, is_tensor, tensor, from_numpy, Tensor
13
15
  import torch.nn.functional as F
14
16
  import torch.distributed as dist
15
17
  from torch.nn import Linear, Module, ModuleList
@@ -58,6 +60,21 @@ def divisible_by(num, den):
58
60
  def to_device(inp, device):
59
61
  return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
60
62
 
63
+ def interface_torch_numpy(fn, device):
64
+ # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
65
+
66
+ @wraps(fn)
67
+ def decorated_fn(*args, **kwargs):
68
+
69
+ args, kwargs = tree_map(lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t, (args, kwargs))
70
+
71
+ out = fn(*args, **kwargs)
72
+
73
+ out = tree_map(lambda t: from_numpy(t).to(device) if isinstance(t, np.ndarray) else t, out)
74
+ return out
75
+
76
+ return decorated_fn
77
+
61
78
  # tensor helpers
62
79
 
63
80
  def l2norm(t):
@@ -1178,7 +1195,7 @@ class EPO(Module):
1178
1195
  if not exists(memories):
1179
1196
  memories = []
1180
1197
 
1181
- rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
1198
+ rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent), device = self.device)
1182
1199
 
1183
1200
  rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
1184
1201
 
@@ -1193,7 +1210,7 @@ class EPO(Module):
1193
1210
  if fix_environ_across_latents:
1194
1211
  reset_kwargs.update(seed = maybe_seed)
1195
1212
 
1196
- state = env.reset(**reset_kwargs)
1213
+ state = interface_torch_numpy(env.reset, device = self.device)(**reset_kwargs)
1197
1214
 
1198
1215
  # get latent from pool
1199
1216
 
@@ -1215,7 +1232,7 @@ class EPO(Module):
1215
1232
 
1216
1233
  # get the next state, action, and reward
1217
1234
 
1218
- state, reward, truncated, terminated = env(action)
1235
+ state, reward, truncated, terminated = interface_torch_numpy(env.forward, device = self.device)(action)
1219
1236
 
1220
1237
  done = truncated or terminated
1221
1238
 
@@ -1264,9 +1281,14 @@ class EPO(Module):
1264
1281
  self,
1265
1282
  agent: Agent,
1266
1283
  env,
1267
- num_learning_cycles
1284
+ num_learning_cycles,
1285
+ seed = None
1268
1286
  ):
1269
1287
 
1288
+ if exists(seed):
1289
+ torch.manual_seed(seed)
1290
+ np.random.seed(seed)
1291
+
1270
1292
  for _ in tqdm(range(num_learning_cycles), desc = 'learning cycle'):
1271
1293
 
1272
1294
  memories = self.gather_experience_from(env)
@@ -34,7 +34,7 @@ class Env(Module):
34
34
  ):
35
35
  state = randn(self.state_shape, device = self.device)
36
36
  self.step.zero_()
37
- return state
37
+ return state.numpy()
38
38
 
39
39
  def forward(
40
40
  self,
@@ -51,4 +51,5 @@ class Env(Module):
51
51
 
52
52
  self.step.add_(1)
53
53
 
54
- return state, reward, truncated, terminated
54
+ out = state, reward, truncated, terminated
55
+ return tuple(t.numpy() for t in out)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.70
3
+ Version: 0.0.72
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=h_BaNLFeU5SfVaOcVeMdEdjfYop6enOPAsESQWZflfA,39449
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=QhPYIazU3uFPKTMZyFw70KhXiDUBr5aU3v1idotfVFI,1391
6
+ evolutionary_policy_optimization-0.0.72.dist-info/METADATA,sha256=kZZupJU1gYm6MEmDYB2D-gIF2NZsAY4qcK9Z5pAT1mQ,6304
7
+ evolutionary_policy_optimization-0.0.72.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.72.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.72.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=MmsqMwytVqBkb1f2piUygueOfn--Icb817P4bDcfPks,38683
4
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
- evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
- evolutionary_policy_optimization-0.0.70.dist-info/METADATA,sha256=fX_hR3dCjKUQ3VZtT-sUQy0qT8sF-nDPyP0QDAGHf60,6304
7
- evolutionary_policy_optimization-0.0.70.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.70.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.70.dist-info/RECORD,,