x-evolution 0.1.19__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from functools import partial
7
7
 
8
8
  import torch
9
+ import torch.distributed as dist
9
10
  from torch import tensor, Tensor, stack, is_tensor, arange, randint
10
11
  from torch.nn import Module, ModuleList, Parameter, ParameterList
11
12
  from torch.optim import SGD, Adam, Optimizer
@@ -112,6 +113,8 @@ class EvoStrategy(Module):
112
113
  self.print('pre_main_callback detected on environment passed in and is invoked')
113
114
  environment.pre_main_callback()
114
115
 
116
+ accelerator.wait_for_everyone()
117
+
115
118
  # take care of model and parameters
116
119
 
117
120
  if isinstance(model, list):
@@ -120,6 +123,13 @@ class EvoStrategy(Module):
120
123
  self.model = model
121
124
  self.noisable_model = Noisable(model, low_rank = noise_low_rank)
122
125
 
126
+ # use prepare and run through environment once to sync params
127
+
128
+ wrapped_model = accelerator.prepare(model)
129
+ environment(wrapped_model)
130
+
131
+ # get param dictionary
132
+
123
133
  named_parameters_dict = dict(model.named_parameters())
124
134
 
125
135
  param_to_name_index = {param: name for name, param in named_parameters_dict.items()}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.19
3
+ Version: 0.1.22
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -0,0 +1,6 @@
1
+ x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
+ x_evolution/x_evolution.py,sha256=ui4HgIcJ_qM0JtL-HzsiwWNF6r5Ybo9c2DG0mWTXo4w,19124
3
+ x_evolution-0.1.22.dist-info/METADATA,sha256=FrF0ebgJVX2NrZkjgcH4QHozWfXMhGtTGMpXlRnP-cE,5728
4
+ x_evolution-0.1.22.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ x_evolution-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ x_evolution-0.1.22.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
- x_evolution/x_evolution.py,sha256=C70fX5PdLx1VLBXe1DYGBIOExucvLFbKDWiofBMrvQA,18860
3
- x_evolution-0.1.19.dist-info/METADATA,sha256=7xr4uJpkmkkg0bi9eByDktqRvm_IIcvHJcLAJUjB2QM,5728
4
- x_evolution-0.1.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- x_evolution-0.1.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- x_evolution-0.1.19.dist-info/RECORD,,