torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,319 @@
1
+ import itertools
2
+ import math
3
+ import warnings
4
+ from collections.abc import Callable
5
+ from contextlib import nullcontext
6
+ from functools import partial
7
+ from typing import Any, Literal
8
+
9
+ import numpy as np
10
+ import scipy.optimize
11
+ import torch
12
+
13
+ from ...core import Chainable, Module, apply_transform
14
+ from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
+ from ...utils.derivatives import (
16
+ hessian_list_to_mat,
17
+ jacobian_wrt,
18
+ )
19
+
20
+ _LETTERS = 'abcdefghijklmnopqrstuvwxyz'
21
+ def _poly_eval(s: np.ndarray, c, derivatives):
22
+ val = float(c)
23
+ for i,T in enumerate(derivatives, 1):
24
+ s1 = ''.join(_LETTERS[:i]) # abcd
25
+ s2 = ',...'.join(_LETTERS[:i]) # a,b,c,d
26
+ # this would make einsum('abcd,a,b,c,d', T, x, x, x, x)
27
+ val += np.einsum(f"...{s1},...{s2}", T, *(s for _ in range(i))) / math.factorial(i)
28
+ return val
29
+
30
+ def _proximal_poly_v(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
31
+ if x.ndim == 2: x = x.T # DE passes (ndim, batch_size)
32
+ s = x - x0
33
+ val = _poly_eval(s, c, derivatives)
34
+ penalty = 0
35
+ if prox != 0: penalty = (prox / 2) * (s**2).sum(-1) # proximal penalty
36
+ return val + penalty
37
+
38
+ def _proximal_poly_g(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
39
+ s = x - x0
40
+ g = derivatives[0].copy()
41
+ if len(derivatives) > 1:
42
+ for i, T in enumerate(derivatives[1:], 2):
43
+ s1 = ''.join(_LETTERS[:i]) # abcd
44
+ s2 = ','.join(_LETTERS[1:i]) # b,c,d
45
+ # this would make einsum('abcd,b,c,d->a', T, x, x, x)
46
+ g += np.einsum(f"{s1},{s2}->a", T, *(s for _ in range(i-1))) / math.factorial(i - 1)
47
+
48
+ g_prox = 0
49
+ if prox != 0: g_prox = prox * s
50
+ return g + g_prox
51
+
52
+ def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
53
+ s = x - x0
54
+ n = x.shape[0]
55
+ if len(derivatives) == 1:
56
+ H = np.zeros(n, n)
57
+ else:
58
+ H = derivatives[1].copy()
59
+ if len(derivatives) > 2:
60
+ for i, T in enumerate(derivatives[2:], 3):
61
+ s1 = ''.join(_LETTERS[:i]) # abcd
62
+ s2 = ','.join(_LETTERS[2:i]) # c,d
63
+ # this would make einsum('abcd,c,d->ab', T, x, x, x)
64
+ H += np.einsum(f"{s1},{s2}->ab", T, *(s for _ in range(i-2))) / math.factorial(i - 2)
65
+
66
+ H_prox = 0
67
+ if prox != 0: H_prox = np.eye(n) * prox
68
+ return H + H_prox
69
+
70
+ def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
71
+ derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
72
+ x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
73
+
74
+ # notes
75
+ # 1. since we have exact hessian we use trust methods
76
+
77
+ # 2. if len(derivatives) is 1, only gradient is available,
78
+ # thus use slsqp depending on whether trust region is enabled
79
+ # this is just so that I can test that trust region works
80
+ if trust_region is None:
81
+ if len(derivatives) == 1: raise RuntimeError("trust region must be enabled because 1st order has no minima")
82
+ method = 'trust-exact'
83
+ de_bounds = list(zip(x0 - 10, x0 + 10))
84
+ constraints = None
85
+
86
+ else:
87
+ if len(derivatives) == 1: method = 'slsqp'
88
+ else: method = 'trust-constr'
89
+ de_bounds = list(zip(x0 - trust_region, x0 + trust_region))
90
+
91
+ def l2_bound_f(x):
92
+ if x.ndim == 2: return np.sum((x - x0[:,None])**2, axis=0)[None,:] # DE passes (ndim, batch_size) and expects (M, S)
93
+ return np.sum((x - x0)**2, axis=0)
94
+
95
+ def l2_bound_g(x):
96
+ return 2 * (x - x0)
97
+
98
+ def l2_bound_h(x, v):
99
+ return v[0] * 2 * np.eye(x0.shape[0])
100
+
101
+ constraint = scipy.optimize.NonlinearConstraint(
102
+ fun=l2_bound_f,
103
+ lb=0, # 0 <= ||x-x0||^2
104
+ ub=trust_region**2, # ||x-x0||^2 <= R^2
105
+ jac=l2_bound_g, # pyright:ignore[reportArgumentType]
106
+ hess=l2_bound_h,
107
+ keep_feasible=False
108
+ )
109
+ constraints = [constraint]
110
+
111
+ x_init = x0.copy()
112
+ v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
113
+
114
+ # ---------------------------------- run DE ---------------------------------- #
115
+ if de_iters is not None and de_iters != 0:
116
+ if de_iters == -1: de_iters = None # let scipy decide
117
+
118
+ # DE needs bounds so use linf ig
119
+ res = scipy.optimize.differential_evolution(
120
+ _proximal_poly_v,
121
+ de_bounds,
122
+ args=(c, prox, x0.copy(), derivatives),
123
+ maxiter=de_iters,
124
+ vectorized=True,
125
+ constraints = constraints,
126
+ updating='deferred',
127
+ )
128
+ if res.fun < v0 and np.all(np.isfinite(res.x)): x_init = res.x
129
+
130
+ # ------------------------------- run minimize ------------------------------- #
131
+ try:
132
+ res = scipy.optimize.minimize(
133
+ _proximal_poly_v,
134
+ x_init,
135
+ method=method,
136
+ args=(c, prox, x0.copy(), derivatives),
137
+ jac=_proximal_poly_g,
138
+ hess=_proximal_poly_H,
139
+ constraints = constraints,
140
+ )
141
+ except ValueError:
142
+ return x, -float('inf')
143
+ return torch.from_numpy(res.x).to(x), res.fun
144
+
145
+
146
+
147
+ class HigherOrderNewton(Module):
148
+ """A basic arbitrary order newton's method with optional trust region and proximal penalty.
149
+
150
+ This constructs an nth order taylor approximation via autograd and minimizes it with
151
+ scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
152
+
153
+ .. note::
154
+ In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
155
+
156
+ .. note::
157
+ This module requires the a closure passed to the optimizer step,
158
+ as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
159
+ The closure must accept a ``backward`` argument (refer to documentation).
160
+
161
+ .. warning::
162
+ this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
163
+
164
+ .. warning::
165
+ "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
166
+
167
+ Args:
168
+
169
+ order (int, optional):
170
+ Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
171
+ trust_method (str | None, optional):
172
+ Method used for trust region.
173
+ - "bounds" - the model is minimized within bounds defined by trust region.
174
+ - "proximal" - the model is minimized with penalty for going too far from current point.
175
+ - "none" - disables trust region.
176
+
177
+ Defaults to 'bounds'.
178
+ increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
179
+ decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
180
+ trust_init (float | None, optional):
181
+ initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on :code:`"proximal"`. Defaults to None.
182
+ trust_tol (float, optional):
183
+ Maximum ratio of expected loss reduction to actual reduction for trust region increase.
184
+ Should 1 or higer. Defaults to 2.
185
+ de_iters (int | None, optional):
186
+ If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
187
+ then it is passed to scipy.optimize.minimize. Defaults to None.
188
+ vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
189
+ """
190
+ def __init__(
191
+ self,
192
+ order: int = 4,
193
+ trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
194
+ nplus: float = 2,
195
+ nminus: float = 0.25,
196
+ init: float | None = None,
197
+ eta: float = 1e-6,
198
+ max_attempts = 10,
199
+ de_iters: int | None = None,
200
+ vectorize: bool = True,
201
+ ):
202
+ if init is None:
203
+ if trust_method == 'bounds': init = 1
204
+ else: init = 0.1
205
+
206
+ defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
207
+ super().__init__(defaults)
208
+
209
+ @torch.no_grad
210
+ def step(self, var):
211
+ params = TensorList(var.params)
212
+ closure = var.closure
213
+ if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
214
+
215
+ settings = self.settings[params[0]]
216
+ order = settings['order']
217
+ nplus = settings['nplus']
218
+ nminus = settings['nminus']
219
+ eta = settings['eta']
220
+ init = settings['init']
221
+ trust_method = settings['trust_method']
222
+ de_iters = settings['de_iters']
223
+ max_attempts = settings['max_attempts']
224
+ vectorize = settings['vectorize']
225
+
226
+ # ------------------------ calculate grad and hessian ------------------------ #
227
+ with torch.enable_grad():
228
+ loss = var.loss = var.loss_approx = closure(False)
229
+
230
+ g_list = torch.autograd.grad(loss, params, create_graph=True)
231
+ var.grad = list(g_list)
232
+
233
+ g = torch.cat([t.ravel() for t in g_list])
234
+ n = g.numel()
235
+ derivatives = [g]
236
+ T = g # current derivatives tensor
237
+
238
+ # get all derivative up to order
239
+ for o in range(2, order + 1):
240
+ is_last = o == order
241
+ T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
242
+ with torch.no_grad() if is_last else nullcontext():
243
+ # the shape is (ndim, ) * order
244
+ T = hessian_list_to_mat(T_list).view(n, n, *T.shape[1:])
245
+ derivatives.append(T)
246
+
247
+ x0 = torch.cat([p.ravel() for p in params])
248
+
249
+ success = False
250
+ x_star = None
251
+ while not success:
252
+ max_attempts -= 1
253
+ if max_attempts < 0: break
254
+
255
+ # load trust region value
256
+ trust_value = self.global_state.get('trust_region', init)
257
+ if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
258
+
259
+ if trust_method is None: trust_method = 'none'
260
+ else: trust_method = trust_method.lower()
261
+
262
+ if trust_method == 'none':
263
+ trust_region = None
264
+ prox = 0
265
+
266
+ elif trust_method == 'bounds':
267
+ trust_region = trust_value
268
+ prox = 0
269
+
270
+ elif trust_method == 'proximal':
271
+ trust_region = None
272
+ prox = 1 / trust_value
273
+
274
+ else:
275
+ raise ValueError(trust_method)
276
+
277
+ # minimize the model
278
+ x_star, expected_loss = _poly_minimize(
279
+ trust_region=trust_region,
280
+ prox=prox,
281
+ de_iters=de_iters,
282
+ c=loss.item(),
283
+ x=x0,
284
+ derivatives=derivatives,
285
+ )
286
+
287
+ # update trust region
288
+ if trust_method == 'none':
289
+ success = True
290
+ else:
291
+ pred_reduction = loss - expected_loss
292
+
293
+ vec_to_tensors_(x_star, params)
294
+ loss_star = closure(False)
295
+ vec_to_tensors_(x0, params)
296
+ reduction = loss - loss_star
297
+
298
+ rho = reduction / (max(pred_reduction, 1e-8))
299
+ # failed step
300
+ if rho < 0.25:
301
+ self.global_state['trust_region'] = trust_value * nminus
302
+
303
+ # very good step
304
+ elif rho > 0.75:
305
+ diff = trust_value - (x0 - x_star).abs_()
306
+ if (diff.amin() / trust_value) > 1e-4: # hits boundary
307
+ self.global_state['trust_region'] = trust_value * nplus
308
+
309
+ # if the ratio is high enough then accept the proposed step
310
+ success = rho > eta
311
+
312
+ assert x_star is not None
313
+ if success:
314
+ difference = vec_to_tensors(x0 - x_star, params)
315
+ var.update = list(difference)
316
+ else:
317
+ var.update = params.zeros_like()
318
+ return var
319
+
@@ -1,5 +1,5 @@
1
- from .line_search import LineSearch, GridLineSearch
2
- from .backtracking import backtracking_line_search, Backtracking, AdaptiveBacktracking
3
- from .strong_wolfe import StrongWolfe
1
+ from .adaptive import AdaptiveLineSearch
2
+ from .backtracking import AdaptiveBacktracking, Backtracking
3
+ from .line_search import LineSearchBase
4
4
  from .scipy import ScipyMinimizeScalar
5
- from .trust_region import TrustRegion
5
+ from .strong_wolfe import StrongWolfe
@@ -0,0 +1,99 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from operator import itemgetter
4
+
5
+ import torch
6
+
7
+ from .line_search import LineSearchBase
8
+
9
+
10
+
11
+ def adaptive_tracking(
12
+ f,
13
+ x_0,
14
+ maxiter: int,
15
+ nplus: float = 2,
16
+ nminus: float = 0.5,
17
+ ):
18
+ f_0 = f(0)
19
+
20
+ t = x_0
21
+ f_t = f(t)
22
+
23
+ # backtrack
24
+ if f_t > f_0:
25
+ while f_t > f_0:
26
+ maxiter -= 1
27
+ if maxiter < 0: return 0, f_0
28
+ t = t*nminus
29
+ f_t = f(t)
30
+ return t, f_t
31
+
32
+ # forwardtrack
33
+ f_prev = f_t
34
+ t *= nplus
35
+ f_t = f(t)
36
+ if f_prev < f_t: return t / nplus, f_prev
37
+ while f_prev >= f_t:
38
+ maxiter -= 1
39
+ if maxiter < 0: return t, f_t
40
+ f_prev = f_t
41
+ t *= nplus
42
+ f_t = f(t)
43
+ return t / nplus, f_prev
44
+
45
+ class AdaptiveLineSearch(LineSearchBase):
46
+ """Adaptive line search, similar to backtracking but also has forward tracking mode.
47
+ Currently doesn't check for weak curvature condition.
48
+
49
+ Args:
50
+ init (float, optional): initial step size. Defaults to 1.0.
51
+ beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
52
+ maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
53
+ adaptive (bool, optional):
54
+ when enabled, if line search failed, beta size is reduced.
55
+ Otherwise it is reset to initial value. Defaults to True.
56
+ """
57
+ def __init__(
58
+ self,
59
+ init: float = 1.0,
60
+ nplus: float = 2,
61
+ nminus: float = 0.5,
62
+ maxiter: int = 10,
63
+ adaptive=True,
64
+ ):
65
+ defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive,)
66
+ super().__init__(defaults=defaults)
67
+ self.global_state['beta_scale'] = 1.0
68
+
69
+ def reset(self):
70
+ super().reset()
71
+ self.global_state['beta_scale'] = 1.0
72
+
73
+ @torch.no_grad
74
+ def search(self, update, var):
75
+ init, nplus, nminus, maxiter, adaptive = itemgetter(
76
+ 'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.settings[var.params[0]])
77
+
78
+ objective = self.make_objective(var=var)
79
+
80
+ # # directional derivative
81
+ # d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
82
+
83
+ # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
84
+ beta_scale = self.global_state.get('beta_scale', 1)
85
+ x_prev = self.global_state.get('prev_x', 1)
86
+
87
+ if adaptive: nminus = nminus * beta_scale
88
+
89
+
90
+ step_size, f = adaptive_tracking(objective, x_prev, maxiter, nplus=nplus, nminus=nminus)
91
+
92
+ # found an alpha that reduces loss
93
+ if step_size != 0:
94
+ self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
95
+ return step_size
96
+
97
+ # on fail reduce beta scale value
98
+ self.global_state['beta_scale'] /= 1.5
99
+ return 0
@@ -4,7 +4,7 @@ from operator import itemgetter
4
4
 
5
5
  import torch
6
6
 
7
- from .line_search import LineSearch
7
+ from .line_search import LineSearchBase
8
8
 
9
9
 
10
10
  def backtracking_line_search(
@@ -14,19 +14,17 @@ def backtracking_line_search(
14
14
  beta: float = 0.5,
15
15
  c: float = 1e-4,
16
16
  maxiter: int = 10,
17
- a_min: float | None = None,
18
17
  try_negative: bool = False,
19
18
  ) -> float | None:
20
19
  """
21
20
 
22
21
  Args:
23
- objective_fn: evaluates step size along some descent direction.
24
- dir_derivative: directional derivative along the descent direction.
25
- alpha_init: initial step size.
22
+ f: evaluates step size along some descent direction.
23
+ g_0: directional derivative along the descent direction.
24
+ init: initial step size.
26
25
  beta: The factor by which to decrease alpha in each iteration
27
26
  c: The constant for the Armijo sufficient decrease condition
28
- max_iter: Maximum number of backtracking iterations (default: 10).
29
- min_alpha: Minimum allowable step size to prevent near-zero values (default: 1e-16).
27
+ maxiter: Maximum number of backtracking iterations (default: 10).
30
28
 
31
29
  Returns:
32
30
  step size
@@ -34,21 +32,21 @@ def backtracking_line_search(
34
32
 
35
33
  a = init
36
34
  f_x = f(0)
35
+ f_prev = None
37
36
 
38
37
  for iteration in range(maxiter):
39
38
  f_a = f(a)
40
39
 
41
- if f_a <= f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
40
+ if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_x): return a / beta
41
+ f_prev = f_a
42
+
43
+ if f_a < f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
42
44
  # found an acceptable alpha
43
45
  return a
44
46
 
45
47
  # decrease alpha
46
48
  a *= beta
47
49
 
48
- # alpha too small
49
- if a_min is not None and a < a_min:
50
- return a_min
51
-
52
50
  # fail
53
51
  if try_negative:
54
52
  def inv_objective(alpha): return f(-alpha)
@@ -59,25 +57,56 @@ def backtracking_line_search(
59
57
  beta=beta,
60
58
  c=c,
61
59
  maxiter=maxiter,
62
- a_min=a_min,
63
60
  try_negative=False,
64
61
  )
65
62
  if v is not None: return -v
66
63
 
67
64
  return None
68
65
 
69
- class Backtracking(LineSearch):
66
+ class Backtracking(LineSearchBase):
67
+ """Backtracking line search satisfying the Armijo condition.
68
+
69
+ Args:
70
+ init (float, optional): initial step size. Defaults to 1.0.
71
+ beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
72
+ c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
73
+ maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
74
+ adaptive (bool, optional):
75
+ when enabled, if line search failed, beta is reduced.
76
+ Otherwise it is reset to initial value. Defaults to True.
77
+ try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
78
+
79
+ Examples:
80
+ Gradient descent with backtracking line search:
81
+
82
+ .. code-block:: python
83
+
84
+ opt = tz.Modular(
85
+ model.parameters(),
86
+ tz.m.Backtracking()
87
+ )
88
+
89
+ LBFGS with backtracking line search:
90
+
91
+ .. code-block:: python
92
+
93
+ opt = tz.Modular(
94
+ model.parameters(),
95
+ tz.m.LBFGS(),
96
+ tz.m.Backtracking()
97
+ )
98
+
99
+ """
70
100
  def __init__(
71
101
  self,
72
102
  init: float = 1.0,
73
103
  beta: float = 0.5,
74
104
  c: float = 1e-4,
75
105
  maxiter: int = 10,
76
- min_alpha: float | None = None,
77
106
  adaptive=True,
78
107
  try_negative: bool = False,
79
108
  ):
80
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,adaptive=adaptive, try_negative=try_negative)
109
+ defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive, try_negative=try_negative)
81
110
  super().__init__(defaults=defaults)
82
111
  self.global_state['beta_scale'] = 1.0
83
112
 
@@ -86,20 +115,20 @@ class Backtracking(LineSearch):
86
115
  self.global_state['beta_scale'] = 1.0
87
116
 
88
117
  @torch.no_grad
89
- def search(self, update, vars):
90
- init, beta, c, maxiter, min_alpha, adaptive, try_negative = itemgetter(
91
- 'init', 'beta', 'c', 'maxiter', 'min_alpha', 'adaptive', 'try_negative')(self.settings[vars.params[0]])
118
+ def search(self, update, var):
119
+ init, beta, c, maxiter, adaptive, try_negative = itemgetter(
120
+ 'init', 'beta', 'c', 'maxiter', 'adaptive', 'try_negative')(self.settings[var.params[0]])
92
121
 
93
- objective = self.make_objective(vars=vars)
122
+ objective = self.make_objective(var=var)
94
123
 
95
124
  # # directional derivative
96
- d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
125
+ d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
97
126
 
98
127
  # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
99
128
  if adaptive: beta = beta * self.global_state['beta_scale']
100
129
 
101
130
  step_size = backtracking_line_search(objective, d, init=init,beta=beta,
102
- c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
131
+ c=c,maxiter=maxiter, try_negative=try_negative)
103
132
 
104
133
  # found an alpha that reduces loss
105
134
  if step_size is not None:
@@ -113,20 +142,35 @@ class Backtracking(LineSearch):
113
142
  def _lerp(start,end,weight):
114
143
  return start + weight * (end - start)
115
144
 
116
- class AdaptiveBacktracking(LineSearch):
145
+ class AdaptiveBacktracking(LineSearchBase):
146
+ """Adaptive backtracking line search. After each line search procedure, a new initial step size is set
147
+ such that optimal step size in the procedure would be found on the second line search iteration.
148
+
149
+ Args:
150
+ init (float, optional): step size for the first step. Defaults to 1.0.
151
+ beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
152
+ c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
153
+ maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
154
+ target_iters (int, optional):
155
+ target number of iterations that would be performed until optimal step size is found. Defaults to 1.
156
+ nplus (float, optional):
157
+ Multiplier to initial step size if it was found to be the optimal step size. Defaults to 2.0.
158
+ scale_beta (float, optional):
159
+ Momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
160
+ try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
161
+ """
117
162
  def __init__(
118
163
  self,
119
164
  init: float = 1.0,
120
165
  beta: float = 0.5,
121
166
  c: float = 1e-4,
122
167
  maxiter: int = 20,
123
- min_alpha: float | None = None,
124
168
  target_iters = 1,
125
169
  nplus = 2.0,
126
170
  scale_beta = 0.0,
127
171
  try_negative: bool = False,
128
172
  ):
129
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
173
+ defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
130
174
  super().__init__(defaults=defaults)
131
175
 
132
176
  self.global_state['beta_scale'] = 1.0
@@ -138,15 +182,15 @@ class AdaptiveBacktracking(LineSearch):
138
182
  self.global_state['initial_scale'] = 1.0
139
183
 
140
184
  @torch.no_grad
141
- def search(self, update, vars):
142
- init, beta, c, maxiter, min_alpha, target_iters, nplus, scale_beta, try_negative=itemgetter(
143
- 'init','beta','c','maxiter','min_alpha','target_iters','nplus','scale_beta', 'try_negative')(self.settings[vars.params[0]])
185
+ def search(self, update, var):
186
+ init, beta, c, maxiter, target_iters, nplus, scale_beta, try_negative=itemgetter(
187
+ 'init','beta','c','maxiter','target_iters','nplus','scale_beta', 'try_negative')(self.settings[var.params[0]])
144
188
 
145
- objective = self.make_objective(vars=vars)
189
+ objective = self.make_objective(var=var)
146
190
 
147
191
  # directional derivative (0 if c = 0 because it is not needed)
148
192
  if c == 0: d = 0
149
- else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
193
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
150
194
 
151
195
  # scale beta
152
196
  beta = beta * self.global_state['beta_scale']
@@ -155,7 +199,7 @@ class AdaptiveBacktracking(LineSearch):
155
199
  init = init * self.global_state['initial_scale']
156
200
 
157
201
  step_size = backtracking_line_search(objective, d, init=init, beta=beta,
158
- c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
202
+ c=c,maxiter=maxiter, try_negative=try_negative)
159
203
 
160
204
  # found an alpha that reduces loss
161
205
  if step_size is not None: