torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""Use BFGS or maybe SR1."""
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from collections.abc import Mapping
|
|
3
|
+
from collections.abc import Mapping, Callable
|
|
4
4
|
from typing import Any, Literal
|
|
5
|
+
import warnings
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
|
|
8
9
|
from ...core import Chainable, Module, TensorwiseTransform, Transform
|
|
9
10
|
from ...utils import TensorList, set_storage_, unpack_states
|
|
11
|
+
from ..functional import safe_scaling_
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
def _safe_dict_update_(d1_:dict, d2:dict):
|
|
@@ -19,13 +21,111 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
|
19
21
|
elif state[key].shape != value.shape: state[key] = value
|
|
20
22
|
else: state[key].lerp_(value, 1-beta)
|
|
21
23
|
|
|
24
|
+
def _safe_clip(x: torch.Tensor):
|
|
25
|
+
"""makes sure scalar tensor x is not smaller than epsilon"""
|
|
26
|
+
assert x.numel() == 1, x.shape
|
|
27
|
+
eps = torch.finfo(x.dtype).eps ** 2
|
|
28
|
+
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
29
|
+
return x
|
|
30
|
+
|
|
22
31
|
class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
32
|
+
"""Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
|
|
33
|
+
|
|
34
|
+
This is an abstract class, to use it, subclass it and override `update_H` and/or `update_B`.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
defaults (dict | None, optional): defaults. Defaults to None.
|
|
38
|
+
init_scale (float | Literal["auto"], optional):
|
|
39
|
+
initial hessian matrix is set to identity times this.
|
|
40
|
+
|
|
41
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
42
|
+
|
|
43
|
+
Defaults to "auto".
|
|
44
|
+
tol (float, optional):
|
|
45
|
+
algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-8.
|
|
46
|
+
ptol (float | None, optional):
|
|
47
|
+
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
|
|
48
|
+
ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
49
|
+
gtol (float | None, optional):
|
|
50
|
+
tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
|
|
51
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
52
|
+
interval between resetting the hessian approximation.
|
|
53
|
+
|
|
54
|
+
"auto" corresponds to number of decision variables + 1.
|
|
55
|
+
|
|
56
|
+
None - no resets.
|
|
57
|
+
|
|
58
|
+
Defaults to None.
|
|
59
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
60
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
61
|
+
scale_first (bool, optional):
|
|
62
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
63
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
64
|
+
concat_params (bool, optional):
|
|
65
|
+
If true, all parameters are treated as a single vector.
|
|
66
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
67
|
+
inverse (bool, optional):
|
|
68
|
+
set to True if this method uses hessian inverse approximation H and has `update_H` method.
|
|
69
|
+
set to False if this maintains hessian approximation B and has `update_B method`.
|
|
70
|
+
Defaults to True.
|
|
71
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
Implementing BFGS method that maintains an estimate of the hessian inverse (H):
|
|
75
|
+
|
|
76
|
+
.. code-block:: python
|
|
77
|
+
|
|
78
|
+
class BFGS(HessianUpdateStrategy):
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
82
|
+
tol: float = 1e-8,
|
|
83
|
+
ptol: float = 1e-10,
|
|
84
|
+
ptol_reset: bool = False,
|
|
85
|
+
reset_interval: int | None = None,
|
|
86
|
+
beta: float | None = None,
|
|
87
|
+
update_freq: int = 1,
|
|
88
|
+
scale_first: bool = True,
|
|
89
|
+
scale_second: bool = False,
|
|
90
|
+
concat_params: bool = True,
|
|
91
|
+
inner: Chainable | None = None,
|
|
92
|
+
):
|
|
93
|
+
super().__init__(
|
|
94
|
+
defaults=None,
|
|
95
|
+
init_scale=init_scale,
|
|
96
|
+
tol=tol,
|
|
97
|
+
ptol=ptol,
|
|
98
|
+
ptol_reset=ptol_reset,
|
|
99
|
+
reset_interval=reset_interval,
|
|
100
|
+
beta=beta,
|
|
101
|
+
update_freq=update_freq,
|
|
102
|
+
scale_first=scale_first,
|
|
103
|
+
scale_second=scale_second,
|
|
104
|
+
concat_params=concat_params,
|
|
105
|
+
inverse=True,
|
|
106
|
+
inner=inner,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
110
|
+
tol = settings["tol"]
|
|
111
|
+
sy = torch.dot(s, y)
|
|
112
|
+
if sy <= tol: return H
|
|
113
|
+
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
114
|
+
term1 = num1.div_(sy**2)
|
|
115
|
+
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
116
|
+
term2 = num2.div_(sy)
|
|
117
|
+
H += term1.sub_(term2)
|
|
118
|
+
return H
|
|
119
|
+
|
|
120
|
+
"""
|
|
23
121
|
def __init__(
|
|
24
122
|
self,
|
|
25
123
|
defaults: dict | None = None,
|
|
26
124
|
init_scale: float | Literal["auto"] = "auto",
|
|
27
|
-
tol: float = 1e-
|
|
28
|
-
|
|
125
|
+
tol: float = 1e-8,
|
|
126
|
+
ptol: float | None = 1e-10,
|
|
127
|
+
ptol_reset: bool = False,
|
|
128
|
+
gtol: float | None = 1e-10,
|
|
29
129
|
reset_interval: int | None | Literal['auto'] = None,
|
|
30
130
|
beta: float | None = None,
|
|
31
131
|
update_freq: int = 1,
|
|
@@ -36,9 +136,12 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
36
136
|
inner: Chainable | None = None,
|
|
37
137
|
):
|
|
38
138
|
if defaults is None: defaults = {}
|
|
39
|
-
_safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol,
|
|
139
|
+
_safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_reset=ptol_reset, gtol=gtol, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
|
|
40
140
|
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
|
|
41
141
|
|
|
142
|
+
def _init_M(self, size:int, device, dtype, is_inverse:bool):
|
|
143
|
+
return torch.eye(size, device=device, dtype=dtype)
|
|
144
|
+
|
|
42
145
|
def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
|
|
43
146
|
"""returns multiplier to H or B"""
|
|
44
147
|
ys = y.dot(s)
|
|
@@ -47,41 +150,83 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
47
150
|
return 1
|
|
48
151
|
|
|
49
152
|
def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
|
|
50
|
-
set_storage_(M,
|
|
153
|
+
set_storage_(M, self._init_M(s.numel(), device=M.device, dtype=M.dtype, is_inverse=inverse))
|
|
51
154
|
if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
|
|
52
155
|
if init_scale >= 1:
|
|
53
156
|
if inverse: M /= init_scale
|
|
54
157
|
else: M *= init_scale
|
|
55
158
|
|
|
56
159
|
def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
57
|
-
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any],
|
|
160
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
58
161
|
"""update hessian inverse"""
|
|
59
162
|
raise NotImplementedError
|
|
60
163
|
|
|
61
164
|
def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
62
|
-
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any],
|
|
165
|
+
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
63
166
|
"""update hessian"""
|
|
64
167
|
raise NotImplementedError
|
|
65
168
|
|
|
169
|
+
def reset_for_online(self):
|
|
170
|
+
super().reset_for_online()
|
|
171
|
+
self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
|
|
172
|
+
|
|
173
|
+
def get_B(self) -> tuple[torch.Tensor, bool]:
|
|
174
|
+
"""returns (B or H, is_inverse)."""
|
|
175
|
+
state = next(iter(self.state.values()))
|
|
176
|
+
if "B" in state: return state["B"], False
|
|
177
|
+
return state["H"], True
|
|
178
|
+
|
|
179
|
+
def get_H(self) -> tuple[torch.Tensor, bool]:
|
|
180
|
+
"""returns (H or B, is_inverse)."""
|
|
181
|
+
state = next(iter(self.state.values()))
|
|
182
|
+
if "H" in state: return state["H"], False
|
|
183
|
+
return state["B"], True
|
|
184
|
+
|
|
185
|
+
def make_Bv(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
186
|
+
B, is_inverse = self.get_B()
|
|
187
|
+
|
|
188
|
+
if is_inverse:
|
|
189
|
+
H=B
|
|
190
|
+
warnings.warn(f'{self} maintains H, so Bv will be inefficient!')
|
|
191
|
+
def Hxv(v): return torch.linalg.solve_ex(H, v)[0] # pylint:disable=not-callable
|
|
192
|
+
return Hxv
|
|
193
|
+
|
|
194
|
+
def Bv(v): return B@v
|
|
195
|
+
return Bv
|
|
196
|
+
|
|
197
|
+
def make_Hv(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
198
|
+
H, is_inverse = self.get_H()
|
|
199
|
+
|
|
200
|
+
if is_inverse:
|
|
201
|
+
B=H
|
|
202
|
+
warnings.warn(f'{self} maintains B, so Hv will be inefficient!')
|
|
203
|
+
def Bxv(v): return torch.linalg.solve_ex(B, v)[0] # pylint:disable=not-callable
|
|
204
|
+
return Bxv
|
|
205
|
+
|
|
206
|
+
def Hv(v): return H@v
|
|
207
|
+
return Hv
|
|
208
|
+
|
|
66
209
|
@torch.no_grad
|
|
67
|
-
def update_tensor(self, tensor, param, grad, loss, state,
|
|
210
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
68
211
|
p = param.view(-1); g = tensor.view(-1)
|
|
69
|
-
inverse =
|
|
212
|
+
inverse = setting['inverse']
|
|
70
213
|
M_key = 'H' if inverse else 'B'
|
|
71
214
|
M = state.get(M_key, None)
|
|
72
|
-
step = state.get('step', 0)
|
|
73
|
-
state['step'] = step
|
|
74
|
-
init_scale =
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
215
|
+
step = state.get('step', 0) + 1
|
|
216
|
+
state['step'] = step
|
|
217
|
+
init_scale = setting['init_scale']
|
|
218
|
+
ptol = setting['ptol']
|
|
219
|
+
ptol_reset = setting['ptol_reset']
|
|
220
|
+
gtol = setting['gtol']
|
|
221
|
+
reset_interval = setting['reset_interval']
|
|
78
222
|
if reset_interval == 'auto': reset_interval = tensor.numel() + 1
|
|
79
223
|
|
|
80
|
-
if M is None:
|
|
81
|
-
M
|
|
82
|
-
|
|
83
|
-
if
|
|
84
|
-
|
|
224
|
+
if M is None or 'f_prev' not in state:
|
|
225
|
+
if M is None: # won't be true on reset_for_online
|
|
226
|
+
M = self._init_M(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
|
|
227
|
+
if isinstance(init_scale, (int, float)) and init_scale != 1:
|
|
228
|
+
if inverse: M /= init_scale
|
|
229
|
+
else: M *= init_scale
|
|
85
230
|
|
|
86
231
|
state[M_key] = M
|
|
87
232
|
state['f_prev'] = loss
|
|
@@ -97,190 +242,511 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
97
242
|
state['p_prev'].copy_(p)
|
|
98
243
|
state['g_prev'].copy_(g)
|
|
99
244
|
|
|
100
|
-
if reset_interval is not None and step
|
|
245
|
+
if reset_interval is not None and step % reset_interval == 0:
|
|
101
246
|
self._reset_M_(M, s, y, inverse, init_scale, state)
|
|
102
247
|
return
|
|
103
248
|
|
|
104
|
-
# tolerance on
|
|
105
|
-
if
|
|
106
|
-
# reset history
|
|
107
|
-
|
|
249
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
250
|
+
if ptol is not None and s.abs().max() <= ptol:
|
|
251
|
+
if ptol_reset: self._reset_M_(M, s, y, inverse, init_scale, state) # reset history
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
255
|
+
if gtol is not None and y.abs().max() <= gtol:
|
|
108
256
|
return
|
|
109
257
|
|
|
110
|
-
if step ==
|
|
258
|
+
if step == 2 and init_scale == 'auto':
|
|
111
259
|
if inverse: M /= self._get_init_scale(s,y)
|
|
112
260
|
else: M *= self._get_init_scale(s,y)
|
|
113
261
|
|
|
114
|
-
beta =
|
|
262
|
+
beta = setting['beta']
|
|
115
263
|
if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
|
|
116
264
|
|
|
117
265
|
if inverse:
|
|
118
|
-
H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state,
|
|
266
|
+
H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
|
|
119
267
|
_maybe_lerp_(state, 'H', H_new, beta)
|
|
120
268
|
|
|
121
269
|
else:
|
|
122
|
-
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state,
|
|
270
|
+
B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, setting=setting)
|
|
123
271
|
_maybe_lerp_(state, 'B', B_new, beta)
|
|
124
272
|
|
|
125
273
|
state['f_prev'] = loss
|
|
126
274
|
|
|
275
|
+
def _post_B(self, B: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
276
|
+
"""modifies B before appling the update rule. Must return (B, g)"""
|
|
277
|
+
return B, g
|
|
278
|
+
|
|
279
|
+
def _post_H(self, H: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
280
|
+
"""modifies H before appling the update rule. Must return (H, g)"""
|
|
281
|
+
return H, g
|
|
282
|
+
|
|
127
283
|
@torch.no_grad
|
|
128
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
284
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
129
285
|
step = state.get('step', 0)
|
|
130
286
|
|
|
131
|
-
if
|
|
132
|
-
|
|
133
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensor.dtype).eps)
|
|
134
|
-
tensor = tensor * scale_factor
|
|
287
|
+
if setting['scale_second'] and step == 2:
|
|
288
|
+
tensor = safe_scaling_(tensor)
|
|
135
289
|
|
|
136
|
-
inverse =
|
|
290
|
+
inverse = setting['inverse']
|
|
137
291
|
if inverse:
|
|
138
292
|
H = state['H']
|
|
139
|
-
|
|
293
|
+
H, g = self._post_H(H, tensor.view(-1), state, setting)
|
|
294
|
+
if H.ndim == 1: return g.mul_(H).view_as(tensor)
|
|
295
|
+
return (H @ g).view_as(tensor)
|
|
140
296
|
|
|
141
297
|
B = state['B']
|
|
298
|
+
H, g = self._post_B(B, tensor.view(-1), state, setting)
|
|
299
|
+
|
|
300
|
+
if B.ndim == 1: return g.div_(B).view_as(tensor)
|
|
301
|
+
x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
|
|
302
|
+
if info == 0: return x.view_as(tensor)
|
|
303
|
+
return safe_scaling_(tensor)
|
|
304
|
+
|
|
305
|
+
class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
306
|
+
'''This is :code:`HessianUpdateStrategy` subclass for algorithms with no extra defaults, to skip the lengthy __init__.
|
|
307
|
+
Refer to :code:`HessianUpdateStrategy` documentation.
|
|
308
|
+
|
|
309
|
+
Example:
|
|
310
|
+
Implementing BFGS method that maintains an estimate of the hessian inverse (H):
|
|
311
|
+
|
|
312
|
+
.. code-block:: python
|
|
313
|
+
|
|
314
|
+
class BFGS(_HessianUpdateStrategyDefaults):
|
|
315
|
+
"""Broyden–Fletcher–Goldfarb–Shanno algorithm"""
|
|
316
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
317
|
+
tol = settings["tol"]
|
|
318
|
+
sy = torch.dot(s, y)
|
|
319
|
+
if sy <= tol: return H
|
|
320
|
+
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
321
|
+
term1 = num1.div_(sy**2)
|
|
322
|
+
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
323
|
+
term2 = num2.div_(sy)
|
|
324
|
+
H += term1.sub_(term2)
|
|
325
|
+
return H
|
|
326
|
+
|
|
327
|
+
Make sure to put at least a basic class level docstring to overwrite this.
|
|
328
|
+
'''
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
332
|
+
tol: float = 1e-8,
|
|
333
|
+
ptol: float | None = 1e-10,
|
|
334
|
+
ptol_reset: bool = False,
|
|
335
|
+
gtol: float | None = 1e-10,
|
|
336
|
+
reset_interval: int | None = None,
|
|
337
|
+
beta: float | None = None,
|
|
338
|
+
update_freq: int = 1,
|
|
339
|
+
scale_first: bool = True,
|
|
340
|
+
scale_second: bool = False,
|
|
341
|
+
concat_params: bool = True,
|
|
342
|
+
inverse: bool = True,
|
|
343
|
+
inner: Chainable | None = None,
|
|
344
|
+
):
|
|
345
|
+
super().__init__(
|
|
346
|
+
defaults=None,
|
|
347
|
+
init_scale=init_scale,
|
|
348
|
+
tol=tol,
|
|
349
|
+
ptol=ptol,
|
|
350
|
+
ptol_reset=ptol_reset,
|
|
351
|
+
gtol=gtol,
|
|
352
|
+
reset_interval=reset_interval,
|
|
353
|
+
beta=beta,
|
|
354
|
+
update_freq=update_freq,
|
|
355
|
+
scale_first=scale_first,
|
|
356
|
+
scale_second=scale_second,
|
|
357
|
+
concat_params=concat_params,
|
|
358
|
+
inverse=inverse,
|
|
359
|
+
inner=inner,
|
|
360
|
+
)
|
|
142
361
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
# to avoid typing all arguments for each method
|
|
146
|
-
class HUpdateStrategy(HessianUpdateStrategy):
|
|
362
|
+
class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
147
363
|
def __init__(
|
|
148
364
|
self,
|
|
149
365
|
init_scale: float | Literal["auto"] = "auto",
|
|
150
|
-
tol: float = 1e-
|
|
151
|
-
|
|
366
|
+
tol: float = 1e-8,
|
|
367
|
+
ptol: float | None = 1e-10,
|
|
368
|
+
ptol_reset: bool = False,
|
|
369
|
+
gtol: float | None = 1e-10,
|
|
152
370
|
reset_interval: int | None = None,
|
|
153
371
|
beta: float | None = None,
|
|
154
372
|
update_freq: int = 1,
|
|
155
373
|
scale_first: bool = True,
|
|
156
374
|
scale_second: bool = False,
|
|
157
375
|
concat_params: bool = True,
|
|
376
|
+
inverse: bool = False,
|
|
158
377
|
inner: Chainable | None = None,
|
|
159
378
|
):
|
|
160
379
|
super().__init__(
|
|
161
380
|
defaults=None,
|
|
162
381
|
init_scale=init_scale,
|
|
163
382
|
tol=tol,
|
|
164
|
-
|
|
383
|
+
ptol=ptol,
|
|
384
|
+
ptol_reset=ptol_reset,
|
|
385
|
+
gtol=gtol,
|
|
165
386
|
reset_interval=reset_interval,
|
|
166
387
|
beta=beta,
|
|
167
388
|
update_freq=update_freq,
|
|
168
389
|
scale_first=scale_first,
|
|
169
390
|
scale_second=scale_second,
|
|
170
391
|
concat_params=concat_params,
|
|
171
|
-
inverse=
|
|
392
|
+
inverse=inverse,
|
|
172
393
|
inner=inner,
|
|
173
394
|
)
|
|
395
|
+
|
|
174
396
|
# ----------------------------------- BFGS ----------------------------------- #
|
|
397
|
+
def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
398
|
+
sy = s.dot(y)
|
|
399
|
+
if sy < tol: return B
|
|
400
|
+
|
|
401
|
+
Bs = B@s
|
|
402
|
+
sBs = _safe_clip(s.dot(Bs))
|
|
403
|
+
|
|
404
|
+
term1 = y.outer(y).div_(sy)
|
|
405
|
+
term2 = (Bs.outer(s) @ B.T).div_(sBs)
|
|
406
|
+
B += term1.sub_(term2)
|
|
407
|
+
return B
|
|
408
|
+
|
|
175
409
|
def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
176
|
-
sy =
|
|
177
|
-
if sy <= tol: return H
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
410
|
+
sy = s.dot(y)
|
|
411
|
+
if sy <= tol: return H
|
|
412
|
+
|
|
413
|
+
sy_sq = _safe_clip(sy**2)
|
|
414
|
+
|
|
415
|
+
Hy = H@y
|
|
416
|
+
scale1 = (sy + y.dot(Hy)) / sy_sq
|
|
417
|
+
term1 = s.outer(s).mul_(scale1)
|
|
418
|
+
|
|
419
|
+
num2 = (Hy.outer(s)).add_(s.outer(y @ H))
|
|
181
420
|
term2 = num2.div_(sy)
|
|
421
|
+
|
|
182
422
|
H += term1.sub_(term2)
|
|
183
423
|
return H
|
|
184
424
|
|
|
185
|
-
class BFGS(
|
|
186
|
-
|
|
187
|
-
|
|
425
|
+
class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
426
|
+
"""Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
|
|
427
|
+
|
|
428
|
+
.. note::
|
|
429
|
+
a line search such as :code:`tz.m.StrongWolfe()` is recommended, although this can be stable without a line search. Alternatively warmup :code:`tz.m.Warmup` can stabilize quasi-newton methods without line search.
|
|
430
|
+
|
|
431
|
+
.. warning::
|
|
432
|
+
this uses roughly O(N^2) memory.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
init_scale (float | Literal["auto"], optional):
|
|
436
|
+
initial hessian matrix is set to identity times this.
|
|
437
|
+
|
|
438
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
439
|
+
|
|
440
|
+
Defaults to "auto".
|
|
441
|
+
tol (float, optional):
|
|
442
|
+
tolerance on curvature condition. Defaults to 1e-8.
|
|
443
|
+
ptol (float | None, optional):
|
|
444
|
+
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
445
|
+
Defaults to 1e-10.
|
|
446
|
+
ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
447
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
448
|
+
interval between resetting the hessian approximation.
|
|
449
|
+
|
|
450
|
+
"auto" corresponds to number of decision variables + 1.
|
|
451
|
+
|
|
452
|
+
None - no resets.
|
|
453
|
+
|
|
454
|
+
Defaults to None.
|
|
455
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
456
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
457
|
+
scale_first (bool, optional):
|
|
458
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
459
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
460
|
+
concat_params (bool, optional):
|
|
461
|
+
If true, all parameters are treated as a single vector.
|
|
462
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
463
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
464
|
+
|
|
465
|
+
Examples:
|
|
466
|
+
BFGS with strong-wolfe line search:
|
|
467
|
+
|
|
468
|
+
.. code-block:: python
|
|
469
|
+
|
|
470
|
+
opt = tz.Modular(
|
|
471
|
+
model.parameters(),
|
|
472
|
+
tz.m.BFGS(),
|
|
473
|
+
tz.m.StrongWolfe()
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
BFGS preconditioning applied to momentum:
|
|
477
|
+
|
|
478
|
+
.. code-block:: python
|
|
479
|
+
|
|
480
|
+
opt = tz.Modular(
|
|
481
|
+
model.parameters(),
|
|
482
|
+
tz.m.BFGS(inner=tz.m.EMA(0.9)),
|
|
483
|
+
tz.m.LR(1e-2)
|
|
484
|
+
)
|
|
485
|
+
"""
|
|
486
|
+
|
|
487
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
488
|
+
return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
489
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
490
|
+
return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])
|
|
188
491
|
|
|
189
492
|
# ------------------------------------ SR1 ----------------------------------- #
|
|
190
|
-
def
|
|
493
|
+
def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
191
494
|
z = s - H@y
|
|
192
|
-
denom =
|
|
495
|
+
denom = z.dot(y)
|
|
193
496
|
|
|
194
497
|
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
195
498
|
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
196
499
|
|
|
197
|
-
if y_norm*z_norm < tol: return H
|
|
500
|
+
# if y_norm*z_norm < tol: return H
|
|
198
501
|
|
|
199
502
|
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
200
503
|
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
201
|
-
H +=
|
|
504
|
+
H += z.outer(z).div_(_safe_clip(denom))
|
|
202
505
|
return H
|
|
203
506
|
|
|
204
|
-
class SR1(
|
|
205
|
-
|
|
206
|
-
|
|
507
|
+
class SR1(_InverseHessianUpdateStrategyDefaults):
|
|
508
|
+
"""Symmetric Rank 1 Quasi-Newton method.
|
|
509
|
+
|
|
510
|
+
.. note::
|
|
511
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
512
|
+
|
|
513
|
+
.. note::
|
|
514
|
+
approximate Hessians generated by the SR1 method show faster progress towards the true Hessian than other methods, but it is more unstable. SR1 is best used within a trust region module.
|
|
515
|
+
|
|
516
|
+
.. note::
|
|
517
|
+
SR1 doesn't enforce the hessian estimate to be positive definite, therefore it can generate directions that are not descent directions.
|
|
518
|
+
|
|
519
|
+
.. warning::
|
|
520
|
+
this uses roughly O(N^2) memory.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
init_scale (float | Literal["auto"], optional):
|
|
524
|
+
initial hessian matrix is set to identity times this.
|
|
525
|
+
|
|
526
|
+
"auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.
|
|
527
|
+
|
|
528
|
+
Defaults to "auto".
|
|
529
|
+
tol (float, optional):
|
|
530
|
+
tolerance for denominator in SR1 update rule as in Nocedal, Wright. “Numerical optimization” 2nd p.146. Defaults to 1e-8.
|
|
531
|
+
ptol (float | None, optional):
|
|
532
|
+
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
533
|
+
Defaults to 1e-10.
|
|
534
|
+
ptol_reset (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
535
|
+
reset_interval (int | None | Literal["auto"], optional):
|
|
536
|
+
interval between resetting the hessian approximation.
|
|
537
|
+
|
|
538
|
+
"auto" corresponds to number of decision variables + 1.
|
|
539
|
+
|
|
540
|
+
None - no resets.
|
|
541
|
+
|
|
542
|
+
Defaults to None.
|
|
543
|
+
beta (float | None, optional): momentum on H or B. Defaults to None.
|
|
544
|
+
update_freq (int, optional): frequency of updating H or B. Defaults to 1.
|
|
545
|
+
scale_first (bool, optional):
|
|
546
|
+
whether to downscale first step before hessian approximation becomes available. Defaults to True.
|
|
547
|
+
scale_second (bool, optional): whether to downscale second step. Defaults to False.
|
|
548
|
+
concat_params (bool, optional):
|
|
549
|
+
If true, all parameters are treated as a single vector.
|
|
550
|
+
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
551
|
+
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
552
|
+
|
|
553
|
+
Examples:
|
|
554
|
+
SR1 with strong-wolfe line search
|
|
555
|
+
|
|
556
|
+
.. code-block:: python
|
|
557
|
+
|
|
558
|
+
opt = tz.Modular(
|
|
559
|
+
model.parameters(),
|
|
560
|
+
tz.m.SR1(),
|
|
561
|
+
tz.m.StrongWolfe()
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
BFGS preconditioning applied to momentum
|
|
565
|
+
|
|
566
|
+
.. code-block:: python
|
|
567
|
+
|
|
568
|
+
opt = tz.Modular(
|
|
569
|
+
model.parameters(),
|
|
570
|
+
tz.m.SR1(inner=tz.m.EMA(0.9)),
|
|
571
|
+
tz.m.LR(1e-2)
|
|
572
|
+
)
|
|
573
|
+
"""
|
|
574
|
+
|
|
575
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
576
|
+
return sr1_(H=H, s=s, y=y, tol=setting['tol'])
|
|
577
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
578
|
+
return sr1_(H=B, s=y, y=s, tol=setting['tol'])
|
|
579
|
+
|
|
207
580
|
|
|
208
581
|
# ------------------------------------ DFP ----------------------------------- #
|
|
209
582
|
def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
210
|
-
sy =
|
|
583
|
+
sy = s.dot(y)
|
|
211
584
|
if sy.abs() <= tol: return H
|
|
212
|
-
term1 =
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
585
|
+
term1 = s.outer(s).div_(sy)
|
|
586
|
+
|
|
587
|
+
yHy = _safe_clip(y.dot(H @ y))
|
|
588
|
+
|
|
589
|
+
num = (H @ y).outer(y) @ H
|
|
216
590
|
term2 = num.div_(yHy)
|
|
591
|
+
|
|
217
592
|
H += term1.sub_(term2)
|
|
218
593
|
return H
|
|
219
594
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
595
|
+
def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
596
|
+
sy = s.dot(y)
|
|
597
|
+
if sy.abs() <= tol: return B
|
|
598
|
+
I = torch.eye(B.size(0), device=B.device, dtype=B.dtype)
|
|
599
|
+
sub = y.outer(s).div_(sy)
|
|
600
|
+
term1 = I - sub
|
|
601
|
+
term2 = I.sub_(sub.T)
|
|
602
|
+
term3 = y.outer(y).div_(sy)
|
|
603
|
+
B = (term1 @ B @ term2).add_(term3)
|
|
604
|
+
return B
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
class DFP(_InverseHessianUpdateStrategyDefaults):
|
|
608
|
+
"""Davidon–Fletcher–Powell Quasi-Newton method.
|
|
609
|
+
|
|
610
|
+
.. note::
|
|
611
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
612
|
+
|
|
613
|
+
.. note::
|
|
614
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
615
|
+
|
|
616
|
+
.. warning::
|
|
617
|
+
this uses roughly O(N^2) memory.
|
|
618
|
+
|
|
619
|
+
"""
|
|
620
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
621
|
+
return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
622
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
623
|
+
return dfp_B(B=B, s=s, y=y, tol=setting['tol'])
|
|
223
624
|
|
|
224
625
|
|
|
225
626
|
# formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
226
627
|
# H' = H - (Hy - S)c^T / c^T*y
|
|
227
628
|
# the difference is how `c` is calculated
|
|
228
629
|
|
|
229
|
-
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
630
|
+
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
230
631
|
c = H.T @ s
|
|
231
|
-
cy = c.dot(y)
|
|
232
|
-
if cy.abs() <= tol: return H
|
|
632
|
+
cy = _safe_clip(c.dot(y))
|
|
233
633
|
num = (H@y).sub_(s).outer(c)
|
|
234
634
|
H -= num/cy
|
|
235
635
|
return H
|
|
636
|
+
def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
637
|
+
r = y - B@s
|
|
638
|
+
ss = _safe_clip(s.dot(s))
|
|
639
|
+
B += r.outer(s).div_(ss)
|
|
640
|
+
return B
|
|
236
641
|
|
|
237
|
-
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
num = (H@y).sub_(s).outer(c)
|
|
242
|
-
H -= num/cy
|
|
642
|
+
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
643
|
+
yy = _safe_clip(y.dot(y))
|
|
644
|
+
num = (s - (H @ y)).outer(y)
|
|
645
|
+
H += num/yy
|
|
243
646
|
return H
|
|
647
|
+
def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
648
|
+
r = y - B@s
|
|
649
|
+
ys = _safe_clip(y.dot(s))
|
|
650
|
+
B += r.outer(y).div_(ys)
|
|
651
|
+
return B
|
|
244
652
|
|
|
245
|
-
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor
|
|
653
|
+
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
|
|
246
654
|
c = g_prev
|
|
247
|
-
cy = c.dot(y)
|
|
248
|
-
if cy.abs() <= tol: return H
|
|
655
|
+
cy = _safe_clip(c.dot(y))
|
|
249
656
|
num = (H@y).sub_(s).outer(c)
|
|
250
657
|
H -= num/cy
|
|
251
658
|
return H
|
|
252
659
|
|
|
253
|
-
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
660
|
+
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
254
661
|
Hy = H @ y
|
|
255
662
|
c = H @ Hy # pylint:disable=not-callable
|
|
256
|
-
cy = c.dot(y)
|
|
257
|
-
if cy.abs() <= tol: return H
|
|
663
|
+
cy = _safe_clip(c.dot(y))
|
|
258
664
|
num = Hy.sub_(s).outer(c)
|
|
259
665
|
H -= num/cy
|
|
260
666
|
return H
|
|
261
667
|
|
|
262
|
-
class BroydenGood(
|
|
263
|
-
|
|
264
|
-
|
|
668
|
+
class BroydenGood(_InverseHessianUpdateStrategyDefaults):
|
|
669
|
+
"""Broyden's "good" Quasi-Newton method.
|
|
670
|
+
|
|
671
|
+
.. note::
|
|
672
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
673
|
+
|
|
674
|
+
.. note::
|
|
675
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
676
|
+
|
|
677
|
+
.. warning::
|
|
678
|
+
this uses roughly O(N^2) memory.
|
|
679
|
+
|
|
680
|
+
Reference:
|
|
681
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
682
|
+
"""
|
|
683
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
684
|
+
return broyden_good_H_(H=H, s=s, y=y)
|
|
685
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
686
|
+
return broyden_good_B_(B=B, s=s, y=y)
|
|
687
|
+
|
|
688
|
+
class BroydenBad(_InverseHessianUpdateStrategyDefaults):
|
|
689
|
+
"""Broyden's "bad" Quasi-Newton method.
|
|
690
|
+
|
|
691
|
+
.. note::
|
|
692
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
693
|
+
|
|
694
|
+
.. note::
|
|
695
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
696
|
+
|
|
697
|
+
.. warning::
|
|
698
|
+
this uses roughly O(N^2) memory.
|
|
699
|
+
|
|
700
|
+
Reference:
|
|
701
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
702
|
+
"""
|
|
703
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
704
|
+
return broyden_bad_H_(H=H, s=s, y=y)
|
|
705
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
706
|
+
return broyden_bad_B_(B=B, s=s, y=y)
|
|
707
|
+
|
|
708
|
+
class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
|
|
709
|
+
"""Greenstadt's first Quasi-Newton method.
|
|
710
|
+
|
|
711
|
+
.. note::
|
|
712
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
713
|
+
|
|
714
|
+
.. note::
|
|
715
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
265
716
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
717
|
+
.. warning::
|
|
718
|
+
this uses roughly O(N^2) memory.
|
|
269
719
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
720
|
+
Reference:
|
|
721
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
722
|
+
"""
|
|
723
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
724
|
+
return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)
|
|
725
|
+
|
|
726
|
+
class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
|
|
727
|
+
"""Greenstadt's second Quasi-Newton method.
|
|
728
|
+
|
|
729
|
+
.. note::
|
|
730
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
731
|
+
|
|
732
|
+
.. note::
|
|
733
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
273
734
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
735
|
+
.. warning::
|
|
736
|
+
this uses roughly O(N^2) memory.
|
|
277
737
|
|
|
738
|
+
Reference:
|
|
739
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
278
740
|
|
|
279
|
-
|
|
741
|
+
"""
|
|
742
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
743
|
+
return greenstadt2_H_(H=H, s=s, y=y)
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
280
747
|
j = y.abs().argmax()
|
|
281
748
|
|
|
282
|
-
denom = y[j]
|
|
283
|
-
if denom.abs() < tol: return H
|
|
749
|
+
denom = _safe_clip(y[j])
|
|
284
750
|
|
|
285
751
|
Hy = H @ y.unsqueeze(1)
|
|
286
752
|
num = s.unsqueeze(1) - Hy
|
|
@@ -288,31 +754,55 @@ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float
|
|
|
288
754
|
H[:, j] += num.squeeze() / denom
|
|
289
755
|
return H
|
|
290
756
|
|
|
291
|
-
class
|
|
292
|
-
"""
|
|
293
|
-
|
|
294
|
-
|
|
757
|
+
class ICUM(_InverseHessianUpdateStrategyDefaults):
|
|
758
|
+
"""
|
|
759
|
+
Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
|
|
760
|
+
due to only updating one column of the inverse hessian approximation per step.
|
|
761
|
+
|
|
762
|
+
.. note::
|
|
763
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
764
|
+
|
|
765
|
+
.. warning::
|
|
766
|
+
this uses roughly O(N^2) memory.
|
|
767
|
+
|
|
768
|
+
Reference:
|
|
769
|
+
Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
|
|
770
|
+
"""
|
|
771
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
772
|
+
return icum_H_(H=H, s=s, y=y)
|
|
295
773
|
|
|
296
|
-
def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor
|
|
774
|
+
def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
297
775
|
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
298
776
|
I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
299
777
|
d = (R + I * (s_norm/2)) @ s
|
|
300
|
-
ds = d.dot(s)
|
|
301
|
-
if ds.abs() <= tol: return H, R
|
|
778
|
+
ds = _safe_clip(d.dot(s))
|
|
302
779
|
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
|
|
303
780
|
|
|
304
781
|
c = H.T @ d
|
|
305
|
-
cy = c.dot(y)
|
|
306
|
-
if cy.abs() <= tol: return H, R
|
|
782
|
+
cy = _safe_clip(c.dot(y))
|
|
307
783
|
num = (H@y).sub_(s).outer(c)
|
|
308
784
|
H -= num/cy
|
|
309
785
|
return H, R
|
|
310
786
|
|
|
311
|
-
class ThomasOptimalMethod(
|
|
312
|
-
"""
|
|
313
|
-
|
|
787
|
+
class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
|
|
788
|
+
"""
|
|
789
|
+
Thomas's "optimal" Quasi-Newton method.
|
|
790
|
+
|
|
791
|
+
.. note::
|
|
792
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
793
|
+
|
|
794
|
+
.. note::
|
|
795
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
796
|
+
|
|
797
|
+
.. warning::
|
|
798
|
+
this uses roughly O(N^2) memory.
|
|
799
|
+
|
|
800
|
+
Reference:
|
|
801
|
+
Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
|
|
802
|
+
"""
|
|
803
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
314
804
|
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
315
|
-
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y
|
|
805
|
+
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
|
|
316
806
|
return H
|
|
317
807
|
|
|
318
808
|
def _reset_M_(self, M, s, y,inverse, init_scale, state):
|
|
@@ -321,97 +811,120 @@ class ThomasOptimalMethod(HUpdateStrategy):
|
|
|
321
811
|
st.pop("R", None)
|
|
322
812
|
|
|
323
813
|
# ------------------------ powell's symmetric broyden ------------------------ #
|
|
324
|
-
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor
|
|
814
|
+
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
325
815
|
y_Bs = y - B@s
|
|
326
|
-
ss = s.dot(s)
|
|
327
|
-
if ss.abs() < tol: return B
|
|
816
|
+
ss = _safe_clip(s.dot(s))
|
|
328
817
|
num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
|
|
329
818
|
term1 = num1.div_(ss)
|
|
330
|
-
term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
|
|
819
|
+
term2 = s.outer(s).mul_(y_Bs.dot(s)/(_safe_clip(ss**2)))
|
|
331
820
|
B += term1.sub_(term2)
|
|
332
821
|
return B
|
|
333
822
|
|
|
334
823
|
# I couldn't find formula for H
|
|
335
|
-
class PSB(
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
update_freq: int = 1,
|
|
344
|
-
scale_first: bool = True,
|
|
345
|
-
scale_second: bool = False,
|
|
346
|
-
concat_params: bool = True,
|
|
347
|
-
inner: Chainable | None = None,
|
|
348
|
-
):
|
|
349
|
-
super().__init__(
|
|
350
|
-
defaults=None,
|
|
351
|
-
init_scale=init_scale,
|
|
352
|
-
tol=tol,
|
|
353
|
-
tol_reset=tol_reset,
|
|
354
|
-
reset_interval=reset_interval,
|
|
355
|
-
beta=beta,
|
|
356
|
-
update_freq=update_freq,
|
|
357
|
-
scale_first=scale_first,
|
|
358
|
-
scale_second=scale_second,
|
|
359
|
-
concat_params=concat_params,
|
|
360
|
-
inverse=False,
|
|
361
|
-
inner=inner,
|
|
362
|
-
)
|
|
824
|
+
class PSB(_HessianUpdateStrategyDefaults):
|
|
825
|
+
"""Powell's Symmetric Broyden Quasi-Newton method.
|
|
826
|
+
|
|
827
|
+
.. note::
|
|
828
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
829
|
+
|
|
830
|
+
.. note::
|
|
831
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
363
832
|
|
|
364
|
-
|
|
365
|
-
|
|
833
|
+
.. warning::
|
|
834
|
+
this uses roughly O(N^2) memory.
|
|
835
|
+
|
|
836
|
+
Reference:
|
|
837
|
+
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
838
|
+
"""
|
|
839
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
840
|
+
return psb_B_(B=B, s=s, y=y)
|
|
366
841
|
|
|
367
842
|
|
|
368
843
|
# Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
|
|
369
|
-
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor
|
|
844
|
+
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
370
845
|
Hy = H@y
|
|
371
|
-
yHy = y.dot(Hy)
|
|
372
|
-
if yHy.abs() <= tol: return H
|
|
846
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
373
847
|
num = (s - Hy).outer(Hy)
|
|
374
848
|
H += num.div_(yHy)
|
|
375
849
|
return H
|
|
376
850
|
|
|
377
|
-
class Pearson(
|
|
378
|
-
"""
|
|
851
|
+
class Pearson(_InverseHessianUpdateStrategyDefaults):
|
|
852
|
+
"""
|
|
853
|
+
Pearson's Quasi-Newton method.
|
|
379
854
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
855
|
+
.. note::
|
|
856
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
383
857
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
858
|
+
.. note::
|
|
859
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
860
|
+
|
|
861
|
+
.. warning::
|
|
862
|
+
this uses roughly O(N^2) memory.
|
|
863
|
+
|
|
864
|
+
Reference:
|
|
865
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
866
|
+
"""
|
|
867
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
868
|
+
return pearson_H_(H=H, s=s, y=y)
|
|
869
|
+
|
|
870
|
+
def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
871
|
+
sy = _safe_clip(s.dot(y))
|
|
387
872
|
num = (s - H@y).outer(s)
|
|
388
873
|
H += num.div_(sy)
|
|
389
874
|
return H
|
|
390
875
|
|
|
391
|
-
class McCormick(
|
|
392
|
-
"""
|
|
876
|
+
class McCormick(_InverseHessianUpdateStrategyDefaults):
|
|
877
|
+
"""McCormicks's Quasi-Newton method.
|
|
878
|
+
|
|
879
|
+
.. note::
|
|
880
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
881
|
+
|
|
882
|
+
.. note::
|
|
883
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
393
884
|
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
|
|
885
|
+
.. warning::
|
|
886
|
+
this uses roughly O(N^2) memory.
|
|
397
887
|
|
|
398
|
-
|
|
888
|
+
Reference:
|
|
889
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
890
|
+
|
|
891
|
+
This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
|
|
892
|
+
"""
|
|
893
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
894
|
+
return mccormick_H_(H=H, s=s, y=y)
|
|
895
|
+
|
|
896
|
+
def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
399
897
|
Hy = H @ y
|
|
400
|
-
yHy = y.dot(Hy)
|
|
401
|
-
if yHy.abs() < tol: return H, R
|
|
898
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
402
899
|
H -= Hy.outer(Hy) / yHy
|
|
403
900
|
R += (s - R@y).outer(Hy) / yHy
|
|
404
901
|
return H, R
|
|
405
902
|
|
|
406
903
|
class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
407
|
-
"""
|
|
904
|
+
"""
|
|
905
|
+
Projected Newton Raphson method.
|
|
906
|
+
|
|
907
|
+
.. note::
|
|
908
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
909
|
+
|
|
910
|
+
.. note::
|
|
911
|
+
this is an experimental method.
|
|
408
912
|
|
|
409
|
-
|
|
913
|
+
.. warning::
|
|
914
|
+
this uses roughly O(N^2) memory.
|
|
915
|
+
|
|
916
|
+
Reference:
|
|
917
|
+
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
918
|
+
|
|
919
|
+
This one is Algorithm 7.
|
|
920
|
+
"""
|
|
410
921
|
def __init__(
|
|
411
922
|
self,
|
|
412
923
|
init_scale: float | Literal["auto"] = 'auto',
|
|
413
|
-
tol: float = 1e-
|
|
414
|
-
|
|
924
|
+
tol: float = 1e-8,
|
|
925
|
+
ptol: float | None = 1e-10,
|
|
926
|
+
ptol_reset: bool = False,
|
|
927
|
+
gtol: float | None = 1e-10,
|
|
415
928
|
reset_interval: int | None | Literal['auto'] = 'auto',
|
|
416
929
|
beta: float | None = None,
|
|
417
930
|
update_freq: int = 1,
|
|
@@ -423,7 +936,9 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
|
423
936
|
super().__init__(
|
|
424
937
|
init_scale=init_scale,
|
|
425
938
|
tol=tol,
|
|
426
|
-
|
|
939
|
+
ptol = ptol,
|
|
940
|
+
ptol_reset=ptol_reset,
|
|
941
|
+
gtol=gtol,
|
|
427
942
|
reset_interval=reset_interval,
|
|
428
943
|
beta=beta,
|
|
429
944
|
update_freq=update_freq,
|
|
@@ -434,9 +949,9 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
|
434
949
|
inner=inner,
|
|
435
950
|
)
|
|
436
951
|
|
|
437
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state,
|
|
952
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
438
953
|
if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
439
|
-
H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y
|
|
954
|
+
H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
|
|
440
955
|
state["R"] = R
|
|
441
956
|
return H
|
|
442
957
|
|
|
@@ -454,12 +969,10 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
454
969
|
# however p.12 says eps = gs / gHy
|
|
455
970
|
|
|
456
971
|
Hy = H@y
|
|
457
|
-
gHy = g.dot(Hy)
|
|
458
|
-
yHy = y.dot(Hy)
|
|
972
|
+
gHy = _safe_clip(g.dot(Hy))
|
|
973
|
+
yHy = _safe_clip(y.dot(Hy))
|
|
459
974
|
sy = s.dot(y)
|
|
460
|
-
if sy < tol: return H
|
|
461
|
-
if yHy.abs() < tol: return H
|
|
462
|
-
if gHy.abs() < tol: return H
|
|
975
|
+
if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
|
|
463
976
|
|
|
464
977
|
v_mul = yHy.sqrt()
|
|
465
978
|
v_term1 = s/sy
|
|
@@ -474,28 +987,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
474
987
|
e = gs / gHy
|
|
475
988
|
if switch in (1, 3):
|
|
476
989
|
if e/o <= 1:
|
|
477
|
-
|
|
478
|
-
phi = e/o
|
|
990
|
+
phi = e/_safe_clip(o)
|
|
479
991
|
theta = 0
|
|
480
992
|
elif o/t >= 1:
|
|
481
|
-
|
|
482
|
-
phi = o/t
|
|
993
|
+
phi = o/_safe_clip(t)
|
|
483
994
|
theta = 1
|
|
484
995
|
else:
|
|
485
996
|
phi = 1
|
|
486
|
-
denom = e*t - o**2
|
|
487
|
-
if denom.abs() <= tol: return H
|
|
997
|
+
denom = _safe_clip(e*t - o**2)
|
|
488
998
|
if switch == 1: theta = o * (e - o) / denom
|
|
489
999
|
else: theta = o * (t - o) / denom
|
|
490
1000
|
|
|
491
1001
|
elif switch == 2:
|
|
492
|
-
|
|
1002
|
+
t = _safe_clip(t)
|
|
1003
|
+
o = _safe_clip(o)
|
|
1004
|
+
e = _safe_clip(e)
|
|
493
1005
|
phi = (e / t) ** 0.5
|
|
494
1006
|
theta = 1 / (1 + (t*e / o**2)**0.5)
|
|
495
1007
|
|
|
496
1008
|
elif switch == 4:
|
|
497
|
-
|
|
498
|
-
phi = e/t
|
|
1009
|
+
phi = e/_safe_clip(t)
|
|
499
1010
|
theta = 1/2
|
|
500
1011
|
|
|
501
1012
|
else: raise ValueError(switch)
|
|
@@ -514,14 +1025,29 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
514
1025
|
|
|
515
1026
|
|
|
516
1027
|
class SSVM(HessianUpdateStrategy):
|
|
517
|
-
"""
|
|
1028
|
+
"""
|
|
1029
|
+
Self-scaling variable metric Quasi-Newton method.
|
|
1030
|
+
|
|
1031
|
+
.. note::
|
|
1032
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
1033
|
+
|
|
1034
|
+
.. note::
|
|
1035
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
1036
|
+
|
|
1037
|
+
.. warning::
|
|
1038
|
+
this uses roughly O(N^2) memory.
|
|
1039
|
+
|
|
1040
|
+
Reference:
|
|
1041
|
+
Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
|
|
518
1042
|
"""
|
|
519
1043
|
def __init__(
|
|
520
1044
|
self,
|
|
521
1045
|
switch: tuple[float,float] | Literal[1,2,3,4] = 3,
|
|
522
1046
|
init_scale: float | Literal["auto"] = 'auto',
|
|
523
|
-
tol: float = 1e-
|
|
524
|
-
|
|
1047
|
+
tol: float = 1e-8,
|
|
1048
|
+
ptol: float | None = 1e-10,
|
|
1049
|
+
ptol_reset: bool = False,
|
|
1050
|
+
gtol: float | None = 1e-10,
|
|
525
1051
|
reset_interval: int | None = None,
|
|
526
1052
|
beta: float | None = None,
|
|
527
1053
|
update_freq: int = 1,
|
|
@@ -535,7 +1061,9 @@ class SSVM(HessianUpdateStrategy):
|
|
|
535
1061
|
defaults=defaults,
|
|
536
1062
|
init_scale=init_scale,
|
|
537
1063
|
tol=tol,
|
|
538
|
-
|
|
1064
|
+
ptol=ptol,
|
|
1065
|
+
ptol_reset=ptol_reset,
|
|
1066
|
+
gtol=gtol,
|
|
539
1067
|
reset_interval=reset_interval,
|
|
540
1068
|
beta=beta,
|
|
541
1069
|
update_freq=update_freq,
|
|
@@ -546,17 +1074,16 @@ class SSVM(HessianUpdateStrategy):
|
|
|
546
1074
|
inner=inner,
|
|
547
1075
|
)
|
|
548
1076
|
|
|
549
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state,
|
|
550
|
-
return ssvm_H_(H=H, s=s, y=y, g=g, switch=
|
|
1077
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1078
|
+
return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])
|
|
551
1079
|
|
|
552
1080
|
# HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
553
1081
|
def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
554
1082
|
Hy = H@y
|
|
555
1083
|
ys = y.dot(s)
|
|
556
|
-
if ys.abs() <= tol: return H
|
|
1084
|
+
if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
|
|
557
1085
|
yHy = y.dot(Hy)
|
|
558
|
-
denom = ys + yHy
|
|
559
|
-
if denom.abs() <= tol: return H
|
|
1086
|
+
denom = _safe_clip(ys + yHy)
|
|
560
1087
|
|
|
561
1088
|
term1 = 1/denom
|
|
562
1089
|
term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
|
|
@@ -569,19 +1096,35 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
569
1096
|
return H
|
|
570
1097
|
|
|
571
1098
|
def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
572
|
-
sy = s.dot(y)
|
|
573
|
-
if sy.abs() < torch.finfo(g[0].dtype).eps: return g
|
|
1099
|
+
sy = _safe_clip(s.dot(y))
|
|
574
1100
|
return g - (y * (s.dot(g) / sy))
|
|
575
1101
|
|
|
576
1102
|
|
|
577
1103
|
class GradientCorrection(Transform):
|
|
578
|
-
"""
|
|
1104
|
+
"""
|
|
1105
|
+
Estimates gradient at minima along search direction assuming function is quadratic.
|
|
1106
|
+
|
|
1107
|
+
This can useful as inner module for second order methods with inexact line search.
|
|
1108
|
+
|
|
1109
|
+
Example:
|
|
1110
|
+
L-BFGS with gradient correction
|
|
1111
|
+
|
|
1112
|
+
.. code-block :: python
|
|
1113
|
+
|
|
1114
|
+
opt = tz.Modular(
|
|
1115
|
+
model.parameters(),
|
|
1116
|
+
tz.m.LBFGS(inner=tz.m.GradientCorrection()),
|
|
1117
|
+
tz.m.Backtracking()
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
Reference:
|
|
1121
|
+
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
579
1122
|
|
|
580
|
-
|
|
1123
|
+
"""
|
|
581
1124
|
def __init__(self):
|
|
582
1125
|
super().__init__(None, uses_grad=False)
|
|
583
1126
|
|
|
584
|
-
def
|
|
1127
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
585
1128
|
if 'p_prev' not in states[0]:
|
|
586
1129
|
p_prev = unpack_states(states, tensors, 'p_prev', init=params)
|
|
587
1130
|
g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
|
|
@@ -594,15 +1137,30 @@ class GradientCorrection(Transform):
|
|
|
594
1137
|
g_prev.copy_(tensors)
|
|
595
1138
|
return g_hat
|
|
596
1139
|
|
|
597
|
-
class Horisho(
|
|
598
|
-
"""
|
|
599
|
-
|
|
600
|
-
|
|
1140
|
+
class Horisho(_InverseHessianUpdateStrategyDefaults):
|
|
1141
|
+
"""
|
|
1142
|
+
Horisho's variable metric Quasi-Newton method.
|
|
1143
|
+
|
|
1144
|
+
.. note::
|
|
1145
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
1146
|
+
|
|
1147
|
+
.. note::
|
|
1148
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
1149
|
+
|
|
1150
|
+
.. warning::
|
|
1151
|
+
this uses roughly O(N^2) memory.
|
|
1152
|
+
|
|
1153
|
+
Reference:
|
|
1154
|
+
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
1155
|
+
"""
|
|
1156
|
+
|
|
1157
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1158
|
+
return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
601
1159
|
|
|
602
1160
|
# Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
603
1161
|
def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
604
1162
|
sy = s.dot(y)
|
|
605
|
-
if sy.abs() < tol: return H
|
|
1163
|
+
if sy.abs() < tol: return H # part of algorithm
|
|
606
1164
|
Hy = H @ y
|
|
607
1165
|
|
|
608
1166
|
term1 = (s.outer(y) @ H).div_(sy)
|
|
@@ -613,16 +1171,30 @@ def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float)
|
|
|
613
1171
|
H -= (term1 + term2 - term4.mul_(term3))
|
|
614
1172
|
return H
|
|
615
1173
|
|
|
616
|
-
class FletcherVMM(
|
|
617
|
-
"""
|
|
618
|
-
|
|
619
|
-
|
|
1174
|
+
class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
|
|
1175
|
+
"""
|
|
1176
|
+
Fletcher's variable metric Quasi-Newton method.
|
|
1177
|
+
|
|
1178
|
+
.. note::
|
|
1179
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is highly recommended.
|
|
1180
|
+
|
|
1181
|
+
.. note::
|
|
1182
|
+
BFGS is the recommended QN method and will usually outperform this.
|
|
1183
|
+
|
|
1184
|
+
.. warning::
|
|
1185
|
+
this uses roughly O(N^2) memory.
|
|
1186
|
+
|
|
1187
|
+
Reference:
|
|
1188
|
+
Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
1189
|
+
"""
|
|
1190
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1191
|
+
return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
620
1192
|
|
|
621
1193
|
|
|
622
1194
|
# Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
623
1195
|
def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
|
|
624
1196
|
sy = s.dot(y)
|
|
625
|
-
if sy < tol: return H
|
|
1197
|
+
if sy < tol: return H # part of algorithm
|
|
626
1198
|
|
|
627
1199
|
term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
|
|
628
1200
|
|
|
@@ -644,15 +1216,25 @@ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol:
|
|
|
644
1216
|
|
|
645
1217
|
|
|
646
1218
|
class NewSSM(HessianUpdateStrategy):
|
|
647
|
-
"""Self-scaling method
|
|
1219
|
+
"""Self-scaling Quasi-Newton method.
|
|
1220
|
+
|
|
1221
|
+
.. note::
|
|
1222
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
|
|
648
1223
|
|
|
649
|
-
|
|
1224
|
+
.. warning::
|
|
1225
|
+
this uses roughly O(N^2) memory.
|
|
1226
|
+
|
|
1227
|
+
Reference:
|
|
1228
|
+
Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
|
|
1229
|
+
"""
|
|
650
1230
|
def __init__(
|
|
651
1231
|
self,
|
|
652
1232
|
type: Literal[1, 2] = 1,
|
|
653
1233
|
init_scale: float | Literal["auto"] = "auto",
|
|
654
|
-
tol: float = 1e-
|
|
655
|
-
|
|
1234
|
+
tol: float = 1e-8,
|
|
1235
|
+
ptol: float | None = 1e-10,
|
|
1236
|
+
ptol_reset: bool = False,
|
|
1237
|
+
gtol: float | None = 1e-10,
|
|
656
1238
|
reset_interval: int | None = None,
|
|
657
1239
|
beta: float | None = None,
|
|
658
1240
|
update_freq: int = 1,
|
|
@@ -665,7 +1247,9 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
665
1247
|
defaults=dict(type=type),
|
|
666
1248
|
init_scale=init_scale,
|
|
667
1249
|
tol=tol,
|
|
668
|
-
|
|
1250
|
+
ptol=ptol,
|
|
1251
|
+
ptol_reset=ptol_reset,
|
|
1252
|
+
gtol=gtol,
|
|
669
1253
|
reset_interval=reset_interval,
|
|
670
1254
|
beta=beta,
|
|
671
1255
|
update_freq=update_freq,
|
|
@@ -675,9 +1259,73 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
675
1259
|
inverse=True,
|
|
676
1260
|
inner=inner,
|
|
677
1261
|
)
|
|
678
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state,
|
|
1262
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
679
1263
|
f = state['f']
|
|
680
1264
|
f_prev = state['f_prev']
|
|
681
|
-
return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=
|
|
1265
|
+
return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])
|
|
1266
|
+
|
|
1267
|
+
# ---------------------------- Shor’s r-algorithm ---------------------------- #
|
|
1268
|
+
# def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
|
|
1269
|
+
# r = B.T @ y
|
|
1270
|
+
# r /= torch.linalg.vector_norm(r).clip(min=1e-8) # pylint:disable=not-callable
|
|
1271
|
+
|
|
1272
|
+
# I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
|
|
1273
|
+
# return B @ (I - gamma*r.outer(r))
|
|
1274
|
+
|
|
1275
|
+
# this is supposed to be equivalent
|
|
1276
|
+
def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
|
|
1277
|
+
p = H@y
|
|
1278
|
+
#(1-y)^2 (ppT)/(pTq)
|
|
1279
|
+
term = p.outer(p).div_(p.dot(y).clip(min=1e-8))
|
|
1280
|
+
H.sub_(term, alpha=1-alpha**2)
|
|
1281
|
+
return H
|
|
1282
|
+
|
|
1283
|
+
class ShorR(HessianUpdateStrategy):
|
|
1284
|
+
"""Shor’s r-algorithm.
|
|
1285
|
+
|
|
1286
|
+
.. note::
|
|
1287
|
+
a line search such as :code:`tz.m.StrongWolfe(plus_minus=True)` is required.
|
|
1288
|
+
|
|
1289
|
+
Reference:
|
|
1290
|
+
Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.
|
|
1291
|
+
|
|
1292
|
+
Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
|
|
1293
|
+
"""
|
|
682
1294
|
|
|
1295
|
+
def __init__(
|
|
1296
|
+
self,
|
|
1297
|
+
alpha=0.5,
|
|
1298
|
+
init_scale: float | Literal["auto"] = 1,
|
|
1299
|
+
tol: float = 1e-8,
|
|
1300
|
+
ptol: float | None = 1e-10,
|
|
1301
|
+
ptol_reset: bool = False,
|
|
1302
|
+
gtol: float | None = 1e-10,
|
|
1303
|
+
reset_interval: int | None | Literal['auto'] = None,
|
|
1304
|
+
beta: float | None = None,
|
|
1305
|
+
update_freq: int = 1,
|
|
1306
|
+
scale_first: bool = False,
|
|
1307
|
+
scale_second: bool = False,
|
|
1308
|
+
concat_params: bool = True,
|
|
1309
|
+
# inverse: bool = True,
|
|
1310
|
+
inner: Chainable | None = None,
|
|
1311
|
+
):
|
|
1312
|
+
defaults = dict(alpha=alpha)
|
|
1313
|
+
super().__init__(
|
|
1314
|
+
defaults=defaults,
|
|
1315
|
+
init_scale=init_scale,
|
|
1316
|
+
tol=tol,
|
|
1317
|
+
ptol=ptol,
|
|
1318
|
+
ptol_reset=ptol_reset,
|
|
1319
|
+
gtol=gtol,
|
|
1320
|
+
reset_interval=reset_interval,
|
|
1321
|
+
beta=beta,
|
|
1322
|
+
update_freq=update_freq,
|
|
1323
|
+
scale_first=scale_first,
|
|
1324
|
+
scale_second=scale_second,
|
|
1325
|
+
concat_params=concat_params,
|
|
1326
|
+
inverse=True,
|
|
1327
|
+
inner=inner,
|
|
1328
|
+
)
|
|
683
1329
|
|
|
1330
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
1331
|
+
return shor_r_(H=H, y=y, alpha=setting['alpha'])
|