x-evolution 0.1.17__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_evolution/x_evolution.py +21 -5
- {x_evolution-0.1.17.dist-info → x_evolution-0.1.22.dist-info}/METADATA +2 -1
- x_evolution-0.1.22.dist-info/RECORD +6 -0
- x_evolution-0.1.17.dist-info/RECORD +0 -6
- {x_evolution-0.1.17.dist-info → x_evolution-0.1.22.dist-info}/WHEEL +0 -0
- {x_evolution-0.1.17.dist-info → x_evolution-0.1.22.dist-info}/licenses/LICENSE +0 -0
x_evolution/x_evolution.py
CHANGED
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from functools import partial
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
+
import torch.distributed as dist
|
|
9
10
|
from torch import tensor, Tensor, stack, is_tensor, arange, randint
|
|
10
11
|
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
|
11
12
|
from torch.optim import SGD, Adam, Optimizer
|
|
@@ -96,19 +97,38 @@ class EvoStrategy(Module):
|
|
|
96
97
|
reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
|
|
97
98
|
):
|
|
98
99
|
super().__init__()
|
|
100
|
+
self.verbose = verbose
|
|
99
101
|
|
|
100
102
|
if not exists(accelerator):
|
|
101
103
|
accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
|
|
102
104
|
|
|
103
105
|
self.accelerate = accelerator
|
|
104
106
|
|
|
107
|
+
# environment - with optional init
|
|
108
|
+
|
|
109
|
+
self.environment = environment
|
|
110
|
+
|
|
111
|
+
if accelerator.is_main_process:
|
|
112
|
+
if hasattr(environment, 'pre_main_callback') and callable(environment.pre_main_callback):
|
|
113
|
+
self.print('pre_main_callback detected on environment passed in and is invoked')
|
|
114
|
+
environment.pre_main_callback()
|
|
115
|
+
|
|
116
|
+
accelerator.wait_for_everyone()
|
|
117
|
+
|
|
118
|
+
# take care of model and parameters
|
|
119
|
+
|
|
105
120
|
if isinstance(model, list):
|
|
106
121
|
model = ModuleList(model)
|
|
107
122
|
|
|
108
123
|
self.model = model
|
|
109
124
|
self.noisable_model = Noisable(model, low_rank = noise_low_rank)
|
|
110
125
|
|
|
111
|
-
|
|
126
|
+
# use prepare and run through environment once to sync params
|
|
127
|
+
|
|
128
|
+
wrapped_model = accelerator.prepare(model)
|
|
129
|
+
environment(wrapped_model)
|
|
130
|
+
|
|
131
|
+
# get param dictionary
|
|
112
132
|
|
|
113
133
|
named_parameters_dict = dict(model.named_parameters())
|
|
114
134
|
|
|
@@ -217,10 +237,6 @@ class EvoStrategy(Module):
|
|
|
217
237
|
|
|
218
238
|
self.reject_generation_fitnesses_if = reject_generation_fitnesses_if
|
|
219
239
|
|
|
220
|
-
# verbose
|
|
221
|
-
|
|
222
|
-
self.verbose = verbose
|
|
223
|
-
|
|
224
240
|
# checkpointing
|
|
225
241
|
|
|
226
242
|
self.checkpoint_every = checkpoint_every
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-evolution
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.22
|
|
4
4
|
Summary: x-evolution
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-evolution/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-evolution
|
|
@@ -153,3 +153,4 @@ $ accelerate launch train.py
|
|
|
153
153
|
}
|
|
154
154
|
```
|
|
155
155
|
|
|
156
|
+
*Nothing makes sense except in the light of evolution* - Theodosius Dobzhansky
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
+
x_evolution/x_evolution.py,sha256=ui4HgIcJ_qM0JtL-HzsiwWNF6r5Ybo9c2DG0mWTXo4w,19124
|
|
3
|
+
x_evolution-0.1.22.dist-info/METADATA,sha256=FrF0ebgJVX2NrZkjgcH4QHozWfXMhGtTGMpXlRnP-cE,5728
|
|
4
|
+
x_evolution-0.1.22.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
x_evolution-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
x_evolution-0.1.22.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
-
x_evolution/x_evolution.py,sha256=UaDdGLvd2xURQp2SnBjLZkdyOPQ41Y55djo00jnKCNE,18503
|
|
3
|
-
x_evolution-0.1.17.dist-info/METADATA,sha256=KSqDsQ5whxAfgPE6is0xQZTDPAPMJFKSCn6lSGd1Ork,5649
|
|
4
|
-
x_evolution-0.1.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
x_evolution-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
x_evolution-0.1.17.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|