torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -0,0 +1,161 @@
1
+ import torch
2
+ from ...core import Module
3
+
4
+ from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
+ from ...utils import vec_to_tensors
6
+ from ...utils.linalg import linear_operator
7
+ class SumOfSquares(Module):
8
+ """Sets loss to be the sum of squares of values returned by the closure.
9
+
10
+ This is meant to be used to test least squares methods against ordinary minimization methods.
11
+
12
+ To use this, the closure should return a vector of values to minimize sum of squares of.
13
+ Please add the `backward` argument, it will always be False but it is required.
14
+ """
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ @torch.no_grad
19
+ def step(self, var):
20
+ closure = var.closure
21
+
22
+ if closure is not None:
23
+ def sos_closure(backward=True):
24
+ if backward:
25
+ var.zero_grad()
26
+ with torch.enable_grad():
27
+ loss = closure(False)
28
+ loss = loss.pow(2).sum()
29
+ loss.backward()
30
+ return loss
31
+
32
+ loss = closure(False)
33
+ return loss.pow(2).sum()
34
+
35
+ var.closure = sos_closure
36
+
37
+ if var.loss is not None:
38
+ var.loss = var.loss.pow(2).sum()
39
+
40
+ if var.loss_approx is not None:
41
+ var.loss_approx = var.loss_approx.pow(2).sum()
42
+
43
+ return var
44
+
45
+
46
+ class GaussNewton(Module):
47
+ """Gauss-newton method.
48
+
49
+ To use this, the closure should return a vector of values to minimize sum of squares of.
50
+ Please add the ``backward`` argument, it will always be False but it is required.
51
+ Gradients will be calculated via batched autograd within this module, you don't need to
52
+ implement the backward pass. Please see below for an example.
53
+
54
+ Note:
55
+ This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
56
+ the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.
57
+
58
+ Args:
59
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
60
+ batched (bool, optional): whether to use vmapping. Defaults to True.
61
+
62
+ Examples:
63
+
64
+ minimizing the rosenbrock function:
65
+ ```python
66
+ def rosenbrock(X):
67
+ x1, x2 = X
68
+ return torch.stack([(1 - x1), 100 * (x2 - x1**2)])
69
+
70
+ X = torch.tensor([-1.1, 2.5], requires_grad=True)
71
+ opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())
72
+
73
+ # define the closure for line search
74
+ def closure(backward=True):
75
+ return rosenbrock(X)
76
+
77
+ # minimize
78
+ for iter in range(10):
79
+ loss = opt.step(closure)
80
+ print(f'{loss = }')
81
+ ```
82
+
83
+ training a neural network with a matrix-free GN trust region:
84
+ ```python
85
+ X = torch.randn(64, 20)
86
+ y = torch.randn(64, 10)
87
+
88
+ model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
89
+ opt = tz.Modular(
90
+ model.parameters(),
91
+ tz.m.TrustCG(tz.m.GaussNewton()),
92
+ )
93
+
94
+ def closure(backward=True):
95
+ y_hat = model(X) # (64, 10)
96
+ return (y_hat - y).pow(2).mean(0) # (10, )
97
+
98
+ for i in range(100):
99
+ losses = opt.step(closure)
100
+ if i % 10 == 0:
101
+ print(f'{losses.mean() = }')
102
+ ```
103
+ """
104
+ def __init__(self, reg:float = 1e-8, batched:bool=True, ):
105
+ super().__init__(defaults=dict(batched=batched, reg=reg))
106
+
107
+ @torch.no_grad
108
+ def update(self, var):
109
+ params = var.params
110
+ batched = self.defaults['batched']
111
+
112
+ closure = var.closure
113
+ assert closure is not None
114
+
115
+ # gauss newton direction
116
+ with torch.enable_grad():
117
+ f = var.get_loss(backward=False) # n_out
118
+ assert isinstance(f, torch.Tensor)
119
+ G_list = jacobian_wrt([f.ravel()], params, batched=batched)
120
+
121
+ var.loss = f.pow(2).sum()
122
+
123
+ G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
124
+ Gtf = G.T @ f.detach() # (ndim)
125
+ self.global_state["Gtf"] = Gtf
126
+ var.grad = vec_to_tensors(Gtf, var.params)
127
+
128
+ # set closure to calculate sum of squares for line searches etc
129
+ if var.closure is not None:
130
+ def sos_closure(backward=True):
131
+ if backward:
132
+ var.zero_grad()
133
+ with torch.enable_grad():
134
+ loss = closure(False).pow(2).sum()
135
+ loss.backward()
136
+ return loss
137
+
138
+ loss = closure(False).pow(2).sum()
139
+ return loss
140
+
141
+ var.closure = sos_closure
142
+
143
+ @torch.no_grad
144
+ def apply(self, var):
145
+ reg = self.defaults['reg']
146
+
147
+ G = self.global_state['G']
148
+ Gtf = self.global_state['Gtf']
149
+
150
+ GtG = G.T @ G # (ndim, ndim)
151
+ if reg != 0:
152
+ GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))
153
+
154
+ v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable
155
+
156
+ var.update = vec_to_tensors(v, var.params)
157
+ return var
158
+
159
+ def get_H(self, var):
160
+ G = self.global_state['G']
161
+ return linear_operator.AtA(G)
@@ -1,5 +1,5 @@
1
- from .adaptive import AdaptiveLineSearch
1
+ from .adaptive import AdaptiveTracking
2
2
  from .backtracking import AdaptiveBacktracking, Backtracking
3
3
  from .line_search import LineSearchBase
4
4
  from .scipy import ScipyMinimizeScalar
5
- from .strong_wolfe import StrongWolfe
5
+ from .strong_wolfe import StrongWolfe
@@ -0,0 +1,289 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .line_search import LineSearchBase
5
+
6
+
7
+ # polynomial interpolation
8
+ # this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
9
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
10
+ def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
11
+ """
12
+ Gives the minimizer and minimum of the interpolating polynomial over given points
13
+ based on function and derivative information. Defaults to bisection if no critical
14
+ points are valid.
15
+
16
+ Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
17
+ modifications.
18
+
19
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
20
+ Last edited 12/6/18.
21
+
22
+ Inputs:
23
+ points (nparray): two-dimensional array with each point of form [x f g]
24
+ x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
25
+ x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
26
+ plot (bool): plot interpolating polynomial
27
+
28
+ Outputs:
29
+ x_sol (float): minimizer of interpolating polynomial
30
+ F_min (float): minimum of interpolating polynomial
31
+
32
+ Note:
33
+ . Set f or g to np.nan if they are unknown
34
+
35
+ """
36
+ no_points = points.shape[0]
37
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
38
+
39
+ x_min = np.min(points[:, 0])
40
+ x_max = np.max(points[:, 0])
41
+
42
+ # compute bounds of interpolation area
43
+ if x_min_bound is None:
44
+ x_min_bound = x_min
45
+ if x_max_bound is None:
46
+ x_max_bound = x_max
47
+
48
+ # explicit formula for quadratic interpolation
49
+ if no_points == 2 and order == 2 and plot is False:
50
+ # Solution to quadratic interpolation is given by:
51
+ # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
52
+ # x_min = x1 - g1/(2a)
53
+ # if x1 = 0, then is given by:
54
+ # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
55
+
56
+ if points[0, 0] == 0:
57
+ x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
58
+ else:
59
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
60
+ x_sol = points[0, 0] - points[0, 2]/(2*a)
61
+
62
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
63
+
64
+ # explicit formula for cubic interpolation
65
+ elif no_points == 2 and order == 3 and plot is False:
66
+ # Solution to cubic interpolation is given by:
67
+ # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
68
+ # d2 = sqrt(d1^2 - g1*g2)
69
+ # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
70
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
71
+ value = d1 ** 2 - points[0, 2] * points[1, 2]
72
+ if value > 0:
73
+ d2 = np.sqrt(value)
74
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
75
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
76
+ else:
77
+ x_sol = (x_max_bound + x_min_bound)/2
78
+
79
+ # solve linear system
80
+ else:
81
+ # define linear constraints
82
+ A = np.zeros((0, order + 1))
83
+ b = np.zeros((0, 1))
84
+
85
+ # add linear constraints on function values
86
+ for i in range(no_points):
87
+ if not np.isnan(points[i, 1]):
88
+ constraint = np.zeros((1, order + 1))
89
+ for j in range(order, -1, -1):
90
+ constraint[0, order - j] = points[i, 0] ** j
91
+ A = np.append(A, constraint, 0)
92
+ b = np.append(b, points[i, 1])
93
+
94
+ # add linear constraints on gradient values
95
+ for i in range(no_points):
96
+ if not np.isnan(points[i, 2]):
97
+ constraint = np.zeros((1, order + 1))
98
+ for j in range(order):
99
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
100
+ A = np.append(A, constraint, 0)
101
+ b = np.append(b, points[i, 2])
102
+
103
+ # check if system is solvable
104
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
105
+ x_sol = (x_min_bound + x_max_bound)/2
106
+ f_min = np.inf
107
+ else:
108
+ # solve linear system for interpolating polynomial
109
+ coeff = np.linalg.solve(A, b)
110
+
111
+ # compute critical points
112
+ dcoeff = np.zeros(order)
113
+ for i in range(len(coeff) - 1):
114
+ dcoeff[i] = coeff[i] * (order - i)
115
+
116
+ crit_pts = np.array([x_min_bound, x_max_bound])
117
+ crit_pts = np.append(crit_pts, points[:, 0])
118
+
119
+ if not np.isinf(dcoeff).any():
120
+ roots = np.roots(dcoeff)
121
+ crit_pts = np.append(crit_pts, roots)
122
+
123
+ # test critical points
124
+ f_min = np.inf
125
+ x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
126
+ for crit_pt in crit_pts:
127
+ if np.isreal(crit_pt):
128
+ if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
129
+ if crit_pt >= x_min_bound and crit_pt <= x_max_bound:
130
+ F_cp = np.polyval(coeff, crit_pt)
131
+ if np.isreal(F_cp) and F_cp < f_min:
132
+ x_sol = np.real(crit_pt)
133
+ f_min = np.real(F_cp)
134
+
135
+ if(plot):
136
+ import matplotlib.pyplot as plt
137
+ plt.figure()
138
+ x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
139
+ f = np.polyval(coeff, x)
140
+ plt.plot(x, f)
141
+ plt.plot(x_sol, f_min, 'x')
142
+
143
+ return x_sol
144
+
145
+
146
+ # polynomial interpolation
147
+ # this code is based on https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
148
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
149
+ # this one is modified where instead of clipping the solution by bounds, it tries a lower degree polynomial
150
+ # all the way to bisection
151
+ def _within_bounds(x, lb, ub):
152
+ if lb is not None and x < lb: return False
153
+ if ub is not None and x > ub: return False
154
+ return True
155
+
156
+ def _quad_interp(points):
157
+ assert points.shape[0] == 2, points.shape
158
+ if points[0, 0] == 0:
159
+ denom = 2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0])
160
+ if abs(denom) > 1e-32:
161
+ return -points[0, 2] * points[1, 0] ** 2 / denom
162
+ else:
163
+ denom = (points[0, 0] - points[1, 0]) ** 2
164
+ if denom > 1e-32:
165
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / denom
166
+ if a > 1e-32:
167
+ return points[0, 0] - points[0, 2]/(2*a)
168
+ return None
169
+
170
+ def _cubic_interp(points, lb, ub):
171
+ assert points.shape[0] == 2, points.shape
172
+ denom = points[0, 0] - points[1, 0]
173
+ if abs(denom) > 1e-32:
174
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / denom)
175
+ value = d1 ** 2 - points[0, 2] * points[1, 2]
176
+ if value > 0:
177
+ d2 = np.sqrt(value)
178
+ denom = points[1, 2] - points[0, 2] + 2 * d2
179
+ if abs(denom) > 1e-32:
180
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / denom)
181
+ if _within_bounds(x_sol, lb, ub): return x_sol
182
+
183
+ # try quadratic interpolations
184
+ x_sol = _quad_interp(points)
185
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
186
+
187
+ return None
188
+
189
+ def _poly_interp(points, lb, ub):
190
+ no_points = points.shape[0]
191
+ assert no_points > 2, points.shape
192
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
193
+
194
+ # define linear constraints
195
+ A = np.zeros((0, order + 1))
196
+ b = np.zeros((0, 1))
197
+
198
+ # add linear constraints on function values
199
+ for i in range(no_points):
200
+ if not np.isnan(points[i, 1]):
201
+ constraint = np.zeros((1, order + 1))
202
+ for j in range(order, -1, -1):
203
+ constraint[0, order - j] = points[i, 0] ** j
204
+ A = np.append(A, constraint, 0)
205
+ b = np.append(b, points[i, 1])
206
+
207
+ # add linear constraints on gradient values
208
+ for i in range(no_points):
209
+ if not np.isnan(points[i, 2]):
210
+ constraint = np.zeros((1, order + 1))
211
+ for j in range(order):
212
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
213
+ A = np.append(A, constraint, 0)
214
+ b = np.append(b, points[i, 2])
215
+
216
+ # check if system is solvable
217
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
218
+ return None
219
+
220
+ # solve linear system for interpolating polynomial
221
+ coeff = np.linalg.solve(A, b)
222
+
223
+ # compute critical points
224
+ dcoeff = np.zeros(order)
225
+ for i in range(len(coeff) - 1):
226
+ dcoeff[i] = coeff[i] * (order - i)
227
+
228
+ lower = np.min(points[:, 0]) if lb is None else lb
229
+ upper = np.max(points[:, 0]) if ub is None else ub
230
+
231
+ crit_pts = np.array([lower, upper])
232
+ crit_pts = np.append(crit_pts, points[:, 0])
233
+
234
+ if not np.isinf(dcoeff).any():
235
+ roots = np.roots(dcoeff)
236
+ crit_pts = np.append(crit_pts, roots)
237
+
238
+ # test critical points
239
+ f_min = np.inf
240
+ x_sol = None
241
+ for crit_pt in crit_pts:
242
+ if np.isreal(crit_pt):
243
+ if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
244
+ if _within_bounds(crit_pt, lb, ub):
245
+ F_cp = np.polyval(coeff, crit_pt)
246
+ if np.isreal(F_cp) and F_cp < f_min:
247
+ x_sol = np.real(crit_pt)
248
+ f_min = np.real(F_cp)
249
+
250
+ return x_sol
251
+
252
+ def polyinterp2(points, lb, ub, unbounded: bool = False):
253
+ no_points = points.shape[0]
254
+ if no_points <= 1:
255
+ return (lb + ub)/2
256
+
257
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
258
+
259
+ x_min = np.min(points[:, 0])
260
+ x_max = np.max(points[:, 0])
261
+
262
+ # compute bounds of interpolation area
263
+ if not unbounded:
264
+ if lb is None:
265
+ lb = x_min
266
+ if ub is None:
267
+ ub = x_max
268
+
269
+ if no_points == 2 and order == 2:
270
+ x_sol = _quad_interp(points)
271
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
272
+ return (lb + ub)/2
273
+
274
+ if no_points == 2 and order == 3:
275
+ x_sol = _cubic_interp(points, lb, ub) # includes fallback on _quad_interp
276
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
277
+ return (lb + ub)/2
278
+
279
+ if no_points <= 2: # order < 2
280
+ return (lb + ub)/2
281
+
282
+ if no_points == 3:
283
+ for p in (points[:2], points[1:], points[::2]):
284
+ x_sol = _cubic_interp(p, lb, ub)
285
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
286
+
287
+ x_sol = _poly_interp(points, lb, ub)
288
+ if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
289
+ return polyinterp2(points[1:], lb, ub)
@@ -1,58 +1,73 @@
1
1
  import math
2
+ from bisect import insort
3
+ from collections import deque
2
4
  from collections.abc import Callable
3
5
  from operator import itemgetter
4
6
 
7
+ import numpy as np
5
8
  import torch
6
9
 
7
- from .line_search import LineSearchBase
8
-
10
+ from .line_search import LineSearchBase, TerminationCondition, termination_condition
9
11
 
10
12
 
11
13
  def adaptive_tracking(
12
14
  f,
13
- x_0,
15
+ a_init,
14
16
  maxiter: int,
15
17
  nplus: float = 2,
16
18
  nminus: float = 0.5,
19
+ f_0 = None,
17
20
  ):
18
- f_0 = f(0)
21
+ niter = 0
22
+ if f_0 is None: f_0 = f(0)
19
23
 
20
- t = x_0
21
- f_t = f(t)
24
+ a = a_init
25
+ f_a = f(a)
22
26
 
23
27
  # backtrack
24
- if f_t > f_0:
25
- while f_t > f_0:
28
+ a_prev = a
29
+ f_prev = math.inf
30
+ if (f_a > f_0) or (not math.isfinite(f_a)):
31
+ while (f_a < f_prev) or not math.isfinite(f_a):
32
+ a_prev, f_prev = a, f_a
26
33
  maxiter -= 1
27
- if maxiter < 0: return 0, f_0
28
- t = t*nminus
29
- f_t = f(t)
30
- return t, f_t
34
+ if maxiter < 0: break
35
+
36
+ a = a*nminus
37
+ f_a = f(a)
38
+ niter += 1
39
+
40
+ if f_prev < f_0: return a_prev, f_prev, niter
41
+ return 0, f_0, niter
31
42
 
32
43
  # forwardtrack
33
- f_prev = f_t
34
- t *= nplus
35
- f_t = f(t)
36
- if f_prev < f_t: return t / nplus, f_prev
37
- while f_prev >= f_t:
44
+ a_prev = a
45
+ f_prev = math.inf
46
+ while (f_a <= f_prev) and math.isfinite(f_a):
47
+ a_prev, f_prev = a, f_a
38
48
  maxiter -= 1
39
- if maxiter < 0: return t, f_t
40
- f_prev = f_t
41
- t *= nplus
42
- f_t = f(t)
43
- return t / nplus, f_prev
49
+ if maxiter < 0: break
50
+
51
+ a *= nplus
52
+ f_a = f(a)
53
+ niter+= 1
54
+
55
+ if f_prev < f_0: return a_prev, f_prev, niter
56
+ return 0, f_0, niter
57
+
44
58
 
45
- class AdaptiveLineSearch(LineSearchBase):
46
- """Adaptive line search, similar to backtracking but also has forward tracking mode.
47
- Currently doesn't check for weak curvature condition.
59
+ class AdaptiveTracking(LineSearchBase):
60
+ """A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
61
+ otherwise forward-tracks until value stops decreasing.
48
62
 
49
63
  Args:
50
64
  init (float, optional): initial step size. Defaults to 1.0.
51
- beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
52
- maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
65
+ nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
66
+ nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
67
+ maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
53
68
  adaptive (bool, optional):
54
- when enabled, if line search failed, beta size is reduced.
55
- Otherwise it is reset to initial value. Defaults to True.
69
+ when enabled, if line search failed, step size will continue decreasing on the next step.
70
+ Otherwise it will restart the line search from ``init`` step size. Defaults to True.
56
71
  """
57
72
  def __init__(
58
73
  self,
@@ -62,38 +77,48 @@ class AdaptiveLineSearch(LineSearchBase):
62
77
  maxiter: int = 10,
63
78
  adaptive=True,
64
79
  ):
65
- defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive,)
80
+ defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
66
81
  super().__init__(defaults=defaults)
67
- self.global_state['beta_scale'] = 1.0
68
82
 
69
83
  def reset(self):
70
84
  super().reset()
71
- self.global_state['beta_scale'] = 1.0
72
85
 
73
86
  @torch.no_grad
74
87
  def search(self, update, var):
75
88
  init, nplus, nminus, maxiter, adaptive = itemgetter(
76
- 'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.settings[var.params[0]])
89
+ 'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)
77
90
 
78
91
  objective = self.make_objective(var=var)
79
92
 
80
- # # directional derivative
81
- # d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
93
+ # scale a_prev
94
+ a_prev = self.global_state.get('a_prev', init)
95
+ if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)
82
96
 
83
- # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
84
- beta_scale = self.global_state.get('beta_scale', 1)
85
- x_prev = self.global_state.get('prev_x', 1)
97
+ a_init = a_prev
98
+ if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
99
+ a_init = torch.finfo(var.params[0].dtype).max / 2
86
100
 
87
- if adaptive: nminus = nminus * beta_scale
88
-
89
-
90
- step_size, f = adaptive_tracking(objective, x_prev, maxiter, nplus=nplus, nminus=nminus)
101
+ step_size, f, niter = adaptive_tracking(
102
+ objective,
103
+ a_init=a_init,
104
+ maxiter=maxiter,
105
+ nplus=nplus,
106
+ nminus=nminus,
107
+ )
91
108
 
92
109
  # found an alpha that reduces loss
93
110
  if step_size != 0:
94
- self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
111
+ assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
112
+ self.global_state['init_scale'] = 1
113
+
114
+ # if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
115
+ if niter == 1 and step_size >= a_init: step_size *= nminus
116
+
117
+ self.global_state['a_prev'] = step_size
95
118
  return step_size
96
119
 
97
120
  # on fail reduce beta scale value
98
- self.global_state['beta_scale'] /= 1.5
121
+ self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
122
+ self.global_state['a_prev'] = init
99
123
  return 0
124
+