torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -1,80 +0,0 @@
1
- import random
2
- from typing import Any
3
-
4
- from ...core import OptimizerModule
5
- from ...tensorlist import TensorList
6
-
7
-
8
- class PolyakStepSize(OptimizerModule):
9
- """Polyak step-size. Meant to be used at the beginning when ascent is the gradient but other placements may work.
10
- This can also work with SGD as SPS (Stochastic Polyak Step-Size) seems to use the same formula.
11
-
12
- Args:
13
- max (float | None, optional): maximum possible step size. Defaults to None.
14
- min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
15
- use_grad (bool, optional):
16
- if True, uses dot product of update and gradient to compute the step size.
17
- Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
18
- Defaults to True.
19
- parameterwise (bool, optional):
20
- if True, calculate Polyak step-size for each parameter separately,
21
- if False calculate one global step size for all parameters. Defaults to False.
22
- alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
23
- """
24
- def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
25
-
26
- defaults = dict(alpha = alpha)
27
- super().__init__(defaults)
28
- self.max = max
29
- self.min_obj_value = min_obj_value
30
- self.use_grad = use_grad
31
- self.parameterwise = parameterwise
32
-
33
- def _update(self, vars, ascent):
34
- if vars.closure is None: raise ValueError("PolyakStepSize requires closure")
35
- if vars.fx0 is None: vars.fx0 = vars.closure(False) # can only happen when placed after SPSA
36
-
37
- alpha = self.get_group_key('alpha')
38
-
39
- if self.parameterwise:
40
- if self.use_grad: denom = (ascent*vars.maybe_compute_grad_(self.get_params())).mean()
41
- else: denom = ascent.pow(2).mean()
42
- polyak_step_size: TensorList | Any = (vars.fx0 - self.min_obj_value) / denom.where(denom!=0, 1) # type:ignore
43
- polyak_step_size = polyak_step_size.where(denom != 0, 0)
44
- if self.max is not None: polyak_step_size = polyak_step_size.clamp_max(self.max)
45
-
46
- else:
47
- if self.use_grad: denom = (ascent*vars.maybe_compute_grad_(self.get_params())).total_mean()
48
- else: denom = ascent.pow(2).total_mean()
49
- if denom == 0: polyak_step_size = 0 # we converged
50
- else: polyak_step_size = (vars.fx0 - self.min_obj_value) / denom
51
-
52
- if self.max is not None:
53
- if polyak_step_size > self.max: polyak_step_size = self.max
54
-
55
- ascent.mul_(alpha * polyak_step_size)
56
- return ascent
57
-
58
-
59
-
60
- class RandomStepSize(OptimizerModule):
61
- """Uses random global step size from `low` to `high`.
62
-
63
- Args:
64
- low (float, optional): minimum learning rate. Defaults to 0.
65
- high (float, optional): maximum learning rate. Defaults to 1.
66
- parameterwise (bool, optional):
67
- if True, generate random step size for each parameter separately,
68
- if False generate one global random step size. Defaults to False.
69
- """
70
- def __init__(self, low: float = 0, high: float = 1, parameterwise=False):
71
- super().__init__({})
72
- self.low = low; self.high = high
73
- self.parameterwise = parameterwise
74
-
75
- def _update(self, vars, ascent):
76
- if self.parameterwise:
77
- lr = [random.uniform(self.low, self.high) for _ in range(len(ascent))]
78
- else:
79
- lr = random.uniform(self.low, self.high)
80
- return ascent.mul_(lr) # type:ignore
@@ -1,90 +0,0 @@
1
- from contextlib import nullcontext
2
-
3
- import numpy as np
4
- import torch
5
-
6
- from ...tensorlist import TensorList, Distributions, mean as tlmean
7
- from ...utils.python_tools import _ScalarLoss
8
- from ...core import _ClosureType, OptimizationVars, OptimizerModule, _maybe_pass_backward
9
-
10
-
11
- def _numpy_or_torch_mean(losses: list):
12
- """Returns the mean of a list of losses, which can be either numpy arrays or torch tensors."""
13
- if isinstance(losses[0], torch.Tensor):
14
- return torch.mean(torch.stack(losses))
15
- return np.mean(losses).item()
16
-
17
- class GaussianHomotopy(OptimizerModule):
18
- """Samples and averages value and gradients in multiple random points around current position.
19
- This effectively applies smoothing to the function.
20
-
21
- Args:
22
- n_samples (int, optional): number of gradient samples from around current position. Defaults to 4.
23
- sigma (float, optional): how far from current position to sample from. Defaults to 0.1.
24
- distribution (tl.Distributions, optional): distribution for random positions. Defaults to "normal".
25
- sample_x0 (bool, optional): 1st sample will be x0. Defaults to False.
26
- randomize_every (int | None, optional): randomizes the points every n steps. Defaults to 1.
27
- """
28
- def __init__(
29
- self,
30
- n_samples: int = 4,
31
- sigma: float = 0.1,
32
- distribution: Distributions = "normal",
33
- sample_x0 = False,
34
- randomize_every: int | None = 1,
35
- ):
36
- defaults = dict(sigma = sigma)
37
- super().__init__(defaults)
38
- self.n_samples = n_samples
39
- self.distribution: Distributions = distribution
40
- self.randomize_every = randomize_every
41
- self.current_step = 0
42
- self.perturbations = None
43
- self.sample_x0 = sample_x0
44
-
45
- @torch.no_grad()
46
- def step(self, vars: OptimizationVars):
47
- if vars.closure is None: raise ValueError('GaussianSmoothing requires closure.')
48
- closure = vars.closure
49
- params = self.get_params()
50
- sigmas = self.get_group_key('sigma')
51
-
52
- # generate random perturbations
53
- if self.perturbations is None or (self.randomize_every is not None and self.current_step % self.randomize_every == 0):
54
- if self.sample_x0:
55
- self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples-1)]
56
- else:
57
- self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples)]
58
-
59
- @torch.no_grad
60
- def smooth_closure(backward = True):
61
- losses = []
62
- grads = []
63
-
64
- # sample gradient and loss at x0
65
- if self.sample_x0:
66
- with torch.enable_grad() if backward else nullcontext():
67
- losses.append(closure())
68
- if backward: grads.append(params.grad.clone())
69
-
70
- # sample gradients from points around current params
71
- # and average them
72
- if self.perturbations is None: raise ValueError('who set perturbations to None???')
73
- for p in self.perturbations:
74
- params.add_(p)
75
- with torch.enable_grad() if backward else nullcontext():
76
- losses.append(_maybe_pass_backward(closure, backward))
77
- if backward: grads.append(params.grad.clone())
78
- params.sub_(p)
79
-
80
- # set the new averaged grads and return average loss
81
- if backward: params.set_grad_(tlmean(grads))
82
- return _numpy_or_torch_mean(losses)
83
-
84
-
85
- self.current_step += 1
86
- vars.closure = smooth_closure
87
- return self._update_params_or_step_with_next(vars)
88
-
89
-
90
- # todo single loop gaussian homotopy?
@@ -1,2 +0,0 @@
1
- from .ema import SwitchEMA
2
- from .swa import PeriodicSWA, CyclicSWA
@@ -1,72 +0,0 @@
1
- import torch
2
- from ...core import OptimizerModule
3
-
4
-
5
- def _reset_stats_hook(optimizer, state):
6
- for module in optimizer.unrolled_modules:
7
- module: OptimizerModule
8
- module.reset_stats()
9
-
10
- # the reason why this needs to be at the end is ??? I NEED TO REMEMBER
11
- class SwitchEMA(OptimizerModule):
12
- """Switch-EMA. Every n steps switches params to an exponential moving average of past weights.
13
-
14
- In the paper the switch happens after each epoch.
15
-
16
- Please put this module at the end, after all other modules.
17
-
18
- This can also function as EMA, set `update_every` to None and instead call `set_ema` and `unset_ema` on this module.
19
-
20
-
21
- Args:
22
- update_every (int): number of steps (batches) between setting model parameters to EMA.
23
- momentum (int): EMA momentum factor.
24
- reset_stats (bool, optional):
25
- if True, when setting model parameters to EMA, resets other modules stats such as momentum velocities.
26
- It might be better to set this to False if `update_every` is very small. Defaults to True.
27
-
28
- reference
29
- https://arxiv.org/abs/2402.09240
30
- """
31
- def __init__(self, update_every: int | None, momentum: float = 0.99, reset_stats: bool = True):
32
- defaults = dict(momentum=momentum)
33
- super().__init__(defaults)
34
- self.update_every = update_every
35
- self.cur_step = 0
36
- self.update_every = update_every
37
- self._reset_stats = reset_stats
38
- self.orig_params = None
39
-
40
- def set_ema(self):
41
- """sets module parameters to EMA, stores original parameters that can be restored by calling `unset_ema`"""
42
- params = self.get_params()
43
- self.orig_params = params.clone()
44
- params.set_(self.get_state_key('ema', init = 'params', params=params))
45
-
46
- def unset_ema(self):
47
- """Undoes `set_ema`."""
48
- if self.orig_params is None: raise ValueError('call `set_ema` first, and then `unset_ema`.')
49
- params = self.get_params()
50
- params.set_(self.orig_params)
51
-
52
- @torch.no_grad
53
- def step(self, vars):
54
- # if self.next_module is not None:
55
- # warn(f'EMA should usually be the last module, but {self.next_module.__class__.__name__} is after it.')
56
- self.cur_step += 1
57
-
58
- params = self.get_params()
59
- # state.maybe_use_grad_(params)
60
- # update params with the child. Averaging is always applied at the end.
61
- ret = self._update_params_or_step_with_next(vars, params)
62
-
63
- ema = self.get_state_key('ema', init = 'params', params=params)
64
- momentum = self.get_group_key('momentum')
65
-
66
- ema.lerp_compat_(params, 1 - momentum)
67
-
68
- if (self.update_every is not None) and (self.cur_step % self.update_every == 0):
69
- params.set_(ema.clone())
70
- if self._reset_stats: vars.add_post_step_hook(_reset_stats_hook)
71
-
72
- return ret
@@ -1,171 +0,0 @@
1
- from ...core import OptimizerModule
2
-
3
-
4
- def _reset_stats_hook(optimizer, state):
5
- for module in optimizer.unrolled_modules:
6
- module: OptimizerModule
7
- module.reset_stats()
8
-
9
- class PeriodicSWA(OptimizerModule):
10
- """Periodic Stochastic Weight Averaging.
11
-
12
- Please put this module at the end, after all other modules.
13
-
14
- The algorithm is as follows:
15
-
16
- 1. perform `pswa_start` normal steps before starting PSWA.
17
-
18
- 2. Perform multiple SWA iterations. On each iteration,
19
- run SWA algorithm for `num_cycles` cycles,
20
- and set weights to the weighted average before starting the next SWA iteration.
21
-
22
- SWA iteration is as follows:
23
-
24
- 1. perform `cycle_start` initial steps (can be 0)
25
-
26
- 2. for `num_cycles`, after every `cycle_length` steps passed, update the weight average with current model weights.
27
-
28
- 3. After `num_cycles` cycles passed, set model parameters to the weight average.
29
-
30
- Args:
31
- first_swa (int):
32
- number of steps before starting PSWA, authors run PSWA starting from 40th epoch out ot 150 epochs in total.
33
- cycle_length (int):
34
- number of steps betwen updating the weight average. Authors update it once per epoch.
35
- num_cycles (int):
36
- Number of weight average updates before setting model weights to the average and proceding to the next cycle.
37
- Authors use 20 (meaning 20 epochs since each cycle is 1 epoch).
38
- cycle_start (int, optional):
39
- number of steps at the beginning of each SWA period before updating the weight average (default: 0).
40
- reset_stats (bool, optional):
41
- if True, when setting model parameters to SWA, resets other modules stats such as momentum velocities (default: True).
42
- """
43
- def __init__(self, pswa_start: int, cycle_length: int, num_cycles: int, cycle_start: int = 0, reset_stats:bool = True):
44
-
45
- super().__init__({})
46
- self.pswa_start = pswa_start
47
- self.cycle_start = cycle_start
48
- self.cycle_length = cycle_length
49
- self.num_cycles = num_cycles
50
- self._reset_stats = reset_stats
51
-
52
-
53
- self.cur = 0
54
- self.period_cur = 0
55
- self.swa_cur = 0
56
- self.n_models = 0
57
-
58
- def step(self, vars):
59
- swa = None
60
- params = self.get_params()
61
- ret = self._update_params_or_step_with_next(vars, params)
62
-
63
- # start first period after `pswa_start` steps
64
- if self.cur >= self.pswa_start:
65
-
66
- # start swa after `cycle_start` steps in the current period
67
- if self.period_cur >= self.cycle_start:
68
-
69
- # swa updates on every `cycle_length`th step
70
- if self.swa_cur % self.cycle_length == 0:
71
- swa = self.get_state_key('swa') # initialized to zeros for simplicity
72
- swa.mul_(self.n_models).add_(params).div_(self.n_models + 1)
73
- self.n_models += 1
74
-
75
- self.swa_cur += 1
76
-
77
- self.period_cur += 1
78
-
79
- self.cur += 1
80
-
81
- # passed num_cycles in period, set model parameters to SWA
82
- if self.n_models == self.num_cycles:
83
- self.period_cur = 0
84
- self.swa_cur = 0
85
- self.n_models = 0
86
-
87
- assert swa is not None # it's created above self.n_models += 1
88
-
89
- params.set_(swa)
90
- # add a hook that resets momentum, which also deletes `swa` in this module
91
- if self._reset_stats: vars.add_post_step_hook(_reset_stats_hook)
92
-
93
- return ret
94
-
95
- class CyclicSWA(OptimizerModule):
96
- """Periodic SWA with cyclic learning rate. So it samples the weights, increases lr to `peak_lr`, samples the weights again,
97
- decreases lr back to `init_lr`, and samples the weights last time. Then model weights are replaced with the average of the three sampled weights,
98
- and next cycle starts. I made this due to a horrible misreading of the original SWA paper but it seems to work well.
99
-
100
- Please put this module at the end, after all other modules.
101
-
102
- Args:
103
- cswa_start (int): number of steps before starting the first CSWA cycle.
104
- cycle_length (int): length of each cycle in steps.
105
- steps_between (int): number of steps between cycles.
106
- init_lr (float, optional): initial and final learning rate in each cycle. Defaults to 0.
107
- peak_lr (float, optional): peak learning rate of each cycle. Defaults to 1.
108
- sample_all (float, optional): if True, instead of sampling 3 weights, it samples all weights in the cycle. Defaults to False.
109
- reset_stats (bool, optional):
110
- if True, when setting model parameters to SWA, resets other modules stats such as momentum velocities (default: True).
111
-
112
- """
113
- def __init__(self, cswa_start: int, cycle_length: int, steps_between: int, init_lr: float = 0, peak_lr: float = 1, sample_all = False, reset_stats: bool=True,):
114
- defaults = dict(init_lr = init_lr, peak_lr = peak_lr)
115
- super().__init__(defaults)
116
- self.cswa_start = cswa_start
117
- self.cycle_length = cycle_length
118
- self.init_lr = init_lr
119
- self.peak_lr = peak_lr
120
- self.steps_between = steps_between
121
- self.sample_all = sample_all
122
- self._reset_stats = reset_stats
123
-
124
- self.cur = 0
125
- self.cycle_cur = 0
126
- self.n_models = 0
127
-
128
- self.cur_lr = self.init_lr
129
-
130
- def step(self, vars):
131
- params = self.get_params()
132
-
133
- # start first period after `cswa_start` steps
134
- if self.cur >= self.cswa_start:
135
-
136
- ascent = vars.maybe_use_grad_(params)
137
-
138
- # determine the lr
139
- point = self.cycle_cur / self.cycle_length
140
- init_lr, peak_lr = self.get_group_keys('init_lr', 'peak_lr')
141
- if point < 0.5:
142
- p2 = point*2
143
- lr = init_lr * (1-p2) + peak_lr * p2
144
- else:
145
- p2 = (1 - point)*2
146
- lr = init_lr * (1-p2) + peak_lr * p2
147
-
148
- ascent *= lr
149
- ret = self._update_params_or_step_with_next(vars, params)
150
-
151
- if self.sample_all or self.cycle_cur in (0, self.cycle_length, self.cycle_length // 2):
152
- swa = self.get_state_key('swa')
153
- swa.mul_(self.n_models).add_(params).div_(self.n_models + 1)
154
- self.n_models += 1
155
-
156
- if self.cycle_cur == self.cycle_length:
157
- if not self.sample_all: assert self.n_models == 3, self.n_models
158
- self.n_models = 0
159
- self.cycle_cur = -1
160
-
161
- params.set_(swa)
162
- if self._reset_stats: vars.add_post_step_hook(_reset_stats_hook)
163
-
164
- self.cycle_cur += 1
165
-
166
- else:
167
- ret = self._update_params_or_step_with_next(vars, params)
168
-
169
- self.cur += 1
170
-
171
- return ret
@@ -1,20 +0,0 @@
1
- """Optimizers that I haven't tested and various (mostly stupid) ideas go there.
2
- If something works well I will move it outside of experimental folder.
3
- Otherwise all optimizers in this category should be considered unlikely to good for most tasks."""
4
- from .experimental import (
5
- HVPDiagNewton,
6
- ExaggeratedNesterov,
7
- ExtraCautiousAdam,
8
- GradMin,
9
- InwardSGD,
10
- MinibatchRprop,
11
- MomentumDenominator,
12
- MomentumNumerator,
13
- MultistepSGD,
14
- RandomCoordinateMomentum,
15
- ReciprocalSGD,
16
- NoiseSign,
17
- )
18
-
19
-
20
- from .ray_search import NewtonFDMRaySearch, LBFGSRaySearch