torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,233 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from .line_search import LineSearchBase
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
# polynomial interpolation
|
|
8
|
-
# this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
9
|
-
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
10
|
-
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
|
|
11
|
-
"""
|
|
12
|
-
Gives the minimizer and minimum of the interpolating polynomial over given points
|
|
13
|
-
based on function and derivative information. Defaults to bisection if no critical
|
|
14
|
-
points are valid.
|
|
15
|
-
|
|
16
|
-
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
|
|
17
|
-
modifications.
|
|
18
|
-
|
|
19
|
-
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
20
|
-
Last edited 12/6/18.
|
|
21
|
-
|
|
22
|
-
Inputs:
|
|
23
|
-
points (nparray): two-dimensional array with each point of form [x f g]
|
|
24
|
-
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
|
|
25
|
-
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
|
|
26
|
-
plot (bool): plot interpolating polynomial
|
|
27
|
-
|
|
28
|
-
Outputs:
|
|
29
|
-
x_sol (float): minimizer of interpolating polynomial
|
|
30
|
-
F_min (float): minimum of interpolating polynomial
|
|
31
|
-
|
|
32
|
-
Note:
|
|
33
|
-
. Set f or g to np.nan if they are unknown
|
|
34
|
-
|
|
35
|
-
"""
|
|
36
|
-
no_points = points.shape[0]
|
|
37
|
-
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
38
|
-
|
|
39
|
-
x_min = np.min(points[:, 0])
|
|
40
|
-
x_max = np.max(points[:, 0])
|
|
41
|
-
|
|
42
|
-
# compute bounds of interpolation area
|
|
43
|
-
if x_min_bound is None:
|
|
44
|
-
x_min_bound = x_min
|
|
45
|
-
if x_max_bound is None:
|
|
46
|
-
x_max_bound = x_max
|
|
47
|
-
|
|
48
|
-
# explicit formula for quadratic interpolation
|
|
49
|
-
if no_points == 2 and order == 2 and plot is False:
|
|
50
|
-
# Solution to quadratic interpolation is given by:
|
|
51
|
-
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
|
|
52
|
-
# x_min = x1 - g1/(2a)
|
|
53
|
-
# if x1 = 0, then is given by:
|
|
54
|
-
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
|
|
55
|
-
|
|
56
|
-
if points[0, 0] == 0:
|
|
57
|
-
x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
|
|
58
|
-
else:
|
|
59
|
-
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
|
|
60
|
-
x_sol = points[0, 0] - points[0, 2]/(2*a)
|
|
61
|
-
|
|
62
|
-
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
63
|
-
|
|
64
|
-
# explicit formula for cubic interpolation
|
|
65
|
-
elif no_points == 2 and order == 3 and plot is False:
|
|
66
|
-
# Solution to cubic interpolation is given by:
|
|
67
|
-
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
|
|
68
|
-
# d2 = sqrt(d1^2 - g1*g2)
|
|
69
|
-
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
|
|
70
|
-
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
|
|
71
|
-
d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
|
|
72
|
-
if np.isreal(d2):
|
|
73
|
-
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
|
|
74
|
-
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
75
|
-
else:
|
|
76
|
-
x_sol = (x_max_bound + x_min_bound)/2
|
|
77
|
-
|
|
78
|
-
# solve linear system
|
|
79
|
-
else:
|
|
80
|
-
# define linear constraints
|
|
81
|
-
A = np.zeros((0, order + 1))
|
|
82
|
-
b = np.zeros((0, 1))
|
|
83
|
-
|
|
84
|
-
# add linear constraints on function values
|
|
85
|
-
for i in range(no_points):
|
|
86
|
-
if not np.isnan(points[i, 1]):
|
|
87
|
-
constraint = np.zeros((1, order + 1))
|
|
88
|
-
for j in range(order, -1, -1):
|
|
89
|
-
constraint[0, order - j] = points[i, 0] ** j
|
|
90
|
-
A = np.append(A, constraint, 0)
|
|
91
|
-
b = np.append(b, points[i, 1])
|
|
92
|
-
|
|
93
|
-
# add linear constraints on gradient values
|
|
94
|
-
for i in range(no_points):
|
|
95
|
-
if not np.isnan(points[i, 2]):
|
|
96
|
-
constraint = np.zeros((1, order + 1))
|
|
97
|
-
for j in range(order):
|
|
98
|
-
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
99
|
-
A = np.append(A, constraint, 0)
|
|
100
|
-
b = np.append(b, points[i, 2])
|
|
101
|
-
|
|
102
|
-
# check if system is solvable
|
|
103
|
-
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
104
|
-
x_sol = (x_min_bound + x_max_bound)/2
|
|
105
|
-
f_min = np.inf
|
|
106
|
-
else:
|
|
107
|
-
# solve linear system for interpolating polynomial
|
|
108
|
-
coeff = np.linalg.solve(A, b)
|
|
109
|
-
|
|
110
|
-
# compute critical points
|
|
111
|
-
dcoeff = np.zeros(order)
|
|
112
|
-
for i in range(len(coeff) - 1):
|
|
113
|
-
dcoeff[i] = coeff[i] * (order - i)
|
|
114
|
-
|
|
115
|
-
crit_pts = np.array([x_min_bound, x_max_bound])
|
|
116
|
-
crit_pts = np.append(crit_pts, points[:, 0])
|
|
117
|
-
|
|
118
|
-
if not np.isinf(dcoeff).any():
|
|
119
|
-
roots = np.roots(dcoeff)
|
|
120
|
-
crit_pts = np.append(crit_pts, roots)
|
|
121
|
-
|
|
122
|
-
# test critical points
|
|
123
|
-
f_min = np.inf
|
|
124
|
-
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
|
|
125
|
-
for crit_pt in crit_pts:
|
|
126
|
-
if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
|
|
127
|
-
F_cp = np.polyval(coeff, crit_pt)
|
|
128
|
-
if np.isreal(F_cp) and F_cp < f_min:
|
|
129
|
-
x_sol = np.real(crit_pt)
|
|
130
|
-
f_min = np.real(F_cp)
|
|
131
|
-
|
|
132
|
-
if(plot):
|
|
133
|
-
import matplotlib.pyplot as plt
|
|
134
|
-
plt.figure()
|
|
135
|
-
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
|
|
136
|
-
f = np.polyval(coeff, x)
|
|
137
|
-
plt.plot(x, f)
|
|
138
|
-
plt.plot(x_sol, f_min, 'x')
|
|
139
|
-
|
|
140
|
-
return x_sol
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
# class PolynomialLineSearch(LineSearch):
|
|
145
|
-
# """TODO
|
|
146
|
-
|
|
147
|
-
# Line search via polynomial interpolation.
|
|
148
|
-
|
|
149
|
-
# Args:
|
|
150
|
-
# init (float, optional): Initial step size. Defaults to 1.0.
|
|
151
|
-
# c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
|
|
152
|
-
# c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
|
|
153
|
-
# maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
|
|
154
|
-
# maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
|
|
155
|
-
# expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
|
|
156
|
-
# adaptive (bool, optional):
|
|
157
|
-
# when enabled, if line search failed, initial step size is reduced.
|
|
158
|
-
# Otherwise it is reset to initial value. Defaults to True.
|
|
159
|
-
# plus_minus (bool, optional):
|
|
160
|
-
# If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
# Examples:
|
|
164
|
-
# Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
|
|
165
|
-
|
|
166
|
-
# .. code-block:: python
|
|
167
|
-
|
|
168
|
-
# opt = tz.Modular(
|
|
169
|
-
# model.parameters(),
|
|
170
|
-
# tz.m.PolakRibiere(),
|
|
171
|
-
# tz.m.StrongWolfe(c2=0.1)
|
|
172
|
-
# )
|
|
173
|
-
|
|
174
|
-
# LBFGS strong wolfe line search:
|
|
175
|
-
|
|
176
|
-
# .. code-block:: python
|
|
177
|
-
|
|
178
|
-
# opt = tz.Modular(
|
|
179
|
-
# model.parameters(),
|
|
180
|
-
# tz.m.LBFGS(),
|
|
181
|
-
# tz.m.StrongWolfe()
|
|
182
|
-
# )
|
|
183
|
-
|
|
184
|
-
# """
|
|
185
|
-
# def __init__(
|
|
186
|
-
# self,
|
|
187
|
-
# init: float = 1.0,
|
|
188
|
-
# c1: float = 1e-4,
|
|
189
|
-
# c2: float = 0.9,
|
|
190
|
-
# maxiter: int = 25,
|
|
191
|
-
# maxzoom: int = 10,
|
|
192
|
-
# # a_max: float = 1e10,
|
|
193
|
-
# expand: float = 2.0,
|
|
194
|
-
# adaptive = True,
|
|
195
|
-
# plus_minus = False,
|
|
196
|
-
# ):
|
|
197
|
-
# defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
|
|
198
|
-
# expand=expand, adaptive=adaptive, plus_minus=plus_minus)
|
|
199
|
-
# super().__init__(defaults=defaults)
|
|
200
|
-
|
|
201
|
-
# self.global_state['initial_scale'] = 1.0
|
|
202
|
-
# self.global_state['beta_scale'] = 1.0
|
|
203
|
-
|
|
204
|
-
# @torch.no_grad
|
|
205
|
-
# def search(self, update, var):
|
|
206
|
-
# objective = self.make_objective_with_derivative(var=var)
|
|
207
|
-
|
|
208
|
-
# init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
|
|
209
|
-
# 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
210
|
-
# 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
|
|
211
|
-
|
|
212
|
-
# f_0, g_0 = objective(0)
|
|
213
|
-
|
|
214
|
-
# step_size,f_a = strong_wolfe(
|
|
215
|
-
# objective,
|
|
216
|
-
# f_0=f_0, g_0=g_0,
|
|
217
|
-
# init=init * self.global_state.setdefault("initial_scale", 1),
|
|
218
|
-
# c1=c1,
|
|
219
|
-
# c2=c2,
|
|
220
|
-
# maxiter=maxiter,
|
|
221
|
-
# maxzoom=maxzoom,
|
|
222
|
-
# expand=expand,
|
|
223
|
-
# plus_minus=plus_minus,
|
|
224
|
-
# )
|
|
225
|
-
|
|
226
|
-
# if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
|
|
227
|
-
# if step_size is not None and step_size != 0 and not _notfinite(step_size):
|
|
228
|
-
# self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
|
|
229
|
-
# return step_size
|
|
230
|
-
|
|
231
|
-
# # fallback to backtracking on fail
|
|
232
|
-
# if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
233
|
-
# return 0
|
|
@@ -1,193 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Module, apply_transform, Chainable
|
|
6
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
-
|
|
9
|
-
class MatrixMomentum(Module):
|
|
10
|
-
"""Second order momentum method.
|
|
11
|
-
|
|
12
|
-
Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
|
|
13
|
-
|
|
14
|
-
.. note::
|
|
15
|
-
:code:`mu` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
|
|
16
|
-
|
|
17
|
-
.. note::
|
|
18
|
-
I have devised an adaptive version of this - :code:`tz.m.AdaptiveMatrixMomentum`, and it works well
|
|
19
|
-
without having to tune :code:`mu`.
|
|
20
|
-
|
|
21
|
-
.. note::
|
|
22
|
-
In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
|
|
23
|
-
|
|
24
|
-
.. note::
|
|
25
|
-
This module requires the a closure passed to the optimizer step,
|
|
26
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
27
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
31
|
-
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
32
|
-
hvp_method (str, optional):
|
|
33
|
-
Determines how Hessian-vector products are evaluated.
|
|
34
|
-
|
|
35
|
-
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
36
|
-
This requires creating a graph for the gradient.
|
|
37
|
-
- ``"forward"``: Use a forward finite difference formula to
|
|
38
|
-
approximate the HVP. This requires one extra gradient evaluation.
|
|
39
|
-
- ``"central"``: Use a central finite difference formula for a
|
|
40
|
-
more accurate HVP approximation. This requires two extra
|
|
41
|
-
gradient evaluations.
|
|
42
|
-
Defaults to "autograd".
|
|
43
|
-
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
44
|
-
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
45
|
-
|
|
46
|
-
Reference:
|
|
47
|
-
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
mu=0.1,
|
|
53
|
-
beta: float = 1,
|
|
54
|
-
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
55
|
-
h: float = 1e-3,
|
|
56
|
-
hvp_tfm: Chainable | None = None,
|
|
57
|
-
):
|
|
58
|
-
defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
|
|
59
|
-
super().__init__(defaults)
|
|
60
|
-
|
|
61
|
-
if hvp_tfm is not None:
|
|
62
|
-
self.set_child('hvp_tfm', hvp_tfm)
|
|
63
|
-
|
|
64
|
-
def reset_for_online(self):
|
|
65
|
-
super().reset_for_online()
|
|
66
|
-
self.clear_state_keys('prev_update')
|
|
67
|
-
|
|
68
|
-
@torch.no_grad
|
|
69
|
-
def update(self, var):
|
|
70
|
-
assert var.closure is not None
|
|
71
|
-
prev_update = self.get_state(var.params, 'prev_update')
|
|
72
|
-
hvp_method = self.settings[var.params[0]]['hvp_method']
|
|
73
|
-
h = self.settings[var.params[0]]['h']
|
|
74
|
-
|
|
75
|
-
Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
76
|
-
Hvp = [t.detach() for t in Hvp]
|
|
77
|
-
|
|
78
|
-
if 'hvp_tfm' in self.children:
|
|
79
|
-
Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
|
|
80
|
-
|
|
81
|
-
self.store(var.params, "Hvp", Hvp)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
@torch.no_grad
|
|
85
|
-
def apply(self, var):
|
|
86
|
-
update = TensorList(var.get_update())
|
|
87
|
-
Hvp, prev_update = self.get_state(var.params, 'Hvp', 'prev_update', cls=TensorList)
|
|
88
|
-
mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
|
|
89
|
-
|
|
90
|
-
update.add_(prev_update - Hvp*mu)
|
|
91
|
-
prev_update.set_(update * beta)
|
|
92
|
-
var.update = update
|
|
93
|
-
return var
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
class AdaptiveMatrixMomentum(Module):
|
|
97
|
-
"""Second order momentum method.
|
|
98
|
-
|
|
99
|
-
Matrix momentum is useful for convex objectives, also for some reason it has very good generalization on elastic net logistic regression.
|
|
100
|
-
|
|
101
|
-
.. note::
|
|
102
|
-
In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
|
|
103
|
-
|
|
104
|
-
.. note::
|
|
105
|
-
This module requires the a closure passed to the optimizer step,
|
|
106
|
-
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
107
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
|
|
112
|
-
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
113
|
-
hvp_method (str, optional):
|
|
114
|
-
Determines how Hessian-vector products are evaluated.
|
|
115
|
-
|
|
116
|
-
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
117
|
-
This requires creating a graph for the gradient.
|
|
118
|
-
- ``"forward"``: Use a forward finite difference formula to
|
|
119
|
-
approximate the HVP. This requires one extra gradient evaluation.
|
|
120
|
-
- ``"central"``: Use a central finite difference formula for a
|
|
121
|
-
more accurate HVP approximation. This requires two extra
|
|
122
|
-
gradient evaluations.
|
|
123
|
-
Defaults to "autograd".
|
|
124
|
-
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
125
|
-
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
126
|
-
|
|
127
|
-
Reference:
|
|
128
|
-
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
129
|
-
"""
|
|
130
|
-
|
|
131
|
-
def __init__(
|
|
132
|
-
self,
|
|
133
|
-
mu_mul: float = 1,
|
|
134
|
-
beta: float = 1,
|
|
135
|
-
eps=1e-4,
|
|
136
|
-
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
137
|
-
h: float = 1e-3,
|
|
138
|
-
hvp_tfm: Chainable | None = None,
|
|
139
|
-
):
|
|
140
|
-
defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
|
|
141
|
-
super().__init__(defaults)
|
|
142
|
-
|
|
143
|
-
if hvp_tfm is not None:
|
|
144
|
-
self.set_child('hvp_tfm', hvp_tfm)
|
|
145
|
-
|
|
146
|
-
def reset_for_online(self):
|
|
147
|
-
super().reset_for_online()
|
|
148
|
-
self.clear_state_keys('prev_params', 'prev_grad')
|
|
149
|
-
|
|
150
|
-
@torch.no_grad
|
|
151
|
-
def update(self, var):
|
|
152
|
-
assert var.closure is not None
|
|
153
|
-
prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
|
|
154
|
-
|
|
155
|
-
settings = self.settings[var.params[0]]
|
|
156
|
-
hvp_method = settings['hvp_method']
|
|
157
|
-
h = settings['h']
|
|
158
|
-
eps = settings['eps']
|
|
159
|
-
|
|
160
|
-
mu_mul = NumberList(self.settings[p]['mu_mul'] for p in var.params)
|
|
161
|
-
|
|
162
|
-
Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
163
|
-
Hvp = [t.detach() for t in Hvp]
|
|
164
|
-
|
|
165
|
-
if 'hvp_tfm' in self.children:
|
|
166
|
-
Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
|
|
167
|
-
|
|
168
|
-
# adaptive part
|
|
169
|
-
s_k = var.params - prev_params
|
|
170
|
-
prev_params.copy_(var.params)
|
|
171
|
-
|
|
172
|
-
if hvp_method != 'central': assert var.grad is not None
|
|
173
|
-
grad = var.get_grad()
|
|
174
|
-
y_k = grad - prev_grad
|
|
175
|
-
prev_grad.copy_(grad)
|
|
176
|
-
|
|
177
|
-
ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
|
|
178
|
-
|
|
179
|
-
self.store(var.params, ['Hvp', 'ada_mu'], [Hvp, ada_mu])
|
|
180
|
-
|
|
181
|
-
@torch.no_grad
|
|
182
|
-
def apply(self, var):
|
|
183
|
-
Hvp, ada_mu = self.get_state(var.params, 'Hvp', 'ada_mu')
|
|
184
|
-
Hvp = as_tensorlist(Hvp)
|
|
185
|
-
beta = NumberList(self.settings[p]['beta'] for p in var.params)
|
|
186
|
-
update = TensorList(var.get_update())
|
|
187
|
-
prev_update = TensorList(self.state[p]['prev_update'] for p in var.params)
|
|
188
|
-
|
|
189
|
-
update.add_(prev_update - Hvp*ada_mu)
|
|
190
|
-
prev_update.set_(update * beta)
|
|
191
|
-
var.update = update
|
|
192
|
-
return var
|
|
193
|
-
|
|
@@ -1,165 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from ...core import (
|
|
6
|
-
Chainable,
|
|
7
|
-
Module,
|
|
8
|
-
Target,
|
|
9
|
-
TensorwiseTransform,
|
|
10
|
-
Transform,
|
|
11
|
-
Var,
|
|
12
|
-
apply_transform,
|
|
13
|
-
)
|
|
14
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
15
|
-
from ...utils.linalg import matrix_power_eigh
|
|
16
|
-
from ..functional import add_power_, lerp_power_, root
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def adagrad_(
|
|
20
|
-
tensors_: TensorList,
|
|
21
|
-
sq_sum_: TensorList,
|
|
22
|
-
alpha: float | NumberList,
|
|
23
|
-
lr_decay: float | NumberList,
|
|
24
|
-
eps: float | NumberList,
|
|
25
|
-
step: int,
|
|
26
|
-
pow: float = 2,
|
|
27
|
-
use_sqrt: bool = True,
|
|
28
|
-
divide: bool = False,
|
|
29
|
-
|
|
30
|
-
# inner args
|
|
31
|
-
inner: Module | None = None,
|
|
32
|
-
params: list[torch.Tensor] | None = None,
|
|
33
|
-
grads: list[torch.Tensor] | None = None,
|
|
34
|
-
):
|
|
35
|
-
"""returns `tensors_`"""
|
|
36
|
-
clr = alpha / (1 + step * lr_decay)
|
|
37
|
-
|
|
38
|
-
sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
|
|
39
|
-
|
|
40
|
-
if inner is not None:
|
|
41
|
-
assert params is not None
|
|
42
|
-
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
43
|
-
|
|
44
|
-
if divide: sq_sum_ = sq_sum_ / max(step, 1)
|
|
45
|
-
|
|
46
|
-
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
47
|
-
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
48
|
-
|
|
49
|
-
return tensors_
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class Adagrad(Transform):
|
|
54
|
-
"""Adagrad, divides by sum of past squares of gradients.
|
|
55
|
-
|
|
56
|
-
This implementation is identical to :code:`torch.optim.Adagrad`.
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
60
|
-
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
61
|
-
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
62
|
-
alpha (float, optional): step size. Defaults to 1.
|
|
63
|
-
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
64
|
-
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
65
|
-
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
66
|
-
"""
|
|
67
|
-
def __init__(
|
|
68
|
-
self,
|
|
69
|
-
lr_decay: float = 0,
|
|
70
|
-
initial_accumulator_value: float = 0,
|
|
71
|
-
eps: float = 1e-10,
|
|
72
|
-
alpha: float = 1,
|
|
73
|
-
pow: float = 2,
|
|
74
|
-
use_sqrt: bool = True,
|
|
75
|
-
divide: bool=False,
|
|
76
|
-
inner: Chainable | None = None,
|
|
77
|
-
):
|
|
78
|
-
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
79
|
-
eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide)
|
|
80
|
-
super().__init__(defaults=defaults, uses_grad=False)
|
|
81
|
-
|
|
82
|
-
if inner is not None:
|
|
83
|
-
self.set_child('inner', inner)
|
|
84
|
-
|
|
85
|
-
@torch.no_grad
|
|
86
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
87
|
-
tensors = TensorList(tensors)
|
|
88
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
89
|
-
|
|
90
|
-
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
91
|
-
|
|
92
|
-
pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
|
|
93
|
-
|
|
94
|
-
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
95
|
-
|
|
96
|
-
# initialize accumulator on 1st step
|
|
97
|
-
if step == 1:
|
|
98
|
-
sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
|
|
99
|
-
|
|
100
|
-
return adagrad_(
|
|
101
|
-
tensors,
|
|
102
|
-
sq_sum_=sq_sum,
|
|
103
|
-
alpha=alpha,
|
|
104
|
-
lr_decay=lr_decay,
|
|
105
|
-
eps=eps,
|
|
106
|
-
step=self.global_state["step"],
|
|
107
|
-
pow=pow,
|
|
108
|
-
use_sqrt=use_sqrt,
|
|
109
|
-
divide=divide,
|
|
110
|
-
|
|
111
|
-
# inner args
|
|
112
|
-
inner=self.children.get("inner", None),
|
|
113
|
-
params=params,
|
|
114
|
-
grads=grads,
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
class FullMatrixAdagrad(TensorwiseTransform):
|
|
120
|
-
def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=True, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', divide: bool=False, inner: Chainable | None = None):
|
|
121
|
-
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init, divide=divide)
|
|
122
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner,)
|
|
123
|
-
|
|
124
|
-
@torch.no_grad
|
|
125
|
-
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
126
|
-
G = tensor.ravel()
|
|
127
|
-
GG = torch.outer(G, G)
|
|
128
|
-
decay = setting['decay']
|
|
129
|
-
beta = setting['beta']
|
|
130
|
-
init = setting['init']
|
|
131
|
-
|
|
132
|
-
if 'GG' not in state:
|
|
133
|
-
if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
134
|
-
elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
|
|
135
|
-
elif init == 'ones': state['GG'] = torch.ones_like(GG)
|
|
136
|
-
elif init == 'GGT': state['GG'] = GG.clone()
|
|
137
|
-
else: raise ValueError(init)
|
|
138
|
-
if decay is not None: state['GG'].mul_(decay)
|
|
139
|
-
|
|
140
|
-
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
141
|
-
else: state['GG'].add_(GG)
|
|
142
|
-
state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
|
|
143
|
-
|
|
144
|
-
@torch.no_grad
|
|
145
|
-
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
146
|
-
GG = state['GG']
|
|
147
|
-
sqrt = setting['sqrt']
|
|
148
|
-
divide = setting['divide']
|
|
149
|
-
if divide: GG = GG/state.get('i', 1)
|
|
150
|
-
|
|
151
|
-
if tensor.numel() == 1:
|
|
152
|
-
GG = GG.squeeze()
|
|
153
|
-
if sqrt: return tensor / GG.sqrt()
|
|
154
|
-
return tensor / GG
|
|
155
|
-
|
|
156
|
-
try:
|
|
157
|
-
if sqrt: B = matrix_power_eigh(GG, -1/2)
|
|
158
|
-
else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
|
|
159
|
-
|
|
160
|
-
except torch.linalg.LinAlgError:
|
|
161
|
-
scale = 1 / tensor.abs().max()
|
|
162
|
-
return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
|
|
163
|
-
|
|
164
|
-
return (B @ tensor.ravel()).view_as(tensor)
|
|
165
|
-
|