torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,11 +1,14 @@
|
|
|
1
|
-
|
|
1
|
+
import warnings
|
|
2
|
+
import math
|
|
3
|
+
from typing import Literal, cast
|
|
4
|
+
from operator import itemgetter
|
|
2
5
|
import torch
|
|
3
6
|
|
|
4
|
-
from ...
|
|
7
|
+
from ...core import Chainable, Module, apply_transform
|
|
8
|
+
from ...utils import TensorList, as_tensorlist, tofloat
|
|
5
9
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
6
|
-
|
|
7
|
-
from
|
|
8
|
-
from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
|
|
10
|
+
from ...utils.linalg.solve import cg, minres, find_within_trust_radius
|
|
11
|
+
from ..trust_region.trust_region import default_radius
|
|
9
12
|
|
|
10
13
|
class NewtonCG(Module):
|
|
11
14
|
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
@@ -88,20 +91,25 @@ class NewtonCG(Module):
|
|
|
88
91
|
def __init__(
|
|
89
92
|
self,
|
|
90
93
|
maxiter: int | None = None,
|
|
91
|
-
tol: float = 1e-
|
|
94
|
+
tol: float = 1e-8,
|
|
92
95
|
reg: float = 1e-8,
|
|
93
96
|
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
94
97
|
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
95
98
|
h: float = 1e-3,
|
|
99
|
+
miniter:int = 1,
|
|
96
100
|
warm_start=False,
|
|
97
101
|
inner: Chainable | None = None,
|
|
98
102
|
):
|
|
99
|
-
defaults =
|
|
103
|
+
defaults = locals().copy()
|
|
104
|
+
del defaults['self'], defaults['inner']
|
|
100
105
|
super().__init__(defaults,)
|
|
101
106
|
|
|
102
107
|
if inner is not None:
|
|
103
108
|
self.set_child('inner', inner)
|
|
104
109
|
|
|
110
|
+
self._num_hvps = 0
|
|
111
|
+
self._num_hvps_last_step = 0
|
|
112
|
+
|
|
105
113
|
@torch.no_grad
|
|
106
114
|
def step(self, var):
|
|
107
115
|
params = TensorList(var.params)
|
|
@@ -117,11 +125,13 @@ class NewtonCG(Module):
|
|
|
117
125
|
h = settings['h']
|
|
118
126
|
warm_start = settings['warm_start']
|
|
119
127
|
|
|
128
|
+
self._num_hvps_last_step = 0
|
|
120
129
|
# ---------------------- Hessian vector product function --------------------- #
|
|
121
130
|
if hvp_method == 'autograd':
|
|
122
131
|
grad = var.get_grad(create_graph=True)
|
|
123
132
|
|
|
124
133
|
def H_mm(x):
|
|
134
|
+
self._num_hvps_last_step += 1
|
|
125
135
|
with torch.enable_grad():
|
|
126
136
|
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
127
137
|
|
|
@@ -132,10 +142,12 @@ class NewtonCG(Module):
|
|
|
132
142
|
|
|
133
143
|
if hvp_method == 'forward':
|
|
134
144
|
def H_mm(x):
|
|
145
|
+
self._num_hvps_last_step += 1
|
|
135
146
|
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
136
147
|
|
|
137
148
|
elif hvp_method == 'central':
|
|
138
149
|
def H_mm(x):
|
|
150
|
+
self._num_hvps_last_step += 1
|
|
139
151
|
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
140
152
|
|
|
141
153
|
else:
|
|
@@ -153,26 +165,28 @@ class NewtonCG(Module):
|
|
|
153
165
|
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
154
166
|
|
|
155
167
|
if solver == 'cg':
|
|
156
|
-
|
|
168
|
+
d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)
|
|
157
169
|
|
|
158
170
|
elif solver == 'minres':
|
|
159
|
-
|
|
171
|
+
d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
|
|
160
172
|
|
|
161
173
|
elif solver == 'minres_npc':
|
|
162
|
-
|
|
174
|
+
d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
163
175
|
|
|
164
176
|
else:
|
|
165
177
|
raise ValueError(f"Unknown solver {solver}")
|
|
166
178
|
|
|
167
179
|
if warm_start:
|
|
168
180
|
assert x0 is not None
|
|
169
|
-
x0.copy_(
|
|
181
|
+
x0.copy_(d)
|
|
182
|
+
|
|
183
|
+
var.update = d
|
|
170
184
|
|
|
171
|
-
|
|
185
|
+
self._num_hvps += self._num_hvps_last_step
|
|
172
186
|
return var
|
|
173
187
|
|
|
174
188
|
|
|
175
|
-
class
|
|
189
|
+
class NewtonCGSteihaug(Module):
|
|
176
190
|
"""Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
|
|
177
191
|
|
|
178
192
|
This optimizer implements Newton's method using a matrix-free conjugate
|
|
@@ -245,49 +259,61 @@ class TruncatedNewtonCG(Module):
|
|
|
245
259
|
def __init__(
|
|
246
260
|
self,
|
|
247
261
|
maxiter: int | None = None,
|
|
248
|
-
eta: float=
|
|
249
|
-
nplus: float =
|
|
262
|
+
eta: float= 0.0,
|
|
263
|
+
nplus: float = 3.5,
|
|
250
264
|
nminus: float = 0.25,
|
|
265
|
+
rho_good: float = 0.99,
|
|
266
|
+
rho_bad: float = 1e-4,
|
|
251
267
|
init: float = 1,
|
|
252
|
-
tol: float = 1e-
|
|
268
|
+
tol: float = 1e-8,
|
|
253
269
|
reg: float = 1e-8,
|
|
254
|
-
hvp_method: Literal["forward", "central"
|
|
255
|
-
solver: Literal['cg',
|
|
270
|
+
hvp_method: Literal["forward", "central"] = "forward",
|
|
271
|
+
solver: Literal['cg', "minres"] = 'cg',
|
|
256
272
|
h: float = 1e-3,
|
|
257
|
-
max_attempts: int =
|
|
273
|
+
max_attempts: int = 100,
|
|
274
|
+
max_history: int = 100,
|
|
275
|
+
boundary_tol: float = 1e-1,
|
|
276
|
+
miniter: int = 1,
|
|
277
|
+
rms_beta: float | None = None,
|
|
278
|
+
adapt_tol: bool = True,
|
|
279
|
+
npc_terminate: bool = False,
|
|
258
280
|
inner: Chainable | None = None,
|
|
259
281
|
):
|
|
260
|
-
defaults =
|
|
282
|
+
defaults = locals().copy()
|
|
283
|
+
del defaults['self'], defaults['inner']
|
|
261
284
|
super().__init__(defaults,)
|
|
262
285
|
|
|
263
286
|
if inner is not None:
|
|
264
287
|
self.set_child('inner', inner)
|
|
265
288
|
|
|
289
|
+
self._num_hvps = 0
|
|
290
|
+
self._num_hvps_last_step = 0
|
|
291
|
+
|
|
266
292
|
@torch.no_grad
|
|
267
293
|
def step(self, var):
|
|
268
294
|
params = TensorList(var.params)
|
|
269
295
|
closure = var.closure
|
|
270
296
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
271
297
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
reg = settings['reg']
|
|
275
|
-
maxiter = settings['maxiter']
|
|
276
|
-
hvp_method = settings['hvp_method']
|
|
277
|
-
h = settings['h']
|
|
278
|
-
max_attempts = settings['max_attempts']
|
|
279
|
-
solver = settings['solver'].lower().strip()
|
|
298
|
+
tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
|
|
299
|
+
solver = self.defaults['solver'].lower().strip()
|
|
280
300
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
301
|
+
(reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
|
|
302
|
+
eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
|
|
303
|
+
miniter, max_history, adapt_tol) = itemgetter(
|
|
304
|
+
"reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
|
|
305
|
+
"eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
|
|
306
|
+
"miniter", "max_history", "adapt_tol",
|
|
307
|
+
)(self.defaults)
|
|
308
|
+
|
|
309
|
+
self._num_hvps_last_step = 0
|
|
285
310
|
|
|
286
311
|
# ---------------------- Hessian vector product function --------------------- #
|
|
287
312
|
if hvp_method == 'autograd':
|
|
288
313
|
grad = var.get_grad(create_graph=True)
|
|
289
314
|
|
|
290
315
|
def H_mm(x):
|
|
316
|
+
self._num_hvps_last_step += 1
|
|
291
317
|
with torch.enable_grad():
|
|
292
318
|
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
293
319
|
|
|
@@ -298,77 +324,112 @@ class TruncatedNewtonCG(Module):
|
|
|
298
324
|
|
|
299
325
|
if hvp_method == 'forward':
|
|
300
326
|
def H_mm(x):
|
|
327
|
+
self._num_hvps_last_step += 1
|
|
301
328
|
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
302
329
|
|
|
303
330
|
elif hvp_method == 'central':
|
|
304
331
|
def H_mm(x):
|
|
332
|
+
self._num_hvps_last_step += 1
|
|
305
333
|
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
306
334
|
|
|
307
335
|
else:
|
|
308
336
|
raise ValueError(hvp_method)
|
|
309
337
|
|
|
310
338
|
|
|
311
|
-
#
|
|
339
|
+
# ------------------------- update RMS preconditioner ------------------------ #
|
|
312
340
|
b = var.get_update()
|
|
341
|
+
P_mm = None
|
|
342
|
+
rms_beta = self.defaults["rms_beta"]
|
|
343
|
+
if rms_beta is not None:
|
|
344
|
+
exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
|
|
345
|
+
exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
|
|
346
|
+
exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
|
|
347
|
+
def _P_mm(x):
|
|
348
|
+
return x / exp_avg_sq_sqrt
|
|
349
|
+
P_mm = _P_mm
|
|
350
|
+
|
|
351
|
+
# -------------------------------- inner step -------------------------------- #
|
|
313
352
|
if 'inner' in self.children:
|
|
314
353
|
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
315
354
|
b = as_tensorlist(b)
|
|
316
355
|
|
|
317
|
-
#
|
|
356
|
+
# ------------------------------- trust region ------------------------------- #
|
|
318
357
|
success = False
|
|
319
|
-
|
|
358
|
+
d = None
|
|
359
|
+
x0 = [p.clone() for p in params]
|
|
360
|
+
solution = None
|
|
361
|
+
|
|
320
362
|
while not success:
|
|
321
363
|
max_attempts -= 1
|
|
322
364
|
if max_attempts < 0: break
|
|
323
365
|
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
if
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
366
|
+
trust_radius = self.global_state.get('trust_radius', init)
|
|
367
|
+
|
|
368
|
+
# -------------- make sure trust radius isn't too small or large ------------- #
|
|
369
|
+
finfo = torch.finfo(x0[0].dtype)
|
|
370
|
+
if trust_radius < finfo.tiny * 2:
|
|
371
|
+
trust_radius = self.global_state['trust_radius'] = init
|
|
372
|
+
if adapt_tol:
|
|
373
|
+
self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
|
|
374
|
+
|
|
375
|
+
elif trust_radius > finfo.max / 2:
|
|
376
|
+
trust_radius = self.global_state['trust_radius'] = init
|
|
377
|
+
|
|
378
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
379
|
+
d = None
|
|
380
|
+
if solution is not None and solution.history is not None:
|
|
381
|
+
d = find_within_trust_radius(solution.history, trust_radius)
|
|
382
|
+
|
|
383
|
+
if d is None:
|
|
384
|
+
if solver == 'cg':
|
|
385
|
+
d, solution = cg(
|
|
386
|
+
A_mm=H_mm,
|
|
387
|
+
b=b,
|
|
388
|
+
tol=tol,
|
|
389
|
+
maxiter=maxiter,
|
|
390
|
+
reg=reg,
|
|
391
|
+
trust_radius=trust_radius,
|
|
392
|
+
miniter=miniter,
|
|
393
|
+
npc_terminate=npc_terminate,
|
|
394
|
+
history_size=max_history,
|
|
395
|
+
P_mm=P_mm,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
elif solver == 'minres':
|
|
399
|
+
d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
|
|
400
|
+
|
|
401
|
+
else:
|
|
402
|
+
raise ValueError(f"unknown solver {solver}")
|
|
403
|
+
|
|
404
|
+
# ---------------------------- update trust radius --------------------------- #
|
|
405
|
+
self.global_state["trust_radius"], success = default_radius(
|
|
406
|
+
params=params,
|
|
407
|
+
closure=closure,
|
|
408
|
+
f=tofloat(var.get_loss(False)),
|
|
409
|
+
g=b,
|
|
410
|
+
H=H_mm,
|
|
411
|
+
d=d,
|
|
412
|
+
trust_radius=trust_radius,
|
|
413
|
+
eta=eta,
|
|
414
|
+
nplus=nplus,
|
|
415
|
+
nminus=nminus,
|
|
416
|
+
rho_good=rho_good,
|
|
417
|
+
rho_bad=rho_bad,
|
|
418
|
+
boundary_tol=boundary_tol,
|
|
419
|
+
|
|
420
|
+
init=init, # init isn't used because check_overflow=False
|
|
421
|
+
state=self.global_state, # not used
|
|
422
|
+
settings=self.defaults, # not used
|
|
423
|
+
check_overflow=False, # this is checked manually to adapt tolerance
|
|
424
|
+
)
|
|
364
425
|
|
|
365
|
-
|
|
426
|
+
# --------------------------- assign new direction --------------------------- #
|
|
427
|
+
assert d is not None
|
|
366
428
|
if success:
|
|
367
|
-
var.update =
|
|
429
|
+
var.update = d
|
|
368
430
|
|
|
369
431
|
else:
|
|
370
432
|
var.update = params.zeros_like()
|
|
371
433
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
434
|
+
self._num_hvps += self._num_hvps_last_step
|
|
435
|
+
return var
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .laplacian import LaplacianSmoothing
|
|
2
|
-
from .
|
|
2
|
+
from .sampling import GradientSampling
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Literal, cast
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Modular, Module, Var
|
|
11
|
+
from ...core.reformulation import Reformulation
|
|
12
|
+
from ...utils import Distributions, NumberList, TensorList
|
|
13
|
+
from ..termination import TerminationCriteriaBase, make_termination_criteria
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _reset_except_self(optimizer: Modular, var: Var, self: Module):
|
|
17
|
+
for m in optimizer.unrolled_modules:
|
|
18
|
+
if m is not self:
|
|
19
|
+
m.reset()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GradientSampling(Reformulation):
|
|
23
|
+
"""Samples and aggregates gradients and values at perturbed points.
|
|
24
|
+
|
|
25
|
+
This module can be used for gaussian homotopy and gradient sampling methods.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
modules (Chainable | None, optional):
|
|
29
|
+
modules that will be optimizing the modified objective.
|
|
30
|
+
if None, returns gradient of the modified objective as the update. Defaults to None.
|
|
31
|
+
sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
|
|
32
|
+
n (int, optional): number of perturbations per step. Defaults to 100.
|
|
33
|
+
aggregate (str, optional):
|
|
34
|
+
how to aggregate values and gradients
|
|
35
|
+
- "mean" - uses mean of the gradients, as in gaussian homotopy.
|
|
36
|
+
- "max" - uses element-wise maximum of the gradients.
|
|
37
|
+
- "min" - uses element-wise minimum of the gradients.
|
|
38
|
+
- "min-norm" - picks gradient with the lowest norm.
|
|
39
|
+
|
|
40
|
+
Defaults to 'mean'.
|
|
41
|
+
distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
|
|
42
|
+
include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
|
|
43
|
+
fixed (bool, optional):
|
|
44
|
+
if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
|
|
45
|
+
pre_generate (bool, optional):
|
|
46
|
+
if True, perturbations are pre-generated before each step.
|
|
47
|
+
This requires more memory to store all of them,
|
|
48
|
+
but ensures they do not change when closure is evaluated multiple times.
|
|
49
|
+
Defaults to True.
|
|
50
|
+
termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
|
|
51
|
+
a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
|
|
52
|
+
and new perturbations will be generated if ``fixed``. Defaults to None.
|
|
53
|
+
decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
|
|
54
|
+
reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
|
|
55
|
+
sigma_strategy (str | None, optional):
|
|
56
|
+
strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
|
|
57
|
+
otherwise it is multiplied by ``sigma_nminus``.
|
|
58
|
+
- "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
|
|
59
|
+
- "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
|
|
60
|
+
- None - doesn't use adaptive sigma.
|
|
61
|
+
|
|
62
|
+
This introduces a side-effect to the closure, so it should be left at None of you use
|
|
63
|
+
trust region or line search to optimize the modified objective.
|
|
64
|
+
Defaults to None.
|
|
65
|
+
sigma_target (int, optional):
|
|
66
|
+
number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
|
|
67
|
+
sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
|
|
68
|
+
sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
|
|
69
|
+
seed (int | None, optional): seed. Defaults to None.
|
|
70
|
+
"""
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
modules: Chainable | None = None,
|
|
74
|
+
sigma: float = 1.,
|
|
75
|
+
n:int = 100,
|
|
76
|
+
aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
|
|
77
|
+
distribution: Distributions = 'gaussian',
|
|
78
|
+
include_x0: bool = True,
|
|
79
|
+
|
|
80
|
+
fixed: bool=True,
|
|
81
|
+
pre_generate: bool = True,
|
|
82
|
+
termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
|
|
83
|
+
decay: float = 2/3,
|
|
84
|
+
reset_on_termination: bool = True,
|
|
85
|
+
|
|
86
|
+
sigma_strategy: Literal['grad-norm', 'value'] | None = None,
|
|
87
|
+
sigma_target: int | float = 0.2,
|
|
88
|
+
sigma_nplus: float = 4/3,
|
|
89
|
+
sigma_nminus: float = 2/3,
|
|
90
|
+
|
|
91
|
+
seed: int | None = None,
|
|
92
|
+
):
|
|
93
|
+
|
|
94
|
+
defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
|
|
95
|
+
super().__init__(defaults, modules)
|
|
96
|
+
|
|
97
|
+
if termination is not None:
|
|
98
|
+
self.set_child('termination', make_termination_criteria(extra=termination))
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def pre_step(self, var):
|
|
102
|
+
params = TensorList(var.params)
|
|
103
|
+
|
|
104
|
+
fixed = self.defaults['fixed']
|
|
105
|
+
|
|
106
|
+
# check termination criteria
|
|
107
|
+
if 'termination' in self.children:
|
|
108
|
+
termination = cast(TerminationCriteriaBase, self.children['termination'])
|
|
109
|
+
if termination.should_terminate(var):
|
|
110
|
+
|
|
111
|
+
# decay sigmas
|
|
112
|
+
states = [self.state[p] for p in params]
|
|
113
|
+
settings = [self.settings[p] for p in params]
|
|
114
|
+
|
|
115
|
+
for state, setting in zip(states, settings):
|
|
116
|
+
if 'sigma' not in state: state['sigma'] = setting['sigma']
|
|
117
|
+
state['sigma'] *= setting['decay']
|
|
118
|
+
|
|
119
|
+
# reset on sigmas decay
|
|
120
|
+
if self.defaults['reset_on_termination']:
|
|
121
|
+
var.post_step_hooks.append(partial(_reset_except_self, self=self))
|
|
122
|
+
|
|
123
|
+
# clear perturbations
|
|
124
|
+
self.global_state.pop('perts', None)
|
|
125
|
+
|
|
126
|
+
# pre-generate perturbations if not already pre-generated or not fixed
|
|
127
|
+
if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
|
|
128
|
+
states = [self.state[p] for p in params]
|
|
129
|
+
settings = [self.settings[p] for p in params]
|
|
130
|
+
|
|
131
|
+
n = self.defaults['n'] - self.defaults['include_x0']
|
|
132
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
133
|
+
|
|
134
|
+
perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]
|
|
135
|
+
|
|
136
|
+
self.global_state['perts'] = perts
|
|
137
|
+
|
|
138
|
+
@torch.no_grad
|
|
139
|
+
def closure(self, backward, closure, params, var):
|
|
140
|
+
params = TensorList(params)
|
|
141
|
+
loss_agg = None
|
|
142
|
+
grad_agg = None
|
|
143
|
+
|
|
144
|
+
states = [self.state[p] for p in params]
|
|
145
|
+
settings = [self.settings[p] for p in params]
|
|
146
|
+
sigma_inits = [s['sigma'] for s in settings]
|
|
147
|
+
sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]
|
|
148
|
+
|
|
149
|
+
include_x0 = self.defaults['include_x0']
|
|
150
|
+
pre_generate = self.defaults['pre_generate']
|
|
151
|
+
aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
|
|
152
|
+
sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
|
|
153
|
+
distribution = self.defaults['distribution']
|
|
154
|
+
generator = self.get_generator(params[0].device, self.defaults['seed'])
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
n_finite = 0
|
|
158
|
+
n_good = 0
|
|
159
|
+
f_0 = None; g_0 = None
|
|
160
|
+
|
|
161
|
+
# evaluate at x_0
|
|
162
|
+
if include_x0:
|
|
163
|
+
f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
|
|
164
|
+
|
|
165
|
+
isfinite = math.isfinite(f_0)
|
|
166
|
+
if isfinite:
|
|
167
|
+
n_finite += 1
|
|
168
|
+
loss_agg = f_0
|
|
169
|
+
|
|
170
|
+
if backward:
|
|
171
|
+
g_0 = var.get_grad()
|
|
172
|
+
if isfinite: grad_agg = g_0
|
|
173
|
+
|
|
174
|
+
# evaluate at x_0 + p for each perturbation
|
|
175
|
+
if pre_generate:
|
|
176
|
+
perts = self.global_state['perts']
|
|
177
|
+
else:
|
|
178
|
+
perts = [None] * (self.defaults['n'] - include_x0)
|
|
179
|
+
|
|
180
|
+
x_0 = [p.clone() for p in params]
|
|
181
|
+
|
|
182
|
+
for pert in perts:
|
|
183
|
+
loss = None; grad = None
|
|
184
|
+
|
|
185
|
+
# generate if not pre-generated
|
|
186
|
+
if pert is None:
|
|
187
|
+
pert = params.sample_like(distribution, generator=generator)
|
|
188
|
+
|
|
189
|
+
# add perturbation and evaluate
|
|
190
|
+
pert = pert * sigmas
|
|
191
|
+
torch._foreach_add_(params, pert)
|
|
192
|
+
|
|
193
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
194
|
+
loss = closure(backward)
|
|
195
|
+
|
|
196
|
+
if math.isfinite(loss):
|
|
197
|
+
n_finite += 1
|
|
198
|
+
|
|
199
|
+
# add loss
|
|
200
|
+
if loss_agg is None:
|
|
201
|
+
loss_agg = loss
|
|
202
|
+
else:
|
|
203
|
+
if aggregate == 'mean':
|
|
204
|
+
loss_agg += loss
|
|
205
|
+
|
|
206
|
+
elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
|
|
207
|
+
loss_agg = loss_agg.clamp(max=loss)
|
|
208
|
+
|
|
209
|
+
elif aggregate == 'max':
|
|
210
|
+
loss_agg = loss_agg.clamp(min=loss)
|
|
211
|
+
|
|
212
|
+
# add grad
|
|
213
|
+
if backward:
|
|
214
|
+
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
215
|
+
if grad_agg is None:
|
|
216
|
+
grad_agg = grad
|
|
217
|
+
else:
|
|
218
|
+
if aggregate == 'mean':
|
|
219
|
+
torch._foreach_add_(grad_agg, grad)
|
|
220
|
+
|
|
221
|
+
elif aggregate == 'min':
|
|
222
|
+
grad_agg_abs = torch._foreach_abs(grad_agg)
|
|
223
|
+
torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
|
|
224
|
+
grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
|
|
225
|
+
|
|
226
|
+
elif aggregate == 'max':
|
|
227
|
+
grad_agg_abs = torch._foreach_abs(grad_agg)
|
|
228
|
+
torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
|
|
229
|
+
grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
|
|
230
|
+
|
|
231
|
+
elif aggregate == 'min-norm':
|
|
232
|
+
if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
|
|
233
|
+
grad_agg = grad
|
|
234
|
+
loss_agg = loss
|
|
235
|
+
|
|
236
|
+
elif aggregate == 'min-value':
|
|
237
|
+
if loss < loss_agg:
|
|
238
|
+
grad_agg = grad
|
|
239
|
+
loss_agg = loss
|
|
240
|
+
|
|
241
|
+
# undo perturbation
|
|
242
|
+
torch._foreach_copy_(params, x_0)
|
|
243
|
+
|
|
244
|
+
# adaptive sigma
|
|
245
|
+
# by value
|
|
246
|
+
if sigma_strategy == 'value':
|
|
247
|
+
if f_0 is None:
|
|
248
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
249
|
+
f_0 = closure(False)
|
|
250
|
+
|
|
251
|
+
if loss < f_0:
|
|
252
|
+
n_good += 1
|
|
253
|
+
|
|
254
|
+
# by gradient norm
|
|
255
|
+
elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
|
|
256
|
+
assert grad is not None
|
|
257
|
+
if g_0 is None:
|
|
258
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
259
|
+
closure()
|
|
260
|
+
g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
261
|
+
|
|
262
|
+
if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
|
|
263
|
+
n_good += 1
|
|
264
|
+
|
|
265
|
+
# update sigma if strategy is enabled
|
|
266
|
+
if sigma_strategy is not None:
|
|
267
|
+
|
|
268
|
+
sigma_target = self.defaults['sigma_target']
|
|
269
|
+
if isinstance(sigma_target, float):
|
|
270
|
+
sigma_target = int(max(1, n_finite * sigma_target))
|
|
271
|
+
|
|
272
|
+
if n_good >= sigma_target:
|
|
273
|
+
key = 'sigma_nplus'
|
|
274
|
+
else:
|
|
275
|
+
key = 'sigma_nminus'
|
|
276
|
+
|
|
277
|
+
for p in params:
|
|
278
|
+
self.state[p]['sigma'] *= self.settings[p][key]
|
|
279
|
+
|
|
280
|
+
# if no finite losses, just return inf
|
|
281
|
+
if n_finite == 0:
|
|
282
|
+
assert loss_agg is None and grad_agg is None
|
|
283
|
+
loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
|
|
284
|
+
grad = [torch.full_like(p, torch.inf) for p in params]
|
|
285
|
+
return loss, grad
|
|
286
|
+
|
|
287
|
+
assert loss_agg is not None
|
|
288
|
+
|
|
289
|
+
# no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
|
|
290
|
+
if aggregate != 'mean':
|
|
291
|
+
return loss_agg, grad_agg
|
|
292
|
+
|
|
293
|
+
# on mean divide by number of evals
|
|
294
|
+
loss_agg /= n_finite
|
|
295
|
+
|
|
296
|
+
if backward:
|
|
297
|
+
assert grad_agg is not None
|
|
298
|
+
torch._foreach_div_(grad_agg, n_finite)
|
|
299
|
+
|
|
300
|
+
return loss_agg, grad_agg
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
|
|
2
|
-
from .adaptive import PolyakStepSize, BarzilaiBorwein
|
|
2
|
+
from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD
|