torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -0,0 +1,181 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Sequence
4
+ from functools import partial
5
+ from operator import itemgetter
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from ...core import Module, Target, Vars
12
+ from ...utils import tofloat
13
+
14
+
15
+ class MaxLineSearchItersReached(Exception): pass
16
+
17
+
18
+ class LineSearch(Module, ABC):
19
+ """Base class for line searches.
20
+ This is an abstract class, to use it, subclass it and override `search`.
21
+
22
+ Args:
23
+ defaults (dict[str, Any] | None): dictionary with defaults.
24
+ maxiter (int | None, optional):
25
+ if this is specified, the search method will terminate upon evaluating
26
+ the objective this many times, and step size with the lowest loss value will be used.
27
+ This is useful when passing `make_objective` to an external library which
28
+ doesn't have a maxiter option. Defaults to None.
29
+ """
30
+ def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
31
+ super().__init__(defaults)
32
+ self._maxiter = maxiter
33
+ self._reset()
34
+
35
+ def _reset(self):
36
+ self._current_step_size: float = 0
37
+ self._lowest_loss = float('inf')
38
+ self._best_step_size: float = 0
39
+ self._current_iter = 0
40
+
41
+ def set_step_size_(
42
+ self,
43
+ step_size: float,
44
+ params: list[torch.Tensor],
45
+ update: list[torch.Tensor],
46
+ ):
47
+ if not math.isfinite(step_size): return
48
+ step_size = max(min(tofloat(step_size), 1e36), -1e36) # fixes overflow when backtracking keeps increasing alpha after converging
49
+ alpha = self._current_step_size - step_size
50
+ if alpha != 0:
51
+ torch._foreach_add_(params, update, alpha=alpha)
52
+ self._current_step_size = step_size
53
+
54
+ def _set_per_parameter_step_size_(
55
+ self,
56
+ step_size: Sequence[float],
57
+ params: list[torch.Tensor],
58
+ update: list[torch.Tensor],
59
+ ):
60
+ if not np.isfinite(step_size): step_size = [0 for _ in step_size]
61
+ alpha = [self._current_step_size - s for s in step_size]
62
+ if any(a!=0 for a in alpha):
63
+ torch._foreach_add_(params, torch._foreach_mul(update, alpha))
64
+
65
+ def _loss(self, step_size: float, vars: Vars, closure, params: list[torch.Tensor],
66
+ update: list[torch.Tensor], backward:bool=False) -> float:
67
+
68
+ # if step_size is 0, we might already know the loss
69
+ if (vars.loss is not None) and (step_size == 0):
70
+ return tofloat(vars.loss)
71
+
72
+ # check max iter
73
+ if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
74
+ self._current_iter += 1
75
+
76
+ # set new lr and evaluate loss with it
77
+ self.set_step_size_(step_size, params=params, update=update)
78
+ if backward:
79
+ with torch.enable_grad(): loss = closure()
80
+ else:
81
+ loss = closure(False)
82
+
83
+ # if it is the best so far, record it
84
+ if loss < self._lowest_loss:
85
+ self._lowest_loss = tofloat(loss)
86
+ self._best_step_size = step_size
87
+
88
+ # if evaluated loss at step size 0, set it to vars.loss
89
+ if step_size == 0:
90
+ vars.loss = loss
91
+ if backward: vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
92
+
93
+ return tofloat(loss)
94
+
95
+ def _loss_derivative(self, step_size: float, vars: Vars, closure,
96
+ params: list[torch.Tensor], update: list[torch.Tensor]):
97
+ # if step_size is 0, we might already know the derivative
98
+ if (vars.grad is not None) and (step_size == 0):
99
+ loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=False)
100
+ derivative = - sum(t.sum() for t in torch._foreach_mul(vars.grad, update))
101
+
102
+ else:
103
+ # loss with a backward pass sets params.grad
104
+ loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=True)
105
+
106
+ # directional derivative
107
+ derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
108
+ else torch.zeros_like(p) for p in params], update))
109
+
110
+ return loss, tofloat(derivative)
111
+
112
+ def evaluate_step_size(self, step_size: float, vars: Vars, backward:bool=False):
113
+ closure = vars.closure
114
+ if closure is None: raise RuntimeError('line search requires closure')
115
+ return self._loss(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update(),backward=backward)
116
+
117
+ def evaluate_step_size_loss_and_derivative(self, step_size: float, vars: Vars):
118
+ closure = vars.closure
119
+ if closure is None: raise RuntimeError('line search requires closure')
120
+ return self._loss_derivative(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update())
121
+
122
+ def make_objective(self, vars: Vars, backward:bool=False):
123
+ closure = vars.closure
124
+ if closure is None: raise RuntimeError('line search requires closure')
125
+ return partial(self._loss, vars=vars, closure=closure, params=vars.params, update=vars.get_update(), backward=backward)
126
+
127
+ def make_objective_with_derivative(self, vars: Vars):
128
+ closure = vars.closure
129
+ if closure is None: raise RuntimeError('line search requires closure')
130
+ return partial(self._loss_derivative, vars=vars, closure=closure, params=vars.params, update=vars.get_update())
131
+
132
+ @abstractmethod
133
+ def search(self, update: list[torch.Tensor], vars: Vars) -> float:
134
+ """Finds the step size to use"""
135
+
136
+ @torch.no_grad
137
+ def step(self, vars: Vars) -> Vars:
138
+ self._reset()
139
+ params = vars.params
140
+ update = vars.get_update()
141
+
142
+ try:
143
+ step_size = self.search(update=update, vars=vars)
144
+ except MaxLineSearchItersReached:
145
+ step_size = self._best_step_size
146
+
147
+ # set loss_approx
148
+ if vars.loss_approx is None: vars.loss_approx = self._lowest_loss
149
+
150
+ # this is last module - set step size to found step_size times lr
151
+ if vars.is_last:
152
+
153
+ if vars.last_module_lrs is None:
154
+ self.set_step_size_(step_size, params=params, update=update)
155
+
156
+ else:
157
+ self._set_per_parameter_step_size_([step_size*lr for lr in vars.last_module_lrs], params=params, update=update)
158
+
159
+ vars.stop = True; vars.skip_update = True
160
+ return vars
161
+
162
+ # revert parameters and multiply update by step size
163
+ self.set_step_size_(0, params=params, update=update)
164
+ torch._foreach_mul_(vars.update, step_size)
165
+ return vars
166
+
167
+
168
+ class GridLineSearch(LineSearch):
169
+ """Mostly for testing, this is not practical"""
170
+ def __init__(self, start, end, num):
171
+ defaults = dict(start=start,end=end,num=num)
172
+ super().__init__(defaults)
173
+
174
+ @torch.no_grad
175
+ def search(self, update, vars):
176
+ start,end,num=itemgetter('start','end','num')(self.settings[vars.params[0]])
177
+
178
+ for lr in torch.linspace(start,end,num):
179
+ self.evaluate_step_size(lr.item(), vars=vars, backward=False)
180
+
181
+ return self._best_step_size
@@ -0,0 +1,37 @@
1
+ from collections.abc import Mapping
2
+ from operator import itemgetter
3
+
4
+ import torch
5
+
6
+ from .line_search import LineSearch
7
+
8
+
9
+ class ScipyMinimizeScalar(LineSearch):
10
+ def __init__(
11
+ self,
12
+ method: str | None = None,
13
+ maxiter: int | None = None,
14
+ bracket=None,
15
+ bounds=None,
16
+ tol: float | None = None,
17
+ options=None,
18
+ ):
19
+ defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
20
+ super().__init__(defaults)
21
+
22
+ import scipy.optimize
23
+ self.scopt = scipy.optimize
24
+
25
+
26
+ @torch.no_grad
27
+ def search(self, update, vars):
28
+ objective = self.make_objective(vars=vars)
29
+ method, bracket, bounds, tol, options, maxiter = itemgetter(
30
+ 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[vars.params[0]])
31
+
32
+ if maxiter is not None:
33
+ options = dict(options) if isinstance(options, Mapping) else {}
34
+ options['maxiter'] = maxiter
35
+
36
+ res = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
37
+ return res.x
@@ -0,0 +1,260 @@
1
+ import math
2
+ import warnings
3
+ from operator import itemgetter
4
+
5
+ import torch
6
+ from torch.optim.lbfgs import _cubic_interpolate
7
+
8
+ from .line_search import LineSearch
9
+ from .backtracking import backtracking_line_search
10
+ from ...utils import totensor
11
+
12
+
13
+ def _zoom(f,
14
+ a_l, a_h,
15
+ f_l, g_l,
16
+ f_h, g_h,
17
+ f_0, g_0,
18
+ c1, c2,
19
+ maxzoom):
20
+
21
+ for i in range(maxzoom):
22
+ a_j = _cubic_interpolate(
23
+ *(totensor(i) for i in (a_l, f_l, g_l, a_h, f_h, g_h))
24
+
25
+ )
26
+
27
+ # if interpolation fails or produces endpoint, bisect
28
+ delta = abs(a_h - a_l)
29
+ if a_j is None or a_j == a_l or a_j == a_h:
30
+ a_j = a_l + 0.5 * delta
31
+
32
+
33
+ f_j, g_j = f(a_j)
34
+
35
+ # check armijo
36
+ armijo = f_j <= f_0 + c1 * a_j * g_0
37
+
38
+ # check strong wolfe
39
+ wolfe = abs(g_j) <= c2 * abs(g_0)
40
+
41
+
42
+ # minimum between alpha_low and alpha_j
43
+ if not armijo or f_j >= f_l:
44
+ a_h = a_j
45
+ f_h = f_j
46
+ g_h = g_j
47
+ else:
48
+ # alpha_j satisfies armijo
49
+ if wolfe:
50
+ return a_j, f_j
51
+
52
+ # minimum between alpha_j and alpha_high
53
+ if g_j * (a_h - a_l) >= 0:
54
+ # between alpha_low and alpha_j
55
+ # a_h = a_l
56
+ # f_h = f_l
57
+ # g_h = g_l
58
+ a_h = a_j
59
+ f_h = f_j
60
+ g_h = g_j
61
+
62
+ # is this messing it up?
63
+ else:
64
+ a_l = a_j
65
+ f_l = f_j
66
+ g_l = g_j
67
+
68
+
69
+
70
+
71
+ # check if interval too small
72
+ delta = abs(a_h - a_l)
73
+ if delta <= 1e-9 or delta <= 1e-6 * max(abs(a_l), abs(a_h)):
74
+ l_satisfies_wolfe = (f_l <= f_0 + c1 * a_l * g_0) and (abs(g_l) <= c2 * abs(g_0))
75
+ h_satisfies_wolfe = (f_h <= f_0 + c1 * a_h * g_0) and (abs(g_h) <= c2 * abs(g_0))
76
+
77
+ if l_satisfies_wolfe and h_satisfies_wolfe: return a_l if f_l <= f_h else a_h, f_h
78
+ if l_satisfies_wolfe: return a_l, f_l
79
+ if h_satisfies_wolfe: return a_h, f_h
80
+ if f_l <= f_0 + c1 * a_l * g_0: return a_l, f_l
81
+ return None,None
82
+
83
+ if a_j is None or a_j == a_l or a_j == a_h:
84
+ a_j = a_l + 0.5 * delta
85
+
86
+
87
+ return None,None
88
+
89
+
90
+ def strong_wolfe(
91
+ f,
92
+ f_0,
93
+ g_0,
94
+ init: float = 1.0,
95
+ c1: float = 1e-4,
96
+ c2: float = 0.9,
97
+ maxiter: int = 25,
98
+ maxzoom: int = 15,
99
+ # a_max: float = 1e30,
100
+ expand: float = 2.0, # Factor to increase alpha in bracketing
101
+ plus_minus: bool = False,
102
+ ) -> tuple[float,float] | tuple[None,None]:
103
+ a_prev = 0.0
104
+
105
+ if g_0 == 0: return None,None
106
+ if g_0 > 0:
107
+ # if direction is not a descent direction, perform line search in opposite direction
108
+ if plus_minus:
109
+ def inverted_objective(alpha):
110
+ l, g = f(-alpha)
111
+ return l, -g
112
+ a, v = strong_wolfe(
113
+ inverted_objective,
114
+ init=init,
115
+ f_0=f_0,
116
+ g_0=-g_0,
117
+ c1=c1,
118
+ c2=c2,
119
+ maxiter=maxiter,
120
+ # a_max=a_max,
121
+ expand=expand,
122
+ plus_minus=False,
123
+ )
124
+ if a is not None and v is not None: return -a, v
125
+ return None, None
126
+
127
+ f_prev = f_0
128
+ g_prev = g_0
129
+ a_cur = init
130
+
131
+ # bracket
132
+ for i in range(maxiter):
133
+
134
+ f_cur, g_cur = f(a_cur)
135
+
136
+ # check armijo
137
+ armijo_violated = f_cur > f_0 + c1 * a_cur * g_0
138
+ func_increased = f_cur >= f_prev and i > 0
139
+
140
+ if armijo_violated or func_increased:
141
+ return _zoom(f,
142
+ a_prev, a_cur,
143
+ f_prev, g_prev,
144
+ f_cur, g_cur,
145
+ f_0, g_0,
146
+ c1, c2,
147
+ maxzoom=maxzoom,
148
+ )
149
+
150
+
151
+
152
+ # check strong wolfe
153
+ if abs(g_cur) <= c2 * abs(g_0):
154
+ return a_cur, f_cur
155
+
156
+ # minimum is bracketed
157
+ if g_cur >= 0:
158
+ return _zoom(f,
159
+ #alpha_curr, alpha_prev,
160
+ a_prev, a_cur,
161
+ #phi_curr, phi_prime_curr,
162
+ f_prev, g_prev,
163
+ f_cur, g_cur,
164
+ f_0, g_0,
165
+ c1, c2,
166
+ maxzoom=maxzoom,)
167
+
168
+ # otherwise continue bracketing
169
+ a_next = a_cur * expand
170
+
171
+ # update previous point and continue loop with increased step size
172
+ a_prev = a_cur
173
+ f_prev = f_cur
174
+ g_prev = g_cur
175
+ a_cur = a_next
176
+
177
+
178
+ # max iters reached
179
+ return None, None
180
+
181
+ def _notfinite(x):
182
+ if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
183
+ return not math.isfinite(x)
184
+
185
+ class StrongWolfe(LineSearch):
186
+ def __init__(
187
+ self,
188
+ init: float = 1.0,
189
+ c1: float = 1e-4,
190
+ c2: float = 0.9,
191
+ maxiter: int = 25,
192
+ maxzoom: int = 10,
193
+ # a_max: float = 1e10,
194
+ expand: float = 2.0,
195
+ adaptive = True,
196
+ fallback = False,
197
+ plus_minus = False,
198
+ ):
199
+ defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
200
+ expand=expand, adaptive=adaptive, fallback=fallback, plus_minus=plus_minus)
201
+ super().__init__(defaults=defaults)
202
+
203
+ self.global_state['initial_scale'] = 1.0
204
+ self.global_state['beta_scale'] = 1.0
205
+
206
+ @torch.no_grad
207
+ def search(self, update, vars):
208
+ objective = self.make_objective_with_derivative(vars=vars)
209
+
210
+ init, c1, c2, maxiter, maxzoom, expand, adaptive, fallback, plus_minus = itemgetter(
211
+ 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
212
+ 'expand', 'adaptive', 'fallback', 'plus_minus')(self.settings[vars.params[0]])
213
+
214
+ f_0, g_0 = objective(0)
215
+
216
+ step_size,f_a = strong_wolfe(
217
+ objective,
218
+ f_0=f_0, g_0=g_0,
219
+ init=init * self.global_state.setdefault("initial_scale", 1),
220
+ c1=c1,
221
+ c2=c2,
222
+ maxiter=maxiter,
223
+ maxzoom=maxzoom,
224
+ expand=expand,
225
+ plus_minus=plus_minus,
226
+ )
227
+
228
+ if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
229
+ if step_size is not None and step_size != 0 and not _notfinite(step_size):
230
+ self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
231
+ return step_size
232
+
233
+ # fallback to backtracking on fail
234
+ if adaptive: self.global_state['initial_scale'] *= 0.5
235
+ if not fallback: return 0
236
+
237
+ objective = self.make_objective(vars=vars)
238
+
239
+ # # directional derivative
240
+ g_0 = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
241
+
242
+ step_size = backtracking_line_search(
243
+ objective,
244
+ g_0,
245
+ init=init * self.global_state["initial_scale"],
246
+ beta=0.5 * self.global_state["beta_scale"],
247
+ c=c1,
248
+ maxiter=maxiter * 2,
249
+ a_min=None,
250
+ try_negative=plus_minus,
251
+ )
252
+
253
+ # found an alpha that reduces loss
254
+ if step_size is not None:
255
+ self.global_state['beta_scale'] = min(1.0, self.global_state.get('beta_scale', 1) * math.sqrt(1.5))
256
+ return step_size
257
+
258
+ # on fail reduce beta scale value
259
+ self.global_state['beta_scale'] /= 1.5
260
+ return 0
@@ -0,0 +1,61 @@
1
+ from operator import itemgetter
2
+
3
+ import torch
4
+
5
+ from .line_search import LineSearch
6
+
7
+
8
+ class TrustRegion(LineSearch):
9
+ """Basic first order trust region, re-evaluates closure with updated parameters and scales step size based on function value change"""
10
+ def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
11
+ defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
12
+ super().__init__(defaults)
13
+
14
+ @torch.no_grad
15
+ def search(self, update, vars):
16
+
17
+ nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[vars.params[0]])
18
+ step_size = self.global_state.setdefault('step_size', init)
19
+ previous_success = self.global_state.setdefault('previous_success', False)
20
+ nplus_mul = self.global_state.setdefault('nplus_mul', 1)
21
+ nminus_mul = self.global_state.setdefault('nminus_mul', 1)
22
+
23
+
24
+ f_0 = self.evaluate_step_size(0, vars, backward=False)
25
+
26
+ # directional derivative (0 if c = 0 because it is not needed)
27
+ if c == 0: d = 0
28
+ else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
29
+
30
+ # test step size
31
+ sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
32
+
33
+ f_1 = self.evaluate_step_size(step_size, vars, backward=False)
34
+
35
+ proposed = step_size
36
+
37
+ # very good step
38
+ if f_1 < sufficient_f:
39
+ self.global_state['step_size'] *= nplus * nplus_mul
40
+
41
+ # two very good steps in a row - increase nplus_mul
42
+ if adaptive:
43
+ if previous_success: self.global_state['nplus_mul'] *= nplus
44
+ else: self.global_state['nplus_mul'] = 1
45
+
46
+ # acceptable step step
47
+ #elif f_1 <= f_0: pass
48
+
49
+ # bad step
50
+ if f_1 >= f_0:
51
+ self.global_state['step_size'] *= nminus * nminus_mul
52
+
53
+ # two bad steps in a row - decrease nminus_mul
54
+ if adaptive:
55
+ if previous_success: self.global_state['nminus_mul'] *= nminus
56
+ else: self.global_state['nminus_mul'] = 1
57
+
58
+ if backtrack: proposed = 0
59
+ else: proposed *= nminus * nminus_mul
60
+
61
+ return proposed
@@ -0,0 +1,2 @@
1
+ from .lr import LR, StepSize, Warmup
2
+ from .step_size import PolyakStepSize, RandomStepSize
@@ -0,0 +1,59 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...utils import NumberList, TensorList, generic_eq
5
+
6
+
7
+ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
8
+ """multiplies by lr if lr is not 1"""
9
+ if generic_eq(lr, 1): return tensors
10
+ if inplace: return tensors.mul_(lr)
11
+ return tensors * lr
12
+
13
+ class LR(Transform):
14
+ def __init__(self, lr: float):
15
+ defaults=dict(lr=lr)
16
+ super().__init__(defaults, uses_grad=False)
17
+
18
+ @torch.no_grad
19
+ def transform(self, tensors, params, grads, vars):
20
+ return lazy_lr(TensorList(tensors), lr=self.get_settings('lr', params=params), inplace=True)
21
+
22
+ class StepSize(Transform):
23
+ """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name"""
24
+ def __init__(self, step_size: float, key = 'step_size'):
25
+ defaults={"key": key, key: step_size}
26
+ super().__init__(defaults, uses_grad=False)
27
+
28
+ @torch.no_grad
29
+ def transform(self, tensors, params, grads, vars):
30
+ lrs = []
31
+ for p in params:
32
+ settings = self.settings[p]
33
+ lrs.append(settings[settings['key']])
34
+ return lazy_lr(TensorList(tensors), lr=lrs, inplace=True)
35
+
36
+
37
+ def warmup(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
38
+ """returns warm up lr scalar"""
39
+ if step > steps: return end_lr
40
+ return start_lr + (end_lr - start_lr) * (step / steps)
41
+
42
+ class Warmup(Transform):
43
+ def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
44
+ defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
45
+ super().__init__(defaults, uses_grad=False)
46
+
47
+ @torch.no_grad
48
+ def transform(self, tensors, params, grads, vars):
49
+ start_lr, end_lr = self.get_settings('start_lr', 'end_lr', params=params, cls = NumberList)
50
+ num_steps = self.settings[params[0]]['steps']
51
+ step = self.global_state.get('step', 0)
52
+
53
+ target = lazy_lr(
54
+ TensorList(tensors),
55
+ lr=warmup(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
56
+ inplace=True
57
+ )
58
+ self.global_state['step'] = step + 1
59
+ return target
@@ -0,0 +1,97 @@
1
+ import random
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ from ...core import Transform
7
+ from ...utils import TensorList, NumberList
8
+
9
+
10
+ class PolyakStepSize(Transform):
11
+ """Polyak step-size.
12
+
13
+ Args:
14
+ max (float | None, optional): maximum possible step size. Defaults to None.
15
+ min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
16
+ use_grad (bool, optional):
17
+ if True, uses dot product of update and gradient to compute the step size.
18
+ Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
19
+ Defaults to True.
20
+ parameterwise (bool, optional):
21
+ if True, calculate Polyak step-size for each parameter separately,
22
+ if False calculate one global step size for all parameters. Defaults to False.
23
+ alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
24
+ """
25
+ def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
26
+
27
+ defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
28
+ super().__init__(defaults, uses_grad=use_grad)
29
+
30
+ @torch.no_grad
31
+ def transform(self, tensors, params, grads, vars):
32
+ loss = vars.get_loss(False)
33
+ assert grads is not None
34
+ tensors = TensorList(tensors)
35
+ grads = TensorList(grads)
36
+ alpha = self.get_settings('alpha', params=params, cls=NumberList)
37
+ settings = self.settings[params[0]]
38
+ parameterwise = settings['parameterwise']
39
+ use_grad = settings['use_grad']
40
+ max = settings['max']
41
+ min_obj_value = settings['min_obj_value']
42
+
43
+ if parameterwise:
44
+ if use_grad: denom = (tensors * grads).sum()
45
+ else: denom = tensors.pow(2).sum()
46
+ polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
47
+ polyak_step_size = polyak_step_size.where(denom != 0, 0)
48
+ if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
49
+
50
+ else:
51
+ if use_grad: denom = tensors.dot(grads)
52
+ else: denom = tensors.dot(tensors)
53
+ if denom == 0: polyak_step_size = 0 # we converged
54
+ else: polyak_step_size = (loss - min_obj_value) / denom
55
+
56
+ if max is not None:
57
+ if polyak_step_size > max: polyak_step_size = max
58
+
59
+ tensors.mul_(alpha * polyak_step_size)
60
+ return tensors
61
+
62
+
63
+
64
+ class RandomStepSize(Transform):
65
+ """Uses random global step size from `low` to `high`.
66
+
67
+ Args:
68
+ low (float, optional): minimum learning rate. Defaults to 0.
69
+ high (float, optional): maximum learning rate. Defaults to 1.
70
+ parameterwise (bool, optional):
71
+ if True, generate random step size for each parameter separately,
72
+ if False generate one global random step size. Defaults to False.
73
+ """
74
+ def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
75
+ defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
76
+ super().__init__(defaults, uses_grad=False)
77
+
78
+ @torch.no_grad
79
+ def transform(self, tensors, params, grads, vars):
80
+ settings = self.settings[params[0]]
81
+ parameterwise = settings['parameterwise']
82
+
83
+ seed = settings['seed']
84
+ if 'generator' not in self.global_state:
85
+ self.global_state['generator'] = random.Random(seed)
86
+ generator: random.Random = self.global_state['generator']
87
+
88
+ if parameterwise:
89
+ low, high = self.get_settings('low', 'high', params=params)
90
+ lr = [generator.uniform(l, h) for l, h in zip(low, high)]
91
+ else:
92
+ low = settings['low']
93
+ high = settings['high']
94
+ lr = generator.uniform(low, high)
95
+
96
+ torch._foreach_mul_(tensors, lr)
97
+ return tensors