x-evolution 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from torch.nn import Module, ModuleList, Parameter, ParameterList
11
11
  from torch.optim import SGD, Adam, Optimizer
12
12
  from torch.optim.lr_scheduler import LRScheduler
13
13
 
14
+ import torch.distributed as dist
14
15
  import torch.nn.functional as F
15
16
 
16
17
  from beartype import beartype
@@ -95,7 +96,8 @@ class EvoStrategy(Module):
95
96
  accelerate_kwargs: dict = dict(),
96
97
  reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None,
97
98
  vectorized = False,
98
- vector_size: int | None = None
99
+ vector_size: int | None = None,
100
+ sync_on_init = True
99
101
  ):
100
102
  super().__init__()
101
103
  self.verbose = verbose
@@ -127,12 +129,10 @@ class EvoStrategy(Module):
127
129
  self.model = model
128
130
  self.noisable_model = Noisable(model, low_rank = noise_low_rank)
129
131
 
130
- # use prepare and run through environment once to sync params
132
+ # maybe sync model params and buffers
131
133
 
132
- wrapped_model = accelerator.prepare(model)
133
-
134
- with torch.no_grad():
135
- environment(wrapped_model)
134
+ if sync_on_init:
135
+ self.sync_model_params_and_buffers_()
136
136
 
137
137
  # get param dictionary
138
138
 
@@ -254,6 +254,17 @@ class EvoStrategy(Module):
254
254
  def device(self):
255
255
  return self.accelerate.device
256
256
 
257
+ @torch.no_grad()
258
+ def sync_model_params_and_buffers_(self):
259
+ if not self.accelerate.num_processes > 1:
260
+ return
261
+
262
+ for param in self.model.parameters():
263
+ dist.broadcast(param, src = 0)
264
+
265
+ for buffer in self.model.buffers():
266
+ dist.broadcast(buffer, src = 0)
267
+
257
268
  def print(self, *args, **kwargs):
258
269
  if not self.verbose:
259
270
  return
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.26
3
+ Version: 0.1.27
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -0,0 +1,6 @@
1
+ x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
+ x_evolution/x_evolution.py,sha256=-G5qXGMjwVzdkxIDR6xL_YGium4KfKC0cnlY76Upy0o,19799
3
+ x_evolution-0.1.27.dist-info/METADATA,sha256=hj0MUpIGVWoOY5wHsoy_ZF_cx7s48_HZicd4IgNUFEo,5853
4
+ x_evolution-0.1.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ x_evolution-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ x_evolution-0.1.27.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
- x_evolution/x_evolution.py,sha256=Jln3wpkIQIp7xwa3KMibh0kSuob1NIi7Aj7Miz8RJdY,19491
3
- x_evolution-0.1.26.dist-info/METADATA,sha256=7zamSGrDtvOUQzpyYCkrt42FmWC4GPYpVSWUAO8a6OA,5853
4
- x_evolution-0.1.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- x_evolution-0.1.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- x_evolution-0.1.26.dist-info/RECORD,,