torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module,
|
|
8
|
+
from ...core import Chainable, Module, apply_transform
|
|
9
9
|
from ...utils import TensorList, vec_to_tensors
|
|
10
10
|
from ...utils.derivatives import (
|
|
11
11
|
hessian_list_to_mat,
|
|
@@ -18,9 +18,12 @@ from ...utils.derivatives import (
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
try:
|
|
22
|
+
x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
|
|
23
|
+
if info == 0: return x
|
|
24
|
+
return None
|
|
25
|
+
except RuntimeError:
|
|
26
|
+
return None
|
|
24
27
|
|
|
25
28
|
def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
26
29
|
x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
@@ -32,12 +35,17 @@ def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
32
35
|
def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
33
36
|
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
34
37
|
|
|
35
|
-
def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
|
|
38
|
+
def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
|
|
36
39
|
try:
|
|
37
40
|
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
38
41
|
if tfm is not None: L = tfm(L)
|
|
39
|
-
L
|
|
40
|
-
|
|
42
|
+
if search_negative and L[0] < 0:
|
|
43
|
+
d = Q[0]
|
|
44
|
+
# use eigvec or -eigvec depending on if it points in same direction as gradient
|
|
45
|
+
return g.dot(d).sign() * d
|
|
46
|
+
|
|
47
|
+
return Q @ ((Q.mH @ g) / L)
|
|
48
|
+
|
|
41
49
|
except torch.linalg.LinAlgError:
|
|
42
50
|
return None
|
|
43
51
|
|
|
@@ -45,103 +53,286 @@ def tikhonov_(H: torch.Tensor, reg: float):
|
|
|
45
53
|
if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
|
|
46
54
|
return H
|
|
47
55
|
|
|
48
|
-
def eig_tikhonov_(H: torch.Tensor, reg: float):
|
|
49
|
-
v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
|
|
50
|
-
return tikhonov_(H, v)
|
|
51
|
-
|
|
52
56
|
|
|
53
57
|
class Newton(Module):
|
|
54
|
-
"""Exact newton via autograd.
|
|
58
|
+
"""Exact newton's method via autograd.
|
|
59
|
+
|
|
60
|
+
.. note::
|
|
61
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
62
|
+
|
|
63
|
+
.. note::
|
|
64
|
+
This module requires the a closure passed to the optimizer step,
|
|
65
|
+
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
66
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
67
|
+
|
|
68
|
+
.. warning::
|
|
69
|
+
this uses roughly O(N^2) memory.
|
|
70
|
+
|
|
55
71
|
|
|
56
72
|
Args:
|
|
57
73
|
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
58
|
-
|
|
74
|
+
search_negative (bool, Optional):
|
|
75
|
+
if True, whenever a negative eigenvalue is detected,
|
|
76
|
+
search direction is proposed along an eigenvector corresponding to a negative eigenvalue.
|
|
59
77
|
hessian_method (str):
|
|
60
78
|
how to calculate hessian. Defaults to "autograd".
|
|
61
79
|
vectorize (bool, optional):
|
|
62
80
|
whether to enable vectorized hessian. Defaults to True.
|
|
63
|
-
inner (Chainable | None, optional):
|
|
81
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
64
82
|
H_tfm (Callable | None, optional):
|
|
65
83
|
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
66
84
|
|
|
67
|
-
must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
68
|
-
which must be True if transform inverted the hessian and False otherwise.
|
|
85
|
+
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
86
|
+
which must be True if transform inverted the hessian and False otherwise.
|
|
87
|
+
|
|
88
|
+
Or it returns a single tensor which is used as the update.
|
|
89
|
+
|
|
90
|
+
Defaults to None.
|
|
69
91
|
eigval_tfm (Callable | None, optional):
|
|
70
92
|
optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
|
|
71
|
-
If this is specified, eigendecomposition will be used to
|
|
93
|
+
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
94
|
+
|
|
95
|
+
Examples:
|
|
96
|
+
Newton's method with backtracking line search
|
|
97
|
+
|
|
98
|
+
.. code-block:: python
|
|
99
|
+
|
|
100
|
+
opt = tz.Modular(
|
|
101
|
+
model.parameters(),
|
|
102
|
+
tz.m.Newton(),
|
|
103
|
+
tz.m.Backtracking()
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
Newton's method modified for non-convex functions by taking matrix absolute value of the hessian
|
|
107
|
+
|
|
108
|
+
.. code-block:: python
|
|
109
|
+
|
|
110
|
+
opt = tz.Modular(
|
|
111
|
+
model.parameters(),
|
|
112
|
+
tz.m.Newton(eigval_tfm=lambda x: torch.abs(x).clip(min=0.1)),
|
|
113
|
+
tz.m.Backtracking()
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
Newton's method modified for non-convex functions by searching along negative curvature directions
|
|
117
|
+
|
|
118
|
+
.. code-block:: python
|
|
119
|
+
|
|
120
|
+
opt = tz.Modular(
|
|
121
|
+
model.parameters(),
|
|
122
|
+
tz.m.Newton(search_negative=True),
|
|
123
|
+
tz.m.Backtracking()
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
Newton preconditioning applied to momentum
|
|
127
|
+
|
|
128
|
+
.. code-block:: python
|
|
129
|
+
|
|
130
|
+
opt = tz.Modular(
|
|
131
|
+
model.parameters(),
|
|
132
|
+
tz.m.Newton(inner=tz.m.EMA(0.9)),
|
|
133
|
+
tz.m.LR(0.1)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient, but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
|
|
137
|
+
|
|
138
|
+
.. code-block:: python
|
|
139
|
+
|
|
140
|
+
opt = tz.Modular(
|
|
141
|
+
model.parameters(),
|
|
142
|
+
tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
|
|
143
|
+
tz.m.Backtracking()
|
|
144
|
+
)
|
|
72
145
|
|
|
73
146
|
"""
|
|
74
147
|
def __init__(
|
|
75
148
|
self,
|
|
76
149
|
reg: float = 1e-6,
|
|
77
|
-
|
|
150
|
+
search_negative: bool = False,
|
|
151
|
+
update_freq: int = 1,
|
|
78
152
|
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
79
153
|
vectorize: bool = True,
|
|
80
154
|
inner: Chainable | None = None,
|
|
81
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
|
|
155
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
82
156
|
eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
83
157
|
):
|
|
84
|
-
defaults = dict(reg=reg,
|
|
158
|
+
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative, update_freq=update_freq)
|
|
85
159
|
super().__init__(defaults)
|
|
86
160
|
|
|
87
161
|
if inner is not None:
|
|
88
162
|
self.set_child('inner', inner)
|
|
89
163
|
|
|
90
164
|
@torch.no_grad
|
|
91
|
-
def step(self,
|
|
92
|
-
params = TensorList(
|
|
93
|
-
closure =
|
|
165
|
+
def step(self, var):
|
|
166
|
+
params = TensorList(var.params)
|
|
167
|
+
closure = var.closure
|
|
94
168
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
95
169
|
|
|
96
170
|
settings = self.settings[params[0]]
|
|
97
171
|
reg = settings['reg']
|
|
98
|
-
|
|
172
|
+
search_negative = settings['search_negative']
|
|
99
173
|
hessian_method = settings['hessian_method']
|
|
100
174
|
vectorize = settings['vectorize']
|
|
101
175
|
H_tfm = settings['H_tfm']
|
|
102
176
|
eigval_tfm = settings['eigval_tfm']
|
|
177
|
+
update_freq = settings['update_freq']
|
|
178
|
+
|
|
179
|
+
step = self.global_state.get('step', 0)
|
|
180
|
+
self.global_state['step'] = step + 1
|
|
181
|
+
|
|
182
|
+
g_list = var.grad
|
|
183
|
+
H = None
|
|
184
|
+
if step % update_freq == 0:
|
|
185
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
186
|
+
if hessian_method == 'autograd':
|
|
187
|
+
with torch.enable_grad():
|
|
188
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
189
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
190
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
191
|
+
var.grad = g_list
|
|
192
|
+
H = hessian_list_to_mat(H_list)
|
|
103
193
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
194
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
195
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
196
|
+
with torch.enable_grad():
|
|
197
|
+
g_list = var.get_grad(retain_graph=True)
|
|
198
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
199
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError(hessian_method)
|
|
203
|
+
|
|
204
|
+
H = tikhonov_(H, reg)
|
|
205
|
+
if update_freq != 1:
|
|
206
|
+
self.global_state['H'] = H
|
|
207
|
+
|
|
208
|
+
if H is None:
|
|
209
|
+
H = self.global_state["H"]
|
|
210
|
+
|
|
211
|
+
# var.storage['hessian'] = H
|
|
122
212
|
|
|
123
213
|
# -------------------------------- inner step -------------------------------- #
|
|
124
|
-
update =
|
|
214
|
+
update = var.get_update()
|
|
125
215
|
if 'inner' in self.children:
|
|
126
|
-
update =
|
|
127
|
-
|
|
216
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
|
|
217
|
+
|
|
218
|
+
g = torch.cat([t.ravel() for t in update])
|
|
128
219
|
|
|
129
|
-
# ------------------------------- regulazition ------------------------------- #
|
|
130
|
-
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
131
|
-
else: H = tikhonov_(H, reg)
|
|
132
220
|
|
|
133
221
|
# ----------------------------------- solve ---------------------------------- #
|
|
134
222
|
update = None
|
|
135
223
|
if H_tfm is not None:
|
|
136
|
-
|
|
137
|
-
if is_inv: update = H
|
|
224
|
+
ret = H_tfm(H, g)
|
|
138
225
|
|
|
139
|
-
|
|
140
|
-
|
|
226
|
+
if isinstance(ret, torch.Tensor):
|
|
227
|
+
update = ret
|
|
228
|
+
|
|
229
|
+
else: # returns (H, is_inv)
|
|
230
|
+
H, is_inv = ret
|
|
231
|
+
if is_inv: update = H @ g
|
|
232
|
+
|
|
233
|
+
if search_negative or (eigval_tfm is not None):
|
|
234
|
+
update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
|
|
141
235
|
|
|
142
236
|
if update is None: update = cholesky_solve(H, g)
|
|
143
237
|
if update is None: update = lu_solve(H, g)
|
|
144
238
|
if update is None: update = least_squares_solve(H, g)
|
|
145
239
|
|
|
146
|
-
|
|
147
|
-
|
|
240
|
+
var.update = vec_to_tensors(update, params)
|
|
241
|
+
|
|
242
|
+
return var
|
|
243
|
+
|
|
244
|
+
class InverseFreeNewton(Module):
|
|
245
|
+
"""Inverse-free newton's method
|
|
246
|
+
|
|
247
|
+
.. note::
|
|
248
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
249
|
+
|
|
250
|
+
.. note::
|
|
251
|
+
This module requires the a closure passed to the optimizer step,
|
|
252
|
+
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
253
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
254
|
+
|
|
255
|
+
.. warning::
|
|
256
|
+
this uses roughly O(N^2) memory.
|
|
257
|
+
|
|
258
|
+
Reference
|
|
259
|
+
Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
|
|
260
|
+
"""
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
update_freq: int = 1,
|
|
264
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
265
|
+
vectorize: bool = True,
|
|
266
|
+
inner: Chainable | None = None,
|
|
267
|
+
):
|
|
268
|
+
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
|
|
269
|
+
super().__init__(defaults)
|
|
270
|
+
|
|
271
|
+
if inner is not None:
|
|
272
|
+
self.set_child('inner', inner)
|
|
273
|
+
|
|
274
|
+
@torch.no_grad
|
|
275
|
+
def step(self, var):
|
|
276
|
+
params = TensorList(var.params)
|
|
277
|
+
closure = var.closure
|
|
278
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
279
|
+
|
|
280
|
+
settings = self.settings[params[0]]
|
|
281
|
+
hessian_method = settings['hessian_method']
|
|
282
|
+
vectorize = settings['vectorize']
|
|
283
|
+
update_freq = settings['update_freq']
|
|
284
|
+
|
|
285
|
+
step = self.global_state.get('step', 0)
|
|
286
|
+
self.global_state['step'] = step + 1
|
|
287
|
+
|
|
288
|
+
g_list = var.grad
|
|
289
|
+
Y = None
|
|
290
|
+
if step % update_freq == 0:
|
|
291
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
292
|
+
if hessian_method == 'autograd':
|
|
293
|
+
with torch.enable_grad():
|
|
294
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
295
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
296
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
297
|
+
var.grad = g_list
|
|
298
|
+
H = hessian_list_to_mat(H_list)
|
|
299
|
+
|
|
300
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
301
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
302
|
+
with torch.enable_grad():
|
|
303
|
+
g_list = var.get_grad(retain_graph=True)
|
|
304
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
305
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
306
|
+
|
|
307
|
+
else:
|
|
308
|
+
raise ValueError(hessian_method)
|
|
309
|
+
|
|
310
|
+
# inverse free part
|
|
311
|
+
if 'Y' not in self.global_state:
|
|
312
|
+
num = H.T
|
|
313
|
+
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
314
|
+
eps = torch.finfo(H.dtype).eps
|
|
315
|
+
Y = self.global_state['Y'] = num.div_(denom.clip(min=eps, max=1/eps))
|
|
316
|
+
|
|
317
|
+
else:
|
|
318
|
+
Y = self.global_state['Y']
|
|
319
|
+
I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
320
|
+
I -= H @ Y
|
|
321
|
+
Y = self.global_state['Y'] = Y @ I
|
|
322
|
+
|
|
323
|
+
if Y is None:
|
|
324
|
+
Y = self.global_state["Y"]
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# -------------------------------- inner step -------------------------------- #
|
|
328
|
+
update = var.get_update()
|
|
329
|
+
if 'inner' in self.children:
|
|
330
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
|
|
331
|
+
|
|
332
|
+
g = torch.cat([t.ravel() for t in update])
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
336
|
+
var.update = vec_to_tensors(Y@g, params)
|
|
337
|
+
|
|
338
|
+
return var
|