torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from .line_search import LineSearchBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# polynomial interpolation
|
|
8
|
+
# this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
|
|
9
|
+
# PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
|
|
10
|
+
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
|
|
11
|
+
"""
|
|
12
|
+
Gives the minimizer and minimum of the interpolating polynomial over given points
|
|
13
|
+
based on function and derivative information. Defaults to bisection if no critical
|
|
14
|
+
points are valid.
|
|
15
|
+
|
|
16
|
+
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
|
|
17
|
+
modifications.
|
|
18
|
+
|
|
19
|
+
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
|
|
20
|
+
Last edited 12/6/18.
|
|
21
|
+
|
|
22
|
+
Inputs:
|
|
23
|
+
points (nparray): two-dimensional array with each point of form [x f g]
|
|
24
|
+
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
|
|
25
|
+
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
|
|
26
|
+
plot (bool): plot interpolating polynomial
|
|
27
|
+
|
|
28
|
+
Outputs:
|
|
29
|
+
x_sol (float): minimizer of interpolating polynomial
|
|
30
|
+
F_min (float): minimum of interpolating polynomial
|
|
31
|
+
|
|
32
|
+
Note:
|
|
33
|
+
. Set f or g to np.nan if they are unknown
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
no_points = points.shape[0]
|
|
37
|
+
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
|
|
38
|
+
|
|
39
|
+
x_min = np.min(points[:, 0])
|
|
40
|
+
x_max = np.max(points[:, 0])
|
|
41
|
+
|
|
42
|
+
# compute bounds of interpolation area
|
|
43
|
+
if x_min_bound is None:
|
|
44
|
+
x_min_bound = x_min
|
|
45
|
+
if x_max_bound is None:
|
|
46
|
+
x_max_bound = x_max
|
|
47
|
+
|
|
48
|
+
# explicit formula for quadratic interpolation
|
|
49
|
+
if no_points == 2 and order == 2 and plot is False:
|
|
50
|
+
# Solution to quadratic interpolation is given by:
|
|
51
|
+
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
|
|
52
|
+
# x_min = x1 - g1/(2a)
|
|
53
|
+
# if x1 = 0, then is given by:
|
|
54
|
+
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
|
|
55
|
+
|
|
56
|
+
if points[0, 0] == 0:
|
|
57
|
+
x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
|
|
58
|
+
else:
|
|
59
|
+
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
|
|
60
|
+
x_sol = points[0, 0] - points[0, 2]/(2*a)
|
|
61
|
+
|
|
62
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
63
|
+
|
|
64
|
+
# explicit formula for cubic interpolation
|
|
65
|
+
elif no_points == 2 and order == 3 and plot is False:
|
|
66
|
+
# Solution to cubic interpolation is given by:
|
|
67
|
+
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
|
|
68
|
+
# d2 = sqrt(d1^2 - g1*g2)
|
|
69
|
+
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
|
|
70
|
+
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
|
|
71
|
+
d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
|
|
72
|
+
if np.isreal(d2):
|
|
73
|
+
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
|
|
74
|
+
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
|
|
75
|
+
else:
|
|
76
|
+
x_sol = (x_max_bound + x_min_bound)/2
|
|
77
|
+
|
|
78
|
+
# solve linear system
|
|
79
|
+
else:
|
|
80
|
+
# define linear constraints
|
|
81
|
+
A = np.zeros((0, order + 1))
|
|
82
|
+
b = np.zeros((0, 1))
|
|
83
|
+
|
|
84
|
+
# add linear constraints on function values
|
|
85
|
+
for i in range(no_points):
|
|
86
|
+
if not np.isnan(points[i, 1]):
|
|
87
|
+
constraint = np.zeros((1, order + 1))
|
|
88
|
+
for j in range(order, -1, -1):
|
|
89
|
+
constraint[0, order - j] = points[i, 0] ** j
|
|
90
|
+
A = np.append(A, constraint, 0)
|
|
91
|
+
b = np.append(b, points[i, 1])
|
|
92
|
+
|
|
93
|
+
# add linear constraints on gradient values
|
|
94
|
+
for i in range(no_points):
|
|
95
|
+
if not np.isnan(points[i, 2]):
|
|
96
|
+
constraint = np.zeros((1, order + 1))
|
|
97
|
+
for j in range(order):
|
|
98
|
+
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
|
|
99
|
+
A = np.append(A, constraint, 0)
|
|
100
|
+
b = np.append(b, points[i, 2])
|
|
101
|
+
|
|
102
|
+
# check if system is solvable
|
|
103
|
+
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
|
|
104
|
+
x_sol = (x_min_bound + x_max_bound)/2
|
|
105
|
+
f_min = np.inf
|
|
106
|
+
else:
|
|
107
|
+
# solve linear system for interpolating polynomial
|
|
108
|
+
coeff = np.linalg.solve(A, b)
|
|
109
|
+
|
|
110
|
+
# compute critical points
|
|
111
|
+
dcoeff = np.zeros(order)
|
|
112
|
+
for i in range(len(coeff) - 1):
|
|
113
|
+
dcoeff[i] = coeff[i] * (order - i)
|
|
114
|
+
|
|
115
|
+
crit_pts = np.array([x_min_bound, x_max_bound])
|
|
116
|
+
crit_pts = np.append(crit_pts, points[:, 0])
|
|
117
|
+
|
|
118
|
+
if not np.isinf(dcoeff).any():
|
|
119
|
+
roots = np.roots(dcoeff)
|
|
120
|
+
crit_pts = np.append(crit_pts, roots)
|
|
121
|
+
|
|
122
|
+
# test critical points
|
|
123
|
+
f_min = np.inf
|
|
124
|
+
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
|
|
125
|
+
for crit_pt in crit_pts:
|
|
126
|
+
if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
|
|
127
|
+
F_cp = np.polyval(coeff, crit_pt)
|
|
128
|
+
if np.isreal(F_cp) and F_cp < f_min:
|
|
129
|
+
x_sol = np.real(crit_pt)
|
|
130
|
+
f_min = np.real(F_cp)
|
|
131
|
+
|
|
132
|
+
if(plot):
|
|
133
|
+
import matplotlib.pyplot as plt
|
|
134
|
+
plt.figure()
|
|
135
|
+
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
|
|
136
|
+
f = np.polyval(coeff, x)
|
|
137
|
+
plt.plot(x, f)
|
|
138
|
+
plt.plot(x_sol, f_min, 'x')
|
|
139
|
+
|
|
140
|
+
return x_sol
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# class PolynomialLineSearch(LineSearch):
|
|
145
|
+
# """TODO
|
|
146
|
+
|
|
147
|
+
# Line search via polynomial interpolation.
|
|
148
|
+
|
|
149
|
+
# Args:
|
|
150
|
+
# init (float, optional): Initial step size. Defaults to 1.0.
|
|
151
|
+
# c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
|
|
152
|
+
# c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
|
|
153
|
+
# maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
|
|
154
|
+
# maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
|
|
155
|
+
# expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
|
|
156
|
+
# adaptive (bool, optional):
|
|
157
|
+
# when enabled, if line search failed, initial step size is reduced.
|
|
158
|
+
# Otherwise it is reset to initial value. Defaults to True.
|
|
159
|
+
# plus_minus (bool, optional):
|
|
160
|
+
# If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
# Examples:
|
|
164
|
+
# Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
|
|
165
|
+
|
|
166
|
+
# .. code-block:: python
|
|
167
|
+
|
|
168
|
+
# opt = tz.Modular(
|
|
169
|
+
# model.parameters(),
|
|
170
|
+
# tz.m.PolakRibiere(),
|
|
171
|
+
# tz.m.StrongWolfe(c2=0.1)
|
|
172
|
+
# )
|
|
173
|
+
|
|
174
|
+
# LBFGS strong wolfe line search:
|
|
175
|
+
|
|
176
|
+
# .. code-block:: python
|
|
177
|
+
|
|
178
|
+
# opt = tz.Modular(
|
|
179
|
+
# model.parameters(),
|
|
180
|
+
# tz.m.LBFGS(),
|
|
181
|
+
# tz.m.StrongWolfe()
|
|
182
|
+
# )
|
|
183
|
+
|
|
184
|
+
# """
|
|
185
|
+
# def __init__(
|
|
186
|
+
# self,
|
|
187
|
+
# init: float = 1.0,
|
|
188
|
+
# c1: float = 1e-4,
|
|
189
|
+
# c2: float = 0.9,
|
|
190
|
+
# maxiter: int = 25,
|
|
191
|
+
# maxzoom: int = 10,
|
|
192
|
+
# # a_max: float = 1e10,
|
|
193
|
+
# expand: float = 2.0,
|
|
194
|
+
# adaptive = True,
|
|
195
|
+
# plus_minus = False,
|
|
196
|
+
# ):
|
|
197
|
+
# defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
|
|
198
|
+
# expand=expand, adaptive=adaptive, plus_minus=plus_minus)
|
|
199
|
+
# super().__init__(defaults=defaults)
|
|
200
|
+
|
|
201
|
+
# self.global_state['initial_scale'] = 1.0
|
|
202
|
+
# self.global_state['beta_scale'] = 1.0
|
|
203
|
+
|
|
204
|
+
# @torch.no_grad
|
|
205
|
+
# def search(self, update, var):
|
|
206
|
+
# objective = self.make_objective_with_derivative(var=var)
|
|
207
|
+
|
|
208
|
+
# init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
|
|
209
|
+
# 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
210
|
+
# 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
|
|
211
|
+
|
|
212
|
+
# f_0, g_0 = objective(0)
|
|
213
|
+
|
|
214
|
+
# step_size,f_a = strong_wolfe(
|
|
215
|
+
# objective,
|
|
216
|
+
# f_0=f_0, g_0=g_0,
|
|
217
|
+
# init=init * self.global_state.setdefault("initial_scale", 1),
|
|
218
|
+
# c1=c1,
|
|
219
|
+
# c2=c2,
|
|
220
|
+
# maxiter=maxiter,
|
|
221
|
+
# maxzoom=maxzoom,
|
|
222
|
+
# expand=expand,
|
|
223
|
+
# plus_minus=plus_minus,
|
|
224
|
+
# )
|
|
225
|
+
|
|
226
|
+
# if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
|
|
227
|
+
# if step_size is not None and step_size != 0 and not _notfinite(step_size):
|
|
228
|
+
# self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
|
|
229
|
+
# return step_size
|
|
230
|
+
|
|
231
|
+
# # fallback to backtracking on fail
|
|
232
|
+
# if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
233
|
+
# return 0
|
|
@@ -3,10 +3,10 @@ from operator import itemgetter
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from .line_search import
|
|
6
|
+
from .line_search import LineSearchBase
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class ScipyMinimizeScalar(
|
|
9
|
+
class ScipyMinimizeScalar(LineSearchBase):
|
|
10
10
|
"""Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
|
|
11
11
|
|
|
12
12
|
Args:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""this needs to be reworked maybe but it also works"""
|
|
1
2
|
import math
|
|
2
3
|
import warnings
|
|
3
4
|
from operator import itemgetter
|
|
@@ -5,8 +6,7 @@ from operator import itemgetter
|
|
|
5
6
|
import torch
|
|
6
7
|
from torch.optim.lbfgs import _cubic_interpolate
|
|
7
8
|
|
|
8
|
-
from .line_search import
|
|
9
|
-
from .backtracking import backtracking_line_search
|
|
9
|
+
from .line_search import LineSearchBase
|
|
10
10
|
from ...utils import totensor
|
|
11
11
|
|
|
12
12
|
|
|
@@ -182,7 +182,7 @@ def _notfinite(x):
|
|
|
182
182
|
if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
|
|
183
183
|
return not math.isfinite(x)
|
|
184
184
|
|
|
185
|
-
class StrongWolfe(
|
|
185
|
+
class StrongWolfe(LineSearchBase):
|
|
186
186
|
"""Cubic interpolation line search satisfying Strong Wolfe condition.
|
|
187
187
|
|
|
188
188
|
Args:
|
|
@@ -192,11 +192,36 @@ class StrongWolfe(LineSearch):
|
|
|
192
192
|
maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
|
|
193
193
|
maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
|
|
194
194
|
expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
|
|
195
|
+
use_prev (bool, optional):
|
|
196
|
+
if True, previous step size is used as the initial step size on the next step.
|
|
195
197
|
adaptive (bool, optional):
|
|
196
198
|
when enabled, if line search failed, initial step size is reduced.
|
|
197
199
|
Otherwise it is reset to initial value. Defaults to True.
|
|
198
200
|
plus_minus (bool, optional):
|
|
199
201
|
If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
Examples:
|
|
205
|
+
Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
|
|
206
|
+
|
|
207
|
+
.. code-block:: python
|
|
208
|
+
|
|
209
|
+
opt = tz.Modular(
|
|
210
|
+
model.parameters(),
|
|
211
|
+
tz.m.PolakRibiere(),
|
|
212
|
+
tz.m.StrongWolfe(c2=0.1)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
LBFGS strong wolfe line search:
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
opt = tz.Modular(
|
|
220
|
+
model.parameters(),
|
|
221
|
+
tz.m.LBFGS(),
|
|
222
|
+
tz.m.StrongWolfe()
|
|
223
|
+
)
|
|
224
|
+
|
|
200
225
|
"""
|
|
201
226
|
def __init__(
|
|
202
227
|
self,
|
|
@@ -207,11 +232,12 @@ class StrongWolfe(LineSearch):
|
|
|
207
232
|
maxzoom: int = 10,
|
|
208
233
|
# a_max: float = 1e10,
|
|
209
234
|
expand: float = 2.0,
|
|
235
|
+
use_prev: bool = False,
|
|
210
236
|
adaptive = True,
|
|
211
237
|
plus_minus = False,
|
|
212
238
|
):
|
|
213
239
|
defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
|
|
214
|
-
expand=expand, adaptive=adaptive, plus_minus=plus_minus)
|
|
240
|
+
expand=expand, adaptive=adaptive, plus_minus=plus_minus,use_prev=use_prev)
|
|
215
241
|
super().__init__(defaults=defaults)
|
|
216
242
|
|
|
217
243
|
self.global_state['initial_scale'] = 1.0
|
|
@@ -221,11 +247,12 @@ class StrongWolfe(LineSearch):
|
|
|
221
247
|
def search(self, update, var):
|
|
222
248
|
objective = self.make_objective_with_derivative(var=var)
|
|
223
249
|
|
|
224
|
-
init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
|
|
250
|
+
init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus, use_prev = itemgetter(
|
|
225
251
|
'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
226
|
-
'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
|
|
252
|
+
'expand', 'adaptive', 'plus_minus', 'use_prev')(self.settings[var.params[0]])
|
|
227
253
|
|
|
228
254
|
f_0, g_0 = objective(0)
|
|
255
|
+
if use_prev: init = self.global_state.get('prev_alpha', init)
|
|
229
256
|
|
|
230
257
|
step_size,f_a = strong_wolfe(
|
|
231
258
|
objective,
|
|
@@ -242,8 +269,8 @@ class StrongWolfe(LineSearch):
|
|
|
242
269
|
if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
|
|
243
270
|
if step_size is not None and step_size != 0 and not _notfinite(step_size):
|
|
244
271
|
self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
|
|
272
|
+
self.global_state['prev_alpha'] = step_size
|
|
245
273
|
return step_size
|
|
246
274
|
|
|
247
|
-
# fallback to backtracking on fail
|
|
248
275
|
if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
249
276
|
return 0
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from .debug import PrintLoss, PrintParams, PrintShape, PrintUpdate
|
|
2
|
+
from .escape import EscapeAnnealing
|
|
3
|
+
from .gradient_accumulation import GradientAccumulation
|
|
4
|
+
from .misc import (
|
|
5
|
+
DivByLoss,
|
|
6
|
+
FillLoss,
|
|
7
|
+
GradSign,
|
|
8
|
+
GraftGradToUpdate,
|
|
9
|
+
GraftToGrad,
|
|
10
|
+
GraftToParams,
|
|
11
|
+
HpuEstimate,
|
|
12
|
+
LastAbsoluteRatio,
|
|
13
|
+
LastDifference,
|
|
14
|
+
LastGradDifference,
|
|
15
|
+
LastProduct,
|
|
16
|
+
LastRatio,
|
|
17
|
+
MulByLoss,
|
|
18
|
+
NoiseSign,
|
|
19
|
+
Previous,
|
|
20
|
+
RandomHvp,
|
|
21
|
+
Relative,
|
|
22
|
+
UpdateSign,
|
|
23
|
+
)
|
|
24
|
+
from .multistep import Multistep, NegateOnLossIncrease, Online, Sequential
|
|
25
|
+
from .regularization import Dropout, PerturbWeights, WeightDropout
|
|
26
|
+
from .split import Split
|
|
27
|
+
from .switch import Alternate, Switch
|
|
@@ -6,6 +6,7 @@ from ...core import Module
|
|
|
6
6
|
from ...utils.tensorlist import Distributions
|
|
7
7
|
|
|
8
8
|
class PrintUpdate(Module):
|
|
9
|
+
"""Prints current update."""
|
|
9
10
|
def __init__(self, text = 'update = ', print_fn = print):
|
|
10
11
|
defaults = dict(text=text, print_fn=print_fn)
|
|
11
12
|
super().__init__(defaults)
|
|
@@ -15,6 +16,7 @@ class PrintUpdate(Module):
|
|
|
15
16
|
return var
|
|
16
17
|
|
|
17
18
|
class PrintShape(Module):
|
|
19
|
+
"""Prints shapes of the update."""
|
|
18
20
|
def __init__(self, text = 'shapes = ', print_fn = print):
|
|
19
21
|
defaults = dict(text=text, print_fn=print_fn)
|
|
20
22
|
super().__init__(defaults)
|
|
@@ -22,4 +24,25 @@ class PrintShape(Module):
|
|
|
22
24
|
def step(self, var):
|
|
23
25
|
shapes = [u.shape for u in var.update] if var.update is not None else None
|
|
24
26
|
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
|
|
25
|
-
return var
|
|
27
|
+
return var
|
|
28
|
+
|
|
29
|
+
class PrintParams(Module):
|
|
30
|
+
"""Prints current update."""
|
|
31
|
+
def __init__(self, text = 'params = ', print_fn = print):
|
|
32
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
33
|
+
super().__init__(defaults)
|
|
34
|
+
|
|
35
|
+
def step(self, var):
|
|
36
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.params}')
|
|
37
|
+
return var
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PrintLoss(Module):
|
|
41
|
+
"""Prints var.get_loss()."""
|
|
42
|
+
def __init__(self, text = 'loss = ', print_fn = print):
|
|
43
|
+
defaults = dict(text=text, print_fn=print_fn)
|
|
44
|
+
super().__init__(defaults)
|
|
45
|
+
|
|
46
|
+
def step(self, var):
|
|
47
|
+
self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.get_loss(False)}')
|
|
48
|
+
return var
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Module
|
|
4
|
+
from ...utils import TensorList, NumberList
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EscapeAnnealing(Module):
|
|
8
|
+
"""If parameters stop changing, this runs a backward annealing random search"""
|
|
9
|
+
def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
|
|
10
|
+
defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
|
|
11
|
+
super().__init__(defaults)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@torch.no_grad
|
|
15
|
+
def step(self, var):
|
|
16
|
+
closure = var.closure
|
|
17
|
+
if closure is None: raise RuntimeError("Escape requries closure")
|
|
18
|
+
|
|
19
|
+
params = TensorList(var.params)
|
|
20
|
+
settings = self.settings[params[0]]
|
|
21
|
+
max_region = self.get_settings(params, 'max_region', cls=NumberList)
|
|
22
|
+
max_iter = settings['max_iter']
|
|
23
|
+
tol = settings['tol']
|
|
24
|
+
n_tol = settings['n_tol']
|
|
25
|
+
|
|
26
|
+
n_bad = self.global_state.get('n_bad', 0)
|
|
27
|
+
|
|
28
|
+
prev_params = self.get_state(params, 'prev_params', cls=TensorList)
|
|
29
|
+
diff = params-prev_params
|
|
30
|
+
prev_params.copy_(params)
|
|
31
|
+
|
|
32
|
+
if diff.abs().global_max() <= tol:
|
|
33
|
+
n_bad += 1
|
|
34
|
+
|
|
35
|
+
else:
|
|
36
|
+
n_bad = 0
|
|
37
|
+
|
|
38
|
+
self.global_state['n_bad'] = n_bad
|
|
39
|
+
|
|
40
|
+
# no progress
|
|
41
|
+
f_0 = var.get_loss(False)
|
|
42
|
+
if n_bad >= n_tol:
|
|
43
|
+
for i in range(1, max_iter+1):
|
|
44
|
+
alpha = max_region * (i / max_iter)
|
|
45
|
+
pert = params.sample_like(distribution='sphere').mul_(alpha)
|
|
46
|
+
|
|
47
|
+
params.add_(pert)
|
|
48
|
+
f_star = closure(False)
|
|
49
|
+
|
|
50
|
+
if f_star < f_0-1e-10:
|
|
51
|
+
var.update = None
|
|
52
|
+
var.stop = True
|
|
53
|
+
var.skip_update = True
|
|
54
|
+
return var
|
|
55
|
+
|
|
56
|
+
else:
|
|
57
|
+
params.sub_(pert)
|
|
58
|
+
|
|
59
|
+
self.global_state['n_bad'] = 0
|
|
60
|
+
return var
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Chainable, Module
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GradientAccumulation(Module):
|
|
7
|
+
"""Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
8
|
+
|
|
9
|
+
Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
|
|
10
|
+
is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
11
|
+
|
|
12
|
+
.. note::
|
|
13
|
+
Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
|
|
17
|
+
n (int): number of gradients to accumulate.
|
|
18
|
+
mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
19
|
+
stop (bool, optional):
|
|
20
|
+
this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
21
|
+
|
|
22
|
+
Examples:
|
|
23
|
+
Adam with gradients accumulated for 16 batches.
|
|
24
|
+
|
|
25
|
+
.. code-block:: python
|
|
26
|
+
|
|
27
|
+
opt = tz.Modular(
|
|
28
|
+
model.parameters(),
|
|
29
|
+
tz.m.GradientAccumulation(
|
|
30
|
+
modules=[tz.m.Adam(), tz.m.LR(1e-2)],
|
|
31
|
+
n=16
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
37
|
+
defaults = dict(n=n, mean=mean, stop=stop)
|
|
38
|
+
super().__init__(defaults)
|
|
39
|
+
self.set_child('modules', modules)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@torch.no_grad
|
|
43
|
+
def step(self, var):
|
|
44
|
+
accumulator = self.get_state(var.params, 'accumulator')
|
|
45
|
+
settings = self.settings[var.params[0]]
|
|
46
|
+
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
47
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
48
|
+
|
|
49
|
+
# add update to accumulator
|
|
50
|
+
torch._foreach_add_(accumulator, var.get_update())
|
|
51
|
+
|
|
52
|
+
# step with accumulated updates
|
|
53
|
+
if step % n == 0:
|
|
54
|
+
if mean:
|
|
55
|
+
torch._foreach_div_(accumulator, n)
|
|
56
|
+
|
|
57
|
+
var.update = [a.clone() for a in accumulator]
|
|
58
|
+
var = self.children['modules'].step(var)
|
|
59
|
+
|
|
60
|
+
# zero accumulator
|
|
61
|
+
torch._foreach_zero_(accumulator)
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
# prevent update
|
|
65
|
+
if stop:
|
|
66
|
+
var.stop=True
|
|
67
|
+
var.skip_update=True
|
|
68
|
+
|
|
69
|
+
return var
|
|
70
|
+
|