torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,17 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from contextlib import nullcontext
|
|
3
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable, Mapping
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
4
5
|
import numpy as np
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import TensorList, vec_to_tensors
|
|
9
|
-
|
|
10
|
-
flatten_jacobian,
|
|
11
|
-
jacobian_wrt,
|
|
12
|
-
)
|
|
8
|
+
from ...core import Chainable, DerivativesMethod, Objective, Transform
|
|
9
|
+
from ...utils import TensorList, vec_to_tensors
|
|
10
|
+
|
|
13
11
|
|
|
14
|
-
class HigherOrderMethodBase(
|
|
15
|
-
def __init__(self, defaults: dict | None = None,
|
|
16
|
-
self.
|
|
12
|
+
class HigherOrderMethodBase(Transform, ABC):
|
|
13
|
+
def __init__(self, defaults: dict | None = None, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
14
|
+
self._derivatives_method: DerivativesMethod = derivatives_method
|
|
17
15
|
super().__init__(defaults)
|
|
18
16
|
|
|
19
17
|
@abstractmethod
|
|
@@ -21,61 +19,27 @@ class HigherOrderMethodBase(Module, ABC):
|
|
|
21
19
|
self,
|
|
22
20
|
x: torch.Tensor,
|
|
23
21
|
evaluate: Callable[[torch.Tensor, int], tuple[torch.Tensor, ...]],
|
|
24
|
-
|
|
22
|
+
objective: Objective,
|
|
23
|
+
setting: Mapping[str, Any],
|
|
25
24
|
) -> torch.Tensor:
|
|
26
25
|
""""""
|
|
27
26
|
|
|
28
27
|
@torch.no_grad
|
|
29
|
-
def
|
|
30
|
-
params = TensorList(
|
|
31
|
-
|
|
32
|
-
closure =
|
|
28
|
+
def apply_states(self, objective, states, settings):
|
|
29
|
+
params = TensorList(objective.params)
|
|
30
|
+
|
|
31
|
+
closure = objective.closure
|
|
33
32
|
if closure is None: raise RuntimeError('MultipointNewton requires closure')
|
|
34
|
-
|
|
33
|
+
derivatives_method = self._derivatives_method
|
|
35
34
|
|
|
36
35
|
def evaluate(x, order) -> tuple[torch.Tensor, ...]:
|
|
37
36
|
"""order=0 - returns (loss,), order=1 - returns (loss, grad), order=2 - returns (loss, grad, hessian), etc."""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
if order == 0:
|
|
41
|
-
loss = closure(False)
|
|
42
|
-
params.copy_(x0)
|
|
43
|
-
return (loss, )
|
|
44
|
-
|
|
45
|
-
if order == 1:
|
|
46
|
-
with torch.enable_grad():
|
|
47
|
-
loss = closure()
|
|
48
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
49
|
-
params.copy_(x0)
|
|
50
|
-
return loss, torch.cat([g.ravel() for g in grad])
|
|
51
|
-
|
|
52
|
-
with torch.enable_grad():
|
|
53
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
54
|
-
|
|
55
|
-
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
56
|
-
var.grad = list(g_list)
|
|
57
|
-
|
|
58
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
59
|
-
n = g.numel()
|
|
60
|
-
ret = [loss, g]
|
|
61
|
-
T = g # current derivatives tensor
|
|
62
|
-
|
|
63
|
-
# get all derivative up to order
|
|
64
|
-
for o in range(2, order + 1):
|
|
65
|
-
is_last = o == order
|
|
66
|
-
T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
|
|
67
|
-
with torch.no_grad() if is_last else nullcontext():
|
|
68
|
-
# the shape is (ndim, ) * order
|
|
69
|
-
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
70
|
-
ret.append(T)
|
|
71
|
-
|
|
72
|
-
params.copy_(x0)
|
|
73
|
-
return tuple(ret)
|
|
37
|
+
return objective.derivatives_at(x, order, method=derivatives_method)
|
|
74
38
|
|
|
75
39
|
x = torch.cat([p.ravel() for p in params])
|
|
76
|
-
dir = self.one_iteration(x, evaluate,
|
|
77
|
-
|
|
78
|
-
return
|
|
40
|
+
dir = self.one_iteration(x, evaluate, objective, settings[0])
|
|
41
|
+
objective.updates = vec_to_tensors(dir, objective.params)
|
|
42
|
+
return objective
|
|
79
43
|
|
|
80
44
|
def _inv(A: torch.Tensor, lstsq:bool) -> torch.Tensor:
|
|
81
45
|
if lstsq: return torch.linalg.pinv(A) # pylint:disable=not-callable
|
|
@@ -106,16 +70,15 @@ class SixthOrder3P(HigherOrderMethodBase):
|
|
|
106
70
|
|
|
107
71
|
Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
|
|
108
72
|
"""
|
|
109
|
-
def __init__(self, lstsq: bool=False,
|
|
73
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
110
74
|
defaults=dict(lstsq=lstsq)
|
|
111
|
-
super().__init__(defaults=defaults,
|
|
75
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
112
76
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
lstsq = settings['lstsq']
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
116
79
|
def f(x): return evaluate(x, 1)[1]
|
|
117
80
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
118
|
-
x_star = sixth_order_3p(x, f, f_j, lstsq)
|
|
81
|
+
x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
|
|
119
82
|
return x - x_star
|
|
120
83
|
|
|
121
84
|
# I don't think it works (I tested root finding with this and it goes all over the place)
|
|
@@ -173,15 +136,14 @@ def sixth_order_5p(x:torch.Tensor, f_j, lstsq:bool=False):
|
|
|
173
136
|
|
|
174
137
|
class SixthOrder5P(HigherOrderMethodBase):
|
|
175
138
|
"""Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
|
|
176
|
-
def __init__(self, lstsq: bool=False,
|
|
139
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
177
140
|
defaults=dict(lstsq=lstsq)
|
|
178
|
-
super().__init__(defaults=defaults,
|
|
141
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
179
142
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
lstsq = settings['lstsq']
|
|
143
|
+
@torch.no_grad
|
|
144
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
183
145
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
184
|
-
x_star = sixth_order_5p(x, f_j, lstsq)
|
|
146
|
+
x_star = sixth_order_5p(x, f_j, setting['lstsq'])
|
|
185
147
|
return x - x_star
|
|
186
148
|
|
|
187
149
|
# 2f 1J 2 solves
|
|
@@ -196,16 +158,15 @@ class TwoPointNewton(HigherOrderMethodBase):
|
|
|
196
158
|
"""two-point Newton method with frozen derivative with third order convergence.
|
|
197
159
|
|
|
198
160
|
Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
|
|
199
|
-
def __init__(self, lstsq: bool=False,
|
|
161
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
200
162
|
defaults=dict(lstsq=lstsq)
|
|
201
|
-
super().__init__(defaults=defaults,
|
|
163
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
202
164
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
lstsq = settings['lstsq']
|
|
165
|
+
@torch.no_grad
|
|
166
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
206
167
|
def f(x): return evaluate(x, 1)[1]
|
|
207
168
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
208
|
-
x_star = two_point_newton(x, f, f_j, lstsq)
|
|
169
|
+
x_star = two_point_newton(x, f, f_j, setting['lstsq'])
|
|
209
170
|
return x - x_star
|
|
210
171
|
|
|
211
172
|
#3f 2J 1inv
|
|
@@ -224,15 +185,14 @@ def sixth_order_3pm2(x:torch.Tensor, f, f_j, lstsq:bool=False):
|
|
|
224
185
|
|
|
225
186
|
class SixthOrder3PM2(HigherOrderMethodBase):
|
|
226
187
|
"""Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
|
|
227
|
-
def __init__(self, lstsq: bool=False,
|
|
188
|
+
def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
|
|
228
189
|
defaults=dict(lstsq=lstsq)
|
|
229
|
-
super().__init__(defaults=defaults,
|
|
190
|
+
super().__init__(defaults=defaults, derivatives_method=derivatives_method)
|
|
230
191
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
lstsq = settings['lstsq']
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def one_iteration(self, x, evaluate, objective, setting):
|
|
234
194
|
def f_j(x): return evaluate(x, 2)[1:]
|
|
235
195
|
def f(x): return evaluate(x, 1)[1]
|
|
236
|
-
x_star = sixth_order_3pm2(x, f, f_j, lstsq)
|
|
196
|
+
x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
|
|
237
197
|
return x - x_star
|
|
238
198
|
|
|
@@ -1,21 +1,12 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from collections.abc import Callable
|
|
3
|
-
from functools import partial
|
|
4
2
|
from typing import Literal
|
|
5
3
|
|
|
6
4
|
import torch
|
|
7
5
|
|
|
8
|
-
from ...core import Chainable,
|
|
9
|
-
from ...utils import
|
|
10
|
-
from ...
|
|
11
|
-
|
|
12
|
-
hessian_mat,
|
|
13
|
-
hvp,
|
|
14
|
-
hvp_fd_central,
|
|
15
|
-
hvp_fd_forward,
|
|
16
|
-
jacobian_and_hessian_wrt,
|
|
17
|
-
)
|
|
18
|
-
from ...utils.linalg.linear_operator import DenseWithInverse, Dense
|
|
6
|
+
from ...core import Chainable, Transform, Objective, HessianMethod, Module
|
|
7
|
+
from ...utils import vec_to_tensors
|
|
8
|
+
from ...linalg.linear_operator import Dense, DenseWithInverse
|
|
9
|
+
|
|
19
10
|
|
|
20
11
|
def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
21
12
|
try:
|
|
@@ -26,10 +17,9 @@ def _lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
26
17
|
return None
|
|
27
18
|
|
|
28
19
|
def _cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
29
|
-
|
|
20
|
+
L, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
30
21
|
if info == 0:
|
|
31
|
-
g.
|
|
32
|
-
return torch.cholesky_solve(g, x)
|
|
22
|
+
return torch.cholesky_solve(g.unsqueeze(-1), L).squeeze(-1)
|
|
33
23
|
return None
|
|
34
24
|
|
|
35
25
|
def _least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
@@ -49,49 +39,14 @@ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_n
|
|
|
49
39
|
except torch.linalg.LinAlgError:
|
|
50
40
|
return None
|
|
51
41
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
If hessian_method isn't 'autograd', loss is not set and returned as None"""
|
|
56
|
-
closure = var.closure
|
|
57
|
-
if closure is None:
|
|
58
|
-
raise RuntimeError("Second order methods requires a closure to be provided to the `step` method.")
|
|
59
|
-
|
|
60
|
-
params = var.params
|
|
61
|
-
|
|
62
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
63
|
-
loss = None
|
|
64
|
-
if hessian_method == 'autograd':
|
|
65
|
-
with torch.enable_grad():
|
|
66
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
67
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
68
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
69
|
-
var.grad = g_list
|
|
70
|
-
H = flatten_jacobian(H_list)
|
|
71
|
-
|
|
72
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
73
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
74
|
-
with torch.enable_grad():
|
|
75
|
-
g_list = var.get_grad(retain_graph=True)
|
|
76
|
-
H = hessian_mat(partial(closure, backward=False), params,
|
|
77
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
78
|
-
|
|
79
|
-
else:
|
|
80
|
-
raise ValueError(hessian_method)
|
|
81
|
-
|
|
82
|
-
return loss, g_list, H
|
|
83
|
-
|
|
84
|
-
def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None) -> torch.Tensor:
|
|
85
|
-
"""returns the update tensor, then do vec_to_tensor(update, params)"""
|
|
86
|
-
params = var.params
|
|
87
|
-
|
|
88
|
-
if damping != 0:
|
|
89
|
-
H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
|
|
90
|
-
|
|
42
|
+
def _newton_step(objective: Objective, H: torch.Tensor, damping:float, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None, no_inner: Module | None = None) -> torch.Tensor:
|
|
43
|
+
"""INNER SHOULD BE NONE IN MOST CASES! Because Transform already has inner.
|
|
44
|
+
Returns the update tensor, then do vec_to_tensor(update, params)"""
|
|
91
45
|
# -------------------------------- inner step -------------------------------- #
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
46
|
+
if no_inner is not None:
|
|
47
|
+
objective = no_inner.step(objective)
|
|
48
|
+
|
|
49
|
+
update = objective.get_updates()
|
|
95
50
|
|
|
96
51
|
g = torch.cat([t.ravel() for t in update])
|
|
97
52
|
if g_proj is not None: g = g_proj(g)
|
|
@@ -99,6 +54,9 @@ def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None,
|
|
|
99
54
|
# ----------------------------------- solve ---------------------------------- #
|
|
100
55
|
update = None
|
|
101
56
|
|
|
57
|
+
if damping != 0:
|
|
58
|
+
H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
|
|
59
|
+
|
|
102
60
|
if H_tfm is not None:
|
|
103
61
|
ret = H_tfm(H, g)
|
|
104
62
|
|
|
@@ -133,7 +91,7 @@ def _get_H(H: torch.Tensor, eigval_fn):
|
|
|
133
91
|
|
|
134
92
|
return Dense(H)
|
|
135
93
|
|
|
136
|
-
class Newton(
|
|
94
|
+
class Newton(Transform):
|
|
137
95
|
"""Exact newton's method via autograd.
|
|
138
96
|
|
|
139
97
|
Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
|
|
@@ -141,7 +99,7 @@ class Newton(Module):
|
|
|
141
99
|
``g`` can be output of another module, if it is specifed in ``inner`` argument.
|
|
142
100
|
|
|
143
101
|
Note:
|
|
144
|
-
In most cases Newton should be the first module in the chain because it relies on autograd. Use the
|
|
102
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
|
|
145
103
|
|
|
146
104
|
Note:
|
|
147
105
|
This module requires the a closure passed to the optimizer step,
|
|
@@ -158,10 +116,6 @@ class Newton(Module):
|
|
|
158
116
|
when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
|
|
159
117
|
If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
|
|
160
118
|
argument will be ignored.
|
|
161
|
-
hessian_method (str):
|
|
162
|
-
how to calculate hessian. Defaults to "autograd".
|
|
163
|
-
vectorize (bool, optional):
|
|
164
|
-
whether to enable vectorized hessian. Defaults to True.
|
|
165
119
|
H_tfm (Callable | None, optional):
|
|
166
120
|
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
167
121
|
|
|
@@ -174,6 +128,21 @@ class Newton(Module):
|
|
|
174
128
|
eigval_fn (Callable | None, optional):
|
|
175
129
|
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
176
130
|
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
131
|
+
hessian_method (str):
|
|
132
|
+
Determines how hessian is computed.
|
|
133
|
+
|
|
134
|
+
- ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
|
|
135
|
+
- ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
136
|
+
- ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
|
|
137
|
+
- ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
|
|
138
|
+
- ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
|
|
139
|
+
- ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
140
|
+
- ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
141
|
+
- ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
|
|
142
|
+
|
|
143
|
+
Defaults to ``"batched_autograd"``.
|
|
144
|
+
h (float, optional):
|
|
145
|
+
finite difference step size for "fd_forward" and "fd_central".
|
|
177
146
|
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
178
147
|
|
|
179
148
|
# See also
|
|
@@ -249,45 +218,43 @@ class Newton(Module):
|
|
|
249
218
|
damping: float = 0,
|
|
250
219
|
use_lstsq: bool = False,
|
|
251
220
|
update_freq: int = 1,
|
|
252
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
253
|
-
vectorize: bool = True,
|
|
254
221
|
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
255
222
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
223
|
+
hessian_method: HessianMethod = "batched_autograd",
|
|
224
|
+
h: float = 1e-3,
|
|
256
225
|
inner: Chainable | None = None,
|
|
257
226
|
):
|
|
258
|
-
defaults =
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
if inner is not None:
|
|
262
|
-
self.set_child('inner', inner)
|
|
227
|
+
defaults = locals().copy()
|
|
228
|
+
del defaults['self'], defaults['update_freq'], defaults["inner"]
|
|
229
|
+
super().__init__(defaults, update_freq=update_freq, inner=inner)
|
|
263
230
|
|
|
264
231
|
@torch.no_grad
|
|
265
|
-
def
|
|
266
|
-
|
|
267
|
-
self.global_state['step'] = step + 1
|
|
232
|
+
def update_states(self, objective, states, settings):
|
|
233
|
+
fs = settings[0]
|
|
268
234
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
235
|
+
_, _, self.global_state['H'] = objective.hessian(
|
|
236
|
+
hessian_method=fs['hessian_method'],
|
|
237
|
+
h=fs['h'],
|
|
238
|
+
at_x0=True
|
|
239
|
+
)
|
|
273
240
|
|
|
274
241
|
@torch.no_grad
|
|
275
|
-
def
|
|
276
|
-
params =
|
|
242
|
+
def apply_states(self, objective, states, settings):
|
|
243
|
+
params = objective.params
|
|
244
|
+
fs = settings[0]
|
|
245
|
+
|
|
277
246
|
update = _newton_step(
|
|
278
|
-
|
|
247
|
+
objective=objective,
|
|
279
248
|
H = self.global_state["H"],
|
|
280
|
-
damping=
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
use_lstsq=self.defaults["use_lstsq"],
|
|
249
|
+
damping = fs["damping"],
|
|
250
|
+
H_tfm = fs["H_tfm"],
|
|
251
|
+
eigval_fn = fs["eigval_fn"],
|
|
252
|
+
use_lstsq = fs["use_lstsq"],
|
|
285
253
|
)
|
|
286
254
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
return var
|
|
255
|
+
objective.updates = vec_to_tensors(update, params)
|
|
256
|
+
return objective
|
|
290
257
|
|
|
291
|
-
def get_H(self,
|
|
258
|
+
def get_H(self,objective=...):
|
|
292
259
|
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
293
260
|
|