evolutionary-policy-optimization 0.0.54__py3-none-any.whl → 0.0.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,12 @@ def pad_dim_to(t, length, dim = 0):
25
25
  def is_distributed():
26
26
  return dist.is_initialized() and dist.get_world_size() > 1
27
27
 
28
+ def get_world_and_rank():
29
+ if not is_distributed():
30
+ return 1, 0
31
+
32
+ return dist.get_world_size(), dist.get_rank()
33
+
28
34
  def maybe_sync_seed(device, max_size = int(1e6)):
29
35
  rand_int = torch.randint(0, max_size, (), device = device)
30
36
 
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial, wraps
4
3
  from pathlib import Path
4
+ from math import ceil
5
+ from functools import partial, wraps
5
6
  from collections import namedtuple
6
7
  from random import randrange
7
8
 
@@ -19,6 +20,7 @@ from einops.layers.torch import Rearrange
19
20
 
20
21
  from evolutionary_policy_optimization.distributed import (
21
22
  is_distributed,
23
+ get_world_and_rank,
22
24
  maybe_sync_seed,
23
25
  all_gather_variable_dim,
24
26
  maybe_barrier
@@ -1061,22 +1063,20 @@ class EPO(Module):
1061
1063
  def latents_for_machine(self):
1062
1064
  num_latents = self.num_latents
1063
1065
 
1064
- if not is_distributed():
1065
- return list(range(self.num_latents))
1066
+ world_size, rank = get_world_and_rank()
1066
1067
 
1067
- world_size, rank = dist.get_world_size(), dist.get_rank()
1068
1068
  assert num_latents >= world_size, 'number of latents must be greater than world size for now'
1069
1069
  assert rank < world_size
1070
1070
 
1071
- pad_id = -1
1072
- num_latents_rounded_up = ceil(num_latents / world_size) * world_size
1073
- latent_ids = torch.arange(num_latents_rounded_up)
1074
- latent_ids[latent_ids >= num_latents] = pad_id
1071
+ num_latents_per_machine = ceil(num_latents / world_size)
1072
+
1073
+ for i in range(num_latents_per_machine):
1074
+ latent_id = rank * num_latents_per_machine + i
1075
1075
 
1076
- latent_ids = rearrange(latent_ids, '(world latents) -> world latents', world = world_size)
1077
- out = latent_ids[rank]
1076
+ if latent_id >= num_latents:
1077
+ continue
1078
1078
 
1079
- return out[out != pad_id].tolist()
1079
+ yield i
1080
1080
 
1081
1081
  @torch.no_grad()
1082
1082
  def forward(
@@ -1093,7 +1093,7 @@ class EPO(Module):
1093
1093
 
1094
1094
  cumulative_rewards = torch.zeros((self.num_latents))
1095
1095
 
1096
- latent_ids = self.latents_for_machine()
1096
+ latent_ids_gen = self.latents_for_machine()
1097
1097
 
1098
1098
  for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
1099
1099
 
@@ -1109,7 +1109,7 @@ class EPO(Module):
1109
1109
 
1110
1110
  # for each latent (on a single machine for now)
1111
1111
 
1112
- for latent_id in tqdm(latent_ids, desc = 'latent'):
1112
+ for latent_id in tqdm(latent_ids_gen, desc = 'latent'):
1113
1113
  time = 0
1114
1114
 
1115
1115
  # initial state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.54
3
+ Version: 0.0.55
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=e0AI7S5QK_uLfokzWTnsAua_HcPW0PyqY-GzUUev0R8,35123
4
+ evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
5
+ evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
+ evolutionary_policy_optimization-0.0.55.dist-info/METADATA,sha256=nsWgp2caBwAiWKMU_BH6Sw58gHdpxE29vXxbAXxWa70,6213
7
+ evolutionary_policy_optimization-0.0.55.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.55.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.55.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=lSSf_vB04NgVJFBh2n36cGuKZWgOpp8PnPpLDmHT6nU,2296
3
- evolutionary_policy_optimization/epo.py,sha256=5QJj_l4pihbSdRk1aZnE2dUyWlaqb_VjIKo6Azzksgs,35292
4
- evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
5
- evolutionary_policy_optimization/mock_env.py,sha256=202KJ5g57wQvOzhGYzgHfBa7Y2do5uuDvl5kFg5o73g,934
6
- evolutionary_policy_optimization-0.0.54.dist-info/METADATA,sha256=phQq8QaMT7TQQG2Sqz1BW4E1dln1HU10DMExwRvGGkg,6213
7
- evolutionary_policy_optimization-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.54.dist-info/RECORD,,