torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +4 -10
- torchzero/core/__init__.py +4 -1
- torchzero/core/chain.py +50 -0
- torchzero/core/functional.py +37 -0
- torchzero/core/modular.py +237 -0
- torchzero/core/module.py +12 -599
- torchzero/core/reformulation.py +3 -1
- torchzero/core/transform.py +7 -5
- torchzero/core/var.py +376 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/conjugate_gradient/cg.py +16 -16
- torchzero/modules/experimental/__init__.py +1 -0
- torchzero/modules/experimental/newtonnewton.py +5 -5
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/functional.py +7 -0
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +2 -5
- torchzero/modules/grad_approximation/rfdm.py +27 -110
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +1 -1
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +11 -20
- torchzero/modules/line_search/scipy.py +15 -3
- torchzero/modules/line_search/strong_wolfe.py +3 -5
- torchzero/modules/misc/misc.py +2 -2
- torchzero/modules/misc/multistep.py +13 -13
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/quasi_newton.py +15 -6
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +5 -4
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +89 -0
- torchzero/modules/second_order/inm.py +105 -0
- torchzero/modules/second_order/newton.py +103 -193
- torchzero/modules/second_order/newton_cg.py +86 -110
- torchzero/modules/second_order/nystrom.py +1 -1
- torchzero/modules/second_order/rsn.py +227 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +6 -4
- torchzero/modules/wrappers/optim_wrapper.py +49 -42
- torchzero/modules/zeroth_order/__init__.py +1 -1
- torchzero/modules/zeroth_order/cd.py +1 -238
- torchzero/utils/derivatives.py +19 -19
- torchzero/utils/linalg/linear_operator.py +50 -2
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +1 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
- torchzero/modules/higher_order/__init__.py +0 -1
- /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
- {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, apply_transform, Var
|
|
9
|
+
from ...utils import TensorList, vec_to_tensors
|
|
10
|
+
from ...utils.linalg.linear_operator import DenseWithInverse, Dense
|
|
11
|
+
from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InverseFreeNewton(Module):
|
|
15
|
+
"""Inverse-free newton's method
|
|
16
|
+
|
|
17
|
+
.. note::
|
|
18
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
19
|
+
|
|
20
|
+
.. note::
|
|
21
|
+
This module requires the a closure passed to the optimizer step,
|
|
22
|
+
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
23
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
24
|
+
|
|
25
|
+
.. warning::
|
|
26
|
+
this uses roughly O(N^2) memory.
|
|
27
|
+
|
|
28
|
+
Reference
|
|
29
|
+
[Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
|
|
30
|
+
"""
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
update_freq: int = 1,
|
|
34
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
35
|
+
vectorize: bool = True,
|
|
36
|
+
inner: Chainable | None = None,
|
|
37
|
+
):
|
|
38
|
+
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
|
|
39
|
+
super().__init__(defaults)
|
|
40
|
+
|
|
41
|
+
if inner is not None:
|
|
42
|
+
self.set_child('inner', inner)
|
|
43
|
+
|
|
44
|
+
@torch.no_grad
|
|
45
|
+
def update(self, var):
|
|
46
|
+
update_freq = self.defaults['update_freq']
|
|
47
|
+
|
|
48
|
+
step = self.global_state.get('step', 0)
|
|
49
|
+
self.global_state['step'] = step + 1
|
|
50
|
+
|
|
51
|
+
if step % update_freq == 0:
|
|
52
|
+
loss, g_list, H = _get_loss_grad_and_hessian(
|
|
53
|
+
var, self.defaults['hessian_method'], self.defaults['vectorize']
|
|
54
|
+
)
|
|
55
|
+
self.global_state["H"] = H
|
|
56
|
+
|
|
57
|
+
# inverse free part
|
|
58
|
+
if 'Y' not in self.global_state:
|
|
59
|
+
num = H.T
|
|
60
|
+
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
61
|
+
|
|
62
|
+
finfo = torch.finfo(H.dtype)
|
|
63
|
+
self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
64
|
+
|
|
65
|
+
else:
|
|
66
|
+
Y = self.global_state['Y']
|
|
67
|
+
I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
68
|
+
I2 -= H @ Y
|
|
69
|
+
self.global_state['Y'] = Y @ I2
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def apply(self, var):
|
|
73
|
+
Y = self.global_state["Y"]
|
|
74
|
+
params = var.params
|
|
75
|
+
|
|
76
|
+
# -------------------------------- inner step -------------------------------- #
|
|
77
|
+
update = var.get_update()
|
|
78
|
+
if 'inner' in self.children:
|
|
79
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
80
|
+
|
|
81
|
+
g = torch.cat([t.ravel() for t in update])
|
|
82
|
+
|
|
83
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
84
|
+
var.update = vec_to_tensors(Y@g, params)
|
|
85
|
+
|
|
86
|
+
return var
|
|
87
|
+
|
|
88
|
+
def get_H(self,var):
|
|
89
|
+
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module
|
|
7
|
+
from ...utils import TensorList, vec_to_tensors
|
|
8
|
+
from ..functional import safe_clip
|
|
9
|
+
from .newton import _get_H, _get_loss_grad_and_hessian, _newton_step
|
|
10
|
+
|
|
11
|
+
@torch.no_grad
|
|
12
|
+
def inm(f:torch.Tensor, J:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
13
|
+
|
|
14
|
+
yy = safe_clip(y.dot(y))
|
|
15
|
+
ss = safe_clip(s.dot(s))
|
|
16
|
+
|
|
17
|
+
term1 = y.dot(y - J@s) / yy
|
|
18
|
+
FbT = f.outer(s).mul_(term1 / ss)
|
|
19
|
+
|
|
20
|
+
P = FbT.add_(J)
|
|
21
|
+
return P
|
|
22
|
+
|
|
23
|
+
def _eigval_fn(J: torch.Tensor, fn) -> torch.Tensor:
|
|
24
|
+
if fn is None: return J
|
|
25
|
+
L, Q = torch.linalg.eigh(J) # pylint:disable=not-callable
|
|
26
|
+
return (Q * L.unsqueeze(-2)) @ Q.mH
|
|
27
|
+
|
|
28
|
+
class INM(Module):
|
|
29
|
+
"""Improved Newton's Method (INM).
|
|
30
|
+
|
|
31
|
+
Reference:
|
|
32
|
+
[Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.](https://d-nb.info/1112813721/34)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
damping: float = 0,
|
|
38
|
+
use_lstsq: bool = False,
|
|
39
|
+
update_freq: int = 1,
|
|
40
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
41
|
+
vectorize: bool = True,
|
|
42
|
+
inner: Chainable | None = None,
|
|
43
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
44
|
+
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
45
|
+
):
|
|
46
|
+
defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
|
|
47
|
+
super().__init__(defaults)
|
|
48
|
+
|
|
49
|
+
if inner is not None:
|
|
50
|
+
self.set_child("inner", inner)
|
|
51
|
+
|
|
52
|
+
@torch.no_grad
|
|
53
|
+
def update(self, var):
|
|
54
|
+
update_freq = self.defaults['update_freq']
|
|
55
|
+
|
|
56
|
+
step = self.global_state.get('step', 0)
|
|
57
|
+
self.global_state['step'] = step + 1
|
|
58
|
+
|
|
59
|
+
if step % update_freq == 0:
|
|
60
|
+
_, f_list, J = _get_loss_grad_and_hessian(
|
|
61
|
+
var, self.defaults['hessian_method'], self.defaults['vectorize']
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
f = torch.cat([t.ravel() for t in f_list])
|
|
65
|
+
J = _eigval_fn(J, self.defaults["eigval_fn"])
|
|
66
|
+
|
|
67
|
+
x_list = TensorList(var.params)
|
|
68
|
+
f_list = TensorList(var.get_grad())
|
|
69
|
+
x_prev, f_prev = self.get_state(var.params, "x_prev", "f_prev", cls=TensorList)
|
|
70
|
+
|
|
71
|
+
# initialize on 1st step, do Newton step
|
|
72
|
+
if step == 0:
|
|
73
|
+
x_prev.copy_(x_list)
|
|
74
|
+
f_prev.copy_(f_list)
|
|
75
|
+
self.global_state["P"] = J
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
# INM update
|
|
79
|
+
s_list = x_list - x_prev
|
|
80
|
+
y_list = f_list - f_prev
|
|
81
|
+
x_prev.copy_(x_list)
|
|
82
|
+
f_prev.copy_(f_list)
|
|
83
|
+
|
|
84
|
+
self.global_state["P"] = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@torch.no_grad
|
|
88
|
+
def apply(self, var):
|
|
89
|
+
params = var.params
|
|
90
|
+
update = _newton_step(
|
|
91
|
+
var=var,
|
|
92
|
+
H = self.global_state["P"],
|
|
93
|
+
damping=self.defaults["damping"],
|
|
94
|
+
inner=self.children.get("inner", None),
|
|
95
|
+
H_tfm=self.defaults["H_tfm"],
|
|
96
|
+
eigval_fn=None, # it is applied in `update`
|
|
97
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
var.update = vec_to_tensors(update, params)
|
|
101
|
+
|
|
102
|
+
return var
|
|
103
|
+
|
|
104
|
+
def get_H(self,var=...):
|
|
105
|
+
return _get_H(self.global_state["P"], eigval_fn=None)
|
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module, apply_transform
|
|
8
|
+
from ...core import Chainable, Module, apply_transform, Var
|
|
9
9
|
from ...utils import TensorList, vec_to_tensors
|
|
10
10
|
from ...utils.derivatives import (
|
|
11
11
|
flatten_jacobian,
|
|
@@ -50,7 +50,88 @@ def _eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_n
|
|
|
50
50
|
return None
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def _get_loss_grad_and_hessian(var: Var, hessian_method:str, vectorize:bool):
|
|
54
|
+
"""returns (loss, g_list, H). Also sets var.loss and var.grad.
|
|
55
|
+
If hessian_method isn't 'autograd', loss is not set and returned as None"""
|
|
56
|
+
closure = var.closure
|
|
57
|
+
if closure is None:
|
|
58
|
+
raise RuntimeError("Second order methods requires a closure to be provided to the `step` method.")
|
|
53
59
|
|
|
60
|
+
params = var.params
|
|
61
|
+
|
|
62
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
63
|
+
loss = None
|
|
64
|
+
if hessian_method == 'autograd':
|
|
65
|
+
with torch.enable_grad():
|
|
66
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
67
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
68
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
69
|
+
var.grad = g_list
|
|
70
|
+
H = flatten_jacobian(H_list)
|
|
71
|
+
|
|
72
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
73
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
74
|
+
with torch.enable_grad():
|
|
75
|
+
g_list = var.get_grad(retain_graph=True)
|
|
76
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
77
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(hessian_method)
|
|
81
|
+
|
|
82
|
+
return loss, g_list, H
|
|
83
|
+
|
|
84
|
+
def _newton_step(var: Var, H: torch.Tensor, damping:float, inner: Module | None, H_tfm, eigval_fn, use_lstsq:bool, g_proj: Callable | None = None) -> torch.Tensor:
|
|
85
|
+
"""returns the update tensor, then do vec_to_tensor(update, params)"""
|
|
86
|
+
params = var.params
|
|
87
|
+
|
|
88
|
+
if damping != 0:
|
|
89
|
+
H = H + torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping)
|
|
90
|
+
|
|
91
|
+
# -------------------------------- inner step -------------------------------- #
|
|
92
|
+
update = var.get_update()
|
|
93
|
+
if inner is not None:
|
|
94
|
+
update = apply_transform(inner, update, params=params, grads=var.grad, loss=var.loss, var=var)
|
|
95
|
+
|
|
96
|
+
g = torch.cat([t.ravel() for t in update])
|
|
97
|
+
if g_proj is not None: g = g_proj(g)
|
|
98
|
+
|
|
99
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
100
|
+
update = None
|
|
101
|
+
|
|
102
|
+
if H_tfm is not None:
|
|
103
|
+
ret = H_tfm(H, g)
|
|
104
|
+
|
|
105
|
+
if isinstance(ret, torch.Tensor):
|
|
106
|
+
update = ret
|
|
107
|
+
|
|
108
|
+
else: # returns (H, is_inv)
|
|
109
|
+
H, is_inv = ret
|
|
110
|
+
if is_inv: update = H @ g
|
|
111
|
+
|
|
112
|
+
if eigval_fn is not None:
|
|
113
|
+
update = _eigh_solve(H, g, eigval_fn, search_negative=False)
|
|
114
|
+
|
|
115
|
+
if update is None and use_lstsq: update = _least_squares_solve(H, g)
|
|
116
|
+
if update is None: update = _cholesky_solve(H, g)
|
|
117
|
+
if update is None: update = _lu_solve(H, g)
|
|
118
|
+
if update is None: update = _least_squares_solve(H, g)
|
|
119
|
+
|
|
120
|
+
return update
|
|
121
|
+
|
|
122
|
+
def _get_H(H: torch.Tensor, eigval_fn):
|
|
123
|
+
if eigval_fn is not None:
|
|
124
|
+
try:
|
|
125
|
+
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
126
|
+
L: torch.Tensor = eigval_fn(L)
|
|
127
|
+
H = Q @ L.diag_embed() @ Q.mH
|
|
128
|
+
H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
|
|
129
|
+
return DenseWithInverse(H, H_inv)
|
|
130
|
+
|
|
131
|
+
except torch.linalg.LinAlgError:
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
return Dense(H)
|
|
54
135
|
|
|
55
136
|
class Newton(Module):
|
|
56
137
|
"""Exact newton's method via autograd.
|
|
@@ -81,7 +162,6 @@ class Newton(Module):
|
|
|
81
162
|
how to calculate hessian. Defaults to "autograd".
|
|
82
163
|
vectorize (bool, optional):
|
|
83
164
|
whether to enable vectorized hessian. Defaults to True.
|
|
84
|
-
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
85
165
|
H_tfm (Callable | None, optional):
|
|
86
166
|
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
87
167
|
|
|
@@ -94,6 +174,7 @@ class Newton(Module):
|
|
|
94
174
|
eigval_fn (Callable | None, optional):
|
|
95
175
|
optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
|
|
96
176
|
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
177
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
97
178
|
|
|
98
179
|
# See also
|
|
99
180
|
|
|
@@ -111,10 +192,9 @@ class Newton(Module):
|
|
|
111
192
|
The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
|
|
112
193
|
Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.
|
|
113
194
|
|
|
114
|
-
Additionally, if ``eigval_fn`` is specified
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
This is more generally more computationally expensive.
|
|
195
|
+
Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
|
|
196
|
+
``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive,
|
|
197
|
+
but not by much
|
|
118
198
|
|
|
119
199
|
## Handling non-convexity
|
|
120
200
|
|
|
@@ -167,16 +247,15 @@ class Newton(Module):
|
|
|
167
247
|
def __init__(
|
|
168
248
|
self,
|
|
169
249
|
damping: float = 0,
|
|
170
|
-
search_negative: bool = False,
|
|
171
250
|
use_lstsq: bool = False,
|
|
172
251
|
update_freq: int = 1,
|
|
173
252
|
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
174
253
|
vectorize: bool = True,
|
|
175
|
-
inner: Chainable | None = None,
|
|
176
254
|
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
177
255
|
eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
256
|
+
inner: Chainable | None = None,
|
|
178
257
|
):
|
|
179
|
-
defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn,
|
|
258
|
+
defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, update_freq=update_freq)
|
|
180
259
|
super().__init__(defaults)
|
|
181
260
|
|
|
182
261
|
if inner is not None:
|
|
@@ -184,200 +263,31 @@ class Newton(Module):
|
|
|
184
263
|
|
|
185
264
|
@torch.no_grad
|
|
186
265
|
def update(self, var):
|
|
187
|
-
params = TensorList(var.params)
|
|
188
|
-
closure = var.closure
|
|
189
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
190
|
-
|
|
191
|
-
settings = self.settings[params[0]]
|
|
192
|
-
damping = settings['damping']
|
|
193
|
-
hessian_method = settings['hessian_method']
|
|
194
|
-
vectorize = settings['vectorize']
|
|
195
|
-
update_freq = settings['update_freq']
|
|
196
|
-
|
|
197
266
|
step = self.global_state.get('step', 0)
|
|
198
267
|
self.global_state['step'] = step + 1
|
|
199
268
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
if hessian_method == 'autograd':
|
|
205
|
-
with torch.enable_grad():
|
|
206
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
207
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
208
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
209
|
-
var.grad = g_list
|
|
210
|
-
H = flatten_jacobian(H_list)
|
|
211
|
-
|
|
212
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
213
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
214
|
-
with torch.enable_grad():
|
|
215
|
-
g_list = var.get_grad(retain_graph=True)
|
|
216
|
-
H = hessian_mat(partial(closure, backward=False), params,
|
|
217
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
218
|
-
|
|
219
|
-
else:
|
|
220
|
-
raise ValueError(hessian_method)
|
|
221
|
-
|
|
222
|
-
if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
|
|
223
|
-
self.global_state['H'] = H
|
|
269
|
+
if step % self.defaults['update_freq'] == 0:
|
|
270
|
+
loss, g_list, self.global_state['H'] = _get_loss_grad_and_hessian(
|
|
271
|
+
var, self.defaults['hessian_method'], self.defaults['vectorize']
|
|
272
|
+
)
|
|
224
273
|
|
|
225
274
|
@torch.no_grad
|
|
226
275
|
def apply(self, var):
|
|
227
|
-
H = self.global_state["H"]
|
|
228
|
-
|
|
229
276
|
params = var.params
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
240
|
-
|
|
241
|
-
g = torch.cat([t.ravel() for t in update])
|
|
242
|
-
|
|
243
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
244
|
-
update = None
|
|
245
|
-
if H_tfm is not None:
|
|
246
|
-
ret = H_tfm(H, g)
|
|
247
|
-
|
|
248
|
-
if isinstance(ret, torch.Tensor):
|
|
249
|
-
update = ret
|
|
250
|
-
|
|
251
|
-
else: # returns (H, is_inv)
|
|
252
|
-
H, is_inv = ret
|
|
253
|
-
if is_inv: update = H @ g
|
|
254
|
-
|
|
255
|
-
if search_negative or (eigval_fn is not None):
|
|
256
|
-
update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)
|
|
257
|
-
|
|
258
|
-
if update is None and use_lstsq: update = _least_squares_solve(H, g)
|
|
259
|
-
if update is None: update = _cholesky_solve(H, g)
|
|
260
|
-
if update is None: update = _lu_solve(H, g)
|
|
261
|
-
if update is None: update = _least_squares_solve(H, g)
|
|
277
|
+
update = _newton_step(
|
|
278
|
+
var=var,
|
|
279
|
+
H = self.global_state["H"],
|
|
280
|
+
damping=self.defaults["damping"],
|
|
281
|
+
inner=self.children.get("inner", None),
|
|
282
|
+
H_tfm=self.defaults["H_tfm"],
|
|
283
|
+
eigval_fn=self.defaults["eigval_fn"],
|
|
284
|
+
use_lstsq=self.defaults["use_lstsq"],
|
|
285
|
+
)
|
|
262
286
|
|
|
263
287
|
var.update = vec_to_tensors(update, params)
|
|
264
288
|
|
|
265
289
|
return var
|
|
266
290
|
|
|
267
|
-
def get_H(self,var):
|
|
268
|
-
|
|
269
|
-
settings = self.defaults
|
|
270
|
-
if settings['eigval_fn'] is not None:
|
|
271
|
-
try:
|
|
272
|
-
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
273
|
-
L = settings['eigval_fn'](L)
|
|
274
|
-
H = Q @ L.diag_embed() @ Q.mH
|
|
275
|
-
H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
|
|
276
|
-
return DenseWithInverse(H, H_inv)
|
|
277
|
-
|
|
278
|
-
except torch.linalg.LinAlgError:
|
|
279
|
-
pass
|
|
280
|
-
|
|
281
|
-
return Dense(H)
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
class InverseFreeNewton(Module):
|
|
285
|
-
"""Inverse-free newton's method
|
|
286
|
-
|
|
287
|
-
.. note::
|
|
288
|
-
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
289
|
-
|
|
290
|
-
.. note::
|
|
291
|
-
This module requires the a closure passed to the optimizer step,
|
|
292
|
-
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
293
|
-
The closure must accept a ``backward`` argument (refer to documentation).
|
|
294
|
-
|
|
295
|
-
.. warning::
|
|
296
|
-
this uses roughly O(N^2) memory.
|
|
297
|
-
|
|
298
|
-
Reference
|
|
299
|
-
Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
|
|
300
|
-
"""
|
|
301
|
-
def __init__(
|
|
302
|
-
self,
|
|
303
|
-
update_freq: int = 1,
|
|
304
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
305
|
-
vectorize: bool = True,
|
|
306
|
-
inner: Chainable | None = None,
|
|
307
|
-
):
|
|
308
|
-
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
|
|
309
|
-
super().__init__(defaults)
|
|
310
|
-
|
|
311
|
-
if inner is not None:
|
|
312
|
-
self.set_child('inner', inner)
|
|
313
|
-
|
|
314
|
-
@torch.no_grad
|
|
315
|
-
def update(self, var):
|
|
316
|
-
params = TensorList(var.params)
|
|
317
|
-
closure = var.closure
|
|
318
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
319
|
-
|
|
320
|
-
settings = self.settings[params[0]]
|
|
321
|
-
hessian_method = settings['hessian_method']
|
|
322
|
-
vectorize = settings['vectorize']
|
|
323
|
-
update_freq = settings['update_freq']
|
|
324
|
-
|
|
325
|
-
step = self.global_state.get('step', 0)
|
|
326
|
-
self.global_state['step'] = step + 1
|
|
327
|
-
|
|
328
|
-
g_list = var.grad
|
|
329
|
-
Y = None
|
|
330
|
-
if step % update_freq == 0:
|
|
331
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
332
|
-
if hessian_method == 'autograd':
|
|
333
|
-
with torch.enable_grad():
|
|
334
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
335
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
336
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
337
|
-
var.grad = g_list
|
|
338
|
-
H = flatten_jacobian(H_list)
|
|
339
|
-
|
|
340
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
341
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
342
|
-
with torch.enable_grad():
|
|
343
|
-
g_list = var.get_grad(retain_graph=True)
|
|
344
|
-
H = hessian_mat(partial(closure, backward=False), params,
|
|
345
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
346
|
-
|
|
347
|
-
else:
|
|
348
|
-
raise ValueError(hessian_method)
|
|
349
|
-
|
|
350
|
-
self.global_state["H"] = H
|
|
351
|
-
|
|
352
|
-
# inverse free part
|
|
353
|
-
if 'Y' not in self.global_state:
|
|
354
|
-
num = H.T
|
|
355
|
-
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
356
|
-
finfo = torch.finfo(H.dtype)
|
|
357
|
-
Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))
|
|
358
|
-
|
|
359
|
-
else:
|
|
360
|
-
Y = self.global_state['Y']
|
|
361
|
-
I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
362
|
-
I -= H @ Y
|
|
363
|
-
Y = self.global_state['Y'] = Y @ I
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
def apply(self, var):
|
|
367
|
-
Y = self.global_state["Y"]
|
|
368
|
-
params = var.params
|
|
369
|
-
|
|
370
|
-
# -------------------------------- inner step -------------------------------- #
|
|
371
|
-
update = var.get_update()
|
|
372
|
-
if 'inner' in self.children:
|
|
373
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)
|
|
374
|
-
|
|
375
|
-
g = torch.cat([t.ravel() for t in update])
|
|
376
|
-
|
|
377
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
378
|
-
var.update = vec_to_tensors(Y@g, params)
|
|
379
|
-
|
|
380
|
-
return var
|
|
291
|
+
def get_H(self,var=...):
|
|
292
|
+
return _get_H(self.global_state["H"], self.defaults["eigval_fn"])
|
|
381
293
|
|
|
382
|
-
def get_H(self,var):
|
|
383
|
-
return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])
|