evolutionary-policy-optimization 0.0.67__py3-none-any.whl → 0.0.68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +15 -1
- {evolutionary_policy_optimization-0.0.67.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.0.67.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.67.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.67.dist-info → evolutionary_policy_optimization-0.0.68.dist-info}/licenses/LICENSE +0 -0
@@ -55,6 +55,9 @@ def xnor(x, y):
|
|
55
55
|
def divisible_by(num, den):
|
56
56
|
return (num % den) == 0
|
57
57
|
|
58
|
+
def to_device(inp, device):
|
59
|
+
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
|
60
|
+
|
58
61
|
# tensor helpers
|
59
62
|
|
60
63
|
def l2norm(t):
|
@@ -744,6 +747,12 @@ class Agent(Module):
|
|
744
747
|
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
745
748
|
self.diversity_aux_loss_weight = diversity_aux_loss_weight
|
746
749
|
|
750
|
+
self.register_buffer('dummy', tensor(0))
|
751
|
+
|
752
|
+
@property
|
753
|
+
def device(self):
|
754
|
+
return self.dummy.device
|
755
|
+
|
747
756
|
def save(self, path, overwrite = False):
|
748
757
|
path = Path(path)
|
749
758
|
|
@@ -838,7 +847,10 @@ class Agent(Module):
|
|
838
847
|
self,
|
839
848
|
memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
|
840
849
|
epochs = 2
|
850
|
+
|
841
851
|
):
|
852
|
+
memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
|
853
|
+
|
842
854
|
memories, rewards_per_latent_episode = memories_and_cumulative_rewards
|
843
855
|
|
844
856
|
# stack memories
|
@@ -1217,6 +1229,8 @@ class EPO(Module):
|
|
1217
1229
|
terminated
|
1218
1230
|
)
|
1219
1231
|
|
1232
|
+
memory = Memory(*tuple(t.cpu() for t in memory))
|
1233
|
+
|
1220
1234
|
memories.append(memory)
|
1221
1235
|
|
1222
1236
|
time += 1
|
@@ -1228,7 +1242,7 @@ class EPO(Module):
|
|
1228
1242
|
|
1229
1243
|
memory_for_gae = memory._replace(
|
1230
1244
|
episode_id = invalid_episode,
|
1231
|
-
value = next_value,
|
1245
|
+
value = next_value.cpu(),
|
1232
1246
|
done = tensor(True)
|
1233
1247
|
)
|
1234
1248
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.68
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,9 +1,9 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=xhE_kHas54xGsgOese9SQEvyK7NKZqEuK3AiVhm0y7Q,38047
|
4
4
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
5
|
evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
9
|
-
evolutionary_policy_optimization-0.0.
|
6
|
+
evolutionary_policy_optimization-0.0.68.dist-info/METADATA,sha256=hOOKOrrPQtQmK3zN1z5nkGJEoaQLyXUzs9ArsEKn1DE,6220
|
7
|
+
evolutionary_policy_optimization-0.0.68.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.68.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.68.dist-info/RECORD,,
|
File without changes
|