torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +55 -22
- tests/test_tensorlist.py +3 -3
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +20 -130
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +111 -0
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +76 -26
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +15 -15
- torchzero/modules/quasi_newton/lsr1.py +18 -17
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +257 -48
- torchzero/modules/second_order/newton.py +38 -21
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +19 -19
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.8.dist-info/RECORD +0 -130
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"""Use BFGS or maybe SR1."""
|
|
2
|
-
from typing import Any, Literal
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
3
|
from collections.abc import Mapping
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
8
|
-
from ...utils import TensorList, set_storage_
|
|
8
|
+
from ...core import Chainable, Module, TensorwiseTransform, Transform
|
|
9
|
+
from ...utils import TensorList, set_storage_, unpack_states
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
def _safe_dict_update_(d1_:dict, d2:dict):
|
|
11
13
|
inter = set(d1_.keys()).intersection(d2.keys())
|
|
@@ -17,14 +19,14 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
|
17
19
|
elif state[key].shape != value.shape: state[key] = value
|
|
18
20
|
else: state[key].lerp_(value, 1-beta)
|
|
19
21
|
|
|
20
|
-
class HessianUpdateStrategy(
|
|
22
|
+
class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
21
23
|
def __init__(
|
|
22
24
|
self,
|
|
23
25
|
defaults: dict | None = None,
|
|
24
26
|
init_scale: float | Literal["auto"] = "auto",
|
|
25
27
|
tol: float = 1e-10,
|
|
26
28
|
tol_reset: bool = True,
|
|
27
|
-
reset_interval: int | None = None,
|
|
29
|
+
reset_interval: int | None | Literal['auto'] = None,
|
|
28
30
|
beta: float | None = None,
|
|
29
31
|
update_freq: int = 1,
|
|
30
32
|
scale_first: bool = True,
|
|
@@ -44,7 +46,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
44
46
|
if ys != 0 and yy != 0: return yy/ys
|
|
45
47
|
return 1
|
|
46
48
|
|
|
47
|
-
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
|
|
49
|
+
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
|
|
48
50
|
set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
|
|
49
51
|
if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
|
|
50
52
|
if init_scale >= 1:
|
|
@@ -62,7 +64,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
62
64
|
raise NotImplementedError
|
|
63
65
|
|
|
64
66
|
@torch.no_grad
|
|
65
|
-
def update_tensor(self, tensor, param, grad, state, settings):
|
|
67
|
+
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
66
68
|
p = param.view(-1); g = tensor.view(-1)
|
|
67
69
|
inverse = settings['inverse']
|
|
68
70
|
M_key = 'H' if inverse else 'B'
|
|
@@ -73,6 +75,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
73
75
|
tol = settings['tol']
|
|
74
76
|
tol_reset = settings['tol_reset']
|
|
75
77
|
reset_interval = settings['reset_interval']
|
|
78
|
+
if reset_interval == 'auto': reset_interval = tensor.numel() + 1
|
|
76
79
|
|
|
77
80
|
if M is None:
|
|
78
81
|
M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
|
|
@@ -81,10 +84,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
81
84
|
else: M *= init_scale
|
|
82
85
|
|
|
83
86
|
state[M_key] = M
|
|
87
|
+
state['f_prev'] = loss
|
|
84
88
|
state['p_prev'] = p.clone()
|
|
85
89
|
state['g_prev'] = g.clone()
|
|
86
90
|
return
|
|
87
91
|
|
|
92
|
+
state['f'] = loss
|
|
88
93
|
p_prev = state['p_prev']
|
|
89
94
|
g_prev = state['g_prev']
|
|
90
95
|
s: torch.Tensor = p - p_prev
|
|
@@ -93,13 +98,13 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
93
98
|
state['g_prev'].copy_(g)
|
|
94
99
|
|
|
95
100
|
if reset_interval is not None and step != 0 and step % reset_interval == 0:
|
|
96
|
-
self._reset_M_(M, s, y, inverse, init_scale)
|
|
101
|
+
self._reset_M_(M, s, y, inverse, init_scale, state)
|
|
97
102
|
return
|
|
98
103
|
|
|
99
104
|
# tolerance on gradient difference to avoid exploding after converging
|
|
100
|
-
|
|
105
|
+
if y.abs().max() <= tol:
|
|
101
106
|
# reset history
|
|
102
|
-
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
|
|
107
|
+
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale, state)
|
|
103
108
|
return
|
|
104
109
|
|
|
105
110
|
if step == 1 and init_scale == 'auto':
|
|
@@ -117,12 +122,16 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
117
122
|
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
|
|
118
123
|
_maybe_lerp_(state, 'B', B_new, beta)
|
|
119
124
|
|
|
125
|
+
state['f_prev'] = loss
|
|
126
|
+
|
|
120
127
|
@torch.no_grad
|
|
121
|
-
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
128
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
122
129
|
step = state.get('step', 0)
|
|
123
130
|
|
|
124
131
|
if settings['scale_second'] and step == 2:
|
|
125
|
-
|
|
132
|
+
scale_factor = 1 / tensor.abs().sum().clip(min=1)
|
|
133
|
+
scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
|
|
134
|
+
tensor = tensor * scale_factor
|
|
126
135
|
|
|
127
136
|
inverse = settings['inverse']
|
|
128
137
|
if inverse:
|
|
@@ -196,19 +205,15 @@ class SR1(HUpdateStrategy):
|
|
|
196
205
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
197
206
|
return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
198
207
|
|
|
199
|
-
# BFGS has defaults - init_scale = "auto" and scale_second = False
|
|
200
|
-
# SR1 has defaults - init_scale = 1 and scale_second = True
|
|
201
|
-
# basically some methods work better with first and some with second.
|
|
202
|
-
# I inherit from BFGS or SR1 to avoid writing all those arguments again
|
|
203
208
|
# ------------------------------------ DFP ----------------------------------- #
|
|
204
209
|
def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
205
210
|
sy = torch.dot(s, y)
|
|
206
211
|
if sy.abs() <= tol: return H
|
|
207
212
|
term1 = torch.outer(s, s).div_(sy)
|
|
208
|
-
|
|
209
|
-
if
|
|
213
|
+
yHy = torch.dot(y, H @ y) #
|
|
214
|
+
if yHy.abs() <= tol: return H
|
|
210
215
|
num = H @ torch.outer(y, y) @ H
|
|
211
|
-
term2 = num.div_(
|
|
216
|
+
term2 = num.div_(yHy)
|
|
212
217
|
H += term1.sub_(term2)
|
|
213
218
|
return H
|
|
214
219
|
|
|
@@ -223,34 +228,35 @@ class DFP(HUpdateStrategy):
|
|
|
223
228
|
|
|
224
229
|
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
225
230
|
c = H.T @ s
|
|
226
|
-
|
|
227
|
-
if
|
|
231
|
+
cy = c.dot(y)
|
|
232
|
+
if cy.abs() <= tol: return H
|
|
228
233
|
num = (H@y).sub_(s).outer(c)
|
|
229
|
-
H -= num/
|
|
234
|
+
H -= num/cy
|
|
230
235
|
return H
|
|
231
236
|
|
|
232
237
|
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
233
238
|
c = y
|
|
234
|
-
|
|
235
|
-
if
|
|
239
|
+
cy = c.dot(y)
|
|
240
|
+
if cy.abs() <= tol: return H
|
|
236
241
|
num = (H@y).sub_(s).outer(c)
|
|
237
|
-
H -= num/
|
|
242
|
+
H -= num/cy
|
|
238
243
|
return H
|
|
239
244
|
|
|
240
245
|
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
|
|
241
246
|
c = g_prev
|
|
242
|
-
|
|
243
|
-
if
|
|
247
|
+
cy = c.dot(y)
|
|
248
|
+
if cy.abs() <= tol: return H
|
|
244
249
|
num = (H@y).sub_(s).outer(c)
|
|
245
|
-
H -= num/
|
|
250
|
+
H -= num/cy
|
|
246
251
|
return H
|
|
247
252
|
|
|
248
253
|
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
+
Hy = H @ y
|
|
255
|
+
c = H @ Hy # pylint:disable=not-callable
|
|
256
|
+
cy = c.dot(y)
|
|
257
|
+
if cy.abs() <= tol: return H
|
|
258
|
+
num = Hy.sub_(s).outer(c)
|
|
259
|
+
H -= num/cy
|
|
254
260
|
return H
|
|
255
261
|
|
|
256
262
|
class BroydenGood(HUpdateStrategy):
|
|
@@ -271,11 +277,7 @@ class Greenstadt2(HUpdateStrategy):
|
|
|
271
277
|
|
|
272
278
|
|
|
273
279
|
def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
|
|
274
|
-
n = H.shape[0]
|
|
275
|
-
|
|
276
280
|
j = y.abs().argmax()
|
|
277
|
-
u = torch.zeros(n, device=H.device, dtype=H.dtype)
|
|
278
|
-
u[j] = 1.0
|
|
279
281
|
|
|
280
282
|
denom = y[j]
|
|
281
283
|
if denom.abs() < tol: return H
|
|
@@ -295,15 +297,15 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
|
|
|
295
297
|
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
296
298
|
I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
297
299
|
d = (R + I * (s_norm/2)) @ s
|
|
298
|
-
|
|
299
|
-
if
|
|
300
|
-
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(
|
|
300
|
+
ds = d.dot(s)
|
|
301
|
+
if ds.abs() <= tol: return H, R
|
|
302
|
+
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
|
|
301
303
|
|
|
302
304
|
c = H.T @ d
|
|
303
|
-
|
|
304
|
-
if
|
|
305
|
+
cy = c.dot(y)
|
|
306
|
+
if cy.abs() <= tol: return H, R
|
|
305
307
|
num = (H@y).sub_(s).outer(c)
|
|
306
|
-
H -= num/
|
|
308
|
+
H -= num/cy
|
|
307
309
|
return H, R
|
|
308
310
|
|
|
309
311
|
class ThomasOptimalMethod(HUpdateStrategy):
|
|
@@ -313,6 +315,11 @@ class ThomasOptimalMethod(HUpdateStrategy):
|
|
|
313
315
|
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
|
|
314
316
|
return H
|
|
315
317
|
|
|
318
|
+
def _reset_M_(self, M, s, y,inverse, init_scale, state):
|
|
319
|
+
super()._reset_M_(M, s, y, inverse, init_scale, state)
|
|
320
|
+
for st in self.state.values():
|
|
321
|
+
st.pop("R", None)
|
|
322
|
+
|
|
316
323
|
# ------------------------ powell's symmetric broyden ------------------------ #
|
|
317
324
|
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
318
325
|
y_Bs = y - B@s
|
|
@@ -324,6 +331,7 @@ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
|
324
331
|
B += term1.sub_(term2)
|
|
325
332
|
return B
|
|
326
333
|
|
|
334
|
+
# I couldn't find formula for H
|
|
327
335
|
class PSB(HessianUpdateStrategy):
|
|
328
336
|
def __init__(
|
|
329
337
|
self,
|
|
@@ -356,17 +364,85 @@ class PSB(HessianUpdateStrategy):
|
|
|
356
364
|
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
|
|
357
365
|
return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
|
|
358
366
|
|
|
359
|
-
|
|
367
|
+
|
|
368
|
+
# Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
|
|
369
|
+
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
370
|
+
Hy = H@y
|
|
371
|
+
yHy = y.dot(Hy)
|
|
372
|
+
if yHy.abs() <= tol: return H
|
|
373
|
+
num = (s - Hy).outer(Hy)
|
|
374
|
+
H += num.div_(yHy)
|
|
375
|
+
return H
|
|
376
|
+
|
|
377
|
+
class Pearson(HUpdateStrategy):
|
|
378
|
+
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
379
|
+
|
|
380
|
+
This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
|
|
381
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
382
|
+
return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
383
|
+
|
|
384
|
+
def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
360
385
|
sy = s.dot(y)
|
|
361
386
|
if sy.abs() <= tol: return H
|
|
362
387
|
num = (s - H@y).outer(s)
|
|
363
388
|
H += num.div_(sy)
|
|
364
389
|
return H
|
|
365
390
|
|
|
366
|
-
class
|
|
367
|
-
"""
|
|
391
|
+
class McCormick(HUpdateStrategy):
|
|
392
|
+
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
393
|
+
|
|
394
|
+
This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
|
|
368
395
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
369
|
-
return
|
|
396
|
+
return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
397
|
+
|
|
398
|
+
def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
399
|
+
Hy = H @ y
|
|
400
|
+
yHy = y.dot(Hy)
|
|
401
|
+
if yHy.abs() < tol: return H, R
|
|
402
|
+
H -= Hy.outer(Hy) / yHy
|
|
403
|
+
R += (s - R@y).outer(Hy) / yHy
|
|
404
|
+
return H, R
|
|
405
|
+
|
|
406
|
+
class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
407
|
+
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
408
|
+
|
|
409
|
+
Algorithm 7"""
|
|
410
|
+
def __init__(
|
|
411
|
+
self,
|
|
412
|
+
init_scale: float | Literal["auto"] = 'auto',
|
|
413
|
+
tol: float = 1e-10,
|
|
414
|
+
tol_reset: bool = True,
|
|
415
|
+
reset_interval: int | None | Literal['auto'] = 'auto',
|
|
416
|
+
beta: float | None = None,
|
|
417
|
+
update_freq: int = 1,
|
|
418
|
+
scale_first: bool = True,
|
|
419
|
+
scale_second: bool = False,
|
|
420
|
+
concat_params: bool = True,
|
|
421
|
+
inner: Chainable | None = None,
|
|
422
|
+
):
|
|
423
|
+
super().__init__(
|
|
424
|
+
init_scale=init_scale,
|
|
425
|
+
tol=tol,
|
|
426
|
+
tol_reset=tol_reset,
|
|
427
|
+
reset_interval=reset_interval,
|
|
428
|
+
beta=beta,
|
|
429
|
+
update_freq=update_freq,
|
|
430
|
+
scale_first=scale_first,
|
|
431
|
+
scale_second=scale_second,
|
|
432
|
+
concat_params=concat_params,
|
|
433
|
+
inverse=True,
|
|
434
|
+
inner=inner,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
438
|
+
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
439
|
+
H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
|
|
440
|
+
state["R"] = R
|
|
441
|
+
return H
|
|
442
|
+
|
|
443
|
+
def _reset_M_(self, M, s, y, inverse, init_scale, state):
|
|
444
|
+
assert inverse
|
|
445
|
+
M.copy_(state["R"])
|
|
370
446
|
|
|
371
447
|
# Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
|
|
372
448
|
def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
|
|
@@ -471,4 +547,137 @@ class SSVM(HessianUpdateStrategy):
|
|
|
471
547
|
)
|
|
472
548
|
|
|
473
549
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
474
|
-
return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
|
|
550
|
+
return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
|
|
551
|
+
|
|
552
|
+
# HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
553
|
+
def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
554
|
+
Hy = H@y
|
|
555
|
+
ys = y.dot(s)
|
|
556
|
+
if ys.abs() <= tol: return H
|
|
557
|
+
yHy = y.dot(Hy)
|
|
558
|
+
denom = ys + yHy
|
|
559
|
+
if denom.abs() <= tol: return H
|
|
560
|
+
|
|
561
|
+
term1 = 1/denom
|
|
562
|
+
term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
|
|
563
|
+
term3 = s.outer(y) @ H
|
|
564
|
+
term4 = Hy.outer(s)
|
|
565
|
+
term5 = Hy.outer(y) @ H
|
|
566
|
+
|
|
567
|
+
inner_term = term2 - term3 - term4 - term5
|
|
568
|
+
H += inner_term.mul_(term1)
|
|
569
|
+
return H
|
|
570
|
+
|
|
571
|
+
def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
572
|
+
sy = s.dot(y)
|
|
573
|
+
if sy.abs() < torch.finfo(g[0].dtype).eps: return g
|
|
574
|
+
return g - (y * (s.dot(g) / sy))
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class GradientCorrection(Transform):
|
|
578
|
+
"""estimates gradient at minima along search direction assuming function is quadratic as proposed in HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
579
|
+
|
|
580
|
+
This can useful as inner module for second order methods."""
|
|
581
|
+
def __init__(self):
|
|
582
|
+
super().__init__(None, uses_grad=False)
|
|
583
|
+
|
|
584
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
585
|
+
if 'p_prev' not in states[0]:
|
|
586
|
+
p_prev = unpack_states(states, tensors, 'p_prev', init=params)
|
|
587
|
+
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
|
|
588
|
+
return tensors
|
|
589
|
+
|
|
590
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
|
|
591
|
+
g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)
|
|
592
|
+
|
|
593
|
+
p_prev.copy_(params)
|
|
594
|
+
g_prev.copy_(tensors)
|
|
595
|
+
return g_hat
|
|
596
|
+
|
|
597
|
+
class Horisho(HUpdateStrategy):
|
|
598
|
+
"""HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394"""
|
|
599
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
600
|
+
return hoshino_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
601
|
+
|
|
602
|
+
# Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
603
|
+
def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
604
|
+
sy = s.dot(y)
|
|
605
|
+
if sy.abs() < tol: return H
|
|
606
|
+
Hy = H @ y
|
|
607
|
+
|
|
608
|
+
term1 = (s.outer(y) @ H).div_(sy)
|
|
609
|
+
term2 = (Hy.outer(s)).div_(sy)
|
|
610
|
+
term3 = 1 + (y.dot(Hy) / sy)
|
|
611
|
+
term4 = s.outer(s).div_(sy)
|
|
612
|
+
|
|
613
|
+
H -= (term1 + term2 - term4.mul_(term3))
|
|
614
|
+
return H
|
|
615
|
+
|
|
616
|
+
class FletcherVMM(HUpdateStrategy):
|
|
617
|
+
"""Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317"""
|
|
618
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
619
|
+
return fletcher_vmm_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
# Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
623
|
+
def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
|
|
624
|
+
sy = s.dot(y)
|
|
625
|
+
if sy < tol: return H
|
|
626
|
+
|
|
627
|
+
term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
|
|
628
|
+
|
|
629
|
+
if type == 1:
|
|
630
|
+
pba = (2*sy + 2*(f-f_prev)) / sy
|
|
631
|
+
|
|
632
|
+
elif type == 2:
|
|
633
|
+
pba = (f_prev - f + 1/(2*sy)) / sy
|
|
634
|
+
|
|
635
|
+
else:
|
|
636
|
+
raise RuntimeError(type)
|
|
637
|
+
|
|
638
|
+
term3 = 1/pba + y.dot(H@y) / sy
|
|
639
|
+
term4 = s.outer(s) / sy
|
|
640
|
+
|
|
641
|
+
H.sub_(term1)
|
|
642
|
+
H.add_(term4.mul_(term3))
|
|
643
|
+
return H
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class NewSSM(HessianUpdateStrategy):
|
|
647
|
+
"""Self-scaling method, requires a line search.
|
|
648
|
+
|
|
649
|
+
Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U."""
|
|
650
|
+
def __init__(
|
|
651
|
+
self,
|
|
652
|
+
type: Literal[1, 2] = 1,
|
|
653
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
654
|
+
tol: float = 1e-10,
|
|
655
|
+
tol_reset: bool = True,
|
|
656
|
+
reset_interval: int | None = None,
|
|
657
|
+
beta: float | None = None,
|
|
658
|
+
update_freq: int = 1,
|
|
659
|
+
scale_first: bool = True,
|
|
660
|
+
scale_second: bool = False,
|
|
661
|
+
concat_params: bool = True,
|
|
662
|
+
inner: Chainable | None = None,
|
|
663
|
+
):
|
|
664
|
+
super().__init__(
|
|
665
|
+
defaults=dict(type=type),
|
|
666
|
+
init_scale=init_scale,
|
|
667
|
+
tol=tol,
|
|
668
|
+
tol_reset=tol_reset,
|
|
669
|
+
reset_interval=reset_interval,
|
|
670
|
+
beta=beta,
|
|
671
|
+
update_freq=update_freq,
|
|
672
|
+
scale_first=scale_first,
|
|
673
|
+
scale_second=scale_second,
|
|
674
|
+
concat_params=concat_params,
|
|
675
|
+
inverse=True,
|
|
676
|
+
inner=inner,
|
|
677
|
+
)
|
|
678
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
679
|
+
f = state['f']
|
|
680
|
+
f_prev = state['f_prev']
|
|
681
|
+
return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=settings['type'], tol=settings['tol'])
|
|
682
|
+
|
|
683
|
+
|
|
@@ -1,22 +1,29 @@
|
|
|
1
1
|
import warnings
|
|
2
|
+
from collections.abc import Callable
|
|
2
3
|
from functools import partial
|
|
3
4
|
from typing import Literal
|
|
4
|
-
|
|
5
|
+
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable,
|
|
8
|
-
from ...utils import
|
|
8
|
+
from ...core import Chainable, Module, apply_transform
|
|
9
|
+
from ...utils import TensorList, vec_to_tensors
|
|
9
10
|
from ...utils.derivatives import (
|
|
10
11
|
hessian_list_to_mat,
|
|
11
12
|
hessian_mat,
|
|
13
|
+
hvp,
|
|
14
|
+
hvp_fd_central,
|
|
15
|
+
hvp_fd_forward,
|
|
12
16
|
jacobian_and_hessian_wrt,
|
|
13
17
|
)
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
21
|
+
try:
|
|
22
|
+
x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
|
|
23
|
+
if info == 0: return x
|
|
24
|
+
return None
|
|
25
|
+
except RuntimeError:
|
|
26
|
+
return None
|
|
20
27
|
|
|
21
28
|
def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
22
29
|
x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
@@ -28,10 +35,15 @@ def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
28
35
|
def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
29
36
|
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
30
37
|
|
|
31
|
-
def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
|
|
38
|
+
def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
|
|
32
39
|
try:
|
|
33
40
|
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
34
41
|
if tfm is not None: L = tfm(L)
|
|
42
|
+
if search_negative and L[0] < 0:
|
|
43
|
+
d = Q[0]
|
|
44
|
+
# use eigvec or -eigvec depending on if it points in same direction as gradient
|
|
45
|
+
return g.dot(d).sign() * d
|
|
46
|
+
|
|
35
47
|
L.reciprocal_()
|
|
36
48
|
return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
|
|
37
49
|
except torch.linalg.LinAlgError:
|
|
@@ -52,6 +64,8 @@ class Newton(Module):
|
|
|
52
64
|
Args:
|
|
53
65
|
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
54
66
|
eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
|
|
67
|
+
search_negative (bool, Optional):
|
|
68
|
+
if True, whenever a negative eigenvalue is detected, the direction is taken along an eigenvector corresponding to a negative eigenvalue.
|
|
55
69
|
hessian_method (str):
|
|
56
70
|
how to calculate hessian. Defaults to "autograd".
|
|
57
71
|
vectorize (bool, optional):
|
|
@@ -71,27 +85,29 @@ class Newton(Module):
|
|
|
71
85
|
self,
|
|
72
86
|
reg: float = 1e-6,
|
|
73
87
|
eig_reg: bool = False,
|
|
88
|
+
search_negative: bool = False,
|
|
74
89
|
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
75
90
|
vectorize: bool = True,
|
|
76
91
|
inner: Chainable | None = None,
|
|
77
92
|
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
|
|
78
93
|
eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
79
94
|
):
|
|
80
|
-
defaults = dict(reg=reg, eig_reg=eig_reg,
|
|
95
|
+
defaults = dict(reg=reg, eig_reg=eig_reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative)
|
|
81
96
|
super().__init__(defaults)
|
|
82
97
|
|
|
83
98
|
if inner is not None:
|
|
84
99
|
self.set_child('inner', inner)
|
|
85
100
|
|
|
86
101
|
@torch.no_grad
|
|
87
|
-
def step(self,
|
|
88
|
-
params = TensorList(
|
|
89
|
-
closure =
|
|
102
|
+
def step(self, var):
|
|
103
|
+
params = TensorList(var.params)
|
|
104
|
+
closure = var.closure
|
|
90
105
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
91
106
|
|
|
92
107
|
settings = self.settings[params[0]]
|
|
93
108
|
reg = settings['reg']
|
|
94
109
|
eig_reg = settings['eig_reg']
|
|
110
|
+
search_negative = settings['search_negative']
|
|
95
111
|
hessian_method = settings['hessian_method']
|
|
96
112
|
vectorize = settings['vectorize']
|
|
97
113
|
H_tfm = settings['H_tfm']
|
|
@@ -100,16 +116,16 @@ class Newton(Module):
|
|
|
100
116
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
101
117
|
if hessian_method == 'autograd':
|
|
102
118
|
with torch.enable_grad():
|
|
103
|
-
loss =
|
|
119
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
104
120
|
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
105
121
|
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
106
|
-
|
|
122
|
+
var.grad = g_list
|
|
107
123
|
H = hessian_list_to_mat(H_list)
|
|
108
124
|
|
|
109
125
|
elif hessian_method in ('func', 'autograd.functional'):
|
|
110
126
|
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
111
127
|
with torch.enable_grad():
|
|
112
|
-
g_list =
|
|
128
|
+
g_list = var.get_grad(retain_graph=True)
|
|
113
129
|
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
114
130
|
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
115
131
|
|
|
@@ -117,9 +133,10 @@ class Newton(Module):
|
|
|
117
133
|
raise ValueError(hessian_method)
|
|
118
134
|
|
|
119
135
|
# -------------------------------- inner step -------------------------------- #
|
|
136
|
+
update = var.get_update()
|
|
120
137
|
if 'inner' in self.children:
|
|
121
|
-
|
|
122
|
-
g = torch.cat([t.
|
|
138
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=list(g_list), var=var)
|
|
139
|
+
g = torch.cat([t.ravel() for t in update])
|
|
123
140
|
|
|
124
141
|
# ------------------------------- regulazition ------------------------------- #
|
|
125
142
|
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
@@ -129,14 +146,14 @@ class Newton(Module):
|
|
|
129
146
|
update = None
|
|
130
147
|
if H_tfm is not None:
|
|
131
148
|
H, is_inv = H_tfm(H, g)
|
|
132
|
-
if is_inv: update = H
|
|
149
|
+
if is_inv: update = H @ g
|
|
133
150
|
|
|
134
|
-
if eigval_tfm is not None:
|
|
135
|
-
update = eigh_solve(H, g, eigval_tfm)
|
|
151
|
+
if search_negative or (eigval_tfm is not None):
|
|
152
|
+
update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
|
|
136
153
|
|
|
137
154
|
if update is None: update = cholesky_solve(H, g)
|
|
138
155
|
if update is None: update = lu_solve(H, g)
|
|
139
156
|
if update is None: update = least_squares_solve(H, g)
|
|
140
157
|
|
|
141
|
-
|
|
142
|
-
return
|
|
158
|
+
var.update = vec_to_tensors(update, params)
|
|
159
|
+
return var
|
|
@@ -6,14 +6,14 @@ import torch
|
|
|
6
6
|
from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable,
|
|
9
|
+
from ...core import Chainable, apply_transform, Module
|
|
10
10
|
from ...utils.linalg.solve import cg
|
|
11
11
|
|
|
12
12
|
class NewtonCG(Module):
|
|
13
13
|
def __init__(
|
|
14
14
|
self,
|
|
15
15
|
maxiter=None,
|
|
16
|
-
tol=1e-
|
|
16
|
+
tol=1e-4,
|
|
17
17
|
reg: float = 1e-8,
|
|
18
18
|
hvp_method: Literal["forward", "central", "autograd"] = "forward",
|
|
19
19
|
h=1e-3,
|
|
@@ -27,9 +27,9 @@ class NewtonCG(Module):
|
|
|
27
27
|
self.set_child('inner', inner)
|
|
28
28
|
|
|
29
29
|
@torch.no_grad
|
|
30
|
-
def step(self,
|
|
31
|
-
params = TensorList(
|
|
32
|
-
closure =
|
|
30
|
+
def step(self, var):
|
|
31
|
+
params = TensorList(var.params)
|
|
32
|
+
closure = var.closure
|
|
33
33
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
34
34
|
|
|
35
35
|
settings = self.settings[params[0]]
|
|
@@ -42,7 +42,7 @@ class NewtonCG(Module):
|
|
|
42
42
|
|
|
43
43
|
# ---------------------- Hessian vector product function --------------------- #
|
|
44
44
|
if hvp_method == 'autograd':
|
|
45
|
-
grad =
|
|
45
|
+
grad = var.get_grad(create_graph=True)
|
|
46
46
|
|
|
47
47
|
def H_mm(x):
|
|
48
48
|
with torch.enable_grad():
|
|
@@ -51,7 +51,7 @@ class NewtonCG(Module):
|
|
|
51
51
|
else:
|
|
52
52
|
|
|
53
53
|
with torch.enable_grad():
|
|
54
|
-
grad =
|
|
54
|
+
grad = var.get_grad()
|
|
55
55
|
|
|
56
56
|
if hvp_method == 'forward':
|
|
57
57
|
def H_mm(x):
|
|
@@ -66,19 +66,20 @@ class NewtonCG(Module):
|
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
# -------------------------------- inner step -------------------------------- #
|
|
69
|
-
b =
|
|
69
|
+
b = var.get_update()
|
|
70
70
|
if 'inner' in self.children:
|
|
71
|
-
b = as_tensorlist(
|
|
71
|
+
b = as_tensorlist(apply_transform(self.children['inner'], b, params=params, grads=grad, var=var))
|
|
72
72
|
|
|
73
73
|
# ---------------------------------- run cg ---------------------------------- #
|
|
74
74
|
x0 = None
|
|
75
|
-
if warm_start: x0 = self.get_state('prev_x',
|
|
75
|
+
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
76
|
+
|
|
76
77
|
x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
|
|
77
78
|
if warm_start:
|
|
78
79
|
assert x0 is not None
|
|
79
80
|
x0.copy_(x)
|
|
80
81
|
|
|
81
|
-
|
|
82
|
-
return
|
|
82
|
+
var.update = x
|
|
83
|
+
return var
|
|
83
84
|
|
|
84
85
|
|