torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -8,16 +8,16 @@ import torch
|
|
|
8
8
|
from ...core import Chainable, Module, apply_transform
|
|
9
9
|
from ...utils import TensorList, vec_to_tensors
|
|
10
10
|
from ...utils.derivatives import (
|
|
11
|
-
|
|
11
|
+
flatten_jacobian,
|
|
12
12
|
hessian_mat,
|
|
13
13
|
hvp,
|
|
14
14
|
hvp_fd_central,
|
|
15
15
|
hvp_fd_forward,
|
|
16
16
|
jacobian_and_hessian_wrt,
|
|
17
17
|
)
|
|
18
|
+
from ...utils.linalg.linear_operator import DenseWithInverse, Dense
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
20
|
+
def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
21
21
|
try:
|
|
22
22
|
x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
|
|
23
23
|
if info == 0: return x
|
|
@@ -25,55 +25,58 @@ def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
25
25
|
except RuntimeError:
|
|
26
26
|
return None
|
|
27
27
|
|
|
28
|
-
def
|
|
28
|
+
def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
29
29
|
x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
30
30
|
if info == 0:
|
|
31
31
|
g.unsqueeze_(1)
|
|
32
32
|
return torch.cholesky_solve(g, x)
|
|
33
33
|
return None
|
|
34
34
|
|
|
35
|
-
def
|
|
35
|
+
def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
36
36
|
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
37
37
|
|
|
38
|
-
def
|
|
38
|
+
def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
|
|
39
39
|
try:
|
|
40
40
|
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
41
41
|
if tfm is not None: L = tfm(L)
|
|
42
42
|
if search_negative and L[0] < 0:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
return g.
|
|
43
|
+
neg_mask = L < 0
|
|
44
|
+
Q_neg = Q[:, neg_mask] * L[neg_mask]
|
|
45
|
+
return (Q_neg * (g @ Q_neg).sign()).mean(1)
|
|
46
46
|
|
|
47
47
|
return Q @ ((Q.mH @ g) / L)
|
|
48
48
|
|
|
49
49
|
except torch.linalg.LinAlgError:
|
|
50
50
|
return None
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
|
|
54
|
-
return H
|
|
52
|
+
|
|
55
53
|
|
|
56
54
|
|
|
57
55
|
class Newton(Module):
|
|
58
56
|
"""Exact newton's method via autograd.
|
|
59
57
|
|
|
60
|
-
|
|
58
|
+
Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
|
|
59
|
+
The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
|
|
60
|
+
``g`` can be output of another module, if it is specifed in ``inner`` argument.
|
|
61
|
+
|
|
62
|
+
Note:
|
|
61
63
|
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
62
64
|
|
|
63
|
-
|
|
65
|
+
Note:
|
|
64
66
|
This module requires the a closure passed to the optimizer step,
|
|
65
67
|
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
66
68
|
The closure must accept a ``backward`` argument (refer to documentation).
|
|
67
69
|
|
|
68
|
-
.. warning::
|
|
69
|
-
this uses roughly O(N^2) memory.
|
|
70
|
-
|
|
71
|
-
|
|
72
70
|
Args:
|
|
73
|
-
|
|
71
|
+
damping (float, optional): tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.
|
|
74
72
|
search_negative (bool, Optional):
|
|
75
73
|
if True, whenever a negative eigenvalue is detected,
|
|
76
|
-
search direction is proposed along
|
|
74
|
+
search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
|
|
75
|
+
use_lstsq (bool, Optional):
|
|
76
|
+
if True, least squares will be used to solve the linear system, this may generate reasonable directions
|
|
77
|
+
when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
78
|
+
If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
|
|
79
|
+
argument will be ignored.
|
|
77
80
|
hessian_method (str):
|
|
78
81
|
how to calculate hessian. Defaults to "autograd".
|
|
79
82
|
vectorize (bool, optional):
|
|
@@ -88,92 +91,107 @@ class Newton(Module):
|
|
|
88
91
|
Or it returns a single tensor which is used as the update.
|
|
89
92
|
|
|
90
93
|
Defaults to None.
|
|
91
|
-
|
|
92
|
-
optional eigenvalues transform, for example
|
|
94
|
+
eigval_fn (Callable | None, optional):
|
|
95
|
+
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
93
96
|
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
94
97
|
|
|
95
|
-
|
|
96
|
-
|
|
98
|
+
# See also
|
|
99
|
+
|
|
100
|
+
* ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products,
|
|
101
|
+
useful for large scale problems as it doesn't form the full hessian.
|
|
102
|
+
* ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
|
|
103
|
+
* ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
|
|
104
|
+
* ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.
|
|
97
105
|
|
|
98
|
-
|
|
106
|
+
# Notes
|
|
99
107
|
|
|
100
|
-
|
|
101
|
-
model.parameters(),
|
|
102
|
-
tz.m.Newton(),
|
|
103
|
-
tz.m.Backtracking()
|
|
104
|
-
)
|
|
108
|
+
## Implementation details
|
|
105
109
|
|
|
106
|
-
|
|
110
|
+
``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
|
|
111
|
+
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
|
|
112
|
+
Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
|
|
107
113
|
|
|
108
|
-
|
|
114
|
+
Additionally, if ``eigval_fn`` is specified or ``search_negative`` is ``True``,
|
|
115
|
+
eigendecomposition of the hessian is computed, ``eigval_fn`` is applied to the eigenvalues,
|
|
116
|
+
and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues.
|
|
117
|
+
This is more generally more computationally expensive.
|
|
109
118
|
|
|
110
|
-
|
|
111
|
-
model.parameters(),
|
|
112
|
-
tz.m.Newton(eigval_tfm=lambda x: torch.abs(x).clip(min=0.1)),
|
|
113
|
-
tz.m.Backtracking()
|
|
114
|
-
)
|
|
119
|
+
## Handling non-convexity
|
|
115
120
|
|
|
116
|
-
|
|
121
|
+
Standard Newton's method does not handle non-convexity well without some modifications.
|
|
122
|
+
This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.
|
|
117
123
|
|
|
118
|
-
|
|
124
|
+
The first modification to handle non-convexity is to modify the eignevalues to be positive,
|
|
125
|
+
for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.
|
|
119
126
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
tz.m.Newton(search_negative=True),
|
|
123
|
-
tz.m.Backtracking()
|
|
124
|
-
)
|
|
127
|
+
Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
|
|
128
|
+
This also requires an eigendecomposition.
|
|
125
129
|
|
|
126
|
-
|
|
130
|
+
The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
|
|
131
|
+
but that may be significantly less efficient.
|
|
127
132
|
|
|
128
|
-
|
|
133
|
+
# Examples:
|
|
129
134
|
|
|
130
|
-
|
|
131
|
-
model.parameters(),
|
|
132
|
-
tz.m.Newton(inner=tz.m.EMA(0.9)),
|
|
133
|
-
tz.m.LR(0.1)
|
|
134
|
-
)
|
|
135
|
+
Newton's method with backtracking line search
|
|
135
136
|
|
|
136
|
-
|
|
137
|
+
```py
|
|
138
|
+
opt = tz.Modular(
|
|
139
|
+
model.parameters(),
|
|
140
|
+
tz.m.Newton(),
|
|
141
|
+
tz.m.Backtracking()
|
|
142
|
+
)
|
|
143
|
+
```
|
|
137
144
|
|
|
138
|
-
|
|
145
|
+
Newton preconditioning applied to momentum
|
|
139
146
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
147
|
+
```py
|
|
148
|
+
opt = tz.Modular(
|
|
149
|
+
model.parameters(),
|
|
150
|
+
tz.m.Newton(inner=tz.m.EMA(0.9)),
|
|
151
|
+
tz.m.LR(0.1)
|
|
152
|
+
)
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient,
|
|
156
|
+
but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
|
|
157
|
+
|
|
158
|
+
```py
|
|
159
|
+
opt = tz.Modular(
|
|
160
|
+
model.parameters(),
|
|
161
|
+
tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
|
|
162
|
+
tz.m.Backtracking()
|
|
163
|
+
)
|
|
164
|
+
```
|
|
145
165
|
|
|
146
166
|
"""
|
|
147
167
|
def __init__(
|
|
148
168
|
self,
|
|
149
|
-
|
|
169
|
+
damping: float = 0,
|
|
150
170
|
search_negative: bool = False,
|
|
171
|
+
use_lstsq: bool = False,
|
|
151
172
|
update_freq: int = 1,
|
|
152
173
|
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
153
174
|
vectorize: bool = True,
|
|
154
175
|
inner: Chainable | None = None,
|
|
155
176
|
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
156
|
-
|
|
177
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
157
178
|
):
|
|
158
|
-
defaults = dict(
|
|
179
|
+
defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, search_negative=search_negative, update_freq=update_freq)
|
|
159
180
|
super().__init__(defaults)
|
|
160
181
|
|
|
161
182
|
if inner is not None:
|
|
162
183
|
self.set_child('inner', inner)
|
|
163
184
|
|
|
164
185
|
@torch.no_grad
|
|
165
|
-
def
|
|
186
|
+
def update(self, var):
|
|
166
187
|
params = TensorList(var.params)
|
|
167
188
|
closure = var.closure
|
|
168
189
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
169
190
|
|
|
170
191
|
settings = self.settings[params[0]]
|
|
171
|
-
|
|
172
|
-
search_negative = settings['search_negative']
|
|
192
|
+
damping = settings['damping']
|
|
173
193
|
hessian_method = settings['hessian_method']
|
|
174
194
|
vectorize = settings['vectorize']
|
|
175
|
-
H_tfm = settings['H_tfm']
|
|
176
|
-
eigval_tfm = settings['eigval_tfm']
|
|
177
195
|
update_freq = settings['update_freq']
|
|
178
196
|
|
|
179
197
|
step = self.global_state.get('step', 0)
|
|
@@ -189,7 +207,7 @@ class Newton(Module):
|
|
|
189
207
|
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
190
208
|
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
191
209
|
var.grad = g_list
|
|
192
|
-
H =
|
|
210
|
+
H = flatten_jacobian(H_list)
|
|
193
211
|
|
|
194
212
|
elif hessian_method in ('func', 'autograd.functional'):
|
|
195
213
|
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
@@ -201,23 +219,27 @@ class Newton(Module):
|
|
|
201
219
|
else:
|
|
202
220
|
raise ValueError(hessian_method)
|
|
203
221
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
self.global_state['H'] = H
|
|
222
|
+
if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
|
|
223
|
+
self.global_state['H'] = H
|
|
207
224
|
|
|
208
|
-
|
|
209
|
-
|
|
225
|
+
@torch.no_grad
|
|
226
|
+
def apply(self, var):
|
|
227
|
+
H = self.global_state["H"]
|
|
210
228
|
|
|
211
|
-
|
|
229
|
+
params = var.params
|
|
230
|
+
settings = self.settings[params[0]]
|
|
231
|
+
search_negative = settings['search_negative']
|
|
232
|
+
H_tfm = settings['H_tfm']
|
|
233
|
+
eigval_fn = settings['eigval_fn']
|
|
234
|
+
use_lstsq = settings['use_lstsq']
|
|
212
235
|
|
|
213
236
|
# -------------------------------- inner step -------------------------------- #
|
|
214
237
|
update = var.get_update()
|
|
215
238
|
if 'inner' in self.children:
|
|
216
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=
|
|
239
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
217
240
|
|
|
218
241
|
g = torch.cat([t.ravel() for t in update])
|
|
219
242
|
|
|
220
|
-
|
|
221
243
|
# ----------------------------------- solve ---------------------------------- #
|
|
222
244
|
update = None
|
|
223
245
|
if H_tfm is not None:
|
|
@@ -230,17 +252,35 @@ class Newton(Module):
|
|
|
230
252
|
H, is_inv = ret
|
|
231
253
|
if is_inv: update = H @ g
|
|
232
254
|
|
|
233
|
-
if search_negative or (
|
|
234
|
-
update =
|
|
255
|
+
if search_negative or (eigval_fn is not None):
|
|
256
|
+
update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
|
|
235
257
|
|
|
236
|
-
if update is None: update =
|
|
237
|
-
if update is None: update =
|
|
238
|
-
if update is None: update =
|
|
258
|
+
if update is None and use_lstsq: update = _least_squares_solve(H, g)
|
|
259
|
+
if update is None: update = _cholesky_solve(H, g)
|
|
260
|
+
if update is None: update = _lu_solve(H, g)
|
|
261
|
+
if update is None: update = _least_squares_solve(H, g)
|
|
239
262
|
|
|
240
263
|
var.update = vec_to_tensors(update, params)
|
|
241
264
|
|
|
242
265
|
return var
|
|
243
266
|
|
|
267
|
+
def get_H(self,var):
|
|
268
|
+
H = self.global_state["H"]
|
|
269
|
+
settings = self.defaults
|
|
270
|
+
if settings['eigval_fn'] is not None:
|
|
271
|
+
try:
|
|
272
|
+
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
273
|
+
L = settings['eigval_fn'](L)
|
|
274
|
+
H = Q @ L.diag_embed() @ Q.mH
|
|
275
|
+
H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
|
|
276
|
+
return DenseWithInverse(H, H_inv)
|
|
277
|
+
|
|
278
|
+
except torch.linalg.LinAlgError:
|
|
279
|
+
pass
|
|
280
|
+
|
|
281
|
+
return Dense(H)
|
|
282
|
+
|
|
283
|
+
|
|
244
284
|
class InverseFreeNewton(Module):
|
|
245
285
|
"""Inverse-free newton's method
|
|
246
286
|
|
|
@@ -272,7 +312,7 @@ class InverseFreeNewton(Module):
|
|
|
272
312
|
self.set_child('inner', inner)
|
|
273
313
|
|
|
274
314
|
@torch.no_grad
|
|
275
|
-
def
|
|
315
|
+
def update(self, var):
|
|
276
316
|
params = TensorList(var.params)
|
|
277
317
|
closure = var.closure
|
|
278
318
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
@@ -295,7 +335,7 @@ class InverseFreeNewton(Module):
|
|
|
295
335
|
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
296
336
|
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
297
337
|
var.grad = g_list
|
|
298
|
-
H =
|
|
338
|
+
H = flatten_jacobian(H_list)
|
|
299
339
|
|
|
300
340
|
elif hessian_method in ('func', 'autograd.functional'):
|
|
301
341
|
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
@@ -307,12 +347,14 @@ class InverseFreeNewton(Module):
|
|
|
307
347
|
else:
|
|
308
348
|
raise ValueError(hessian_method)
|
|
309
349
|
|
|
350
|
+
self.global_state["H"] = H
|
|
351
|
+
|
|
310
352
|
# inverse free part
|
|
311
353
|
if 'Y' not in self.global_state:
|
|
312
354
|
num = H.T
|
|
313
355
|
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
314
|
-
|
|
315
|
-
Y = self.global_state['Y'] = num.div_(denom.clip(min=
|
|
356
|
+
finfo = torch.finfo(H.dtype)
|
|
357
|
+
Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
316
358
|
|
|
317
359
|
else:
|
|
318
360
|
Y = self.global_state['Y']
|
|
@@ -320,19 +362,22 @@ class InverseFreeNewton(Module):
|
|
|
320
362
|
I -= H @ Y
|
|
321
363
|
Y = self.global_state['Y'] = Y @ I
|
|
322
364
|
|
|
323
|
-
if Y is None:
|
|
324
|
-
Y = self.global_state["Y"]
|
|
325
365
|
|
|
366
|
+
def apply(self, var):
|
|
367
|
+
Y = self.global_state["Y"]
|
|
368
|
+
params = var.params
|
|
326
369
|
|
|
327
370
|
# -------------------------------- inner step -------------------------------- #
|
|
328
371
|
update = var.get_update()
|
|
329
372
|
if 'inner' in self.children:
|
|
330
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=
|
|
373
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
331
374
|
|
|
332
375
|
g = torch.cat([t.ravel() for t in update])
|
|
333
376
|
|
|
334
|
-
|
|
335
377
|
# ----------------------------------- solve ---------------------------------- #
|
|
336
378
|
var.update = vec_to_tensors(Y@g, params)
|
|
337
379
|
|
|
338
380
|
return var
|
|
381
|
+
|
|
382
|
+
def get_H(self,var):
|
|
383
|
+
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|