torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
"""Use BFGS or maybe SR1."""
|
|
2
|
-
from typing import Any, Literal
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
3
|
from collections.abc import Mapping
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
8
|
-
from ...utils import TensorList, set_storage_
|
|
8
|
+
from ...core import Chainable, Module, TensorwiseTransform, Transform
|
|
9
|
+
from ...utils import TensorList, set_storage_, unpack_states
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
def _safe_dict_update_(d1_:dict, d2:dict):
|
|
11
13
|
inter = set(d1_.keys()).intersection(d2.keys())
|
|
@@ -17,14 +19,14 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
|
17
19
|
elif state[key].shape != value.shape: state[key] = value
|
|
18
20
|
else: state[key].lerp_(value, 1-beta)
|
|
19
21
|
|
|
20
|
-
class HessianUpdateStrategy(
|
|
22
|
+
class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
21
23
|
def __init__(
|
|
22
24
|
self,
|
|
23
25
|
defaults: dict | None = None,
|
|
24
26
|
init_scale: float | Literal["auto"] = "auto",
|
|
25
27
|
tol: float = 1e-10,
|
|
26
28
|
tol_reset: bool = True,
|
|
27
|
-
reset_interval: int | None = None,
|
|
29
|
+
reset_interval: int | None | Literal['auto'] = None,
|
|
28
30
|
beta: float | None = None,
|
|
29
31
|
update_freq: int = 1,
|
|
30
32
|
scale_first: bool = True,
|
|
@@ -44,7 +46,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
44
46
|
if ys != 0 and yy != 0: return yy/ys
|
|
45
47
|
return 1
|
|
46
48
|
|
|
47
|
-
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
|
|
49
|
+
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
|
|
48
50
|
set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
|
|
49
51
|
if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
|
|
50
52
|
if init_scale >= 1:
|
|
@@ -62,7 +64,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
62
64
|
raise NotImplementedError
|
|
63
65
|
|
|
64
66
|
@torch.no_grad
|
|
65
|
-
def update_tensor(self, tensor, param, grad, state, settings):
|
|
67
|
+
def update_tensor(self, tensor, param, grad, loss, state, settings):
|
|
66
68
|
p = param.view(-1); g = tensor.view(-1)
|
|
67
69
|
inverse = settings['inverse']
|
|
68
70
|
M_key = 'H' if inverse else 'B'
|
|
@@ -73,6 +75,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
73
75
|
tol = settings['tol']
|
|
74
76
|
tol_reset = settings['tol_reset']
|
|
75
77
|
reset_interval = settings['reset_interval']
|
|
78
|
+
if reset_interval == 'auto': reset_interval = tensor.numel() + 1
|
|
76
79
|
|
|
77
80
|
if M is None:
|
|
78
81
|
M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
|
|
@@ -81,10 +84,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
81
84
|
else: M *= init_scale
|
|
82
85
|
|
|
83
86
|
state[M_key] = M
|
|
87
|
+
state['f_prev'] = loss
|
|
84
88
|
state['p_prev'] = p.clone()
|
|
85
89
|
state['g_prev'] = g.clone()
|
|
86
90
|
return
|
|
87
91
|
|
|
92
|
+
state['f'] = loss
|
|
88
93
|
p_prev = state['p_prev']
|
|
89
94
|
g_prev = state['g_prev']
|
|
90
95
|
s: torch.Tensor = p - p_prev
|
|
@@ -93,13 +98,13 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
93
98
|
state['g_prev'].copy_(g)
|
|
94
99
|
|
|
95
100
|
if reset_interval is not None and step != 0 and step % reset_interval == 0:
|
|
96
|
-
self._reset_M_(M, s, y, inverse, init_scale)
|
|
101
|
+
self._reset_M_(M, s, y, inverse, init_scale, state)
|
|
97
102
|
return
|
|
98
103
|
|
|
99
104
|
# tolerance on gradient difference to avoid exploding after converging
|
|
100
|
-
|
|
105
|
+
if y.abs().max() <= tol:
|
|
101
106
|
# reset history
|
|
102
|
-
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
|
|
107
|
+
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale, state)
|
|
103
108
|
return
|
|
104
109
|
|
|
105
110
|
if step == 1 and init_scale == 'auto':
|
|
@@ -117,8 +122,10 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
117
122
|
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
|
|
118
123
|
_maybe_lerp_(state, 'B', B_new, beta)
|
|
119
124
|
|
|
125
|
+
state['f_prev'] = loss
|
|
126
|
+
|
|
120
127
|
@torch.no_grad
|
|
121
|
-
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
128
|
+
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
122
129
|
step = state.get('step', 0)
|
|
123
130
|
|
|
124
131
|
if settings['scale_second'] and step == 2:
|
|
@@ -198,19 +205,15 @@ class SR1(HUpdateStrategy):
|
|
|
198
205
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
199
206
|
return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
200
207
|
|
|
201
|
-
# BFGS has defaults - init_scale = "auto" and scale_second = False
|
|
202
|
-
# SR1 has defaults - init_scale = 1 and scale_second = True
|
|
203
|
-
# basically some methods work better with first and some with second.
|
|
204
|
-
# I inherit from BFGS or SR1 to avoid writing all those arguments again
|
|
205
208
|
# ------------------------------------ DFP ----------------------------------- #
|
|
206
209
|
def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
207
210
|
sy = torch.dot(s, y)
|
|
208
211
|
if sy.abs() <= tol: return H
|
|
209
212
|
term1 = torch.outer(s, s).div_(sy)
|
|
210
|
-
|
|
211
|
-
if
|
|
213
|
+
yHy = torch.dot(y, H @ y) #
|
|
214
|
+
if yHy.abs() <= tol: return H
|
|
212
215
|
num = H @ torch.outer(y, y) @ H
|
|
213
|
-
term2 = num.div_(
|
|
216
|
+
term2 = num.div_(yHy)
|
|
214
217
|
H += term1.sub_(term2)
|
|
215
218
|
return H
|
|
216
219
|
|
|
@@ -225,34 +228,35 @@ class DFP(HUpdateStrategy):
|
|
|
225
228
|
|
|
226
229
|
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
227
230
|
c = H.T @ s
|
|
228
|
-
|
|
229
|
-
if
|
|
231
|
+
cy = c.dot(y)
|
|
232
|
+
if cy.abs() <= tol: return H
|
|
230
233
|
num = (H@y).sub_(s).outer(c)
|
|
231
|
-
H -= num/
|
|
234
|
+
H -= num/cy
|
|
232
235
|
return H
|
|
233
236
|
|
|
234
237
|
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
235
238
|
c = y
|
|
236
|
-
|
|
237
|
-
if
|
|
239
|
+
cy = c.dot(y)
|
|
240
|
+
if cy.abs() <= tol: return H
|
|
238
241
|
num = (H@y).sub_(s).outer(c)
|
|
239
|
-
H -= num/
|
|
242
|
+
H -= num/cy
|
|
240
243
|
return H
|
|
241
244
|
|
|
242
245
|
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
|
|
243
246
|
c = g_prev
|
|
244
|
-
|
|
245
|
-
if
|
|
247
|
+
cy = c.dot(y)
|
|
248
|
+
if cy.abs() <= tol: return H
|
|
246
249
|
num = (H@y).sub_(s).outer(c)
|
|
247
|
-
H -= num/
|
|
250
|
+
H -= num/cy
|
|
248
251
|
return H
|
|
249
252
|
|
|
250
253
|
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
254
|
+
Hy = H @ y
|
|
255
|
+
c = H @ Hy # pylint:disable=not-callable
|
|
256
|
+
cy = c.dot(y)
|
|
257
|
+
if cy.abs() <= tol: return H
|
|
258
|
+
num = Hy.sub_(s).outer(c)
|
|
259
|
+
H -= num/cy
|
|
256
260
|
return H
|
|
257
261
|
|
|
258
262
|
class BroydenGood(HUpdateStrategy):
|
|
@@ -273,11 +277,7 @@ class Greenstadt2(HUpdateStrategy):
|
|
|
273
277
|
|
|
274
278
|
|
|
275
279
|
def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
|
|
276
|
-
n = H.shape[0]
|
|
277
|
-
|
|
278
280
|
j = y.abs().argmax()
|
|
279
|
-
u = torch.zeros(n, device=H.device, dtype=H.dtype)
|
|
280
|
-
u[j] = 1.0
|
|
281
281
|
|
|
282
282
|
denom = y[j]
|
|
283
283
|
if denom.abs() < tol: return H
|
|
@@ -297,15 +297,15 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
|
|
|
297
297
|
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
298
298
|
I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
299
299
|
d = (R + I * (s_norm/2)) @ s
|
|
300
|
-
|
|
301
|
-
if
|
|
302
|
-
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(
|
|
300
|
+
ds = d.dot(s)
|
|
301
|
+
if ds.abs() <= tol: return H, R
|
|
302
|
+
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
|
|
303
303
|
|
|
304
304
|
c = H.T @ d
|
|
305
|
-
|
|
306
|
-
if
|
|
305
|
+
cy = c.dot(y)
|
|
306
|
+
if cy.abs() <= tol: return H, R
|
|
307
307
|
num = (H@y).sub_(s).outer(c)
|
|
308
|
-
H -= num/
|
|
308
|
+
H -= num/cy
|
|
309
309
|
return H, R
|
|
310
310
|
|
|
311
311
|
class ThomasOptimalMethod(HUpdateStrategy):
|
|
@@ -315,6 +315,11 @@ class ThomasOptimalMethod(HUpdateStrategy):
|
|
|
315
315
|
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
|
|
316
316
|
return H
|
|
317
317
|
|
|
318
|
+
def _reset_M_(self, M, s, y,inverse, init_scale, state):
|
|
319
|
+
super()._reset_M_(M, s, y, inverse, init_scale, state)
|
|
320
|
+
for st in self.state.values():
|
|
321
|
+
st.pop("R", None)
|
|
322
|
+
|
|
318
323
|
# ------------------------ powell's symmetric broyden ------------------------ #
|
|
319
324
|
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
320
325
|
y_Bs = y - B@s
|
|
@@ -326,6 +331,7 @@ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
|
326
331
|
B += term1.sub_(term2)
|
|
327
332
|
return B
|
|
328
333
|
|
|
334
|
+
# I couldn't find formula for H
|
|
329
335
|
class PSB(HessianUpdateStrategy):
|
|
330
336
|
def __init__(
|
|
331
337
|
self,
|
|
@@ -358,17 +364,85 @@ class PSB(HessianUpdateStrategy):
|
|
|
358
364
|
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
|
|
359
365
|
return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
|
|
360
366
|
|
|
361
|
-
|
|
367
|
+
|
|
368
|
+
# Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
|
|
369
|
+
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
370
|
+
Hy = H@y
|
|
371
|
+
yHy = y.dot(Hy)
|
|
372
|
+
if yHy.abs() <= tol: return H
|
|
373
|
+
num = (s - Hy).outer(Hy)
|
|
374
|
+
H += num.div_(yHy)
|
|
375
|
+
return H
|
|
376
|
+
|
|
377
|
+
class Pearson(HUpdateStrategy):
|
|
378
|
+
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
379
|
+
|
|
380
|
+
This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
|
|
381
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
382
|
+
return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
383
|
+
|
|
384
|
+
def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
362
385
|
sy = s.dot(y)
|
|
363
386
|
if sy.abs() <= tol: return H
|
|
364
387
|
num = (s - H@y).outer(s)
|
|
365
388
|
H += num.div_(sy)
|
|
366
389
|
return H
|
|
367
390
|
|
|
368
|
-
class
|
|
369
|
-
"""
|
|
391
|
+
class McCormick(HUpdateStrategy):
|
|
392
|
+
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
393
|
+
|
|
394
|
+
This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
|
|
370
395
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
371
|
-
return
|
|
396
|
+
return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
397
|
+
|
|
398
|
+
def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
399
|
+
Hy = H @ y
|
|
400
|
+
yHy = y.dot(Hy)
|
|
401
|
+
if yHy.abs() < tol: return H, R
|
|
402
|
+
H -= Hy.outer(Hy) / yHy
|
|
403
|
+
R += (s - R@y).outer(Hy) / yHy
|
|
404
|
+
return H, R
|
|
405
|
+
|
|
406
|
+
class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
407
|
+
"""Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
408
|
+
|
|
409
|
+
Algorithm 7"""
|
|
410
|
+
def __init__(
|
|
411
|
+
self,
|
|
412
|
+
init_scale: float | Literal["auto"] = 'auto',
|
|
413
|
+
tol: float = 1e-10,
|
|
414
|
+
tol_reset: bool = True,
|
|
415
|
+
reset_interval: int | None | Literal['auto'] = 'auto',
|
|
416
|
+
beta: float | None = None,
|
|
417
|
+
update_freq: int = 1,
|
|
418
|
+
scale_first: bool = True,
|
|
419
|
+
scale_second: bool = False,
|
|
420
|
+
concat_params: bool = True,
|
|
421
|
+
inner: Chainable | None = None,
|
|
422
|
+
):
|
|
423
|
+
super().__init__(
|
|
424
|
+
init_scale=init_scale,
|
|
425
|
+
tol=tol,
|
|
426
|
+
tol_reset=tol_reset,
|
|
427
|
+
reset_interval=reset_interval,
|
|
428
|
+
beta=beta,
|
|
429
|
+
update_freq=update_freq,
|
|
430
|
+
scale_first=scale_first,
|
|
431
|
+
scale_second=scale_second,
|
|
432
|
+
concat_params=concat_params,
|
|
433
|
+
inverse=True,
|
|
434
|
+
inner=inner,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
438
|
+
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
439
|
+
H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
|
|
440
|
+
state["R"] = R
|
|
441
|
+
return H
|
|
442
|
+
|
|
443
|
+
def _reset_M_(self, M, s, y, inverse, init_scale, state):
|
|
444
|
+
assert inverse
|
|
445
|
+
M.copy_(state["R"])
|
|
372
446
|
|
|
373
447
|
# Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
|
|
374
448
|
def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
|
|
@@ -473,4 +547,137 @@ class SSVM(HessianUpdateStrategy):
|
|
|
473
547
|
)
|
|
474
548
|
|
|
475
549
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
476
|
-
return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
|
|
550
|
+
return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
|
|
551
|
+
|
|
552
|
+
# HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
553
|
+
def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
554
|
+
Hy = H@y
|
|
555
|
+
ys = y.dot(s)
|
|
556
|
+
if ys.abs() <= tol: return H
|
|
557
|
+
yHy = y.dot(Hy)
|
|
558
|
+
denom = ys + yHy
|
|
559
|
+
if denom.abs() <= tol: return H
|
|
560
|
+
|
|
561
|
+
term1 = 1/denom
|
|
562
|
+
term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
|
|
563
|
+
term3 = s.outer(y) @ H
|
|
564
|
+
term4 = Hy.outer(s)
|
|
565
|
+
term5 = Hy.outer(y) @ H
|
|
566
|
+
|
|
567
|
+
inner_term = term2 - term3 - term4 - term5
|
|
568
|
+
H += inner_term.mul_(term1)
|
|
569
|
+
return H
|
|
570
|
+
|
|
571
|
+
def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
572
|
+
sy = s.dot(y)
|
|
573
|
+
if sy.abs() < torch.finfo(g[0].dtype).eps: return g
|
|
574
|
+
return g - (y * (s.dot(g) / sy))
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class GradientCorrection(Transform):
|
|
578
|
+
"""estimates gradient at minima along search direction assuming function is quadratic as proposed in HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
579
|
+
|
|
580
|
+
This can useful as inner module for second order methods."""
|
|
581
|
+
def __init__(self):
|
|
582
|
+
super().__init__(None, uses_grad=False)
|
|
583
|
+
|
|
584
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
585
|
+
if 'p_prev' not in states[0]:
|
|
586
|
+
p_prev = unpack_states(states, tensors, 'p_prev', init=params)
|
|
587
|
+
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
|
|
588
|
+
return tensors
|
|
589
|
+
|
|
590
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
|
|
591
|
+
g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)
|
|
592
|
+
|
|
593
|
+
p_prev.copy_(params)
|
|
594
|
+
g_prev.copy_(tensors)
|
|
595
|
+
return g_hat
|
|
596
|
+
|
|
597
|
+
class Horisho(HUpdateStrategy):
|
|
598
|
+
"""HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394"""
|
|
599
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
600
|
+
return hoshino_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
601
|
+
|
|
602
|
+
# Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
603
|
+
def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
604
|
+
sy = s.dot(y)
|
|
605
|
+
if sy.abs() < tol: return H
|
|
606
|
+
Hy = H @ y
|
|
607
|
+
|
|
608
|
+
term1 = (s.outer(y) @ H).div_(sy)
|
|
609
|
+
term2 = (Hy.outer(s)).div_(sy)
|
|
610
|
+
term3 = 1 + (y.dot(Hy) / sy)
|
|
611
|
+
term4 = s.outer(s).div_(sy)
|
|
612
|
+
|
|
613
|
+
H -= (term1 + term2 - term4.mul_(term3))
|
|
614
|
+
return H
|
|
615
|
+
|
|
616
|
+
class FletcherVMM(HUpdateStrategy):
|
|
617
|
+
"""Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317"""
|
|
618
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
619
|
+
return fletcher_vmm_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
# Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
623
|
+
def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
|
|
624
|
+
sy = s.dot(y)
|
|
625
|
+
if sy < tol: return H
|
|
626
|
+
|
|
627
|
+
term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
|
|
628
|
+
|
|
629
|
+
if type == 1:
|
|
630
|
+
pba = (2*sy + 2*(f-f_prev)) / sy
|
|
631
|
+
|
|
632
|
+
elif type == 2:
|
|
633
|
+
pba = (f_prev - f + 1/(2*sy)) / sy
|
|
634
|
+
|
|
635
|
+
else:
|
|
636
|
+
raise RuntimeError(type)
|
|
637
|
+
|
|
638
|
+
term3 = 1/pba + y.dot(H@y) / sy
|
|
639
|
+
term4 = s.outer(s) / sy
|
|
640
|
+
|
|
641
|
+
H.sub_(term1)
|
|
642
|
+
H.add_(term4.mul_(term3))
|
|
643
|
+
return H
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class NewSSM(HessianUpdateStrategy):
|
|
647
|
+
"""Self-scaling method, requires a line search.
|
|
648
|
+
|
|
649
|
+
Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U."""
|
|
650
|
+
def __init__(
|
|
651
|
+
self,
|
|
652
|
+
type: Literal[1, 2] = 1,
|
|
653
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
654
|
+
tol: float = 1e-10,
|
|
655
|
+
tol_reset: bool = True,
|
|
656
|
+
reset_interval: int | None = None,
|
|
657
|
+
beta: float | None = None,
|
|
658
|
+
update_freq: int = 1,
|
|
659
|
+
scale_first: bool = True,
|
|
660
|
+
scale_second: bool = False,
|
|
661
|
+
concat_params: bool = True,
|
|
662
|
+
inner: Chainable | None = None,
|
|
663
|
+
):
|
|
664
|
+
super().__init__(
|
|
665
|
+
defaults=dict(type=type),
|
|
666
|
+
init_scale=init_scale,
|
|
667
|
+
tol=tol,
|
|
668
|
+
tol_reset=tol_reset,
|
|
669
|
+
reset_interval=reset_interval,
|
|
670
|
+
beta=beta,
|
|
671
|
+
update_freq=update_freq,
|
|
672
|
+
scale_first=scale_first,
|
|
673
|
+
scale_second=scale_second,
|
|
674
|
+
concat_params=concat_params,
|
|
675
|
+
inverse=True,
|
|
676
|
+
inner=inner,
|
|
677
|
+
)
|
|
678
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
679
|
+
f = state['f']
|
|
680
|
+
f_prev = state['f_prev']
|
|
681
|
+
return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=settings['type'], tol=settings['tol'])
|
|
682
|
+
|
|
683
|
+
|
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module,
|
|
8
|
+
from ...core import Chainable, Module, apply_transform
|
|
9
9
|
from ...utils import TensorList, vec_to_tensors
|
|
10
10
|
from ...utils.derivatives import (
|
|
11
11
|
hessian_list_to_mat,
|
|
@@ -18,9 +18,12 @@ from ...utils.derivatives import (
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
21
|
+
try:
|
|
22
|
+
x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
|
|
23
|
+
if info == 0: return x
|
|
24
|
+
return None
|
|
25
|
+
except RuntimeError:
|
|
26
|
+
return None
|
|
24
27
|
|
|
25
28
|
def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
26
29
|
x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
@@ -32,10 +35,15 @@ def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
|
32
35
|
def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
33
36
|
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
34
37
|
|
|
35
|
-
def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
|
|
38
|
+
def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
|
|
36
39
|
try:
|
|
37
40
|
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
38
41
|
if tfm is not None: L = tfm(L)
|
|
42
|
+
if search_negative and L[0] < 0:
|
|
43
|
+
d = Q[0]
|
|
44
|
+
# use eigvec or -eigvec depending on if it points in same direction as gradient
|
|
45
|
+
return g.dot(d).sign() * d
|
|
46
|
+
|
|
39
47
|
L.reciprocal_()
|
|
40
48
|
return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
|
|
41
49
|
except torch.linalg.LinAlgError:
|
|
@@ -56,6 +64,8 @@ class Newton(Module):
|
|
|
56
64
|
Args:
|
|
57
65
|
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
58
66
|
eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
|
|
67
|
+
search_negative (bool, Optional):
|
|
68
|
+
if True, whenever a negative eigenvalue is detected, the direction is taken along an eigenvector corresponding to a negative eigenvalue.
|
|
59
69
|
hessian_method (str):
|
|
60
70
|
how to calculate hessian. Defaults to "autograd".
|
|
61
71
|
vectorize (bool, optional):
|
|
@@ -75,27 +85,29 @@ class Newton(Module):
|
|
|
75
85
|
self,
|
|
76
86
|
reg: float = 1e-6,
|
|
77
87
|
eig_reg: bool = False,
|
|
88
|
+
search_negative: bool = False,
|
|
78
89
|
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
79
90
|
vectorize: bool = True,
|
|
80
91
|
inner: Chainable | None = None,
|
|
81
92
|
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
|
|
82
93
|
eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
83
94
|
):
|
|
84
|
-
defaults = dict(reg=reg, eig_reg=eig_reg,
|
|
95
|
+
defaults = dict(reg=reg, eig_reg=eig_reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative)
|
|
85
96
|
super().__init__(defaults)
|
|
86
97
|
|
|
87
98
|
if inner is not None:
|
|
88
99
|
self.set_child('inner', inner)
|
|
89
100
|
|
|
90
101
|
@torch.no_grad
|
|
91
|
-
def step(self,
|
|
92
|
-
params = TensorList(
|
|
93
|
-
closure =
|
|
102
|
+
def step(self, var):
|
|
103
|
+
params = TensorList(var.params)
|
|
104
|
+
closure = var.closure
|
|
94
105
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
95
106
|
|
|
96
107
|
settings = self.settings[params[0]]
|
|
97
108
|
reg = settings['reg']
|
|
98
109
|
eig_reg = settings['eig_reg']
|
|
110
|
+
search_negative = settings['search_negative']
|
|
99
111
|
hessian_method = settings['hessian_method']
|
|
100
112
|
vectorize = settings['vectorize']
|
|
101
113
|
H_tfm = settings['H_tfm']
|
|
@@ -104,16 +116,16 @@ class Newton(Module):
|
|
|
104
116
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
105
117
|
if hessian_method == 'autograd':
|
|
106
118
|
with torch.enable_grad():
|
|
107
|
-
loss =
|
|
119
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
108
120
|
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
109
121
|
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
110
|
-
|
|
122
|
+
var.grad = g_list
|
|
111
123
|
H = hessian_list_to_mat(H_list)
|
|
112
124
|
|
|
113
125
|
elif hessian_method in ('func', 'autograd.functional'):
|
|
114
126
|
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
115
127
|
with torch.enable_grad():
|
|
116
|
-
g_list =
|
|
128
|
+
g_list = var.get_grad(retain_graph=True)
|
|
117
129
|
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
118
130
|
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
119
131
|
|
|
@@ -121,10 +133,10 @@ class Newton(Module):
|
|
|
121
133
|
raise ValueError(hessian_method)
|
|
122
134
|
|
|
123
135
|
# -------------------------------- inner step -------------------------------- #
|
|
124
|
-
update =
|
|
136
|
+
update = var.get_update()
|
|
125
137
|
if 'inner' in self.children:
|
|
126
|
-
update =
|
|
127
|
-
g = torch.cat([t.
|
|
138
|
+
update = apply_transform(self.children['inner'], update, params=params, grads=list(g_list), var=var)
|
|
139
|
+
g = torch.cat([t.ravel() for t in update])
|
|
128
140
|
|
|
129
141
|
# ------------------------------- regulazition ------------------------------- #
|
|
130
142
|
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
@@ -134,14 +146,14 @@ class Newton(Module):
|
|
|
134
146
|
update = None
|
|
135
147
|
if H_tfm is not None:
|
|
136
148
|
H, is_inv = H_tfm(H, g)
|
|
137
|
-
if is_inv: update = H
|
|
149
|
+
if is_inv: update = H @ g
|
|
138
150
|
|
|
139
|
-
if eigval_tfm is not None:
|
|
140
|
-
update = eigh_solve(H, g, eigval_tfm)
|
|
151
|
+
if search_negative or (eigval_tfm is not None):
|
|
152
|
+
update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
|
|
141
153
|
|
|
142
154
|
if update is None: update = cholesky_solve(H, g)
|
|
143
155
|
if update is None: update = lu_solve(H, g)
|
|
144
156
|
if update is None: update = least_squares_solve(H, g)
|
|
145
157
|
|
|
146
|
-
|
|
147
|
-
return
|
|
158
|
+
var.update = vec_to_tensors(update, params)
|
|
159
|
+
return var
|
|
@@ -6,14 +6,14 @@ import torch
|
|
|
6
6
|
from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
8
|
|
|
9
|
-
from ...core import Chainable,
|
|
9
|
+
from ...core import Chainable, apply_transform, Module
|
|
10
10
|
from ...utils.linalg.solve import cg
|
|
11
11
|
|
|
12
12
|
class NewtonCG(Module):
|
|
13
13
|
def __init__(
|
|
14
14
|
self,
|
|
15
15
|
maxiter=None,
|
|
16
|
-
tol=1e-
|
|
16
|
+
tol=1e-4,
|
|
17
17
|
reg: float = 1e-8,
|
|
18
18
|
hvp_method: Literal["forward", "central", "autograd"] = "forward",
|
|
19
19
|
h=1e-3,
|
|
@@ -27,9 +27,9 @@ class NewtonCG(Module):
|
|
|
27
27
|
self.set_child('inner', inner)
|
|
28
28
|
|
|
29
29
|
@torch.no_grad
|
|
30
|
-
def step(self,
|
|
31
|
-
params = TensorList(
|
|
32
|
-
closure =
|
|
30
|
+
def step(self, var):
|
|
31
|
+
params = TensorList(var.params)
|
|
32
|
+
closure = var.closure
|
|
33
33
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
34
34
|
|
|
35
35
|
settings = self.settings[params[0]]
|
|
@@ -42,7 +42,7 @@ class NewtonCG(Module):
|
|
|
42
42
|
|
|
43
43
|
# ---------------------- Hessian vector product function --------------------- #
|
|
44
44
|
if hvp_method == 'autograd':
|
|
45
|
-
grad =
|
|
45
|
+
grad = var.get_grad(create_graph=True)
|
|
46
46
|
|
|
47
47
|
def H_mm(x):
|
|
48
48
|
with torch.enable_grad():
|
|
@@ -51,7 +51,7 @@ class NewtonCG(Module):
|
|
|
51
51
|
else:
|
|
52
52
|
|
|
53
53
|
with torch.enable_grad():
|
|
54
|
-
grad =
|
|
54
|
+
grad = var.get_grad()
|
|
55
55
|
|
|
56
56
|
if hvp_method == 'forward':
|
|
57
57
|
def H_mm(x):
|
|
@@ -66,19 +66,20 @@ class NewtonCG(Module):
|
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
# -------------------------------- inner step -------------------------------- #
|
|
69
|
-
b =
|
|
69
|
+
b = var.get_update()
|
|
70
70
|
if 'inner' in self.children:
|
|
71
|
-
b = as_tensorlist(
|
|
71
|
+
b = as_tensorlist(apply_transform(self.children['inner'], b, params=params, grads=grad, var=var))
|
|
72
72
|
|
|
73
73
|
# ---------------------------------- run cg ---------------------------------- #
|
|
74
74
|
x0 = None
|
|
75
|
-
if warm_start: x0 = self.get_state('prev_x',
|
|
75
|
+
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
76
|
+
|
|
76
77
|
x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
|
|
77
78
|
if warm_start:
|
|
78
79
|
assert x0 is not None
|
|
79
80
|
x0.copy_(x)
|
|
80
81
|
|
|
81
|
-
|
|
82
|
-
return
|
|
82
|
+
var.update = x
|
|
83
|
+
return var
|
|
83
84
|
|
|
84
85
|
|