torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -44,8 +44,8 @@ def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_ne
|
|
|
44
44
|
# use eigvec or -eigvec depending on if it points in same direction as gradient
|
|
45
45
|
return g.dot(d).sign() * d
|
|
46
46
|
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
return Q @ ((Q.mH @ g) / L)
|
|
48
|
+
|
|
49
49
|
except torch.linalg.LinAlgError:
|
|
50
50
|
return None
|
|
51
51
|
|
|
@@ -53,46 +53,109 @@ def tikhonov_(H: torch.Tensor, reg: float):
|
|
|
53
53
|
if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
|
|
54
54
|
return H
|
|
55
55
|
|
|
56
|
-
def eig_tikhonov_(H: torch.Tensor, reg: float):
|
|
57
|
-
v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
|
|
58
|
-
return tikhonov_(H, v)
|
|
59
|
-
|
|
60
56
|
|
|
61
57
|
class Newton(Module):
|
|
62
|
-
"""Exact newton via autograd.
|
|
58
|
+
"""Exact newton's method via autograd.
|
|
59
|
+
|
|
60
|
+
.. note::
|
|
61
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
62
|
+
|
|
63
|
+
.. note::
|
|
64
|
+
This module requires the a closure passed to the optimizer step,
|
|
65
|
+
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
66
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
67
|
+
|
|
68
|
+
.. warning::
|
|
69
|
+
this uses roughly O(N^2) memory.
|
|
70
|
+
|
|
63
71
|
|
|
64
72
|
Args:
|
|
65
73
|
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
66
|
-
eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
|
|
67
74
|
search_negative (bool, Optional):
|
|
68
|
-
if True, whenever a negative eigenvalue is detected,
|
|
75
|
+
if True, whenever a negative eigenvalue is detected,
|
|
76
|
+
search direction is proposed along an eigenvector corresponding to a negative eigenvalue.
|
|
69
77
|
hessian_method (str):
|
|
70
78
|
how to calculate hessian. Defaults to "autograd".
|
|
71
79
|
vectorize (bool, optional):
|
|
72
80
|
whether to enable vectorized hessian. Defaults to True.
|
|
73
|
-
inner (Chainable | None, optional):
|
|
81
|
+
inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
|
|
74
82
|
H_tfm (Callable | None, optional):
|
|
75
83
|
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
76
84
|
|
|
77
|
-
must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
78
|
-
which must be True if transform inverted the hessian and False otherwise.
|
|
85
|
+
must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
86
|
+
which must be True if transform inverted the hessian and False otherwise.
|
|
87
|
+
|
|
88
|
+
Or it returns a single tensor which is used as the update.
|
|
89
|
+
|
|
90
|
+
Defaults to None.
|
|
79
91
|
eigval_tfm (Callable | None, optional):
|
|
80
92
|
optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
|
|
81
|
-
If this is specified, eigendecomposition will be used to
|
|
93
|
+
If this is specified, eigendecomposition will be used to invert the hessian.
|
|
94
|
+
|
|
95
|
+
Examples:
|
|
96
|
+
Newton's method with backtracking line search
|
|
97
|
+
|
|
98
|
+
.. code-block:: python
|
|
99
|
+
|
|
100
|
+
opt = tz.Modular(
|
|
101
|
+
model.parameters(),
|
|
102
|
+
tz.m.Newton(),
|
|
103
|
+
tz.m.Backtracking()
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
Newton's method modified for non-convex functions by taking matrix absolute value of the hessian
|
|
107
|
+
|
|
108
|
+
.. code-block:: python
|
|
109
|
+
|
|
110
|
+
opt = tz.Modular(
|
|
111
|
+
model.parameters(),
|
|
112
|
+
tz.m.Newton(eigval_tfm=lambda x: torch.abs(x).clip(min=0.1)),
|
|
113
|
+
tz.m.Backtracking()
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
Newton's method modified for non-convex functions by searching along negative curvature directions
|
|
117
|
+
|
|
118
|
+
.. code-block:: python
|
|
119
|
+
|
|
120
|
+
opt = tz.Modular(
|
|
121
|
+
model.parameters(),
|
|
122
|
+
tz.m.Newton(search_negative=True),
|
|
123
|
+
tz.m.Backtracking()
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
Newton preconditioning applied to momentum
|
|
127
|
+
|
|
128
|
+
.. code-block:: python
|
|
129
|
+
|
|
130
|
+
opt = tz.Modular(
|
|
131
|
+
model.parameters(),
|
|
132
|
+
tz.m.Newton(inner=tz.m.EMA(0.9)),
|
|
133
|
+
tz.m.LR(0.1)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient, but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.
|
|
137
|
+
|
|
138
|
+
.. code-block:: python
|
|
139
|
+
|
|
140
|
+
opt = tz.Modular(
|
|
141
|
+
model.parameters(),
|
|
142
|
+
tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
|
|
143
|
+
tz.m.Backtracking()
|
|
144
|
+
)
|
|
82
145
|
|
|
83
146
|
"""
|
|
84
147
|
def __init__(
|
|
85
148
|
self,
|
|
86
149
|
reg: float = 1e-6,
|
|
87
|
-
eig_reg: bool = False,
|
|
88
150
|
search_negative: bool = False,
|
|
151
|
+
update_freq: int = 1,
|
|
89
152
|
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
90
153
|
vectorize: bool = True,
|
|
91
154
|
inner: Chainable | None = None,
|
|
92
|
-
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
|
|
155
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
93
156
|
eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
94
157
|
):
|
|
95
|
-
defaults = dict(reg=reg,
|
|
158
|
+
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative, update_freq=update_freq)
|
|
96
159
|
super().__init__(defaults)
|
|
97
160
|
|
|
98
161
|
if inner is not None:
|
|
@@ -106,47 +169,66 @@ class Newton(Module):
|
|
|
106
169
|
|
|
107
170
|
settings = self.settings[params[0]]
|
|
108
171
|
reg = settings['reg']
|
|
109
|
-
eig_reg = settings['eig_reg']
|
|
110
172
|
search_negative = settings['search_negative']
|
|
111
173
|
hessian_method = settings['hessian_method']
|
|
112
174
|
vectorize = settings['vectorize']
|
|
113
175
|
H_tfm = settings['H_tfm']
|
|
114
176
|
eigval_tfm = settings['eigval_tfm']
|
|
177
|
+
update_freq = settings['update_freq']
|
|
178
|
+
|
|
179
|
+
step = self.global_state.get('step', 0)
|
|
180
|
+
self.global_state['step'] = step + 1
|
|
181
|
+
|
|
182
|
+
g_list = var.grad
|
|
183
|
+
H = None
|
|
184
|
+
if step % update_freq == 0:
|
|
185
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
186
|
+
if hessian_method == 'autograd':
|
|
187
|
+
with torch.enable_grad():
|
|
188
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
189
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
190
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
191
|
+
var.grad = g_list
|
|
192
|
+
H = hessian_list_to_mat(H_list)
|
|
115
193
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
194
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
195
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
196
|
+
with torch.enable_grad():
|
|
197
|
+
g_list = var.get_grad(retain_graph=True)
|
|
198
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
199
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError(hessian_method)
|
|
203
|
+
|
|
204
|
+
H = tikhonov_(H, reg)
|
|
205
|
+
if update_freq != 1:
|
|
206
|
+
self.global_state['H'] = H
|
|
207
|
+
|
|
208
|
+
if H is None:
|
|
209
|
+
H = self.global_state["H"]
|
|
210
|
+
|
|
211
|
+
# var.storage['hessian'] = H
|
|
134
212
|
|
|
135
213
|
# -------------------------------- inner step -------------------------------- #
|
|
136
214
|
update = var.get_update()
|
|
137
215
|
if 'inner' in self.children:
|
|
138
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=
|
|
216
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
|
|
217
|
+
|
|
139
218
|
g = torch.cat([t.ravel() for t in update])
|
|
140
219
|
|
|
141
|
-
# ------------------------------- regulazition ------------------------------- #
|
|
142
|
-
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
143
|
-
else: H = tikhonov_(H, reg)
|
|
144
220
|
|
|
145
221
|
# ----------------------------------- solve ---------------------------------- #
|
|
146
222
|
update = None
|
|
147
223
|
if H_tfm is not None:
|
|
148
|
-
|
|
149
|
-
|
|
224
|
+
ret = H_tfm(H, g)
|
|
225
|
+
|
|
226
|
+
if isinstance(ret, torch.Tensor):
|
|
227
|
+
update = ret
|
|
228
|
+
|
|
229
|
+
else: # returns (H, is_inv)
|
|
230
|
+
H, is_inv = ret
|
|
231
|
+
if is_inv: update = H @ g
|
|
150
232
|
|
|
151
233
|
if search_negative or (eigval_tfm is not None):
|
|
152
234
|
update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
|
|
@@ -156,4 +238,101 @@ class Newton(Module):
|
|
|
156
238
|
if update is None: update = least_squares_solve(H, g)
|
|
157
239
|
|
|
158
240
|
var.update = vec_to_tensors(update, params)
|
|
241
|
+
|
|
242
|
+
return var
|
|
243
|
+
|
|
244
|
+
class InverseFreeNewton(Module):
|
|
245
|
+
"""Inverse-free newton's method
|
|
246
|
+
|
|
247
|
+
.. note::
|
|
248
|
+
In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
249
|
+
|
|
250
|
+
.. note::
|
|
251
|
+
This module requires the a closure passed to the optimizer step,
|
|
252
|
+
as it needs to re-evaluate the loss and gradients for calculating the hessian.
|
|
253
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
254
|
+
|
|
255
|
+
.. warning::
|
|
256
|
+
this uses roughly O(N^2) memory.
|
|
257
|
+
|
|
258
|
+
Reference
|
|
259
|
+
Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
|
|
260
|
+
"""
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
update_freq: int = 1,
|
|
264
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
265
|
+
vectorize: bool = True,
|
|
266
|
+
inner: Chainable | None = None,
|
|
267
|
+
):
|
|
268
|
+
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
|
|
269
|
+
super().__init__(defaults)
|
|
270
|
+
|
|
271
|
+
if inner is not None:
|
|
272
|
+
self.set_child('inner', inner)
|
|
273
|
+
|
|
274
|
+
@torch.no_grad
|
|
275
|
+
def step(self, var):
|
|
276
|
+
params = TensorList(var.params)
|
|
277
|
+
closure = var.closure
|
|
278
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
279
|
+
|
|
280
|
+
settings = self.settings[params[0]]
|
|
281
|
+
hessian_method = settings['hessian_method']
|
|
282
|
+
vectorize = settings['vectorize']
|
|
283
|
+
update_freq = settings['update_freq']
|
|
284
|
+
|
|
285
|
+
step = self.global_state.get('step', 0)
|
|
286
|
+
self.global_state['step'] = step + 1
|
|
287
|
+
|
|
288
|
+
g_list = var.grad
|
|
289
|
+
Y = None
|
|
290
|
+
if step % update_freq == 0:
|
|
291
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
292
|
+
if hessian_method == 'autograd':
|
|
293
|
+
with torch.enable_grad():
|
|
294
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
295
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
296
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
297
|
+
var.grad = g_list
|
|
298
|
+
H = hessian_list_to_mat(H_list)
|
|
299
|
+
|
|
300
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
301
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
302
|
+
with torch.enable_grad():
|
|
303
|
+
g_list = var.get_grad(retain_graph=True)
|
|
304
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
305
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
306
|
+
|
|
307
|
+
else:
|
|
308
|
+
raise ValueError(hessian_method)
|
|
309
|
+
|
|
310
|
+
# inverse free part
|
|
311
|
+
if 'Y' not in self.global_state:
|
|
312
|
+
num = H.T
|
|
313
|
+
denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
|
|
314
|
+
eps = torch.finfo(H.dtype).eps
|
|
315
|
+
Y = self.global_state['Y'] = num.div_(denom.clip(min=eps, max=1/eps))
|
|
316
|
+
|
|
317
|
+
else:
|
|
318
|
+
Y = self.global_state['Y']
|
|
319
|
+
I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
|
|
320
|
+
I -= H @ Y
|
|
321
|
+
Y = self.global_state['Y'] = Y @ I
|
|
322
|
+
|
|
323
|
+
if Y is None:
|
|
324
|
+
Y = self.global_state["Y"]
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# -------------------------------- inner step -------------------------------- #
|
|
328
|
+
update = var.get_update()
|
|
329
|
+
if 'inner' in self.children:
|
|
330
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=g_list, var=var)
|
|
331
|
+
|
|
332
|
+
g = torch.cat([t.ravel() for t in update])
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
336
|
+
var.update = vec_to_tensors(Y@g, params)
|
|
337
|
+
|
|
159
338
|
return var
|
|
@@ -1,26 +1,102 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
1
|
from typing import Literal, overload
|
|
3
|
-
import warnings
|
|
4
2
|
import torch
|
|
5
3
|
|
|
6
|
-
from ...utils import TensorList, as_tensorlist,
|
|
4
|
+
from ...utils import TensorList, as_tensorlist, NumberList
|
|
7
5
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
6
|
|
|
9
7
|
from ...core import Chainable, apply_transform, Module
|
|
10
|
-
from ...utils.linalg.solve import cg
|
|
8
|
+
from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
|
|
11
9
|
|
|
12
10
|
class NewtonCG(Module):
|
|
11
|
+
"""Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
|
|
12
|
+
|
|
13
|
+
This optimizer implements Newton's method using a matrix-free conjugate
|
|
14
|
+
gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
|
|
15
|
+
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
16
|
+
(HVPs). These can be calculated efficiently using automatic
|
|
17
|
+
differentiation or approximated using finite differences.
|
|
18
|
+
|
|
19
|
+
.. note::
|
|
20
|
+
In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
This module requires the a closure passed to the optimizer step,
|
|
24
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
25
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
26
|
+
|
|
27
|
+
.. warning::
|
|
28
|
+
CG may fail if hessian is not positive-definite.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
maxiter (int | None, optional):
|
|
32
|
+
Maximum number of iterations for the conjugate gradient solver.
|
|
33
|
+
By default, this is set to the number of dimensions in the
|
|
34
|
+
objective function, which is the theoretical upper bound for CG
|
|
35
|
+
convergence. Setting this to a smaller value (truncated Newton)
|
|
36
|
+
can still generate good search directions. Defaults to None.
|
|
37
|
+
tol (float, optional):
|
|
38
|
+
Relative tolerance for the conjugate gradient solver to determine
|
|
39
|
+
convergence. Defaults to 1e-4.
|
|
40
|
+
reg (float, optional):
|
|
41
|
+
Regularization parameter (damping) added to the Hessian diagonal.
|
|
42
|
+
This helps ensure the system is positive-definite. Defaults to 1e-8.
|
|
43
|
+
hvp_method (str, optional):
|
|
44
|
+
Determines how Hessian-vector products are evaluated.
|
|
45
|
+
|
|
46
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
47
|
+
This requires creating a graph for the gradient.
|
|
48
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
49
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
50
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
51
|
+
more accurate HVP approximation. This requires two extra
|
|
52
|
+
gradient evaluations.
|
|
53
|
+
Defaults to "autograd".
|
|
54
|
+
h (float, optional):
|
|
55
|
+
The step size for finite differences if :code:`hvp_method` is
|
|
56
|
+
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
57
|
+
warm_start (bool, optional):
|
|
58
|
+
If ``True``, the conjugate gradient solver is initialized with the
|
|
59
|
+
solution from the previous optimization step. This can accelerate
|
|
60
|
+
convergence, especially in truncated Newton methods.
|
|
61
|
+
Defaults to False.
|
|
62
|
+
inner (Chainable | None, optional):
|
|
63
|
+
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
64
|
+
|
|
65
|
+
Examples:
|
|
66
|
+
Newton-CG with a backtracking line search:
|
|
67
|
+
|
|
68
|
+
.. code-block:: python
|
|
69
|
+
|
|
70
|
+
opt = tz.Modular(
|
|
71
|
+
model.parameters(),
|
|
72
|
+
tz.m.NewtonCG(),
|
|
73
|
+
tz.m.Backtracking()
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
Truncated Newton method (useful for large-scale problems):
|
|
77
|
+
|
|
78
|
+
.. code-block:: python
|
|
79
|
+
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.NewtonCG(maxiter=10, warm_start=True),
|
|
83
|
+
tz.m.Backtracking()
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
"""
|
|
13
88
|
def __init__(
|
|
14
89
|
self,
|
|
15
|
-
maxiter=None,
|
|
16
|
-
tol=1e-4,
|
|
90
|
+
maxiter: int | None = None,
|
|
91
|
+
tol: float = 1e-4,
|
|
17
92
|
reg: float = 1e-8,
|
|
18
|
-
hvp_method: Literal["forward", "central", "autograd"] = "
|
|
19
|
-
|
|
93
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
94
|
+
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
95
|
+
h: float = 1e-3,
|
|
20
96
|
warm_start=False,
|
|
21
97
|
inner: Chainable | None = None,
|
|
22
98
|
):
|
|
23
|
-
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, warm_start=warm_start)
|
|
99
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
|
|
24
100
|
super().__init__(defaults,)
|
|
25
101
|
|
|
26
102
|
if inner is not None:
|
|
@@ -37,6 +113,7 @@ class NewtonCG(Module):
|
|
|
37
113
|
reg = settings['reg']
|
|
38
114
|
maxiter = settings['maxiter']
|
|
39
115
|
hvp_method = settings['hvp_method']
|
|
116
|
+
solver = settings['solver'].lower().strip()
|
|
40
117
|
h = settings['h']
|
|
41
118
|
warm_start = settings['warm_start']
|
|
42
119
|
|
|
@@ -68,13 +145,25 @@ class NewtonCG(Module):
|
|
|
68
145
|
# -------------------------------- inner step -------------------------------- #
|
|
69
146
|
b = var.get_update()
|
|
70
147
|
if 'inner' in self.children:
|
|
71
|
-
b =
|
|
148
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
149
|
+
b = as_tensorlist(b)
|
|
72
150
|
|
|
73
151
|
# ---------------------------------- run cg ---------------------------------- #
|
|
74
152
|
x0 = None
|
|
75
153
|
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
76
154
|
|
|
77
|
-
|
|
155
|
+
if solver == 'cg':
|
|
156
|
+
x = cg(A_mm=H_mm, b=b, x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
|
|
157
|
+
|
|
158
|
+
elif solver == 'minres':
|
|
159
|
+
x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
|
|
160
|
+
|
|
161
|
+
elif solver == 'minres_npc':
|
|
162
|
+
x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
163
|
+
|
|
164
|
+
else:
|
|
165
|
+
raise ValueError(f"Unknown solver {solver}")
|
|
166
|
+
|
|
78
167
|
if warm_start:
|
|
79
168
|
assert x0 is not None
|
|
80
169
|
x0.copy_(x)
|
|
@@ -83,3 +172,203 @@ class NewtonCG(Module):
|
|
|
83
172
|
return var
|
|
84
173
|
|
|
85
174
|
|
|
175
|
+
class TruncatedNewtonCG(Module):
|
|
176
|
+
"""Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
|
|
177
|
+
|
|
178
|
+
This optimizer implements Newton's method using a matrix-free conjugate
|
|
179
|
+
gradient (CG) solver to approximate the search direction. Instead of
|
|
180
|
+
forming the full Hessian matrix, it only requires Hessian-vector products
|
|
181
|
+
(HVPs). These can be calculated efficiently using automatic
|
|
182
|
+
differentiation or approximated using finite differences.
|
|
183
|
+
|
|
184
|
+
.. note::
|
|
185
|
+
In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
|
|
186
|
+
|
|
187
|
+
.. note::
|
|
188
|
+
This module requires the a closure passed to the optimizer step,
|
|
189
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
190
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
191
|
+
|
|
192
|
+
.. warning::
|
|
193
|
+
CG may fail if hessian is not positive-definite.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
maxiter (int | None, optional):
|
|
197
|
+
Maximum number of iterations for the conjugate gradient solver.
|
|
198
|
+
By default, this is set to the number of dimensions in the
|
|
199
|
+
objective function, which is the theoretical upper bound for CG
|
|
200
|
+
convergence. Setting this to a smaller value (truncated Newton)
|
|
201
|
+
can still generate good search directions. Defaults to None.
|
|
202
|
+
eta (float, optional):
|
|
203
|
+
whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
|
|
204
|
+
nplus (float, optional):
|
|
205
|
+
trust region multiplier on successful steps.
|
|
206
|
+
nminus (float, optional):
|
|
207
|
+
trust region multiplier on unsuccessful steps.
|
|
208
|
+
init (float, optional): initial trust region.
|
|
209
|
+
tol (float, optional):
|
|
210
|
+
Relative tolerance for the conjugate gradient solver to determine
|
|
211
|
+
convergence. Defaults to 1e-4.
|
|
212
|
+
reg (float, optional):
|
|
213
|
+
Regularization parameter (damping) added to the Hessian diagonal.
|
|
214
|
+
This helps ensure the system is positive-definite. Defaults to 1e-8.
|
|
215
|
+
hvp_method (str, optional):
|
|
216
|
+
Determines how Hessian-vector products are evaluated.
|
|
217
|
+
|
|
218
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
219
|
+
This requires creating a graph for the gradient.
|
|
220
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
221
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
222
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
223
|
+
more accurate HVP approximation. This requires two extra
|
|
224
|
+
gradient evaluations.
|
|
225
|
+
Defaults to "autograd".
|
|
226
|
+
h (float, optional):
|
|
227
|
+
The step size for finite differences if :code:`hvp_method` is
|
|
228
|
+
``"forward"`` or ``"central"``. Defaults to 1e-3.
|
|
229
|
+
inner (Chainable | None, optional):
|
|
230
|
+
NewtonCG will attempt to apply preconditioning to the output of this module.
|
|
231
|
+
|
|
232
|
+
Examples:
|
|
233
|
+
Trust-region Newton-CG:
|
|
234
|
+
|
|
235
|
+
.. code-block:: python
|
|
236
|
+
|
|
237
|
+
opt = tz.Modular(
|
|
238
|
+
model.parameters(),
|
|
239
|
+
tz.m.NewtonCGSteihaug(),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
Reference:
|
|
243
|
+
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
|
|
244
|
+
"""
|
|
245
|
+
def __init__(
|
|
246
|
+
self,
|
|
247
|
+
maxiter: int | None = None,
|
|
248
|
+
eta: float= 1e-6,
|
|
249
|
+
nplus: float = 2,
|
|
250
|
+
nminus: float = 0.25,
|
|
251
|
+
init: float = 1,
|
|
252
|
+
tol: float = 1e-4,
|
|
253
|
+
reg: float = 1e-8,
|
|
254
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
255
|
+
solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
|
|
256
|
+
h: float = 1e-3,
|
|
257
|
+
max_attempts: int = 10,
|
|
258
|
+
inner: Chainable | None = None,
|
|
259
|
+
):
|
|
260
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, eta=eta, nplus=nplus, nminus=nminus, init=init, max_attempts=max_attempts, solver=solver)
|
|
261
|
+
super().__init__(defaults,)
|
|
262
|
+
|
|
263
|
+
if inner is not None:
|
|
264
|
+
self.set_child('inner', inner)
|
|
265
|
+
|
|
266
|
+
@torch.no_grad
|
|
267
|
+
def step(self, var):
|
|
268
|
+
params = TensorList(var.params)
|
|
269
|
+
closure = var.closure
|
|
270
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
271
|
+
|
|
272
|
+
settings = self.settings[params[0]]
|
|
273
|
+
tol = settings['tol']
|
|
274
|
+
reg = settings['reg']
|
|
275
|
+
maxiter = settings['maxiter']
|
|
276
|
+
hvp_method = settings['hvp_method']
|
|
277
|
+
h = settings['h']
|
|
278
|
+
max_attempts = settings['max_attempts']
|
|
279
|
+
solver = settings['solver'].lower().strip()
|
|
280
|
+
|
|
281
|
+
eta = settings['eta']
|
|
282
|
+
nplus = settings['nplus']
|
|
283
|
+
nminus = settings['nminus']
|
|
284
|
+
init = settings['init']
|
|
285
|
+
|
|
286
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
287
|
+
if hvp_method == 'autograd':
|
|
288
|
+
grad = var.get_grad(create_graph=True)
|
|
289
|
+
|
|
290
|
+
def H_mm(x):
|
|
291
|
+
with torch.enable_grad():
|
|
292
|
+
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
293
|
+
|
|
294
|
+
else:
|
|
295
|
+
|
|
296
|
+
with torch.enable_grad():
|
|
297
|
+
grad = var.get_grad()
|
|
298
|
+
|
|
299
|
+
if hvp_method == 'forward':
|
|
300
|
+
def H_mm(x):
|
|
301
|
+
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
302
|
+
|
|
303
|
+
elif hvp_method == 'central':
|
|
304
|
+
def H_mm(x):
|
|
305
|
+
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
306
|
+
|
|
307
|
+
else:
|
|
308
|
+
raise ValueError(hvp_method)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
# -------------------------------- inner step -------------------------------- #
|
|
312
|
+
b = var.get_update()
|
|
313
|
+
if 'inner' in self.children:
|
|
314
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
315
|
+
b = as_tensorlist(b)
|
|
316
|
+
|
|
317
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
318
|
+
success = False
|
|
319
|
+
x = None
|
|
320
|
+
while not success:
|
|
321
|
+
max_attempts -= 1
|
|
322
|
+
if max_attempts < 0: break
|
|
323
|
+
|
|
324
|
+
trust_region = self.global_state.get('trust_region', init)
|
|
325
|
+
if trust_region < 1e-8 or trust_region > 1e8:
|
|
326
|
+
trust_region = self.global_state['trust_region'] = init
|
|
327
|
+
|
|
328
|
+
if solver == 'cg':
|
|
329
|
+
x = steihaug_toint_cg(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg)
|
|
330
|
+
|
|
331
|
+
elif solver == 'minres':
|
|
332
|
+
x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
|
|
333
|
+
|
|
334
|
+
elif solver == 'minres_npc':
|
|
335
|
+
x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
|
|
336
|
+
|
|
337
|
+
else:
|
|
338
|
+
raise ValueError(f"unknown solver {solver}")
|
|
339
|
+
|
|
340
|
+
# ------------------------------- trust region ------------------------------- #
|
|
341
|
+
Hx = H_mm(x)
|
|
342
|
+
pred_reduction = b.dot(x) - 0.5 * x.dot(Hx)
|
|
343
|
+
|
|
344
|
+
params -= x
|
|
345
|
+
loss_star = closure(False)
|
|
346
|
+
params += x
|
|
347
|
+
reduction = var.get_loss(False) - loss_star
|
|
348
|
+
|
|
349
|
+
rho = reduction / (pred_reduction.clip(min=1e-8))
|
|
350
|
+
|
|
351
|
+
# failed step
|
|
352
|
+
if rho < 0.25:
|
|
353
|
+
self.global_state['trust_region'] = trust_region * nminus
|
|
354
|
+
|
|
355
|
+
# very good step
|
|
356
|
+
elif rho > 0.75:
|
|
357
|
+
diff = trust_region - x.abs()
|
|
358
|
+
if (diff.global_min() / trust_region) > 1e-4: # hits boundary
|
|
359
|
+
self.global_state['trust_region'] = trust_region * nplus
|
|
360
|
+
|
|
361
|
+
# if the ratio is high enough then accept the proposed step
|
|
362
|
+
if rho > eta:
|
|
363
|
+
success = True
|
|
364
|
+
|
|
365
|
+
assert x is not None
|
|
366
|
+
if success:
|
|
367
|
+
var.update = x
|
|
368
|
+
|
|
369
|
+
else:
|
|
370
|
+
var.update = params.zeros_like()
|
|
371
|
+
|
|
372
|
+
return var
|
|
373
|
+
|
|
374
|
+
|