x-evolution 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ from pathlib import Path
6
6
  from functools import partial
7
7
 
8
8
  import torch
9
- from torch import tensor, Tensor, is_tensor, arange, randint
9
+ from torch import tensor, Tensor, stack, is_tensor, arange, randint
10
10
  from torch.nn import Module, ModuleList, Parameter, ParameterList
11
11
  from torch.optim import SGD, Adam, Optimizer
12
12
  from torch.optim.lr_scheduler import LRScheduler
@@ -366,11 +366,19 @@ class EvoStrategy(Module):
366
366
  self,
367
367
  filename = 'evolved.model',
368
368
  num_generations = None,
369
- disable_distributed = False
369
+ disable_distributed = False,
370
+ rollback_model_at_end = False,
371
+ verbose = None
370
372
  ):
373
+ verbose = default(verbose, self.verbose)
371
374
 
372
375
  model = self.noisable_model.to(self.device)
373
376
 
377
+ # maybe save model for rolling back (for meta-evo)
378
+
379
+ if rollback_model_at_end:
380
+ self.checkpoint('initial.model')
381
+
374
382
  # maybe sigmas
375
383
 
376
384
  if self.learned_noise_scale:
@@ -533,9 +541,18 @@ class EvoStrategy(Module):
533
541
 
534
542
  self.print('evolution complete')
535
543
 
544
+ # final checkpoint
545
+
536
546
  self.checkpoint(f'{filename}.final.{generation}')
537
547
 
548
+ # maybe rollback
549
+
550
+ if rollback_model_at_end:
551
+ orig_state_dict = torch.load(str(self.checkpoint_folder / 'initial.model.pt'), weights_only = True)
552
+
553
+ self.model.load_state_dict(orig_state_dict)
554
+
538
555
  # return fitnesses across generations
539
556
  # for meta-evolutionary (nesting EvoStrategy within the environment of another and optimizing some meta-network)
540
557
 
541
- return fitnesses_across_generations
558
+ return stack(fitnesses_across_generations)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -0,0 +1,6 @@
1
+ x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
+ x_evolution/x_evolution.py,sha256=UaDdGLvd2xURQp2SnBjLZkdyOPQ41Y55djo00jnKCNE,18503
3
+ x_evolution-0.1.17.dist-info/METADATA,sha256=KSqDsQ5whxAfgPE6is0xQZTDPAPMJFKSCn6lSGd1Ork,5649
4
+ x_evolution-0.1.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ x_evolution-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ x_evolution-0.1.17.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- x_evolution/__init__.py,sha256=XcwXJgIMPnCWGfGws3-vKgoR_7IfVslJBtiMvmEeSg0,57
2
- x_evolution/x_evolution.py,sha256=zW2sSev-jDtkOFNScx4p2Ncrdqj39SYef2ynApjo8RU,17979
3
- x_evolution-0.1.15.dist-info/METADATA,sha256=apdgiJ4-mMt2MU034eLxu6SJkQ-eeRJJqx7w7G9l0YI,5649
4
- x_evolution-0.1.15.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- x_evolution-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- x_evolution-0.1.15.dist-info/RECORD,,