torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,17 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from contextlib import nullcontext
|
|
3
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable, Mapping
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
4
5
|
import numpy as np
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import TensorList, vec_to_tensors
|
|
9
|
-
|
|
10
|
-
flatten_jacobian,
|
|
11
|
-
jacobian_wrt,
|
|
12
|
-
)
|
|
8
|
+
from ...core import Chainable, DerivativesMethod, Objective, Transform
|
|
9
|
+
from ...utils import TensorList, vec_to_tensors
|
|
10
|
+
|
|
13
11
|
|
|
14
|
-
class HigherOrderMethodBase(
|
|
15
|
-
def __init__(self, defaults: dict | None = None,
|
|
16
|
-
self.
|
|
12
|
+
class HigherOrderMethodBase(Transform, ABC):
|
|
13
|
+
def __init__(self, defaults: dict | None = None, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
14
|
+
self._derivatives_method: DerivativesMethod = derivatives_method
|
|
17
15
|
super().__init__(defaults)
|
|
18
16
|
|
|
19
17
|
@abstractmethod
|
|
@@ -21,61 +19,27 @@ class HigherOrderMethodBase(Module, ABC):
|
|
|
21
19
|
self,
|
|
22
20
|
x: torch.Tensor,
|
|
23
21
|
evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
|
|
24
|
-
|
|
22
|
+
objective: Objective,
|
|
23
|
+
setting: Mapping[str, Any],
|
|
25
24
|
) -> torch.Tensor:
|
|
26
25
|
""""""
|
|
27
26
|
|
|
28
27
|
@torch.no_grad
|
|
29
|
-
def
|
|
30
|
-
params = TensorList(
|
|
31
|
-
|
|
32
|
-
closure =
|
|
28
|
+
def apply_states(self, objective, states, settings):
|
|
29
|
+
params = TensorList(objective.params)
|
|
30
|
+
|
|
31
|
+
closure = objective.closure
|
|
33
32
|
if closure is None: raise RuntimeError('MultipointNewton requires closure')
|
|
34
|
-
|
|
33
|
+
derivatives_method = self._derivatives_method
|
|
35
34
|
|
|
36
35
|
def evaluate(x, order) -> tuple[torch.Tensor, ...]:
|
|
37
36
|
"""order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
if order == 0:
|
|
41
|
-
loss = closure(False)
|
|
42
|
-
params.copy_(x0)
|
|
43
|
-
return (loss, )
|
|
44
|
-
|
|
45
|
-
if order == 1:
|
|
46
|
-
with torch.enable_grad():
|
|
47
|
-
loss = closure()
|
|
48
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
49
|
-
params.copy_(x0)
|
|
50
|
-
return loss, torch.cat([g.ravel() for g in grad])
|
|
51
|
-
|
|
52
|
-
with torch.enable_grad():
|
|
53
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
54
|
-
|
|
55
|
-
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
56
|
-
var.grad = list(g_list)
|
|
57
|
-
|
|
58
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
59
|
-
n = g.numel()
|
|
60
|
-
ret = [loss, g]
|
|
61
|
-
T = g # current derivatives tensor
|
|
62
|
-
|
|
63
|
-
# get all derivative up to order
|
|
64
|
-
for o in range(2, order + 1):
|
|
65
|
-
is_last = o == order
|
|
66
|
-
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
67
|
-
with torch.no_grad() if is_last else nullcontext():
|
|
68
|
-
# the shape is (ndim, ) * order
|
|
69
|
-
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
70
|
-
ret.append(T)
|
|
71
|
-
|
|
72
|
-
params.copy_(x0)
|
|
73
|
-
return tuple(ret)
|
|
37
|
+
return objective.derivatives_at(x, order, method=derivatives_method)
|
|
74
38
|
|
|
75
39
|
x = torch.cat([p.ravel() for p in params])
|
|
76
|
-
dir = self.one_iteration(x, evaluate,
|
|
77
|
-
|
|
78
|
-
return
|
|
40
|
+
dir = self.one_iteration(x, evaluate, objective, settings[0])
|
|
41
|
+
objective.updates = vec_to_tensors(dir, objective.params)
|
|
42
|
+
return objective
|
|
79
43
|
|
|
80
44
|
def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
|
|
81
45
|
if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
|
|
@@ -106,16 +70,15 @@ class SixthOrder3P(HigherOrderMethodBase):
|
|
|
106
70
|
|
|
107
71
|
Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
|
|
108
72
|
"""
|
|
109
|
-
def __init__(self, lstsq: bool=False,
|
|
73
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
110
74
|
defaults=dict(lstsq=lstsq)
|
|
111
|
-
super().__init__(defaults=defaults,
|
|
75
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
112
76
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
lstsq = settings['lstsq']
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
116
79
|
def f(x): return evaluate(x, 1)[1]
|
|
117
80
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
118
|
-
x_star = sixth_order_3p(x, f, f_j, lstsq)
|
|
81
|
+
x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
|
|
119
82
|
return x - x_star
|
|
120
83
|
|
|
121
84
|
# I don't think it works (I tested root finding with this and it goes all over the place)
|
|
@@ -173,15 +136,14 @@ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
|
|
|
173
136
|
|
|
174
137
|
class SixthOrder5P(HigherOrderMethodBase):
|
|
175
138
|
"""Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
|
|
176
|
-
def __init__(self, lstsq: bool=False,
|
|
139
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
177
140
|
defaults=dict(lstsq=lstsq)
|
|
178
|
-
super().__init__(defaults=defaults,
|
|
141
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
179
142
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
lstsq = settings['lstsq']
|
|
143
|
+
@torch.no_grad
|
|
144
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
183
145
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
184
|
-
x_star = sixth_order_5p(x, f_j, lstsq)
|
|
146
|
+
x_star = sixth_order_5p(x, f_j, setting['lstsq'])
|
|
185
147
|
return x - x_star
|
|
186
148
|
|
|
187
149
|
# 2f 1J 2 solves
|
|
@@ -196,16 +158,15 @@ class TwoPointNewton(HigherOrderMethodBase):
|
|
|
196
158
|
"""two-point Newton method with frozen derivative with third order convergence.
|
|
197
159
|
|
|
198
160
|
Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
|
|
199
|
-
def __init__(self, lstsq: bool=False,
|
|
161
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
200
162
|
defaults=dict(lstsq=lstsq)
|
|
201
|
-
super().__init__(defaults=defaults,
|
|
163
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
202
164
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
lstsq = settings['lstsq']
|
|
165
|
+
@torch.no_grad
|
|
166
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
206
167
|
def f(x): return evaluate(x, 1)[1]
|
|
207
168
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
208
|
-
x_star = two_point_newton(x, f, f_j, lstsq)
|
|
169
|
+
x_star = two_point_newton(x, f, f_j, setting['lstsq'])
|
|
209
170
|
return x - x_star
|
|
210
171
|
|
|
211
172
|
#3f 2J 1inv
|
|
@@ -224,15 +185,14 @@ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
|
|
|
224
185
|
|
|
225
186
|
class SixthOrder3PM2(HigherOrderMethodBase):
|
|
226
187
|
"""Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
|
|
227
|
-
def __init__(self, lstsq: bool=False,
|
|
188
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
228
189
|
defaults=dict(lstsq=lstsq)
|
|
229
|
-
super().__init__(defaults=defaults,
|
|
190
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
230
191
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
lstsq = settings['lstsq']
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
234
194
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
235
195
|
def f(x): return evaluate(x, 1)[1]
|
|
236
|
-
x_star = sixth_order_3pm2(x, f, f_j, lstsq)
|
|
196
|
+
x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
|
|
237
197
|
return x - x_star
|
|
238
198
|
|
|
@@ -1,21 +1,12 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from collections.abc import Callable
|
|
3
|
-
from functools import partial
|
|
4
2
|
from typing import Literal
|
|
5
3
|
|
|
6
4
|
import torch
|
|
7
5
|
|
|
8
|
-
from ...core import Chainable,
|
|
9
|
-
from ...utils import
|
|
10
|
-
from ...
|
|
11
|
-
|
|
12
|
-
hessian_mat,
|
|
13
|
-
hvp,
|
|
14
|
-
hvp_fd_central,
|
|
15
|
-
hvp_fd_forward,
|
|
16
|
-
jacobian_and_hessian_wrt,
|
|
17
|
-
)
|
|
18
|
-
from ...utils.linalg.linear_operator import DenseWithInverse, Dense
|
|
6
|
+
from ...core import Chainable, Transform, Objective, HessianMethod, Module
|
|
7
|
+
from ...utils import vec_to_tensors
|
|
8
|
+
from ...linalg.linear_operator import Dense, DenseWithInverse
|
|
9
|
+
|
|
19
10
|
|
|
20
11
|
def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
21
12
|
try:
|
|
@@ -26,10 +17,9 @@ def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
26
17
|
return None
|
|
27
18
|
|
|
28
19
|
def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
29
|
-
|
|
20
|
+
L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
30
21
|
if info == 0:
|
|
31
|
-
g.
|
|
32
|
-
return torch.cholesky_solve(g, x)
|
|
22
|
+
return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
|
|
33
23
|
return None
|
|
34
24
|
|
|
35
25
|
def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
@@ -49,10 +39,59 @@ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_n
|
|
|
49
39
|
except torch.linalg.LinAlgError:
|
|
50
40
|
return None
|
|
51
41
|
|
|
42
|
+
def _newton_step(objective: Objective, H: torch.Tensor, damping:float, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None, no_inner: Module | None = None) -> torch.Tensor:
|
|
43
|
+
"""INNER SHOULD BE NONE IN MOST CASES! Because Transform already has inner.
|
|
44
|
+
Returns the update tensor, then do vec_to_tensor(update, params)"""
|
|
45
|
+
# -------------------------------- inner step -------------------------------- #
|
|
46
|
+
if no_inner is not None:
|
|
47
|
+
objective = no_inner.step(objective)
|
|
48
|
+
|
|
49
|
+
update = objective.get_updates()
|
|
50
|
+
|
|
51
|
+
g = torch.cat([t.ravel() for t in update])
|
|
52
|
+
if g_proj is not None: g = g_proj(g)
|
|
53
|
+
|
|
54
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
55
|
+
update = None
|
|
56
|
+
|
|
57
|
+
if damping != 0:
|
|
58
|
+
H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
|
|
59
|
+
|
|
60
|
+
if H_tfm is not None:
|
|
61
|
+
ret = H_tfm(H, g)
|
|
62
|
+
|
|
63
|
+
if isinstance(ret, torch.Tensor):
|
|
64
|
+
update = ret
|
|
65
|
+
|
|
66
|
+
else: # returns (H, is_inv)
|
|
67
|
+
H, is_inv = ret
|
|
68
|
+
if is_inv: update = H @ g
|
|
69
|
+
|
|
70
|
+
if eigval_fn is not None:
|
|
71
|
+
update = _eigh_solve(H, g, eigval_fn, search_negative=False)
|
|
72
|
+
|
|
73
|
+
if update is None and use_lstsq: update = _least_squares_solve(H, g)
|
|
74
|
+
if update is None: update = _cholesky_solve(H, g)
|
|
75
|
+
if update is None: update = _lu_solve(H, g)
|
|
76
|
+
if update is None: update = _least_squares_solve(H, g)
|
|
77
|
+
|
|
78
|
+
return update
|
|
79
|
+
|
|
80
|
+
def _get_H(H: torch.Tensor, eigval_fn):
|
|
81
|
+
if eigval_fn is not None:
|
|
82
|
+
try:
|
|
83
|
+
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
84
|
+
L: torch.Tensor = eigval_fn(L)
|
|
85
|
+
H = Q @ L.diag_embed() @ Q.mH
|
|
86
|
+
H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
|
|
87
|
+
return DenseWithInverse(H, H_inv)
|
|
52
88
|
|
|
89
|
+
except torch.linalg.LinAlgError:
|
|
90
|
+
pass
|
|
53
91
|
|
|
92
|
+
return Dense(H)
|
|
54
93
|
|
|
55
|
-
class Newton(
|
|
94
|
+
class Newton(Transform):
|
|
56
95
|
"""Exact newton's method via autograd.
|
|
57
96
|
|
|
58
97
|
Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
|
|
@@ -60,7 +99,7 @@ class Newton(Module):
|
|
|
60
99
|
``g`` can be output of another module, if it is specifed in ``inner`` argument.
|
|
61
100
|
|
|
62
101
|
Note:
|
|
63
|
-
In most cases Newton should be the first module in the chain because it relies on autograd. Use the
|
|
102
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
64
103
|
|
|
65
104
|
Note:
|
|
66
105
|
This module requires the a closure passed to the optimizer step,
|
|
@@ -77,11 +116,6 @@ class Newton(Module):
|
|
|
77
116
|
when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
78
117
|
If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
|
|
79
118
|
argument will be ignored.
|
|
80
|
-
hessian_method (str):
|
|
81
|
-
how to calculate hessian. Defaults to "autograd".
|
|
82
|
-
vectorize (bool, optional):
|
|
83
|
-
whether to enable vectorized hessian. Defaults to True.
|
|
84
|
-
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
85
119
|
H_tfm (Callable | None, optional):
|
|
86
120
|
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
87
121
|
|
|
@@ -94,6 +128,22 @@ class Newton(Module):
|
|
|
94
128
|
eigval_fn (Callable | None, optional):
|
|
95
129
|
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
96
130
|
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
131
|
+
hessian_method (str):
|
|
132
|
+
Determines how hessian is computed.
|
|
133
|
+
|
|
134
|
+
- ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
|
|
135
|
+
- ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
136
|
+
- ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
|
|
137
|
+
- ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
|
|
138
|
+
- ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
|
|
139
|
+
- ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
140
|
+
- ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
141
|
+
- ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
|
|
142
|
+
|
|
143
|
+
Defaults to ``"batched_autograd"``.
|
|
144
|
+
h (float, optional):
|
|
145
|
+
finite difference step size for "fd_forward" and "fd_central".
|
|
146
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
97
147
|
|
|
98
148
|
# See also
|
|
99
149
|
|
|
@@ -111,10 +161,9 @@ class Newton(Module):
|
|
|
111
161
|
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
|
|
112
162
|
Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
|
|
113
163
|
|
|
114
|
-
Additionally, if ``eigval_fn`` is specified
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
This is more generally more computationally expensive.
|
|
164
|
+
Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
|
|
165
|
+
``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive,
|
|
166
|
+
but not by much
|
|
118
167
|
|
|
119
168
|
## Handling non-convexity
|
|
120
169
|
|
|
@@ -167,217 +216,45 @@ class Newton(Module):
|
|
|
167
216
|
def __init__(
|
|
168
217
|
self,
|
|
169
218
|
damping: float = 0,
|
|
170
|
-
search_negative: bool = False,
|
|
171
219
|
use_lstsq: bool = False,
|
|
172
220
|
update_freq: int = 1,
|
|
173
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
174
|
-
vectorize: bool = True,
|
|
175
|
-
inner: Chainable | None = None,
|
|
176
221
|
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
177
222
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
223
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
224
|
+
h: float = 1e-3,
|
|
225
|
+
inner: Chainable | None = None,
|
|
178
226
|
):
|
|
179
|
-
defaults =
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
if inner is not None:
|
|
183
|
-
self.set_child('inner', inner)
|
|
184
|
-
|
|
185
|
-
@torch.no_grad
|
|
186
|
-
def update(self, var):
|
|
187
|
-
params = TensorList(var.params)
|
|
188
|
-
closure = var.closure
|
|
189
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
190
|
-
|
|
191
|
-
settings = self.settings[params[0]]
|
|
192
|
-
damping = settings['damping']
|
|
193
|
-
hessian_method = settings['hessian_method']
|
|
194
|
-
vectorize = settings['vectorize']
|
|
195
|
-
update_freq = settings['update_freq']
|
|
196
|
-
|
|
197
|
-
step = self.global_state.get('step', 0)
|
|
198
|
-
self.global_state['step'] = step + 1
|
|
199
|
-
|
|
200
|
-
g_list = var.grad
|
|
201
|
-
H = None
|
|
202
|
-
if step % update_freq == 0:
|
|
203
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
204
|
-
if hessian_method == 'autograd':
|
|
205
|
-
with torch.enable_grad():
|
|
206
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
207
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
208
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
209
|
-
var.grad = g_list
|
|
210
|
-
H = flatten_jacobian(H_list)
|
|
211
|
-
|
|
212
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
213
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
214
|
-
with torch.enable_grad():
|
|
215
|
-
g_list = var.get_grad(retain_graph=True)
|
|
216
|
-
H = hessian_mat(partial(closure, backward=False), params,
|
|
217
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
218
|
-
|
|
219
|
-
else:
|
|
220
|
-
raise ValueError(hessian_method)
|
|
221
|
-
|
|
222
|
-
if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
|
|
223
|
-
self.global_state['H'] = H
|
|
227
|
+
defaults = locals().copy()
|
|
228
|
+
del defaults['self'], defaults['update_freq'], defaults["inner"]
|
|
229
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
224
230
|
|
|
225
231
|
@torch.no_grad
|
|
226
|
-
def
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
params = var.params
|
|
230
|
-
settings = self.settings[params[0]]
|
|
231
|
-
search_negative = settings['search_negative']
|
|
232
|
-
H_tfm = settings['H_tfm']
|
|
233
|
-
eigval_fn = settings['eigval_fn']
|
|
234
|
-
use_lstsq = settings['use_lstsq']
|
|
235
|
-
|
|
236
|
-
# -------------------------------- inner step -------------------------------- #
|
|
237
|
-
update = var.get_update()
|
|
238
|
-
if 'inner' in self.children:
|
|
239
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
240
|
-
|
|
241
|
-
g = torch.cat([t.ravel() for t in update])
|
|
232
|
+
def update_states(self, objective, states, settings):
|
|
233
|
+
fs = settings[0]
|
|
242
234
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
235
|
+
_, _, self.global_state['H'] = objective.hessian(
|
|
236
|
+
hessian_method=fs['hessian_method'],
|
|
237
|
+
h=fs['h'],
|
|
238
|
+
at_x0=True
|
|
239
|
+
)
|
|
247
240
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
H, is_inv = ret
|
|
253
|
-
if is_inv: update = H @ g
|
|
254
|
-
|
|
255
|
-
if search_negative or (eigval_fn is not None):
|
|
256
|
-
update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
|
|
257
|
-
|
|
258
|
-
if update is None and use_lstsq: update = _least_squares_solve(H, g)
|
|
259
|
-
if update is None: update = _cholesky_solve(H, g)
|
|
260
|
-
if update is None: update = _lu_solve(H, g)
|
|
261
|
-
if update is None: update = _least_squares_solve(H, g)
|
|
262
|
-
|
|
263
|
-
var.update = vec_to_tensors(update, params)
|
|
264
|
-
|
|
265
|
-
return var
|
|
266
|
-
|
|
267
|
-
def get_H(self,var):
|
|
268
|
-
H = self.global_state["H"]
|
|
269
|
-
settings = self.defaults
|
|
270
|
-
if settings['eigval_fn'] is not None:
|
|
271
|
-
try:
|
|
272
|
-
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
273
|
-
L = settings['eigval_fn'](L)
|
|
274
|
-
H = Q @ L.diag_embed() @ Q.mH
|
|
275
|
-
H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
|
|
276
|
-
return DenseWithInverse(H, H_inv)
|
|
277
|
-
|
|
278
|
-
except torch.linalg.LinAlgError:
|
|
279
|
-
pass
|
|
280
|
-
|
|
281
|
-
return Dense(H)
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
class InverseFreeNewton(Module):
|
|
285
|
-
"""Inverse-free newton's method
|
|
286
|
-
|
|
287
|
-
.. note::
|
|
288
|
-
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
289
|
-
|
|
290
|
-
.. note::
|
|
291
|
-
This module requires the a closure passed to the optimizer step,
|
|
292
|
-
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
293
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
241
|
+
@torch.no_grad
|
|
242
|
+
def apply_states(self, objective, states, settings):
|
|
243
|
+
params = objective.params
|
|
244
|
+
fs = settings[0]
|
|
294
245
|
|
|
295
|
-
|
|
296
|
-
|
|
246
|
+
update = _newton_step(
|
|
247
|
+
objective=objective,
|
|
248
|
+
H = self.global_state["H"],
|
|
249
|
+
damping = fs["damping"],
|
|
250
|
+
H_tfm = fs["H_tfm"],
|
|
251
|
+
eigval_fn = fs["eigval_fn"],
|
|
252
|
+
use_lstsq = fs["use_lstsq"],
|
|
253
|
+
)
|
|
297
254
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
"""
|
|
301
|
-
def __init__(
|
|
302
|
-
self,
|
|
303
|
-
update_freq: int = 1,
|
|
304
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
305
|
-
vectorize: bool = True,
|
|
306
|
-
inner: Chainable | None = None,
|
|
307
|
-
):
|
|
308
|
-
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
|
|
309
|
-
super().__init__(defaults)
|
|
255
|
+
objective.updates = vec_to_tensors(update, params)
|
|
256
|
+
return objective
|
|
310
257
|
|
|
311
|
-
|
|
312
|
-
|
|
258
|
+
def get_H(self,objective=...):
|
|
259
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
313
260
|
|
|
314
|
-
@torch.no_grad
|
|
315
|
-
def update(self, var):
|
|
316
|
-
params = TensorList(var.params)
|
|
317
|
-
closure = var.closure
|
|
318
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
319
|
-
|
|
320
|
-
settings = self.settings[params[0]]
|
|
321
|
-
hessian_method = settings['hessian_method']
|
|
322
|
-
vectorize = settings['vectorize']
|
|
323
|
-
update_freq = settings['update_freq']
|
|
324
|
-
|
|
325
|
-
step = self.global_state.get('step', 0)
|
|
326
|
-
self.global_state['step'] = step + 1
|
|
327
|
-
|
|
328
|
-
g_list = var.grad
|
|
329
|
-
Y = None
|
|
330
|
-
if step % update_freq == 0:
|
|
331
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
332
|
-
if hessian_method == 'autograd':
|
|
333
|
-
with torch.enable_grad():
|
|
334
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
335
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
336
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
337
|
-
var.grad = g_list
|
|
338
|
-
H = flatten_jacobian(H_list)
|
|
339
|
-
|
|
340
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
341
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
342
|
-
with torch.enable_grad():
|
|
343
|
-
g_list = var.get_grad(retain_graph=True)
|
|
344
|
-
H = hessian_mat(partial(closure, backward=False), params,
|
|
345
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
346
|
-
|
|
347
|
-
else:
|
|
348
|
-
raise ValueError(hessian_method)
|
|
349
|
-
|
|
350
|
-
self.global_state["H"] = H
|
|
351
|
-
|
|
352
|
-
# inverse free part
|
|
353
|
-
if 'Y' not in self.global_state:
|
|
354
|
-
num = H.T
|
|
355
|
-
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
356
|
-
finfo = torch.finfo(H.dtype)
|
|
357
|
-
Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
358
|
-
|
|
359
|
-
else:
|
|
360
|
-
Y = self.global_state['Y']
|
|
361
|
-
I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
362
|
-
I -= H @ Y
|
|
363
|
-
Y = self.global_state['Y'] = Y @ I
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
def apply(self, var):
|
|
367
|
-
Y = self.global_state["Y"]
|
|
368
|
-
params = var.params
|
|
369
|
-
|
|
370
|
-
# -------------------------------- inner step -------------------------------- #
|
|
371
|
-
update = var.get_update()
|
|
372
|
-
if 'inner' in self.children:
|
|
373
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
374
|
-
|
|
375
|
-
g = torch.cat([t.ravel() for t in update])
|
|
376
|
-
|
|
377
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
378
|
-
var.update = vec_to_tensors(Y@g, params)
|
|
379
|
-
|
|
380
|
-
return var
|
|
381
|
-
|
|
382
|
-
def get_H(self,var):
|
|
383
|
-
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|