evolutionary-policy-optimization 0.0.46__py3-none-any.whl → 0.0.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +14 -2
- evolutionary_policy_optimization/mock_env.py +2 -1
- {evolutionary_policy_optimization-0.0.46.dist-info → evolutionary_policy_optimization-0.0.47.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.47.dist-info/RECORD +8 -0
- evolutionary_policy_optimization-0.0.46.dist-info/RECORD +0 -8
- {evolutionary_policy_optimization-0.0.46.dist-info → evolutionary_policy_optimization-0.0.47.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.46.dist-info → evolutionary_policy_optimization-0.0.47.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from functools import partial, wraps
|
4
4
|
from pathlib import Path
|
5
5
|
from collections import namedtuple
|
6
|
+
from random import randrange
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch import nn, cat, stack, is_tensor, tensor
|
@@ -989,7 +990,8 @@ class EPO(Module):
|
|
989
990
|
@torch.no_grad()
|
990
991
|
def forward(
|
991
992
|
self,
|
992
|
-
env
|
993
|
+
env,
|
994
|
+
fix_seed_across_latents = True
|
993
995
|
) -> MemoriesAndCumulativeRewards:
|
994
996
|
|
995
997
|
self.agent.eval()
|
@@ -1002,12 +1004,22 @@ class EPO(Module):
|
|
1002
1004
|
|
1003
1005
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1004
1006
|
|
1007
|
+
# maybe fix seed for environment across all latents
|
1008
|
+
|
1009
|
+
env_reset_kwargs = dict()
|
1010
|
+
|
1011
|
+
if fix_seed_across_latents:
|
1012
|
+
seed = randrange(int(1e6))
|
1013
|
+
env_reset_kwargs = dict(seed = seed)
|
1014
|
+
|
1015
|
+
# for each latent (on a single machine for now)
|
1016
|
+
|
1005
1017
|
for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
|
1006
1018
|
time = 0
|
1007
1019
|
|
1008
1020
|
# initial state
|
1009
1021
|
|
1010
|
-
state = env.reset()
|
1022
|
+
state = env.reset(**env_reset_kwargs)
|
1011
1023
|
|
1012
1024
|
# get latent from pool
|
1013
1025
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.47
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,8 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=-uRpnD0dKF6h4drVSikm9HnlP2OZ0WYQSWRQcghzd9Y,32242
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
|
+
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
5
|
+
evolutionary_policy_optimization-0.0.47.dist-info/METADATA,sha256=oSI5NowsOOlQZ5cPmCs-8kYeG6TmzUybpRZt_6-cFWk,6213
|
6
|
+
evolutionary_policy_optimization-0.0.47.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.47.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.47.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=SAhWgRY8uPQEKFg1_nz1mvh8A6S_sHwnDykhd0F5xEI,31853
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
|
-
evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
|
5
|
-
evolutionary_policy_optimization-0.0.46.dist-info/METADATA,sha256=xP2kdKo52-X4Z5XXTPpW0M_NFI0spuigeL7fvqFlsRM,6213
|
6
|
-
evolutionary_policy_optimization-0.0.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
evolutionary_policy_optimization-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
evolutionary_policy_optimization-0.0.46.dist-info/RECORD,,
|
File without changes
|