evolutionary-policy-optimization 0.0.70__py3-none-any.whl → 0.0.72__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +27 -5
- evolutionary_policy_optimization/mock_env.py +3 -2
- {evolutionary_policy_optimization-0.0.70.dist-info → evolutionary_policy_optimization-0.0.72.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.72.dist-info/RECORD +9 -0
- evolutionary_policy_optimization-0.0.70.dist-info/RECORD +0 -9
- {evolutionary_policy_optimization-0.0.70.dist-info → evolutionary_policy_optimization-0.0.72.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.70.dist-info → evolutionary_policy_optimization-0.0.72.dist-info}/licenses/LICENSE +0 -0
| @@ -8,8 +8,10 @@ from functools import partial, wraps | |
| 8 8 | 
             
            from collections import namedtuple
         | 
| 9 9 | 
             
            from random import randrange
         | 
| 10 10 |  | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
             | 
| 11 13 | 
             
            import torch
         | 
| 12 | 
            -
            from torch import nn, cat, stack, is_tensor, tensor, Tensor
         | 
| 14 | 
            +
            from torch import nn, cat, stack, is_tensor, tensor, from_numpy, Tensor
         | 
| 13 15 | 
             
            import torch.nn.functional as F
         | 
| 14 16 | 
             
            import torch.distributed as dist
         | 
| 15 17 | 
             
            from torch.nn import Linear, Module, ModuleList
         | 
| @@ -58,6 +60,21 @@ def divisible_by(num, den): | |
| 58 60 | 
             
            def to_device(inp, device):
         | 
| 59 61 | 
             
                return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
         | 
| 60 62 |  | 
| 63 | 
            +
            def interface_torch_numpy(fn, device):
         | 
| 64 | 
            +
                # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                @wraps(fn)
         | 
| 67 | 
            +
                def decorated_fn(*args, **kwargs):
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    args, kwargs = tree_map(lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t, (args, kwargs))
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    out = fn(*args, **kwargs)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    out = tree_map(lambda t: from_numpy(t).to(device) if isinstance(t, np.ndarray) else t, out)
         | 
| 74 | 
            +
                    return out
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                return decorated_fn
         | 
| 77 | 
            +
             | 
| 61 78 | 
             
            # tensor helpers
         | 
| 62 79 |  | 
| 63 80 | 
             
            def l2norm(t):
         | 
| @@ -1178,7 +1195,7 @@ class EPO(Module): | |
| 1178 1195 | 
             
                    if not exists(memories):
         | 
| 1179 1196 | 
             
                        memories = []
         | 
| 1180 1197 |  | 
| 1181 | 
            -
                    rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
         | 
| 1198 | 
            +
                    rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent), device = self.device)
         | 
| 1182 1199 |  | 
| 1183 1200 | 
             
                    rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
         | 
| 1184 1201 |  | 
| @@ -1193,7 +1210,7 @@ class EPO(Module): | |
| 1193 1210 | 
             
                        if fix_environ_across_latents:
         | 
| 1194 1211 | 
             
                            reset_kwargs.update(seed = maybe_seed)
         | 
| 1195 1212 |  | 
| 1196 | 
            -
                        state = env.reset(**reset_kwargs)
         | 
| 1213 | 
            +
                        state = interface_torch_numpy(env.reset, device = self.device)(**reset_kwargs)
         | 
| 1197 1214 |  | 
| 1198 1215 | 
             
                        # get latent from pool
         | 
| 1199 1216 |  | 
| @@ -1215,7 +1232,7 @@ class EPO(Module): | |
| 1215 1232 |  | 
| 1216 1233 | 
             
                            # get the next state, action, and reward
         | 
| 1217 1234 |  | 
| 1218 | 
            -
                            state, reward, truncated, terminated = env(action)
         | 
| 1235 | 
            +
                            state, reward, truncated, terminated = interface_torch_numpy(env.forward, device = self.device)(action)
         | 
| 1219 1236 |  | 
| 1220 1237 | 
             
                            done = truncated or terminated
         | 
| 1221 1238 |  | 
| @@ -1264,9 +1281,14 @@ class EPO(Module): | |
| 1264 1281 | 
             
                    self,
         | 
| 1265 1282 | 
             
                    agent: Agent,
         | 
| 1266 1283 | 
             
                    env,
         | 
| 1267 | 
            -
                    num_learning_cycles
         | 
| 1284 | 
            +
                    num_learning_cycles,
         | 
| 1285 | 
            +
                    seed = None
         | 
| 1268 1286 | 
             
                ):
         | 
| 1269 1287 |  | 
| 1288 | 
            +
                    if exists(seed):
         | 
| 1289 | 
            +
                        torch.manual_seed(seed)
         | 
| 1290 | 
            +
                        np.random.seed(seed)
         | 
| 1291 | 
            +
             | 
| 1270 1292 | 
             
                    for _ in tqdm(range(num_learning_cycles), desc = 'learning cycle'):
         | 
| 1271 1293 |  | 
| 1272 1294 | 
             
                        memories = self.gather_experience_from(env)
         | 
| @@ -34,7 +34,7 @@ class Env(Module): | |
| 34 34 | 
             
                ):
         | 
| 35 35 | 
             
                    state = randn(self.state_shape, device = self.device)
         | 
| 36 36 | 
             
                    self.step.zero_()
         | 
| 37 | 
            -
                    return state
         | 
| 37 | 
            +
                    return state.numpy()
         | 
| 38 38 |  | 
| 39 39 | 
             
                def forward(
         | 
| 40 40 | 
             
                    self,
         | 
| @@ -51,4 +51,5 @@ class Env(Module): | |
| 51 51 |  | 
| 52 52 | 
             
                    self.step.add_(1)
         | 
| 53 53 |  | 
| 54 | 
            -
                     | 
| 54 | 
            +
                    out = state, reward, truncated, terminated
         | 
| 55 | 
            +
                    return tuple(t.numpy() for t in out)
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.4
         | 
| 2 2 | 
             
            Name: evolutionary-policy-optimization
         | 
| 3 | 
            -
            Version: 0.0. | 
| 3 | 
            +
            Version: 0.0.72
         | 
| 4 4 | 
             
            Summary: EPO - Pytorch
         | 
| 5 5 | 
             
            Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
         | 
| 6 6 | 
             
            Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
         | 
| @@ -0,0 +1,9 @@ | |
| 1 | 
            +
            evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
         | 
| 2 | 
            +
            evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
         | 
| 3 | 
            +
            evolutionary_policy_optimization/epo.py,sha256=h_BaNLFeU5SfVaOcVeMdEdjfYop6enOPAsESQWZflfA,39449
         | 
| 4 | 
            +
            evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
         | 
| 5 | 
            +
            evolutionary_policy_optimization/mock_env.py,sha256=QhPYIazU3uFPKTMZyFw70KhXiDUBr5aU3v1idotfVFI,1391
         | 
| 6 | 
            +
            evolutionary_policy_optimization-0.0.72.dist-info/METADATA,sha256=kZZupJU1gYm6MEmDYB2D-gIF2NZsAY4qcK9Z5pAT1mQ,6304
         | 
| 7 | 
            +
            evolutionary_policy_optimization-0.0.72.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         | 
| 8 | 
            +
            evolutionary_policy_optimization-0.0.72.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         | 
| 9 | 
            +
            evolutionary_policy_optimization-0.0.72.dist-info/RECORD,,
         | 
| @@ -1,9 +0,0 @@ | |
| 1 | 
            -
            evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
         | 
| 2 | 
            -
            evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
         | 
| 3 | 
            -
            evolutionary_policy_optimization/epo.py,sha256=MmsqMwytVqBkb1f2piUygueOfn--Icb817P4bDcfPks,38683
         | 
| 4 | 
            -
            evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
         | 
| 5 | 
            -
            evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
         | 
| 6 | 
            -
            evolutionary_policy_optimization-0.0.70.dist-info/METADATA,sha256=fX_hR3dCjKUQ3VZtT-sUQy0qT8sF-nDPyP0QDAGHf60,6304
         | 
| 7 | 
            -
            evolutionary_policy_optimization-0.0.70.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
         | 
| 8 | 
            -
            evolutionary_policy_optimization-0.0.70.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
         | 
| 9 | 
            -
            evolutionary_policy_optimization-0.0.70.dist-info/RECORD,,
         | 
| 
            File without changes
         |