torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from functools import partial
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, Transform, Vars, apply
|
|
9
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
10
|
+
from .lbfgs import _adaptive_damping, lbfgs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@torch.no_grad
|
|
14
|
+
def _store_sk_yk_after_step_hook(optimizer, vars: Vars, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
|
|
15
|
+
assert vars.closure is not None
|
|
16
|
+
with torch.enable_grad(): vars.closure()
|
|
17
|
+
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in vars.params]
|
|
18
|
+
s_k = vars.params - prev_params
|
|
19
|
+
y_k = grad - prev_grad
|
|
20
|
+
ys_k = s_k.dot(y_k)
|
|
21
|
+
|
|
22
|
+
if damping:
|
|
23
|
+
s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
24
|
+
|
|
25
|
+
if ys_k > 1e-10:
|
|
26
|
+
s_history.append(s_k)
|
|
27
|
+
y_history.append(y_k)
|
|
28
|
+
sy_history.append(ys_k)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OnlineLBFGS(Module):
|
|
33
|
+
"""Online L-BFGS.
|
|
34
|
+
Parameter and gradient differences are sampled from the same mini-batch by performing an extra forward and backward pass.
|
|
35
|
+
However I did a bunch of experiments and the online part doesn't seem to help. Normal L-BFGS is usually still
|
|
36
|
+
better because it performs twice as many steps, and it is reasonably stable with normalization or grafting.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
40
|
+
sample_grads (str, optional):
|
|
41
|
+
- "before" - samples current mini-batch gradient at previous and current parameters, calculates y_k
|
|
42
|
+
and adds it to history before stepping.
|
|
43
|
+
- "after" - samples current mini-batch gradient at parameters before stepping and after updating parameters.
|
|
44
|
+
s_k and y_k are added after parameter update, therefore they are delayed by 1 step.
|
|
45
|
+
|
|
46
|
+
In practice both modes behave very similarly. Defaults to 'before'.
|
|
47
|
+
tol (float | None, optional):
|
|
48
|
+
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
49
|
+
damping (bool, optional):
|
|
50
|
+
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
51
|
+
init_damping (float, optional):
|
|
52
|
+
initial damping for adaptive dampening. Defaults to 0.9.
|
|
53
|
+
eigval_bounds (tuple, optional):
|
|
54
|
+
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
55
|
+
params_beta (float | None, optional):
|
|
56
|
+
if not None, EMA of parameters is used for preconditioner update. Defaults to None.
|
|
57
|
+
grads_beta (float | None, optional):
|
|
58
|
+
if not None, EMA of gradients is used for preconditioner update. Defaults to None.
|
|
59
|
+
update_freq (int, optional):
|
|
60
|
+
how often to update L-BFGS history. Defaults to 1.
|
|
61
|
+
z_beta (float | None, optional):
|
|
62
|
+
optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
|
|
63
|
+
inner (Chainable | None, optional):
|
|
64
|
+
optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
|
|
65
|
+
"""
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
history_size=10,
|
|
69
|
+
sample_grads: Literal['before', 'after'] = 'before',
|
|
70
|
+
tol: float | None = 1e-10,
|
|
71
|
+
damping: bool = False,
|
|
72
|
+
init_damping=0.9,
|
|
73
|
+
eigval_bounds=(0.5, 50),
|
|
74
|
+
z_beta: float | None = None,
|
|
75
|
+
inner: Chainable | None = None,
|
|
76
|
+
):
|
|
77
|
+
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, sample_grads=sample_grads, z_beta=z_beta)
|
|
78
|
+
super().__init__(defaults)
|
|
79
|
+
|
|
80
|
+
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
81
|
+
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
82
|
+
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
83
|
+
|
|
84
|
+
if inner is not None:
|
|
85
|
+
self.set_child('inner', inner)
|
|
86
|
+
|
|
87
|
+
def reset(self):
|
|
88
|
+
"""Resets the internal state of the L-SR1 module."""
|
|
89
|
+
# super().reset() # Clears self.state (per-parameter) if any, and "step"
|
|
90
|
+
# Re-initialize L-SR1 specific global state
|
|
91
|
+
self.state.clear()
|
|
92
|
+
self.global_state['step'] = 0
|
|
93
|
+
self.global_state['s_history'].clear()
|
|
94
|
+
self.global_state['y_history'].clear()
|
|
95
|
+
self.global_state['sy_history'].clear()
|
|
96
|
+
|
|
97
|
+
@torch.no_grad
|
|
98
|
+
def step(self, vars):
|
|
99
|
+
assert vars.closure is not None
|
|
100
|
+
|
|
101
|
+
params = as_tensorlist(vars.params)
|
|
102
|
+
update = as_tensorlist(vars.get_update())
|
|
103
|
+
step = self.global_state.get('step', 0)
|
|
104
|
+
self.global_state['step'] = step + 1
|
|
105
|
+
|
|
106
|
+
# history of s and k
|
|
107
|
+
s_history: deque[TensorList] = self.global_state['s_history']
|
|
108
|
+
y_history: deque[TensorList] = self.global_state['y_history']
|
|
109
|
+
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
110
|
+
|
|
111
|
+
tol, damping, init_damping, eigval_bounds, sample_grads, z_beta = itemgetter(
|
|
112
|
+
'tol', 'damping', 'init_damping', 'eigval_bounds', 'sample_grads', 'z_beta')(self.settings[params[0]])
|
|
113
|
+
|
|
114
|
+
# sample gradient at previous params with current mini-batch
|
|
115
|
+
if sample_grads == 'before':
|
|
116
|
+
prev_params = self.get_state('prev_params', params=params, cls=TensorList)
|
|
117
|
+
if step == 0:
|
|
118
|
+
s_k = None; y_k = None; ys_k = None
|
|
119
|
+
else:
|
|
120
|
+
s_k = params - prev_params
|
|
121
|
+
|
|
122
|
+
current_params = params.clone()
|
|
123
|
+
params.set_(prev_params)
|
|
124
|
+
with torch.enable_grad(): vars.closure()
|
|
125
|
+
y_k = update - params.grad
|
|
126
|
+
ys_k = s_k.dot(y_k)
|
|
127
|
+
params.set_(current_params)
|
|
128
|
+
|
|
129
|
+
if damping:
|
|
130
|
+
s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
131
|
+
|
|
132
|
+
if ys_k > 1e-10:
|
|
133
|
+
s_history.append(s_k)
|
|
134
|
+
y_history.append(y_k)
|
|
135
|
+
sy_history.append(ys_k)
|
|
136
|
+
|
|
137
|
+
prev_params.copy_(params)
|
|
138
|
+
|
|
139
|
+
# use previous s_k, y_k pair, samples gradient at current batch before and after updating parameters
|
|
140
|
+
elif sample_grads == 'after':
|
|
141
|
+
if len(s_history) == 0:
|
|
142
|
+
s_k = None; y_k = None; ys_k = None
|
|
143
|
+
else:
|
|
144
|
+
s_k = s_history[-1]
|
|
145
|
+
y_k = y_history[-1]
|
|
146
|
+
ys_k = s_k.dot(y_k)
|
|
147
|
+
|
|
148
|
+
# this will run after params are updated by Modular after running all future modules
|
|
149
|
+
vars.post_step_hooks.append(
|
|
150
|
+
partial(
|
|
151
|
+
_store_sk_yk_after_step_hook,
|
|
152
|
+
prev_params=params.clone(),
|
|
153
|
+
prev_grad=update.clone(),
|
|
154
|
+
damping=damping,
|
|
155
|
+
init_damping=init_damping,
|
|
156
|
+
eigval_bounds=eigval_bounds,
|
|
157
|
+
s_history=s_history,
|
|
158
|
+
y_history=y_history,
|
|
159
|
+
sy_history=sy_history,
|
|
160
|
+
))
|
|
161
|
+
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(sample_grads)
|
|
164
|
+
|
|
165
|
+
# step with inner module before applying preconditioner
|
|
166
|
+
if self.children:
|
|
167
|
+
update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
|
|
168
|
+
|
|
169
|
+
# tolerance on gradient difference to avoid exploding after converging
|
|
170
|
+
if tol is not None:
|
|
171
|
+
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
172
|
+
vars.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
173
|
+
return vars
|
|
174
|
+
|
|
175
|
+
# lerp initial H^-1 @ q guess
|
|
176
|
+
z_ema = None
|
|
177
|
+
if z_beta is not None:
|
|
178
|
+
z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
|
|
179
|
+
|
|
180
|
+
# precondition
|
|
181
|
+
dir = lbfgs(
|
|
182
|
+
tensors_=as_tensorlist(update),
|
|
183
|
+
s_history=s_history,
|
|
184
|
+
y_history=y_history,
|
|
185
|
+
sy_history=sy_history,
|
|
186
|
+
y_k=y_k,
|
|
187
|
+
ys_k=ys_k,
|
|
188
|
+
z_beta = z_beta,
|
|
189
|
+
z_ema = z_ema,
|
|
190
|
+
step=step
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
vars.update = dir
|
|
194
|
+
|
|
195
|
+
return vars
|
|
196
|
+
|
|
@@ -0,0 +1,475 @@
|
|
|
1
|
+
"""Use BFGS or maybe SR1."""
|
|
2
|
+
from typing import Any, Literal
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Mapping
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Preconditioner, TensorwisePreconditioner
|
|
8
|
+
from ...utils import TensorList, set_storage_
|
|
9
|
+
|
|
10
|
+
def _safe_dict_update_(d1_:dict, d2:dict):
|
|
11
|
+
inter = set(d1_.keys()).intersection(d2.keys())
|
|
12
|
+
if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
|
|
13
|
+
d1_.update(d2)
|
|
14
|
+
|
|
15
|
+
def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
16
|
+
if (beta is None) or (beta == 0) or (key not in state): state[key] = value
|
|
17
|
+
elif state[key].shape != value.shape: state[key] = value
|
|
18
|
+
else: state[key].lerp_(value, 1-beta)
|
|
19
|
+
|
|
20
|
+
class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
defaults: dict | None = None,
|
|
24
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
25
|
+
tol: float = 1e-10,
|
|
26
|
+
tol_reset: bool = True,
|
|
27
|
+
reset_interval: int | None = None,
|
|
28
|
+
beta: float | None = None,
|
|
29
|
+
update_freq: int = 1,
|
|
30
|
+
scale_first: bool = True,
|
|
31
|
+
scale_second: bool = False,
|
|
32
|
+
concat_params: bool = True,
|
|
33
|
+
inverse: bool = True,
|
|
34
|
+
inner: Chainable | None = None,
|
|
35
|
+
):
|
|
36
|
+
if defaults is None: defaults = {}
|
|
37
|
+
_safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, tol_reset=tol_reset, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
|
|
38
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
|
|
39
|
+
|
|
40
|
+
def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
|
|
41
|
+
"""returns multiplier to H or B"""
|
|
42
|
+
ys = y.dot(s)
|
|
43
|
+
yy = y.dot(y)
|
|
44
|
+
if ys != 0 and yy != 0: return yy/ys
|
|
45
|
+
return 1
|
|
46
|
+
|
|
47
|
+
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
|
|
48
|
+
set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
|
|
49
|
+
if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
|
|
50
|
+
if init_scale >= 1:
|
|
51
|
+
if inverse: M /= init_scale
|
|
52
|
+
else: M *= init_scale
|
|
53
|
+
|
|
54
|
+
def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
55
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
|
|
56
|
+
"""update hessian inverse"""
|
|
57
|
+
raise NotImplementedError
|
|
58
|
+
|
|
59
|
+
def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
60
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
|
|
61
|
+
"""update hessian"""
|
|
62
|
+
raise NotImplementedError
|
|
63
|
+
|
|
64
|
+
@torch.no_grad
|
|
65
|
+
def update_tensor(self, tensor, param, grad, state, settings):
|
|
66
|
+
p = param.view(-1); g = tensor.view(-1)
|
|
67
|
+
inverse = settings['inverse']
|
|
68
|
+
M_key = 'H' if inverse else 'B'
|
|
69
|
+
M = state.get(M_key, None)
|
|
70
|
+
step = state.get('step', 0)
|
|
71
|
+
init_scale = settings['init_scale']
|
|
72
|
+
tol = settings['tol']
|
|
73
|
+
tol_reset = settings['tol_reset']
|
|
74
|
+
reset_interval = settings['reset_interval']
|
|
75
|
+
|
|
76
|
+
if M is None:
|
|
77
|
+
M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
|
|
78
|
+
if isinstance(init_scale, (int, float)) and init_scale != 1:
|
|
79
|
+
if inverse: M /= init_scale
|
|
80
|
+
else: M *= init_scale
|
|
81
|
+
|
|
82
|
+
state[M_key] = M
|
|
83
|
+
state['p_prev'] = p.clone()
|
|
84
|
+
state['g_prev'] = g.clone()
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
p_prev = state['p_prev']
|
|
88
|
+
g_prev = state['g_prev']
|
|
89
|
+
s: torch.Tensor = p - p_prev
|
|
90
|
+
y: torch.Tensor = g - g_prev
|
|
91
|
+
state['p_prev'].copy_(p)
|
|
92
|
+
state['g_prev'].copy_(g)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
if reset_interval is not None and step % reset_interval == 0:
|
|
96
|
+
self._reset_M_(M, s, y, inverse, init_scale)
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# tolerance on gradient difference to avoid exploding after converging
|
|
100
|
+
if y.abs().max() <= tol:
|
|
101
|
+
# reset history
|
|
102
|
+
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
if step == 1 and init_scale == 'auto':
|
|
106
|
+
if inverse: M /= self._get_init_scale(s,y)
|
|
107
|
+
else: M *= self._get_init_scale(s,y)
|
|
108
|
+
|
|
109
|
+
beta = settings['beta']
|
|
110
|
+
if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
|
|
111
|
+
|
|
112
|
+
if inverse:
|
|
113
|
+
H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
|
|
114
|
+
_maybe_lerp_(state, 'H', H_new, beta)
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
|
|
118
|
+
_maybe_lerp_(state, 'B', B_new, beta)
|
|
119
|
+
|
|
120
|
+
@torch.no_grad
|
|
121
|
+
def apply_tensor(self, tensor, param, grad, state, settings):
|
|
122
|
+
step = state['step'] = state.get('step', 0) + 1
|
|
123
|
+
|
|
124
|
+
if settings['scale_second'] and step == 2:
|
|
125
|
+
s = max(1, tensor.abs().sum()) # pyright:ignore[reportArgumentType]
|
|
126
|
+
if s < settings['tol']: tensor = tensor/s
|
|
127
|
+
|
|
128
|
+
inverse = settings['inverse']
|
|
129
|
+
if inverse:
|
|
130
|
+
H = state['H']
|
|
131
|
+
return (H @ tensor.view(-1)).view_as(tensor)
|
|
132
|
+
|
|
133
|
+
B = state['B']
|
|
134
|
+
|
|
135
|
+
return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
|
|
136
|
+
|
|
137
|
+
# to avoid typing all arguments for each method
|
|
138
|
+
class QuasiNewtonH(HessianUpdateStrategy):
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
142
|
+
tol: float = 1e-10,
|
|
143
|
+
tol_reset: bool = True,
|
|
144
|
+
reset_interval: int | None = None,
|
|
145
|
+
beta: float | None = None,
|
|
146
|
+
update_freq: int = 1,
|
|
147
|
+
scale_first: bool = True,
|
|
148
|
+
scale_second: bool = False,
|
|
149
|
+
concat_params: bool = True,
|
|
150
|
+
inner: Chainable | None = None,
|
|
151
|
+
):
|
|
152
|
+
super().__init__(
|
|
153
|
+
defaults=None,
|
|
154
|
+
init_scale=init_scale,
|
|
155
|
+
tol=tol,
|
|
156
|
+
tol_reset=tol_reset,
|
|
157
|
+
reset_interval=reset_interval,
|
|
158
|
+
beta=beta,
|
|
159
|
+
update_freq=update_freq,
|
|
160
|
+
scale_first=scale_first,
|
|
161
|
+
scale_second=scale_second,
|
|
162
|
+
concat_params=concat_params,
|
|
163
|
+
inverse=True,
|
|
164
|
+
inner=inner,
|
|
165
|
+
)
|
|
166
|
+
# ----------------------------------- BFGS ----------------------------------- #
|
|
167
|
+
def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
168
|
+
sy = torch.dot(s, y)
|
|
169
|
+
if sy <= tol: return H # don't reset H in this case
|
|
170
|
+
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
171
|
+
term1 = num1.div_(sy**2)
|
|
172
|
+
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
173
|
+
term2 = num2.div_(sy)
|
|
174
|
+
H += term1.sub_(term2)
|
|
175
|
+
return H
|
|
176
|
+
|
|
177
|
+
class BFGS(QuasiNewtonH):
|
|
178
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
179
|
+
return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
180
|
+
|
|
181
|
+
# ------------------------------------ SR1 ----------------------------------- #
|
|
182
|
+
def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
183
|
+
z = s - H@y
|
|
184
|
+
denom = torch.dot(z, y)
|
|
185
|
+
|
|
186
|
+
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
187
|
+
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
188
|
+
|
|
189
|
+
if y_norm*z_norm < tol: return H
|
|
190
|
+
|
|
191
|
+
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
192
|
+
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
193
|
+
H += torch.outer(z, z).div_(denom)
|
|
194
|
+
return H
|
|
195
|
+
|
|
196
|
+
class SR1(QuasiNewtonH):
|
|
197
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
198
|
+
return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
199
|
+
|
|
200
|
+
# BFGS has defaults - init_scale = "auto" and scale_second = False
|
|
201
|
+
# SR1 has defaults - init_scale = 1 and scale_second = True
|
|
202
|
+
# basically some methods work better with first and some with second.
|
|
203
|
+
# I inherit from BFGS or SR1 to avoid writing all those arguments again
|
|
204
|
+
# ------------------------------------ DFP ----------------------------------- #
|
|
205
|
+
def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
206
|
+
sy = torch.dot(s, y)
|
|
207
|
+
if sy.abs() <= tol: return H
|
|
208
|
+
term1 = torch.outer(s, s).div_(sy)
|
|
209
|
+
denom = torch.dot(y, H @ y) #
|
|
210
|
+
if denom.abs() <= tol: return H
|
|
211
|
+
num = H @ torch.outer(y, y) @ H
|
|
212
|
+
term2 = num.div_(denom)
|
|
213
|
+
H += term1.sub_(term2)
|
|
214
|
+
return H
|
|
215
|
+
|
|
216
|
+
class DFP(QuasiNewtonH):
|
|
217
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
218
|
+
return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
222
|
+
# H' = H - (Hy - S)c^T / c^T*y
|
|
223
|
+
# the difference is how `c` is calculated
|
|
224
|
+
|
|
225
|
+
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
226
|
+
c = H.T @ s
|
|
227
|
+
denom = c.dot(y)
|
|
228
|
+
if denom.abs() <= tol: return H
|
|
229
|
+
num = (H@y).sub_(s).outer(c)
|
|
230
|
+
H -= num/denom
|
|
231
|
+
return H
|
|
232
|
+
|
|
233
|
+
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
234
|
+
c = y
|
|
235
|
+
denom = c.dot(y)
|
|
236
|
+
if denom.abs() <= tol: return H
|
|
237
|
+
num = (H@y).sub_(s).outer(c)
|
|
238
|
+
H -= num/denom
|
|
239
|
+
return H
|
|
240
|
+
|
|
241
|
+
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
|
|
242
|
+
c = g_prev
|
|
243
|
+
denom = c.dot(y)
|
|
244
|
+
if denom.abs() <= tol: return H
|
|
245
|
+
num = (H@y).sub_(s).outer(c)
|
|
246
|
+
H -= num/denom
|
|
247
|
+
return H
|
|
248
|
+
|
|
249
|
+
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
250
|
+
c = torch.linalg.multi_dot([H,H,y]) # pylint:disable=not-callable
|
|
251
|
+
denom = c.dot(y)
|
|
252
|
+
if denom.abs() <= tol: return H
|
|
253
|
+
num = (H@y).sub_(s).outer(c)
|
|
254
|
+
H -= num/denom
|
|
255
|
+
return H
|
|
256
|
+
|
|
257
|
+
class BroydenGood(QuasiNewtonH):
|
|
258
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
259
|
+
return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
260
|
+
|
|
261
|
+
class BroydenBad(QuasiNewtonH):
|
|
262
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
263
|
+
return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
264
|
+
|
|
265
|
+
class Greenstadt1(QuasiNewtonH):
|
|
266
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
267
|
+
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
|
|
268
|
+
|
|
269
|
+
class Greenstadt2(QuasiNewtonH):
|
|
270
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
271
|
+
return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
|
|
275
|
+
n = H.shape[0]
|
|
276
|
+
|
|
277
|
+
j = y.abs().argmax()
|
|
278
|
+
u = torch.zeros(n, device=H.device, dtype=H.dtype)
|
|
279
|
+
u[j] = 1.0
|
|
280
|
+
|
|
281
|
+
denom = y[j]
|
|
282
|
+
if denom.abs() < tol: return H
|
|
283
|
+
|
|
284
|
+
Hy = H @ y.unsqueeze(1)
|
|
285
|
+
num = s.unsqueeze(1) - Hy
|
|
286
|
+
|
|
287
|
+
H[:, j] += num.squeeze() / denom
|
|
288
|
+
return H
|
|
289
|
+
|
|
290
|
+
class ColumnUpdatingMethod(QuasiNewtonH):
|
|
291
|
+
"""Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
|
|
292
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
293
|
+
return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
294
|
+
|
|
295
|
+
def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
296
|
+
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
297
|
+
I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
298
|
+
d = (R + I * (s_norm/2)) @ s
|
|
299
|
+
denom = d.dot(s)
|
|
300
|
+
if denom.abs() <= tol: return H, R
|
|
301
|
+
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(denom)))
|
|
302
|
+
|
|
303
|
+
c = H.T @ d
|
|
304
|
+
denom = c.dot(y)
|
|
305
|
+
if denom.abs() <= tol: return H, R
|
|
306
|
+
num = (H@y).sub_(s).outer(c)
|
|
307
|
+
H -= num/denom
|
|
308
|
+
return H, R
|
|
309
|
+
|
|
310
|
+
class ThomasOptimalMethod(QuasiNewtonH):
|
|
311
|
+
"""Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
|
|
312
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
313
|
+
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
314
|
+
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
|
|
315
|
+
return H
|
|
316
|
+
|
|
317
|
+
# ------------------------ powell's symmetric broyden ------------------------ #
|
|
318
|
+
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
|
|
319
|
+
y_Bs = y - B@s
|
|
320
|
+
ss = s.dot(s)
|
|
321
|
+
if ss.abs() < tol: return B
|
|
322
|
+
num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
|
|
323
|
+
term1 = num1.div_(ss)
|
|
324
|
+
term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
|
|
325
|
+
B += term1.sub_(term2)
|
|
326
|
+
return B
|
|
327
|
+
|
|
328
|
+
class PSB(HessianUpdateStrategy):
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
init_scale: float | Literal["auto"] = 'auto',
|
|
332
|
+
tol: float = 1e-10,
|
|
333
|
+
tol_reset: bool = True,
|
|
334
|
+
reset_interval: int | None = None,
|
|
335
|
+
beta: float | None = None,
|
|
336
|
+
update_freq: int = 1,
|
|
337
|
+
scale_first: bool = True,
|
|
338
|
+
scale_second: bool = False,
|
|
339
|
+
concat_params: bool = True,
|
|
340
|
+
inner: Chainable | None = None,
|
|
341
|
+
):
|
|
342
|
+
super().__init__(
|
|
343
|
+
defaults=None,
|
|
344
|
+
init_scale=init_scale,
|
|
345
|
+
tol=tol,
|
|
346
|
+
tol_reset=tol_reset,
|
|
347
|
+
reset_interval=reset_interval,
|
|
348
|
+
beta=beta,
|
|
349
|
+
update_freq=update_freq,
|
|
350
|
+
scale_first=scale_first,
|
|
351
|
+
scale_second=scale_second,
|
|
352
|
+
concat_params=concat_params,
|
|
353
|
+
inverse=False,
|
|
354
|
+
inner=inner,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
|
|
358
|
+
return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
|
|
359
|
+
|
|
360
|
+
def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
361
|
+
sy = s.dot(y)
|
|
362
|
+
if sy.abs() <= tol: return H
|
|
363
|
+
num = (s - H@y).outer(s)
|
|
364
|
+
H += num.div_(sy)
|
|
365
|
+
return H
|
|
366
|
+
|
|
367
|
+
class Pearson2(QuasiNewtonH):
|
|
368
|
+
"""finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
|
|
369
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
370
|
+
return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
371
|
+
|
|
372
|
+
# Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
|
|
373
|
+
def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
|
|
374
|
+
# in notation p is s, q is y, H is D
|
|
375
|
+
# another p is lr
|
|
376
|
+
# omega (o) = sy
|
|
377
|
+
# tau (t) = yHy
|
|
378
|
+
# epsilon = p'D^-1 p
|
|
379
|
+
# however p.12 says eps = gs / gHy
|
|
380
|
+
|
|
381
|
+
Hy = H@y
|
|
382
|
+
gHy = g.dot(Hy)
|
|
383
|
+
yHy = y.dot(Hy)
|
|
384
|
+
sy = s.dot(y)
|
|
385
|
+
if sy < tol: return H
|
|
386
|
+
if yHy.abs() < tol: return H
|
|
387
|
+
if gHy.abs() < tol: return H
|
|
388
|
+
|
|
389
|
+
v_mul = yHy.sqrt()
|
|
390
|
+
v_term1 = s/sy
|
|
391
|
+
v_term2 = Hy/yHy
|
|
392
|
+
v = (v_term1.sub_(v_term2)).mul_(v_mul)
|
|
393
|
+
gs = g.dot(s)
|
|
394
|
+
|
|
395
|
+
if isinstance(switch, tuple): phi, theta = switch
|
|
396
|
+
else:
|
|
397
|
+
o = sy
|
|
398
|
+
t = yHy
|
|
399
|
+
e = gs / gHy
|
|
400
|
+
if switch in (1, 3):
|
|
401
|
+
if e/o <= 1:
|
|
402
|
+
if o.abs() <= tol: return H
|
|
403
|
+
phi = e/o
|
|
404
|
+
theta = 0
|
|
405
|
+
elif o/t >= 1:
|
|
406
|
+
if t.abs() <= tol: return H
|
|
407
|
+
phi = o/t
|
|
408
|
+
theta = 1
|
|
409
|
+
else:
|
|
410
|
+
phi = 1
|
|
411
|
+
denom = e*t - o**2
|
|
412
|
+
if denom.abs() <= tol: return H
|
|
413
|
+
if switch == 1: theta = o * (e - o) / denom
|
|
414
|
+
else: theta = o * (t - o) / denom
|
|
415
|
+
|
|
416
|
+
elif switch == 2:
|
|
417
|
+
if t.abs() <= tol or o.abs() <= tol or e.abs() <= tol: return H
|
|
418
|
+
phi = (e / t) ** 0.5
|
|
419
|
+
theta = 1 / (1 + (t*e / o**2)**0.5)
|
|
420
|
+
|
|
421
|
+
elif switch == 4:
|
|
422
|
+
if t.abs() <= tol: return H
|
|
423
|
+
phi = e/t
|
|
424
|
+
theta = 1/2
|
|
425
|
+
|
|
426
|
+
else: raise ValueError(switch)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
u = phi * (gs/gHy) + (1 - phi) * (sy/yHy)
|
|
430
|
+
term1 = (H @ y.outer(y) @ H).div_(yHy)
|
|
431
|
+
term2 = v.outer(v).mul_(theta)
|
|
432
|
+
term3 = s.outer(s).div_(sy)
|
|
433
|
+
|
|
434
|
+
H -= term1
|
|
435
|
+
H += term2
|
|
436
|
+
H *= u
|
|
437
|
+
H += term3
|
|
438
|
+
return H
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
class SSVM(HessianUpdateStrategy):
|
|
442
|
+
"""This one is from Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
|
|
443
|
+
"""
|
|
444
|
+
def __init__(
|
|
445
|
+
self,
|
|
446
|
+
switch: tuple[float,float] | Literal[1,2,3,4] = 3,
|
|
447
|
+
init_scale: float | Literal["auto"] = 'auto',
|
|
448
|
+
tol: float = 1e-10,
|
|
449
|
+
tol_reset: bool = True,
|
|
450
|
+
reset_interval: int | None = None,
|
|
451
|
+
beta: float | None = None,
|
|
452
|
+
update_freq: int = 1,
|
|
453
|
+
scale_first: bool = True,
|
|
454
|
+
scale_second: bool = False,
|
|
455
|
+
concat_params: bool = True,
|
|
456
|
+
inner: Chainable | None = None,
|
|
457
|
+
):
|
|
458
|
+
defaults = dict(switch=switch)
|
|
459
|
+
super().__init__(
|
|
460
|
+
defaults=defaults,
|
|
461
|
+
init_scale=init_scale,
|
|
462
|
+
tol=tol,
|
|
463
|
+
tol_reset=tol_reset,
|
|
464
|
+
reset_interval=reset_interval,
|
|
465
|
+
beta=beta,
|
|
466
|
+
update_freq=update_freq,
|
|
467
|
+
scale_first=scale_first,
|
|
468
|
+
scale_second=scale_second,
|
|
469
|
+
concat_params=concat_params,
|
|
470
|
+
inverse=True,
|
|
471
|
+
inner=inner,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
475
|
+
return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
from .newton import ExactNewton, LinearSystemSolvers, FallbackLinearSystemSolvers, LINEAR_SYSTEM_SOLVERS
|
|
1
|
+
from .newton import Newton
|
|
2
|
+
from .newton_cg import NewtonCG
|
|
3
|
+
from .nystrom import NystromSketchAndSolve, NystromPCG
|