evolutionary-policy-optimization 0.0.46__py3-none-any.whl → 0.0.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from functools import partial, wraps
4
4
  from pathlib import Path
5
5
  from collections import namedtuple
6
+ from random import randrange
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, stack, is_tensor, tensor
@@ -989,7 +990,8 @@ class EPO(Module):
989
990
  @torch.no_grad()
990
991
  def forward(
991
992
  self,
992
- env
993
+ env,
994
+ fix_seed_across_latents = True
993
995
  ) -> MemoriesAndCumulativeRewards:
994
996
 
995
997
  self.agent.eval()
@@ -1002,12 +1004,22 @@ class EPO(Module):
1002
1004
 
1003
1005
  for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1004
1006
 
1007
+ # maybe fix seed for environment across all latents
1008
+
1009
+ env_reset_kwargs = dict()
1010
+
1011
+ if fix_seed_across_latents:
1012
+ seed = randrange(int(1e6))
1013
+ env_reset_kwargs = dict(seed = seed)
1014
+
1015
+ # for each latent (on a single machine for now)
1016
+
1005
1017
  for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
1006
1018
  time = 0
1007
1019
 
1008
1020
  # initial state
1009
1021
 
1010
- state = env.reset()
1022
+ state = env.reset(**env_reset_kwargs)
1011
1023
 
1012
1024
  # get latent from pool
1013
1025
 
@@ -25,7 +25,8 @@ class Env(Module):
25
25
  return self.dummy.device
26
26
 
27
27
  def reset(
28
- self
28
+ self,
29
+ seed
29
30
  ):
30
31
  state = randn(self.state_shape, device = self.device)
31
32
  return state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.46
3
+ Version: 0.0.47
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,8 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/epo.py,sha256=-uRpnD0dKF6h4drVSikm9HnlP2OZ0WYQSWRQcghzd9Y,32242
3
+ evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
+ evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
5
+ evolutionary_policy_optimization-0.0.47.dist-info/METADATA,sha256=oSI5NowsOOlQZ5cPmCs-8kYeG6TmzUybpRZt_6-cFWk,6213
6
+ evolutionary_policy_optimization-0.0.47.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.47.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.47.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/epo.py,sha256=SAhWgRY8uPQEKFg1_nz1mvh8A6S_sHwnDykhd0F5xEI,31853
3
- evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
- evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
5
- evolutionary_policy_optimization-0.0.46.dist-info/METADATA,sha256=xP2kdKo52-X4Z5XXTPpW0M_NFI0spuigeL7fvqFlsRM,6213
6
- evolutionary_policy_optimization-0.0.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.46.dist-info/RECORD,,