torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,125 +0,0 @@
1
- from typing import Literal, Any, cast
2
-
3
- import torch
4
-
5
- from ...utils.python_tools import _ScalarLoss
6
- from ...tensorlist import Distributions, TensorList
7
- from ...core import _ClosureType, OptimizerModule, OptimizationVars
8
- from ._fd_formulas import _FD_Formulas
9
- from .base_approximator import GradientApproximatorBase
10
-
11
- def _two_point_rcd_(closure: _ClosureType, params: TensorList, perturbation: TensorList, eps: TensorList, fx0: _ScalarLoss | None, ):
12
- """Two point randomized finite difference (same signature for all other finite differences functions).
13
-
14
- Args:
15
- closure (Callable): A closure that reevaluates the model and returns the loss.
16
- params (TensorList): TensorList with parameters.
17
- perturbation (TensorList): TensorList with perturbation ALREADY MULTIPLIED BY EPSILON.
18
- eps (TensorList): Finite difference epsilon.
19
- fx0 (ScalarType): Loss at fx0, to avoid reevaluating it each time. On some functions can be None when it isn't needed.
20
-
21
- Returns:
22
- TensorList with gradient estimation and approximate loss.
23
- """
24
- # positive loss
25
- params += perturbation
26
- loss_pos = closure(False)
27
-
28
- # negative loss
29
- params.sub_(perturbation, alpha = 2)
30
- loss_neg = closure(False)
31
-
32
- # restore params
33
- params += perturbation
34
-
35
- # calculate gradient estimation using central finite differences formula
36
- # (we square eps in denominator because perturbation is already multiplied by eps)
37
- # grad_est = (perturbation * (loss_pos - loss_neg)) / (2 * eps**2)
38
- # is equivalent to the following:
39
- return perturbation * eps.map(lambda x: (loss_pos - loss_neg) / (2 * x**2)), loss_pos
40
- # also we can't reuse the perturbatuion tensor and multiply it in place,
41
- # since if randomize_every is more than 1, that would break it.
42
-
43
- def _two_point_rfd_(closure: _ClosureType, params: TensorList, perturbation: TensorList, eps: TensorList, fx0: _ScalarLoss | None):
44
- if fx0 is None: raise ValueError()
45
-
46
- params += perturbation
47
- fx1 = closure(False)
48
-
49
- params -= perturbation
50
-
51
- return perturbation * eps.map(lambda x: (fx1 - fx0) / x**2), fx0
52
-
53
- def _two_point_rbd_(closure: _ClosureType, params: TensorList, perturbation: TensorList, eps: TensorList, fx0: _ScalarLoss | None):
54
- if fx0 is None: raise ValueError()
55
-
56
- params -= perturbation
57
- fx1 = closure(False)
58
-
59
- params += perturbation
60
-
61
- return perturbation * eps.map(lambda x: (fx0 - fx1) / x**2), fx0
62
-
63
-
64
- class RandomizedFDM(GradientApproximatorBase):
65
- """Gradient approximation via randomized finite difference.
66
-
67
- Args:
68
- eps (float, optional): finite difference epsilon. Defaults to 1e-5.
69
- formula (_FD_Formulas, optional): Finite difference formula. Defaults to 'forward'.
70
- n_samples (int, optional): number of times gradient is approximated and then averaged. Defaults to 1.
71
- distribution (Distributions, optional): distribution for random perturbations. Defaults to "normal".
72
- target (str, optional):
73
- determines what this module sets.
74
-
75
- "ascent" - it creates a new ascent direction but doesn't treat is as gradient.
76
-
77
- "grad" - it creates the gradient and sets it to `.grad` attributes (default).
78
-
79
- "closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
80
- """
81
- def __init__(
82
- self,
83
- eps: float = 1e-5,
84
- formula: _FD_Formulas = "forward",
85
- n_samples: int = 1,
86
- distribution: Distributions = "normal",
87
- target: Literal['ascent', 'grad', 'closure'] = 'grad',
88
- ):
89
- defaults = dict(eps = eps)
90
-
91
- if formula == 'forward':
92
- self._finite_difference = _two_point_rfd_
93
- requires_fx0 = True
94
-
95
- elif formula == 'backward':
96
- self._finite_difference = _two_point_rbd_
97
- requires_fx0 = True
98
-
99
- elif formula == 'central':
100
- self._finite_difference = _two_point_rcd_
101
- requires_fx0 = False
102
-
103
- else: raise ValueError(f"Unknown formula: {formula}")
104
-
105
- self.n_samples = n_samples
106
- self.distribution: Distributions = distribution
107
-
108
- super().__init__(defaults, requires_fx0=requires_fx0, target = target)
109
-
110
- @torch.no_grad
111
- def _make_ascent(self, closure, params, fx0):
112
- eps = self.get_group_key('eps')
113
- fx0_approx = None
114
-
115
- if self.n_samples == 1:
116
- grads, fx0_approx = self._finite_difference(closure, params, params.sample_like(eps, self.distribution), eps, fx0)
117
-
118
- else:
119
- grads = params.zeros_like()
120
- for i in range(self.n_samples):
121
- g, fx0_approx = self._finite_difference(closure, params, params.sample_like(eps, self.distribution), eps, fx0)
122
- grads += g
123
- grads /= self.n_samples
124
-
125
- return grads, fx0, fx0_approx
@@ -1,56 +0,0 @@
1
- import torch
2
-
3
- from ...tensorlist import TensorList
4
- from ...core import OptimizationVars
5
- from .base_ls import LineSearchBase
6
-
7
-
8
- class ArmijoLS(LineSearchBase):
9
- """Armijo backtracking line search
10
-
11
- Args:
12
- alpha (float): initial step size.
13
- mul (float, optional): lr multiplier on each iteration. Defaults to 0.5.
14
- beta (float, optional):
15
- armijo condition parameter, fraction of expected linear loss decrease to accept.
16
- Larger values mean loss needs to decrease more for a step sizer to be accepted. Defaults to 1e-4.
17
- max_iter (int, optional): maximum iterations. Defaults to 10.
18
- log_lrs (bool, optional): logs learning rates. Defaults to False.
19
- """
20
- def __init__(
21
- self,
22
- alpha: float = 1,
23
- mul: float = 0.5,
24
- beta: float = 1e-2,
25
- max_iter: int = 10,
26
- log_lrs = False,
27
- ):
28
- defaults = dict(alpha=alpha)
29
- super().__init__(defaults, maxiter=None, log_lrs=log_lrs)
30
- self.mul = mul
31
- self.beta = beta
32
- self.max_iter = max_iter
33
-
34
- @torch.no_grad
35
- def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
36
- if vars.closure is None: raise RuntimeError(f"Line searches ({self.__class__.__name__}) require a closure")
37
- ascent = vars.maybe_use_grad_(params)
38
- grad = vars.maybe_compute_grad_(params)
39
- alpha = self.get_first_group_key('alpha')
40
- if vars.fx0 is None: vars.fx0 = vars.closure(False)
41
-
42
- # loss decrease per lr=1 if function was linear
43
- decrease_per_lr = (grad*ascent).total_sum()
44
-
45
- for _ in range(self.max_iter):
46
- loss = self._evaluate_lr_(alpha, vars.closure, ascent, params)
47
-
48
- # expected decrease
49
- expected_decrease = decrease_per_lr * alpha
50
-
51
- if (vars.fx0 - loss) / expected_decrease >= self.beta:
52
- return alpha
53
-
54
- alpha *= self.mul
55
-
56
- return 0
@@ -1,139 +0,0 @@
1
- from typing import Literal
2
- from abc import ABC, abstractmethod
3
-
4
-
5
- import torch
6
-
7
- from ...tensorlist import TensorList
8
- from ...core import _ClosureType, OptimizationVars, OptimizerModule, _maybe_pass_backward
9
- from ...utils.python_tools import _ScalarLoss
10
-
11
-
12
- class MaxIterReached(Exception): pass
13
-
14
- class LineSearchBase(OptimizerModule, ABC):
15
- """Base linesearch class. This is an abstract class, please don't use it as the optimizer.
16
-
17
- When inheriting from this class the easiest way is only override `_find_best_lr`, which should
18
- return the final lr to use.
19
-
20
- Args:
21
- defaults (dict): dictionary with default parameters for the module.
22
- target (str, optional):
23
- determines how _update method is used in the default step method.
24
-
25
- "ascent" - it updates the ascent
26
-
27
- "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
28
-
29
- "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
30
- maxiter (_type_, optional): maximum line search iterations
31
- (useful for things like scipy.optimize.minimize_scalar) as it doesn't have
32
- an exact iteration limit. Defaults to None.
33
- log_lrs (bool, optional):
34
- saves lrs and losses with them into optimizer._lrs (for debugging).
35
- Defaults to False.
36
- """
37
- def __init__(
38
- self,
39
- defaults: dict,
40
- target: Literal['grad', 'ascent', 'closure'] = 'ascent',
41
- maxiter=None,
42
- log_lrs=False,
43
- ):
44
- super().__init__(defaults, target=target)
45
- self._reset()
46
-
47
- self.maxiter = maxiter
48
- self.log_lrs = log_lrs
49
- self._lrs: list[dict[float, _ScalarLoss]] = []
50
- """this only gets filled if `log_lrs` is True. On each step, a dictionary is added to this list,
51
- with all lrs tested at that step as keys and corresponding losses as values."""
52
-
53
- def _reset(self):
54
- """Resets `_last_lr`, `_lowest_loss`, `_best_lr`, `_fx0_approx` and `_current_iter`."""
55
- self._last_lr = 0
56
- self._lowest_loss = float('inf')
57
- self._best_lr = 0
58
- self._fx0_approx = None
59
- self._current_iter = 0
60
-
61
- def _set_lr_(self, lr: float, ascent_direction: TensorList, params: TensorList, ):
62
- alpha = self._last_lr - lr
63
- if alpha != 0: params.add_(ascent_direction, alpha = alpha)
64
- self._last_lr = lr
65
-
66
- # lr is first here so that we can use a partial
67
- def _evaluate_lr_(self, lr: float, closure: _ClosureType, ascent: TensorList, params: TensorList, backward=False):
68
- """Evaluate `lr`, if loss is better than current lowest loss,
69
- overrides `self._lowest_loss` and `self._best_lr`.
70
-
71
- Args:
72
- closure (ClosureType): closure.
73
- params (tl.TensorList): params.
74
- ascent_direction (tl.TensorList): ascent.
75
- lr (float): lr to evaluate.
76
-
77
- Returns:
78
- Loss with evaluated lr.
79
- """
80
- # check max iter
81
- if self._current_iter == self.maxiter: raise MaxIterReached
82
- self._current_iter += 1
83
-
84
- # set new lr and evaluate loss with it
85
- self._set_lr_(lr, ascent, params = params)
86
- with torch.enable_grad() if backward else torch.no_grad(): self._fx0_approx = _maybe_pass_backward(closure, backward)
87
-
88
- # if it is the best so far, record it
89
- if self._fx0_approx < self._lowest_loss:
90
- self._lowest_loss = self._fx0_approx
91
- self._best_lr = lr
92
-
93
- # log lr and loss
94
- if self.log_lrs:
95
- self._lrs[-1][lr] = self._fx0_approx
96
-
97
- return self._fx0_approx
98
-
99
- def _evaluate_lr_ensure_float(
100
- self,
101
- lr: float,
102
- closure: _ClosureType,
103
- ascent: TensorList,
104
- params: TensorList,
105
- ) -> float:
106
- """Same as _evaluate_lr_ but ensures that the loss value is float."""
107
- v = self._evaluate_lr_(lr, closure, ascent, params)
108
- if isinstance(v, torch.Tensor): return v.detach().cpu().item()
109
- return float(v)
110
-
111
- def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
112
- """This should return the best lr."""
113
- ... # pylint:disable=unnecessary-ellipsis
114
-
115
- @torch.no_grad
116
- def step(self, vars: OptimizationVars):
117
- self._reset()
118
- if self.log_lrs: self._lrs.append({})
119
-
120
- params = self.get_params()
121
- ascent_direction = vars.maybe_use_grad_(params)
122
-
123
- try:
124
- lr = self._find_best_lr(vars, params) # pylint:disable=assignment-from-no-return
125
- except MaxIterReached:
126
- lr = self._best_lr
127
-
128
- # if child is None, set best lr which update params and return loss
129
- if self.next_module is None:
130
- self._set_lr_(lr, ascent_direction, params)
131
- return self._lowest_loss
132
-
133
- # otherwise undo the update by setting lr to 0 and instead multiply ascent direction by lr.
134
- self._set_lr_(0, ascent_direction, params)
135
- ascent_direction.mul_(self._best_lr)
136
- vars.ascent = ascent_direction
137
- if vars.fx0_approx is None: vars.fx0_approx = self._lowest_loss
138
- return self.next_module.step(vars)
139
-
@@ -1,217 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- from ...tensorlist import TensorList
5
- from ...core import OptimizationVars
6
- from .base_ls import LineSearchBase
7
-
8
- _FloatOrTensor = float | torch.Tensor
9
- def _fit_and_minimize_quadratic_2points_grad(x1:_FloatOrTensor,y1:_FloatOrTensor,y1_prime:_FloatOrTensor,x2:_FloatOrTensor,y2:_FloatOrTensor):
10
- """Fits a quadratic to value and gradient and x1 and value at x2 and returns minima and a parameter."""
11
- a = (y1_prime * x2 - y2 - y1_prime*x1 + y1) / (x1**2 - x2**2 - 2*x1**2 + 2*x1*x2)
12
- b = y1_prime - 2*a*x1
13
- # c = -(a*x1**2 + b*x1 - y1)
14
- return -b / (2 * a), a
15
-
16
- def _ensure_float(x):
17
- if isinstance(x, torch.Tensor): return x.detach().cpu().item()
18
- elif isinstance(x, np.ndarray): return x.item()
19
- return float(x)
20
-
21
- class DirectionalNewton(LineSearchBase):
22
- """Minimizes a parabola in the direction of the update via one additional forward pass,
23
- and uses another forward pass to make sure it didn't overstep (optionally).
24
- So in total this performs three forward passes and one backward.
25
-
26
- It is recommented to put LR before DirectionalNewton.
27
-
28
- First forward and backward pass is used to calculate the value and gradient at initial parameters.
29
- Then a gradient descent step is performed with `lr` learning rate, and loss is recalculated
30
- with new parameters. A quadratic is fitted to two points and gradient,
31
- if it has positive curvature, this makes a step towards the minimum, and checks if lr decreased
32
- with an additional forward pass.
33
-
34
- Args:
35
- eps (float, optional):
36
- learning rate, also functions as epsilon for directional second derivative estimation. Defaults to 1.
37
- max_dist (float | None, optional):
38
- maximum distance to step when minimizing quadratic.
39
- If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
40
- validate_step (bool, optional):
41
- uses an additional forward pass to check
42
- if step towards the minimum actually decreased the loss. Defaults to True.
43
- alpha (float, optional):
44
- epsilon for estimating directional second derivative, also works as learning rate
45
- for when curvature is negative or loss increases.
46
- log_lrs (bool, optional):
47
- saves lrs and losses with them into optimizer._lrs (for debugging).
48
- Defaults to False.
49
-
50
- Note:
51
- While lr scheduling is supported, this uses lr of the first parameter for all parameters.
52
- """
53
- def __init__(self, max_dist: float | None = 1e5, validate_step = True, alpha:float=1, log_lrs = False,):
54
- super().__init__({"alpha": alpha}, maxiter=None, log_lrs=log_lrs)
55
-
56
- self.max_dist = max_dist
57
- self.validate_step = validate_step
58
-
59
- @torch.no_grad
60
- def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
61
- if vars.closure is None: raise ValueError('QuardaticLS requires closure')
62
- closure = vars.closure
63
-
64
- params = self.get_params()
65
- grad = vars.maybe_compute_grad_(params)
66
- ascent = vars.maybe_use_grad_(params)
67
- if vars.fx0 is None: vars.fx0 = vars.closure(False) # at this stage maybe_compute_grad could've evaluated fx0
68
-
69
- alpha: float = self.get_first_group_key('alpha') # this doesn't support variable lrs but we still want to support schedulers
70
-
71
- # directional f'(x1)
72
- y1_prime = (grad * ascent).total_sum()
73
-
74
- # f(x2)
75
- y2 = self._evaluate_lr_(alpha, closure, ascent, params)
76
-
77
- # if gradients weren't 0
78
- if y1_prime != 0:
79
- xmin, a = _fit_and_minimize_quadratic_2points_grad(
80
- x1=0,
81
- y1=vars.fx0,
82
- y1_prime=-y1_prime,
83
- x2=alpha,
84
- # we stepped in the direction of minus gradient times lr.
85
- # which is why y1_prime is negative and we multiply x2 by lr.
86
- y2=y2
87
- )
88
- # so we obtained xmin in lr*grad units. We need in lr units.
89
- xmin = _ensure_float(xmin)
90
-
91
- # make sure curvature is positive
92
- if a > 0:
93
-
94
- # discard very large steps
95
- if self.max_dist is None or xmin <= self.max_dist:
96
-
97
- # if validate_step is enabled, make sure loss didn't increase
98
- if self.validate_step:
99
- y_val = self._evaluate_lr_(xmin, closure, ascent, params)
100
- # if it increased, move back to y2.
101
- if y_val > y2:
102
- return float(alpha)
103
-
104
- return float(xmin)
105
-
106
- return float(alpha)
107
-
108
- def _fit_and_minimize_quadratic_3points(
109
- x1: _FloatOrTensor,
110
- y1: _FloatOrTensor,
111
- x2: _FloatOrTensor,
112
- y2: _FloatOrTensor,
113
- x3: _FloatOrTensor,
114
- y3: _FloatOrTensor,
115
- ):
116
- """Fits a quadratic to three points."""
117
- a = (x1*(y3-y2) + x2*(y1-y3) + x3*(y2-y1)) / ((x1-x2) * (x1 - x3) * (x2 - x3))
118
- b = (y2-y1) / (x2-x1) - a*(x1+x2)
119
- # c = (y1 - a*x1**2 - b*x1)
120
- return (-b / (2 * a), a)
121
-
122
-
123
- def _newton_step_3points(
124
- xneg: _FloatOrTensor,
125
- yneg: _FloatOrTensor,
126
- x0: _FloatOrTensor,
127
- y0: _FloatOrTensor,
128
- xpos: _FloatOrTensor, # since points are evenly spaced, xpos is x0 + eps, its turns out unused
129
- ypos: _FloatOrTensor,
130
- ):
131
- eps = x0 - xneg
132
- dx = (-yneg + ypos) / (2 * eps)
133
- ddx = (ypos - 2*y0 + yneg) / (eps**2)
134
-
135
- # xneg is actually x0
136
- return xneg - dx / ddx, ddx
137
-
138
- class DirectionalNewton3Points(LineSearchBase):
139
- """Minimizes a parabola in the direction of the update via two additional forward pass,
140
- and uses another forward pass to make sure it didn't overstep (optionally).
141
- So in total this performs four forward passes.
142
-
143
- It is recommented to put LR before DirectionalNewton3Points
144
-
145
- Two steps are performed in the direction of the update with `lr` learning rate.
146
- A quadratic is fitted to three points, if it has positive curvature,
147
- this makes a step towards the minimum, and checks if lr decreased
148
- with an additional forward pass.
149
-
150
- Args:
151
- for when curvature is negative or loss increases.
152
- max_dist (float | None, optional):
153
- maximum distance to step when minimizing quadratic.
154
- If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
155
- validate_step (bool, optional):
156
- uses an additional forward pass to check
157
- if step towards the minimum actually decreased the loss. Defaults to True.
158
- alpha (float, optional):
159
- epsilon for estimating directional second derivative, also works as learning rate
160
- log_lrs (bool, optional):
161
- saves lrs and losses with them into optimizer._lrs (for debugging).
162
- Defaults to False.
163
-
164
- Note:
165
- While lr scheduling is supported, this uses lr of the first parameter for all parameters.
166
- """
167
- def __init__(self, max_dist: float | None = 1e4, validate_step = True, alpha: float = 1, log_lrs = False,):
168
- super().__init__(dict(alpha = alpha), maxiter=None, log_lrs=log_lrs)
169
-
170
- self.alpha = alpha
171
- self.max_dist = max_dist
172
- self.validate_step = validate_step
173
-
174
- @torch.no_grad
175
- def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
176
- if vars.closure is None: raise ValueError('QuardaticLS requires closure')
177
- closure = vars.closure
178
- ascent_direction = vars.ascent
179
- if ascent_direction is None: raise ValueError('Ascent direction is None')
180
- alpha: float = self.get_first_group_key('alpha')
181
-
182
- if vars.fx0 is None: vars.fx0 = vars.closure(False)
183
- params = self.get_params()
184
-
185
- # make a step in the direction and evaluate f(x2)
186
- y2 = self._evaluate_lr_(alpha, closure, ascent_direction, params)
187
-
188
- # make a step in the direction and evaluate f(x3)
189
- y3 = self._evaluate_lr_(alpha*2, closure, ascent_direction, params)
190
-
191
- # if gradients weren't 0
192
- xmin, a = _newton_step_3points(
193
- 0, vars.fx0,
194
- # we stepped in the direction of minus ascent_direction.
195
- alpha, y2,
196
- alpha * 2, y3
197
- )
198
- xmin = _ensure_float(xmin)
199
-
200
- # make sure curvature is positive
201
- if a > 0:
202
-
203
- # discard very large steps
204
- if self.max_dist is None or xmin <= self.max_dist:
205
-
206
- # if validate_step is enabled, make sure loss didn't increase
207
- if self.validate_step:
208
- y_val = self._evaluate_lr_(xmin, closure, ascent_direction, params)
209
- # if it increased, move back to y2.
210
- if y_val > y2 or y_val > y3:
211
- if y3 > y2: return alpha
212
- else: return alpha * 2
213
-
214
- return xmin
215
-
216
- if y3 > y2: return alpha
217
- else: return alpha * 2
@@ -1,158 +0,0 @@
1
- from typing import Any, Literal
2
- from collections.abc import Sequence
3
-
4
- import numpy as np
5
- import torch
6
-
7
- from ...tensorlist import TensorList
8
- from ...core import _ClosureType, OptimizationVars
9
- from .base_ls import LineSearchBase
10
-
11
- class GridLS(LineSearchBase):
12
- """Test all `lrs` and pick best.
13
-
14
- Args:
15
- lrs (Sequence[float] | np.ndarray | torch.Tensor): sequence of lrs to test.
16
- stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
17
- stop_on_worsened (bool, optional):
18
- stops if next lr loss is worse than previous one.
19
- this assumes that lrs are in ascending order. Defaults to False.
20
- log_lrs (bool, optional):
21
- saves lrs and losses with them into optimizer._lrs (for debugging).
22
- Defaults to False.
23
- """
24
- def __init__(
25
- self,
26
- lrs: Sequence[float] | np.ndarray | torch.Tensor,
27
- stop_on_improvement=False,
28
- stop_on_worsened=False,
29
- log_lrs = False,
30
- ):
31
- super().__init__({}, maxiter=None, log_lrs=log_lrs)
32
- self.lrs = lrs
33
- self.stop_on_improvement = stop_on_improvement
34
- self.stop_on_worsened = stop_on_worsened
35
-
36
- @torch.no_grad
37
- def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
38
- if vars.closure is None: raise ValueError("closure is not set")
39
- if vars.ascent is None: raise ValueError("ascent_direction is not set")
40
-
41
- if self.stop_on_improvement:
42
- if vars.fx0 is None: vars.fx0 = vars.closure(False)
43
- self._lowest_loss = vars.fx0
44
-
45
- for lr in self.lrs:
46
- loss = self._evaluate_lr_(float(lr), vars.closure, vars.ascent, params)
47
-
48
- # if worsened
49
- if self.stop_on_worsened and loss != self._lowest_loss:
50
- break
51
-
52
- # if improved
53
- if self.stop_on_improvement and loss == self._lowest_loss:
54
- break
55
-
56
- return float(self._best_lr)
57
-
58
-
59
-
60
- class MultiplicativeLS(GridLS):
61
- """Starts with `init` lr, then keeps multiplying it by `mul` until loss stops decreasing.
62
-
63
- Args:
64
- init (float, optional): initial lr. Defaults to 0.001.
65
- mul (float, optional): lr multiplier. Defaults to 2.
66
- num (int, optional): maximum number of multiplication steps. Defaults to 10.
67
- stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
68
- stop_on_worsened (bool, optional):
69
- stops if next lr loss is worse than previous one.
70
- this assumes that lrs are in ascending order. Defaults to False.
71
- log_lrs (bool, optional):
72
- saves lrs and losses with them into optimizer._lrs (for debugging).
73
- Defaults to False.
74
- """
75
- def __init__(
76
- self,
77
- init: float = 0.001,
78
- mul: float = 2,
79
- num=10,
80
- stop_on_improvement=False,
81
- stop_on_worsened=True,
82
- ):
83
- super().__init__(
84
- [init * mul**i for i in range(num)],
85
- stop_on_improvement=stop_on_improvement,
86
- stop_on_worsened=stop_on_worsened,
87
- )
88
-
89
- class BacktrackingLS(GridLS):
90
- """tests `init` lr, and keeps multiplying it by `mul` until loss becomes better than initial loss.
91
-
92
- note: this doesn't include Armijo–Goldstein condition.
93
-
94
- Args:
95
- init (float, optional): initial lr. Defaults to 1.
96
- mul (float, optional): lr multiplier. Defaults to 0.5.
97
- num (int, optional): maximum number of multiplication steps. Defaults to 10.
98
- stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
99
- stop_on_worsened (bool, optional):
100
- stops if next lr loss is worse than previous one.
101
- this assumes that lrs are in ascending order. Defaults to False.
102
- log_lrs (bool, optional):
103
- saves lrs and losses with them into optimizer._lrs (for debugging).
104
- Defaults to False.
105
-
106
- """
107
- def __init__(
108
- self,
109
- init: float = 1,
110
- mul: float = 0.5,
111
- num=10,
112
- stop_on_improvement=True,
113
- stop_on_worsened=False,
114
- log_lrs = False,
115
- ):
116
- super().__init__(
117
- [init * mul**i for i in range(num)],
118
- stop_on_improvement=stop_on_improvement,
119
- stop_on_worsened=stop_on_worsened,
120
- log_lrs = log_lrs,
121
- )
122
-
123
- class LinspaceLS(GridLS):
124
- """Test all learning rates from a linspace and pick best."""
125
- def __init__(
126
- self,
127
- start: float = 0.001,
128
- end: float = 2,
129
- steps=10,
130
- stop_on_improvement=False,
131
- stop_on_worsened=False,
132
- log_lrs = False,
133
- ):
134
- super().__init__(
135
- torch.linspace(start, end, steps),
136
- stop_on_improvement=stop_on_improvement,
137
- stop_on_worsened=stop_on_worsened,
138
- log_lrs = log_lrs,
139
- )
140
-
141
- class ArangeLS(GridLS):
142
- """Test all learning rates from a linspace and pick best."""
143
- def __init__(
144
- self,
145
- start: float = 0.001,
146
- end: float = 2,
147
- step=0.1,
148
- stop_on_improvement=False,
149
- stop_on_worsened=False,
150
- log_lrs = False,
151
-
152
- ):
153
- super().__init__(
154
- torch.arange(start, end, step),
155
- stop_on_improvement=stop_on_improvement,
156
- stop_on_worsened=stop_on_worsened,
157
- log_lrs = log_lrs,
158
- )