torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""Use BFGS or maybe SR1."""
|
|
2
|
-
from typing import Any, Literal
|
|
3
2
|
from abc import ABC, abstractmethod
|
|
4
|
-
from collections.abc import Mapping
|
|
3
|
+
from collections.abc import Mapping, Callable
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
5
7
|
import torch
|
|
6
8
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
8
|
-
from ...utils import TensorList, set_storage_
|
|
9
|
+
from ...core import Chainable, Module, TensorwiseTransform, Transform
|
|
10
|
+
from ...utils import TensorList, set_storage_, unpack_states
|
|
11
|
+
from ..functional import safe_scaling_
|
|
12
|
+
|
|
9
13
|
|
|
10
14
|
def _safe_dict_update_(d1_:dict, d2:dict):
|
|
11
15
|
inter = set(d1_.keys()).intersection(d2.keys())
|
|
@@ -17,14 +21,112 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
|
17
21
|
elif state[key].shape != value.shape: state[key] = value
|
|
18
22
|
else: state[key].lerp_(value, 1-beta)
|
|
19
23
|
|
|
20
|
-
|
|
24
|
+
def _safe_clip(x: torch.Tensor):
|
|
25
|
+
"""makes sure scalar tensor x is not smaller than epsilon"""
|
|
26
|
+
assert x.numel() == 1, x.shape
|
|
27
|
+
eps = torch.finfo(x.dtype).eps ** 2
|
|
28
|
+
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
29
|
+
return x
|
|
30
|
+
|
|
31
|
+
class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
32
|
+
"""Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
|
|
33
|
+
|
|
34
|
+
This is an abstract class, to use it, subclass it and override `update_H` and/or `update_B`.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
defaults (dict | None, optional): defaults. Defaults to None.
|
|
38
|
+
init_scale (float | Literal["auto"], optional):
|
|
39
|
+
initial hessian matrix is set to identity times this.
|
|
40
|
+
|
|
41
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
42
|
+
|
|
43
|
+
Defaults to "auto".
|
|
44
|
+
tol (float, optional):
|
|
45
|
+
algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-8.
|
|
46
|
+
ptol (float | None, optional):
|
|
47
|
+
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
|
|
48
|
+
ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
49
|
+
gtol (float | None, optional):
|
|
50
|
+
tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
|
|
51
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
52
|
+
interval between resetting the hessian approximation.
|
|
53
|
+
|
|
54
|
+
"auto" corresponds to number of decision variables + 1.
|
|
55
|
+
|
|
56
|
+
None - no resets.
|
|
57
|
+
|
|
58
|
+
Defaults to None.
|
|
59
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
60
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
61
|
+
scale_first (bool, optional):
|
|
62
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
63
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
64
|
+
concat_params (bool, optional):
|
|
65
|
+
If true, all parameters are treated as a single vector.
|
|
66
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
67
|
+
inverse (bool, optional):
|
|
68
|
+
set to True if this method uses hessian inverse approximation H and has `update_H` method.
|
|
69
|
+
set to False if this maintains hessian approximation B and has `update_B method`.
|
|
70
|
+
Defaults to True.
|
|
71
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
Implementing BFGS method that maintains an estimate of the hessian inverse (H):
|
|
75
|
+
|
|
76
|
+
.. code-block:: python
|
|
77
|
+
|
|
78
|
+
class BFGS(HessianUpdateStrategy):
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
82
|
+
tol: float = 1e-8,
|
|
83
|
+
ptol: float = 1e-10,
|
|
84
|
+
ptol_reset: bool = False,
|
|
85
|
+
reset_interval: int | None = None,
|
|
86
|
+
beta: float | None = None,
|
|
87
|
+
update_freq: int = 1,
|
|
88
|
+
scale_first: bool = True,
|
|
89
|
+
scale_second: bool = False,
|
|
90
|
+
concat_params: bool = True,
|
|
91
|
+
inner: Chainable | None = None,
|
|
92
|
+
):
|
|
93
|
+
super().__init__(
|
|
94
|
+
defaults=None,
|
|
95
|
+
init_scale=init_scale,
|
|
96
|
+
tol=tol,
|
|
97
|
+
ptol=ptol,
|
|
98
|
+
ptol_reset=ptol_reset,
|
|
99
|
+
reset_interval=reset_interval,
|
|
100
|
+
beta=beta,
|
|
101
|
+
update_freq=update_freq,
|
|
102
|
+
scale_first=scale_first,
|
|
103
|
+
scale_second=scale_second,
|
|
104
|
+
concat_params=concat_params,
|
|
105
|
+
inverse=True,
|
|
106
|
+
inner=inner,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
110
|
+
tol = settings["tol"]
|
|
111
|
+
sy = torch.dot(s, y)
|
|
112
|
+
if sy <= tol: return H
|
|
113
|
+
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
114
|
+
term1 = num1.div_(sy**2)
|
|
115
|
+
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
116
|
+
term2 = num2.div_(sy)
|
|
117
|
+
H += term1.sub_(term2)
|
|
118
|
+
return H
|
|
119
|
+
|
|
120
|
+
"""
|
|
21
121
|
def __init__(
|
|
22
122
|
self,
|
|
23
123
|
defaults: dict | None = None,
|
|
24
124
|
init_scale: float | Literal["auto"] = "auto",
|
|
25
|
-
tol: float = 1e-
|
|
26
|
-
|
|
27
|
-
|
|
125
|
+
tol: float = 1e-8,
|
|
126
|
+
ptol: float | None = 1e-10,
|
|
127
|
+
ptol_reset: bool = False,
|
|
128
|
+
gtol: float | None = 1e-10,
|
|
129
|
+
reset_interval: int | None | Literal['auto'] = None,
|
|
28
130
|
beta: float | None = None,
|
|
29
131
|
update_freq: int = 1,
|
|
30
132
|
scale_first: bool = True,
|
|
@@ -34,9 +136,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
34
136
|
inner: Chainable | None = None,
|
|
35
137
|
):
|
|
36
138
|
if defaults is None: defaults = {}
|
|
37
|
-
_safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol,
|
|
139
|
+
_safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_reset=ptol_reset, gtol=gtol, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
|
|
38
140
|
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
|
|
39
141
|
|
|
142
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool):
|
|
143
|
+
return torch.eye(size, device=device, dtype=dtype)
|
|
144
|
+
|
|
40
145
|
def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
|
|
41
146
|
"""returns multiplier to H or B"""
|
|
42
147
|
ys = y.dot(s)
|
|
@@ -44,47 +149,92 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
44
149
|
if ys != 0 and yy != 0: return yy/ys
|
|
45
150
|
return 1
|
|
46
151
|
|
|
47
|
-
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
|
|
48
|
-
set_storage_(M,
|
|
152
|
+
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
|
|
153
|
+
set_storage_(M, self._init_M(s.numel(), device=M.device, dtype=M.dtype, is_inverse=inverse))
|
|
49
154
|
if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
|
|
50
155
|
if init_scale >= 1:
|
|
51
156
|
if inverse: M /= init_scale
|
|
52
157
|
else: M *= init_scale
|
|
53
158
|
|
|
54
159
|
def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
55
|
-
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any],
|
|
160
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
56
161
|
"""update hessian inverse"""
|
|
57
162
|
raise NotImplementedError
|
|
58
163
|
|
|
59
164
|
def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
60
|
-
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any],
|
|
165
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
61
166
|
"""update hessian"""
|
|
62
167
|
raise NotImplementedError
|
|
63
168
|
|
|
169
|
+
def reset_for_online(self):
|
|
170
|
+
super().reset_for_online()
|
|
171
|
+
self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
|
|
172
|
+
|
|
173
|
+
def get_B(self) -> tuple[torch.Tensor, bool]:
|
|
174
|
+
"""returns (B or H, is_inverse)."""
|
|
175
|
+
state = next(iter(self.state.values()))
|
|
176
|
+
if "B" in state: return state["B"], False
|
|
177
|
+
return state["H"], True
|
|
178
|
+
|
|
179
|
+
def get_H(self) -> tuple[torch.Tensor, bool]:
|
|
180
|
+
"""returns (H or B, is_inverse)."""
|
|
181
|
+
state = next(iter(self.state.values()))
|
|
182
|
+
if "H" in state: return state["H"], False
|
|
183
|
+
return state["B"], True
|
|
184
|
+
|
|
185
|
+
def make_Bv(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
186
|
+
B, is_inverse = self.get_B()
|
|
187
|
+
|
|
188
|
+
if is_inverse:
|
|
189
|
+
H=B
|
|
190
|
+
warnings.warn(f'{self} maintains H, so Bv will be inefficient!')
|
|
191
|
+
def Hxv(v): return torch.linalg.solve_ex(H, v)[0] # pylint:disable=not-callable
|
|
192
|
+
return Hxv
|
|
193
|
+
|
|
194
|
+
def Bv(v): return B@v
|
|
195
|
+
return Bv
|
|
196
|
+
|
|
197
|
+
def make_Hv(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
198
|
+
H, is_inverse = self.get_H()
|
|
199
|
+
|
|
200
|
+
if is_inverse:
|
|
201
|
+
B=H
|
|
202
|
+
warnings.warn(f'{self} maintains B, so Hv will be inefficient!')
|
|
203
|
+
def Bxv(v): return torch.linalg.solve_ex(B, v)[0] # pylint:disable=not-callable
|
|
204
|
+
return Bxv
|
|
205
|
+
|
|
206
|
+
def Hv(v): return H@v
|
|
207
|
+
return Hv
|
|
208
|
+
|
|
64
209
|
@torch.no_grad
|
|
65
|
-
def update_tensor(self, tensor, param, grad, state,
|
|
210
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
66
211
|
p = param.view(-1); g = tensor.view(-1)
|
|
67
|
-
inverse =
|
|
212
|
+
inverse = setting['inverse']
|
|
68
213
|
M_key = 'H' if inverse else 'B'
|
|
69
214
|
M = state.get(M_key, None)
|
|
70
|
-
step = state.get('step', 0)
|
|
71
|
-
state['step'] = step
|
|
72
|
-
init_scale =
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
if
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
215
|
+
step = state.get('step', 0) + 1
|
|
216
|
+
state['step'] = step
|
|
217
|
+
init_scale = setting['init_scale']
|
|
218
|
+
ptol = setting['ptol']
|
|
219
|
+
ptol_reset = setting['ptol_reset']
|
|
220
|
+
gtol = setting['gtol']
|
|
221
|
+
reset_interval = setting['reset_interval']
|
|
222
|
+
if reset_interval == 'auto': reset_interval = tensor.numel() + 1
|
|
223
|
+
|
|
224
|
+
if M is None or 'f_prev' not in state:
|
|
225
|
+
if M is None: # won't be true on reset_for_online
|
|
226
|
+
M = self._init_M(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
|
|
227
|
+
if isinstance(init_scale, (int, float)) and init_scale != 1:
|
|
228
|
+
if inverse: M /= init_scale
|
|
229
|
+
else: M *= init_scale
|
|
82
230
|
|
|
83
231
|
state[M_key] = M
|
|
232
|
+
state['f_prev'] = loss
|
|
84
233
|
state['p_prev'] = p.clone()
|
|
85
234
|
state['g_prev'] = g.clone()
|
|
86
235
|
return
|
|
87
236
|
|
|
237
|
+
state['f'] = loss
|
|
88
238
|
p_prev = state['p_prev']
|
|
89
239
|
g_prev = state['g_prev']
|
|
90
240
|
s: torch.Tensor = p - p_prev
|
|
@@ -92,195 +242,511 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
|
|
|
92
242
|
state['p_prev'].copy_(p)
|
|
93
243
|
state['g_prev'].copy_(g)
|
|
94
244
|
|
|
95
|
-
if reset_interval is not None and step
|
|
96
|
-
self._reset_M_(M, s, y, inverse, init_scale)
|
|
245
|
+
if reset_interval is not None and step % reset_interval == 0:
|
|
246
|
+
self._reset_M_(M, s, y, inverse, init_scale, state)
|
|
247
|
+
return
|
|
248
|
+
|
|
249
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
250
|
+
if ptol is not None and s.abs().max() <= ptol:
|
|
251
|
+
if ptol_reset: self._reset_M_(M, s, y, inverse, init_scale, state) # reset history
|
|
97
252
|
return
|
|
98
253
|
|
|
99
|
-
# tolerance on gradient difference to avoid exploding
|
|
100
|
-
|
|
101
|
-
# reset history
|
|
102
|
-
if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
|
|
254
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
255
|
+
if gtol is not None and y.abs().max() <= gtol:
|
|
103
256
|
return
|
|
104
257
|
|
|
105
|
-
if step ==
|
|
258
|
+
if step == 2 and init_scale == 'auto':
|
|
106
259
|
if inverse: M /= self._get_init_scale(s,y)
|
|
107
260
|
else: M *= self._get_init_scale(s,y)
|
|
108
261
|
|
|
109
|
-
beta =
|
|
262
|
+
beta = setting['beta']
|
|
110
263
|
if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
|
|
111
264
|
|
|
112
265
|
if inverse:
|
|
113
|
-
H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state,
|
|
266
|
+
H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
|
|
114
267
|
_maybe_lerp_(state, 'H', H_new, beta)
|
|
115
268
|
|
|
116
269
|
else:
|
|
117
|
-
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state,
|
|
270
|
+
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
|
|
118
271
|
_maybe_lerp_(state, 'B', B_new, beta)
|
|
119
272
|
|
|
273
|
+
state['f_prev'] = loss
|
|
274
|
+
|
|
275
|
+
def _post_B(self, B: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
276
|
+
"""modifies B before appling the update rule. Must return (B, g)"""
|
|
277
|
+
return B, g
|
|
278
|
+
|
|
279
|
+
def _post_H(self, H: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
280
|
+
"""modifies H before appling the update rule. Must return (H, g)"""
|
|
281
|
+
return H, g
|
|
282
|
+
|
|
120
283
|
@torch.no_grad
|
|
121
|
-
def apply_tensor(self, tensor, param, grad, state,
|
|
284
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
122
285
|
step = state.get('step', 0)
|
|
123
286
|
|
|
124
|
-
if
|
|
125
|
-
|
|
126
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
|
|
127
|
-
tensor = tensor * scale_factor
|
|
287
|
+
if setting['scale_second'] and step == 2:
|
|
288
|
+
tensor = safe_scaling_(tensor)
|
|
128
289
|
|
|
129
|
-
inverse =
|
|
290
|
+
inverse = setting['inverse']
|
|
130
291
|
if inverse:
|
|
131
292
|
H = state['H']
|
|
132
|
-
|
|
293
|
+
H, g = self._post_H(H, tensor.view(-1), state, setting)
|
|
294
|
+
if H.ndim == 1: return g.mul_(H).view_as(tensor)
|
|
295
|
+
return (H @ g).view_as(tensor)
|
|
133
296
|
|
|
134
297
|
B = state['B']
|
|
298
|
+
H, g = self._post_B(B, tensor.view(-1), state, setting)
|
|
299
|
+
|
|
300
|
+
if B.ndim == 1: return g.div_(B).view_as(tensor)
|
|
301
|
+
x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
|
|
302
|
+
if info == 0: return x.view_as(tensor)
|
|
303
|
+
return safe_scaling_(tensor)
|
|
304
|
+
|
|
305
|
+
class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
306
|
+
'''This is :code:`HessianUpdateStrategy` subclass for algorithms with no extra defaults, to skip the lengthy __init__.
|
|
307
|
+
Refer to :code:`HessianUpdateStrategy` documentation.
|
|
308
|
+
|
|
309
|
+
Example:
|
|
310
|
+
Implementing BFGS method that maintains an estimate of the hessian inverse (H):
|
|
311
|
+
|
|
312
|
+
.. code-block:: python
|
|
313
|
+
|
|
314
|
+
class BFGS(_HessianUpdateStrategyDefaults):
|
|
315
|
+
"""Broyden–Fletcher–Goldfarb–Shanno algorithm"""
|
|
316
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
317
|
+
tol = settings["tol"]
|
|
318
|
+
sy = torch.dot(s, y)
|
|
319
|
+
if sy <= tol: return H
|
|
320
|
+
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
321
|
+
term1 = num1.div_(sy**2)
|
|
322
|
+
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
323
|
+
term2 = num2.div_(sy)
|
|
324
|
+
H += term1.sub_(term2)
|
|
325
|
+
return H
|
|
326
|
+
|
|
327
|
+
Make sure to put at least a basic class level docstring to overwrite this.
|
|
328
|
+
'''
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
332
|
+
tol: float = 1e-8,
|
|
333
|
+
ptol: float | None = 1e-10,
|
|
334
|
+
ptol_reset: bool = False,
|
|
335
|
+
gtol: float | None = 1e-10,
|
|
336
|
+
reset_interval: int | None = None,
|
|
337
|
+
beta: float | None = None,
|
|
338
|
+
update_freq: int = 1,
|
|
339
|
+
scale_first: bool = True,
|
|
340
|
+
scale_second: bool = False,
|
|
341
|
+
concat_params: bool = True,
|
|
342
|
+
inverse: bool = True,
|
|
343
|
+
inner: Chainable | None = None,
|
|
344
|
+
):
|
|
345
|
+
super().__init__(
|
|
346
|
+
defaults=None,
|
|
347
|
+
init_scale=init_scale,
|
|
348
|
+
tol=tol,
|
|
349
|
+
ptol=ptol,
|
|
350
|
+
ptol_reset=ptol_reset,
|
|
351
|
+
gtol=gtol,
|
|
352
|
+
reset_interval=reset_interval,
|
|
353
|
+
beta=beta,
|
|
354
|
+
update_freq=update_freq,
|
|
355
|
+
scale_first=scale_first,
|
|
356
|
+
scale_second=scale_second,
|
|
357
|
+
concat_params=concat_params,
|
|
358
|
+
inverse=inverse,
|
|
359
|
+
inner=inner,
|
|
360
|
+
)
|
|
135
361
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
# to avoid typing all arguments for each method
|
|
139
|
-
class HUpdateStrategy(HessianUpdateStrategy):
|
|
362
|
+
class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
140
363
|
def __init__(
|
|
141
364
|
self,
|
|
142
365
|
init_scale: float | Literal["auto"] = "auto",
|
|
143
|
-
tol: float = 1e-
|
|
144
|
-
|
|
366
|
+
tol: float = 1e-8,
|
|
367
|
+
ptol: float | None = 1e-10,
|
|
368
|
+
ptol_reset: bool = False,
|
|
369
|
+
gtol: float | None = 1e-10,
|
|
145
370
|
reset_interval: int | None = None,
|
|
146
371
|
beta: float | None = None,
|
|
147
372
|
update_freq: int = 1,
|
|
148
373
|
scale_first: bool = True,
|
|
149
374
|
scale_second: bool = False,
|
|
150
375
|
concat_params: bool = True,
|
|
376
|
+
inverse: bool = False,
|
|
151
377
|
inner: Chainable | None = None,
|
|
152
378
|
):
|
|
153
379
|
super().__init__(
|
|
154
380
|
defaults=None,
|
|
155
381
|
init_scale=init_scale,
|
|
156
382
|
tol=tol,
|
|
157
|
-
|
|
383
|
+
ptol=ptol,
|
|
384
|
+
ptol_reset=ptol_reset,
|
|
385
|
+
gtol=gtol,
|
|
158
386
|
reset_interval=reset_interval,
|
|
159
387
|
beta=beta,
|
|
160
388
|
update_freq=update_freq,
|
|
161
389
|
scale_first=scale_first,
|
|
162
390
|
scale_second=scale_second,
|
|
163
391
|
concat_params=concat_params,
|
|
164
|
-
inverse=
|
|
392
|
+
inverse=inverse,
|
|
165
393
|
inner=inner,
|
|
166
394
|
)
|
|
395
|
+
|
|
167
396
|
# ----------------------------------- BFGS ----------------------------------- #
|
|
397
|
+
def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
398
|
+
sy = s.dot(y)
|
|
399
|
+
if sy < tol: return B
|
|
400
|
+
|
|
401
|
+
Bs = B@s
|
|
402
|
+
sBs = _safe_clip(s.dot(Bs))
|
|
403
|
+
|
|
404
|
+
term1 = y.outer(y).div_(sy)
|
|
405
|
+
term2 = (Bs.outer(s) @ B.T).div_(sBs)
|
|
406
|
+
B += term1.sub_(term2)
|
|
407
|
+
return B
|
|
408
|
+
|
|
168
409
|
def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
169
|
-
sy =
|
|
170
|
-
if sy <= tol: return H
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
410
|
+
sy = s.dot(y)
|
|
411
|
+
if sy <= tol: return H
|
|
412
|
+
|
|
413
|
+
sy_sq = _safe_clip(sy**2)
|
|
414
|
+
|
|
415
|
+
Hy = H@y
|
|
416
|
+
scale1 = (sy + y.dot(Hy)) / sy_sq
|
|
417
|
+
term1 = s.outer(s).mul_(scale1)
|
|
418
|
+
|
|
419
|
+
num2 = (Hy.outer(s)).add_(s.outer(y @ H))
|
|
174
420
|
term2 = num2.div_(sy)
|
|
421
|
+
|
|
175
422
|
H += term1.sub_(term2)
|
|
176
423
|
return H
|
|
177
424
|
|
|
178
|
-
class BFGS(
|
|
179
|
-
|
|
180
|
-
|
|
425
|
+
class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
426
|
+
"""Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
|
|
427
|
+
|
|
428
|
+
.. note::
|
|
429
|
+
a line search such as :code:`tz.m.StrongWolfe()` is recommended, although this can be stable without a line search. Alternatively warmup :code:`tz.m.Warmup` can stabilize quasi-newton methods without line search.
|
|
430
|
+
|
|
431
|
+
.. warning::
|
|
432
|
+
this uses roughly O(N^2) memory.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
init_scale (float | Literal["auto"], optional):
|
|
436
|
+
initial hessian matrix is set to identity times this.
|
|
437
|
+
|
|
438
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
439
|
+
|
|
440
|
+
Defaults to "auto".
|
|
441
|
+
tol (float, optional):
|
|
442
|
+
tolerance on curvature condition. Defaults to 1e-8.
|
|
443
|
+
ptol (float | None, optional):
|
|
444
|
+
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
445
|
+
Defaults to 1e-10.
|
|
446
|
+
ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
447
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
448
|
+
interval between resetting the hessian approximation.
|
|
449
|
+
|
|
450
|
+
"auto" corresponds to number of decision variables + 1.
|
|
451
|
+
|
|
452
|
+
None - no resets.
|
|
453
|
+
|
|
454
|
+
Defaults to None.
|
|
455
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
456
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
457
|
+
scale_first (bool, optional):
|
|
458
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
459
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
460
|
+
concat_params (bool, optional):
|
|
461
|
+
If true, all parameters are treated as a single vector.
|
|
462
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
463
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
464
|
+
|
|
465
|
+
Examples:
|
|
466
|
+
BFGS with strong-wolfe line search:
|
|
467
|
+
|
|
468
|
+
.. code-block:: python
|
|
469
|
+
|
|
470
|
+
opt = tz.Modular(
|
|
471
|
+
model.parameters(),
|
|
472
|
+
tz.m.BFGS(),
|
|
473
|
+
tz.m.StrongWolfe()
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
BFGS preconditioning applied to momentum:
|
|
477
|
+
|
|
478
|
+
.. code-block:: python
|
|
479
|
+
|
|
480
|
+
opt = tz.Modular(
|
|
481
|
+
model.parameters(),
|
|
482
|
+
tz.m.BFGS(inner=tz.m.EMA(0.9)),
|
|
483
|
+
tz.m.LR(1e-2)
|
|
484
|
+
)
|
|
485
|
+
"""
|
|
486
|
+
|
|
487
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
488
|
+
return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
489
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
490
|
+
return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])
|
|
181
491
|
|
|
182
492
|
# ------------------------------------ SR1 ----------------------------------- #
|
|
183
|
-
def
|
|
493
|
+
def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
184
494
|
z = s - H@y
|
|
185
|
-
denom =
|
|
495
|
+
denom = z.dot(y)
|
|
186
496
|
|
|
187
497
|
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
188
498
|
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
189
499
|
|
|
190
|
-
if y_norm*z_norm < tol: return H
|
|
500
|
+
# if y_norm*z_norm < tol: return H
|
|
191
501
|
|
|
192
502
|
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
193
503
|
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
194
|
-
H +=
|
|
504
|
+
H += z.outer(z).div_(_safe_clip(denom))
|
|
195
505
|
return H
|
|
196
506
|
|
|
197
|
-
class SR1(
|
|
198
|
-
|
|
199
|
-
|
|
507
|
+
class SR1(_InverseHessianUpdateStrategyDefaults):
|
|
508
|
+
"""Symmetric Rank 1 Quasi-Newton method.
|
|
509
|
+
|
|
510
|
+
.. note::
|
|
511
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
512
|
+
|
|
513
|
+
.. note::
|
|
514
|
+
approximate Hessians generated by the SR1 method show faster progress towards the true Hessian than other methods, but it is more unstable. SR1 is best used within a trust region module.
|
|
515
|
+
|
|
516
|
+
.. note::
|
|
517
|
+
SR1 doesn't enforce the hessian estimate to be positive definite, therefore it can generate directions that are not descent directions.
|
|
518
|
+
|
|
519
|
+
.. warning::
|
|
520
|
+
this uses roughly O(N^2) memory.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
init_scale (float | Literal["auto"], optional):
|
|
524
|
+
initial hessian matrix is set to identity times this.
|
|
525
|
+
|
|
526
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
527
|
+
|
|
528
|
+
Defaults to "auto".
|
|
529
|
+
tol (float, optional):
|
|
530
|
+
tolerance for denominator in SR1 update rule as in Nocedal, Wright. “Numerical optimization” 2nd p.146. Defaults to 1e-8.
|
|
531
|
+
ptol (float | None, optional):
|
|
532
|
+
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
533
|
+
Defaults to 1e-10.
|
|
534
|
+
ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
535
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
536
|
+
interval between resetting the hessian approximation.
|
|
537
|
+
|
|
538
|
+
"auto" corresponds to number of decision variables + 1.
|
|
539
|
+
|
|
540
|
+
None - no resets.
|
|
541
|
+
|
|
542
|
+
Defaults to None.
|
|
543
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
544
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
545
|
+
scale_first (bool, optional):
|
|
546
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
547
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
548
|
+
concat_params (bool, optional):
|
|
549
|
+
If true, all parameters are treated as a single vector.
|
|
550
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
551
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
552
|
+
|
|
553
|
+
Examples:
|
|
554
|
+
SR1 with strong-wolfe line search
|
|
555
|
+
|
|
556
|
+
.. code-block:: python
|
|
557
|
+
|
|
558
|
+
opt = tz.Modular(
|
|
559
|
+
model.parameters(),
|
|
560
|
+
tz.m.SR1(),
|
|
561
|
+
tz.m.StrongWolfe()
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
BFGS preconditioning applied to momentum
|
|
565
|
+
|
|
566
|
+
.. code-block:: python
|
|
567
|
+
|
|
568
|
+
opt = tz.Modular(
|
|
569
|
+
model.parameters(),
|
|
570
|
+
tz.m.SR1(inner=tz.m.EMA(0.9)),
|
|
571
|
+
tz.m.LR(1e-2)
|
|
572
|
+
)
|
|
573
|
+
"""
|
|
574
|
+
|
|
575
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
576
|
+
return sr1_(H=H, s=s, y=y, tol=setting['tol'])
|
|
577
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
578
|
+
return sr1_(H=B, s=y, y=s, tol=setting['tol'])
|
|
579
|
+
|
|
200
580
|
|
|
201
|
-
# BFGS has defaults - init_scale = "auto" and scale_second = False
|
|
202
|
-
# SR1 has defaults - init_scale = 1 and scale_second = True
|
|
203
|
-
# basically some methods work better with first and some with second.
|
|
204
|
-
# I inherit from BFGS or SR1 to avoid writing all those arguments again
|
|
205
581
|
# ------------------------------------ DFP ----------------------------------- #
|
|
206
582
|
def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
207
|
-
sy =
|
|
583
|
+
sy = s.dot(y)
|
|
208
584
|
if sy.abs() <= tol: return H
|
|
209
|
-
term1 =
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
585
|
+
term1 = s.outer(s).div_(sy)
|
|
586
|
+
|
|
587
|
+
yHy = _safe_clip(y.dot(H @ y))
|
|
588
|
+
|
|
589
|
+
num = (H @ y).outer(y) @ H
|
|
590
|
+
term2 = num.div_(yHy)
|
|
591
|
+
|
|
214
592
|
H += term1.sub_(term2)
|
|
215
593
|
return H
|
|
216
594
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
595
|
+
def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
596
|
+
sy = s.dot(y)
|
|
597
|
+
if sy.abs() <= tol: return B
|
|
598
|
+
I = torch.eye(B.size(0), device=B.device, dtype=B.dtype)
|
|
599
|
+
sub = y.outer(s).div_(sy)
|
|
600
|
+
term1 = I - sub
|
|
601
|
+
term2 = I.sub_(sub.T)
|
|
602
|
+
term3 = y.outer(y).div_(sy)
|
|
603
|
+
B = (term1 @ B @ term2).add_(term3)
|
|
604
|
+
return B
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
class DFP(_InverseHessianUpdateStrategyDefaults):
|
|
608
|
+
"""Davidon–Fletcher–Powell Quasi-Newton method.
|
|
609
|
+
|
|
610
|
+
.. note::
|
|
611
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
612
|
+
|
|
613
|
+
.. note::
|
|
614
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
615
|
+
|
|
616
|
+
.. warning::
|
|
617
|
+
this uses roughly O(N^2) memory.
|
|
618
|
+
|
|
619
|
+
"""
|
|
620
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
621
|
+
return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
622
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
623
|
+
return dfp_B(B=B, s=s, y=y, tol=setting['tol'])
|
|
220
624
|
|
|
221
625
|
|
|
222
626
|
# formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
223
627
|
# H' = H - (Hy - S)c^T / c^T*y
|
|
224
628
|
# the difference is how `c` is calculated
|
|
225
629
|
|
|
226
|
-
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
630
|
+
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
227
631
|
c = H.T @ s
|
|
228
|
-
|
|
229
|
-
if denom.abs() <= tol: return H
|
|
632
|
+
cy = _safe_clip(c.dot(y))
|
|
230
633
|
num = (H@y).sub_(s).outer(c)
|
|
231
|
-
H -= num/
|
|
634
|
+
H -= num/cy
|
|
232
635
|
return H
|
|
636
|
+
def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
637
|
+
r = y - B@s
|
|
638
|
+
ss = _safe_clip(s.dot(s))
|
|
639
|
+
B += r.outer(s).div_(ss)
|
|
640
|
+
return B
|
|
233
641
|
|
|
234
|
-
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
num = (H@y).sub_(s).outer(c)
|
|
239
|
-
H -= num/denom
|
|
642
|
+
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
643
|
+
yy = _safe_clip(y.dot(y))
|
|
644
|
+
num = (s - (H @ y)).outer(y)
|
|
645
|
+
H += num/yy
|
|
240
646
|
return H
|
|
647
|
+
def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
648
|
+
r = y - B@s
|
|
649
|
+
ys = _safe_clip(y.dot(s))
|
|
650
|
+
B += r.outer(y).div_(ys)
|
|
651
|
+
return B
|
|
241
652
|
|
|
242
|
-
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor
|
|
653
|
+
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
|
|
243
654
|
c = g_prev
|
|
244
|
-
|
|
245
|
-
if denom.abs() <= tol: return H
|
|
655
|
+
cy = _safe_clip(c.dot(y))
|
|
246
656
|
num = (H@y).sub_(s).outer(c)
|
|
247
|
-
H -= num/
|
|
657
|
+
H -= num/cy
|
|
248
658
|
return H
|
|
249
659
|
|
|
250
|
-
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
num =
|
|
255
|
-
H -= num/
|
|
660
|
+
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
661
|
+
Hy = H @ y
|
|
662
|
+
c = H @ Hy # pylint:disable=not-callable
|
|
663
|
+
cy = _safe_clip(c.dot(y))
|
|
664
|
+
num = Hy.sub_(s).outer(c)
|
|
665
|
+
H -= num/cy
|
|
256
666
|
return H
|
|
257
667
|
|
|
258
|
-
class BroydenGood(
|
|
259
|
-
|
|
260
|
-
return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
668
|
+
class BroydenGood(_InverseHessianUpdateStrategyDefaults):
|
|
669
|
+
"""Broyden's "good" Quasi-Newton method.
|
|
261
670
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
671
|
+
.. note::
|
|
672
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
265
673
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
|
|
674
|
+
.. note::
|
|
675
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
269
676
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
677
|
+
.. warning::
|
|
678
|
+
this uses roughly O(N^2) memory.
|
|
679
|
+
|
|
680
|
+
Reference:
|
|
681
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
682
|
+
"""
|
|
683
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
684
|
+
return broyden_good_H_(H=H, s=s, y=y)
|
|
685
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
686
|
+
return broyden_good_B_(B=B, s=s, y=y)
|
|
273
687
|
|
|
688
|
+
class BroydenBad(_InverseHessianUpdateStrategyDefaults):
|
|
689
|
+
"""Broyden's "bad" Quasi-Newton method.
|
|
274
690
|
|
|
275
|
-
|
|
276
|
-
|
|
691
|
+
.. note::
|
|
692
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
277
693
|
|
|
694
|
+
.. note::
|
|
695
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
696
|
+
|
|
697
|
+
.. warning::
|
|
698
|
+
this uses roughly O(N^2) memory.
|
|
699
|
+
|
|
700
|
+
Reference:
|
|
701
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
702
|
+
"""
|
|
703
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
704
|
+
return broyden_bad_H_(H=H, s=s, y=y)
|
|
705
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
706
|
+
return broyden_bad_B_(B=B, s=s, y=y)
|
|
707
|
+
|
|
708
|
+
class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
|
|
709
|
+
"""Greenstadt's first Quasi-Newton method.
|
|
710
|
+
|
|
711
|
+
.. note::
|
|
712
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
713
|
+
|
|
714
|
+
.. note::
|
|
715
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
716
|
+
|
|
717
|
+
.. warning::
|
|
718
|
+
this uses roughly O(N^2) memory.
|
|
719
|
+
|
|
720
|
+
Reference:
|
|
721
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
722
|
+
"""
|
|
723
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
724
|
+
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)
|
|
725
|
+
|
|
726
|
+
class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
|
|
727
|
+
"""Greenstadt's second Quasi-Newton method.
|
|
728
|
+
|
|
729
|
+
.. note::
|
|
730
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
731
|
+
|
|
732
|
+
.. note::
|
|
733
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
734
|
+
|
|
735
|
+
.. warning::
|
|
736
|
+
this uses roughly O(N^2) memory.
|
|
737
|
+
|
|
738
|
+
Reference:
|
|
739
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
740
|
+
|
|
741
|
+
"""
|
|
742
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
743
|
+
return greenstadt2_H_(H=H, s=s, y=y)
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
278
747
|
j = y.abs().argmax()
|
|
279
|
-
u = torch.zeros(n, device=H.device, dtype=H.dtype)
|
|
280
|
-
u[j] = 1.0
|
|
281
748
|
|
|
282
|
-
denom = y[j]
|
|
283
|
-
if denom.abs() < tol: return H
|
|
749
|
+
denom = _safe_clip(y[j])
|
|
284
750
|
|
|
285
751
|
Hy = H @ y.unsqueeze(1)
|
|
286
752
|
num = s.unsqueeze(1) - Hy
|
|
@@ -288,51 +754,178 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
|
|
|
288
754
|
H[:, j] += num.squeeze() / denom
|
|
289
755
|
return H
|
|
290
756
|
|
|
291
|
-
class
|
|
292
|
-
"""
|
|
293
|
-
|
|
294
|
-
|
|
757
|
+
class ICUM(_InverseHessianUpdateStrategyDefaults):
|
|
758
|
+
"""
|
|
759
|
+
Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
|
|
760
|
+
due to only updating one column of the inverse hessian approximation per step.
|
|
761
|
+
|
|
762
|
+
.. note::
|
|
763
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
764
|
+
|
|
765
|
+
.. warning::
|
|
766
|
+
this uses roughly O(N^2) memory.
|
|
295
767
|
|
|
296
|
-
|
|
768
|
+
Reference:
|
|
769
|
+
Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
|
|
770
|
+
"""
|
|
771
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
772
|
+
return icum_H_(H=H, s=s, y=y)
|
|
773
|
+
|
|
774
|
+
def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
297
775
|
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
298
776
|
I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
299
777
|
d = (R + I * (s_norm/2)) @ s
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(denom)))
|
|
778
|
+
ds = _safe_clip(d.dot(s))
|
|
779
|
+
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
|
|
303
780
|
|
|
304
781
|
c = H.T @ d
|
|
305
|
-
|
|
306
|
-
if denom.abs() <= tol: return H, R
|
|
782
|
+
cy = _safe_clip(c.dot(y))
|
|
307
783
|
num = (H@y).sub_(s).outer(c)
|
|
308
|
-
H -= num/
|
|
784
|
+
H -= num/cy
|
|
309
785
|
return H, R
|
|
310
786
|
|
|
311
|
-
class ThomasOptimalMethod(
|
|
312
|
-
"""
|
|
313
|
-
|
|
787
|
+
class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
|
|
788
|
+
"""
|
|
789
|
+
Thomas's "optimal" Quasi-Newton method.
|
|
790
|
+
|
|
791
|
+
.. note::
|
|
792
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
793
|
+
|
|
794
|
+
.. note::
|
|
795
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
796
|
+
|
|
797
|
+
.. warning::
|
|
798
|
+
this uses roughly O(N^2) memory.
|
|
799
|
+
|
|
800
|
+
Reference:
|
|
801
|
+
Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
|
|
802
|
+
"""
|
|
803
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
314
804
|
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
315
|
-
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y
|
|
805
|
+
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
|
|
316
806
|
return H
|
|
317
807
|
|
|
808
|
+
def _reset_M_(self, M, s, y,inverse, init_scale, state):
|
|
809
|
+
super()._reset_M_(M, s, y, inverse, init_scale, state)
|
|
810
|
+
for st in self.state.values():
|
|
811
|
+
st.pop("R", None)
|
|
812
|
+
|
|
318
813
|
# ------------------------ powell's symmetric broyden ------------------------ #
|
|
319
|
-
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor
|
|
814
|
+
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
320
815
|
y_Bs = y - B@s
|
|
321
|
-
ss = s.dot(s)
|
|
322
|
-
if ss.abs() < tol: return B
|
|
816
|
+
ss = _safe_clip(s.dot(s))
|
|
323
817
|
num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
|
|
324
818
|
term1 = num1.div_(ss)
|
|
325
|
-
term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
|
|
819
|
+
term2 = s.outer(s).mul_(y_Bs.dot(s)/(_safe_clip(ss**2)))
|
|
326
820
|
B += term1.sub_(term2)
|
|
327
821
|
return B
|
|
328
822
|
|
|
329
|
-
|
|
823
|
+
# I couldn't find formula for H
|
|
824
|
+
class PSB(_HessianUpdateStrategyDefaults):
|
|
825
|
+
"""Powell's Symmetric Broyden Quasi-Newton method.
|
|
826
|
+
|
|
827
|
+
.. note::
|
|
828
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
829
|
+
|
|
830
|
+
.. note::
|
|
831
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
832
|
+
|
|
833
|
+
.. warning::
|
|
834
|
+
this uses roughly O(N^2) memory.
|
|
835
|
+
|
|
836
|
+
Reference:
|
|
837
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
838
|
+
"""
|
|
839
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
840
|
+
return psb_B_(B=B, s=s, y=y)
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
# Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
|
|
844
|
+
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
845
|
+
Hy = H@y
|
|
846
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
847
|
+
num = (s - Hy).outer(Hy)
|
|
848
|
+
H += num.div_(yHy)
|
|
849
|
+
return H
|
|
850
|
+
|
|
851
|
+
class Pearson(_InverseHessianUpdateStrategyDefaults):
|
|
852
|
+
"""
|
|
853
|
+
Pearson's Quasi-Newton method.
|
|
854
|
+
|
|
855
|
+
.. note::
|
|
856
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
857
|
+
|
|
858
|
+
.. note::
|
|
859
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
860
|
+
|
|
861
|
+
.. warning::
|
|
862
|
+
this uses roughly O(N^2) memory.
|
|
863
|
+
|
|
864
|
+
Reference:
|
|
865
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
866
|
+
"""
|
|
867
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
868
|
+
return pearson_H_(H=H, s=s, y=y)
|
|
869
|
+
|
|
870
|
+
def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
871
|
+
sy = _safe_clip(s.dot(y))
|
|
872
|
+
num = (s - H@y).outer(s)
|
|
873
|
+
H += num.div_(sy)
|
|
874
|
+
return H
|
|
875
|
+
|
|
876
|
+
class McCormick(_InverseHessianUpdateStrategyDefaults):
|
|
877
|
+
"""McCormicks's Quasi-Newton method.
|
|
878
|
+
|
|
879
|
+
.. note::
|
|
880
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
881
|
+
|
|
882
|
+
.. note::
|
|
883
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
884
|
+
|
|
885
|
+
.. warning::
|
|
886
|
+
this uses roughly O(N^2) memory.
|
|
887
|
+
|
|
888
|
+
Reference:
|
|
889
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
890
|
+
|
|
891
|
+
This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
|
|
892
|
+
"""
|
|
893
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
894
|
+
return mccormick_H_(H=H, s=s, y=y)
|
|
895
|
+
|
|
896
|
+
def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
897
|
+
Hy = H @ y
|
|
898
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
899
|
+
H -= Hy.outer(Hy) / yHy
|
|
900
|
+
R += (s - R@y).outer(Hy) / yHy
|
|
901
|
+
return H, R
|
|
902
|
+
|
|
903
|
+
class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
904
|
+
"""
|
|
905
|
+
Projected Newton Raphson method.
|
|
906
|
+
|
|
907
|
+
.. note::
|
|
908
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
909
|
+
|
|
910
|
+
.. note::
|
|
911
|
+
this is an experimental method.
|
|
912
|
+
|
|
913
|
+
.. warning::
|
|
914
|
+
this uses roughly O(N^2) memory.
|
|
915
|
+
|
|
916
|
+
Reference:
|
|
917
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
918
|
+
|
|
919
|
+
This one is Algorithm 7.
|
|
920
|
+
"""
|
|
330
921
|
def __init__(
|
|
331
922
|
self,
|
|
332
923
|
init_scale: float | Literal["auto"] = 'auto',
|
|
333
|
-
tol: float = 1e-
|
|
334
|
-
|
|
335
|
-
|
|
924
|
+
tol: float = 1e-8,
|
|
925
|
+
ptol: float | None = 1e-10,
|
|
926
|
+
ptol_reset: bool = False,
|
|
927
|
+
gtol: float | None = 1e-10,
|
|
928
|
+
reset_interval: int | None | Literal['auto'] = 'auto',
|
|
336
929
|
beta: float | None = None,
|
|
337
930
|
update_freq: int = 1,
|
|
338
931
|
scale_first: bool = True,
|
|
@@ -341,34 +934,30 @@ class PSB(HessianUpdateStrategy):
|
|
|
341
934
|
inner: Chainable | None = None,
|
|
342
935
|
):
|
|
343
936
|
super().__init__(
|
|
344
|
-
defaults=None,
|
|
345
937
|
init_scale=init_scale,
|
|
346
938
|
tol=tol,
|
|
347
|
-
|
|
939
|
+
ptol = ptol,
|
|
940
|
+
ptol_reset=ptol_reset,
|
|
941
|
+
gtol=gtol,
|
|
348
942
|
reset_interval=reset_interval,
|
|
349
943
|
beta=beta,
|
|
350
944
|
update_freq=update_freq,
|
|
351
945
|
scale_first=scale_first,
|
|
352
946
|
scale_second=scale_second,
|
|
353
947
|
concat_params=concat_params,
|
|
354
|
-
inverse=
|
|
948
|
+
inverse=True,
|
|
355
949
|
inner=inner,
|
|
356
950
|
)
|
|
357
951
|
|
|
358
|
-
def
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
if sy.abs() <= tol: return H
|
|
364
|
-
num = (s - H@y).outer(s)
|
|
365
|
-
H += num.div_(sy)
|
|
366
|
-
return H
|
|
952
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
953
|
+
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
954
|
+
H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
|
|
955
|
+
state["R"] = R
|
|
956
|
+
return H
|
|
367
957
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
958
|
+
def _reset_M_(self, M, s, y, inverse, init_scale, state):
|
|
959
|
+
assert inverse
|
|
960
|
+
M.copy_(state["R"])
|
|
372
961
|
|
|
373
962
|
# Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
|
|
374
963
|
def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
|
|
@@ -380,12 +969,10 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
380
969
|
# however p.12 says eps = gs / gHy
|
|
381
970
|
|
|
382
971
|
Hy = H@y
|
|
383
|
-
gHy = g.dot(Hy)
|
|
384
|
-
yHy = y.dot(Hy)
|
|
972
|
+
gHy = _safe_clip(g.dot(Hy))
|
|
973
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
385
974
|
sy = s.dot(y)
|
|
386
|
-
if sy < tol: return H
|
|
387
|
-
if yHy.abs() < tol: return H
|
|
388
|
-
if gHy.abs() < tol: return H
|
|
975
|
+
if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
|
|
389
976
|
|
|
390
977
|
v_mul = yHy.sqrt()
|
|
391
978
|
v_term1 = s/sy
|
|
@@ -400,28 +987,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
400
987
|
e = gs / gHy
|
|
401
988
|
if switch in (1, 3):
|
|
402
989
|
if e/o <= 1:
|
|
403
|
-
|
|
404
|
-
phi = e/o
|
|
990
|
+
phi = e/_safe_clip(o)
|
|
405
991
|
theta = 0
|
|
406
992
|
elif o/t >= 1:
|
|
407
|
-
|
|
408
|
-
phi = o/t
|
|
993
|
+
phi = o/_safe_clip(t)
|
|
409
994
|
theta = 1
|
|
410
995
|
else:
|
|
411
996
|
phi = 1
|
|
412
|
-
denom = e*t - o**2
|
|
413
|
-
if denom.abs() <= tol: return H
|
|
997
|
+
denom = _safe_clip(e*t - o**2)
|
|
414
998
|
if switch == 1: theta = o * (e - o) / denom
|
|
415
999
|
else: theta = o * (t - o) / denom
|
|
416
1000
|
|
|
417
1001
|
elif switch == 2:
|
|
418
|
-
|
|
1002
|
+
t = _safe_clip(t)
|
|
1003
|
+
o = _safe_clip(o)
|
|
1004
|
+
e = _safe_clip(e)
|
|
419
1005
|
phi = (e / t) ** 0.5
|
|
420
1006
|
theta = 1 / (1 + (t*e / o**2)**0.5)
|
|
421
1007
|
|
|
422
1008
|
elif switch == 4:
|
|
423
|
-
|
|
424
|
-
phi = e/t
|
|
1009
|
+
phi = e/_safe_clip(t)
|
|
425
1010
|
theta = 1/2
|
|
426
1011
|
|
|
427
1012
|
else: raise ValueError(switch)
|
|
@@ -440,14 +1025,29 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
440
1025
|
|
|
441
1026
|
|
|
442
1027
|
class SSVM(HessianUpdateStrategy):
|
|
443
|
-
"""
|
|
1028
|
+
"""
|
|
1029
|
+
Self-scaling variable metric Quasi-Newton method.
|
|
1030
|
+
|
|
1031
|
+
.. note::
|
|
1032
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
1033
|
+
|
|
1034
|
+
.. note::
|
|
1035
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
1036
|
+
|
|
1037
|
+
.. warning::
|
|
1038
|
+
this uses roughly O(N^2) memory.
|
|
1039
|
+
|
|
1040
|
+
Reference:
|
|
1041
|
+
Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
|
|
444
1042
|
"""
|
|
445
1043
|
def __init__(
|
|
446
1044
|
self,
|
|
447
1045
|
switch: tuple[float,float] | Literal[1,2,3,4] = 3,
|
|
448
1046
|
init_scale: float | Literal["auto"] = 'auto',
|
|
449
|
-
tol: float = 1e-
|
|
450
|
-
|
|
1047
|
+
tol: float = 1e-8,
|
|
1048
|
+
ptol: float | None = 1e-10,
|
|
1049
|
+
ptol_reset: bool = False,
|
|
1050
|
+
gtol: float | None = 1e-10,
|
|
451
1051
|
reset_interval: int | None = None,
|
|
452
1052
|
beta: float | None = None,
|
|
453
1053
|
update_freq: int = 1,
|
|
@@ -461,7 +1061,262 @@ class SSVM(HessianUpdateStrategy):
|
|
|
461
1061
|
defaults=defaults,
|
|
462
1062
|
init_scale=init_scale,
|
|
463
1063
|
tol=tol,
|
|
464
|
-
|
|
1064
|
+
ptol=ptol,
|
|
1065
|
+
ptol_reset=ptol_reset,
|
|
1066
|
+
gtol=gtol,
|
|
1067
|
+
reset_interval=reset_interval,
|
|
1068
|
+
beta=beta,
|
|
1069
|
+
update_freq=update_freq,
|
|
1070
|
+
scale_first=scale_first,
|
|
1071
|
+
scale_second=scale_second,
|
|
1072
|
+
concat_params=concat_params,
|
|
1073
|
+
inverse=True,
|
|
1074
|
+
inner=inner,
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1078
|
+
return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])
|
|
1079
|
+
|
|
1080
|
+
# HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
1081
|
+
def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
1082
|
+
Hy = H@y
|
|
1083
|
+
ys = y.dot(s)
|
|
1084
|
+
if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
|
|
1085
|
+
yHy = y.dot(Hy)
|
|
1086
|
+
denom = _safe_clip(ys + yHy)
|
|
1087
|
+
|
|
1088
|
+
term1 = 1/denom
|
|
1089
|
+
term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
|
|
1090
|
+
term3 = s.outer(y) @ H
|
|
1091
|
+
term4 = Hy.outer(s)
|
|
1092
|
+
term5 = Hy.outer(y) @ H
|
|
1093
|
+
|
|
1094
|
+
inner_term = term2 - term3 - term4 - term5
|
|
1095
|
+
H += inner_term.mul_(term1)
|
|
1096
|
+
return H
|
|
1097
|
+
|
|
1098
|
+
def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
1099
|
+
sy = _safe_clip(s.dot(y))
|
|
1100
|
+
return g - (y * (s.dot(g) / sy))
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
class GradientCorrection(Transform):
|
|
1104
|
+
"""
|
|
1105
|
+
Estimates gradient at minima along search direction assuming function is quadratic.
|
|
1106
|
+
|
|
1107
|
+
This can useful as inner module for second order methods with inexact line search.
|
|
1108
|
+
|
|
1109
|
+
Example:
|
|
1110
|
+
L-BFGS with gradient correction
|
|
1111
|
+
|
|
1112
|
+
.. code-block :: python
|
|
1113
|
+
|
|
1114
|
+
opt = tz.Modular(
|
|
1115
|
+
model.parameters(),
|
|
1116
|
+
tz.m.LBFGS(inner=tz.m.GradientCorrection()),
|
|
1117
|
+
tz.m.Backtracking()
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
Reference:
|
|
1121
|
+
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
1122
|
+
|
|
1123
|
+
"""
|
|
1124
|
+
def __init__(self):
|
|
1125
|
+
super().__init__(None, uses_grad=False)
|
|
1126
|
+
|
|
1127
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
1128
|
+
if 'p_prev' not in states[0]:
|
|
1129
|
+
p_prev = unpack_states(states, tensors, 'p_prev', init=params)
|
|
1130
|
+
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
|
|
1131
|
+
return tensors
|
|
1132
|
+
|
|
1133
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
|
|
1134
|
+
g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)
|
|
1135
|
+
|
|
1136
|
+
p_prev.copy_(params)
|
|
1137
|
+
g_prev.copy_(tensors)
|
|
1138
|
+
return g_hat
|
|
1139
|
+
|
|
1140
|
+
class Horisho(_InverseHessianUpdateStrategyDefaults):
|
|
1141
|
+
"""
|
|
1142
|
+
Horisho's variable metric Quasi-Newton method.
|
|
1143
|
+
|
|
1144
|
+
.. note::
|
|
1145
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
1146
|
+
|
|
1147
|
+
.. note::
|
|
1148
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
1149
|
+
|
|
1150
|
+
.. warning::
|
|
1151
|
+
this uses roughly O(N^2) memory.
|
|
1152
|
+
|
|
1153
|
+
Reference:
|
|
1154
|
+
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
1155
|
+
"""
|
|
1156
|
+
|
|
1157
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1158
|
+
return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
1159
|
+
|
|
1160
|
+
# Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
1161
|
+
def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
1162
|
+
sy = s.dot(y)
|
|
1163
|
+
if sy.abs() < tol: return H # part of algorithm
|
|
1164
|
+
Hy = H @ y
|
|
1165
|
+
|
|
1166
|
+
term1 = (s.outer(y) @ H).div_(sy)
|
|
1167
|
+
term2 = (Hy.outer(s)).div_(sy)
|
|
1168
|
+
term3 = 1 + (y.dot(Hy) / sy)
|
|
1169
|
+
term4 = s.outer(s).div_(sy)
|
|
1170
|
+
|
|
1171
|
+
H -= (term1 + term2 - term4.mul_(term3))
|
|
1172
|
+
return H
|
|
1173
|
+
|
|
1174
|
+
class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
|
|
1175
|
+
"""
|
|
1176
|
+
Fletcher's variable metric Quasi-Newton method.
|
|
1177
|
+
|
|
1178
|
+
.. note::
|
|
1179
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
1180
|
+
|
|
1181
|
+
.. note::
|
|
1182
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
1183
|
+
|
|
1184
|
+
.. warning::
|
|
1185
|
+
this uses roughly O(N^2) memory.
|
|
1186
|
+
|
|
1187
|
+
Reference:
|
|
1188
|
+
Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
1189
|
+
"""
|
|
1190
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1191
|
+
return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
1192
|
+
|
|
1193
|
+
|
|
1194
|
+
# Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
1195
|
+
def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
|
|
1196
|
+
sy = s.dot(y)
|
|
1197
|
+
if sy < tol: return H # part of algorithm
|
|
1198
|
+
|
|
1199
|
+
term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
|
|
1200
|
+
|
|
1201
|
+
if type == 1:
|
|
1202
|
+
pba = (2*sy + 2*(f-f_prev)) / sy
|
|
1203
|
+
|
|
1204
|
+
elif type == 2:
|
|
1205
|
+
pba = (f_prev - f + 1/(2*sy)) / sy
|
|
1206
|
+
|
|
1207
|
+
else:
|
|
1208
|
+
raise RuntimeError(type)
|
|
1209
|
+
|
|
1210
|
+
term3 = 1/pba + y.dot(H@y) / sy
|
|
1211
|
+
term4 = s.outer(s) / sy
|
|
1212
|
+
|
|
1213
|
+
H.sub_(term1)
|
|
1214
|
+
H.add_(term4.mul_(term3))
|
|
1215
|
+
return H
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
class NewSSM(HessianUpdateStrategy):
|
|
1219
|
+
"""Self-scaling Quasi-Newton method.
|
|
1220
|
+
|
|
1221
|
+
.. note::
|
|
1222
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
|
|
1223
|
+
|
|
1224
|
+
.. warning::
|
|
1225
|
+
this uses roughly O(N^2) memory.
|
|
1226
|
+
|
|
1227
|
+
Reference:
|
|
1228
|
+
Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
1229
|
+
"""
|
|
1230
|
+
def __init__(
|
|
1231
|
+
self,
|
|
1232
|
+
type: Literal[1, 2] = 1,
|
|
1233
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
1234
|
+
tol: float = 1e-8,
|
|
1235
|
+
ptol: float | None = 1e-10,
|
|
1236
|
+
ptol_reset: bool = False,
|
|
1237
|
+
gtol: float | None = 1e-10,
|
|
1238
|
+
reset_interval: int | None = None,
|
|
1239
|
+
beta: float | None = None,
|
|
1240
|
+
update_freq: int = 1,
|
|
1241
|
+
scale_first: bool = True,
|
|
1242
|
+
scale_second: bool = False,
|
|
1243
|
+
concat_params: bool = True,
|
|
1244
|
+
inner: Chainable | None = None,
|
|
1245
|
+
):
|
|
1246
|
+
super().__init__(
|
|
1247
|
+
defaults=dict(type=type),
|
|
1248
|
+
init_scale=init_scale,
|
|
1249
|
+
tol=tol,
|
|
1250
|
+
ptol=ptol,
|
|
1251
|
+
ptol_reset=ptol_reset,
|
|
1252
|
+
gtol=gtol,
|
|
1253
|
+
reset_interval=reset_interval,
|
|
1254
|
+
beta=beta,
|
|
1255
|
+
update_freq=update_freq,
|
|
1256
|
+
scale_first=scale_first,
|
|
1257
|
+
scale_second=scale_second,
|
|
1258
|
+
concat_params=concat_params,
|
|
1259
|
+
inverse=True,
|
|
1260
|
+
inner=inner,
|
|
1261
|
+
)
|
|
1262
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1263
|
+
f = state['f']
|
|
1264
|
+
f_prev = state['f_prev']
|
|
1265
|
+
return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])
|
|
1266
|
+
|
|
1267
|
+
# ---------------------------- Shor’s r-algorithm ---------------------------- #
|
|
1268
|
+
# def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
|
|
1269
|
+
# r = B.T @ y
|
|
1270
|
+
# r /= torch.linalg.vector_norm(r).clip(min=1e-8) # pylint:disable=not-callable
|
|
1271
|
+
|
|
1272
|
+
# I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
|
|
1273
|
+
# return B @ (I - gamma*r.outer(r))
|
|
1274
|
+
|
|
1275
|
+
# this is supposed to be equivalent
|
|
1276
|
+
def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
|
|
1277
|
+
p = H@y
|
|
1278
|
+
#(1-y)^2 (ppT)/(pTq)
|
|
1279
|
+
term = p.outer(p).div_(p.dot(y).clip(min=1e-8))
|
|
1280
|
+
H.sub_(term, alpha=1-alpha**2)
|
|
1281
|
+
return H
|
|
1282
|
+
|
|
1283
|
+
class ShorR(HessianUpdateStrategy):
|
|
1284
|
+
"""Shor’s r-algorithm.
|
|
1285
|
+
|
|
1286
|
+
.. note::
|
|
1287
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
|
|
1288
|
+
|
|
1289
|
+
Reference:
|
|
1290
|
+
Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.
|
|
1291
|
+
|
|
1292
|
+
Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
|
|
1293
|
+
"""
|
|
1294
|
+
|
|
1295
|
+
def __init__(
|
|
1296
|
+
self,
|
|
1297
|
+
alpha=0.5,
|
|
1298
|
+
init_scale: float | Literal["auto"] = 1,
|
|
1299
|
+
tol: float = 1e-8,
|
|
1300
|
+
ptol: float | None = 1e-10,
|
|
1301
|
+
ptol_reset: bool = False,
|
|
1302
|
+
gtol: float | None = 1e-10,
|
|
1303
|
+
reset_interval: int | None | Literal['auto'] = None,
|
|
1304
|
+
beta: float | None = None,
|
|
1305
|
+
update_freq: int = 1,
|
|
1306
|
+
scale_first: bool = False,
|
|
1307
|
+
scale_second: bool = False,
|
|
1308
|
+
concat_params: bool = True,
|
|
1309
|
+
# inverse: bool = True,
|
|
1310
|
+
inner: Chainable | None = None,
|
|
1311
|
+
):
|
|
1312
|
+
defaults = dict(alpha=alpha)
|
|
1313
|
+
super().__init__(
|
|
1314
|
+
defaults=defaults,
|
|
1315
|
+
init_scale=init_scale,
|
|
1316
|
+
tol=tol,
|
|
1317
|
+
ptol=ptol,
|
|
1318
|
+
ptol_reset=ptol_reset,
|
|
1319
|
+
gtol=gtol,
|
|
465
1320
|
reset_interval=reset_interval,
|
|
466
1321
|
beta=beta,
|
|
467
1322
|
update_freq=update_freq,
|
|
@@ -472,5 +1327,5 @@ class SSVM(HessianUpdateStrategy):
|
|
|
472
1327
|
inner=inner,
|
|
473
1328
|
)
|
|
474
1329
|
|
|
475
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state,
|
|
476
|
-
return
|
|
1330
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1331
|
+
return shor_r_(H=H, y=y, alpha=setting['alpha'])
|