evolutionary-policy-optimization 0.0.54__py3-none-any.whl → 0.0.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/distributed.py +6 -0
- evolutionary_policy_optimization/epo.py +13 -13
- {evolutionary_policy_optimization-0.0.54.dist-info → evolutionary_policy_optimization-0.0.55.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.55.dist-info/RECORD +9 -0
- evolutionary_policy_optimization-0.0.54.dist-info/RECORD +0 -9
- {evolutionary_policy_optimization-0.0.54.dist-info → evolutionary_policy_optimization-0.0.55.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.54.dist-info → evolutionary_policy_optimization-0.0.55.dist-info}/licenses/LICENSE +0 -0
@@ -25,6 +25,12 @@ def pad_dim_to(t, length, dim = 0):
|
|
25
25
|
def is_distributed():
|
26
26
|
return dist.is_initialized() and dist.get_world_size() > 1
|
27
27
|
|
28
|
+
def get_world_and_rank():
|
29
|
+
if not is_distributed():
|
30
|
+
return 1, 0
|
31
|
+
|
32
|
+
return dist.get_world_size(), dist.get_rank()
|
33
|
+
|
28
34
|
def maybe_sync_seed(device, max_size = int(1e6)):
|
29
35
|
rand_int = torch.randint(0, max_size, (), device = device)
|
30
36
|
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from functools import partial, wraps
|
4
3
|
from pathlib import Path
|
4
|
+
from math import ceil
|
5
|
+
from functools import partial, wraps
|
5
6
|
from collections import namedtuple
|
6
7
|
from random import randrange
|
7
8
|
|
@@ -19,6 +20,7 @@ from einops.layers.torch import Rearrange
|
|
19
20
|
|
20
21
|
from evolutionary_policy_optimization.distributed import (
|
21
22
|
is_distributed,
|
23
|
+
get_world_and_rank,
|
22
24
|
maybe_sync_seed,
|
23
25
|
all_gather_variable_dim,
|
24
26
|
maybe_barrier
|
@@ -1061,22 +1063,20 @@ class EPO(Module):
|
|
1061
1063
|
def latents_for_machine(self):
|
1062
1064
|
num_latents = self.num_latents
|
1063
1065
|
|
1064
|
-
|
1065
|
-
return list(range(self.num_latents))
|
1066
|
+
world_size, rank = get_world_and_rank()
|
1066
1067
|
|
1067
|
-
world_size, rank = dist.get_world_size(), dist.get_rank()
|
1068
1068
|
assert num_latents >= world_size, 'number of latents must be greater than world size for now'
|
1069
1069
|
assert rank < world_size
|
1070
1070
|
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1071
|
+
num_latents_per_machine = ceil(num_latents / world_size)
|
1072
|
+
|
1073
|
+
for i in range(num_latents_per_machine):
|
1074
|
+
latent_id = rank * num_latents_per_machine + i
|
1075
1075
|
|
1076
|
-
|
1077
|
-
|
1076
|
+
if latent_id >= num_latents:
|
1077
|
+
continue
|
1078
1078
|
|
1079
|
-
|
1079
|
+
yield i
|
1080
1080
|
|
1081
1081
|
@torch.no_grad()
|
1082
1082
|
def forward(
|
@@ -1093,7 +1093,7 @@ class EPO(Module):
|
|
1093
1093
|
|
1094
1094
|
cumulative_rewards = torch.zeros((self.num_latents))
|
1095
1095
|
|
1096
|
-
|
1096
|
+
latent_ids_gen = self.latents_for_machine()
|
1097
1097
|
|
1098
1098
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
1099
1099
|
|
@@ -1109,7 +1109,7 @@ class EPO(Module):
|
|
1109
1109
|
|
1110
1110
|
# for each latent (on a single machine for now)
|
1111
1111
|
|
1112
|
-
for latent_id in tqdm(
|
1112
|
+
for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
|
1113
1113
|
time = 0
|
1114
1114
|
|
1115
1115
|
# initial state
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.55
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,9 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=e0AI7S5QK_uLfokzWTnsAua_HcPW0PyqY-GzUUev0R8,35123
|
4
|
+
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
5
|
+
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
6
|
+
evolutionary_policy_optimization-0.0.55.dist-info/METADATA,sha256=nsWgp2caBwAiWKMU_BH6Sw58gHdpxE29vXxbAXxWa70,6213
|
7
|
+
evolutionary_policy_optimization-0.0.55.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.55.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.55.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/distributed.py,sha256=lSSf_vB04NgVJFBh2n36cGuKZWgOpp8PnPpLDmHT6nU,2296
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=5QJj_l4pihbSdRk1aZnE2dUyWlaqb_VjIKo6Azzksgs,35292
|
4
|
-
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
5
|
-
evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
|
6
|
-
evolutionary_policy_optimization-0.0.54.dist-info/METADATA,sha256=phQq8QaMT7TQQG2Sqz1BW4E1dln1HU10DMExwRvGGkg,6213
|
7
|
-
evolutionary_policy_optimization-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
evolutionary_policy_optimization-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
evolutionary_policy_optimization-0.0.54.dist-info/RECORD,,
|
File without changes
|