torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from .line_search import LineSearchBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# polynomial interpolation
|
|
8
|
+
# this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
9
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
10
|
+
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
|
|
11
|
+
"""
|
|
12
|
+
Gives the minimizer and minimum of the interpolating polynomial over given points
|
|
13
|
+
based on function and derivative information. Defaults to bisection if no critical
|
|
14
|
+
points are valid.
|
|
15
|
+
|
|
16
|
+
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
|
|
17
|
+
modifications.
|
|
18
|
+
|
|
19
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
20
|
+
Last edited 12/6/18.
|
|
21
|
+
|
|
22
|
+
Inputs:
|
|
23
|
+
points (nparray): two-dimensional array with each point of form [x f g]
|
|
24
|
+
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
|
|
25
|
+
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
|
|
26
|
+
plot (bool): plot interpolating polynomial
|
|
27
|
+
|
|
28
|
+
Outputs:
|
|
29
|
+
x_sol (float): minimizer of interpolating polynomial
|
|
30
|
+
F_min (float): minimum of interpolating polynomial
|
|
31
|
+
|
|
32
|
+
Note:
|
|
33
|
+
. Set f or g to np.nan if they are unknown
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
no_points = points.shape[0]
|
|
37
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
38
|
+
|
|
39
|
+
x_min = np.min(points[:, 0])
|
|
40
|
+
x_max = np.max(points[:, 0])
|
|
41
|
+
|
|
42
|
+
# compute bounds of interpolation area
|
|
43
|
+
if x_min_bound is None:
|
|
44
|
+
x_min_bound = x_min
|
|
45
|
+
if x_max_bound is None:
|
|
46
|
+
x_max_bound = x_max
|
|
47
|
+
|
|
48
|
+
# explicit formula for quadratic interpolation
|
|
49
|
+
if no_points == 2 and order == 2 and plot is False:
|
|
50
|
+
# Solution to quadratic interpolation is given by:
|
|
51
|
+
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
|
|
52
|
+
# x_min = x1 - g1/(2a)
|
|
53
|
+
# if x1 = 0, then is given by:
|
|
54
|
+
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
|
|
55
|
+
|
|
56
|
+
if points[0, 0] == 0:
|
|
57
|
+
x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
|
|
58
|
+
else:
|
|
59
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
|
|
60
|
+
x_sol = points[0, 0] - points[0, 2]/(2*a)
|
|
61
|
+
|
|
62
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
63
|
+
|
|
64
|
+
# explicit formula for cubic interpolation
|
|
65
|
+
elif no_points == 2 and order == 3 and plot is False:
|
|
66
|
+
# Solution to cubic interpolation is given by:
|
|
67
|
+
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
|
|
68
|
+
# d2 = sqrt(d1^2 - g1*g2)
|
|
69
|
+
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
|
|
70
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
|
|
71
|
+
value = d1 ** 2 - points[0, 2] * points[1, 2]
|
|
72
|
+
if value > 0:
|
|
73
|
+
d2 = np.sqrt(value)
|
|
74
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
|
|
75
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
76
|
+
else:
|
|
77
|
+
x_sol = (x_max_bound + x_min_bound)/2
|
|
78
|
+
|
|
79
|
+
# solve linear system
|
|
80
|
+
else:
|
|
81
|
+
# define linear constraints
|
|
82
|
+
A = np.zeros((0, order + 1))
|
|
83
|
+
b = np.zeros((0, 1))
|
|
84
|
+
|
|
85
|
+
# add linear constraints on function values
|
|
86
|
+
for i in range(no_points):
|
|
87
|
+
if not np.isnan(points[i, 1]):
|
|
88
|
+
constraint = np.zeros((1, order + 1))
|
|
89
|
+
for j in range(order, -1, -1):
|
|
90
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
91
|
+
A = np.append(A, constraint, 0)
|
|
92
|
+
b = np.append(b, points[i, 1])
|
|
93
|
+
|
|
94
|
+
# add linear constraints on gradient values
|
|
95
|
+
for i in range(no_points):
|
|
96
|
+
if not np.isnan(points[i, 2]):
|
|
97
|
+
constraint = np.zeros((1, order + 1))
|
|
98
|
+
for j in range(order):
|
|
99
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
100
|
+
A = np.append(A, constraint, 0)
|
|
101
|
+
b = np.append(b, points[i, 2])
|
|
102
|
+
|
|
103
|
+
# check if system is solvable
|
|
104
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
105
|
+
x_sol = (x_min_bound + x_max_bound)/2
|
|
106
|
+
f_min = np.inf
|
|
107
|
+
else:
|
|
108
|
+
# solve linear system for interpolating polynomial
|
|
109
|
+
coeff = np.linalg.solve(A, b)
|
|
110
|
+
|
|
111
|
+
# compute critical points
|
|
112
|
+
dcoeff = np.zeros(order)
|
|
113
|
+
for i in range(len(coeff) - 1):
|
|
114
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
115
|
+
|
|
116
|
+
crit_pts = np.array([x_min_bound, x_max_bound])
|
|
117
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
118
|
+
|
|
119
|
+
if not np.isinf(dcoeff).any():
|
|
120
|
+
roots = np.roots(dcoeff)
|
|
121
|
+
crit_pts = np.append(crit_pts, roots)
|
|
122
|
+
|
|
123
|
+
# test critical points
|
|
124
|
+
f_min = np.inf
|
|
125
|
+
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
|
|
126
|
+
for crit_pt in crit_pts:
|
|
127
|
+
if np.isreal(crit_pt):
|
|
128
|
+
if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
|
|
129
|
+
if crit_pt >= x_min_bound and crit_pt <= x_max_bound:
|
|
130
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
131
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
132
|
+
x_sol = np.real(crit_pt)
|
|
133
|
+
f_min = np.real(F_cp)
|
|
134
|
+
|
|
135
|
+
if(plot):
|
|
136
|
+
import matplotlib.pyplot as plt
|
|
137
|
+
plt.figure()
|
|
138
|
+
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
|
|
139
|
+
f = np.polyval(coeff, x)
|
|
140
|
+
plt.plot(x, f)
|
|
141
|
+
plt.plot(x_sol, f_min, 'x')
|
|
142
|
+
|
|
143
|
+
return x_sol
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# polynomial interpolation
|
|
147
|
+
# this code is based on https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
148
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
149
|
+
# this one is modified where instead of clipping the solution by bounds, it tries a lower degree polynomial
|
|
150
|
+
# all the way to bisection
|
|
151
|
+
def _within_bounds(x, lb, ub):
|
|
152
|
+
if lb is not None and x < lb: return False
|
|
153
|
+
if ub is not None and x > ub: return False
|
|
154
|
+
return True
|
|
155
|
+
|
|
156
|
+
def _quad_interp(points):
|
|
157
|
+
assert points.shape[0] == 2, points.shape
|
|
158
|
+
if points[0, 0] == 0:
|
|
159
|
+
denom = 2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0])
|
|
160
|
+
if abs(denom) > 1e-32:
|
|
161
|
+
return -points[0, 2] * points[1, 0] ** 2 / denom
|
|
162
|
+
else:
|
|
163
|
+
denom = (points[0, 0] - points[1, 0]) ** 2
|
|
164
|
+
if denom > 1e-32:
|
|
165
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / denom
|
|
166
|
+
if a > 1e-32:
|
|
167
|
+
return points[0, 0] - points[0, 2]/(2*a)
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
def _cubic_interp(points, lb, ub):
|
|
171
|
+
assert points.shape[0] == 2, points.shape
|
|
172
|
+
denom = points[0, 0] - points[1, 0]
|
|
173
|
+
if abs(denom) > 1e-32:
|
|
174
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / denom)
|
|
175
|
+
value = d1 ** 2 - points[0, 2] * points[1, 2]
|
|
176
|
+
if value > 0:
|
|
177
|
+
d2 = np.sqrt(value)
|
|
178
|
+
denom = points[1, 2] - points[0, 2] + 2 * d2
|
|
179
|
+
if abs(denom) > 1e-32:
|
|
180
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / denom)
|
|
181
|
+
if _within_bounds(x_sol, lb, ub): return x_sol
|
|
182
|
+
|
|
183
|
+
# try quadratic interpolations
|
|
184
|
+
x_sol = _quad_interp(points)
|
|
185
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
186
|
+
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
def _poly_interp(points, lb, ub):
|
|
190
|
+
no_points = points.shape[0]
|
|
191
|
+
assert no_points > 2, points.shape
|
|
192
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
193
|
+
|
|
194
|
+
# define linear constraints
|
|
195
|
+
A = np.zeros((0, order + 1))
|
|
196
|
+
b = np.zeros((0, 1))
|
|
197
|
+
|
|
198
|
+
# add linear constraints on function values
|
|
199
|
+
for i in range(no_points):
|
|
200
|
+
if not np.isnan(points[i, 1]):
|
|
201
|
+
constraint = np.zeros((1, order + 1))
|
|
202
|
+
for j in range(order, -1, -1):
|
|
203
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
204
|
+
A = np.append(A, constraint, 0)
|
|
205
|
+
b = np.append(b, points[i, 1])
|
|
206
|
+
|
|
207
|
+
# add linear constraints on gradient values
|
|
208
|
+
for i in range(no_points):
|
|
209
|
+
if not np.isnan(points[i, 2]):
|
|
210
|
+
constraint = np.zeros((1, order + 1))
|
|
211
|
+
for j in range(order):
|
|
212
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
213
|
+
A = np.append(A, constraint, 0)
|
|
214
|
+
b = np.append(b, points[i, 2])
|
|
215
|
+
|
|
216
|
+
# check if system is solvable
|
|
217
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# solve linear system for interpolating polynomial
|
|
221
|
+
coeff = np.linalg.solve(A, b)
|
|
222
|
+
|
|
223
|
+
# compute critical points
|
|
224
|
+
dcoeff = np.zeros(order)
|
|
225
|
+
for i in range(len(coeff) - 1):
|
|
226
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
227
|
+
|
|
228
|
+
lower = np.min(points[:, 0]) if lb is None else lb
|
|
229
|
+
upper = np.max(points[:, 0]) if ub is None else ub
|
|
230
|
+
|
|
231
|
+
crit_pts = np.array([lower, upper])
|
|
232
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
233
|
+
|
|
234
|
+
if not np.isinf(dcoeff).any():
|
|
235
|
+
roots = np.roots(dcoeff)
|
|
236
|
+
crit_pts = np.append(crit_pts, roots)
|
|
237
|
+
|
|
238
|
+
# test critical points
|
|
239
|
+
f_min = np.inf
|
|
240
|
+
x_sol = None
|
|
241
|
+
for crit_pt in crit_pts:
|
|
242
|
+
if np.isreal(crit_pt):
|
|
243
|
+
if not np.isrealobj(crit_pt): crit_pt = crit_pt.real
|
|
244
|
+
if _within_bounds(crit_pt, lb, ub):
|
|
245
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
246
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
247
|
+
x_sol = np.real(crit_pt)
|
|
248
|
+
f_min = np.real(F_cp)
|
|
249
|
+
|
|
250
|
+
return x_sol
|
|
251
|
+
|
|
252
|
+
def polyinterp2(points, lb, ub, unbounded: bool = False):
|
|
253
|
+
no_points = points.shape[0]
|
|
254
|
+
if no_points <= 1:
|
|
255
|
+
return (lb + ub)/2
|
|
256
|
+
|
|
257
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
258
|
+
|
|
259
|
+
x_min = np.min(points[:, 0])
|
|
260
|
+
x_max = np.max(points[:, 0])
|
|
261
|
+
|
|
262
|
+
# compute bounds of interpolation area
|
|
263
|
+
if not unbounded:
|
|
264
|
+
if lb is None:
|
|
265
|
+
lb = x_min
|
|
266
|
+
if ub is None:
|
|
267
|
+
ub = x_max
|
|
268
|
+
|
|
269
|
+
if no_points == 2 and order == 2:
|
|
270
|
+
x_sol = _quad_interp(points)
|
|
271
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
272
|
+
return (lb + ub)/2
|
|
273
|
+
|
|
274
|
+
if no_points == 2 and order == 3:
|
|
275
|
+
x_sol = _cubic_interp(points, lb, ub) # includes fallback on _quad_interp
|
|
276
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
277
|
+
return (lb + ub)/2
|
|
278
|
+
|
|
279
|
+
if no_points <= 2: # order < 2
|
|
280
|
+
return (lb + ub)/2
|
|
281
|
+
|
|
282
|
+
if no_points == 3:
|
|
283
|
+
for p in (points[:2], points[1:], points[::2]):
|
|
284
|
+
x_sol = _cubic_interp(p, lb, ub)
|
|
285
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
286
|
+
|
|
287
|
+
x_sol = _poly_interp(points, lb, ub)
|
|
288
|
+
if x_sol is not None and _within_bounds(x_sol, lb, ub): return x_sol
|
|
289
|
+
return polyinterp2(points[1:], lb, ub)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from bisect import insort
|
|
3
|
+
from collections import deque
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from operator import itemgetter
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def adaptive_tracking(
|
|
14
|
+
f,
|
|
15
|
+
a_init,
|
|
16
|
+
maxiter: int,
|
|
17
|
+
nplus: float = 2,
|
|
18
|
+
nminus: float = 0.5,
|
|
19
|
+
f_0 = None,
|
|
20
|
+
):
|
|
21
|
+
niter = 0
|
|
22
|
+
if f_0 is None: f_0 = f(0)
|
|
23
|
+
|
|
24
|
+
a = a_init
|
|
25
|
+
f_a = f(a)
|
|
26
|
+
|
|
27
|
+
# backtrack
|
|
28
|
+
a_prev = a
|
|
29
|
+
f_prev = math.inf
|
|
30
|
+
if (f_a > f_0) or (not math.isfinite(f_a)):
|
|
31
|
+
while (f_a < f_prev) or not math.isfinite(f_a):
|
|
32
|
+
a_prev, f_prev = a, f_a
|
|
33
|
+
maxiter -= 1
|
|
34
|
+
if maxiter < 0: break
|
|
35
|
+
|
|
36
|
+
a = a*nminus
|
|
37
|
+
f_a = f(a)
|
|
38
|
+
niter += 1
|
|
39
|
+
|
|
40
|
+
if f_prev < f_0: return a_prev, f_prev, niter
|
|
41
|
+
return 0, f_0, niter
|
|
42
|
+
|
|
43
|
+
# forwardtrack
|
|
44
|
+
a_prev = a
|
|
45
|
+
f_prev = math.inf
|
|
46
|
+
while (f_a <= f_prev) and math.isfinite(f_a):
|
|
47
|
+
a_prev, f_prev = a, f_a
|
|
48
|
+
maxiter -= 1
|
|
49
|
+
if maxiter < 0: break
|
|
50
|
+
|
|
51
|
+
a *= nplus
|
|
52
|
+
f_a = f(a)
|
|
53
|
+
niter+= 1
|
|
54
|
+
|
|
55
|
+
if f_prev < f_0: return a_prev, f_prev, niter
|
|
56
|
+
return 0, f_0, niter
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AdaptiveTracking(LineSearchBase):
|
|
60
|
+
"""A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
|
|
61
|
+
otherwise forward-tracks until value stops decreasing.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
65
|
+
nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
|
|
66
|
+
nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
|
|
67
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
68
|
+
adaptive (bool, optional):
|
|
69
|
+
when enabled, if line search failed, step size will continue decreasing on the next step.
|
|
70
|
+
Otherwise it will restart the line search from ``init`` step size. Defaults to True.
|
|
71
|
+
"""
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
init: float = 1.0,
|
|
75
|
+
nplus: float = 2,
|
|
76
|
+
nminus: float = 0.5,
|
|
77
|
+
maxiter: int = 10,
|
|
78
|
+
adaptive=True,
|
|
79
|
+
):
|
|
80
|
+
defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
|
|
81
|
+
super().__init__(defaults=defaults)
|
|
82
|
+
|
|
83
|
+
def reset(self):
|
|
84
|
+
super().reset()
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def search(self, update, var):
|
|
88
|
+
init, nplus, nminus, maxiter, adaptive = itemgetter(
|
|
89
|
+
'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)
|
|
90
|
+
|
|
91
|
+
objective = self.make_objective(var=var)
|
|
92
|
+
|
|
93
|
+
# scale a_prev
|
|
94
|
+
a_prev = self.global_state.get('a_prev', init)
|
|
95
|
+
if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)
|
|
96
|
+
|
|
97
|
+
a_init = a_prev
|
|
98
|
+
if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
|
|
99
|
+
a_init = torch.finfo(var.params[0].dtype).max / 2
|
|
100
|
+
|
|
101
|
+
step_size, f, niter = adaptive_tracking(
|
|
102
|
+
objective,
|
|
103
|
+
a_init=a_init,
|
|
104
|
+
maxiter=maxiter,
|
|
105
|
+
nplus=nplus,
|
|
106
|
+
nminus=nminus,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# found an alpha that reduces loss
|
|
110
|
+
if step_size != 0:
|
|
111
|
+
assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
|
|
112
|
+
self.global_state['init_scale'] = 1
|
|
113
|
+
|
|
114
|
+
# if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
|
|
115
|
+
if niter == 1 and step_size >= a_init: step_size *= nminus
|
|
116
|
+
|
|
117
|
+
self.global_state['a_prev'] = step_size
|
|
118
|
+
return step_size
|
|
119
|
+
|
|
120
|
+
# on fail reduce beta scale value
|
|
121
|
+
self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
|
|
122
|
+
self.global_state['a_prev'] = init
|
|
123
|
+
return 0
|
|
124
|
+
|
|
@@ -4,7 +4,7 @@ from operator import itemgetter
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from .line_search import
|
|
7
|
+
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def backtracking_line_search(
|
|
@@ -14,29 +14,37 @@ def backtracking_line_search(
|
|
|
14
14
|
beta: float = 0.5,
|
|
15
15
|
c: float = 1e-4,
|
|
16
16
|
maxiter: int = 10,
|
|
17
|
-
|
|
17
|
+
condition: TerminationCondition = 'armijo',
|
|
18
18
|
) -> float | None:
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
f: evaluates step size along some descent direction.
|
|
23
|
+
g_0: directional derivative along the descent direction.
|
|
24
|
+
init: initial step size.
|
|
25
25
|
beta: The factor by which to decrease alpha in each iteration
|
|
26
26
|
c: The constant for the Armijo sufficient decrease condition
|
|
27
|
-
|
|
27
|
+
maxiter: Maximum number of backtracking iterations (default: 10).
|
|
28
28
|
|
|
29
29
|
Returns:
|
|
30
30
|
step size
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
a = init
|
|
34
|
-
|
|
34
|
+
f_0 = f(0)
|
|
35
|
+
f_prev = None
|
|
35
36
|
|
|
36
37
|
for iteration in range(maxiter):
|
|
37
38
|
f_a = f(a)
|
|
39
|
+
if not math.isfinite(f_a):
|
|
40
|
+
a *= beta
|
|
41
|
+
continue
|
|
38
42
|
|
|
39
|
-
if
|
|
43
|
+
if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_0):
|
|
44
|
+
return a / beta # new value is larger than previous value
|
|
45
|
+
f_prev = f_a
|
|
46
|
+
|
|
47
|
+
if termination_condition(condition, f_0=f_0, g_0=g_0, f_a=f_a, g_a=None, a=a, c=c):
|
|
40
48
|
# found an acceptable alpha
|
|
41
49
|
return a
|
|
42
50
|
|
|
@@ -44,108 +52,134 @@ def backtracking_line_search(
|
|
|
44
52
|
a *= beta
|
|
45
53
|
|
|
46
54
|
# fail
|
|
47
|
-
if try_negative:
|
|
48
|
-
def inv_objective(alpha): return f(-alpha)
|
|
49
|
-
|
|
50
|
-
v = backtracking_line_search(
|
|
51
|
-
inv_objective,
|
|
52
|
-
g_0=-g_0,
|
|
53
|
-
beta=beta,
|
|
54
|
-
c=c,
|
|
55
|
-
maxiter=maxiter,
|
|
56
|
-
try_negative=False,
|
|
57
|
-
)
|
|
58
|
-
if v is not None: return -v
|
|
59
|
-
|
|
60
55
|
return None
|
|
61
56
|
|
|
62
|
-
class Backtracking(
|
|
63
|
-
"""Backtracking line search
|
|
57
|
+
class Backtracking(LineSearchBase):
|
|
58
|
+
"""Backtracking line search.
|
|
64
59
|
|
|
65
60
|
Args:
|
|
66
61
|
init (float, optional): initial step size. Defaults to 1.0.
|
|
67
62
|
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
68
|
-
c (float, optional):
|
|
69
|
-
|
|
63
|
+
c (float, optional): sufficient decrease condition. Defaults to 1e-4.
|
|
64
|
+
condition (TerminationCondition, optional):
|
|
65
|
+
termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
|
|
66
|
+
- "armijo" - sufficient decrease condition.
|
|
67
|
+
- "decrease" - any decrease in objective function value satisfies the condition.
|
|
68
|
+
|
|
69
|
+
"goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
|
|
70
|
+
Defaults to 'armijo'.
|
|
71
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
70
72
|
adaptive (bool, optional):
|
|
71
|
-
when enabled, if line search failed,
|
|
72
|
-
Otherwise it
|
|
73
|
-
|
|
73
|
+
when enabled, if line search failed, step size will continue decreasing on the next step.
|
|
74
|
+
Otherwise it will restart the line search from ``init`` step size. Defaults to True.
|
|
75
|
+
|
|
76
|
+
Examples:
|
|
77
|
+
Gradient descent with backtracking line search:
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.Backtracking()
|
|
83
|
+
)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
L-BFGS with backtracking line search:
|
|
87
|
+
```python
|
|
88
|
+
opt = tz.Modular(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.LBFGS(),
|
|
91
|
+
tz.m.Backtracking()
|
|
92
|
+
)
|
|
93
|
+
```
|
|
94
|
+
|
|
74
95
|
"""
|
|
75
96
|
def __init__(
|
|
76
97
|
self,
|
|
77
98
|
init: float = 1.0,
|
|
78
99
|
beta: float = 0.5,
|
|
79
100
|
c: float = 1e-4,
|
|
101
|
+
condition: TerminationCondition = 'armijo',
|
|
80
102
|
maxiter: int = 10,
|
|
81
103
|
adaptive=True,
|
|
82
|
-
try_negative: bool = False,
|
|
83
104
|
):
|
|
84
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive
|
|
105
|
+
defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
|
|
85
106
|
super().__init__(defaults=defaults)
|
|
86
|
-
self.global_state['beta_scale'] = 1.0
|
|
87
107
|
|
|
88
108
|
def reset(self):
|
|
89
109
|
super().reset()
|
|
90
|
-
self.global_state['beta_scale'] = 1.0
|
|
91
110
|
|
|
92
111
|
@torch.no_grad
|
|
93
112
|
def search(self, update, var):
|
|
94
|
-
init, beta, c, maxiter, adaptive
|
|
95
|
-
'init', 'beta', 'c', '
|
|
113
|
+
init, beta, c, condition, maxiter, adaptive = itemgetter(
|
|
114
|
+
'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)
|
|
96
115
|
|
|
97
116
|
objective = self.make_objective(var=var)
|
|
98
117
|
|
|
99
118
|
# # directional derivative
|
|
100
|
-
|
|
119
|
+
if c == 0: d = 0
|
|
120
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
|
|
101
121
|
|
|
102
|
-
# scale
|
|
103
|
-
|
|
122
|
+
# scale init
|
|
123
|
+
init_scale = self.global_state.get('init_scale', 1)
|
|
124
|
+
if adaptive: init = init * init_scale
|
|
104
125
|
|
|
105
|
-
step_size = backtracking_line_search(objective, d, init=init,beta=beta,
|
|
106
|
-
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
126
|
+
step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)
|
|
107
127
|
|
|
108
128
|
# found an alpha that reduces loss
|
|
109
129
|
if step_size is not None:
|
|
110
|
-
self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
130
|
+
#self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
131
|
+
self.global_state['init_scale'] = 1
|
|
111
132
|
return step_size
|
|
112
133
|
|
|
113
|
-
# on fail
|
|
114
|
-
|
|
134
|
+
# on fail set init_scale to continue decreasing the step size
|
|
135
|
+
# or set to large step size when it becomes too small
|
|
136
|
+
if adaptive:
|
|
137
|
+
finfo = torch.finfo(var.params[0].dtype)
|
|
138
|
+
if init_scale <= finfo.tiny * 2:
|
|
139
|
+
self.global_state["init_scale"] = finfo.max / 2
|
|
140
|
+
else:
|
|
141
|
+
self.global_state['init_scale'] = init_scale * beta**maxiter
|
|
115
142
|
return 0
|
|
116
143
|
|
|
117
144
|
def _lerp(start,end,weight):
|
|
118
145
|
return start + weight * (end - start)
|
|
119
146
|
|
|
120
|
-
class AdaptiveBacktracking(
|
|
147
|
+
class AdaptiveBacktracking(LineSearchBase):
|
|
121
148
|
"""Adaptive backtracking line search. After each line search procedure, a new initial step size is set
|
|
122
149
|
such that optimal step size in the procedure would be found on the second line search iteration.
|
|
123
150
|
|
|
124
151
|
Args:
|
|
125
|
-
init (float, optional): step size
|
|
152
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
126
153
|
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
127
|
-
c (float, optional):
|
|
128
|
-
|
|
154
|
+
c (float, optional): sufficient decrease condition. Defaults to 1e-4.
|
|
155
|
+
condition (TerminationCondition, optional):
|
|
156
|
+
termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
|
|
157
|
+
- "armijo" - sufficient decrease condition.
|
|
158
|
+
- "decrease" - any decrease in objective function value satisfies the condition.
|
|
159
|
+
|
|
160
|
+
"goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
|
|
161
|
+
Defaults to 'armijo'.
|
|
162
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
129
163
|
target_iters (int, optional):
|
|
130
|
-
|
|
164
|
+
sets next step size such that this number of iterations are expected
|
|
165
|
+
to be performed until optimal step size is found. Defaults to 1.
|
|
131
166
|
nplus (float, optional):
|
|
132
|
-
|
|
167
|
+
if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
|
|
133
168
|
scale_beta (float, optional):
|
|
134
|
-
|
|
135
|
-
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
169
|
+
momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
|
|
136
170
|
"""
|
|
137
171
|
def __init__(
|
|
138
172
|
self,
|
|
139
173
|
init: float = 1.0,
|
|
140
174
|
beta: float = 0.5,
|
|
141
175
|
c: float = 1e-4,
|
|
176
|
+
condition: TerminationCondition = 'armijo',
|
|
142
177
|
maxiter: int = 20,
|
|
143
178
|
target_iters = 1,
|
|
144
179
|
nplus = 2.0,
|
|
145
180
|
scale_beta = 0.0,
|
|
146
|
-
try_negative: bool = False,
|
|
147
181
|
):
|
|
148
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta
|
|
182
|
+
defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
|
|
149
183
|
super().__init__(defaults=defaults)
|
|
150
184
|
|
|
151
185
|
self.global_state['beta_scale'] = 1.0
|
|
@@ -158,8 +192,8 @@ class AdaptiveBacktracking(LineSearch):
|
|
|
158
192
|
|
|
159
193
|
@torch.no_grad
|
|
160
194
|
def search(self, update, var):
|
|
161
|
-
init, beta, c, maxiter, target_iters, nplus, scale_beta
|
|
162
|
-
'init','beta','c','maxiter','target_iters','nplus','scale_beta'
|
|
195
|
+
init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
|
|
196
|
+
'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)
|
|
163
197
|
|
|
164
198
|
objective = self.make_objective(var=var)
|
|
165
199
|
|
|
@@ -173,8 +207,7 @@ class AdaptiveBacktracking(LineSearch):
|
|
|
173
207
|
# scale step size so that decrease is expected at target_iters
|
|
174
208
|
init = init * self.global_state['initial_scale']
|
|
175
209
|
|
|
176
|
-
step_size = backtracking_line_search(objective, d, init=init, beta=beta,
|
|
177
|
-
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
210
|
+
step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)
|
|
178
211
|
|
|
179
212
|
# found an alpha that reduces loss
|
|
180
213
|
if step_size is not None:
|
|
@@ -183,7 +216,12 @@ class AdaptiveBacktracking(LineSearch):
|
|
|
183
216
|
# initial step size satisfied conditions, increase initial_scale by nplus
|
|
184
217
|
if step_size == init and target_iters > 0:
|
|
185
218
|
self.global_state['initial_scale'] *= nplus ** target_iters
|
|
186
|
-
|
|
219
|
+
|
|
220
|
+
# clip by maximum possibel value to avoid overflow exception
|
|
221
|
+
self.global_state['initial_scale'] = min(
|
|
222
|
+
self.global_state['initial_scale'],
|
|
223
|
+
torch.finfo(var.params[0].dtype).max / 2,
|
|
224
|
+
)
|
|
187
225
|
|
|
188
226
|
else:
|
|
189
227
|
# otherwise make initial_scale such that target_iters iterations will satisfy armijo
|