torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,147 +1,119 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from collections.abc import Callable
|
|
3
|
-
from
|
|
4
|
-
from typing import Literal
|
|
2
|
+
from typing import Any
|
|
5
3
|
|
|
6
4
|
import torch
|
|
7
5
|
|
|
8
|
-
from ...core import Chainable,
|
|
9
|
-
from ...utils import
|
|
10
|
-
from ...
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
hvp_fd_central,
|
|
15
|
-
hvp_fd_forward,
|
|
16
|
-
jacobian_and_hessian_wrt,
|
|
17
|
-
)
|
|
18
|
-
from ...utils.linalg.linear_operator import DenseWithInverse, Dense
|
|
19
|
-
|
|
20
|
-
def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
6
|
+
from ...core import Chainable, Transform, Objective, HessianMethod
|
|
7
|
+
from ...utils import vec_to_tensors_
|
|
8
|
+
from ...linalg.linear_operator import Dense, DenseWithInverse, Eigendecomposition
|
|
9
|
+
from ...linalg import torch_linalg
|
|
10
|
+
|
|
11
|
+
def _try_lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
21
12
|
try:
|
|
22
|
-
x, info =
|
|
13
|
+
x, info = torch_linalg.solve_ex(H, g, retry_float64=True)
|
|
23
14
|
if info == 0: return x
|
|
24
15
|
return None
|
|
25
16
|
except RuntimeError:
|
|
26
17
|
return None
|
|
27
18
|
|
|
28
|
-
def
|
|
29
|
-
|
|
19
|
+
def _try_cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
20
|
+
L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
30
21
|
if info == 0:
|
|
31
|
-
g.
|
|
32
|
-
return torch.cholesky_solve(g, x)
|
|
22
|
+
return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
|
|
33
23
|
return None
|
|
34
24
|
|
|
35
25
|
def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
36
26
|
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
37
27
|
|
|
38
|
-
def
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
except torch.linalg.LinAlgError:
|
|
50
|
-
return None
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def _get_loss_grad_and_hessian(var: Var, hessian_method:str, vectorize:bool):
|
|
54
|
-
"""returns (loss, g_list, H). Also sets var.loss and var.grad.
|
|
55
|
-
If hessian_method isn't 'autograd', loss is not set and returned as None"""
|
|
56
|
-
closure = var.closure
|
|
57
|
-
if closure is None:
|
|
58
|
-
raise RuntimeError("Second order methods requires a closure to be provided to the `step` method.")
|
|
59
|
-
|
|
60
|
-
params = var.params
|
|
61
|
-
|
|
62
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
63
|
-
loss = None
|
|
64
|
-
if hessian_method == 'autograd':
|
|
65
|
-
with torch.enable_grad():
|
|
66
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
67
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
68
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
69
|
-
var.grad = g_list
|
|
70
|
-
H = flatten_jacobian(H_list)
|
|
71
|
-
|
|
72
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
73
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
74
|
-
with torch.enable_grad():
|
|
75
|
-
g_list = var.get_grad(retain_graph=True)
|
|
76
|
-
H = hessian_mat(partial(closure, backward=False), params,
|
|
77
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
78
|
-
|
|
79
|
-
else:
|
|
80
|
-
raise ValueError(hessian_method)
|
|
81
|
-
|
|
82
|
-
return loss, g_list, H
|
|
83
|
-
|
|
84
|
-
def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None) -> torch.Tensor:
|
|
85
|
-
"""returns the update tensor, then do vec_to_tensor(update, params)"""
|
|
86
|
-
params = var.params
|
|
87
|
-
|
|
28
|
+
def _newton_update_state_(
|
|
29
|
+
state: dict,
|
|
30
|
+
H: torch.Tensor,
|
|
31
|
+
damping: float,
|
|
32
|
+
eigval_fn: Callable | None,
|
|
33
|
+
precompute_inverse: bool,
|
|
34
|
+
use_lstsq: bool,
|
|
35
|
+
):
|
|
36
|
+
"""used in most hessian-based modules"""
|
|
37
|
+
# add damping
|
|
88
38
|
if damping != 0:
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# -------------------------------- inner step -------------------------------- #
|
|
92
|
-
update = var.get_update()
|
|
93
|
-
if inner is not None:
|
|
94
|
-
update = apply_transform(inner, update, params=params, grads=var.grad, loss=var.loss, var=var)
|
|
95
|
-
|
|
96
|
-
g = torch.cat([t.ravel() for t in update])
|
|
97
|
-
if g_proj is not None: g = g_proj(g)
|
|
98
|
-
|
|
99
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
100
|
-
update = None
|
|
101
|
-
|
|
102
|
-
if H_tfm is not None:
|
|
103
|
-
ret = H_tfm(H, g)
|
|
104
|
-
|
|
105
|
-
if isinstance(ret, torch.Tensor):
|
|
106
|
-
update = ret
|
|
107
|
-
|
|
108
|
-
else: # returns (H, is_inv)
|
|
109
|
-
H, is_inv = ret
|
|
110
|
-
if is_inv: update = H @ g
|
|
39
|
+
reg = torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(damping)
|
|
40
|
+
H += reg
|
|
111
41
|
|
|
42
|
+
# if eigval_fn is given, we don't need H or H_inv, we store factors
|
|
112
43
|
if eigval_fn is not None:
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
44
|
+
L, Q = torch_linalg.eigh(H, retry_float64=True)
|
|
45
|
+
L = eigval_fn(L)
|
|
46
|
+
state["L"] = L
|
|
47
|
+
state["Q"] = Q
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
# pre-compute inverse if requested
|
|
51
|
+
# store H to as it is needed for trust regions
|
|
52
|
+
state["H"] = H
|
|
53
|
+
if precompute_inverse:
|
|
54
|
+
if use_lstsq:
|
|
55
|
+
H_inv = torch.linalg.pinv(H) # pylint:disable=not-callable
|
|
56
|
+
else:
|
|
57
|
+
H_inv, _ = torch_linalg.inv_ex(H)
|
|
58
|
+
state["H_inv"] = H_inv
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _newton_solve(
|
|
62
|
+
b: torch.Tensor,
|
|
63
|
+
state: dict[str, torch.Tensor | Any],
|
|
64
|
+
use_lstsq: bool = False,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
used in most hessian-based modules. state is from ``_newton_update_state_``, in it:
|
|
135
68
|
|
|
136
|
-
|
|
137
|
-
|
|
69
|
+
H (torch.Tensor): hessian
|
|
70
|
+
H_inv (torch.Tensor | None): hessian inverse
|
|
71
|
+
L (torch.Tensor | None): eigenvalues (transformed)
|
|
72
|
+
Q (torch.Tensor | None): eigenvectors
|
|
73
|
+
"""
|
|
74
|
+
# use eig if provided
|
|
75
|
+
if "L" in state:
|
|
76
|
+
Q = state["Q"]; L = state["L"]
|
|
77
|
+
assert Q is not None
|
|
78
|
+
return Q @ ((Q.mH @ b) / L)
|
|
79
|
+
|
|
80
|
+
# use inverse if cached
|
|
81
|
+
if "H_inv" in state:
|
|
82
|
+
return state["H_inv"] @ b
|
|
83
|
+
|
|
84
|
+
# use hessian
|
|
85
|
+
H = state["H"]
|
|
86
|
+
if use_lstsq: return _least_squares_solve(H, b)
|
|
87
|
+
|
|
88
|
+
dir = None
|
|
89
|
+
if dir is None: dir = _try_cholesky_solve(H, b)
|
|
90
|
+
if dir is None: dir = _try_lu_solve(H, b)
|
|
91
|
+
if dir is None: dir = _least_squares_solve(H, b)
|
|
92
|
+
return dir
|
|
93
|
+
|
|
94
|
+
def _newton_get_H(state: dict[str, torch.Tensor | Any]):
|
|
95
|
+
"""used in most hessian-based modules. state is from ``_newton_update_state_``"""
|
|
96
|
+
if "H_inv" in state:
|
|
97
|
+
return DenseWithInverse(state["H"], state["H_inv"])
|
|
98
|
+
|
|
99
|
+
if "L" in state:
|
|
100
|
+
# Eigendecomposition has sligthly different solve_plus_diag
|
|
101
|
+
# I am pretty sure it should be very close and it uses no solves
|
|
102
|
+
# best way to test is to try cubic regularization with this
|
|
103
|
+
return Eigendecomposition(state["L"], state["Q"], use_nystrom=False)
|
|
104
|
+
|
|
105
|
+
return Dense(state["H"])
|
|
106
|
+
|
|
107
|
+
class Newton(Transform):
|
|
108
|
+
"""Exact Newton's method via autograd.
|
|
138
109
|
|
|
139
110
|
Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
|
|
140
111
|
The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
|
|
112
|
+
|
|
141
113
|
``g`` can be output of another module, if it is specifed in ``inner`` argument.
|
|
142
114
|
|
|
143
115
|
Note:
|
|
144
|
-
In most cases Newton should be the first module in the chain because it relies on autograd. Use the
|
|
116
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
145
117
|
|
|
146
118
|
Note:
|
|
147
119
|
This module requires the a closure passed to the optimizer step,
|
|
@@ -149,38 +121,43 @@ class Newton(Module):
|
|
|
149
121
|
The closure must accept a ``backward`` argument (refer to documentation).
|
|
150
122
|
|
|
151
123
|
Args:
|
|
152
|
-
damping (float, optional): tikhonov regularizer value.
|
|
153
|
-
search_negative (bool, Optional):
|
|
154
|
-
if True, whenever a negative eigenvalue is detected,
|
|
155
|
-
search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
|
|
156
|
-
use_lstsq (bool, Optional):
|
|
157
|
-
if True, least squares will be used to solve the linear system, this may generate reasonable directions
|
|
158
|
-
when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
159
|
-
If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
|
|
160
|
-
argument will be ignored.
|
|
161
|
-
hessian_method (str):
|
|
162
|
-
how to calculate hessian. Defaults to "autograd".
|
|
163
|
-
vectorize (bool, optional):
|
|
164
|
-
whether to enable vectorized hessian. Defaults to True.
|
|
165
|
-
H_tfm (Callable | None, optional):
|
|
166
|
-
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
167
|
-
|
|
168
|
-
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
169
|
-
which must be True if transform inverted the hessian and False otherwise.
|
|
170
|
-
|
|
171
|
-
Or it returns a single tensor which is used as the update.
|
|
172
|
-
|
|
173
|
-
Defaults to None.
|
|
124
|
+
damping (float, optional): tikhonov regularizer value. Defaults to 0.
|
|
174
125
|
eigval_fn (Callable | None, optional):
|
|
175
|
-
|
|
126
|
+
function to apply to eigenvalues, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
176
127
|
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
128
|
+
update_freq (int, optional):
|
|
129
|
+
updates hessian every ``update_freq`` steps.
|
|
130
|
+
precompute_inverse (bool, optional):
|
|
131
|
+
if ``True``, whenever hessian is computed, also computes the inverse. This is more efficient
|
|
132
|
+
when ``update_freq`` is large. If ``None``, this is ``True`` if ``update_freq >= 10``.
|
|
133
|
+
use_lstsq (bool, Optional):
|
|
134
|
+
if True, least squares will be used to solve the linear system, this can prevent it from exploding
|
|
135
|
+
when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
136
|
+
If ``eigval_fn`` is specified, eigendecomposition is always used and this argument is ignored.
|
|
137
|
+
hessian_method (str):
|
|
138
|
+
Determines how hessian is computed.
|
|
139
|
+
|
|
140
|
+
- ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
|
|
141
|
+
- ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
142
|
+
- ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
|
|
143
|
+
- ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
|
|
144
|
+
- ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
|
|
145
|
+
- ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
146
|
+
- ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
147
|
+
- ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
|
|
148
|
+
- ``"thoad"`` - uses ``thoad`` library, can be significantly faster than pytorch but limited operator coverage.
|
|
149
|
+
|
|
150
|
+
Defaults to ``"batched_autograd"``.
|
|
151
|
+
h (float, optional):
|
|
152
|
+
finite difference step size if hessian is compute via finite-difference.
|
|
177
153
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
178
154
|
|
|
179
155
|
# See also
|
|
180
156
|
|
|
181
|
-
* ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products
|
|
157
|
+
* ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products.
|
|
182
158
|
useful for large scale problems as it doesn't form the full hessian.
|
|
183
159
|
* ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
|
|
160
|
+
* ``tz.m.ImprovedNewton``: Newton with additional rank one correction to the hessian, can be faster than Newton.
|
|
184
161
|
* ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
|
|
185
162
|
* ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
|
|
186
163
|
|
|
@@ -189,57 +166,48 @@ class Newton(Module):
|
|
|
189
166
|
## Implementation details
|
|
190
167
|
|
|
191
168
|
``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
|
|
192
|
-
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
|
|
193
|
-
Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
|
|
169
|
+
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting ``use_lstsq=True``.
|
|
194
170
|
|
|
195
171
|
Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
|
|
196
|
-
``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive
|
|
197
|
-
but not by much
|
|
172
|
+
``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.
|
|
198
173
|
|
|
199
174
|
## Handling non-convexity
|
|
200
175
|
|
|
201
176
|
Standard Newton's method does not handle non-convexity well without some modifications.
|
|
202
177
|
This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
|
|
203
178
|
|
|
204
|
-
|
|
179
|
+
A modification to handle non-convexity is to modify the eignevalues to be positive,
|
|
205
180
|
for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
|
|
206
181
|
|
|
207
|
-
Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
|
|
208
|
-
This also requires an eigendecomposition.
|
|
209
|
-
|
|
210
|
-
The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
|
|
211
|
-
but that may be significantly less efficient.
|
|
212
|
-
|
|
213
182
|
# Examples:
|
|
214
183
|
|
|
215
184
|
Newton's method with backtracking line search
|
|
216
185
|
|
|
217
186
|
```py
|
|
218
|
-
opt = tz.
|
|
187
|
+
opt = tz.Optimizer(
|
|
219
188
|
model.parameters(),
|
|
220
189
|
tz.m.Newton(),
|
|
221
190
|
tz.m.Backtracking()
|
|
222
191
|
)
|
|
223
192
|
```
|
|
224
193
|
|
|
225
|
-
Newton
|
|
194
|
+
Newton's method for non-convex optimization.
|
|
226
195
|
|
|
227
196
|
```py
|
|
228
|
-
opt = tz.
|
|
197
|
+
opt = tz.Optimizer(
|
|
229
198
|
model.parameters(),
|
|
230
|
-
tz.m.Newton(
|
|
231
|
-
tz.m.
|
|
199
|
+
tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
|
|
200
|
+
tz.m.Backtracking()
|
|
232
201
|
)
|
|
233
202
|
```
|
|
234
203
|
|
|
235
|
-
|
|
236
|
-
but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
|
|
204
|
+
Newton preconditioning applied to momentum
|
|
237
205
|
|
|
238
206
|
```py
|
|
239
|
-
opt = tz.
|
|
207
|
+
opt = tz.Optimizer(
|
|
240
208
|
model.parameters(),
|
|
241
|
-
tz.m.Newton(
|
|
242
|
-
tz.m.
|
|
209
|
+
tz.m.Newton(inner=tz.m.EMA(0.9)),
|
|
210
|
+
tz.m.LR(0.1)
|
|
243
211
|
)
|
|
244
212
|
```
|
|
245
213
|
|
|
@@ -247,47 +215,48 @@ class Newton(Module):
|
|
|
247
215
|
def __init__(
|
|
248
216
|
self,
|
|
249
217
|
damping: float = 0,
|
|
250
|
-
use_lstsq: bool = False,
|
|
251
|
-
update_freq: int = 1,
|
|
252
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
253
|
-
vectorize: bool = True,
|
|
254
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
255
218
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
219
|
+
update_freq: int = 1,
|
|
220
|
+
precompute_inverse: bool | None = None,
|
|
221
|
+
use_lstsq: bool = False,
|
|
222
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
223
|
+
h: float = 1e-3,
|
|
256
224
|
inner: Chainable | None = None,
|
|
257
225
|
):
|
|
258
|
-
defaults =
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
if inner is not None:
|
|
262
|
-
self.set_child('inner', inner)
|
|
226
|
+
defaults = locals().copy()
|
|
227
|
+
del defaults['self'], defaults['update_freq'], defaults["inner"]
|
|
228
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
263
229
|
|
|
264
230
|
@torch.no_grad
|
|
265
|
-
def
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
if
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
231
|
+
def update_states(self, objective, states, settings):
|
|
232
|
+
fs = settings[0]
|
|
233
|
+
|
|
234
|
+
precompute_inverse = fs["precompute_inverse"]
|
|
235
|
+
if precompute_inverse is None:
|
|
236
|
+
precompute_inverse = fs["__update_freq"] >= 10
|
|
237
|
+
|
|
238
|
+
__, _, H = objective.hessian(hessian_method=fs["hessian_method"], h=fs["h"], at_x0=True)
|
|
239
|
+
|
|
240
|
+
_newton_update_state_(
|
|
241
|
+
state = self.global_state,
|
|
242
|
+
H=H,
|
|
243
|
+
damping = fs["damping"],
|
|
244
|
+
eigval_fn = fs["eigval_fn"],
|
|
245
|
+
precompute_inverse = precompute_inverse,
|
|
246
|
+
use_lstsq = fs["use_lstsq"]
|
|
247
|
+
)
|
|
273
248
|
|
|
274
249
|
@torch.no_grad
|
|
275
|
-
def
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
var=var,
|
|
279
|
-
H = self.global_state["H"],
|
|
280
|
-
damping=self.defaults["damping"],
|
|
281
|
-
inner=self.children.get("inner", None),
|
|
282
|
-
H_tfm=self.defaults["H_tfm"],
|
|
283
|
-
eigval_fn=self.defaults["eigval_fn"],
|
|
284
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
285
|
-
)
|
|
250
|
+
def apply_states(self, objective, states, settings):
|
|
251
|
+
updates = objective.get_updates()
|
|
252
|
+
fs = settings[0]
|
|
286
253
|
|
|
287
|
-
|
|
254
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
255
|
+
sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
|
|
288
256
|
|
|
289
|
-
|
|
257
|
+
vec_to_tensors_(sol, updates)
|
|
258
|
+
return objective
|
|
290
259
|
|
|
291
|
-
def get_H(self,
|
|
292
|
-
return
|
|
260
|
+
def get_H(self,objective=...):
|
|
261
|
+
return _newton_get_H(self.global_state)
|
|
293
262
|
|