torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -70,57 +70,99 @@ def _proximal_poly_H(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
|
70
70
|
def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
|
|
71
71
|
derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
|
|
72
72
|
x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
|
|
73
|
-
bounds = None
|
|
74
|
-
if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
75
73
|
|
|
76
|
-
#
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
74
|
+
# notes
|
|
75
|
+
# 1. since we have exact hessian we use trust methods
|
|
76
|
+
|
|
77
|
+
# 2. if len(derivatives) is 1, only gradient is available,
|
|
78
|
+
# thus use slsqp depending on whether trust region is enabled
|
|
79
|
+
# this is just so that I can test that trust region works
|
|
80
|
+
if trust_region is None:
|
|
81
|
+
if len(derivatives) == 1: raise RuntimeError("trust region must be enabled because 1st order has no minima")
|
|
82
|
+
method = 'trust-exact'
|
|
83
|
+
de_bounds = list(zip(x0 - 10, x0 + 10))
|
|
84
|
+
constraints = None
|
|
85
|
+
|
|
80
86
|
else:
|
|
81
|
-
if len(derivatives) == 1: method = '
|
|
87
|
+
if len(derivatives) == 1: method = 'slsqp'
|
|
82
88
|
else: method = 'trust-constr'
|
|
89
|
+
de_bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
90
|
+
|
|
91
|
+
def l2_bound_f(x):
|
|
92
|
+
if x.ndim == 2: return np.sum((x - x0[:,None])**2, axis=0)[None,:] # DE passes (ndim, batch_size) and expects (M, S)
|
|
93
|
+
return np.sum((x - x0)**2, axis=0)
|
|
94
|
+
|
|
95
|
+
def l2_bound_g(x):
|
|
96
|
+
return 2 * (x - x0)
|
|
97
|
+
|
|
98
|
+
def l2_bound_h(x, v):
|
|
99
|
+
return v[0] * 2 * np.eye(x0.shape[0])
|
|
100
|
+
|
|
101
|
+
constraint = scipy.optimize.NonlinearConstraint(
|
|
102
|
+
fun=l2_bound_f,
|
|
103
|
+
lb=0, # 0 <= ||x-x0||^2
|
|
104
|
+
ub=trust_region**2, # ||x-x0||^2 <= R^2
|
|
105
|
+
jac=l2_bound_g, # pyright:ignore[reportArgumentType]
|
|
106
|
+
hess=l2_bound_h,
|
|
107
|
+
keep_feasible=False
|
|
108
|
+
)
|
|
109
|
+
constraints = [constraint]
|
|
83
110
|
|
|
84
111
|
x_init = x0.copy()
|
|
85
112
|
v0 = _proximal_poly_v(x0, c, prox, x0, derivatives)
|
|
113
|
+
|
|
114
|
+
# ---------------------------------- run DE ---------------------------------- #
|
|
86
115
|
if de_iters is not None and de_iters != 0:
|
|
87
116
|
if de_iters == -1: de_iters = None # let scipy decide
|
|
117
|
+
|
|
118
|
+
# DE needs bounds so use linf ig
|
|
88
119
|
res = scipy.optimize.differential_evolution(
|
|
89
120
|
_proximal_poly_v,
|
|
90
|
-
|
|
121
|
+
de_bounds,
|
|
91
122
|
args=(c, prox, x0.copy(), derivatives),
|
|
92
123
|
maxiter=de_iters,
|
|
93
124
|
vectorized=True,
|
|
125
|
+
constraints = constraints,
|
|
126
|
+
updating='deferred',
|
|
94
127
|
)
|
|
95
|
-
if res.fun < v0: x_init = res.x
|
|
96
|
-
|
|
97
|
-
res = scipy.optimize.minimize(
|
|
98
|
-
_proximal_poly_v,
|
|
99
|
-
x_init,
|
|
100
|
-
method=method,
|
|
101
|
-
args=(c, prox, x0.copy(), derivatives),
|
|
102
|
-
jac=_proximal_poly_g,
|
|
103
|
-
hess=_proximal_poly_H,
|
|
104
|
-
bounds=bounds
|
|
105
|
-
)
|
|
128
|
+
if res.fun < v0 and np.all(np.isfinite(res.x)): x_init = res.x
|
|
106
129
|
|
|
130
|
+
# ------------------------------- run minimize ------------------------------- #
|
|
131
|
+
try:
|
|
132
|
+
res = scipy.optimize.minimize(
|
|
133
|
+
_proximal_poly_v,
|
|
134
|
+
x_init,
|
|
135
|
+
method=method,
|
|
136
|
+
args=(c, prox, x0.copy(), derivatives),
|
|
137
|
+
jac=_proximal_poly_g,
|
|
138
|
+
hess=_proximal_poly_H,
|
|
139
|
+
constraints = constraints,
|
|
140
|
+
)
|
|
141
|
+
except ValueError:
|
|
142
|
+
return x, -float('inf')
|
|
107
143
|
return torch.from_numpy(res.x).to(x), res.fun
|
|
108
144
|
|
|
109
145
|
|
|
110
146
|
|
|
111
147
|
class HigherOrderNewton(Module):
|
|
112
|
-
"""
|
|
113
|
-
A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
114
|
-
It is recommended to enable at least one of trust region or proximal penalty.
|
|
148
|
+
"""A basic arbitrary order newton's method with optional trust region and proximal penalty.
|
|
115
149
|
|
|
116
150
|
This constructs an nth order taylor approximation via autograd and minimizes it with
|
|
117
151
|
scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
|
|
118
152
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
153
|
+
.. note::
|
|
154
|
+
In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
155
|
+
|
|
156
|
+
.. note::
|
|
157
|
+
This module requires the a closure passed to the optimizer step,
|
|
158
|
+
as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
|
|
159
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
160
|
+
|
|
161
|
+
.. warning::
|
|
162
|
+
this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
|
|
163
|
+
|
|
164
|
+
.. warning::
|
|
165
|
+
"none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
|
|
124
166
|
|
|
125
167
|
Args:
|
|
126
168
|
|
|
@@ -149,39 +191,38 @@ class HigherOrderNewton(Module):
|
|
|
149
191
|
self,
|
|
150
192
|
order: int = 4,
|
|
151
193
|
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
194
|
+
nplus: float = 2,
|
|
195
|
+
nminus: float = 0.25,
|
|
196
|
+
init: float | None = None,
|
|
197
|
+
eta: float = 1e-6,
|
|
198
|
+
max_attempts = 10,
|
|
156
199
|
de_iters: int | None = None,
|
|
157
200
|
vectorize: bool = True,
|
|
158
201
|
):
|
|
159
|
-
if
|
|
160
|
-
if trust_method == 'bounds':
|
|
161
|
-
else:
|
|
202
|
+
if init is None:
|
|
203
|
+
if trust_method == 'bounds': init = 1
|
|
204
|
+
else: init = 0.1
|
|
162
205
|
|
|
163
|
-
defaults = dict(order=order, trust_method=trust_method,
|
|
206
|
+
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
|
|
164
207
|
super().__init__(defaults)
|
|
165
208
|
|
|
166
209
|
@torch.no_grad
|
|
167
210
|
def step(self, var):
|
|
168
211
|
params = TensorList(var.params)
|
|
169
212
|
closure = var.closure
|
|
170
|
-
if closure is None: raise RuntimeError('
|
|
213
|
+
if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
|
|
171
214
|
|
|
172
215
|
settings = self.settings[params[0]]
|
|
173
216
|
order = settings['order']
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
217
|
+
nplus = settings['nplus']
|
|
218
|
+
nminus = settings['nminus']
|
|
219
|
+
eta = settings['eta']
|
|
220
|
+
init = settings['init']
|
|
178
221
|
trust_method = settings['trust_method']
|
|
179
222
|
de_iters = settings['de_iters']
|
|
223
|
+
max_attempts = settings['max_attempts']
|
|
180
224
|
vectorize = settings['vectorize']
|
|
181
225
|
|
|
182
|
-
trust_value = self.global_state.get('trust_value', trust_init)
|
|
183
|
-
|
|
184
|
-
|
|
185
226
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
186
227
|
with torch.enable_grad():
|
|
187
228
|
loss = var.loss = var.loss_approx = closure(False)
|
|
@@ -205,52 +246,74 @@ class HigherOrderNewton(Module):
|
|
|
205
246
|
|
|
206
247
|
x0 = torch.cat([p.ravel() for p in params])
|
|
207
248
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
trust_region =
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
249
|
+
success = False
|
|
250
|
+
x_star = None
|
|
251
|
+
while not success:
|
|
252
|
+
max_attempts -= 1
|
|
253
|
+
if max_attempts < 0: break
|
|
254
|
+
|
|
255
|
+
# load trust region value
|
|
256
|
+
trust_value = self.global_state.get('trust_region', init)
|
|
257
|
+
if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
|
|
258
|
+
|
|
259
|
+
if trust_method is None: trust_method = 'none'
|
|
260
|
+
else: trust_method = trust_method.lower()
|
|
261
|
+
|
|
262
|
+
if trust_method == 'none':
|
|
263
|
+
trust_region = None
|
|
264
|
+
prox = 0
|
|
265
|
+
|
|
266
|
+
elif trust_method == 'bounds':
|
|
267
|
+
trust_region = trust_value
|
|
268
|
+
prox = 0
|
|
269
|
+
|
|
270
|
+
elif trust_method == 'proximal':
|
|
271
|
+
trust_region = None
|
|
272
|
+
prox = 1 / trust_value
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(trust_method)
|
|
276
|
+
|
|
277
|
+
# minimize the model
|
|
278
|
+
x_star, expected_loss = _poly_minimize(
|
|
279
|
+
trust_region=trust_region,
|
|
280
|
+
prox=prox,
|
|
281
|
+
de_iters=de_iters,
|
|
282
|
+
c=loss.item(),
|
|
283
|
+
x=x0,
|
|
284
|
+
derivatives=derivatives,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# update trust region
|
|
288
|
+
if trust_method == 'none':
|
|
289
|
+
success = True
|
|
290
|
+
else:
|
|
291
|
+
pred_reduction = loss - expected_loss
|
|
292
|
+
|
|
293
|
+
vec_to_tensors_(x_star, params)
|
|
294
|
+
loss_star = closure(False)
|
|
295
|
+
vec_to_tensors_(x0, params)
|
|
296
|
+
reduction = loss - loss_star
|
|
297
|
+
|
|
298
|
+
rho = reduction / (max(pred_reduction, 1e-8))
|
|
299
|
+
# failed step
|
|
300
|
+
if rho < 0.25:
|
|
301
|
+
self.global_state['trust_region'] = trust_value * nminus
|
|
302
|
+
|
|
303
|
+
# very good step
|
|
304
|
+
elif rho > 0.75:
|
|
305
|
+
diff = trust_value - (x0 - x_star).abs_()
|
|
306
|
+
if (diff.amin() / trust_value) > 1e-4: # hits boundary
|
|
307
|
+
self.global_state['trust_region'] = trust_value * nplus
|
|
308
|
+
|
|
309
|
+
# if the ratio is high enough then accept the proposed step
|
|
310
|
+
success = rho > eta
|
|
311
|
+
|
|
312
|
+
assert x_star is not None
|
|
313
|
+
if success:
|
|
314
|
+
difference = vec_to_tensors(x0 - x_star, params)
|
|
315
|
+
var.update = list(difference)
|
|
223
316
|
else:
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
x_star, expected_loss = _poly_minimize(
|
|
227
|
-
trust_region=trust_region,
|
|
228
|
-
prox=prox,
|
|
229
|
-
de_iters=de_iters,
|
|
230
|
-
c=loss.item(),
|
|
231
|
-
x=x0,
|
|
232
|
-
derivatives=derivatives,
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
# trust region
|
|
236
|
-
if trust_method != 'none':
|
|
237
|
-
expected_reduction = loss - expected_loss
|
|
238
|
-
|
|
239
|
-
vec_to_tensors_(x_star, params)
|
|
240
|
-
loss_star = closure(False)
|
|
241
|
-
vec_to_tensors_(x0, params)
|
|
242
|
-
reduction = loss - loss_star
|
|
243
|
-
|
|
244
|
-
# failed step
|
|
245
|
-
if reduction <= 0:
|
|
246
|
-
x_star = x0
|
|
247
|
-
self.global_state['trust_value'] = trust_value * decrease
|
|
248
|
-
|
|
249
|
-
# very good step
|
|
250
|
-
elif expected_reduction / reduction <= trust_tol:
|
|
251
|
-
self.global_state['trust_value'] = trust_value * increase
|
|
252
|
-
|
|
253
|
-
difference = vec_to_tensors(x0 - x_star, params)
|
|
254
|
-
var.update = list(difference)
|
|
317
|
+
var.update = params.zeros_like()
|
|
255
318
|
return var
|
|
256
319
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .backtracking import
|
|
3
|
-
from .
|
|
1
|
+
from .adaptive import AdaptiveLineSearch
|
|
2
|
+
from .backtracking import AdaptiveBacktracking, Backtracking
|
|
3
|
+
from .line_search import LineSearchBase
|
|
4
4
|
from .scipy import ScipyMinimizeScalar
|
|
5
|
-
from .
|
|
5
|
+
from .strong_wolfe import StrongWolfe
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .line_search import LineSearchBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def adaptive_tracking(
|
|
12
|
+
f,
|
|
13
|
+
x_0,
|
|
14
|
+
maxiter: int,
|
|
15
|
+
nplus: float = 2,
|
|
16
|
+
nminus: float = 0.5,
|
|
17
|
+
):
|
|
18
|
+
f_0 = f(0)
|
|
19
|
+
|
|
20
|
+
t = x_0
|
|
21
|
+
f_t = f(t)
|
|
22
|
+
|
|
23
|
+
# backtrack
|
|
24
|
+
if f_t > f_0:
|
|
25
|
+
while f_t > f_0:
|
|
26
|
+
maxiter -= 1
|
|
27
|
+
if maxiter < 0: return 0, f_0
|
|
28
|
+
t = t*nminus
|
|
29
|
+
f_t = f(t)
|
|
30
|
+
return t, f_t
|
|
31
|
+
|
|
32
|
+
# forwardtrack
|
|
33
|
+
f_prev = f_t
|
|
34
|
+
t *= nplus
|
|
35
|
+
f_t = f(t)
|
|
36
|
+
if f_prev < f_t: return t / nplus, f_prev
|
|
37
|
+
while f_prev >= f_t:
|
|
38
|
+
maxiter -= 1
|
|
39
|
+
if maxiter < 0: return t, f_t
|
|
40
|
+
f_prev = f_t
|
|
41
|
+
t *= nplus
|
|
42
|
+
f_t = f(t)
|
|
43
|
+
return t / nplus, f_prev
|
|
44
|
+
|
|
45
|
+
class AdaptiveLineSearch(LineSearchBase):
|
|
46
|
+
"""Adaptive line search, similar to backtracking but also has forward tracking mode.
|
|
47
|
+
Currently doesn't check for weak curvature condition.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
51
|
+
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
52
|
+
maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
|
|
53
|
+
adaptive (bool, optional):
|
|
54
|
+
when enabled, if line search failed, beta size is reduced.
|
|
55
|
+
Otherwise it is reset to initial value. Defaults to True.
|
|
56
|
+
"""
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
init: float = 1.0,
|
|
60
|
+
nplus: float = 2,
|
|
61
|
+
nminus: float = 0.5,
|
|
62
|
+
maxiter: int = 10,
|
|
63
|
+
adaptive=True,
|
|
64
|
+
):
|
|
65
|
+
defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive,)
|
|
66
|
+
super().__init__(defaults=defaults)
|
|
67
|
+
self.global_state['beta_scale'] = 1.0
|
|
68
|
+
|
|
69
|
+
def reset(self):
|
|
70
|
+
super().reset()
|
|
71
|
+
self.global_state['beta_scale'] = 1.0
|
|
72
|
+
|
|
73
|
+
@torch.no_grad
|
|
74
|
+
def search(self, update, var):
|
|
75
|
+
init, nplus, nminus, maxiter, adaptive = itemgetter(
|
|
76
|
+
'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.settings[var.params[0]])
|
|
77
|
+
|
|
78
|
+
objective = self.make_objective(var=var)
|
|
79
|
+
|
|
80
|
+
# # directional derivative
|
|
81
|
+
# d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
|
|
82
|
+
|
|
83
|
+
# scale beta (beta is multiplicative and i think may be better than scaling initial step size)
|
|
84
|
+
beta_scale = self.global_state.get('beta_scale', 1)
|
|
85
|
+
x_prev = self.global_state.get('prev_x', 1)
|
|
86
|
+
|
|
87
|
+
if adaptive: nminus = nminus * beta_scale
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
step_size, f = adaptive_tracking(objective, x_prev, maxiter, nplus=nplus, nminus=nminus)
|
|
91
|
+
|
|
92
|
+
# found an alpha that reduces loss
|
|
93
|
+
if step_size != 0:
|
|
94
|
+
self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
95
|
+
return step_size
|
|
96
|
+
|
|
97
|
+
# on fail reduce beta scale value
|
|
98
|
+
self.global_state['beta_scale'] /= 1.5
|
|
99
|
+
return 0
|
|
@@ -4,7 +4,7 @@ from operator import itemgetter
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from .line_search import
|
|
7
|
+
from .line_search import LineSearchBase
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def backtracking_line_search(
|
|
@@ -19,12 +19,12 @@ def backtracking_line_search(
|
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
21
|
Args:
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
22
|
+
f: evaluates step size along some descent direction.
|
|
23
|
+
g_0: directional derivative along the descent direction.
|
|
24
|
+
init: initial step size.
|
|
25
25
|
beta: The factor by which to decrease alpha in each iteration
|
|
26
26
|
c: The constant for the Armijo sufficient decrease condition
|
|
27
|
-
|
|
27
|
+
maxiter: Maximum number of backtracking iterations (default: 10).
|
|
28
28
|
|
|
29
29
|
Returns:
|
|
30
30
|
step size
|
|
@@ -32,11 +32,15 @@ def backtracking_line_search(
|
|
|
32
32
|
|
|
33
33
|
a = init
|
|
34
34
|
f_x = f(0)
|
|
35
|
+
f_prev = None
|
|
35
36
|
|
|
36
37
|
for iteration in range(maxiter):
|
|
37
38
|
f_a = f(a)
|
|
38
39
|
|
|
39
|
-
if
|
|
40
|
+
if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_x): return a / beta
|
|
41
|
+
f_prev = f_a
|
|
42
|
+
|
|
43
|
+
if f_a < f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
|
|
40
44
|
# found an acceptable alpha
|
|
41
45
|
return a
|
|
42
46
|
|
|
@@ -59,7 +63,7 @@ def backtracking_line_search(
|
|
|
59
63
|
|
|
60
64
|
return None
|
|
61
65
|
|
|
62
|
-
class Backtracking(
|
|
66
|
+
class Backtracking(LineSearchBase):
|
|
63
67
|
"""Backtracking line search satisfying the Armijo condition.
|
|
64
68
|
|
|
65
69
|
Args:
|
|
@@ -68,9 +72,30 @@ class Backtracking(LineSearch):
|
|
|
68
72
|
c (float, optional): acceptance value for Armijo condition. Defaults to 1e-4.
|
|
69
73
|
maxiter (int, optional): Maximum line search function evaluations. Defaults to 10.
|
|
70
74
|
adaptive (bool, optional):
|
|
71
|
-
when enabled, if line search failed,
|
|
75
|
+
when enabled, if line search failed, beta is reduced.
|
|
72
76
|
Otherwise it is reset to initial value. Defaults to True.
|
|
73
77
|
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
Gradient descent with backtracking line search:
|
|
81
|
+
|
|
82
|
+
.. code-block:: python
|
|
83
|
+
|
|
84
|
+
opt = tz.Modular(
|
|
85
|
+
model.parameters(),
|
|
86
|
+
tz.m.Backtracking()
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
LBFGS with backtracking line search:
|
|
90
|
+
|
|
91
|
+
.. code-block:: python
|
|
92
|
+
|
|
93
|
+
opt = tz.Modular(
|
|
94
|
+
model.parameters(),
|
|
95
|
+
tz.m.LBFGS(),
|
|
96
|
+
tz.m.Backtracking()
|
|
97
|
+
)
|
|
98
|
+
|
|
74
99
|
"""
|
|
75
100
|
def __init__(
|
|
76
101
|
self,
|
|
@@ -117,7 +142,7 @@ class Backtracking(LineSearch):
|
|
|
117
142
|
def _lerp(start,end,weight):
|
|
118
143
|
return start + weight * (end - start)
|
|
119
144
|
|
|
120
|
-
class AdaptiveBacktracking(
|
|
145
|
+
class AdaptiveBacktracking(LineSearchBase):
|
|
121
146
|
"""Adaptive backtracking line search. After each line search procedure, a new initial step size is set
|
|
122
147
|
such that optimal step size in the procedure would be found on the second line search iteration.
|
|
123
148
|
|
|
@@ -15,8 +15,9 @@ from ...utils import tofloat
|
|
|
15
15
|
class MaxLineSearchItersReached(Exception): pass
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
18
|
+
class LineSearchBase(Module, ABC):
|
|
19
19
|
"""Base class for line searches.
|
|
20
|
+
|
|
20
21
|
This is an abstract class, to use it, subclass it and override `search`.
|
|
21
22
|
|
|
22
23
|
Args:
|
|
@@ -26,6 +27,62 @@ class LineSearch(Module, ABC):
|
|
|
26
27
|
the objective this many times, and step size with the lowest loss value will be used.
|
|
27
28
|
This is useful when passing `make_objective` to an external library which
|
|
28
29
|
doesn't have a maxiter option. Defaults to None.
|
|
30
|
+
|
|
31
|
+
Other useful methods:
|
|
32
|
+
* `evaluate_step_size` - returns loss with a given scalar step size
|
|
33
|
+
* `evaluate_step_size_loss_and_derivative` - returns loss and directional derivative with a given scalar step size
|
|
34
|
+
* `make_objective` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
|
|
35
|
+
* `make_objective_with_derivative` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
#### Basic line search
|
|
39
|
+
|
|
40
|
+
This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
|
|
41
|
+
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
class GridLineSearch(LineSearch):
|
|
45
|
+
def __init__(self, start, end, num):
|
|
46
|
+
defaults = dict(start=start,end=end,num=num)
|
|
47
|
+
super().__init__(defaults)
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def search(self, update, var):
|
|
51
|
+
settings = self.settings[var.params[0]]
|
|
52
|
+
start = settings["start"]
|
|
53
|
+
end = settings["end"]
|
|
54
|
+
num = settings["num"]
|
|
55
|
+
|
|
56
|
+
lowest_loss = float("inf")
|
|
57
|
+
best_step_size = best_step_size
|
|
58
|
+
|
|
59
|
+
for step_size in torch.linspace(start,end,num):
|
|
60
|
+
loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
|
|
61
|
+
if loss < lowest_loss:
|
|
62
|
+
lowest_loss = loss
|
|
63
|
+
best_step_size = step_size
|
|
64
|
+
|
|
65
|
+
return best_step_size
|
|
66
|
+
|
|
67
|
+
#### Using external solver via self.make_objective
|
|
68
|
+
|
|
69
|
+
Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
|
|
70
|
+
|
|
71
|
+
.. code-block:: python
|
|
72
|
+
|
|
73
|
+
class ScipyMinimizeScalar(LineSearch):
|
|
74
|
+
def __init__(self, method: str | None = None):
|
|
75
|
+
defaults = dict(method=method)
|
|
76
|
+
super().__init__(defaults)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def search(self, update, var):
|
|
80
|
+
objective = self.make_objective(var=var)
|
|
81
|
+
method = self.settings[var.params[0]]["method"]
|
|
82
|
+
|
|
83
|
+
res = self.scopt.minimize_scalar(objective, method=method)
|
|
84
|
+
return res.x
|
|
85
|
+
|
|
29
86
|
"""
|
|
30
87
|
def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
|
|
31
88
|
super().__init__(defaults)
|
|
@@ -165,17 +222,18 @@ class LineSearch(Module, ABC):
|
|
|
165
222
|
return var
|
|
166
223
|
|
|
167
224
|
|
|
168
|
-
class GridLineSearch(LineSearch):
|
|
169
|
-
"""Mostly for testing, this is not practical"""
|
|
170
|
-
def __init__(self, start, end, num):
|
|
171
|
-
defaults = dict(start=start,end=end,num=num)
|
|
172
|
-
super().__init__(defaults)
|
|
173
225
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
226
|
+
# class GridLineSearch(LineSearch):
|
|
227
|
+
# """Mostly for testing, this is not practical"""
|
|
228
|
+
# def __init__(self, start, end, num):
|
|
229
|
+
# defaults = dict(start=start,end=end,num=num)
|
|
230
|
+
# super().__init__(defaults)
|
|
231
|
+
|
|
232
|
+
# @torch.no_grad
|
|
233
|
+
# def search(self, update, var):
|
|
234
|
+
# start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
|
|
177
235
|
|
|
178
|
-
|
|
179
|
-
|
|
236
|
+
# for lr in torch.linspace(start,end,num):
|
|
237
|
+
# self.evaluate_step_size(lr.item(), var=var, backward=False)
|
|
180
238
|
|
|
181
|
-
|
|
239
|
+
# return self._best_step_size
|