evolutionary-policy-optimization 0.0.67__py3-none-any.whl → 0.0.68__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -55,6 +55,9 @@ def xnor(x, y):
55
55
  def divisible_by(num, den):
56
56
  return (num % den) == 0
57
57
 
58
+ def to_device(inp, device):
59
+ return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
60
+
58
61
  # tensor helpers
59
62
 
60
63
  def l2norm(t):
@@ -744,6 +747,12 @@ class Agent(Module):
744
747
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
745
748
  self.diversity_aux_loss_weight = diversity_aux_loss_weight
746
749
 
750
+ self.register_buffer('dummy', tensor(0))
751
+
752
+ @property
753
+ def device(self):
754
+ return self.dummy.device
755
+
747
756
  def save(self, path, overwrite = False):
748
757
  path = Path(path)
749
758
 
@@ -838,7 +847,10 @@ class Agent(Module):
838
847
  self,
839
848
  memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
840
849
  epochs = 2
850
+
841
851
  ):
852
+ memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
853
+
842
854
  memories, rewards_per_latent_episode = memories_and_cumulative_rewards
843
855
 
844
856
  # stack memories
@@ -1217,6 +1229,8 @@ class EPO(Module):
1217
1229
  terminated
1218
1230
  )
1219
1231
 
1232
+ memory = Memory(*tuple(t.cpu() for t in memory))
1233
+
1220
1234
  memories.append(memory)
1221
1235
 
1222
1236
  time += 1
@@ -1228,7 +1242,7 @@ class EPO(Module):
1228
1242
 
1229
1243
  memory_for_gae = memory._replace(
1230
1244
  episode_id = invalid_episode,
1231
- value = next_value,
1245
+ value = next_value.cpu(),
1232
1246
  done = tensor(True)
1233
1247
  )
1234
1248
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.67
3
+ Version: 0.0.68
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,9 +1,9 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=7WOGl22WknudsNLSZ18AWwW7rPt5ITMAtByetUwLp7M,37654
3
+ evolutionary_policy_optimization/epo.py,sha256=xhE_kHas54xGsgOese9SQEvyK7NKZqEuK3AiVhm0y7Q,38047
4
4
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
5
  evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
- evolutionary_policy_optimization-0.0.67.dist-info/METADATA,sha256=S1biwayyDA4vTOXknMU5KeWtJTFvxePHZZ0OZzuaNms,6220
7
- evolutionary_policy_optimization-0.0.67.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.67.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.67.dist-info/RECORD,,
6
+ evolutionary_policy_optimization-0.0.68.dist-info/METADATA,sha256=hOOKOrrPQtQmK3zN1z5nkGJEoaQLyXUzs9ArsEKn1DE,6220
7
+ evolutionary_policy_optimization-0.0.68.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.68.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.68.dist-info/RECORD,,