x-evolution 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_evolution/x_evolution.py +20 -3
- {x_evolution-0.1.15.dist-info → x_evolution-0.1.17.dist-info}/METADATA +1 -1
- x_evolution-0.1.17.dist-info/RECORD +6 -0
- x_evolution-0.1.15.dist-info/RECORD +0 -6
- {x_evolution-0.1.15.dist-info → x_evolution-0.1.17.dist-info}/WHEEL +0 -0
- {x_evolution-0.1.15.dist-info → x_evolution-0.1.17.dist-info}/licenses/LICENSE +0 -0
x_evolution/x_evolution.py
CHANGED
|
@@ -6,7 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from functools import partial
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from torch import tensor, Tensor, is_tensor, arange, randint
|
|
9
|
+
from torch import tensor, Tensor, stack, is_tensor, arange, randint
|
|
10
10
|
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
|
11
11
|
from torch.optim import SGD, Adam, Optimizer
|
|
12
12
|
from torch.optim.lr_scheduler import LRScheduler
|
|
@@ -366,11 +366,19 @@ class EvoStrategy(Module):
|
|
|
366
366
|
self,
|
|
367
367
|
filename = 'evolved.model',
|
|
368
368
|
num_generations = None,
|
|
369
|
-
disable_distributed = False
|
|
369
|
+
disable_distributed = False,
|
|
370
|
+
rollback_model_at_end = False,
|
|
371
|
+
verbose = None
|
|
370
372
|
):
|
|
373
|
+
verbose = default(verbose, self.verbose)
|
|
371
374
|
|
|
372
375
|
model = self.noisable_model.to(self.device)
|
|
373
376
|
|
|
377
|
+
# maybe save model for rolling back (for meta-evo)
|
|
378
|
+
|
|
379
|
+
if rollback_model_at_end:
|
|
380
|
+
self.checkpoint('initial.model')
|
|
381
|
+
|
|
374
382
|
# maybe sigmas
|
|
375
383
|
|
|
376
384
|
if self.learned_noise_scale:
|
|
@@ -533,9 +541,18 @@ class EvoStrategy(Module):
|
|
|
533
541
|
|
|
534
542
|
self.print('evolution complete')
|
|
535
543
|
|
|
544
|
+
# final checkpoint
|
|
545
|
+
|
|
536
546
|
self.checkpoint(f'{filename}.final.{generation}')
|
|
537
547
|
|
|
548
|
+
# maybe rollback
|
|
549
|
+
|
|
550
|
+
if rollback_model_at_end:
|
|
551
|
+
orig_state_dict = torch.load(str(self.checkpoint_folder / 'initial.model.pt'), weights_only = True)
|
|
552
|
+
|
|
553
|
+
self.model.load_state_dict(orig_state_dict)
|
|
554
|
+
|
|
538
555
|
# return fitnesses across generations
|
|
539
556
|
# for meta-evolutionary (nesting EvoStrategy within the environment of another and optimizing some meta-network)
|
|
540
557
|
|
|
541
|
-
return fitnesses_across_generations
|
|
558
|
+
return stack(fitnesses_across_generations)
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
+
x_evolution/x_evolution.py,sha256=UaDdGLvd2xURQp2SnBjLZkdyOPQ41Y55djo00jnKCNE,18503
|
|
3
|
+
x_evolution-0.1.17.dist-info/METADATA,sha256=KSqDsQ5whxAfgPE6is0xQZTDPAPMJFKSCn6lSGd1Ork,5649
|
|
4
|
+
x_evolution-0.1.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
x_evolution-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
x_evolution-0.1.17.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
|
|
2
|
-
x_evolution/x_evolution.py,sha256=zW2sSev-jDtkOFNScx4p2Ncrdqj39SYef2ynApjo8RU,17979
|
|
3
|
-
x_evolution-0.1.15.dist-info/METADATA,sha256=apdgiJ4-mMt2MU034eLxu6SJkQ-eeRJJqx7w7G9l0YI,5649
|
|
4
|
-
x_evolution-0.1.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
x_evolution-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
x_evolution-0.1.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|