torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -8,16 +8,16 @@ import torch
|
|
|
8
8
|
from ...core import Chainable, Module, apply_transform
|
|
9
9
|
from ...utils import TensorList, vec_to_tensors
|
|
10
10
|
from ...utils.derivatives import (
|
|
11
|
-
|
|
11
|
+
flatten_jacobian,
|
|
12
12
|
hessian_mat,
|
|
13
13
|
hvp,
|
|
14
14
|
hvp_fd_central,
|
|
15
15
|
hvp_fd_forward,
|
|
16
16
|
jacobian_and_hessian_wrt,
|
|
17
17
|
)
|
|
18
|
+
from ...utils.linalg.linear_operator import DenseWithInverse, Dense
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
20
|
+
def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
21
21
|
try:
|
|
22
22
|
x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
|
|
23
23
|
if info == 0: return x
|
|
@@ -25,135 +25,359 @@ def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
25
25
|
except RuntimeError:
|
|
26
26
|
return None
|
|
27
27
|
|
|
28
|
-
def
|
|
28
|
+
def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
29
29
|
x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
30
30
|
if info == 0:
|
|
31
31
|
g.unsqueeze_(1)
|
|
32
32
|
return torch.cholesky_solve(g, x)
|
|
33
33
|
return None
|
|
34
34
|
|
|
35
|
-
def
|
|
35
|
+
def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
36
36
|
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
37
37
|
|
|
38
|
-
def
|
|
38
|
+
def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
|
|
39
39
|
try:
|
|
40
40
|
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
41
41
|
if tfm is not None: L = tfm(L)
|
|
42
42
|
if search_negative and L[0] < 0:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
return g.
|
|
43
|
+
neg_mask = L < 0
|
|
44
|
+
Q_neg = Q[:, neg_mask] * L[neg_mask]
|
|
45
|
+
return (Q_neg * (g @ Q_neg).sign()).mean(1)
|
|
46
|
+
|
|
47
|
+
return Q @ ((Q.mH @ g) / L)
|
|
46
48
|
|
|
47
|
-
L.reciprocal_()
|
|
48
|
-
return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
|
|
49
49
|
except torch.linalg.LinAlgError:
|
|
50
50
|
return None
|
|
51
51
|
|
|
52
|
-
def tikhonov_(H: torch.Tensor, reg: float):
|
|
53
|
-
if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
|
|
54
|
-
return H
|
|
55
52
|
|
|
56
|
-
def eig_tikhonov_(H: torch.Tensor, reg: float):
|
|
57
|
-
v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
|
|
58
|
-
return tikhonov_(H, v)
|
|
59
53
|
|
|
60
54
|
|
|
61
55
|
class Newton(Module):
|
|
62
|
-
"""Exact newton via autograd.
|
|
56
|
+
"""Exact newton's method via autograd.
|
|
57
|
+
|
|
58
|
+
Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
|
|
59
|
+
The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
|
|
60
|
+
``g`` can be output of another module, if it is specifed in ``inner`` argument.
|
|
61
|
+
|
|
62
|
+
Note:
|
|
63
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
64
|
+
|
|
65
|
+
Note:
|
|
66
|
+
This module requires the a closure passed to the optimizer step,
|
|
67
|
+
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
68
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
63
69
|
|
|
64
70
|
Args:
|
|
65
|
-
|
|
66
|
-
eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
|
|
71
|
+
damping (float, optional): tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.
|
|
67
72
|
search_negative (bool, Optional):
|
|
68
|
-
if True, whenever a negative eigenvalue is detected,
|
|
73
|
+
if True, whenever a negative eigenvalue is detected,
|
|
74
|
+
search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
|
|
75
|
+
use_lstsq (bool, Optional):
|
|
76
|
+
if True, least squares will be used to solve the linear system, this may generate reasonable directions
|
|
77
|
+
when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
78
|
+
If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
|
|
79
|
+
argument will be ignored.
|
|
69
80
|
hessian_method (str):
|
|
70
81
|
how to calculate hessian. Defaults to "autograd".
|
|
71
82
|
vectorize (bool, optional):
|
|
72
83
|
whether to enable vectorized hessian. Defaults to True.
|
|
73
|
-
inner (Chainable | None, optional):
|
|
84
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
74
85
|
H_tfm (Callable | None, optional):
|
|
75
86
|
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
76
87
|
|
|
77
|
-
must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
78
|
-
which must be True if transform inverted the hessian and False otherwise.
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
88
|
+
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
89
|
+
which must be True if transform inverted the hessian and False otherwise.
|
|
90
|
+
|
|
91
|
+
Or it returns a single tensor which is used as the update.
|
|
92
|
+
|
|
93
|
+
Defaults to None.
|
|
94
|
+
eigval_fn (Callable | None, optional):
|
|
95
|
+
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
96
|
+
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
97
|
+
|
|
98
|
+
# See also
|
|
99
|
+
|
|
100
|
+
* ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products,
|
|
101
|
+
useful for large scale problems as it doesn't form the full hessian.
|
|
102
|
+
* ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
|
|
103
|
+
* ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
|
|
104
|
+
* ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
|
|
105
|
+
|
|
106
|
+
# Notes
|
|
107
|
+
|
|
108
|
+
## Implementation details
|
|
109
|
+
|
|
110
|
+
``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
|
|
111
|
+
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
|
|
112
|
+
Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
|
|
113
|
+
|
|
114
|
+
Additionally, if ``eigval_fn`` is specified or ``search_negative`` is ``True``,
|
|
115
|
+
eigendecomposition of the hessian is computed, ``eigval_fn`` is applied to the eigenvalues,
|
|
116
|
+
and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues.
|
|
117
|
+
This is more generally more computationally expensive.
|
|
118
|
+
|
|
119
|
+
## Handling non-convexity
|
|
120
|
+
|
|
121
|
+
Standard Newton's method does not handle non-convexity well without some modifications.
|
|
122
|
+
This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
|
|
123
|
+
|
|
124
|
+
The first modification to handle non-convexity is to modify the eignevalues to be positive,
|
|
125
|
+
for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
|
|
126
|
+
|
|
127
|
+
Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
|
|
128
|
+
This also requires an eigendecomposition.
|
|
129
|
+
|
|
130
|
+
The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
|
|
131
|
+
but that may be significantly less efficient.
|
|
132
|
+
|
|
133
|
+
# Examples:
|
|
134
|
+
|
|
135
|
+
Newton's method with backtracking line search
|
|
136
|
+
|
|
137
|
+
```py
|
|
138
|
+
opt = tz.Modular(
|
|
139
|
+
model.parameters(),
|
|
140
|
+
tz.m.Newton(),
|
|
141
|
+
tz.m.Backtracking()
|
|
142
|
+
)
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
Newton preconditioning applied to momentum
|
|
146
|
+
|
|
147
|
+
```py
|
|
148
|
+
opt = tz.Modular(
|
|
149
|
+
model.parameters(),
|
|
150
|
+
tz.m.Newton(inner=tz.m.EMA(0.9)),
|
|
151
|
+
tz.m.LR(0.1)
|
|
152
|
+
)
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient,
|
|
156
|
+
but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
|
|
157
|
+
|
|
158
|
+
```py
|
|
159
|
+
opt = tz.Modular(
|
|
160
|
+
model.parameters(),
|
|
161
|
+
tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
|
|
162
|
+
tz.m.Backtracking()
|
|
163
|
+
)
|
|
164
|
+
```
|
|
82
165
|
|
|
83
166
|
"""
|
|
84
167
|
def __init__(
|
|
85
168
|
self,
|
|
86
|
-
|
|
87
|
-
eig_reg: bool = False,
|
|
169
|
+
damping: float = 0,
|
|
88
170
|
search_negative: bool = False,
|
|
171
|
+
use_lstsq: bool = False,
|
|
172
|
+
update_freq: int = 1,
|
|
89
173
|
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
90
174
|
vectorize: bool = True,
|
|
91
175
|
inner: Chainable | None = None,
|
|
92
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
|
|
93
|
-
|
|
176
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
177
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
94
178
|
):
|
|
95
|
-
defaults = dict(
|
|
179
|
+
defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, search_negative=search_negative, update_freq=update_freq)
|
|
96
180
|
super().__init__(defaults)
|
|
97
181
|
|
|
98
182
|
if inner is not None:
|
|
99
183
|
self.set_child('inner', inner)
|
|
100
184
|
|
|
101
185
|
@torch.no_grad
|
|
102
|
-
def
|
|
186
|
+
def update(self, var):
|
|
103
187
|
params = TensorList(var.params)
|
|
104
188
|
closure = var.closure
|
|
105
189
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
106
190
|
|
|
107
191
|
settings = self.settings[params[0]]
|
|
108
|
-
|
|
109
|
-
eig_reg = settings['eig_reg']
|
|
110
|
-
search_negative = settings['search_negative']
|
|
192
|
+
damping = settings['damping']
|
|
111
193
|
hessian_method = settings['hessian_method']
|
|
112
194
|
vectorize = settings['vectorize']
|
|
195
|
+
update_freq = settings['update_freq']
|
|
196
|
+
|
|
197
|
+
step = self.global_state.get('step', 0)
|
|
198
|
+
self.global_state['step'] = step + 1
|
|
199
|
+
|
|
200
|
+
g_list = var.grad
|
|
201
|
+
H = None
|
|
202
|
+
if step % update_freq == 0:
|
|
203
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
204
|
+
if hessian_method == 'autograd':
|
|
205
|
+
with torch.enable_grad():
|
|
206
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
207
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
208
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
209
|
+
var.grad = g_list
|
|
210
|
+
H = flatten_jacobian(H_list)
|
|
211
|
+
|
|
212
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
213
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
214
|
+
with torch.enable_grad():
|
|
215
|
+
g_list = var.get_grad(retain_graph=True)
|
|
216
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
217
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
218
|
+
|
|
219
|
+
else:
|
|
220
|
+
raise ValueError(hessian_method)
|
|
221
|
+
|
|
222
|
+
if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
|
|
223
|
+
self.global_state['H'] = H
|
|
224
|
+
|
|
225
|
+
@torch.no_grad
|
|
226
|
+
def apply(self, var):
|
|
227
|
+
H = self.global_state["H"]
|
|
228
|
+
|
|
229
|
+
params = var.params
|
|
230
|
+
settings = self.settings[params[0]]
|
|
231
|
+
search_negative = settings['search_negative']
|
|
113
232
|
H_tfm = settings['H_tfm']
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
117
|
-
if hessian_method == 'autograd':
|
|
118
|
-
with torch.enable_grad():
|
|
119
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
120
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
121
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
122
|
-
var.grad = g_list
|
|
123
|
-
H = hessian_list_to_mat(H_list)
|
|
124
|
-
|
|
125
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
126
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
127
|
-
with torch.enable_grad():
|
|
128
|
-
g_list = var.get_grad(retain_graph=True)
|
|
129
|
-
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
130
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
131
|
-
|
|
132
|
-
else:
|
|
133
|
-
raise ValueError(hessian_method)
|
|
233
|
+
eigval_fn = settings['eigval_fn']
|
|
234
|
+
use_lstsq = settings['use_lstsq']
|
|
134
235
|
|
|
135
236
|
# -------------------------------- inner step -------------------------------- #
|
|
136
237
|
update = var.get_update()
|
|
137
238
|
if 'inner' in self.children:
|
|
138
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=
|
|
139
|
-
g = torch.cat([t.ravel() for t in update])
|
|
239
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
140
240
|
|
|
141
|
-
|
|
142
|
-
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
143
|
-
else: H = tikhonov_(H, reg)
|
|
241
|
+
g = torch.cat([t.ravel() for t in update])
|
|
144
242
|
|
|
145
243
|
# ----------------------------------- solve ---------------------------------- #
|
|
146
244
|
update = None
|
|
147
245
|
if H_tfm is not None:
|
|
148
|
-
|
|
149
|
-
|
|
246
|
+
ret = H_tfm(H, g)
|
|
247
|
+
|
|
248
|
+
if isinstance(ret, torch.Tensor):
|
|
249
|
+
update = ret
|
|
250
|
+
|
|
251
|
+
else: # returns (H, is_inv)
|
|
252
|
+
H, is_inv = ret
|
|
253
|
+
if is_inv: update = H @ g
|
|
150
254
|
|
|
151
|
-
if search_negative or (
|
|
152
|
-
update =
|
|
255
|
+
if search_negative or (eigval_fn is not None):
|
|
256
|
+
update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
|
|
153
257
|
|
|
154
|
-
if update is None: update =
|
|
155
|
-
if update is None: update =
|
|
156
|
-
if update is None: update =
|
|
258
|
+
if update is None and use_lstsq: update = _least_squares_solve(H, g)
|
|
259
|
+
if update is None: update = _cholesky_solve(H, g)
|
|
260
|
+
if update is None: update = _lu_solve(H, g)
|
|
261
|
+
if update is None: update = _least_squares_solve(H, g)
|
|
157
262
|
|
|
158
263
|
var.update = vec_to_tensors(update, params)
|
|
264
|
+
|
|
265
|
+
return var
|
|
266
|
+
|
|
267
|
+
def get_H(self,var):
|
|
268
|
+
H = self.global_state["H"]
|
|
269
|
+
settings = self.defaults
|
|
270
|
+
if settings['eigval_fn'] is not None:
|
|
271
|
+
try:
|
|
272
|
+
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
273
|
+
L = settings['eigval_fn'](L)
|
|
274
|
+
H = Q @ L.diag_embed() @ Q.mH
|
|
275
|
+
H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
|
|
276
|
+
return DenseWithInverse(H, H_inv)
|
|
277
|
+
|
|
278
|
+
except torch.linalg.LinAlgError:
|
|
279
|
+
pass
|
|
280
|
+
|
|
281
|
+
return Dense(H)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class InverseFreeNewton(Module):
|
|
285
|
+
"""Inverse-free newton's method
|
|
286
|
+
|
|
287
|
+
.. note::
|
|
288
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
289
|
+
|
|
290
|
+
.. note::
|
|
291
|
+
This module requires the a closure passed to the optimizer step,
|
|
292
|
+
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
293
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
294
|
+
|
|
295
|
+
.. warning::
|
|
296
|
+
this uses roughly O(N^2) memory.
|
|
297
|
+
|
|
298
|
+
Reference
|
|
299
|
+
Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
|
|
300
|
+
"""
|
|
301
|
+
def __init__(
|
|
302
|
+
self,
|
|
303
|
+
update_freq: int = 1,
|
|
304
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
305
|
+
vectorize: bool = True,
|
|
306
|
+
inner: Chainable | None = None,
|
|
307
|
+
):
|
|
308
|
+
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
|
|
309
|
+
super().__init__(defaults)
|
|
310
|
+
|
|
311
|
+
if inner is not None:
|
|
312
|
+
self.set_child('inner', inner)
|
|
313
|
+
|
|
314
|
+
@torch.no_grad
|
|
315
|
+
def update(self, var):
|
|
316
|
+
params = TensorList(var.params)
|
|
317
|
+
closure = var.closure
|
|
318
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
319
|
+
|
|
320
|
+
settings = self.settings[params[0]]
|
|
321
|
+
hessian_method = settings['hessian_method']
|
|
322
|
+
vectorize = settings['vectorize']
|
|
323
|
+
update_freq = settings['update_freq']
|
|
324
|
+
|
|
325
|
+
step = self.global_state.get('step', 0)
|
|
326
|
+
self.global_state['step'] = step + 1
|
|
327
|
+
|
|
328
|
+
g_list = var.grad
|
|
329
|
+
Y = None
|
|
330
|
+
if step % update_freq == 0:
|
|
331
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
332
|
+
if hessian_method == 'autograd':
|
|
333
|
+
with torch.enable_grad():
|
|
334
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
335
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
336
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
337
|
+
var.grad = g_list
|
|
338
|
+
H = flatten_jacobian(H_list)
|
|
339
|
+
|
|
340
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
341
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
342
|
+
with torch.enable_grad():
|
|
343
|
+
g_list = var.get_grad(retain_graph=True)
|
|
344
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
345
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
346
|
+
|
|
347
|
+
else:
|
|
348
|
+
raise ValueError(hessian_method)
|
|
349
|
+
|
|
350
|
+
self.global_state["H"] = H
|
|
351
|
+
|
|
352
|
+
# inverse free part
|
|
353
|
+
if 'Y' not in self.global_state:
|
|
354
|
+
num = H.T
|
|
355
|
+
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
356
|
+
finfo = torch.finfo(H.dtype)
|
|
357
|
+
Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
358
|
+
|
|
359
|
+
else:
|
|
360
|
+
Y = self.global_state['Y']
|
|
361
|
+
I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
362
|
+
I -= H @ Y
|
|
363
|
+
Y = self.global_state['Y'] = Y @ I
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def apply(self, var):
|
|
367
|
+
Y = self.global_state["Y"]
|
|
368
|
+
params = var.params
|
|
369
|
+
|
|
370
|
+
# -------------------------------- inner step -------------------------------- #
|
|
371
|
+
update = var.get_update()
|
|
372
|
+
if 'inner' in self.children:
|
|
373
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
374
|
+
|
|
375
|
+
g = torch.cat([t.ravel() for t in update])
|
|
376
|
+
|
|
377
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
378
|
+
var.update = vec_to_tensors(Y@g, params)
|
|
379
|
+
|
|
159
380
|
return var
|
|
381
|
+
|
|
382
|
+
def get_H(self,var):
|
|
383
|
+
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|