x-evolution 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_evolution/x_evolution.py +17 -6
- {x_evolution-0.1.26.dist-info → x_evolution-0.1.27.dist-info}/METADATA +1 -1
- x_evolution-0.1.27.dist-info/RECORD +6 -0
- x_evolution-0.1.26.dist-info/RECORD +0 -6
- {x_evolution-0.1.26.dist-info → x_evolution-0.1.27.dist-info}/WHEEL +0 -0
- {x_evolution-0.1.26.dist-info → x_evolution-0.1.27.dist-info}/licenses/LICENSE +0 -0
x_evolution/x_evolution.py
CHANGED
|
@@ -11,6 +11,7 @@ from torch.nn import Module, ModuleList, Parameter, ParameterList
|
|
|
11
11
|
from torch.optim import SGD, Adam, Optimizer
|
|
12
12
|
from torch.optim.lr_scheduler import LRScheduler
|
|
13
13
|
|
|
14
|
+
import torch.distributed as dist
|
|
14
15
|
import torch.nn.functional as F
|
|
15
16
|
|
|
16
17
|
from beartype import beartype
|
|
@@ -95,7 +96,8 @@ class EvoStrategy(Module):
|
|
|
95
96
|
accelerate_kwargs: dict = dict(),
|
|
96
97
|
reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None,
|
|
97
98
|
vectorized = False,
|
|
98
|
-
vector_size: int | None = None
|
|
99
|
+
vector_size: int | None = None,
|
|
100
|
+
sync_on_init = True
|
|
99
101
|
):
|
|
100
102
|
super().__init__()
|
|
101
103
|
self.verbose = verbose
|
|
@@ -127,12 +129,10 @@ class EvoStrategy(Module):
|
|
|
127
129
|
self.model = model
|
|
128
130
|
self.noisable_model = Noisable(model, low_rank = noise_low_rank)
|
|
129
131
|
|
|
130
|
-
#
|
|
132
|
+
# maybe sync model params and buffers
|
|
131
133
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
with torch.no_grad():
|
|
135
|
-
environment(wrapped_model)
|
|
134
|
+
if sync_on_init:
|
|
135
|
+
self.sync_model_params_and_buffers_()
|
|
136
136
|
|
|
137
137
|
# get param dictionary
|
|
138
138
|
|
|
@@ -254,6 +254,17 @@ class EvoStrategy(Module):
|
|
|
254
254
|
def device(self):
|
|
255
255
|
return self.accelerate.device
|
|
256
256
|
|
|
257
|
+
@torch.no_grad()
|
|
258
|
+
def sync_model_params_and_buffers_(self):
|
|
259
|
+
if not self.accelerate.num_processes > 1:
|
|
260
|
+
return
|
|
261
|
+
|
|
262
|
+
for param in self.model.parameters():
|
|
263
|
+
dist.broadcast(param, src = 0)
|
|
264
|
+
|
|
265
|
+
for buffer in self.model.buffers():
|
|
266
|
+
dist.broadcast(buffer, src = 0)
|
|
267
|
+
|
|
257
268
|
def print(self, *args, **kwargs):
|
|
258
269
|
if not self.verbose:
|
|
259
270
|
return
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
+
x_evolution/x_evolution.py,sha256=-G5qXGMjwVzdkxIDR6xL_YGium4KfKC0cnlY76Upy0o,19799
|
|
3
|
+
x_evolution-0.1.27.dist-info/METADATA,sha256=hj0MUpIGVWoOY5wHsoy_ZF_cx7s48_HZicd4IgNUFEo,5853
|
|
4
|
+
x_evolution-0.1.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
x_evolution-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
x_evolution-0.1.27.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
-
x_evolution/x_evolution.py,sha256=Jln3wpkIQIp7xwa3KMibh0kSuob1NIi7Aj7Miz8RJdY,19491
|
|
3
|
-
x_evolution-0.1.26.dist-info/METADATA,sha256=7zamSGrDtvOUQzpyYCkrt42FmWC4GPYpVSWUAO8a6OA,5853
|
|
4
|
-
x_evolution-0.1.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
x_evolution-0.1.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
x_evolution-0.1.26.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|