x-evolution 0.1.16__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ from pathlib import Path
6
6
  from functools import partial
7
7
 
8
8
  import torch
9
- from torch import tensor, Tensor, is_tensor, arange, randint
9
+ from torch import tensor, Tensor, stack, is_tensor, arange, randint
10
10
  from torch.nn import Module, ModuleList, Parameter, ParameterList
11
11
  from torch.optim import SGD, Adam, Optimizer
12
12
  from torch.optim.lr_scheduler import LRScheduler
@@ -96,20 +96,30 @@ class EvoStrategy(Module):
96
96
  reject_generation_fitnesses_if: Callable[[Tensor], bool] | None = None
97
97
  ):
98
98
  super().__init__()
99
+ self.verbose = verbose
99
100
 
100
101
  if not exists(accelerator):
101
102
  accelerator = Accelerator(cpu = cpu, **accelerate_kwargs)
102
103
 
103
104
  self.accelerate = accelerator
104
105
 
106
+ # environment - with optional init
107
+
108
+ self.environment = environment
109
+
110
+ if accelerator.is_main_process:
111
+ if hasattr(environment, 'pre_main_callback') and callable(environment.pre_main_callback):
112
+ self.print('pre_main_callback detected on environment passed in and is invoked')
113
+ environment.pre_main_callback()
114
+
115
+ # take care of model and parameters
116
+
105
117
  if isinstance(model, list):
106
118
  model = ModuleList(model)
107
119
 
108
120
  self.model = model
109
121
  self.noisable_model = Noisable(model, low_rank = noise_low_rank)
110
122
 
111
- self.environment = environment
112
-
113
123
  named_parameters_dict = dict(model.named_parameters())
114
124
 
115
125
  param_to_name_index = {param: name for name, param in named_parameters_dict.items()}
@@ -217,10 +227,6 @@ class EvoStrategy(Module):
217
227
 
218
228
  self.reject_generation_fitnesses_if = reject_generation_fitnesses_if
219
229
 
220
- # verbose
221
-
222
- self.verbose = verbose
223
-
224
230
  # checkpointing
225
231
 
226
232
  self.checkpoint_every = checkpoint_every
@@ -367,8 +373,10 @@ class EvoStrategy(Module):
367
373
  filename = 'evolved.model',
368
374
  num_generations = None,
369
375
  disable_distributed = False,
370
- rollback_model_at_end = False
376
+ rollback_model_at_end = False,
377
+ verbose = None
371
378
  ):
379
+ verbose = default(verbose, self.verbose)
372
380
 
373
381
  model = self.noisable_model.to(self.device)
374
382
 
@@ -553,4 +561,4 @@ class EvoStrategy(Module):
553
561
  # return fitnesses across generations
554
562
  # for meta-evolutionary (nesting EvoStrategy within the environment of another and optimizing some meta-network)
555
563
 
556
- return fitnesses_across_generations
564
+ return stack(fitnesses_across_generations)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.16
3
+ Version: 0.1.19
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -153,3 +153,4 @@ $ accelerate launch train.py
153
153
  }
154
154
  ```
155
155
 
156
+ *Nothing makes sense except in the light of evolution* - Theodosius Dobzhansky
@@ -0,0 +1,6 @@
1
+ x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
+ x_evolution/x_evolution.py,sha256=C70fX5PdLx1VLBXe1DYGBIOExucvLFbKDWiofBMrvQA,18860
3
+ x_evolution-0.1.19.dist-info/METADATA,sha256=7xr4uJpkmkkg0bi9eByDktqRvm_IIcvHJcLAJUjB2QM,5728
4
+ x_evolution-0.1.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ x_evolution-0.1.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ x_evolution-0.1.19.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
- x_evolution/x_evolution.py,sha256=2zK32idsczSaGlLZDg0IzQmptZiVkE049inbafy2Hh8,18416
3
- x_evolution-0.1.16.dist-info/METADATA,sha256=XtCSXqLRM3BYXu5UwrwCq4RiQIKLCPEN8KxYstyTwPM,5649
4
- x_evolution-0.1.16.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- x_evolution-0.1.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- x_evolution-0.1.16.dist-info/RECORD,,