x-evolution 0.1.17__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from functools import partial
7
7
 
8
8
  import torch
9
+ import torch.distributed as dist
9
10
  from torch import tensor, Tensor, stack, is_tensor, arange, randint
10
11
  from torch.nn import Module, ModuleList, Parameter, ParameterList
11
12
  from torch.optim import SGD, Adam, Optimizer
@@ -96,19 +97,38 @@ class EvoStrategy(Module):
96
97
  reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
97
98
  ):
98
99
  super().__init__()
100
+ self.verbose = verbose
99
101
 
100
102
  if not exists(accelerator):
101
103
  accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
102
104
 
103
105
  self.accelerate = accelerator
104
106
 
107
+ # environment - with optional init
108
+
109
+ self.environment = environment
110
+
111
+ if accelerator.is_main_process:
112
+ if hasattr(environment, 'pre_main_callback') and callable(environment.pre_main_callback):
113
+ self.print('pre_main_callback detected on environment passed in and is invoked')
114
+ environment.pre_main_callback()
115
+
116
+ accelerator.wait_for_everyone()
117
+
118
+ # take care of model and parameters
119
+
105
120
  if isinstance(model, list):
106
121
  model = ModuleList(model)
107
122
 
108
123
  self.model = model
109
124
  self.noisable_model = Noisable(model, low_rank = noise_low_rank)
110
125
 
111
- self.environment = environment
126
+ # use prepare and run through environment once to sync params
127
+
128
+ wrapped_model = accelerator.prepare(model)
129
+ environment(wrapped_model)
130
+
131
+ # get param dictionary
112
132
 
113
133
  named_parameters_dict = dict(model.named_parameters())
114
134
 
@@ -217,10 +237,6 @@ class EvoStrategy(Module):
217
237
 
218
238
  self.reject_generation_fitnesses_if = reject_generation_fitnesses_if
219
239
 
220
- # verbose
221
-
222
- self.verbose = verbose
223
-
224
240
  # checkpointing
225
241
 
226
242
  self.checkpoint_every = checkpoint_every
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.17
3
+ Version: 0.1.22
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -153,3 +153,4 @@ $ accelerate launch train.py
153
153
  }
154
154
  ```
155
155
 
156
+ *Nothing makes sense except in the light of evolution* - Theodosius Dobzhansky
@@ -0,0 +1,6 @@
1
+ x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
+ x_evolution/x_evolution.py,sha256=ui4HgIcJ_qM0JtL-HzsiwWNF6r5Ybo9c2DG0mWTXo4w,19124
3
+ x_evolution-0.1.22.dist-info/METADATA,sha256=FrF0ebgJVX2NrZkjgcH4QHozWfXMhGtTGMpXlRnP-cE,5728
4
+ x_evolution-0.1.22.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ x_evolution-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ x_evolution-0.1.22.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
- x_evolution/x_evolution.py,sha256=UaDdGLvd2xURQp2SnBjLZkdyOPQ41Y55djo00jnKCNE,18503
3
- x_evolution-0.1.17.dist-info/METADATA,sha256=KSqDsQ5whxAfgPE6is0xQZTDPAPMJFKSCn6lSGd1Ork,5649
4
- x_evolution-0.1.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- x_evolution-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- x_evolution-0.1.17.dist-info/RECORD,,