torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,289 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .line_search import LineSearchBase
5
+
6
+
7
+ # polynomial interpolation
8
+ # this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
9
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
10
+ def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
11
+ """
12
+ Gives the minimizer and minimum of the interpolating polynomial over given points
13
+ based on function and derivative information. Defaults to bisection if no critical
14
+ points are valid.
15
+
16
+ Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
17
+ modifications.
18
+
19
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
20
+ Last edited 12/6/18.
21
+
22
+ Inputs:
23
+ points (nparray): two-dimensional array with each point of form [x f g]
24
+ x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
25
+ x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
26
+ plot (bool): plot interpolating polynomial
27
+
28
+ Outputs:
29
+ x_sol (float): minimizer of interpolating polynomial
30
+ F_min (float): minimum of interpolating polynomial
31
+
32
+ Note:
33
+ . Set f or g to np.nan if they are unknown
34
+
35
+ """
36
+ no_points = points.shape[0]
37
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
38
+
39
+ x_min = np.min(points[:, 0])
40
+ x_max = np.max(points[:, 0])
41
+
42
+ # compute bounds of interpolation area
43
+ if x_min_bound is None:
44
+ x_min_bound = x_min
45
+ if x_max_bound is None:
46
+ x_max_bound = x_max
47
+
48
+ # explicit formula for quadratic interpolation
49
+ if no_points == 2 and order == 2 and plot is False:
50
+ # Solution to quadratic interpolation is given by:
51
+ # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
52
+ # x_min = x1 - g1/(2a)
53
+ # if x1 = 0, then is given by:
54
+ # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
55
+
56
+ if points[0, 0] == 0:
57
+ x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
58
+ else:
59
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
60
+ x_sol = points[0, 0] - points[0, 2]/(2*a)
61
+
62
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
63
+
64
+ # explicit formula for cubic interpolation
65
+ elif no_points == 2 and order == 3 and plot is False:
66
+ # Solution to cubic interpolation is given by:
67
+ # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
68
+ # d2 = sqrt(d1^2 - g1*g2)
69
+ # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
70
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
71
+ value = d1 ** 2 - points[0, 2] * points[1, 2]
72
+ if value > 0:
73
+ d2 = np.sqrt(value)
74
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
75
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
76
+ else:
77
+ x_sol = (x_max_bound + x_min_bound)/2
78
+
79
+ # solve linear system
80
+ else:
81
+ # define linear constraints
82
+ A = np.zeros((0, order + 1))
83
+ b = np.zeros((0, 1))
84
+
85
+ # add linear constraints on function values
86
+ for i in range(no_points):
87
+ if not np.isnan(points[i, 1]):
88
+ constraint = np.zeros((1, order + 1))
89
+ for j in range(order, -1, -1):
90
+ constraint[0, order - j] = points[i, 0] ** j
91
+ A = np.append(A, constraint, 0)
92
+ b = np.append(b, points[i, 1])
93
+
94
+ # add linear constraints on gradient values
95
+ for i in range(no_points):
96
+ if not np.isnan(points[i, 2]):
97
+ constraint = np.zeros((1, order + 1))
98
+ for j in range(order):
99
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
100
+ A = np.append(A, constraint, 0)
101
+ b = np.append(b, points[i, 2])
102
+
103
+ # check if system is solvable
104
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
105
+ x_sol = (x_min_bound + x_max_bound)/2
106
+ f_min = np.inf
107
+ else:
108
+ # solve linear system for interpolating polynomial
109
+ coeff = np.linalg.solve(A, b)
110
+
111
+ # compute critical points
112
+ dcoeff = np.zeros(order)
113
+ for i in range(len(coeff) - 1):
114
+ dcoeff[i] = coeff[i] * (order - i)
115
+
116
+ crit_pts = np.array([x_min_bound, x_max_bound])
117
+ crit_pts = np.append(crit_pts, points[:, 0])
118
+
119
+ if not np.isinf(dcoeff).any():
120
+ roots = np.roots(dcoeff)
121
+ crit_pts = np.append(crit_pts, roots)
122
+
123
+ # test critical points
124
+ f_min = np.inf
125
+ x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
126
+ for crit_pt in crit_pts:
127
+ if np.isreal(crit_pt):
128
+ if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
129
+ if crit_pt >= x_min_bound and crit_pt <= x_max_bound:
130
+ F_cp = np.polyval(coeff, crit_pt)
131
+ if np.isreal(F_cp) and F_cp < f_min:
132
+ x_sol = np.real(crit_pt)
133
+ f_min = np.real(F_cp)
134
+
135
+ if(plot):
136
+ import matplotlib.pyplot as plt
137
+ plt.figure()
138
+ x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
139
+ f = np.polyval(coeff, x)
140
+ plt.plot(x, f)
141
+ plt.plot(x_sol, f_min, 'x')
142
+
143
+ return x_sol
144
+
145
+
146
+ # polynomial interpolation
147
+ # this code is based on https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
148
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
149
+ # this one is modified where instead of clipping the solution by bounds, it tries a lower degree polynomial
150
+ # all the way to bisection
151
+ def _within_bounds(x, lb, ub):
152
+ if lb is not None and x < lb: return False
153
+ if ub is not None and x > ub: return False
154
+ return True
155
+
156
+ def _quad_interp(points):
157
+ assert points.shape[0] == 2, points.shape
158
+ if points[0, 0] == 0:
159
+ denom = 2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0])
160
+ if abs(denom) > 1e-32:
161
+ return -points[0, 2] * points[1, 0] ** 2 / denom
162
+ else:
163
+ denom = (points[0, 0] - points[1, 0]) ** 2
164
+ if denom > 1e-32:
165
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / denom
166
+ if a > 1e-32:
167
+ return points[0, 0] - points[0, 2]/(2*a)
168
+ return None
169
+
170
+ def _cubic_interp(points, lb, ub):
171
+ assert points.shape[0] == 2, points.shape
172
+ denom = points[0, 0] - points[1, 0]
173
+ if abs(denom) > 1e-32:
174
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / denom)
175
+ value = d1 ** 2 - points[0, 2] * points[1, 2]
176
+ if value > 0:
177
+ d2 = np.sqrt(value)
178
+ denom = points[1, 2] - points[0, 2] + 2 * d2
179
+ if abs(denom) > 1e-32:
180
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / denom)
181
+ if _within_bounds(x_sol, lb, ub): return x_sol
182
+
183
+ # try quadratic interpolations
184
+ x_sol = _quad_interp(points)
185
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
186
+
187
+ return None
188
+
189
+ def _poly_interp(points, lb, ub):
190
+ no_points = points.shape[0]
191
+ assert no_points > 2, points.shape
192
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
193
+
194
+ # define linear constraints
195
+ A = np.zeros((0, order + 1))
196
+ b = np.zeros((0, 1))
197
+
198
+ # add linear constraints on function values
199
+ for i in range(no_points):
200
+ if not np.isnan(points[i, 1]):
201
+ constraint = np.zeros((1, order + 1))
202
+ for j in range(order, -1, -1):
203
+ constraint[0, order - j] = points[i, 0] ** j
204
+ A = np.append(A, constraint, 0)
205
+ b = np.append(b, points[i, 1])
206
+
207
+ # add linear constraints on gradient values
208
+ for i in range(no_points):
209
+ if not np.isnan(points[i, 2]):
210
+ constraint = np.zeros((1, order + 1))
211
+ for j in range(order):
212
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
213
+ A = np.append(A, constraint, 0)
214
+ b = np.append(b, points[i, 2])
215
+
216
+ # check if system is solvable
217
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
218
+ return None
219
+
220
+ # solve linear system for interpolating polynomial
221
+ coeff = np.linalg.solve(A, b)
222
+
223
+ # compute critical points
224
+ dcoeff = np.zeros(order)
225
+ for i in range(len(coeff) - 1):
226
+ dcoeff[i] = coeff[i] * (order - i)
227
+
228
+ lower = np.min(points[:, 0]) if lb is None else lb
229
+ upper = np.max(points[:, 0]) if ub is None else ub
230
+
231
+ crit_pts = np.array([lower, upper])
232
+ crit_pts = np.append(crit_pts, points[:, 0])
233
+
234
+ if not np.isinf(dcoeff).any():
235
+ roots = np.roots(dcoeff)
236
+ crit_pts = np.append(crit_pts, roots)
237
+
238
+ # test critical points
239
+ f_min = np.inf
240
+ x_sol = None
241
+ for crit_pt in crit_pts:
242
+ if np.isreal(crit_pt):
243
+ if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
244
+ if _within_bounds(crit_pt, lb, ub):
245
+ F_cp = np.polyval(coeff, crit_pt)
246
+ if np.isreal(F_cp) and F_cp < f_min:
247
+ x_sol = np.real(crit_pt)
248
+ f_min = np.real(F_cp)
249
+
250
+ return x_sol
251
+
252
+ def polyinterp2(points, lb, ub, unbounded: bool = False):
253
+ no_points = points.shape[0]
254
+ if no_points <= 1:
255
+ return (lb + ub)/2
256
+
257
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
258
+
259
+ x_min = np.min(points[:, 0])
260
+ x_max = np.max(points[:, 0])
261
+
262
+ # compute bounds of interpolation area
263
+ if not unbounded:
264
+ if lb is None:
265
+ lb = x_min
266
+ if ub is None:
267
+ ub = x_max
268
+
269
+ if no_points == 2 and order == 2:
270
+ x_sol = _quad_interp(points)
271
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
272
+ return (lb + ub)/2
273
+
274
+ if no_points == 2 and order == 3:
275
+ x_sol = _cubic_interp(points, lb, ub) # includes fallback on _quad_interp
276
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
277
+ return (lb + ub)/2
278
+
279
+ if no_points <= 2: # order < 2
280
+ return (lb + ub)/2
281
+
282
+ if no_points == 3:
283
+ for p in (points[:2], points[1:], points[::2]):
284
+ x_sol = _cubic_interp(p, lb, ub)
285
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
286
+
287
+ x_sol = _poly_interp(points, lb, ub)
288
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
289
+ return polyinterp2(points[1:], lb, ub)
@@ -0,0 +1,124 @@
1
+ import math
2
+ from bisect import insort
3
+ from collections import deque
4
+ from collections.abc import Callable
5
+ from operator import itemgetter
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
11
+
12
+
13
+ def adaptive_tracking(
14
+ f,
15
+ a_init,
16
+ maxiter: int,
17
+ nplus: float = 2,
18
+ nminus: float = 0.5,
19
+ f_0 = None,
20
+ ):
21
+ niter = 0
22
+ if f_0 is None: f_0 = f(0)
23
+
24
+ a = a_init
25
+ f_a = f(a)
26
+
27
+ # backtrack
28
+ a_prev = a
29
+ f_prev = math.inf
30
+ if (f_a > f_0) or (not math.isfinite(f_a)):
31
+ while (f_a < f_prev) or not math.isfinite(f_a):
32
+ a_prev, f_prev = a, f_a
33
+ maxiter -= 1
34
+ if maxiter < 0: break
35
+
36
+ a = a*nminus
37
+ f_a = f(a)
38
+ niter += 1
39
+
40
+ if f_prev < f_0: return a_prev, f_prev, niter
41
+ return 0, f_0, niter
42
+
43
+ # forwardtrack
44
+ a_prev = a
45
+ f_prev = math.inf
46
+ while (f_a <= f_prev) and math.isfinite(f_a):
47
+ a_prev, f_prev = a, f_a
48
+ maxiter -= 1
49
+ if maxiter < 0: break
50
+
51
+ a *= nplus
52
+ f_a = f(a)
53
+ niter+= 1
54
+
55
+ if f_prev < f_0: return a_prev, f_prev, niter
56
+ return 0, f_0, niter
57
+
58
+
59
+ class AdaptiveTracking(LineSearchBase):
60
+ """A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
61
+ otherwise forward-tracks until value stops decreasing.
62
+
63
+ Args:
64
+ init (float, optional): initial step size. Defaults to 1.0.
65
+ nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
66
+ nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
67
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
68
+ adaptive (bool, optional):
69
+ when enabled, if line search failed, step size will continue decreasing on the next step.
70
+ Otherwise it will restart the line search from ``init`` step size. Defaults to True.
71
+ """
72
+ def __init__(
73
+ self,
74
+ init: float = 1.0,
75
+ nplus: float = 2,
76
+ nminus: float = 0.5,
77
+ maxiter: int = 10,
78
+ adaptive=True,
79
+ ):
80
+ defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
81
+ super().__init__(defaults=defaults)
82
+
83
+ def reset(self):
84
+ super().reset()
85
+
86
+ @torch.no_grad
87
+ def search(self, update, var):
88
+ init, nplus, nminus, maxiter, adaptive = itemgetter(
89
+ 'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)
90
+
91
+ objective = self.make_objective(var=var)
92
+
93
+ # scale a_prev
94
+ a_prev = self.global_state.get('a_prev', init)
95
+ if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)
96
+
97
+ a_init = a_prev
98
+ if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
99
+ a_init = torch.finfo(var.params[0].dtype).max / 2
100
+
101
+ step_size, f, niter = adaptive_tracking(
102
+ objective,
103
+ a_init=a_init,
104
+ maxiter=maxiter,
105
+ nplus=nplus,
106
+ nminus=nminus,
107
+ )
108
+
109
+ # found an alpha that reduces loss
110
+ if step_size != 0:
111
+ assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
112
+ self.global_state['init_scale'] = 1
113
+
114
+ # if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
115
+ if niter == 1 and step_size >= a_init: step_size *= nminus
116
+
117
+ self.global_state['a_prev'] = step_size
118
+ return step_size
119
+
120
+ # on fail reduce beta scale value
121
+ self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
122
+ self.global_state['a_prev'] = init
123
+ return 0
124
+
@@ -4,7 +4,7 @@ from operator import itemgetter
4
4
 
5
5
  import torch
6
6
 
7
- from .line_search import LineSearch
7
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
8
8
 
9
9
 
10
10
  def backtracking_line_search(
@@ -14,29 +14,37 @@ def backtracking_line_search(
14
14
  beta: float = 0.5,
15
15
  c: float = 1e-4,
16
16
  maxiter: int = 10,
17
- try_negative: bool = False,
17
+ condition: TerminationCondition = 'armijo',
18
18
  ) -> float | None:
19
19
  """
20
20
 
21
21
  Args:
22
- objective_fn: evaluates step size along some descent direction.
23
- dir_derivative: directional derivative along the descent direction.
24
- alpha_init: initial step size.
22
+ f: evaluates step size along some descent direction.
23
+ g_0: directional derivative along the descent direction.
24
+ init: initial step size.
25
25
  beta: The factor by which to decrease alpha in each iteration
26
26
  c: The constant for the Armijo sufficient decrease condition
27
- max_iter: Maximum number of backtracking iterations (default: 10).
27
+ maxiter: Maximum number of backtracking iterations (default: 10).
28
28
 
29
29
  Returns:
30
30
  step size
31
31
  """
32
32
 
33
33
  a = init
34
- f_x = f(0)
34
+ f_0 = f(0)
35
+ f_prev = None
35
36
 
36
37
  for iteration in range(maxiter):
37
38
  f_a = f(a)
39
+ if not math.isfinite(f_a):
40
+ a *= beta
41
+ continue
38
42
 
39
- if f_a <= f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
43
+ if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_0):
44
+ return a / beta # new value is larger than previous value
45
+ f_prev = f_a
46
+
47
+ if termination_condition(condition, f_0=f_0, g_0=g_0, f_a=f_a, g_a=None, a=a, c=c):
40
48
  # found an acceptable alpha
41
49
  return a
42
50
 
@@ -44,108 +52,134 @@ def backtracking_line_search(
44
52
  a *= beta
45
53
 
46
54
  # fail
47
- if try_negative:
48
- def inv_objective(alpha): return f(-alpha)
49
-
50
- v = backtracking_line_search(
51
- inv_objective,
52
- g_0=-g_0,
53
- beta=beta,
54
- c=c,
55
- maxiter=maxiter,
56
- try_negative=False,
57
- )
58
- if v is not None: return -v
59
-
60
55
  return None
61
56
 
62
- class Backtracking(LineSearch):
63
- """Backtracking line search satisfying the Armijo condition.
57
+ class Backtracking(LineSearchBase):
58
+ """Backtracking line search.
64
59
 
65
60
  Args:
66
61
  init (float, optional): initial step size. Defaults to 1.0.
67
62
  beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
68
- c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
69
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
63
+ c (float, optional): sufficient decrease condition. Defaults to 1e-4.
64
+ condition (TerminationCondition, optional):
65
+ termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
66
+ - "armijo" - sufficient decrease condition.
67
+ - "decrease" - any decrease in objective function value satisfies the condition.
68
+
69
+ "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
70
+ Defaults to 'armijo'.
71
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
70
72
  adaptive (bool, optional):
71
- when enabled, if line search failed, initial step size is reduced.
72
- Otherwise it is reset to initial value. Defaults to True.
73
- try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
73
+ when enabled, if line search failed, step size will continue decreasing on the next step.
74
+ Otherwise it will restart the line search from ``init`` step size. Defaults to True.
75
+
76
+ Examples:
77
+ Gradient descent with backtracking line search:
78
+
79
+ ```python
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.Backtracking()
83
+ )
84
+ ```
85
+
86
+ L-BFGS with backtracking line search:
87
+ ```python
88
+ opt = tz.Modular(
89
+ model.parameters(),
90
+ tz.m.LBFGS(),
91
+ tz.m.Backtracking()
92
+ )
93
+ ```
94
+
74
95
  """
75
96
  def __init__(
76
97
  self,
77
98
  init: float = 1.0,
78
99
  beta: float = 0.5,
79
100
  c: float = 1e-4,
101
+ condition: TerminationCondition = 'armijo',
80
102
  maxiter: int = 10,
81
103
  adaptive=True,
82
- try_negative: bool = False,
83
104
  ):
84
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive, try_negative=try_negative)
105
+ defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
85
106
  super().__init__(defaults=defaults)
86
- self.global_state['beta_scale'] = 1.0
87
107
 
88
108
  def reset(self):
89
109
  super().reset()
90
- self.global_state['beta_scale'] = 1.0
91
110
 
92
111
  @torch.no_grad
93
112
  def search(self, update, var):
94
- init, beta, c, maxiter, adaptive, try_negative = itemgetter(
95
- 'init', 'beta', 'c', 'maxiter', 'adaptive', 'try_negative')(self.settings[var.params[0]])
113
+ init, beta, c, condition, maxiter, adaptive = itemgetter(
114
+ 'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)
96
115
 
97
116
  objective = self.make_objective(var=var)
98
117
 
99
118
  # # directional derivative
100
- d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
119
+ if c == 0: d = 0
120
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
101
121
 
102
- # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
103
- if adaptive: beta = beta * self.global_state['beta_scale']
122
+ # scale init
123
+ init_scale = self.global_state.get('init_scale', 1)
124
+ if adaptive: init = init * init_scale
104
125
 
105
- step_size = backtracking_line_search(objective, d, init=init,beta=beta,
106
- c=c,maxiter=maxiter, try_negative=try_negative)
126
+ step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)
107
127
 
108
128
  # found an alpha that reduces loss
109
129
  if step_size is not None:
110
- self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
130
+ #self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
131
+ self.global_state['init_scale'] = 1
111
132
  return step_size
112
133
 
113
- # on fail reduce beta scale value
114
- self.global_state['beta_scale'] /= 1.5
134
+ # on fail set init_scale to continue decreasing the step size
135
+ # or set to large step size when it becomes too small
136
+ if adaptive:
137
+ finfo = torch.finfo(var.params[0].dtype)
138
+ if init_scale <= finfo.tiny * 2:
139
+ self.global_state["init_scale"] = finfo.max / 2
140
+ else:
141
+ self.global_state['init_scale'] = init_scale * beta**maxiter
115
142
  return 0
116
143
 
117
144
  def _lerp(start,end,weight):
118
145
  return start + weight * (end - start)
119
146
 
120
- class AdaptiveBacktracking(LineSearch):
147
+ class AdaptiveBacktracking(LineSearchBase):
121
148
  """Adaptive backtracking line search. After each line search procedure, a new initial step size is set
122
149
  such that optimal step size in the procedure would be found on the second line search iteration.
123
150
 
124
151
  Args:
125
- init (float, optional): step size for the first step. Defaults to 1.0.
152
+ init (float, optional): initial step size. Defaults to 1.0.
126
153
  beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
127
- c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
128
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
154
+ c (float, optional): sufficient decrease condition. Defaults to 1e-4.
155
+ condition (TerminationCondition, optional):
156
+ termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
157
+ - "armijo" - sufficient decrease condition.
158
+ - "decrease" - any decrease in objective function value satisfies the condition.
159
+
160
+ "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
161
+ Defaults to 'armijo'.
162
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
129
163
  target_iters (int, optional):
130
- target number of iterations that would be performed until optimal step size is found. Defaults to 1.
164
+ sets next step size such that this number of iterations are expected
165
+ to be performed until optimal step size is found. Defaults to 1.
131
166
  nplus (float, optional):
132
- Multiplier to initial step size if it was found to be the optimal step size. Defaults to 2.0.
167
+ if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
133
168
  scale_beta (float, optional):
134
- Momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
135
- try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
169
+ momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
136
170
  """
137
171
  def __init__(
138
172
  self,
139
173
  init: float = 1.0,
140
174
  beta: float = 0.5,
141
175
  c: float = 1e-4,
176
+ condition: TerminationCondition = 'armijo',
142
177
  maxiter: int = 20,
143
178
  target_iters = 1,
144
179
  nplus = 2.0,
145
180
  scale_beta = 0.0,
146
- try_negative: bool = False,
147
181
  ):
148
- defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
182
+ defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
149
183
  super().__init__(defaults=defaults)
150
184
 
151
185
  self.global_state['beta_scale'] = 1.0
@@ -158,8 +192,8 @@ class AdaptiveBacktracking(LineSearch):
158
192
 
159
193
  @torch.no_grad
160
194
  def search(self, update, var):
161
- init, beta, c, maxiter, target_iters, nplus, scale_beta, try_negative=itemgetter(
162
- 'init','beta','c','maxiter','target_iters','nplus','scale_beta', 'try_negative')(self.settings[var.params[0]])
195
+ init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
196
+ 'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)
163
197
 
164
198
  objective = self.make_objective(var=var)
165
199
 
@@ -173,8 +207,7 @@ class AdaptiveBacktracking(LineSearch):
173
207
  # scale step size so that decrease is expected at target_iters
174
208
  init = init * self.global_state['initial_scale']
175
209
 
176
- step_size = backtracking_line_search(objective, d, init=init, beta=beta,
177
- c=c,maxiter=maxiter, try_negative=try_negative)
210
+ step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)
178
211
 
179
212
  # found an alpha that reduces loss
180
213
  if step_size is not None:
@@ -183,7 +216,12 @@ class AdaptiveBacktracking(LineSearch):
183
216
  # initial step size satisfied conditions, increase initial_scale by nplus
184
217
  if step_size == init and target_iters > 0:
185
218
  self.global_state['initial_scale'] *= nplus ** target_iters
186
- self.global_state['initial_scale'] = min(self.global_state['initial_scale'], 1e32) # avoid overflow error
219
+
220
+ # clip by maximum possibel value to avoid overflow exception
221
+ self.global_state['initial_scale'] = min(
222
+ self.global_state['initial_scale'],
223
+ torch.finfo(var.params[0].dtype).max / 2,
224
+ )
187
225
 
188
226
  else:
189
227
  # otherwise make initial_scale such that target_iters iterations will satisfy armijo