torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,22 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Transform, Objective, HessianMethod
|
|
7
|
-
from ...utils import
|
|
8
|
-
from ...linalg.linear_operator import Dense, DenseWithInverse
|
|
6
|
+
from ...core import Chainable, Transform, Objective, HessianMethod
|
|
7
|
+
from ...utils import vec_to_tensors_
|
|
8
|
+
from ...linalg.linear_operator import Dense, DenseWithInverse, Eigendecomposition
|
|
9
|
+
from ...linalg import torch_linalg
|
|
9
10
|
|
|
10
|
-
|
|
11
|
-
def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
11
|
+
def _try_lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
12
12
|
try:
|
|
13
|
-
x, info =
|
|
13
|
+
x, info = torch_linalg.solve_ex(H, g, retry_float64=True)
|
|
14
14
|
if info == 0: return x
|
|
15
15
|
return None
|
|
16
16
|
except RuntimeError:
|
|
17
17
|
return None
|
|
18
18
|
|
|
19
|
-
def
|
|
19
|
+
def _try_cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
20
20
|
L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
21
21
|
if info == 0:
|
|
22
22
|
return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
|
|
@@ -25,77 +25,91 @@ def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
25
25
|
def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
26
26
|
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
27
27
|
|
|
28
|
-
def
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
except torch.linalg.LinAlgError:
|
|
40
|
-
return None
|
|
41
|
-
|
|
42
|
-
def _newton_step(objective: Objective, H: torch.Tensor, damping:float, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None, no_inner: Module | None = None) -> torch.Tensor:
|
|
43
|
-
"""INNER SHOULD BE NONE IN MOST CASES! Because Transform already has inner.
|
|
44
|
-
Returns the update tensor, then do vec_to_tensor(update, params)"""
|
|
45
|
-
# -------------------------------- inner step -------------------------------- #
|
|
46
|
-
if no_inner is not None:
|
|
47
|
-
objective = no_inner.step(objective)
|
|
48
|
-
|
|
49
|
-
update = objective.get_updates()
|
|
50
|
-
|
|
51
|
-
g = torch.cat([t.ravel() for t in update])
|
|
52
|
-
if g_proj is not None: g = g_proj(g)
|
|
53
|
-
|
|
54
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
55
|
-
update = None
|
|
56
|
-
|
|
28
|
+
def _newton_update_state_(
|
|
29
|
+
state: dict,
|
|
30
|
+
H: torch.Tensor,
|
|
31
|
+
damping: float,
|
|
32
|
+
eigval_fn: Callable | None,
|
|
33
|
+
precompute_inverse: bool,
|
|
34
|
+
use_lstsq: bool,
|
|
35
|
+
):
|
|
36
|
+
"""used in most hessian-based modules"""
|
|
37
|
+
# add damping
|
|
57
38
|
if damping != 0:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
if H_tfm is not None:
|
|
61
|
-
ret = H_tfm(H, g)
|
|
62
|
-
|
|
63
|
-
if isinstance(ret, torch.Tensor):
|
|
64
|
-
update = ret
|
|
39
|
+
reg = torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(damping)
|
|
40
|
+
H += reg
|
|
65
41
|
|
|
66
|
-
|
|
67
|
-
H, is_inv = ret
|
|
68
|
-
if is_inv: update = H @ g
|
|
69
|
-
|
|
70
|
-
if eigval_fn is not None:
|
|
71
|
-
update = _eigh_solve(H, g, eigval_fn, search_negative=False)
|
|
72
|
-
|
|
73
|
-
if update is None and use_lstsq: update = _least_squares_solve(H, g)
|
|
74
|
-
if update is None: update = _cholesky_solve(H, g)
|
|
75
|
-
if update is None: update = _lu_solve(H, g)
|
|
76
|
-
if update is None: update = _least_squares_solve(H, g)
|
|
77
|
-
|
|
78
|
-
return update
|
|
79
|
-
|
|
80
|
-
def _get_H(H: torch.Tensor, eigval_fn):
|
|
42
|
+
# if eigval_fn is given, we don't need H or H_inv, we store factors
|
|
81
43
|
if eigval_fn is not None:
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
44
|
+
L, Q = torch_linalg.eigh(H, retry_float64=True)
|
|
45
|
+
L = eigval_fn(L)
|
|
46
|
+
state["L"] = L
|
|
47
|
+
state["Q"] = Q
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
# pre-compute inverse if requested
|
|
51
|
+
# store H to as it is needed for trust regions
|
|
52
|
+
state["H"] = H
|
|
53
|
+
if precompute_inverse:
|
|
54
|
+
if use_lstsq:
|
|
55
|
+
H_inv = torch.linalg.pinv(H) # pylint:disable=not-callable
|
|
56
|
+
else:
|
|
57
|
+
H_inv, _ = torch_linalg.inv_ex(H)
|
|
58
|
+
state["H_inv"] = H_inv
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _newton_solve(
|
|
62
|
+
b: torch.Tensor,
|
|
63
|
+
state: dict[str, torch.Tensor | Any],
|
|
64
|
+
use_lstsq: bool = False,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
used in most hessian-based modules. state is from ``_newton_update_state_``, in it:
|
|
91
68
|
|
|
92
|
-
|
|
69
|
+
H (torch.Tensor): hessian
|
|
70
|
+
H_inv (torch.Tensor | None): hessian inverse
|
|
71
|
+
L (torch.Tensor | None): eigenvalues (transformed)
|
|
72
|
+
Q (torch.Tensor | None): eigenvectors
|
|
73
|
+
"""
|
|
74
|
+
# use eig if provided
|
|
75
|
+
if "L" in state:
|
|
76
|
+
Q = state["Q"]; L = state["L"]
|
|
77
|
+
assert Q is not None
|
|
78
|
+
return Q @ ((Q.mH @ b) / L)
|
|
79
|
+
|
|
80
|
+
# use inverse if cached
|
|
81
|
+
if "H_inv" in state:
|
|
82
|
+
return state["H_inv"] @ b
|
|
83
|
+
|
|
84
|
+
# use hessian
|
|
85
|
+
H = state["H"]
|
|
86
|
+
if use_lstsq: return _least_squares_solve(H, b)
|
|
87
|
+
|
|
88
|
+
dir = None
|
|
89
|
+
if dir is None: dir = _try_cholesky_solve(H, b)
|
|
90
|
+
if dir is None: dir = _try_lu_solve(H, b)
|
|
91
|
+
if dir is None: dir = _least_squares_solve(H, b)
|
|
92
|
+
return dir
|
|
93
|
+
|
|
94
|
+
def _newton_get_H(state: dict[str, torch.Tensor | Any]):
|
|
95
|
+
"""used in most hessian-based modules. state is from ``_newton_update_state_``"""
|
|
96
|
+
if "H_inv" in state:
|
|
97
|
+
return DenseWithInverse(state["H"], state["H_inv"])
|
|
98
|
+
|
|
99
|
+
if "L" in state:
|
|
100
|
+
# Eigendecomposition has sligthly different solve_plus_diag
|
|
101
|
+
# I am pretty sure it should be very close and it uses no solves
|
|
102
|
+
# best way to test is to try cubic regularization with this
|
|
103
|
+
return Eigendecomposition(state["L"], state["Q"], use_nystrom=False)
|
|
104
|
+
|
|
105
|
+
return Dense(state["H"])
|
|
93
106
|
|
|
94
107
|
class Newton(Transform):
|
|
95
|
-
"""Exact
|
|
108
|
+
"""Exact Newton's method via autograd.
|
|
96
109
|
|
|
97
110
|
Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
|
|
98
111
|
The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
|
|
112
|
+
|
|
99
113
|
``g`` can be output of another module, if it is specifed in ``inner`` argument.
|
|
100
114
|
|
|
101
115
|
Note:
|
|
@@ -107,27 +121,19 @@ class Newton(Transform):
|
|
|
107
121
|
The closure must accept a ``backward`` argument (refer to documentation).
|
|
108
122
|
|
|
109
123
|
Args:
|
|
110
|
-
damping (float, optional): tikhonov regularizer value.
|
|
111
|
-
search_negative (bool, Optional):
|
|
112
|
-
if True, whenever a negative eigenvalue is detected,
|
|
113
|
-
search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
|
|
114
|
-
use_lstsq (bool, Optional):
|
|
115
|
-
if True, least squares will be used to solve the linear system, this may generate reasonable directions
|
|
116
|
-
when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
117
|
-
If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
|
|
118
|
-
argument will be ignored.
|
|
119
|
-
H_tfm (Callable | None, optional):
|
|
120
|
-
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
121
|
-
|
|
122
|
-
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
123
|
-
which must be True if transform inverted the hessian and False otherwise.
|
|
124
|
-
|
|
125
|
-
Or it returns a single tensor which is used as the update.
|
|
126
|
-
|
|
127
|
-
Defaults to None.
|
|
124
|
+
damping (float, optional): tikhonov regularizer value. Defaults to 0.
|
|
128
125
|
eigval_fn (Callable | None, optional):
|
|
129
|
-
|
|
126
|
+
function to apply to eigenvalues, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
130
127
|
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
128
|
+
update_freq (int, optional):
|
|
129
|
+
updates hessian every ``update_freq`` steps.
|
|
130
|
+
precompute_inverse (bool, optional):
|
|
131
|
+
if ``True``, whenever hessian is computed, also computes the inverse. This is more efficient
|
|
132
|
+
when ``update_freq`` is large. If ``None``, this is ``True`` if ``update_freq >= 10``.
|
|
133
|
+
use_lstsq (bool, Optional):
|
|
134
|
+
if True, least squares will be used to solve the linear system, this can prevent it from exploding
|
|
135
|
+
when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
136
|
+
If ``eigval_fn`` is specified, eigendecomposition is always used and this argument is ignored.
|
|
131
137
|
hessian_method (str):
|
|
132
138
|
Determines how hessian is computed.
|
|
133
139
|
|
|
@@ -139,17 +145,19 @@ class Newton(Transform):
|
|
|
139
145
|
- ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
140
146
|
- ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
141
147
|
- ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
|
|
148
|
+
- ``"thoad"`` - uses ``thoad`` library, can be significantly faster than pytorch but limited operator coverage.
|
|
142
149
|
|
|
143
150
|
Defaults to ``"batched_autograd"``.
|
|
144
151
|
h (float, optional):
|
|
145
|
-
finite difference step size
|
|
152
|
+
finite difference step size if hessian is compute via finite-difference.
|
|
146
153
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
147
154
|
|
|
148
155
|
# See also
|
|
149
156
|
|
|
150
|
-
* ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products
|
|
157
|
+
* ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products.
|
|
151
158
|
useful for large scale problems as it doesn't form the full hessian.
|
|
152
159
|
* ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
|
|
160
|
+
* ``tz.m.ImprovedNewton``: Newton with additional rank one correction to the hessian, can be faster than Newton.
|
|
153
161
|
* ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
|
|
154
162
|
* ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
|
|
155
163
|
|
|
@@ -158,57 +166,48 @@ class Newton(Transform):
|
|
|
158
166
|
## Implementation details
|
|
159
167
|
|
|
160
168
|
``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
|
|
161
|
-
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
|
|
162
|
-
Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
|
|
169
|
+
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting ``use_lstsq=True``.
|
|
163
170
|
|
|
164
171
|
Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
|
|
165
|
-
``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive
|
|
166
|
-
but not by much
|
|
172
|
+
``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.
|
|
167
173
|
|
|
168
174
|
## Handling non-convexity
|
|
169
175
|
|
|
170
176
|
Standard Newton's method does not handle non-convexity well without some modifications.
|
|
171
177
|
This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
|
|
172
178
|
|
|
173
|
-
|
|
179
|
+
A modification to handle non-convexity is to modify the eignevalues to be positive,
|
|
174
180
|
for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
|
|
175
181
|
|
|
176
|
-
Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
|
|
177
|
-
This also requires an eigendecomposition.
|
|
178
|
-
|
|
179
|
-
The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
|
|
180
|
-
but that may be significantly less efficient.
|
|
181
|
-
|
|
182
182
|
# Examples:
|
|
183
183
|
|
|
184
184
|
Newton's method with backtracking line search
|
|
185
185
|
|
|
186
186
|
```py
|
|
187
|
-
opt = tz.
|
|
187
|
+
opt = tz.Optimizer(
|
|
188
188
|
model.parameters(),
|
|
189
189
|
tz.m.Newton(),
|
|
190
190
|
tz.m.Backtracking()
|
|
191
191
|
)
|
|
192
192
|
```
|
|
193
193
|
|
|
194
|
-
Newton
|
|
194
|
+
Newton's method for non-convex optimization.
|
|
195
195
|
|
|
196
196
|
```py
|
|
197
|
-
opt = tz.
|
|
197
|
+
opt = tz.Optimizer(
|
|
198
198
|
model.parameters(),
|
|
199
|
-
tz.m.Newton(
|
|
200
|
-
tz.m.
|
|
199
|
+
tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
|
|
200
|
+
tz.m.Backtracking()
|
|
201
201
|
)
|
|
202
202
|
```
|
|
203
203
|
|
|
204
|
-
|
|
205
|
-
but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
|
|
204
|
+
Newton preconditioning applied to momentum
|
|
206
205
|
|
|
207
206
|
```py
|
|
208
|
-
opt = tz.
|
|
207
|
+
opt = tz.Optimizer(
|
|
209
208
|
model.parameters(),
|
|
210
|
-
tz.m.Newton(
|
|
211
|
-
tz.m.
|
|
209
|
+
tz.m.Newton(inner=tz.m.EMA(0.9)),
|
|
210
|
+
tz.m.LR(0.1)
|
|
212
211
|
)
|
|
213
212
|
```
|
|
214
213
|
|
|
@@ -216,10 +215,10 @@ class Newton(Transform):
|
|
|
216
215
|
def __init__(
|
|
217
216
|
self,
|
|
218
217
|
damping: float = 0,
|
|
219
|
-
use_lstsq: bool = False,
|
|
220
|
-
update_freq: int = 1,
|
|
221
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
222
218
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
219
|
+
update_freq: int = 1,
|
|
220
|
+
precompute_inverse: bool | None = None,
|
|
221
|
+
use_lstsq: bool = False,
|
|
223
222
|
hessian_method: HessianMethod = "batched_autograd",
|
|
224
223
|
h: float = 1e-3,
|
|
225
224
|
inner: Chainable | None = None,
|
|
@@ -232,29 +231,32 @@ class Newton(Transform):
|
|
|
232
231
|
def update_states(self, objective, states, settings):
|
|
233
232
|
fs = settings[0]
|
|
234
233
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
234
|
+
precompute_inverse = fs["precompute_inverse"]
|
|
235
|
+
if precompute_inverse is None:
|
|
236
|
+
precompute_inverse = fs["__update_freq"] >= 10
|
|
237
|
+
|
|
238
|
+
__, _, H = objective.hessian(hessian_method=fs["hessian_method"], h=fs["h"], at_x0=True)
|
|
239
|
+
|
|
240
|
+
_newton_update_state_(
|
|
241
|
+
state = self.global_state,
|
|
242
|
+
H=H,
|
|
243
|
+
damping = fs["damping"],
|
|
244
|
+
eigval_fn = fs["eigval_fn"],
|
|
245
|
+
precompute_inverse = precompute_inverse,
|
|
246
|
+
use_lstsq = fs["use_lstsq"]
|
|
239
247
|
)
|
|
240
248
|
|
|
241
249
|
@torch.no_grad
|
|
242
250
|
def apply_states(self, objective, states, settings):
|
|
243
|
-
|
|
251
|
+
updates = objective.get_updates()
|
|
244
252
|
fs = settings[0]
|
|
245
253
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
H = self.global_state["H"],
|
|
249
|
-
damping = fs["damping"],
|
|
250
|
-
H_tfm = fs["H_tfm"],
|
|
251
|
-
eigval_fn = fs["eigval_fn"],
|
|
252
|
-
use_lstsq = fs["use_lstsq"],
|
|
253
|
-
)
|
|
254
|
+
b = torch.cat([t.ravel() for t in updates])
|
|
255
|
+
sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])
|
|
254
256
|
|
|
255
|
-
|
|
257
|
+
vec_to_tensors_(sol, updates)
|
|
256
258
|
return objective
|
|
257
259
|
|
|
258
260
|
def get_H(self,objective=...):
|
|
259
|
-
return
|
|
261
|
+
return _newton_get_H(self.global_state)
|
|
260
262
|
|
|
@@ -57,7 +57,7 @@ class NewtonCG(Transform):
|
|
|
57
57
|
Newton-CG with a backtracking line search:
|
|
58
58
|
|
|
59
59
|
```python
|
|
60
|
-
opt = tz.
|
|
60
|
+
opt = tz.Optimizer(
|
|
61
61
|
model.parameters(),
|
|
62
62
|
tz.m.NewtonCG(),
|
|
63
63
|
tz.m.Backtracking()
|
|
@@ -66,7 +66,7 @@ class NewtonCG(Transform):
|
|
|
66
66
|
|
|
67
67
|
Truncated Newton method (useful for large-scale problems):
|
|
68
68
|
```
|
|
69
|
-
opt = tz.
|
|
69
|
+
opt = tz.Optimizer(
|
|
70
70
|
model.parameters(),
|
|
71
71
|
tz.m.NewtonCG(maxiter=10),
|
|
72
72
|
tz.m.Backtracking()
|
|
@@ -198,7 +198,7 @@ class NewtonCGSteihaug(Transform):
|
|
|
198
198
|
Trust-region Newton-CG:
|
|
199
199
|
|
|
200
200
|
```python
|
|
201
|
-
opt = tz.
|
|
201
|
+
opt = tz.Optimizer(
|
|
202
202
|
model.parameters(),
|
|
203
203
|
tz.m.NewtonCGSteihaug(),
|
|
204
204
|
)
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
6
|
from ...core import Chainable, Transform, HVPMethod
|
|
6
7
|
from ...utils import TensorList, vec_to_tensors
|
|
7
|
-
from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg
|
|
8
|
+
from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod
|
|
8
9
|
from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
|
|
9
10
|
|
|
10
11
|
class NystromSketchAndSolve(Transform):
|
|
@@ -19,7 +20,18 @@ class NystromSketchAndSolve(Transform):
|
|
|
19
20
|
|
|
20
21
|
Args:
|
|
21
22
|
rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
|
|
22
|
-
reg (float, optional):
|
|
23
|
+
reg (float | None, optional):
|
|
24
|
+
scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve
|
|
25
|
+
is used to compute ``(Q diag(L) Q.T + reg*I)x = b``. It is very unstable when ``reg`` is small,
|
|
26
|
+
i.e. smaller than 1e-4. If this is None,``(Q diag(L) Q.T)x = b`` is computed by simply taking
|
|
27
|
+
reciprocal of eigenvalues. Defaults to 1e-3.
|
|
28
|
+
eigv_tol (float, optional):
|
|
29
|
+
all eigenvalues smaller than largest eigenvalue times ``eigv_tol`` are removed. Defaults to None.
|
|
30
|
+
truncate (int | None, optional):
|
|
31
|
+
keeps top ``truncate`` eigenvalues. Defaults to None.
|
|
32
|
+
damping (float, optional): scalar added to eigenvalues. Defaults to 0.
|
|
33
|
+
rdamping (float, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
|
|
34
|
+
update_freq (int, optional): frequency of updating preconditioner. Defaults to 1.
|
|
23
35
|
hvp_method (str, optional):
|
|
24
36
|
Determines how Hessian-vector products are computed.
|
|
25
37
|
|
|
@@ -40,7 +52,7 @@ class NystromSketchAndSolve(Transform):
|
|
|
40
52
|
NystromSketchAndSolve with backtracking line search
|
|
41
53
|
|
|
42
54
|
```py
|
|
43
|
-
opt = tz.
|
|
55
|
+
opt = tz.Optimizer(
|
|
44
56
|
model.parameters(),
|
|
45
57
|
tz.m.NystromSketchAndSolve(100),
|
|
46
58
|
tz.m.Backtracking()
|
|
@@ -50,7 +62,7 @@ class NystromSketchAndSolve(Transform):
|
|
|
50
62
|
Trust region NystromSketchAndSolve
|
|
51
63
|
|
|
52
64
|
```py
|
|
53
|
-
opt = tz.
|
|
65
|
+
opt = tz.Optimizer(
|
|
54
66
|
model.parameters(),
|
|
55
67
|
tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
|
|
56
68
|
)
|
|
@@ -64,10 +76,15 @@ class NystromSketchAndSolve(Transform):
|
|
|
64
76
|
def __init__(
|
|
65
77
|
self,
|
|
66
78
|
rank: int,
|
|
67
|
-
reg: float = 1e-
|
|
79
|
+
reg: float | None = 1e-2,
|
|
80
|
+
eigv_tol: float = 0,
|
|
81
|
+
truncate: int | None = None,
|
|
82
|
+
damping: float = 0,
|
|
83
|
+
rdamping: float = 0,
|
|
84
|
+
update_freq: int = 1,
|
|
85
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
68
86
|
hvp_method: HVPMethod = "batched_autograd",
|
|
69
87
|
h: float = 1e-3,
|
|
70
|
-
update_freq: int = 1,
|
|
71
88
|
inner: Chainable | None = None,
|
|
72
89
|
seed: int | None = None,
|
|
73
90
|
):
|
|
@@ -92,25 +109,53 @@ class NystromSketchAndSolve(Transform):
|
|
|
92
109
|
|
|
93
110
|
generator = self.get_generator(params[0].device, seed=fs['seed'])
|
|
94
111
|
try:
|
|
95
|
-
|
|
96
|
-
|
|
112
|
+
# compute the approximation
|
|
113
|
+
L, Q = nystrom_approximation(
|
|
114
|
+
A_mv=H_mv,
|
|
115
|
+
A_mm=H_mm,
|
|
116
|
+
ndim=ndim,
|
|
117
|
+
rank=min(fs["rank"], ndim),
|
|
118
|
+
eigv_tol=fs["eigv_tol"],
|
|
119
|
+
orthogonalize_method=fs["orthogonalize_method"],
|
|
120
|
+
dtype=dtype,
|
|
121
|
+
device=device,
|
|
122
|
+
generator=generator,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# regularize
|
|
126
|
+
L, Q = regularize_eigh(
|
|
127
|
+
L=L,
|
|
128
|
+
Q=Q,
|
|
129
|
+
truncate=fs["truncate"],
|
|
130
|
+
tol=fs["eigv_tol"],
|
|
131
|
+
damping=fs["damping"],
|
|
132
|
+
rdamping=fs["rdamping"],
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# store
|
|
136
|
+
if L is not None:
|
|
137
|
+
self.global_state["L"] = L
|
|
138
|
+
self.global_state["Q"] = Q
|
|
97
139
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
except torch.linalg.LinAlgError:
|
|
101
|
-
pass
|
|
140
|
+
except torch.linalg.LinAlgError as e:
|
|
141
|
+
warnings.warn(f"Nystrom approximation failed with: {e}")
|
|
102
142
|
|
|
103
143
|
def apply_states(self, objective, states, settings):
|
|
104
|
-
fs = settings[0]
|
|
105
|
-
b = objective.get_updates()
|
|
106
|
-
|
|
107
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
108
144
|
if "L" not in self.global_state:
|
|
109
145
|
return objective
|
|
110
146
|
|
|
147
|
+
fs = settings[0]
|
|
148
|
+
updates = objective.get_updates()
|
|
149
|
+
b=torch.cat([t.ravel() for t in updates])
|
|
150
|
+
|
|
151
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
111
152
|
L = self.global_state["L"]
|
|
112
153
|
Q = self.global_state["Q"]
|
|
113
|
-
|
|
154
|
+
|
|
155
|
+
if fs["reg"] is None:
|
|
156
|
+
x = Q @ ((Q.mH @ b) / L)
|
|
157
|
+
else:
|
|
158
|
+
x = nystrom_sketch_and_solve(L=L, Q=Q, b=b, reg=fs["reg"])
|
|
114
159
|
|
|
115
160
|
# -------------------------------- set update -------------------------------- #
|
|
116
161
|
objective.updates = vec_to_tensors(x, reference=objective.params)
|
|
@@ -127,8 +172,6 @@ class NystromSketchAndSolve(Transform):
|
|
|
127
172
|
|
|
128
173
|
class NystromPCG(Transform):
|
|
129
174
|
"""Newton's method with a Nyström-preconditioned conjugate gradient solver.
|
|
130
|
-
This tends to outperform NewtonCG but requires tuning sketch size.
|
|
131
|
-
An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
|
|
132
175
|
|
|
133
176
|
Notes:
|
|
134
177
|
- This module requires the a closure passed to the optimizer step,
|
|
@@ -138,7 +181,7 @@ class NystromPCG(Transform):
|
|
|
138
181
|
- In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
139
182
|
|
|
140
183
|
Args:
|
|
141
|
-
|
|
184
|
+
rank (int):
|
|
142
185
|
size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
|
|
143
186
|
running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
|
|
144
187
|
conjugate gradient.
|
|
@@ -169,7 +212,7 @@ class NystromPCG(Transform):
|
|
|
169
212
|
NystromPCG with backtracking line search
|
|
170
213
|
|
|
171
214
|
```python
|
|
172
|
-
opt = tz.
|
|
215
|
+
opt = tz.Optimizer(
|
|
173
216
|
model.parameters(),
|
|
174
217
|
tz.m.NystromPCG(10),
|
|
175
218
|
tz.m.Backtracking()
|
|
@@ -187,6 +230,8 @@ class NystromPCG(Transform):
|
|
|
187
230
|
tol=1e-8,
|
|
188
231
|
reg: float = 1e-6,
|
|
189
232
|
update_freq: int = 1, # here update_freq is within update_states
|
|
233
|
+
eigv_tol: float = 0,
|
|
234
|
+
orthogonalize_method: OrthogonalizeMethod = 'qr',
|
|
190
235
|
hvp_method: HVPMethod = "batched_autograd",
|
|
191
236
|
h=1e-3,
|
|
192
237
|
inner: Chainable | None = None,
|
|
@@ -202,31 +247,36 @@ class NystromPCG(Transform):
|
|
|
202
247
|
|
|
203
248
|
# ---------------------- Hessian vector product function --------------------- #
|
|
204
249
|
# this should run on every update_states
|
|
205
|
-
|
|
206
|
-
h = fs['h']
|
|
207
|
-
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
|
|
250
|
+
_, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=fs['hvp_method'], h=fs['h'], at_x0=True)
|
|
208
251
|
objective.temp = H_mv
|
|
209
252
|
|
|
210
253
|
# --------------------------- update preconditioner -------------------------- #
|
|
211
254
|
step = self.increment_counter("step", 0)
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
if step % update_freq == 0:
|
|
255
|
+
if step % fs["update_freq"] == 0:
|
|
215
256
|
|
|
216
|
-
rank = fs['rank']
|
|
217
257
|
ndim = sum(t.numel() for t in objective.params)
|
|
218
258
|
device = objective.params[0].device
|
|
219
259
|
dtype = objective.params[0].dtype
|
|
220
260
|
generator = self.get_generator(device, seed=fs['seed'])
|
|
221
261
|
|
|
222
262
|
try:
|
|
223
|
-
L, Q = nystrom_approximation(
|
|
224
|
-
|
|
263
|
+
L, Q = nystrom_approximation(
|
|
264
|
+
A_mv=None,
|
|
265
|
+
A_mm=H_mm,
|
|
266
|
+
ndim=ndim,
|
|
267
|
+
rank=min(fs["rank"], ndim),
|
|
268
|
+
eigv_tol=fs["eigv_tol"],
|
|
269
|
+
orthogonalize_method=fs["orthogonalize_method"],
|
|
270
|
+
dtype=dtype,
|
|
271
|
+
device=device,
|
|
272
|
+
generator=generator,
|
|
273
|
+
)
|
|
225
274
|
|
|
226
275
|
self.global_state["L"] = L
|
|
227
276
|
self.global_state["Q"] = Q
|
|
228
|
-
|
|
229
|
-
|
|
277
|
+
|
|
278
|
+
except torch.linalg.LinAlgError as e:
|
|
279
|
+
warnings.warn(f"Nystrom approximation failed with: {e}")
|
|
230
280
|
|
|
231
281
|
@torch.no_grad
|
|
232
282
|
def apply_states(self, objective, states, settings):
|
|
@@ -243,6 +293,7 @@ class NystromPCG(Transform):
|
|
|
243
293
|
|
|
244
294
|
L = self.global_state["L"]
|
|
245
295
|
Q = self.global_state["Q"]
|
|
296
|
+
|
|
246
297
|
x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
|
|
247
298
|
reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])
|
|
248
299
|
|