x-evolution 0.1.11__tar.gz → 0.1.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-evolution
3
- Version: 0.1.11
3
+ Version: 0.1.15
4
4
  Summary: x-evolution
5
5
  Project-URL: Homepage, https://pypi.org/project/x-evolution/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-evolution
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-evolution"
3
- version = "0.1.11"
3
+ version = "0.1.15"
4
4
  description = "x-evolution"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -80,6 +80,8 @@ from x_mlps_pytorch.residual_normed_mlp import ResidualNormedMLP
80
80
 
81
81
  actor = ResidualNormedMLP(dim_in = 8, dim = 24, depth = 2, residual_every = 1, dim_out = 4)
82
82
 
83
+ from torch.optim.lr_scheduler import CosineAnnealingLR
84
+
83
85
  evo_strat = EvoStrategy(
84
86
  actor,
85
87
  environment = LunarEnvironment(repeats = 2),
@@ -91,7 +93,10 @@ evo_strat = EvoStrategy(
91
93
  learned_noise_scale = True,
92
94
  use_sigma_optimizer = True,
93
95
  learning_rate = 1e-3,
94
- noise_scale_learning_rate = 1e-5
96
+ noise_scale_learning_rate = 1e-4,
97
+ use_scheduler = True,
98
+ scheduler_klass = CosineAnnealingLR,
99
+ scheduler_kwargs = dict(T_max = 50_000)
95
100
  )
96
101
 
97
102
  evo_strat()
@@ -52,12 +52,12 @@ evo_strat = EvoStrategy(
52
52
  environment = loss_mnist,
53
53
  noise_population_size = 100,
54
54
  noise_scale = 1e-2,
55
- noise_scale_clamp_range = (5e-3, 2e-2),
56
- noise_low_rank = 2,
55
+ noise_scale_clamp_range = (8e-3, 2e-2),
56
+ noise_low_rank = 1,
57
57
  num_generations = 10_000,
58
58
  learning_rate = 1e-3,
59
59
  learned_noise_scale = True,
60
- noise_scale_learning_rate = 1e-5
60
+ noise_scale_learning_rate = 2e-5
61
61
  )
62
62
 
63
63
  evo_strat()
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  from torch import tensor
3
3
  import torch.nn.functional as F
4
+ from torch.optim.lr_scheduler import LambdaLR
4
5
 
5
6
  # model
6
7
 
@@ -44,7 +45,13 @@ evo_strat = EvoStrategy(
44
45
  noise_scale = 1e-1,
45
46
  noise_scale_clamp_range = (5e-2, 2e-1),
46
47
  learned_noise_scale = True,
47
- noise_scale_learning_rate = 5e-4
48
+ noise_scale_learning_rate = 5e-4,
49
+ use_scheduler = True,
50
+ scheduler_klass = LambdaLR,
51
+ scheduler_kwargs = dict(lr_lambda = lambda step: min(1., step / 10.)),
52
+ use_sigma_scheduler = True,
53
+ sigma_scheduler_klass = LambdaLR,
54
+ sigma_scheduler_kwargs = dict(lr_lambda = lambda step: min(1., step / 10.))
48
55
  )
49
56
 
50
57
  evo_strat()
@@ -8,7 +8,8 @@ from functools import partial
8
8
  import torch
9
9
  from torch import tensor, Tensor, is_tensor, arange, randint
10
10
  from torch.nn import Module, ModuleList, Parameter, ParameterList
11
- from torch.optim import SGD, Adam
11
+ from torch.optim import SGD, Adam, Optimizer
12
+ from torch.optim.lr_scheduler import LRScheduler
12
13
 
13
14
  import torch.nn.functional as F
14
15
 
@@ -73,11 +74,17 @@ class EvoStrategy(Module):
73
74
  noise_scale_learning_rate = 1e-5,
74
75
  noise_scale_clamp_range: tuple[float, float] = (1e-3, 1e-1),
75
76
  use_optimizer = True,
76
- optimizer_klass = partial(SGD, nesterov = True, momentum = 0.1, weight_decay = 1e-2),
77
+ optimizer_klass: type[Optimizer] | Callable = partial(SGD, nesterov = True, momentum = 0.1, weight_decay = 1e-2),
77
78
  optimizer_kwargs: dict = dict(),
78
79
  use_sigma_optimizer = True,
79
- sigma_optimizer_klass = partial(SGD, nesterov = True, momentum = 0.1),
80
+ sigma_optimizer_klass: type[Optimizer] | Callable = partial(SGD, nesterov = True, momentum = 0.1),
80
81
  sigma_optimizer_kwargs: dict = dict(),
82
+ use_scheduler = False,
83
+ scheduler_klass: type[LRScheduler] | None = None,
84
+ scheduler_kwargs: dict = dict(),
85
+ use_sigma_scheduler = False,
86
+ sigma_scheduler_klass: type[LRScheduler] | None = None,
87
+ sigma_scheduler_kwargs: dict = dict(),
81
88
  transform_fitness: Callable = identity,
82
89
  fitness_to_weighted_factor: Callable[[Tensor], Tensor] = normalize,
83
90
  checkpoint_every = None, # saving every number of generations
@@ -198,6 +205,16 @@ class EvoStrategy(Module):
198
205
 
199
206
  # rejecting the fitnesses for a certain generation if this function is true
200
207
 
208
+ self.use_scheduler = use_scheduler
209
+
210
+ if use_scheduler and exists(scheduler_klass) and use_optimizer:
211
+ self.scheduler = scheduler_klass(self.optimizer, **scheduler_kwargs)
212
+
213
+ self.use_sigma_scheduler = use_sigma_scheduler
214
+
215
+ if use_sigma_scheduler and exists(sigma_scheduler_klass) and use_sigma_optimizer:
216
+ self.sigma_scheduler = sigma_scheduler_klass(self.sigma_optimizer, **sigma_scheduler_kwargs)
217
+
201
218
  self.reject_generation_fitnesses_if = reject_generation_fitnesses_if
202
219
 
203
220
  # verbose
@@ -310,9 +327,6 @@ class EvoStrategy(Module):
310
327
 
311
328
  if self.use_sigma_optimizer:
312
329
  accum_grad_(sigma, -one_grad_sigma)
313
-
314
- self.sigma_optimizer.step()
315
- self.sigma_optimizer.zero_grad()
316
330
  else:
317
331
  sigma.add_(one_grad_sigma * self.noise_scale_learning_rate)
318
332
 
@@ -324,11 +338,17 @@ class EvoStrategy(Module):
324
338
  self.optimizer.step()
325
339
  self.optimizer.zero_grad()
326
340
 
341
+ if self.use_scheduler and exists(self.scheduler):
342
+ self.scheduler.step()
343
+
327
344
  if self.learned_noise_scale:
328
345
  if self.use_sigma_optimizer:
329
346
  self.sigma_optimizer.step()
330
347
  self.sigma_optimizer.zero_grad()
331
348
 
349
+ if self.use_sigma_scheduler and exists(self.sigma_scheduler):
350
+ self.sigma_scheduler.step()
351
+
332
352
  for sigma in self.sigmas:
333
353
  self.sigma_clamp_(sigma)
334
354
 
@@ -495,7 +515,7 @@ class EvoStrategy(Module):
495
515
  if self.learned_noise_scale:
496
516
  packed_sigma, _ = pack(list(self.sigmas), '*')
497
517
  avg_sigma = packed_sigma.mean().item()
498
- msg += f' | avg sigma: {avg_sigma:.3f}'
518
+ msg += f' | average sigma: {avg_sigma:.3f}'
499
519
 
500
520
  self.print(msg)
501
521
 
File without changes
File without changes
File without changes