torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,21 +1,12 @@
|
|
|
1
|
-
import itertools
|
|
2
1
|
import math
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from contextlib import nullcontext
|
|
6
|
-
from functools import partial
|
|
7
2
|
from typing import Any, Literal
|
|
8
3
|
|
|
9
4
|
import numpy as np
|
|
10
5
|
import scipy.optimize
|
|
11
6
|
import torch
|
|
12
7
|
|
|
13
|
-
from ...core import
|
|
8
|
+
from ...core import DerivativesMethod, Module
|
|
14
9
|
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
|
-
from ...utils.derivatives import (
|
|
16
|
-
flatten_jacobian,
|
|
17
|
-
jacobian_wrt,
|
|
18
|
-
)
|
|
19
10
|
|
|
20
11
|
_LETTERS = 'abcdefghijklmnopqrstuvwxyz'
|
|
21
12
|
def _poly_eval(s: np.ndarray, c, derivatives):
|
|
@@ -195,22 +186,22 @@ class HigherOrderNewton(Module):
|
|
|
195
186
|
max_attempts = 10,
|
|
196
187
|
boundary_tol: float = 1e-2,
|
|
197
188
|
de_iters: int | None = None,
|
|
198
|
-
|
|
189
|
+
derivatives_method: DerivativesMethod = "batched_autograd",
|
|
199
190
|
):
|
|
200
191
|
if init is None:
|
|
201
192
|
if trust_method == 'bounds': init = 1
|
|
202
193
|
else: init = 0.1
|
|
203
194
|
|
|
204
|
-
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init,
|
|
195
|
+
defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad, derivatives_method=derivatives_method)
|
|
205
196
|
super().__init__(defaults)
|
|
206
197
|
|
|
207
198
|
@torch.no_grad
|
|
208
|
-
def
|
|
209
|
-
params = TensorList(
|
|
210
|
-
closure =
|
|
199
|
+
def apply(self, objective):
|
|
200
|
+
params = TensorList(objective.params)
|
|
201
|
+
closure = objective.closure
|
|
211
202
|
if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
|
|
212
203
|
|
|
213
|
-
settings = self.
|
|
204
|
+
settings = self.defaults
|
|
214
205
|
order = settings['order']
|
|
215
206
|
nplus = settings['nplus']
|
|
216
207
|
nminus = settings['nminus']
|
|
@@ -219,31 +210,12 @@ class HigherOrderNewton(Module):
|
|
|
219
210
|
trust_method = settings['trust_method']
|
|
220
211
|
de_iters = settings['de_iters']
|
|
221
212
|
max_attempts = settings['max_attempts']
|
|
222
|
-
vectorize = settings['vectorize']
|
|
223
213
|
boundary_tol = settings['boundary_tol']
|
|
224
214
|
rho_good = settings['rho_good']
|
|
225
215
|
rho_bad = settings['rho_bad']
|
|
226
216
|
|
|
227
217
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
228
|
-
|
|
229
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
230
|
-
|
|
231
|
-
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
232
|
-
var.grad = list(g_list)
|
|
233
|
-
|
|
234
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
235
|
-
n = g.numel()
|
|
236
|
-
derivatives = [g]
|
|
237
|
-
T = g # current derivatives tensor
|
|
238
|
-
|
|
239
|
-
# get all derivative up to order
|
|
240
|
-
for o in range(2, order + 1):
|
|
241
|
-
is_last = o == order
|
|
242
|
-
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
243
|
-
with torch.no_grad() if is_last else nullcontext():
|
|
244
|
-
# the shape is (ndim, ) * order
|
|
245
|
-
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
246
|
-
derivatives.append(T)
|
|
218
|
+
loss, *derivatives = objective.derivatives(order=order, at_x0=True, method=self.defaults["derivatives_method"])
|
|
247
219
|
|
|
248
220
|
x0 = torch.cat([p.ravel() for p in params])
|
|
249
221
|
|
|
@@ -301,7 +273,8 @@ class HigherOrderNewton(Module):
|
|
|
301
273
|
vec_to_tensors_(x0, params)
|
|
302
274
|
reduction = loss - loss_star
|
|
303
275
|
|
|
304
|
-
rho = reduction / (max(pred_reduction,
|
|
276
|
+
rho = reduction / (max(pred_reduction, finfo.tiny * 2)) # pyright:ignore[reportArgumentType]
|
|
277
|
+
|
|
305
278
|
# failed step
|
|
306
279
|
if rho < rho_bad:
|
|
307
280
|
self.global_state['trust_region'] = trust_value * nminus
|
|
@@ -320,8 +293,9 @@ class HigherOrderNewton(Module):
|
|
|
320
293
|
assert x_star is not None
|
|
321
294
|
if success:
|
|
322
295
|
difference = vec_to_tensors(x0 - x_star, params)
|
|
323
|
-
|
|
296
|
+
objective.updates = list(difference)
|
|
324
297
|
else:
|
|
325
|
-
|
|
326
|
-
|
|
298
|
+
objective.updates = params.zeros_like()
|
|
299
|
+
|
|
300
|
+
return objective
|
|
327
301
|
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torchzero.core import Chainable, Transform, HVPMethod
|
|
6
|
+
from torchzero.utils import NumberList, TensorList
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def matrix_nag_(
|
|
10
|
+
tensors_: TensorList,
|
|
11
|
+
s: TensorList,
|
|
12
|
+
Hvp_fn: Callable,
|
|
13
|
+
mu: float | NumberList,
|
|
14
|
+
):
|
|
15
|
+
s += tensors_
|
|
16
|
+
Hv = TensorList(Hvp_fn(s))
|
|
17
|
+
s -= Hv.mul_(mu)
|
|
18
|
+
return tensors_.add_(s)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MatrixNAG(Transform):
|
|
22
|
+
"""nesterov momentum version of matrix momentum. It seemed to work really well but adapting doesn't work,
|
|
23
|
+
I need to test more"""
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
mu=0.1,
|
|
27
|
+
hvp_method: HVPMethod = "autograd",
|
|
28
|
+
h: float = 1e-3,
|
|
29
|
+
adaptive:bool = False,
|
|
30
|
+
adapt_freq: int | None = None,
|
|
31
|
+
hvp_tfm: Chainable | None = None,
|
|
32
|
+
):
|
|
33
|
+
defaults = dict(mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
|
|
34
|
+
super().__init__(defaults)
|
|
35
|
+
|
|
36
|
+
if hvp_tfm is not None:
|
|
37
|
+
self.set_child('hvp_tfm', hvp_tfm)
|
|
38
|
+
|
|
39
|
+
def reset_for_online(self):
|
|
40
|
+
super().reset_for_online()
|
|
41
|
+
self.clear_state_keys('p_prev')
|
|
42
|
+
|
|
43
|
+
@torch.no_grad
|
|
44
|
+
def apply_states(self, objective, states, settings):
|
|
45
|
+
assert objective.closure is not None
|
|
46
|
+
step = self.global_state.get("step", 0)
|
|
47
|
+
self.global_state["step"] = step + 1
|
|
48
|
+
|
|
49
|
+
p = TensorList(objective.params)
|
|
50
|
+
g = TensorList(objective.get_grads(create_graph=self.defaults["hvp_method"] == "autograd"))
|
|
51
|
+
p_prev = self.get_state(p, "p_prev", init=p, cls=TensorList)
|
|
52
|
+
s = p - p_prev
|
|
53
|
+
p_prev.copy_(p)
|
|
54
|
+
|
|
55
|
+
# -------------------------------- adaptive mu ------------------------------- #
|
|
56
|
+
if self.defaults["adaptive"]:
|
|
57
|
+
|
|
58
|
+
if step == 1:
|
|
59
|
+
self.global_state["mu_mul"] = 0
|
|
60
|
+
|
|
61
|
+
else:
|
|
62
|
+
# ---------------------------- deterministic case ---------------------------- #
|
|
63
|
+
if self.defaults["adapt_freq"] is None:
|
|
64
|
+
g_prev = self.get_state(objective.params, "g_prev", cls=TensorList)
|
|
65
|
+
y = g - g_prev
|
|
66
|
+
g_prev.copy_(g)
|
|
67
|
+
|
|
68
|
+
denom = y.global_vector_norm()
|
|
69
|
+
denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
|
|
70
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
71
|
+
|
|
72
|
+
# -------------------------------- stochastic -------------------------------- #
|
|
73
|
+
else:
|
|
74
|
+
adapt_freq = self.defaults["adapt_freq"]
|
|
75
|
+
|
|
76
|
+
# we start on 1nd step, and want to adapt when we start, so use (step - 1)
|
|
77
|
+
if (step - 1) % adapt_freq == 0:
|
|
78
|
+
assert objective.closure is not None
|
|
79
|
+
p_cur = p.clone()
|
|
80
|
+
|
|
81
|
+
# move to previous params and evaluate p_prev with current mini-batch
|
|
82
|
+
p.copy_(self.get_state(objective.params, 'p_prev'))
|
|
83
|
+
with torch.enable_grad():
|
|
84
|
+
objective.closure()
|
|
85
|
+
g_prev = [t.grad if t.grad is not None else torch.zeros_like(t) for t in p]
|
|
86
|
+
y = g - g_prev
|
|
87
|
+
|
|
88
|
+
# move back to current params
|
|
89
|
+
p.copy_(p_cur)
|
|
90
|
+
|
|
91
|
+
denom = y.global_vector_norm()
|
|
92
|
+
denom = denom.clip(min = torch.finfo(denom.dtype).tiny * 2)
|
|
93
|
+
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
94
|
+
|
|
95
|
+
# -------------------------- matrix momentum update -------------------------- #
|
|
96
|
+
mu = self.get_settings(p, "mu", cls=NumberList)
|
|
97
|
+
if "mu_mul" in self.global_state:
|
|
98
|
+
mu = mu * self.global_state["mu_mul"]
|
|
99
|
+
|
|
100
|
+
# def Hvp_fn(v):
|
|
101
|
+
# Hv, _ = self.Hvp(
|
|
102
|
+
# v=v,
|
|
103
|
+
# at_x0=True,
|
|
104
|
+
# var=objective,
|
|
105
|
+
# rgrad=g,
|
|
106
|
+
# hvp_method=self.defaults["hvp_method"],
|
|
107
|
+
# h=self.defaults["h"],
|
|
108
|
+
# normalize=True,
|
|
109
|
+
# retain_grad=False,
|
|
110
|
+
# )
|
|
111
|
+
# return Hv
|
|
112
|
+
|
|
113
|
+
_, Hvp_fn = objective.list_Hvp_function(hvp_method=self.defaults["hvp_method"], h=self.defaults["h"], at_x0=True)
|
|
114
|
+
|
|
115
|
+
objective.updates = matrix_nag_(
|
|
116
|
+
tensors_=TensorList(objective.get_updates()),
|
|
117
|
+
s=s,
|
|
118
|
+
Hvp_fn=Hvp_fn,
|
|
119
|
+
mu=mu,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return objective
|
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from typing import Any
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable,
|
|
7
|
-
from ...utils import TensorList
|
|
8
|
-
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
6
|
+
from ...core import Chainable, Optimizer, Module, step, HVPMethod
|
|
7
|
+
from ...utils import TensorList
|
|
9
8
|
from ..quasi_newton import LBFGS
|
|
10
9
|
|
|
11
10
|
|
|
@@ -13,30 +12,32 @@ class NewtonSolver(Module):
|
|
|
13
12
|
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
|
|
14
13
|
def __init__(
|
|
15
14
|
self,
|
|
16
|
-
solver: Callable[[list[torch.Tensor]], Any] = lambda p:
|
|
15
|
+
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Optimizer(p, LBFGS()),
|
|
17
16
|
maxiter=None,
|
|
18
17
|
maxiter1=None,
|
|
19
18
|
tol:float | None=1e-3,
|
|
20
19
|
reg: float = 0,
|
|
21
20
|
warm_start=True,
|
|
22
|
-
hvp_method:
|
|
21
|
+
hvp_method: HVPMethod = "autograd",
|
|
23
22
|
reset_solver: bool = False,
|
|
24
23
|
h: float= 1e-3,
|
|
24
|
+
|
|
25
25
|
inner: Chainable | None = None,
|
|
26
26
|
):
|
|
27
|
-
defaults =
|
|
28
|
-
|
|
27
|
+
defaults = locals().copy()
|
|
28
|
+
del defaults['self'], defaults['inner']
|
|
29
|
+
super().__init__(defaults)
|
|
29
30
|
|
|
30
|
-
|
|
31
|
-
self.set_child('inner', inner)
|
|
31
|
+
self.set_child("inner", inner)
|
|
32
32
|
|
|
33
33
|
self._num_hvps = 0
|
|
34
34
|
self._num_hvps_last_step = 0
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def
|
|
38
|
-
|
|
39
|
-
|
|
37
|
+
def apply(self, objective):
|
|
38
|
+
|
|
39
|
+
params = TensorList(objective.params)
|
|
40
|
+
closure = objective.closure
|
|
40
41
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
41
42
|
|
|
42
43
|
settings = self.settings[params[0]]
|
|
@@ -44,51 +45,19 @@ class NewtonSolver(Module):
|
|
|
44
45
|
maxiter = settings['maxiter']
|
|
45
46
|
maxiter1 = settings['maxiter1']
|
|
46
47
|
tol = settings['tol']
|
|
47
|
-
reg = settings['reg']
|
|
48
48
|
hvp_method = settings['hvp_method']
|
|
49
49
|
warm_start = settings['warm_start']
|
|
50
50
|
h = settings['h']
|
|
51
51
|
reset_solver = settings['reset_solver']
|
|
52
52
|
|
|
53
53
|
self._num_hvps_last_step = 0
|
|
54
|
-
# ---------------------- Hessian vector product function --------------------- #
|
|
55
|
-
if hvp_method == 'autograd':
|
|
56
|
-
grad = var.get_grad(create_graph=True)
|
|
57
|
-
|
|
58
|
-
def H_mm(x):
|
|
59
|
-
self._num_hvps_last_step += 1
|
|
60
|
-
with torch.enable_grad():
|
|
61
|
-
Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
|
|
62
|
-
if reg != 0: Hvp = Hvp + (x*reg)
|
|
63
|
-
return Hvp
|
|
64
|
-
|
|
65
|
-
else:
|
|
66
|
-
|
|
67
|
-
with torch.enable_grad():
|
|
68
|
-
grad = var.get_grad()
|
|
69
|
-
|
|
70
|
-
if hvp_method == 'forward':
|
|
71
|
-
def H_mm(x):
|
|
72
|
-
self._num_hvps_last_step += 1
|
|
73
|
-
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
74
|
-
if reg != 0: Hvp = Hvp + (x*reg)
|
|
75
|
-
return Hvp
|
|
76
|
-
|
|
77
|
-
elif hvp_method == 'central':
|
|
78
|
-
def H_mm(x):
|
|
79
|
-
self._num_hvps_last_step += 1
|
|
80
|
-
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
81
|
-
if reg != 0: Hvp = Hvp + (x*reg)
|
|
82
|
-
return Hvp
|
|
83
|
-
|
|
84
|
-
else:
|
|
85
|
-
raise ValueError(hvp_method)
|
|
86
54
|
|
|
55
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
56
|
+
_, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
87
57
|
|
|
88
58
|
# -------------------------------- inner step -------------------------------- #
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
|
|
59
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
60
|
+
b = TensorList(objective.get_updates())
|
|
92
61
|
|
|
93
62
|
# ---------------------------------- run cg ---------------------------------- #
|
|
94
63
|
x0 = None
|
|
@@ -112,7 +81,7 @@ class NewtonSolver(Module):
|
|
|
112
81
|
solver = self.global_state['solver']
|
|
113
82
|
|
|
114
83
|
def lstsq_closure(backward=True):
|
|
115
|
-
Hx =
|
|
84
|
+
Hx = H_mv(x).detach()
|
|
116
85
|
# loss = (Hx-b).pow(2).global_mean()
|
|
117
86
|
# if backward:
|
|
118
87
|
# solver.zero_grad()
|
|
@@ -122,7 +91,7 @@ class NewtonSolver(Module):
|
|
|
122
91
|
loss = residual.pow(2).global_mean()
|
|
123
92
|
if backward:
|
|
124
93
|
with torch.no_grad():
|
|
125
|
-
H_residual =
|
|
94
|
+
H_residual = H_mv(residual)
|
|
126
95
|
n = residual.global_numel()
|
|
127
96
|
x.set_grad_((2.0 / n) * H_residual)
|
|
128
97
|
|
|
@@ -143,8 +112,8 @@ class NewtonSolver(Module):
|
|
|
143
112
|
assert x0 is not None
|
|
144
113
|
x0.copy_(x)
|
|
145
114
|
|
|
146
|
-
|
|
115
|
+
objective.updates = x.detach()
|
|
147
116
|
self._num_hvps += self._num_hvps_last_step
|
|
148
|
-
return
|
|
117
|
+
return objective
|
|
149
118
|
|
|
150
119
|
|
|
@@ -7,21 +7,21 @@ from typing import Literal
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable,
|
|
11
|
-
from ...
|
|
10
|
+
from ...core import Chainable, Transform, step
|
|
11
|
+
from ...linalg.linear_operator import Dense
|
|
12
|
+
from ...utils import TensorList, vec_to_tensors_
|
|
12
13
|
from ...utils.derivatives import (
|
|
13
14
|
flatten_jacobian,
|
|
14
15
|
jacobian_wrt,
|
|
15
16
|
)
|
|
16
17
|
from ..second_order.newton import (
|
|
17
|
-
|
|
18
|
-
_eigh_solve,
|
|
18
|
+
_try_cholesky_solve,
|
|
19
19
|
_least_squares_solve,
|
|
20
|
-
|
|
20
|
+
_try_lu_solve,
|
|
21
21
|
)
|
|
22
|
-
from ...utils.linalg.linear_operator import Dense
|
|
23
22
|
|
|
24
|
-
|
|
23
|
+
|
|
24
|
+
class NewtonNewton(Transform):
|
|
25
25
|
"""Applies Newton-like preconditioning to Newton step.
|
|
26
26
|
|
|
27
27
|
This is a method that I thought of and then it worked. Here is how it works:
|
|
@@ -33,42 +33,36 @@ class NewtonNewton(Module):
|
|
|
33
33
|
3. Solve H2 x2 = x for x2.
|
|
34
34
|
|
|
35
35
|
4. Optionally, repeat (if order is higher than 3.)
|
|
36
|
-
|
|
37
|
-
Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
|
|
38
|
-
|
|
39
|
-
3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
|
|
40
|
-
this is if pytorch can vectorize hessian computation efficiently.
|
|
41
36
|
"""
|
|
42
37
|
def __init__(
|
|
43
38
|
self,
|
|
44
39
|
reg: float = 1e-6,
|
|
45
40
|
order: int = 3,
|
|
46
|
-
search_negative: bool = False,
|
|
47
41
|
vectorize: bool = True,
|
|
48
|
-
|
|
42
|
+
update_freq: int = 1,
|
|
43
|
+
inner: Chainable | None = None,
|
|
49
44
|
):
|
|
50
|
-
defaults = dict(order=order, reg=reg, vectorize=vectorize
|
|
51
|
-
super().__init__(defaults)
|
|
45
|
+
defaults = dict(order=order, reg=reg, vectorize=vectorize)
|
|
46
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
52
47
|
|
|
53
48
|
@torch.no_grad
|
|
54
|
-
def
|
|
55
|
-
|
|
56
|
-
|
|
49
|
+
def update_states(self, objective, states, settings):
|
|
50
|
+
fs = settings[0]
|
|
51
|
+
|
|
52
|
+
params = TensorList(objective.params)
|
|
53
|
+
closure = objective.closure
|
|
57
54
|
if closure is None: raise RuntimeError('NewtonNewton requires closure')
|
|
58
55
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
order = settings['order']
|
|
63
|
-
search_negative = settings['search_negative']
|
|
64
|
-
eigval_fn = settings['eigval_fn']
|
|
56
|
+
reg = fs['reg']
|
|
57
|
+
vectorize = fs['vectorize']
|
|
58
|
+
order = fs['order']
|
|
65
59
|
|
|
66
60
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
67
|
-
|
|
61
|
+
P = None
|
|
68
62
|
with torch.enable_grad():
|
|
69
|
-
loss =
|
|
63
|
+
loss = objective.loss = objective.loss_approx = closure(False)
|
|
70
64
|
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
71
|
-
|
|
65
|
+
objective.grads = list(g_list)
|
|
72
66
|
|
|
73
67
|
xp = torch.cat([t.ravel() for t in g_list])
|
|
74
68
|
I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
|
|
@@ -79,27 +73,30 @@ class NewtonNewton(Module):
|
|
|
79
73
|
with torch.no_grad() if is_last else nullcontext():
|
|
80
74
|
H = flatten_jacobian(H_list)
|
|
81
75
|
if reg != 0: H = H + I * reg
|
|
82
|
-
|
|
76
|
+
if P is None: P = H
|
|
77
|
+
else: P = P @ H
|
|
78
|
+
|
|
79
|
+
if not is_last:
|
|
80
|
+
x = _try_cholesky_solve(H, xp)
|
|
81
|
+
if x is None: x = _try_lu_solve(H, xp)
|
|
82
|
+
if x is None: x = _least_squares_solve(H, xp)
|
|
83
|
+
xp = x.squeeze()
|
|
84
|
+
|
|
85
|
+
self.global_state["P"] = P
|
|
86
|
+
|
|
87
|
+
@torch.no_grad
|
|
88
|
+
def apply_states(self, objective, states, settings):
|
|
89
|
+
updates = objective.get_updates()
|
|
90
|
+
P = self.global_state['P']
|
|
91
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
83
92
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
if x is None: x = _cholesky_solve(H, xp)
|
|
88
|
-
if x is None: x = _lu_solve(H, xp)
|
|
89
|
-
if x is None: x = _least_squares_solve(H, xp)
|
|
90
|
-
xp = x.squeeze()
|
|
93
|
+
sol = _try_cholesky_solve(P, b)
|
|
94
|
+
if sol is None: sol = _try_lu_solve(P, b)
|
|
95
|
+
if sol is None: sol = _least_squares_solve(P, b)
|
|
91
96
|
|
|
92
|
-
|
|
93
|
-
|
|
97
|
+
vec_to_tensors_(sol, updates)
|
|
98
|
+
return objective
|
|
94
99
|
|
|
95
100
|
@torch.no_grad
|
|
96
|
-
def
|
|
97
|
-
|
|
98
|
-
xp = self.global_state['xp']
|
|
99
|
-
var.update = vec_to_tensors(xp, params)
|
|
100
|
-
return var
|
|
101
|
-
|
|
102
|
-
def get_H(self, var):
|
|
103
|
-
Hs = self.global_state["Hs"]
|
|
104
|
-
if len(Hs) == 1: return Dense(Hs[0])
|
|
105
|
-
return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
|
|
101
|
+
def get_H(self, objective=...):
|
|
102
|
+
return Dense(self.global_state["P"])
|
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import
|
|
3
|
+
from ...core import TensorTransform
|
|
4
4
|
from ...utils import TensorList, unpack_states, unpack_dicts
|
|
5
5
|
|
|
6
|
-
class ReduceOutwardLR(
|
|
6
|
+
class ReduceOutwardLR(TensorTransform):
|
|
7
7
|
"""When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
8
8
|
|
|
9
9
|
This means updates that move weights towards zero have higher learning rates.
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
Warning:
|
|
12
12
|
This sounded good but after testing turns out it sucks.
|
|
13
13
|
"""
|
|
14
|
-
def __init__(self, mul = 0.5, use_grad=False, invert=False
|
|
14
|
+
def __init__(self, mul = 0.5, use_grad=False, invert=False):
|
|
15
15
|
defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
|
|
16
|
-
super().__init__(defaults, uses_grad=use_grad
|
|
16
|
+
super().__init__(defaults, uses_grad=use_grad)
|
|
17
17
|
|
|
18
18
|
@torch.no_grad
|
|
19
|
-
def
|
|
19
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
20
20
|
params = TensorList(params)
|
|
21
21
|
tensors = TensorList(tensors)
|
|
22
22
|
|
|
23
23
|
mul = [s['mul'] for s in settings]
|
|
24
24
|
s = settings[0]
|
|
25
|
-
use_grad =
|
|
25
|
+
use_grad = self._uses_grad
|
|
26
26
|
invert = s['invert']
|
|
27
27
|
|
|
28
28
|
if use_grad: cur = grads
|
|
@@ -3,10 +3,9 @@ from typing import Literal, overload
|
|
|
3
3
|
import torch
|
|
4
4
|
from scipy.sparse.linalg import LinearOperator, gcrotmk
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module,
|
|
7
|
-
from ...utils import
|
|
8
|
-
from ...utils.derivatives import
|
|
9
|
-
from ...utils.linalg.solve import cg, minres
|
|
6
|
+
from ...core import Chainable, Module, step
|
|
7
|
+
from ...utils import TensorList, vec_to_tensors
|
|
8
|
+
from ...utils.derivatives import hvp_fd_central, hvp_fd_forward
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
class ScipyNewtonCG(Module):
|
|
@@ -14,7 +13,7 @@ class ScipyNewtonCG(Module):
|
|
|
14
13
|
def __init__(
|
|
15
14
|
self,
|
|
16
15
|
solver = gcrotmk,
|
|
17
|
-
hvp_method: Literal["
|
|
16
|
+
hvp_method: Literal["fd_forward", "fd_central", "autograd"] = "autograd",
|
|
18
17
|
h: float = 1e-3,
|
|
19
18
|
warm_start=False,
|
|
20
19
|
inner: Chainable | None = None,
|
|
@@ -33,47 +32,47 @@ class ScipyNewtonCG(Module):
|
|
|
33
32
|
self._kwargs = kwargs
|
|
34
33
|
|
|
35
34
|
@torch.no_grad
|
|
36
|
-
def
|
|
37
|
-
params = TensorList(
|
|
38
|
-
closure =
|
|
35
|
+
def apply(self, objective):
|
|
36
|
+
params = TensorList(objective.params)
|
|
37
|
+
closure = objective.closure
|
|
39
38
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
40
39
|
|
|
41
|
-
|
|
42
|
-
hvp_method =
|
|
43
|
-
solver =
|
|
44
|
-
h =
|
|
45
|
-
warm_start =
|
|
40
|
+
fs = self.settings[params[0]]
|
|
41
|
+
hvp_method = fs['hvp_method']
|
|
42
|
+
solver = fs['solver']
|
|
43
|
+
h = fs['h']
|
|
44
|
+
warm_start = fs['warm_start']
|
|
46
45
|
|
|
47
46
|
self._num_hvps_last_step = 0
|
|
48
47
|
# ---------------------- Hessian vector product function --------------------- #
|
|
49
48
|
device = params[0].device; dtype=params[0].dtype
|
|
50
49
|
if hvp_method == 'autograd':
|
|
51
|
-
grad =
|
|
50
|
+
grad = objective.get_grads(create_graph=True)
|
|
52
51
|
|
|
53
52
|
def H_mm(x_np):
|
|
54
53
|
self._num_hvps_last_step += 1
|
|
55
54
|
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
56
55
|
with torch.enable_grad():
|
|
57
|
-
Hvp = TensorList(
|
|
56
|
+
Hvp = TensorList(torch.autograd.grad(grad, params, x, retain_graph=True))
|
|
58
57
|
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
59
58
|
|
|
60
59
|
else:
|
|
61
60
|
|
|
62
61
|
with torch.enable_grad():
|
|
63
|
-
grad =
|
|
62
|
+
grad = objective.get_grads()
|
|
64
63
|
|
|
65
64
|
if hvp_method == 'forward':
|
|
66
65
|
def H_mm(x_np):
|
|
67
66
|
self._num_hvps_last_step += 1
|
|
68
67
|
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
69
|
-
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad
|
|
68
|
+
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad)[1])
|
|
70
69
|
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
71
70
|
|
|
72
71
|
elif hvp_method == 'central':
|
|
73
72
|
def H_mm(x_np):
|
|
74
73
|
self._num_hvps_last_step += 1
|
|
75
74
|
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
76
|
-
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h
|
|
75
|
+
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h)[1])
|
|
77
76
|
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
78
77
|
|
|
79
78
|
else:
|
|
@@ -83,10 +82,8 @@ class ScipyNewtonCG(Module):
|
|
|
83
82
|
H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
|
|
84
83
|
|
|
85
84
|
# -------------------------------- inner step -------------------------------- #
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
89
|
-
b = as_tensorlist(b)
|
|
85
|
+
objective = self.inner_step("inner", objective, must_exist=False)
|
|
86
|
+
b = TensorList(objective.get_updates())
|
|
90
87
|
|
|
91
88
|
# ---------------------------------- run cg ---------------------------------- #
|
|
92
89
|
x0 = None
|
|
@@ -98,8 +95,8 @@ class ScipyNewtonCG(Module):
|
|
|
98
95
|
if warm_start:
|
|
99
96
|
self.global_state['x_prev'] = x_np
|
|
100
97
|
|
|
101
|
-
|
|
98
|
+
objective.updates = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
|
|
102
99
|
|
|
103
100
|
self._num_hvps += self._num_hvps_last_step
|
|
104
|
-
return
|
|
101
|
+
return objective
|
|
105
102
|
|