torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,18 +1,16 @@
|
|
|
1
|
-
|
|
1
|
+
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from collections.abc import Mapping
|
|
3
|
+
from collections.abc import Callable, Mapping
|
|
4
4
|
from typing import Any, Literal
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from ...core import Chainable, Module, TensorwiseTransform, Transform
|
|
9
|
-
from ...utils import TensorList, set_storage_, unpack_states
|
|
9
|
+
from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
|
|
10
|
+
from ...utils.linalg import linear_operator
|
|
11
|
+
from ..functional import initial_step_size, safe_clip
|
|
10
12
|
|
|
11
13
|
|
|
12
|
-
def _safe_dict_update_(d1_:dict, d2:dict):
|
|
13
|
-
inter = set(d1_.keys()).intersection(d2.keys())
|
|
14
|
-
if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
|
|
15
|
-
d1_.update(d2)
|
|
16
14
|
|
|
17
15
|
def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
18
16
|
if (beta is None) or (beta == 0) or (key not in state): state[key] = value
|
|
@@ -20,68 +18,165 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
|
20
18
|
else: state[key].lerp_(value, 1-beta)
|
|
21
19
|
|
|
22
20
|
class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
21
|
+
"""Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
|
|
22
|
+
|
|
23
|
+
This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
|
|
24
|
+
and if necessary, ``initialize_P``, ``modify_H`` and ``modify_B``.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
defaults (dict | None, optional): defaults. Defaults to None.
|
|
28
|
+
init_scale (float | Literal["auto"], optional):
|
|
29
|
+
initial hessian matrix is set to identity times this.
|
|
30
|
+
|
|
31
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
32
|
+
|
|
33
|
+
Defaults to "auto".
|
|
34
|
+
tol (float, optional):
|
|
35
|
+
algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-32.
|
|
36
|
+
ptol (float | None, optional):
|
|
37
|
+
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-32.
|
|
38
|
+
ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
39
|
+
gtol (float | None, optional):
|
|
40
|
+
tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-32.
|
|
41
|
+
restart_interval (int | None | Literal["auto"], optional):
|
|
42
|
+
interval between resetting the hessian approximation.
|
|
43
|
+
|
|
44
|
+
"auto" corresponds to number of decision variables + 1.
|
|
45
|
+
|
|
46
|
+
None - no resets.
|
|
47
|
+
|
|
48
|
+
Defaults to None.
|
|
49
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
50
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
51
|
+
scale_first (bool, optional):
|
|
52
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
53
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
54
|
+
concat_params (bool, optional):
|
|
55
|
+
If true, all parameters are treated as a single vector.
|
|
56
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
57
|
+
inverse (bool, optional):
|
|
58
|
+
set to True if this method uses hessian inverse approximation H and has `update_H` method.
|
|
59
|
+
set to False if this maintains hessian approximation B and has `update_B method`.
|
|
60
|
+
Defaults to True.
|
|
61
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
62
|
+
|
|
63
|
+
## Notes
|
|
64
|
+
|
|
65
|
+
### update
|
|
66
|
+
|
|
67
|
+
On 1st ``update_tensor`` H or B is initialized using ``initialize_P``, which returns identity matrix by default.
|
|
68
|
+
|
|
69
|
+
2nd and subsequent ``update_tensor`` calls ``update_H`` or ``update_B``.
|
|
70
|
+
|
|
71
|
+
Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
|
|
72
|
+
|
|
73
|
+
### apply
|
|
74
|
+
|
|
75
|
+
``apply_tensor`` computes ``H = modify_H(H)`` or ``B = modify_B(B)``, those methods do nothing by default.
|
|
76
|
+
|
|
77
|
+
Then it computes and returns ``H @ input`` or ``solve(B, input)``.
|
|
78
|
+
|
|
79
|
+
Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
|
|
80
|
+
|
|
81
|
+
### initial scale
|
|
82
|
+
|
|
83
|
+
If ``init_scale`` is a scalar, the preconditioner is multiplied or divided (if inverse) by it on first ``update_tensor``.
|
|
84
|
+
|
|
85
|
+
If ``init_scale="auto"``, it is computed and applied on the second ``update_tensor``.
|
|
86
|
+
|
|
87
|
+
### get_H
|
|
88
|
+
|
|
89
|
+
First it computes ``H = modify_H(H)`` or ``B = modify_B(B)``.
|
|
90
|
+
|
|
91
|
+
Returns a ``Dense`` linear operator with ``B``, or ``DenseInverse`` linear operator with ``H``.
|
|
92
|
+
|
|
93
|
+
But if H/B has 1 dimension, ``Diagonal`` linear operator is returned with ``B`` or ``1/H``.
|
|
94
|
+
"""
|
|
23
95
|
def __init__(
|
|
24
96
|
self,
|
|
25
97
|
defaults: dict | None = None,
|
|
26
98
|
init_scale: float | Literal["auto"] = "auto",
|
|
27
|
-
tol: float = 1e-
|
|
28
|
-
|
|
29
|
-
|
|
99
|
+
tol: float = 1e-32,
|
|
100
|
+
ptol: float | None = 1e-32,
|
|
101
|
+
ptol_restart: bool = False,
|
|
102
|
+
gtol: float | None = 1e-32,
|
|
103
|
+
restart_interval: int | None | Literal['auto'] = None,
|
|
30
104
|
beta: float | None = None,
|
|
31
105
|
update_freq: int = 1,
|
|
32
|
-
scale_first: bool =
|
|
33
|
-
scale_second: bool = False,
|
|
106
|
+
scale_first: bool = False,
|
|
34
107
|
concat_params: bool = True,
|
|
35
108
|
inverse: bool = True,
|
|
36
109
|
inner: Chainable | None = None,
|
|
37
110
|
):
|
|
38
111
|
if defaults is None: defaults = {}
|
|
39
|
-
|
|
40
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq,
|
|
112
|
+
safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
|
|
113
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
41
114
|
|
|
42
|
-
def
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
yy = y.dot(y)
|
|
46
|
-
if ys != 0 and yy != 0: return yy/ys
|
|
47
|
-
return 1
|
|
115
|
+
def reset_for_online(self):
|
|
116
|
+
super().reset_for_online()
|
|
117
|
+
self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
|
|
48
118
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
if inverse: M /= init_scale
|
|
54
|
-
else: M *= init_scale
|
|
119
|
+
# ---------------------------- methods to override --------------------------- #
|
|
120
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool) -> torch.Tensor:
|
|
121
|
+
"""returns the initial torch.Tensor for H or B"""
|
|
122
|
+
return torch.eye(size, device=device, dtype=dtype)
|
|
55
123
|
|
|
56
124
|
def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
57
|
-
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any],
|
|
125
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
58
126
|
"""update hessian inverse"""
|
|
59
|
-
raise NotImplementedError
|
|
127
|
+
raise NotImplementedError(f"hessian inverse approximation is not implemented for {self.__class__.__name__}.")
|
|
60
128
|
|
|
61
129
|
def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
62
|
-
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any],
|
|
130
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
63
131
|
"""update hessian"""
|
|
64
|
-
raise NotImplementedError
|
|
132
|
+
raise NotImplementedError(f"{self.__class__.__name__} only supports hessian inverse approximation. "
|
|
133
|
+
"Remove the `inverse=False` argument when initializing this module.")
|
|
134
|
+
|
|
135
|
+
def modify_B(self, B: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
136
|
+
"""modifies B out of place before appling the update rule, doesn't affect the buffer B."""
|
|
137
|
+
return B
|
|
138
|
+
|
|
139
|
+
def modify_H(self, H: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
140
|
+
"""modifies H out of place before appling the update rule, doesn't affect the buffer H."""
|
|
141
|
+
return H
|
|
142
|
+
|
|
143
|
+
# ------------------------------ common methods ------------------------------ #
|
|
144
|
+
def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
|
|
145
|
+
"""returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
|
|
146
|
+
ys = y.dot(s)
|
|
147
|
+
yy = y.dot(y)
|
|
148
|
+
if ys != 0 and yy != 0: return yy/ys
|
|
149
|
+
return 1
|
|
150
|
+
|
|
151
|
+
def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
|
|
152
|
+
"""resets ``P`` which is either B or H"""
|
|
153
|
+
set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
|
|
154
|
+
if init_scale == 'auto': init_scale = self.auto_initial_scale(s,y)
|
|
155
|
+
if init_scale >= 1:
|
|
156
|
+
if inverse: P /= init_scale
|
|
157
|
+
else: P *= init_scale
|
|
65
158
|
|
|
66
159
|
@torch.no_grad
|
|
67
|
-
def update_tensor(self, tensor, param, grad, loss, state,
|
|
160
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
68
161
|
p = param.view(-1); g = tensor.view(-1)
|
|
69
|
-
inverse =
|
|
162
|
+
inverse = setting['inverse']
|
|
70
163
|
M_key = 'H' if inverse else 'B'
|
|
71
164
|
M = state.get(M_key, None)
|
|
72
|
-
step = state.get('step', 0)
|
|
73
|
-
state['step'] = step
|
|
74
|
-
init_scale =
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if
|
|
83
|
-
|
|
84
|
-
|
|
165
|
+
step = state.get('step', 0) + 1
|
|
166
|
+
state['step'] = step
|
|
167
|
+
init_scale = setting['init_scale']
|
|
168
|
+
ptol = setting['ptol']
|
|
169
|
+
ptol_restart = setting['ptol_restart']
|
|
170
|
+
gtol = setting['gtol']
|
|
171
|
+
restart_interval = setting['restart_interval']
|
|
172
|
+
if restart_interval == 'auto': restart_interval = tensor.numel() + 1
|
|
173
|
+
|
|
174
|
+
if M is None or 'f_prev' not in state:
|
|
175
|
+
if M is None: # won't be true on reset_for_online
|
|
176
|
+
M = self.initialize_P(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
|
|
177
|
+
if isinstance(init_scale, (int, float)) and init_scale != 1:
|
|
178
|
+
if inverse: M /= init_scale
|
|
179
|
+
else: M *= init_scale
|
|
85
180
|
|
|
86
181
|
state[M_key] = M
|
|
87
182
|
state['f_prev'] = loss
|
|
@@ -97,190 +192,487 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
97
192
|
state['p_prev'].copy_(p)
|
|
98
193
|
state['g_prev'].copy_(g)
|
|
99
194
|
|
|
100
|
-
if
|
|
101
|
-
self.
|
|
195
|
+
if restart_interval is not None and step % restart_interval == 0:
|
|
196
|
+
self.reset_P(M, s, y, inverse, init_scale, state)
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
200
|
+
if ptol is not None and s.abs().max() <= ptol:
|
|
201
|
+
if ptol_restart: self.reset_P(M, s, y, inverse, init_scale, state) # reset history
|
|
102
202
|
return
|
|
103
203
|
|
|
104
|
-
# tolerance on gradient difference to avoid exploding
|
|
105
|
-
if y.abs().max() <=
|
|
106
|
-
# reset history
|
|
107
|
-
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale, state)
|
|
204
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
205
|
+
if gtol is not None and y.abs().max() <= gtol:
|
|
108
206
|
return
|
|
109
207
|
|
|
110
|
-
if step ==
|
|
111
|
-
if inverse: M /= self.
|
|
112
|
-
else: M *= self.
|
|
208
|
+
if step == 2 and init_scale == 'auto':
|
|
209
|
+
if inverse: M /= self.auto_initial_scale(s,y)
|
|
210
|
+
else: M *= self.auto_initial_scale(s,y)
|
|
113
211
|
|
|
114
|
-
beta =
|
|
212
|
+
beta = setting['beta']
|
|
115
213
|
if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
|
|
116
214
|
|
|
117
215
|
if inverse:
|
|
118
|
-
H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state,
|
|
216
|
+
H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
|
|
119
217
|
_maybe_lerp_(state, 'H', H_new, beta)
|
|
120
218
|
|
|
121
219
|
else:
|
|
122
|
-
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state,
|
|
220
|
+
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
|
|
123
221
|
_maybe_lerp_(state, 'B', B_new, beta)
|
|
124
222
|
|
|
125
223
|
state['f_prev'] = loss
|
|
126
224
|
|
|
127
225
|
@torch.no_grad
|
|
128
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
129
|
-
step = state
|
|
226
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
227
|
+
step = state['step']
|
|
228
|
+
|
|
229
|
+
if setting['scale_first'] and step == 1:
|
|
230
|
+
tensor *= initial_step_size(tensor)
|
|
130
231
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
|
|
134
|
-
tensor = tensor * scale_factor
|
|
232
|
+
inverse = setting['inverse']
|
|
233
|
+
g = tensor.view(-1)
|
|
135
234
|
|
|
136
|
-
inverse = settings['inverse']
|
|
137
235
|
if inverse:
|
|
138
236
|
H = state['H']
|
|
139
|
-
|
|
237
|
+
H = self.modify_H(H, state, setting)
|
|
238
|
+
if H.ndim == 1: return g.mul_(H).view_as(tensor)
|
|
239
|
+
return (H @ g).view_as(tensor)
|
|
140
240
|
|
|
141
241
|
B = state['B']
|
|
242
|
+
B = self.modify_B(B, state, setting)
|
|
243
|
+
|
|
244
|
+
if B.ndim == 1: return g.div_(B).view_as(tensor)
|
|
245
|
+
x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
|
|
246
|
+
if info == 0: return x.view_as(tensor)
|
|
247
|
+
|
|
248
|
+
# failed to solve linear system, so reset state
|
|
249
|
+
self.state.clear()
|
|
250
|
+
self.global_state.clear()
|
|
251
|
+
return tensor.mul_(initial_step_size(tensor))
|
|
252
|
+
|
|
253
|
+
def get_H(self, var):
|
|
254
|
+
param = var.params[0]
|
|
255
|
+
state = self.state[param]
|
|
256
|
+
settings = self.settings[param]
|
|
257
|
+
if "B" in state:
|
|
258
|
+
B = self.modify_B(state["B"], state, settings)
|
|
259
|
+
if B.ndim == 2: return linear_operator.Dense(B)
|
|
260
|
+
assert B.ndim == 1, B.shape
|
|
261
|
+
return linear_operator.Diagonal(B)
|
|
262
|
+
|
|
263
|
+
if "H" in state:
|
|
264
|
+
H = self.modify_H(state["H"], state, settings)
|
|
265
|
+
if H.ndim != 1: return linear_operator.DenseInverse(H)
|
|
266
|
+
return linear_operator.Diagonal(1/H)
|
|
267
|
+
|
|
268
|
+
return None
|
|
269
|
+
|
|
270
|
+
class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
271
|
+
'''This is ``HessianUpdateStrategy`` subclass for algorithms with no extra defaults, to skip the lengthy ``__init__``.
|
|
272
|
+
Refer to ``HessianUpdateStrategy`` documentation.
|
|
273
|
+
|
|
274
|
+
## Example:
|
|
275
|
+
|
|
276
|
+
Implementing BFGS method that maintains an estimate of the hessian inverse (H):
|
|
277
|
+
```python
|
|
278
|
+
class BFGS(_HessianUpdateStrategyDefaults):
|
|
279
|
+
"""Broyden–Fletcher–Goldfarb–Shanno algorithm"""
|
|
280
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
281
|
+
tol = settings["tol"]
|
|
282
|
+
sy = torch.dot(s, y)
|
|
283
|
+
if sy <= tol: return H
|
|
284
|
+
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
285
|
+
term1 = num1.div_(sy**2)
|
|
286
|
+
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
287
|
+
term2 = num2.div_(sy)
|
|
288
|
+
H += term1.sub_(term2)
|
|
289
|
+
return H
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
Make sure to put at least a basic class level docstring to overwrite this.
|
|
293
|
+
'''
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
297
|
+
tol: float = 1e-32,
|
|
298
|
+
ptol: float | None = 1e-32,
|
|
299
|
+
ptol_restart: bool = False,
|
|
300
|
+
gtol: float | None = 1e-32,
|
|
301
|
+
restart_interval: int | None = None,
|
|
302
|
+
beta: float | None = None,
|
|
303
|
+
update_freq: int = 1,
|
|
304
|
+
scale_first: bool = False,
|
|
305
|
+
concat_params: bool = True,
|
|
306
|
+
inverse: bool = True,
|
|
307
|
+
inner: Chainable | None = None,
|
|
308
|
+
):
|
|
309
|
+
super().__init__(
|
|
310
|
+
defaults=None,
|
|
311
|
+
init_scale=init_scale,
|
|
312
|
+
tol=tol,
|
|
313
|
+
ptol=ptol,
|
|
314
|
+
ptol_restart=ptol_restart,
|
|
315
|
+
gtol=gtol,
|
|
316
|
+
restart_interval=restart_interval,
|
|
317
|
+
beta=beta,
|
|
318
|
+
update_freq=update_freq,
|
|
319
|
+
scale_first=scale_first,
|
|
320
|
+
concat_params=concat_params,
|
|
321
|
+
inverse=inverse,
|
|
322
|
+
inner=inner,
|
|
323
|
+
)
|
|
142
324
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
# to avoid typing all arguments for each method
|
|
146
|
-
class HUpdateStrategy(HessianUpdateStrategy):
|
|
325
|
+
class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
147
326
|
def __init__(
|
|
148
327
|
self,
|
|
149
328
|
init_scale: float | Literal["auto"] = "auto",
|
|
150
|
-
tol: float = 1e-
|
|
151
|
-
|
|
152
|
-
|
|
329
|
+
tol: float = 1e-32,
|
|
330
|
+
ptol: float | None = 1e-32,
|
|
331
|
+
ptol_restart: bool = False,
|
|
332
|
+
gtol: float | None = 1e-32,
|
|
333
|
+
restart_interval: int | None = None,
|
|
153
334
|
beta: float | None = None,
|
|
154
335
|
update_freq: int = 1,
|
|
155
|
-
scale_first: bool =
|
|
156
|
-
scale_second: bool = False,
|
|
336
|
+
scale_first: bool = False,
|
|
157
337
|
concat_params: bool = True,
|
|
338
|
+
inverse: bool = False,
|
|
158
339
|
inner: Chainable | None = None,
|
|
159
340
|
):
|
|
160
341
|
super().__init__(
|
|
161
342
|
defaults=None,
|
|
162
343
|
init_scale=init_scale,
|
|
163
344
|
tol=tol,
|
|
164
|
-
|
|
165
|
-
|
|
345
|
+
ptol=ptol,
|
|
346
|
+
ptol_restart=ptol_restart,
|
|
347
|
+
gtol=gtol,
|
|
348
|
+
restart_interval=restart_interval,
|
|
166
349
|
beta=beta,
|
|
167
350
|
update_freq=update_freq,
|
|
168
351
|
scale_first=scale_first,
|
|
169
|
-
scale_second=scale_second,
|
|
170
352
|
concat_params=concat_params,
|
|
171
|
-
inverse=
|
|
353
|
+
inverse=inverse,
|
|
172
354
|
inner=inner,
|
|
173
355
|
)
|
|
356
|
+
|
|
174
357
|
# ----------------------------------- BFGS ----------------------------------- #
|
|
358
|
+
def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
359
|
+
sy = s.dot(y)
|
|
360
|
+
if sy < tol: return B
|
|
361
|
+
|
|
362
|
+
Bs = B@s
|
|
363
|
+
sBs = safe_clip(s.dot(Bs))
|
|
364
|
+
|
|
365
|
+
term1 = y.outer(y).div_(sy)
|
|
366
|
+
term2 = (Bs.outer(s) @ B.T).div_(sBs)
|
|
367
|
+
B += term1.sub_(term2)
|
|
368
|
+
return B
|
|
369
|
+
|
|
175
370
|
def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
176
|
-
sy =
|
|
177
|
-
if sy <= tol: return H
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
371
|
+
sy = s.dot(y)
|
|
372
|
+
if sy <= tol: return H
|
|
373
|
+
|
|
374
|
+
sy_sq = safe_clip(sy**2)
|
|
375
|
+
|
|
376
|
+
Hy = H@y
|
|
377
|
+
scale1 = (sy + y.dot(Hy)) / sy_sq
|
|
378
|
+
term1 = s.outer(s).mul_(scale1)
|
|
379
|
+
|
|
380
|
+
num2 = (Hy.outer(s)).add_(s.outer(y @ H))
|
|
181
381
|
term2 = num2.div_(sy)
|
|
382
|
+
|
|
182
383
|
H += term1.sub_(term2)
|
|
183
384
|
return H
|
|
184
385
|
|
|
185
|
-
class BFGS(
|
|
186
|
-
|
|
187
|
-
|
|
386
|
+
class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
387
|
+
"""Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
|
|
388
|
+
|
|
389
|
+
Note:
|
|
390
|
+
a line search or a trust region is recommended
|
|
391
|
+
|
|
392
|
+
Warning:
|
|
393
|
+
this uses at least O(N^2) memory.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
init_scale (float | Literal["auto"], optional):
|
|
397
|
+
initial hessian matrix is set to identity times this.
|
|
398
|
+
|
|
399
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
400
|
+
|
|
401
|
+
Defaults to "auto".
|
|
402
|
+
tol (float, optional):
|
|
403
|
+
tolerance on curvature condition. Defaults to 1e-32.
|
|
404
|
+
ptol (float | None, optional):
|
|
405
|
+
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
406
|
+
Defaults to 1e-32.
|
|
407
|
+
ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
408
|
+
restart_interval (int | None | Literal["auto"], optional):
|
|
409
|
+
interval between resetting the hessian approximation.
|
|
410
|
+
|
|
411
|
+
"auto" corresponds to number of decision variables + 1.
|
|
412
|
+
|
|
413
|
+
None - no resets.
|
|
414
|
+
|
|
415
|
+
Defaults to None.
|
|
416
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
417
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
418
|
+
scale_first (bool, optional):
|
|
419
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
420
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
421
|
+
concat_params (bool, optional):
|
|
422
|
+
If true, all parameters are treated as a single vector.
|
|
423
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
424
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
425
|
+
|
|
426
|
+
## Examples:
|
|
427
|
+
|
|
428
|
+
BFGS with backtracking line search:
|
|
429
|
+
|
|
430
|
+
```python
|
|
431
|
+
opt = tz.Modular(
|
|
432
|
+
model.parameters(),
|
|
433
|
+
tz.m.BFGS(),
|
|
434
|
+
tz.m.Backtracking()
|
|
435
|
+
)
|
|
436
|
+
```
|
|
437
|
+
|
|
438
|
+
BFGS with trust region
|
|
439
|
+
```python
|
|
440
|
+
opt = tz.Modular(
|
|
441
|
+
model.parameters(),
|
|
442
|
+
tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
|
|
443
|
+
)
|
|
444
|
+
```
|
|
445
|
+
"""
|
|
446
|
+
|
|
447
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
448
|
+
return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
449
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
450
|
+
return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])
|
|
188
451
|
|
|
189
452
|
# ------------------------------------ SR1 ----------------------------------- #
|
|
190
|
-
def
|
|
453
|
+
def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
191
454
|
z = s - H@y
|
|
192
|
-
denom =
|
|
455
|
+
denom = z.dot(y)
|
|
193
456
|
|
|
194
457
|
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
195
458
|
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
196
459
|
|
|
197
|
-
if y_norm*z_norm < tol: return H
|
|
460
|
+
# if y_norm*z_norm < tol: return H
|
|
198
461
|
|
|
199
462
|
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
200
463
|
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
201
|
-
H +=
|
|
464
|
+
H += z.outer(z).div_(safe_clip(denom))
|
|
202
465
|
return H
|
|
203
466
|
|
|
204
|
-
class SR1(
|
|
205
|
-
|
|
206
|
-
|
|
467
|
+
class SR1(_InverseHessianUpdateStrategyDefaults):
|
|
468
|
+
"""Symmetric Rank 1. This works best with a trust region:
|
|
469
|
+
```python
|
|
470
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))
|
|
471
|
+
```
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
init_scale (float | Literal["auto"], optional):
|
|
475
|
+
initial hessian matrix is set to identity times this.
|
|
476
|
+
|
|
477
|
+
"auto" corresponds to a heuristic from [1] p.142-143.
|
|
478
|
+
|
|
479
|
+
Defaults to "auto".
|
|
480
|
+
tol (float, optional):
|
|
481
|
+
tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.
|
|
482
|
+
ptol (float | None, optional):
|
|
483
|
+
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
484
|
+
Defaults to 1e-32.
|
|
485
|
+
ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
486
|
+
restart_interval (int | None | Literal["auto"], optional):
|
|
487
|
+
interval between resetting the hessian approximation.
|
|
488
|
+
|
|
489
|
+
"auto" corresponds to number of decision variables + 1.
|
|
490
|
+
|
|
491
|
+
None - no resets.
|
|
492
|
+
|
|
493
|
+
Defaults to None.
|
|
494
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
495
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
496
|
+
scale_first (bool, optional):
|
|
497
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
498
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
499
|
+
concat_params (bool, optional):
|
|
500
|
+
If true, all parameters are treated as a single vector.
|
|
501
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
502
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
503
|
+
|
|
504
|
+
### Examples:
|
|
505
|
+
|
|
506
|
+
SR1 with trust region
|
|
507
|
+
```python
|
|
508
|
+
opt = tz.Modular(
|
|
509
|
+
model.parameters(),
|
|
510
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
511
|
+
)
|
|
512
|
+
```
|
|
513
|
+
|
|
514
|
+
### References:
|
|
515
|
+
[1]. Nocedal. Stephen J. Wright. Numerical Optimization
|
|
516
|
+
"""
|
|
517
|
+
|
|
518
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
519
|
+
return sr1_(H=H, s=s, y=y, tol=setting['tol'])
|
|
520
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
521
|
+
return sr1_(H=B, s=y, y=s, tol=setting['tol'])
|
|
522
|
+
|
|
207
523
|
|
|
208
524
|
# ------------------------------------ DFP ----------------------------------- #
|
|
209
525
|
def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
210
|
-
sy =
|
|
526
|
+
sy = s.dot(y)
|
|
211
527
|
if sy.abs() <= tol: return H
|
|
212
|
-
term1 =
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
528
|
+
term1 = s.outer(s).div_(sy)
|
|
529
|
+
|
|
530
|
+
yHy = safe_clip(y.dot(H @ y))
|
|
531
|
+
|
|
532
|
+
num = (H @ y).outer(y) @ H
|
|
216
533
|
term2 = num.div_(yHy)
|
|
534
|
+
|
|
217
535
|
H += term1.sub_(term2)
|
|
218
536
|
return H
|
|
219
537
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
538
|
+
def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
539
|
+
sy = s.dot(y)
|
|
540
|
+
if sy.abs() <= tol: return B
|
|
541
|
+
I = torch.eye(B.size(0), device=B.device, dtype=B.dtype)
|
|
542
|
+
sub = y.outer(s).div_(sy)
|
|
543
|
+
term1 = I - sub
|
|
544
|
+
term2 = I.sub_(sub.T)
|
|
545
|
+
term3 = y.outer(y).div_(sy)
|
|
546
|
+
B = (term1 @ B @ term2).add_(term3)
|
|
547
|
+
return B
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class DFP(_InverseHessianUpdateStrategyDefaults):
|
|
551
|
+
"""Davidon–Fletcher–Powell Quasi-Newton method.
|
|
552
|
+
|
|
553
|
+
Note:
|
|
554
|
+
a trust region or an accurate line search is recommended.
|
|
555
|
+
|
|
556
|
+
Warning:
|
|
557
|
+
this uses at least O(N^2) memory.
|
|
558
|
+
"""
|
|
559
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
560
|
+
return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
561
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
562
|
+
return dfp_B(B=B, s=s, y=y, tol=setting['tol'])
|
|
223
563
|
|
|
224
564
|
|
|
225
565
|
# formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
226
566
|
# H' = H - (Hy - S)c^T / c^T*y
|
|
227
567
|
# the difference is how `c` is calculated
|
|
228
568
|
|
|
229
|
-
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
569
|
+
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
230
570
|
c = H.T @ s
|
|
231
|
-
cy = c.dot(y)
|
|
232
|
-
if cy.abs() <= tol: return H
|
|
571
|
+
cy = safe_clip(c.dot(y))
|
|
233
572
|
num = (H@y).sub_(s).outer(c)
|
|
234
573
|
H -= num/cy
|
|
235
574
|
return H
|
|
575
|
+
def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
576
|
+
r = y - B@s
|
|
577
|
+
ss = safe_clip(s.dot(s))
|
|
578
|
+
B += r.outer(s).div_(ss)
|
|
579
|
+
return B
|
|
236
580
|
|
|
237
|
-
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
num = (H@y).sub_(s).outer(c)
|
|
242
|
-
H -= num/cy
|
|
581
|
+
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
582
|
+
yy = safe_clip(y.dot(y))
|
|
583
|
+
num = (s - (H @ y)).outer(y)
|
|
584
|
+
H += num/yy
|
|
243
585
|
return H
|
|
586
|
+
def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
587
|
+
r = y - B@s
|
|
588
|
+
ys = safe_clip(y.dot(s))
|
|
589
|
+
B += r.outer(y).div_(ys)
|
|
590
|
+
return B
|
|
244
591
|
|
|
245
|
-
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor
|
|
592
|
+
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
|
|
246
593
|
c = g_prev
|
|
247
|
-
cy = c.dot(y)
|
|
248
|
-
if cy.abs() <= tol: return H
|
|
594
|
+
cy = safe_clip(c.dot(y))
|
|
249
595
|
num = (H@y).sub_(s).outer(c)
|
|
250
596
|
H -= num/cy
|
|
251
597
|
return H
|
|
252
598
|
|
|
253
|
-
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
599
|
+
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
254
600
|
Hy = H @ y
|
|
255
601
|
c = H @ Hy # pylint:disable=not-callable
|
|
256
|
-
cy = c.dot(y)
|
|
257
|
-
if cy.abs() <= tol: return H
|
|
602
|
+
cy = safe_clip(c.dot(y))
|
|
258
603
|
num = Hy.sub_(s).outer(c)
|
|
259
604
|
H -= num/cy
|
|
260
605
|
return H
|
|
261
606
|
|
|
262
|
-
class BroydenGood(
|
|
263
|
-
|
|
264
|
-
|
|
607
|
+
class BroydenGood(_InverseHessianUpdateStrategyDefaults):
|
|
608
|
+
"""Broyden's "good" Quasi-Newton method.
|
|
609
|
+
|
|
610
|
+
Note:
|
|
611
|
+
a trust region or an accurate line search is recommended.
|
|
612
|
+
|
|
613
|
+
Warning:
|
|
614
|
+
this uses at least O(N^2) memory.
|
|
615
|
+
|
|
616
|
+
Reference:
|
|
617
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
618
|
+
"""
|
|
619
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
620
|
+
return broyden_good_H_(H=H, s=s, y=y)
|
|
621
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
622
|
+
return broyden_good_B_(B=B, s=s, y=y)
|
|
265
623
|
|
|
266
|
-
class BroydenBad(
|
|
267
|
-
|
|
268
|
-
return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
624
|
+
class BroydenBad(_InverseHessianUpdateStrategyDefaults):
|
|
625
|
+
"""Broyden's "bad" Quasi-Newton method.
|
|
269
626
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
|
|
627
|
+
Note:
|
|
628
|
+
a trust region or an accurate line search is recommended.
|
|
273
629
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
630
|
+
Warning:
|
|
631
|
+
this uses at least O(N^2) memory.
|
|
277
632
|
|
|
633
|
+
Reference:
|
|
634
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
635
|
+
"""
|
|
636
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
637
|
+
return broyden_bad_H_(H=H, s=s, y=y)
|
|
638
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
639
|
+
return broyden_bad_B_(B=B, s=s, y=y)
|
|
640
|
+
|
|
641
|
+
class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
|
|
642
|
+
"""Greenstadt's first Quasi-Newton method.
|
|
643
|
+
|
|
644
|
+
Note:
|
|
645
|
+
a trust region or an accurate line search is recommended.
|
|
646
|
+
|
|
647
|
+
Warning:
|
|
648
|
+
this uses at least O(N^2) memory.
|
|
649
|
+
|
|
650
|
+
Reference:
|
|
651
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
652
|
+
"""
|
|
653
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
654
|
+
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)
|
|
278
655
|
|
|
279
|
-
|
|
656
|
+
class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
|
|
657
|
+
"""Greenstadt's second Quasi-Newton method.
|
|
658
|
+
|
|
659
|
+
Note:
|
|
660
|
+
a line search is recommended.
|
|
661
|
+
|
|
662
|
+
Warning:
|
|
663
|
+
this uses at least O(N^2) memory.
|
|
664
|
+
|
|
665
|
+
Reference:
|
|
666
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
667
|
+
"""
|
|
668
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
669
|
+
return greenstadt2_H_(H=H, s=s, y=y)
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
280
673
|
j = y.abs().argmax()
|
|
281
674
|
|
|
282
|
-
denom = y[j]
|
|
283
|
-
if denom.abs() < tol: return H
|
|
675
|
+
denom = safe_clip(y[j])
|
|
284
676
|
|
|
285
677
|
Hy = H @ y.unsqueeze(1)
|
|
286
678
|
num = s.unsqueeze(1) - Hy
|
|
@@ -288,161 +680,194 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
|
|
|
288
680
|
H[:, j] += num.squeeze() / denom
|
|
289
681
|
return H
|
|
290
682
|
|
|
291
|
-
class
|
|
292
|
-
"""
|
|
293
|
-
|
|
294
|
-
|
|
683
|
+
class ICUM(_InverseHessianUpdateStrategyDefaults):
|
|
684
|
+
"""
|
|
685
|
+
Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
|
|
686
|
+
due to only updating one column of the inverse hessian approximation per step.
|
|
687
|
+
|
|
688
|
+
Note:
|
|
689
|
+
a line search is recommended.
|
|
690
|
+
|
|
691
|
+
Warning:
|
|
692
|
+
this uses at least O(N^2) memory.
|
|
693
|
+
|
|
694
|
+
Reference:
|
|
695
|
+
Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
|
|
696
|
+
"""
|
|
697
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
698
|
+
return icum_H_(H=H, s=s, y=y)
|
|
295
699
|
|
|
296
|
-
def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor
|
|
700
|
+
def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
297
701
|
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
298
702
|
I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
299
703
|
d = (R + I * (s_norm/2)) @ s
|
|
300
|
-
ds = d.dot(s)
|
|
301
|
-
if ds.abs() <= tol: return H, R
|
|
704
|
+
ds = safe_clip(d.dot(s))
|
|
302
705
|
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
|
|
303
706
|
|
|
304
707
|
c = H.T @ d
|
|
305
|
-
cy = c.dot(y)
|
|
306
|
-
if cy.abs() <= tol: return H, R
|
|
708
|
+
cy = safe_clip(c.dot(y))
|
|
307
709
|
num = (H@y).sub_(s).outer(c)
|
|
308
710
|
H -= num/cy
|
|
309
711
|
return H, R
|
|
310
712
|
|
|
311
|
-
class ThomasOptimalMethod(
|
|
312
|
-
"""
|
|
313
|
-
|
|
713
|
+
class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
|
|
714
|
+
"""
|
|
715
|
+
Thomas's "optimal" Quasi-Newton method.
|
|
716
|
+
|
|
717
|
+
Note:
|
|
718
|
+
a line search is recommended.
|
|
719
|
+
|
|
720
|
+
Warning:
|
|
721
|
+
this uses at least O(N^2) memory.
|
|
722
|
+
|
|
723
|
+
Reference:
|
|
724
|
+
Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
|
|
725
|
+
"""
|
|
726
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
314
727
|
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
315
|
-
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y
|
|
728
|
+
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
|
|
316
729
|
return H
|
|
317
730
|
|
|
318
|
-
def
|
|
319
|
-
super().
|
|
731
|
+
def reset_P(self, P, s, y, inverse, init_scale, state):
|
|
732
|
+
super().reset_P(P, s, y, inverse, init_scale, state)
|
|
320
733
|
for st in self.state.values():
|
|
321
734
|
st.pop("R", None)
|
|
322
735
|
|
|
323
736
|
# ------------------------ powell's symmetric broyden ------------------------ #
|
|
324
|
-
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor
|
|
737
|
+
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
325
738
|
y_Bs = y - B@s
|
|
326
|
-
ss = s.dot(s)
|
|
327
|
-
if ss.abs() < tol: return B
|
|
739
|
+
ss = safe_clip(s.dot(s))
|
|
328
740
|
num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
|
|
329
741
|
term1 = num1.div_(ss)
|
|
330
|
-
term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
|
|
742
|
+
term2 = s.outer(s).mul_(y_Bs.dot(s)/(safe_clip(ss**2)))
|
|
331
743
|
B += term1.sub_(term2)
|
|
332
744
|
return B
|
|
333
745
|
|
|
334
746
|
# I couldn't find formula for H
|
|
335
|
-
class PSB(
|
|
336
|
-
|
|
337
|
-
self,
|
|
338
|
-
init_scale: float | Literal["auto"] = 'auto',
|
|
339
|
-
tol: float = 1e-10,
|
|
340
|
-
tol_reset: bool = True,
|
|
341
|
-
reset_interval: int | None = None,
|
|
342
|
-
beta: float | None = None,
|
|
343
|
-
update_freq: int = 1,
|
|
344
|
-
scale_first: bool = True,
|
|
345
|
-
scale_second: bool = False,
|
|
346
|
-
concat_params: bool = True,
|
|
347
|
-
inner: Chainable | None = None,
|
|
348
|
-
):
|
|
349
|
-
super().__init__(
|
|
350
|
-
defaults=None,
|
|
351
|
-
init_scale=init_scale,
|
|
352
|
-
tol=tol,
|
|
353
|
-
tol_reset=tol_reset,
|
|
354
|
-
reset_interval=reset_interval,
|
|
355
|
-
beta=beta,
|
|
356
|
-
update_freq=update_freq,
|
|
357
|
-
scale_first=scale_first,
|
|
358
|
-
scale_second=scale_second,
|
|
359
|
-
concat_params=concat_params,
|
|
360
|
-
inverse=False,
|
|
361
|
-
inner=inner,
|
|
362
|
-
)
|
|
747
|
+
class PSB(_HessianUpdateStrategyDefaults):
|
|
748
|
+
"""Powell's Symmetric Broyden Quasi-Newton method.
|
|
363
749
|
|
|
364
|
-
|
|
365
|
-
|
|
750
|
+
Note:
|
|
751
|
+
a line search or a trust region is recommended.
|
|
752
|
+
|
|
753
|
+
Warning:
|
|
754
|
+
this uses at least O(N^2) memory.
|
|
755
|
+
|
|
756
|
+
Reference:
|
|
757
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
758
|
+
"""
|
|
759
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
760
|
+
return psb_B_(B=B, s=s, y=y)
|
|
366
761
|
|
|
367
762
|
|
|
368
763
|
# Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
|
|
369
|
-
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
764
|
+
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
370
765
|
Hy = H@y
|
|
371
|
-
yHy = y.dot(Hy)
|
|
372
|
-
if yHy.abs() <= tol: return H
|
|
766
|
+
yHy = safe_clip(y.dot(Hy))
|
|
373
767
|
num = (s - Hy).outer(Hy)
|
|
374
768
|
H += num.div_(yHy)
|
|
375
769
|
return H
|
|
376
770
|
|
|
377
|
-
class Pearson(
|
|
378
|
-
"""
|
|
771
|
+
class Pearson(_InverseHessianUpdateStrategyDefaults):
|
|
772
|
+
"""
|
|
773
|
+
Pearson's Quasi-Newton method.
|
|
379
774
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
775
|
+
Note:
|
|
776
|
+
a line search is recommended.
|
|
383
777
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
778
|
+
Warning:
|
|
779
|
+
this uses at least O(N^2) memory.
|
|
780
|
+
|
|
781
|
+
Reference:
|
|
782
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
783
|
+
"""
|
|
784
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
785
|
+
return pearson_H_(H=H, s=s, y=y)
|
|
786
|
+
|
|
787
|
+
def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
788
|
+
sy = safe_clip(s.dot(y))
|
|
387
789
|
num = (s - H@y).outer(s)
|
|
388
790
|
H += num.div_(sy)
|
|
389
791
|
return H
|
|
390
792
|
|
|
391
|
-
class McCormick(
|
|
392
|
-
"""
|
|
793
|
+
class McCormick(_InverseHessianUpdateStrategyDefaults):
|
|
794
|
+
"""McCormicks's Quasi-Newton method.
|
|
393
795
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
796
|
+
Note:
|
|
797
|
+
a line search is recommended.
|
|
397
798
|
|
|
398
|
-
|
|
799
|
+
Warning:
|
|
800
|
+
this uses at least O(N^2) memory.
|
|
801
|
+
|
|
802
|
+
Reference:
|
|
803
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
804
|
+
|
|
805
|
+
This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
|
|
806
|
+
"""
|
|
807
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
808
|
+
return mccormick_H_(H=H, s=s, y=y)
|
|
809
|
+
|
|
810
|
+
def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
399
811
|
Hy = H @ y
|
|
400
|
-
yHy = y.dot(Hy)
|
|
401
|
-
if yHy.abs() < tol: return H, R
|
|
812
|
+
yHy = safe_clip(y.dot(Hy))
|
|
402
813
|
H -= Hy.outer(Hy) / yHy
|
|
403
814
|
R += (s - R@y).outer(Hy) / yHy
|
|
404
815
|
return H, R
|
|
405
816
|
|
|
406
817
|
class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
407
|
-
"""
|
|
818
|
+
"""
|
|
819
|
+
Projected Newton Raphson method.
|
|
820
|
+
|
|
821
|
+
Note:
|
|
822
|
+
a line search is recommended.
|
|
823
|
+
|
|
824
|
+
Warning:
|
|
825
|
+
this uses at least O(N^2) memory.
|
|
826
|
+
|
|
827
|
+
Reference:
|
|
828
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
408
829
|
|
|
409
|
-
|
|
830
|
+
This one is Algorithm 7.
|
|
831
|
+
"""
|
|
410
832
|
def __init__(
|
|
411
833
|
self,
|
|
412
834
|
init_scale: float | Literal["auto"] = 'auto',
|
|
413
|
-
tol: float = 1e-
|
|
414
|
-
|
|
415
|
-
|
|
835
|
+
tol: float = 1e-32,
|
|
836
|
+
ptol: float | None = 1e-32,
|
|
837
|
+
ptol_restart: bool = False,
|
|
838
|
+
gtol: float | None = 1e-32,
|
|
839
|
+
restart_interval: int | None | Literal['auto'] = 'auto',
|
|
416
840
|
beta: float | None = None,
|
|
417
841
|
update_freq: int = 1,
|
|
418
|
-
scale_first: bool =
|
|
419
|
-
scale_second: bool = False,
|
|
842
|
+
scale_first: bool = False,
|
|
420
843
|
concat_params: bool = True,
|
|
421
844
|
inner: Chainable | None = None,
|
|
422
845
|
):
|
|
423
846
|
super().__init__(
|
|
424
847
|
init_scale=init_scale,
|
|
425
848
|
tol=tol,
|
|
426
|
-
|
|
427
|
-
|
|
849
|
+
ptol = ptol,
|
|
850
|
+
ptol_restart=ptol_restart,
|
|
851
|
+
gtol=gtol,
|
|
852
|
+
restart_interval=restart_interval,
|
|
428
853
|
beta=beta,
|
|
429
854
|
update_freq=update_freq,
|
|
430
855
|
scale_first=scale_first,
|
|
431
|
-
scale_second=scale_second,
|
|
432
856
|
concat_params=concat_params,
|
|
433
857
|
inverse=True,
|
|
434
858
|
inner=inner,
|
|
435
859
|
)
|
|
436
860
|
|
|
437
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state,
|
|
861
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
438
862
|
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
439
|
-
H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y
|
|
863
|
+
H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
|
|
440
864
|
state["R"] = R
|
|
441
865
|
return H
|
|
442
866
|
|
|
443
|
-
def
|
|
867
|
+
def reset_P(self, P, s, y, inverse, init_scale, state):
|
|
444
868
|
assert inverse
|
|
445
|
-
|
|
869
|
+
if 'R' not in state: state['R'] = torch.eye(P.size(-1), device=P.device, dtype=P.dtype)
|
|
870
|
+
P.copy_(state["R"])
|
|
446
871
|
|
|
447
872
|
# Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
|
|
448
873
|
def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
|
|
@@ -454,12 +879,10 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
454
879
|
# however p.12 says eps = gs / gHy
|
|
455
880
|
|
|
456
881
|
Hy = H@y
|
|
457
|
-
gHy = g.dot(Hy)
|
|
458
|
-
yHy = y.dot(Hy)
|
|
882
|
+
gHy = safe_clip(g.dot(Hy))
|
|
883
|
+
yHy = safe_clip(y.dot(Hy))
|
|
459
884
|
sy = s.dot(y)
|
|
460
|
-
if sy < tol: return H
|
|
461
|
-
if yHy.abs() < tol: return H
|
|
462
|
-
if gHy.abs() < tol: return H
|
|
885
|
+
if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
|
|
463
886
|
|
|
464
887
|
v_mul = yHy.sqrt()
|
|
465
888
|
v_term1 = s/sy
|
|
@@ -474,28 +897,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
474
897
|
e = gs / gHy
|
|
475
898
|
if switch in (1, 3):
|
|
476
899
|
if e/o <= 1:
|
|
477
|
-
|
|
478
|
-
phi = e/o
|
|
900
|
+
phi = e/safe_clip(o)
|
|
479
901
|
theta = 0
|
|
480
902
|
elif o/t >= 1:
|
|
481
|
-
|
|
482
|
-
phi = o/t
|
|
903
|
+
phi = o/safe_clip(t)
|
|
483
904
|
theta = 1
|
|
484
905
|
else:
|
|
485
906
|
phi = 1
|
|
486
|
-
denom = e*t - o**2
|
|
487
|
-
if denom.abs() <= tol: return H
|
|
907
|
+
denom = safe_clip(e*t - o**2)
|
|
488
908
|
if switch == 1: theta = o * (e - o) / denom
|
|
489
909
|
else: theta = o * (t - o) / denom
|
|
490
910
|
|
|
491
911
|
elif switch == 2:
|
|
492
|
-
|
|
912
|
+
t = safe_clip(t)
|
|
913
|
+
o = safe_clip(o)
|
|
914
|
+
e = safe_clip(e)
|
|
493
915
|
phi = (e / t) ** 0.5
|
|
494
916
|
theta = 1 / (1 + (t*e / o**2)**0.5)
|
|
495
917
|
|
|
496
918
|
elif switch == 4:
|
|
497
|
-
|
|
498
|
-
phi = e/t
|
|
919
|
+
phi = e/safe_clip(t)
|
|
499
920
|
theta = 1/2
|
|
500
921
|
|
|
501
922
|
else: raise ValueError(switch)
|
|
@@ -514,19 +935,30 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
514
935
|
|
|
515
936
|
|
|
516
937
|
class SSVM(HessianUpdateStrategy):
|
|
517
|
-
"""
|
|
938
|
+
"""
|
|
939
|
+
Self-scaling variable metric Quasi-Newton method.
|
|
940
|
+
|
|
941
|
+
Note:
|
|
942
|
+
a line search is recommended.
|
|
943
|
+
|
|
944
|
+
Warning:
|
|
945
|
+
this uses at least O(N^2) memory.
|
|
946
|
+
|
|
947
|
+
Reference:
|
|
948
|
+
Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
|
|
518
949
|
"""
|
|
519
950
|
def __init__(
|
|
520
951
|
self,
|
|
521
952
|
switch: tuple[float,float] | Literal[1,2,3,4] = 3,
|
|
522
953
|
init_scale: float | Literal["auto"] = 'auto',
|
|
523
|
-
tol: float = 1e-
|
|
524
|
-
|
|
525
|
-
|
|
954
|
+
tol: float = 1e-32,
|
|
955
|
+
ptol: float | None = 1e-32,
|
|
956
|
+
ptol_restart: bool = False,
|
|
957
|
+
gtol: float | None = 1e-32,
|
|
958
|
+
restart_interval: int | None = None,
|
|
526
959
|
beta: float | None = None,
|
|
527
960
|
update_freq: int = 1,
|
|
528
|
-
scale_first: bool =
|
|
529
|
-
scale_second: bool = False,
|
|
961
|
+
scale_first: bool = False,
|
|
530
962
|
concat_params: bool = True,
|
|
531
963
|
inner: Chainable | None = None,
|
|
532
964
|
):
|
|
@@ -535,28 +967,28 @@ class SSVM(HessianUpdateStrategy):
|
|
|
535
967
|
defaults=defaults,
|
|
536
968
|
init_scale=init_scale,
|
|
537
969
|
tol=tol,
|
|
538
|
-
|
|
539
|
-
|
|
970
|
+
ptol=ptol,
|
|
971
|
+
ptol_restart=ptol_restart,
|
|
972
|
+
gtol=gtol,
|
|
973
|
+
restart_interval=restart_interval,
|
|
540
974
|
beta=beta,
|
|
541
975
|
update_freq=update_freq,
|
|
542
976
|
scale_first=scale_first,
|
|
543
|
-
scale_second=scale_second,
|
|
544
977
|
concat_params=concat_params,
|
|
545
978
|
inverse=True,
|
|
546
979
|
inner=inner,
|
|
547
980
|
)
|
|
548
981
|
|
|
549
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state,
|
|
550
|
-
return ssvm_H_(H=H, s=s, y=y, g=g, switch=
|
|
982
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
983
|
+
return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])
|
|
551
984
|
|
|
552
985
|
# HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
553
986
|
def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
554
987
|
Hy = H@y
|
|
555
988
|
ys = y.dot(s)
|
|
556
|
-
if ys.abs() <= tol: return H
|
|
989
|
+
if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
|
|
557
990
|
yHy = y.dot(Hy)
|
|
558
|
-
denom = ys + yHy
|
|
559
|
-
if denom.abs() <= tol: return H
|
|
991
|
+
denom = safe_clip(ys + yHy)
|
|
560
992
|
|
|
561
993
|
term1 = 1/denom
|
|
562
994
|
term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
|
|
@@ -569,19 +1001,35 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
569
1001
|
return H
|
|
570
1002
|
|
|
571
1003
|
def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
572
|
-
sy = s.dot(y)
|
|
573
|
-
if sy.abs() < torch.finfo(g[0].dtype).eps: return g
|
|
1004
|
+
sy = safe_clip(s.dot(y))
|
|
574
1005
|
return g - (y * (s.dot(g) / sy))
|
|
575
1006
|
|
|
576
1007
|
|
|
577
1008
|
class GradientCorrection(Transform):
|
|
578
|
-
"""
|
|
1009
|
+
"""
|
|
1010
|
+
Estimates gradient at minima along search direction assuming function is quadratic.
|
|
1011
|
+
|
|
1012
|
+
This can useful as inner module for second order methods with inexact line search.
|
|
1013
|
+
|
|
1014
|
+
## Example:
|
|
1015
|
+
L-BFGS with gradient correction
|
|
579
1016
|
|
|
580
|
-
|
|
1017
|
+
```python
|
|
1018
|
+
opt = tz.Modular(
|
|
1019
|
+
model.parameters(),
|
|
1020
|
+
tz.m.LBFGS(inner=tz.m.GradientCorrection()),
|
|
1021
|
+
tz.m.Backtracking()
|
|
1022
|
+
)
|
|
1023
|
+
```
|
|
1024
|
+
|
|
1025
|
+
Reference:
|
|
1026
|
+
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
1027
|
+
|
|
1028
|
+
"""
|
|
581
1029
|
def __init__(self):
|
|
582
1030
|
super().__init__(None, uses_grad=False)
|
|
583
1031
|
|
|
584
|
-
def
|
|
1032
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
585
1033
|
if 'p_prev' not in states[0]:
|
|
586
1034
|
p_prev = unpack_states(states, tensors, 'p_prev', init=params)
|
|
587
1035
|
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
|
|
@@ -594,15 +1042,27 @@ class GradientCorrection(Transform):
|
|
|
594
1042
|
g_prev.copy_(tensors)
|
|
595
1043
|
return g_hat
|
|
596
1044
|
|
|
597
|
-
class Horisho(
|
|
598
|
-
"""
|
|
599
|
-
|
|
600
|
-
|
|
1045
|
+
class Horisho(_InverseHessianUpdateStrategyDefaults):
|
|
1046
|
+
"""
|
|
1047
|
+
Horisho's variable metric Quasi-Newton method.
|
|
1048
|
+
|
|
1049
|
+
Note:
|
|
1050
|
+
a line search is recommended.
|
|
1051
|
+
|
|
1052
|
+
Warning:
|
|
1053
|
+
this uses at least O(N^2) memory.
|
|
1054
|
+
|
|
1055
|
+
Reference:
|
|
1056
|
+
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
1057
|
+
"""
|
|
1058
|
+
|
|
1059
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1060
|
+
return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
601
1061
|
|
|
602
1062
|
# Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
603
1063
|
def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
604
1064
|
sy = s.dot(y)
|
|
605
|
-
if sy.abs() < tol: return H
|
|
1065
|
+
if sy.abs() < tol: return H # part of algorithm
|
|
606
1066
|
Hy = H @ y
|
|
607
1067
|
|
|
608
1068
|
term1 = (s.outer(y) @ H).div_(sy)
|
|
@@ -613,16 +1073,27 @@ def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float)
|
|
|
613
1073
|
H -= (term1 + term2 - term4.mul_(term3))
|
|
614
1074
|
return H
|
|
615
1075
|
|
|
616
|
-
class FletcherVMM(
|
|
617
|
-
"""
|
|
618
|
-
|
|
619
|
-
|
|
1076
|
+
class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
|
|
1077
|
+
"""
|
|
1078
|
+
Fletcher's variable metric Quasi-Newton method.
|
|
1079
|
+
|
|
1080
|
+
Note:
|
|
1081
|
+
a line search is recommended.
|
|
1082
|
+
|
|
1083
|
+
Warning:
|
|
1084
|
+
this uses at least O(N^2) memory.
|
|
1085
|
+
|
|
1086
|
+
Reference:
|
|
1087
|
+
Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
1088
|
+
"""
|
|
1089
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1090
|
+
return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
620
1091
|
|
|
621
1092
|
|
|
622
1093
|
# Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
623
1094
|
def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
|
|
624
1095
|
sy = s.dot(y)
|
|
625
|
-
if sy < tol: return H
|
|
1096
|
+
if sy < tol: return H # part of algorithm
|
|
626
1097
|
|
|
627
1098
|
term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
|
|
628
1099
|
|
|
@@ -644,20 +1115,29 @@ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol:
|
|
|
644
1115
|
|
|
645
1116
|
|
|
646
1117
|
class NewSSM(HessianUpdateStrategy):
|
|
647
|
-
"""Self-scaling method
|
|
1118
|
+
"""Self-scaling Quasi-Newton method.
|
|
1119
|
+
|
|
1120
|
+
Note:
|
|
1121
|
+
a line search such as ``tz.m.StrongWolfe()`` is required.
|
|
1122
|
+
|
|
1123
|
+
Warning:
|
|
1124
|
+
this uses roughly O(N^2) memory.
|
|
648
1125
|
|
|
649
|
-
|
|
1126
|
+
Reference:
|
|
1127
|
+
Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
1128
|
+
"""
|
|
650
1129
|
def __init__(
|
|
651
1130
|
self,
|
|
652
1131
|
type: Literal[1, 2] = 1,
|
|
653
1132
|
init_scale: float | Literal["auto"] = "auto",
|
|
654
|
-
tol: float = 1e-
|
|
655
|
-
|
|
656
|
-
|
|
1133
|
+
tol: float = 1e-32,
|
|
1134
|
+
ptol: float | None = 1e-32,
|
|
1135
|
+
ptol_restart: bool = False,
|
|
1136
|
+
gtol: float | None = 1e-32,
|
|
1137
|
+
restart_interval: int | None = None,
|
|
657
1138
|
beta: float | None = None,
|
|
658
1139
|
update_freq: int = 1,
|
|
659
|
-
scale_first: bool =
|
|
660
|
-
scale_second: bool = False,
|
|
1140
|
+
scale_first: bool = False,
|
|
661
1141
|
concat_params: bool = True,
|
|
662
1142
|
inner: Chainable | None = None,
|
|
663
1143
|
):
|
|
@@ -665,19 +1145,87 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
665
1145
|
defaults=dict(type=type),
|
|
666
1146
|
init_scale=init_scale,
|
|
667
1147
|
tol=tol,
|
|
668
|
-
|
|
669
|
-
|
|
1148
|
+
ptol=ptol,
|
|
1149
|
+
ptol_restart=ptol_restart,
|
|
1150
|
+
gtol=gtol,
|
|
1151
|
+
restart_interval=restart_interval,
|
|
670
1152
|
beta=beta,
|
|
671
1153
|
update_freq=update_freq,
|
|
672
1154
|
scale_first=scale_first,
|
|
673
|
-
scale_second=scale_second,
|
|
674
1155
|
concat_params=concat_params,
|
|
675
1156
|
inverse=True,
|
|
676
1157
|
inner=inner,
|
|
677
1158
|
)
|
|
678
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state,
|
|
1159
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
679
1160
|
f = state['f']
|
|
680
1161
|
f_prev = state['f_prev']
|
|
681
|
-
return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=
|
|
1162
|
+
return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])
|
|
1163
|
+
|
|
1164
|
+
# ---------------------------- Shor’s r-algorithm ---------------------------- #
|
|
1165
|
+
# def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
|
|
1166
|
+
# r = B.T @ y
|
|
1167
|
+
# r /= torch.linalg.vector_norm(r).clip(min=1e-32) # pylint:disable=not-callable
|
|
1168
|
+
|
|
1169
|
+
# I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
|
|
1170
|
+
# return B @ (I - gamma*r.outer(r))
|
|
1171
|
+
|
|
1172
|
+
# this is supposed to be equivalent (and it is)
|
|
1173
|
+
def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
|
|
1174
|
+
p = H@y
|
|
1175
|
+
#(1-y)^2 (ppT)/(pTq)
|
|
1176
|
+
#term = p.outer(p).div_(p.dot(y).clip(min=1e-32))
|
|
1177
|
+
term = p.outer(p).div_(safe_clip(p.dot(y)))
|
|
1178
|
+
H.sub_(term, alpha=1-alpha**2)
|
|
1179
|
+
return H
|
|
1180
|
+
|
|
1181
|
+
class ShorR(HessianUpdateStrategy):
|
|
1182
|
+
"""Shor’s r-algorithm.
|
|
1183
|
+
|
|
1184
|
+
Note:
|
|
1185
|
+
A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
|
|
1186
|
+
Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
|
|
1187
|
+
so setting ``a_init`` in the line search is recommended.
|
|
682
1188
|
|
|
1189
|
+
References:
|
|
1190
|
+
S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.
|
|
1191
|
+
|
|
1192
|
+
Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.
|
|
1193
|
+
|
|
1194
|
+
Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
|
|
1195
|
+
"""
|
|
1196
|
+
|
|
1197
|
+
def __init__(
|
|
1198
|
+
self,
|
|
1199
|
+
alpha=0.5,
|
|
1200
|
+
init_scale: float | Literal["auto"] = 1,
|
|
1201
|
+
tol: float = 1e-32,
|
|
1202
|
+
ptol: float | None = 1e-32,
|
|
1203
|
+
ptol_restart: bool = False,
|
|
1204
|
+
gtol: float | None = 1e-32,
|
|
1205
|
+
restart_interval: int | None | Literal['auto'] = None,
|
|
1206
|
+
beta: float | None = None,
|
|
1207
|
+
update_freq: int = 1,
|
|
1208
|
+
scale_first: bool = False,
|
|
1209
|
+
concat_params: bool = True,
|
|
1210
|
+
# inverse: bool = True,
|
|
1211
|
+
inner: Chainable | None = None,
|
|
1212
|
+
):
|
|
1213
|
+
defaults = dict(alpha=alpha)
|
|
1214
|
+
super().__init__(
|
|
1215
|
+
defaults=defaults,
|
|
1216
|
+
init_scale=init_scale,
|
|
1217
|
+
tol=tol,
|
|
1218
|
+
ptol=ptol,
|
|
1219
|
+
ptol_restart=ptol_restart,
|
|
1220
|
+
gtol=gtol,
|
|
1221
|
+
restart_interval=restart_interval,
|
|
1222
|
+
beta=beta,
|
|
1223
|
+
update_freq=update_freq,
|
|
1224
|
+
scale_first=scale_first,
|
|
1225
|
+
concat_params=concat_params,
|
|
1226
|
+
inverse=True,
|
|
1227
|
+
inner=inner,
|
|
1228
|
+
)
|
|
683
1229
|
|
|
1230
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1231
|
+
return shor_r_(H=H, y=y, alpha=setting['alpha'])
|