torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,37 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from collections.abc import
|
|
3
|
+
from collections.abc import Callable, Mapping
|
|
4
4
|
from typing import Any, Literal
|
|
5
|
-
import warnings
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
7
|
|
|
9
8
|
from ...core import Chainable, Module, TensorwiseTransform, Transform
|
|
10
|
-
from ...utils import TensorList, set_storage_, unpack_states
|
|
11
|
-
from
|
|
9
|
+
from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
|
|
10
|
+
from ...utils.linalg import linear_operator
|
|
11
|
+
from ..functional import initial_step_size, safe_clip
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def _safe_dict_update_(d1_:dict, d2:dict):
|
|
15
|
-
inter = set(d1_.keys()).intersection(d2.keys())
|
|
16
|
-
if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
|
|
17
|
-
d1_.update(d2)
|
|
18
14
|
|
|
19
15
|
def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
|
|
20
16
|
if (beta is None) or (beta == 0) or (key not in state): state[key] = value
|
|
21
17
|
elif state[key].shape != value.shape: state[key] = value
|
|
22
18
|
else: state[key].lerp_(value, 1-beta)
|
|
23
19
|
|
|
24
|
-
def _safe_clip(x: torch.Tensor):
|
|
25
|
-
"""makes sure scalar tensor x is not smaller than epsilon"""
|
|
26
|
-
assert x.numel() == 1, x.shape
|
|
27
|
-
eps = torch.finfo(x.dtype).eps ** 2
|
|
28
|
-
if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
|
|
29
|
-
return x
|
|
30
|
-
|
|
31
20
|
class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
32
21
|
"""Base class for quasi-newton methods that store and update hessian approximation H or inverse B.
|
|
33
22
|
|
|
34
|
-
This is an abstract class, to use it, subclass it and override
|
|
23
|
+
This is an abstract class, to use it, subclass it and override ``update_H`` and/or ``update_B``,
|
|
24
|
+
and if necessary, ``initialize_P``, ``modify_H`` and ``modify_B``.
|
|
35
25
|
|
|
36
26
|
Args:
|
|
37
27
|
defaults (dict | None, optional): defaults. Defaults to None.
|
|
@@ -42,13 +32,13 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
42
32
|
|
|
43
33
|
Defaults to "auto".
|
|
44
34
|
tol (float, optional):
|
|
45
|
-
algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-
|
|
35
|
+
algorithm-dependent tolerance (usually on curvature condition). Defaults to 1e-32.
|
|
46
36
|
ptol (float | None, optional):
|
|
47
|
-
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-
|
|
48
|
-
|
|
37
|
+
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-32.
|
|
38
|
+
ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
49
39
|
gtol (float | None, optional):
|
|
50
|
-
tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-
|
|
51
|
-
|
|
40
|
+
tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-32.
|
|
41
|
+
restart_interval (int | None | Literal["auto"], optional):
|
|
52
42
|
interval between resetting the hessian approximation.
|
|
53
43
|
|
|
54
44
|
"auto" corresponds to number of decision variables + 1.
|
|
@@ -70,141 +60,101 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
70
60
|
Defaults to True.
|
|
71
61
|
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
72
62
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
defaults=None,
|
|
95
|
-
init_scale=init_scale,
|
|
96
|
-
tol=tol,
|
|
97
|
-
ptol=ptol,
|
|
98
|
-
ptol_reset=ptol_reset,
|
|
99
|
-
reset_interval=reset_interval,
|
|
100
|
-
beta=beta,
|
|
101
|
-
update_freq=update_freq,
|
|
102
|
-
scale_first=scale_first,
|
|
103
|
-
scale_second=scale_second,
|
|
104
|
-
concat_params=concat_params,
|
|
105
|
-
inverse=True,
|
|
106
|
-
inner=inner,
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
110
|
-
tol = settings["tol"]
|
|
111
|
-
sy = torch.dot(s, y)
|
|
112
|
-
if sy <= tol: return H
|
|
113
|
-
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
114
|
-
term1 = num1.div_(sy**2)
|
|
115
|
-
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
116
|
-
term2 = num2.div_(sy)
|
|
117
|
-
H += term1.sub_(term2)
|
|
118
|
-
return H
|
|
63
|
+
## Notes
|
|
64
|
+
|
|
65
|
+
### update
|
|
66
|
+
|
|
67
|
+
On 1st ``update_tensor`` H or B is initialized using ``initialize_P``, which returns identity matrix by default.
|
|
68
|
+
|
|
69
|
+
2nd and subsequent ``update_tensor`` calls ``update_H`` or ``update_B``.
|
|
70
|
+
|
|
71
|
+
Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
|
|
72
|
+
|
|
73
|
+
### apply
|
|
74
|
+
|
|
75
|
+
``apply_tensor`` computes ``H = modify_H(H)`` or ``B = modify_B(B)``, those methods do nothing by default.
|
|
76
|
+
|
|
77
|
+
Then it computes and returns ``H @ input`` or ``solve(B, input)``.
|
|
78
|
+
|
|
79
|
+
Whether ``H`` or ``B`` is used depends on value of ``inverse`` setting.
|
|
80
|
+
|
|
81
|
+
### initial scale
|
|
82
|
+
|
|
83
|
+
If ``init_scale`` is a scalar, the preconditioner is multiplied or divided (if inverse) by it on first ``update_tensor``.
|
|
119
84
|
|
|
85
|
+
If ``init_scale="auto"``, it is computed and applied on the second ``update_tensor``.
|
|
86
|
+
|
|
87
|
+
### get_H
|
|
88
|
+
|
|
89
|
+
First it computes ``H = modify_H(H)`` or ``B = modify_B(B)``.
|
|
90
|
+
|
|
91
|
+
Returns a ``Dense`` linear operator with ``B``, or ``DenseInverse`` linear operator with ``H``.
|
|
92
|
+
|
|
93
|
+
But if H/B has 1 dimension, ``Diagonal`` linear operator is returned with ``B`` or ``1/H``.
|
|
120
94
|
"""
|
|
121
95
|
def __init__(
|
|
122
96
|
self,
|
|
123
97
|
defaults: dict | None = None,
|
|
124
98
|
init_scale: float | Literal["auto"] = "auto",
|
|
125
|
-
tol: float = 1e-
|
|
126
|
-
ptol: float | None = 1e-
|
|
127
|
-
|
|
128
|
-
gtol: float | None = 1e-
|
|
129
|
-
|
|
99
|
+
tol: float = 1e-32,
|
|
100
|
+
ptol: float | None = 1e-32,
|
|
101
|
+
ptol_restart: bool = False,
|
|
102
|
+
gtol: float | None = 1e-32,
|
|
103
|
+
restart_interval: int | None | Literal['auto'] = None,
|
|
130
104
|
beta: float | None = None,
|
|
131
105
|
update_freq: int = 1,
|
|
132
|
-
scale_first: bool =
|
|
133
|
-
scale_second: bool = False,
|
|
106
|
+
scale_first: bool = False,
|
|
134
107
|
concat_params: bool = True,
|
|
135
108
|
inverse: bool = True,
|
|
136
109
|
inner: Chainable | None = None,
|
|
137
110
|
):
|
|
138
111
|
if defaults is None: defaults = {}
|
|
139
|
-
|
|
140
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq,
|
|
141
|
-
|
|
142
|
-
def _init_M(self, size:int, device, dtype, is_inverse:bool):
|
|
143
|
-
return torch.eye(size, device=device, dtype=dtype)
|
|
112
|
+
safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
|
|
113
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
144
114
|
|
|
145
|
-
def
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
yy = y.dot(y)
|
|
149
|
-
if ys != 0 and yy != 0: return yy/ys
|
|
150
|
-
return 1
|
|
115
|
+
def reset_for_online(self):
|
|
116
|
+
super().reset_for_online()
|
|
117
|
+
self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
|
|
151
118
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
if inverse: M /= init_scale
|
|
157
|
-
else: M *= init_scale
|
|
119
|
+
# ---------------------------- methods to override --------------------------- #
|
|
120
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool) -> torch.Tensor:
|
|
121
|
+
"""returns the initial torch.Tensor for H or B"""
|
|
122
|
+
return torch.eye(size, device=device, dtype=dtype)
|
|
158
123
|
|
|
159
124
|
def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
160
125
|
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
161
126
|
"""update hessian inverse"""
|
|
162
|
-
raise NotImplementedError
|
|
127
|
+
raise NotImplementedError(f"hessian inverse approximation is not implemented for {self.__class__.__name__}.")
|
|
163
128
|
|
|
164
129
|
def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
|
|
165
130
|
p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]) -> torch.Tensor:
|
|
166
131
|
"""update hessian"""
|
|
167
|
-
raise NotImplementedError
|
|
168
|
-
|
|
169
|
-
def reset_for_online(self):
|
|
170
|
-
super().reset_for_online()
|
|
171
|
-
self.clear_state_keys('f_prev', 'p_prev', 'g_prev')
|
|
132
|
+
raise NotImplementedError(f"{self.__class__.__name__} only supports hessian inverse approximation. "
|
|
133
|
+
"Remove the `inverse=False` argument when initializing this module.")
|
|
172
134
|
|
|
173
|
-
def
|
|
174
|
-
"""
|
|
175
|
-
|
|
176
|
-
if "B" in state: return state["B"], False
|
|
177
|
-
return state["H"], True
|
|
135
|
+
def modify_B(self, B: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
136
|
+
"""modifies B out of place before appling the update rule, doesn't affect the buffer B."""
|
|
137
|
+
return B
|
|
178
138
|
|
|
179
|
-
def
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
if "H" in state: return state["H"], False
|
|
183
|
-
return state["B"], True
|
|
184
|
-
|
|
185
|
-
def make_Bv(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
186
|
-
B, is_inverse = self.get_B()
|
|
187
|
-
|
|
188
|
-
if is_inverse:
|
|
189
|
-
H=B
|
|
190
|
-
warnings.warn(f'{self} maintains H, so Bv will be inefficient!')
|
|
191
|
-
def Hxv(v): return torch.linalg.solve_ex(H, v)[0] # pylint:disable=not-callable
|
|
192
|
-
return Hxv
|
|
193
|
-
|
|
194
|
-
def Bv(v): return B@v
|
|
195
|
-
return Bv
|
|
196
|
-
|
|
197
|
-
def make_Hv(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
198
|
-
H, is_inverse = self.get_H()
|
|
139
|
+
def modify_H(self, H: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
140
|
+
"""modifies H out of place before appling the update rule, doesn't affect the buffer H."""
|
|
141
|
+
return H
|
|
199
142
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
143
|
+
# ------------------------------ common methods ------------------------------ #
|
|
144
|
+
def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
|
|
145
|
+
"""returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
|
|
146
|
+
ys = y.dot(s)
|
|
147
|
+
yy = y.dot(y)
|
|
148
|
+
if ys != 0 and yy != 0: return yy/ys
|
|
149
|
+
return 1
|
|
205
150
|
|
|
206
|
-
|
|
207
|
-
|
|
151
|
+
def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
|
|
152
|
+
"""resets ``P`` which is either B or H"""
|
|
153
|
+
set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
|
|
154
|
+
if init_scale == 'auto': init_scale = self.auto_initial_scale(s,y)
|
|
155
|
+
if init_scale >= 1:
|
|
156
|
+
if inverse: P /= init_scale
|
|
157
|
+
else: P *= init_scale
|
|
208
158
|
|
|
209
159
|
@torch.no_grad
|
|
210
160
|
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
@@ -216,14 +166,14 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
216
166
|
state['step'] = step
|
|
217
167
|
init_scale = setting['init_scale']
|
|
218
168
|
ptol = setting['ptol']
|
|
219
|
-
|
|
169
|
+
ptol_restart = setting['ptol_restart']
|
|
220
170
|
gtol = setting['gtol']
|
|
221
|
-
|
|
222
|
-
if
|
|
171
|
+
restart_interval = setting['restart_interval']
|
|
172
|
+
if restart_interval == 'auto': restart_interval = tensor.numel() + 1
|
|
223
173
|
|
|
224
174
|
if M is None or 'f_prev' not in state:
|
|
225
175
|
if M is None: # won't be true on reset_for_online
|
|
226
|
-
M = self.
|
|
176
|
+
M = self.initialize_P(p.numel(), device=p.device, dtype=p.dtype, is_inverse=inverse)
|
|
227
177
|
if isinstance(init_scale, (int, float)) and init_scale != 1:
|
|
228
178
|
if inverse: M /= init_scale
|
|
229
179
|
else: M *= init_scale
|
|
@@ -242,13 +192,13 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
242
192
|
state['p_prev'].copy_(p)
|
|
243
193
|
state['g_prev'].copy_(g)
|
|
244
194
|
|
|
245
|
-
if
|
|
246
|
-
self.
|
|
195
|
+
if restart_interval is not None and step % restart_interval == 0:
|
|
196
|
+
self.reset_P(M, s, y, inverse, init_scale, state)
|
|
247
197
|
return
|
|
248
198
|
|
|
249
199
|
# tolerance on parameter difference to avoid exploding after converging
|
|
250
200
|
if ptol is not None and s.abs().max() <= ptol:
|
|
251
|
-
if
|
|
201
|
+
if ptol_restart: self.reset_P(M, s, y, inverse, init_scale, state) # reset history
|
|
252
202
|
return
|
|
253
203
|
|
|
254
204
|
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
@@ -256,8 +206,8 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
256
206
|
return
|
|
257
207
|
|
|
258
208
|
if step == 2 and init_scale == 'auto':
|
|
259
|
-
if inverse: M /= self.
|
|
260
|
-
else: M *= self.
|
|
209
|
+
if inverse: M /= self.auto_initial_scale(s,y)
|
|
210
|
+
else: M *= self.auto_initial_scale(s,y)
|
|
261
211
|
|
|
262
212
|
beta = setting['beta']
|
|
263
213
|
if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
|
|
@@ -272,72 +222,86 @@ class HessianUpdateStrategy(TensorwiseTransform, ABC):
|
|
|
272
222
|
|
|
273
223
|
state['f_prev'] = loss
|
|
274
224
|
|
|
275
|
-
def _post_B(self, B: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
276
|
-
"""modifies B before appling the update rule. Must return (B, g)"""
|
|
277
|
-
return B, g
|
|
278
|
-
|
|
279
|
-
def _post_H(self, H: torch.Tensor, g: torch.Tensor, state: dict[str, Any], setting: Mapping[str, Any]):
|
|
280
|
-
"""modifies H before appling the update rule. Must return (H, g)"""
|
|
281
|
-
return H, g
|
|
282
|
-
|
|
283
225
|
@torch.no_grad
|
|
284
226
|
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
285
|
-
step = state
|
|
227
|
+
step = state['step']
|
|
286
228
|
|
|
287
|
-
if setting['
|
|
288
|
-
tensor
|
|
229
|
+
if setting['scale_first'] and step == 1:
|
|
230
|
+
tensor *= initial_step_size(tensor)
|
|
289
231
|
|
|
290
232
|
inverse = setting['inverse']
|
|
233
|
+
g = tensor.view(-1)
|
|
234
|
+
|
|
291
235
|
if inverse:
|
|
292
236
|
H = state['H']
|
|
293
|
-
H
|
|
237
|
+
H = self.modify_H(H, state, setting)
|
|
294
238
|
if H.ndim == 1: return g.mul_(H).view_as(tensor)
|
|
295
239
|
return (H @ g).view_as(tensor)
|
|
296
240
|
|
|
297
241
|
B = state['B']
|
|
298
|
-
|
|
242
|
+
B = self.modify_B(B, state, setting)
|
|
299
243
|
|
|
300
244
|
if B.ndim == 1: return g.div_(B).view_as(tensor)
|
|
301
245
|
x, info = torch.linalg.solve_ex(B, g) # pylint:disable=not-callable
|
|
302
246
|
if info == 0: return x.view_as(tensor)
|
|
303
|
-
|
|
247
|
+
|
|
248
|
+
# failed to solve linear system, so reset state
|
|
249
|
+
self.state.clear()
|
|
250
|
+
self.global_state.clear()
|
|
251
|
+
return tensor.mul_(initial_step_size(tensor))
|
|
252
|
+
|
|
253
|
+
def get_H(self, var):
|
|
254
|
+
param = var.params[0]
|
|
255
|
+
state = self.state[param]
|
|
256
|
+
settings = self.settings[param]
|
|
257
|
+
if "B" in state:
|
|
258
|
+
B = self.modify_B(state["B"], state, settings)
|
|
259
|
+
if B.ndim == 2: return linear_operator.Dense(B)
|
|
260
|
+
assert B.ndim == 1, B.shape
|
|
261
|
+
return linear_operator.Diagonal(B)
|
|
262
|
+
|
|
263
|
+
if "H" in state:
|
|
264
|
+
H = self.modify_H(state["H"], state, settings)
|
|
265
|
+
if H.ndim != 1: return linear_operator.DenseInverse(H)
|
|
266
|
+
return linear_operator.Diagonal(1/H)
|
|
267
|
+
|
|
268
|
+
return None
|
|
304
269
|
|
|
305
270
|
class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
306
|
-
'''This is
|
|
307
|
-
Refer to
|
|
308
|
-
|
|
309
|
-
Example:
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
271
|
+
'''This is ``HessianUpdateStrategy`` subclass for algorithms with no extra defaults, to skip the lengthy ``__init__``.
|
|
272
|
+
Refer to ``HessianUpdateStrategy`` documentation.
|
|
273
|
+
|
|
274
|
+
## Example:
|
|
275
|
+
|
|
276
|
+
Implementing BFGS method that maintains an estimate of the hessian inverse (H):
|
|
277
|
+
```python
|
|
278
|
+
class BFGS(_HessianUpdateStrategyDefaults):
|
|
279
|
+
"""Broyden–Fletcher–Goldfarb–Shanno algorithm"""
|
|
280
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
|
|
281
|
+
tol = settings["tol"]
|
|
282
|
+
sy = torch.dot(s, y)
|
|
283
|
+
if sy <= tol: return H
|
|
284
|
+
num1 = (sy + (y @ H @ y)) * s.outer(s)
|
|
285
|
+
term1 = num1.div_(sy**2)
|
|
286
|
+
num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
|
|
287
|
+
term2 = num2.div_(sy)
|
|
288
|
+
H += term1.sub_(term2)
|
|
289
|
+
return H
|
|
290
|
+
```
|
|
326
291
|
|
|
327
292
|
Make sure to put at least a basic class level docstring to overwrite this.
|
|
328
293
|
'''
|
|
329
294
|
def __init__(
|
|
330
295
|
self,
|
|
331
296
|
init_scale: float | Literal["auto"] = "auto",
|
|
332
|
-
tol: float = 1e-
|
|
333
|
-
ptol: float | None = 1e-
|
|
334
|
-
|
|
335
|
-
gtol: float | None = 1e-
|
|
336
|
-
|
|
297
|
+
tol: float = 1e-32,
|
|
298
|
+
ptol: float | None = 1e-32,
|
|
299
|
+
ptol_restart: bool = False,
|
|
300
|
+
gtol: float | None = 1e-32,
|
|
301
|
+
restart_interval: int | None = None,
|
|
337
302
|
beta: float | None = None,
|
|
338
303
|
update_freq: int = 1,
|
|
339
|
-
scale_first: bool =
|
|
340
|
-
scale_second: bool = False,
|
|
304
|
+
scale_first: bool = False,
|
|
341
305
|
concat_params: bool = True,
|
|
342
306
|
inverse: bool = True,
|
|
343
307
|
inner: Chainable | None = None,
|
|
@@ -347,13 +311,12 @@ class _InverseHessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
|
347
311
|
init_scale=init_scale,
|
|
348
312
|
tol=tol,
|
|
349
313
|
ptol=ptol,
|
|
350
|
-
|
|
314
|
+
ptol_restart=ptol_restart,
|
|
351
315
|
gtol=gtol,
|
|
352
|
-
|
|
316
|
+
restart_interval=restart_interval,
|
|
353
317
|
beta=beta,
|
|
354
318
|
update_freq=update_freq,
|
|
355
319
|
scale_first=scale_first,
|
|
356
|
-
scale_second=scale_second,
|
|
357
320
|
concat_params=concat_params,
|
|
358
321
|
inverse=inverse,
|
|
359
322
|
inner=inner,
|
|
@@ -363,15 +326,14 @@ class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
|
363
326
|
def __init__(
|
|
364
327
|
self,
|
|
365
328
|
init_scale: float | Literal["auto"] = "auto",
|
|
366
|
-
tol: float = 1e-
|
|
367
|
-
ptol: float | None = 1e-
|
|
368
|
-
|
|
369
|
-
gtol: float | None = 1e-
|
|
370
|
-
|
|
329
|
+
tol: float = 1e-32,
|
|
330
|
+
ptol: float | None = 1e-32,
|
|
331
|
+
ptol_restart: bool = False,
|
|
332
|
+
gtol: float | None = 1e-32,
|
|
333
|
+
restart_interval: int | None = None,
|
|
371
334
|
beta: float | None = None,
|
|
372
335
|
update_freq: int = 1,
|
|
373
|
-
scale_first: bool =
|
|
374
|
-
scale_second: bool = False,
|
|
336
|
+
scale_first: bool = False,
|
|
375
337
|
concat_params: bool = True,
|
|
376
338
|
inverse: bool = False,
|
|
377
339
|
inner: Chainable | None = None,
|
|
@@ -381,13 +343,12 @@ class _HessianUpdateStrategyDefaults(HessianUpdateStrategy):
|
|
|
381
343
|
init_scale=init_scale,
|
|
382
344
|
tol=tol,
|
|
383
345
|
ptol=ptol,
|
|
384
|
-
|
|
346
|
+
ptol_restart=ptol_restart,
|
|
385
347
|
gtol=gtol,
|
|
386
|
-
|
|
348
|
+
restart_interval=restart_interval,
|
|
387
349
|
beta=beta,
|
|
388
350
|
update_freq=update_freq,
|
|
389
351
|
scale_first=scale_first,
|
|
390
|
-
scale_second=scale_second,
|
|
391
352
|
concat_params=concat_params,
|
|
392
353
|
inverse=inverse,
|
|
393
354
|
inner=inner,
|
|
@@ -399,7 +360,7 @@ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
399
360
|
if sy < tol: return B
|
|
400
361
|
|
|
401
362
|
Bs = B@s
|
|
402
|
-
sBs =
|
|
363
|
+
sBs = safe_clip(s.dot(Bs))
|
|
403
364
|
|
|
404
365
|
term1 = y.outer(y).div_(sy)
|
|
405
366
|
term2 = (Bs.outer(s) @ B.T).div_(sBs)
|
|
@@ -410,7 +371,7 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
410
371
|
sy = s.dot(y)
|
|
411
372
|
if sy <= tol: return H
|
|
412
373
|
|
|
413
|
-
sy_sq =
|
|
374
|
+
sy_sq = safe_clip(sy**2)
|
|
414
375
|
|
|
415
376
|
Hy = H@y
|
|
416
377
|
scale1 = (sy + y.dot(Hy)) / sy_sq
|
|
@@ -425,11 +386,11 @@ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
425
386
|
class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
426
387
|
"""Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
|
|
427
388
|
|
|
428
|
-
|
|
429
|
-
a line search
|
|
389
|
+
Note:
|
|
390
|
+
a line search or a trust region is recommended
|
|
430
391
|
|
|
431
|
-
|
|
432
|
-
this uses
|
|
392
|
+
Warning:
|
|
393
|
+
this uses at least O(N^2) memory.
|
|
433
394
|
|
|
434
395
|
Args:
|
|
435
396
|
init_scale (float | Literal["auto"], optional):
|
|
@@ -439,12 +400,12 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
|
439
400
|
|
|
440
401
|
Defaults to "auto".
|
|
441
402
|
tol (float, optional):
|
|
442
|
-
tolerance on curvature condition. Defaults to 1e-
|
|
403
|
+
tolerance on curvature condition. Defaults to 1e-32.
|
|
443
404
|
ptol (float | None, optional):
|
|
444
405
|
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
445
|
-
Defaults to 1e-
|
|
446
|
-
|
|
447
|
-
|
|
406
|
+
Defaults to 1e-32.
|
|
407
|
+
ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
408
|
+
restart_interval (int | None | Literal["auto"], optional):
|
|
448
409
|
interval between resetting the hessian approximation.
|
|
449
410
|
|
|
450
411
|
"auto" corresponds to number of decision variables + 1.
|
|
@@ -462,26 +423,25 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
|
|
|
462
423
|
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
463
424
|
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
464
425
|
|
|
465
|
-
Examples:
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
)
|
|
426
|
+
## Examples:
|
|
427
|
+
|
|
428
|
+
BFGS with backtracking line search:
|
|
429
|
+
|
|
430
|
+
```python
|
|
431
|
+
opt = tz.Modular(
|
|
432
|
+
model.parameters(),
|
|
433
|
+
tz.m.BFGS(),
|
|
434
|
+
tz.m.Backtracking()
|
|
435
|
+
)
|
|
436
|
+
```
|
|
437
|
+
|
|
438
|
+
BFGS with trust region
|
|
439
|
+
```python
|
|
440
|
+
opt = tz.Modular(
|
|
441
|
+
model.parameters(),
|
|
442
|
+
tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
|
|
443
|
+
)
|
|
444
|
+
```
|
|
485
445
|
"""
|
|
486
446
|
|
|
487
447
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
@@ -501,38 +461,29 @@ def sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
|
501
461
|
|
|
502
462
|
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
503
463
|
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
504
|
-
H += z.outer(z).div_(
|
|
464
|
+
H += z.outer(z).div_(safe_clip(denom))
|
|
505
465
|
return H
|
|
506
466
|
|
|
507
467
|
class SR1(_InverseHessianUpdateStrategyDefaults):
|
|
508
|
-
"""Symmetric Rank 1
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
.. note::
|
|
514
|
-
approximate Hessians generated by the SR1 method show faster progress towards the true Hessian than other methods, but it is more unstable. SR1 is best used within a trust region module.
|
|
515
|
-
|
|
516
|
-
.. note::
|
|
517
|
-
SR1 doesn't enforce the hessian estimate to be positive definite, therefore it can generate directions that are not descent directions.
|
|
518
|
-
|
|
519
|
-
.. warning::
|
|
520
|
-
this uses roughly O(N^2) memory.
|
|
468
|
+
"""Symmetric Rank 1. This works best with a trust region:
|
|
469
|
+
```python
|
|
470
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))
|
|
471
|
+
```
|
|
521
472
|
|
|
522
473
|
Args:
|
|
523
474
|
init_scale (float | Literal["auto"], optional):
|
|
524
475
|
initial hessian matrix is set to identity times this.
|
|
525
476
|
|
|
526
|
-
"auto" corresponds to a heuristic from
|
|
477
|
+
"auto" corresponds to a heuristic from [1] p.142-143.
|
|
527
478
|
|
|
528
479
|
Defaults to "auto".
|
|
529
480
|
tol (float, optional):
|
|
530
|
-
tolerance for denominator in SR1 update rule as in
|
|
481
|
+
tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.
|
|
531
482
|
ptol (float | None, optional):
|
|
532
483
|
skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
|
|
533
|
-
Defaults to 1e-
|
|
534
|
-
|
|
535
|
-
|
|
484
|
+
Defaults to 1e-32.
|
|
485
|
+
ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
|
|
486
|
+
restart_interval (int | None | Literal["auto"], optional):
|
|
536
487
|
interval between resetting the hessian approximation.
|
|
537
488
|
|
|
538
489
|
"auto" corresponds to number of decision variables + 1.
|
|
@@ -550,26 +501,18 @@ class SR1(_InverseHessianUpdateStrategyDefaults):
|
|
|
550
501
|
If False, the update rule is applied to each parameter separately. Defaults to True.
|
|
551
502
|
inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.
|
|
552
503
|
|
|
553
|
-
Examples:
|
|
554
|
-
SR1 with strong-wolfe line search
|
|
504
|
+
### Examples:
|
|
555
505
|
|
|
556
|
-
|
|
506
|
+
SR1 with trust region
|
|
507
|
+
```python
|
|
508
|
+
opt = tz.Modular(
|
|
509
|
+
model.parameters(),
|
|
510
|
+
tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
|
|
511
|
+
)
|
|
512
|
+
```
|
|
557
513
|
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
tz.m.SR1(),
|
|
561
|
-
tz.m.StrongWolfe()
|
|
562
|
-
)
|
|
563
|
-
|
|
564
|
-
BFGS preconditioning applied to momentum
|
|
565
|
-
|
|
566
|
-
.. code-block:: python
|
|
567
|
-
|
|
568
|
-
opt = tz.Modular(
|
|
569
|
-
model.parameters(),
|
|
570
|
-
tz.m.SR1(inner=tz.m.EMA(0.9)),
|
|
571
|
-
tz.m.LR(1e-2)
|
|
572
|
-
)
|
|
514
|
+
### References:
|
|
515
|
+
[1]. Nocedal. Stephen J. Wright. Numerical Optimization
|
|
573
516
|
"""
|
|
574
517
|
|
|
575
518
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
@@ -584,7 +527,7 @@ def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
584
527
|
if sy.abs() <= tol: return H
|
|
585
528
|
term1 = s.outer(s).div_(sy)
|
|
586
529
|
|
|
587
|
-
yHy =
|
|
530
|
+
yHy = safe_clip(y.dot(H @ y))
|
|
588
531
|
|
|
589
532
|
num = (H @ y).outer(y) @ H
|
|
590
533
|
term2 = num.div_(yHy)
|
|
@@ -607,15 +550,11 @@ def dfp_B(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
607
550
|
class DFP(_InverseHessianUpdateStrategyDefaults):
|
|
608
551
|
"""Davidon–Fletcher–Powell Quasi-Newton method.
|
|
609
552
|
|
|
610
|
-
|
|
611
|
-
a
|
|
612
|
-
|
|
613
|
-
.. note::
|
|
614
|
-
BFGS is the recommended QN method and will usually outperform this.
|
|
615
|
-
|
|
616
|
-
.. warning::
|
|
617
|
-
this uses roughly O(N^2) memory.
|
|
553
|
+
Note:
|
|
554
|
+
a trust region or an accurate line search is recommended.
|
|
618
555
|
|
|
556
|
+
Warning:
|
|
557
|
+
this uses at least O(N^2) memory.
|
|
619
558
|
"""
|
|
620
559
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
621
560
|
return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
@@ -629,30 +568,30 @@ class DFP(_InverseHessianUpdateStrategyDefaults):
|
|
|
629
568
|
|
|
630
569
|
def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
631
570
|
c = H.T @ s
|
|
632
|
-
cy =
|
|
571
|
+
cy = safe_clip(c.dot(y))
|
|
633
572
|
num = (H@y).sub_(s).outer(c)
|
|
634
573
|
H -= num/cy
|
|
635
574
|
return H
|
|
636
575
|
def broyden_good_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
637
576
|
r = y - B@s
|
|
638
|
-
ss =
|
|
577
|
+
ss = safe_clip(s.dot(s))
|
|
639
578
|
B += r.outer(s).div_(ss)
|
|
640
579
|
return B
|
|
641
580
|
|
|
642
581
|
def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
643
|
-
yy =
|
|
582
|
+
yy = safe_clip(y.dot(y))
|
|
644
583
|
num = (s - (H @ y)).outer(y)
|
|
645
584
|
H += num/yy
|
|
646
585
|
return H
|
|
647
586
|
def broyden_bad_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
648
587
|
r = y - B@s
|
|
649
|
-
ys =
|
|
588
|
+
ys = safe_clip(y.dot(s))
|
|
650
589
|
B += r.outer(y).div_(ys)
|
|
651
590
|
return B
|
|
652
591
|
|
|
653
592
|
def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor):
|
|
654
593
|
c = g_prev
|
|
655
|
-
cy =
|
|
594
|
+
cy = safe_clip(c.dot(y))
|
|
656
595
|
num = (H@y).sub_(s).outer(c)
|
|
657
596
|
H -= num/cy
|
|
658
597
|
return H
|
|
@@ -660,7 +599,7 @@ def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torc
|
|
|
660
599
|
def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
661
600
|
Hy = H @ y
|
|
662
601
|
c = H @ Hy # pylint:disable=not-callable
|
|
663
|
-
cy =
|
|
602
|
+
cy = safe_clip(c.dot(y))
|
|
664
603
|
num = Hy.sub_(s).outer(c)
|
|
665
604
|
H -= num/cy
|
|
666
605
|
return H
|
|
@@ -668,14 +607,11 @@ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
|
668
607
|
class BroydenGood(_InverseHessianUpdateStrategyDefaults):
|
|
669
608
|
"""Broyden's "good" Quasi-Newton method.
|
|
670
609
|
|
|
671
|
-
|
|
672
|
-
a
|
|
610
|
+
Note:
|
|
611
|
+
a trust region or an accurate line search is recommended.
|
|
673
612
|
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
.. warning::
|
|
678
|
-
this uses roughly O(N^2) memory.
|
|
613
|
+
Warning:
|
|
614
|
+
this uses at least O(N^2) memory.
|
|
679
615
|
|
|
680
616
|
Reference:
|
|
681
617
|
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
@@ -688,14 +624,11 @@ class BroydenGood(_InverseHessianUpdateStrategyDefaults):
|
|
|
688
624
|
class BroydenBad(_InverseHessianUpdateStrategyDefaults):
|
|
689
625
|
"""Broyden's "bad" Quasi-Newton method.
|
|
690
626
|
|
|
691
|
-
|
|
692
|
-
a
|
|
693
|
-
|
|
694
|
-
.. note::
|
|
695
|
-
BFGS is the recommended QN method and will usually outperform this.
|
|
627
|
+
Note:
|
|
628
|
+
a trust region or an accurate line search is recommended.
|
|
696
629
|
|
|
697
|
-
|
|
698
|
-
this uses
|
|
630
|
+
Warning:
|
|
631
|
+
this uses at least O(N^2) memory.
|
|
699
632
|
|
|
700
633
|
Reference:
|
|
701
634
|
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
@@ -708,14 +641,11 @@ class BroydenBad(_InverseHessianUpdateStrategyDefaults):
|
|
|
708
641
|
class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
|
|
709
642
|
"""Greenstadt's first Quasi-Newton method.
|
|
710
643
|
|
|
711
|
-
|
|
712
|
-
a
|
|
644
|
+
Note:
|
|
645
|
+
a trust region or an accurate line search is recommended.
|
|
713
646
|
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
.. warning::
|
|
718
|
-
this uses roughly O(N^2) memory.
|
|
647
|
+
Warning:
|
|
648
|
+
this uses at least O(N^2) memory.
|
|
719
649
|
|
|
720
650
|
Reference:
|
|
721
651
|
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
@@ -726,18 +656,14 @@ class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
|
|
|
726
656
|
class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
|
|
727
657
|
"""Greenstadt's second Quasi-Newton method.
|
|
728
658
|
|
|
729
|
-
|
|
730
|
-
a line search
|
|
659
|
+
Note:
|
|
660
|
+
a line search is recommended.
|
|
731
661
|
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
.. warning::
|
|
736
|
-
this uses roughly O(N^2) memory.
|
|
662
|
+
Warning:
|
|
663
|
+
this uses at least O(N^2) memory.
|
|
737
664
|
|
|
738
665
|
Reference:
|
|
739
666
|
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
740
|
-
|
|
741
667
|
"""
|
|
742
668
|
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
743
669
|
return greenstadt2_H_(H=H, s=s, y=y)
|
|
@@ -746,7 +672,7 @@ class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
|
|
|
746
672
|
def icum_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor):
|
|
747
673
|
j = y.abs().argmax()
|
|
748
674
|
|
|
749
|
-
denom =
|
|
675
|
+
denom = safe_clip(y[j])
|
|
750
676
|
|
|
751
677
|
Hy = H @ y.unsqueeze(1)
|
|
752
678
|
num = s.unsqueeze(1) - Hy
|
|
@@ -759,11 +685,11 @@ class ICUM(_InverseHessianUpdateStrategyDefaults):
|
|
|
759
685
|
Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
|
|
760
686
|
due to only updating one column of the inverse hessian approximation per step.
|
|
761
687
|
|
|
762
|
-
|
|
763
|
-
a line search
|
|
688
|
+
Note:
|
|
689
|
+
a line search is recommended.
|
|
764
690
|
|
|
765
|
-
|
|
766
|
-
this uses
|
|
691
|
+
Warning:
|
|
692
|
+
this uses at least O(N^2) memory.
|
|
767
693
|
|
|
768
694
|
Reference:
|
|
769
695
|
Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
|
|
@@ -775,11 +701,11 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor)
|
|
|
775
701
|
s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
|
|
776
702
|
I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
|
|
777
703
|
d = (R + I * (s_norm/2)) @ s
|
|
778
|
-
ds =
|
|
704
|
+
ds = safe_clip(d.dot(s))
|
|
779
705
|
R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
|
|
780
706
|
|
|
781
707
|
c = H.T @ d
|
|
782
|
-
cy =
|
|
708
|
+
cy = safe_clip(c.dot(y))
|
|
783
709
|
num = (H@y).sub_(s).outer(c)
|
|
784
710
|
H -= num/cy
|
|
785
711
|
return H, R
|
|
@@ -788,14 +714,11 @@ class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
|
|
|
788
714
|
"""
|
|
789
715
|
Thomas's "optimal" Quasi-Newton method.
|
|
790
716
|
|
|
791
|
-
|
|
792
|
-
a line search
|
|
717
|
+
Note:
|
|
718
|
+
a line search is recommended.
|
|
793
719
|
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
.. warning::
|
|
798
|
-
this uses roughly O(N^2) memory.
|
|
720
|
+
Warning:
|
|
721
|
+
this uses at least O(N^2) memory.
|
|
799
722
|
|
|
800
723
|
Reference:
|
|
801
724
|
Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
|
|
@@ -805,18 +728,18 @@ class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
|
|
|
805
728
|
H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
|
|
806
729
|
return H
|
|
807
730
|
|
|
808
|
-
def
|
|
809
|
-
super().
|
|
731
|
+
def reset_P(self, P, s, y, inverse, init_scale, state):
|
|
732
|
+
super().reset_P(P, s, y, inverse, init_scale, state)
|
|
810
733
|
for st in self.state.values():
|
|
811
734
|
st.pop("R", None)
|
|
812
735
|
|
|
813
736
|
# ------------------------ powell's symmetric broyden ------------------------ #
|
|
814
737
|
def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
815
738
|
y_Bs = y - B@s
|
|
816
|
-
ss =
|
|
739
|
+
ss = safe_clip(s.dot(s))
|
|
817
740
|
num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
|
|
818
741
|
term1 = num1.div_(ss)
|
|
819
|
-
term2 = s.outer(s).mul_(y_Bs.dot(s)/(
|
|
742
|
+
term2 = s.outer(s).mul_(y_Bs.dot(s)/(safe_clip(ss**2)))
|
|
820
743
|
B += term1.sub_(term2)
|
|
821
744
|
return B
|
|
822
745
|
|
|
@@ -824,14 +747,11 @@ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
|
824
747
|
class PSB(_HessianUpdateStrategyDefaults):
|
|
825
748
|
"""Powell's Symmetric Broyden Quasi-Newton method.
|
|
826
749
|
|
|
827
|
-
|
|
828
|
-
a line search
|
|
829
|
-
|
|
830
|
-
.. note::
|
|
831
|
-
BFGS is the recommended QN method and will usually outperform this.
|
|
750
|
+
Note:
|
|
751
|
+
a line search or a trust region is recommended.
|
|
832
752
|
|
|
833
|
-
|
|
834
|
-
this uses
|
|
753
|
+
Warning:
|
|
754
|
+
this uses at least O(N^2) memory.
|
|
835
755
|
|
|
836
756
|
Reference:
|
|
837
757
|
Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
|
|
@@ -843,7 +763,7 @@ class PSB(_HessianUpdateStrategyDefaults):
|
|
|
843
763
|
# Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
|
|
844
764
|
def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
845
765
|
Hy = H@y
|
|
846
|
-
yHy =
|
|
766
|
+
yHy = safe_clip(y.dot(Hy))
|
|
847
767
|
num = (s - Hy).outer(Hy)
|
|
848
768
|
H += num.div_(yHy)
|
|
849
769
|
return H
|
|
@@ -852,14 +772,11 @@ class Pearson(_InverseHessianUpdateStrategyDefaults):
|
|
|
852
772
|
"""
|
|
853
773
|
Pearson's Quasi-Newton method.
|
|
854
774
|
|
|
855
|
-
|
|
856
|
-
a line search
|
|
775
|
+
Note:
|
|
776
|
+
a line search is recommended.
|
|
857
777
|
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
.. warning::
|
|
862
|
-
this uses roughly O(N^2) memory.
|
|
778
|
+
Warning:
|
|
779
|
+
this uses at least O(N^2) memory.
|
|
863
780
|
|
|
864
781
|
Reference:
|
|
865
782
|
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
@@ -868,7 +785,7 @@ class Pearson(_InverseHessianUpdateStrategyDefaults):
|
|
|
868
785
|
return pearson_H_(H=H, s=s, y=y)
|
|
869
786
|
|
|
870
787
|
def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
871
|
-
sy =
|
|
788
|
+
sy = safe_clip(s.dot(y))
|
|
872
789
|
num = (s - H@y).outer(s)
|
|
873
790
|
H += num.div_(sy)
|
|
874
791
|
return H
|
|
@@ -876,14 +793,11 @@ def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
|
876
793
|
class McCormick(_InverseHessianUpdateStrategyDefaults):
|
|
877
794
|
"""McCormicks's Quasi-Newton method.
|
|
878
795
|
|
|
879
|
-
|
|
880
|
-
a line search
|
|
881
|
-
|
|
882
|
-
.. note::
|
|
883
|
-
BFGS is the recommended QN method and will usually outperform this.
|
|
796
|
+
Note:
|
|
797
|
+
a line search is recommended.
|
|
884
798
|
|
|
885
|
-
|
|
886
|
-
this uses
|
|
799
|
+
Warning:
|
|
800
|
+
this uses at least O(N^2) memory.
|
|
887
801
|
|
|
888
802
|
Reference:
|
|
889
803
|
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
@@ -895,7 +809,7 @@ class McCormick(_InverseHessianUpdateStrategyDefaults):
|
|
|
895
809
|
|
|
896
810
|
def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor):
|
|
897
811
|
Hy = H @ y
|
|
898
|
-
yHy =
|
|
812
|
+
yHy = safe_clip(y.dot(Hy))
|
|
899
813
|
H -= Hy.outer(Hy) / yHy
|
|
900
814
|
R += (s - R@y).outer(Hy) / yHy
|
|
901
815
|
return H, R
|
|
@@ -904,14 +818,11 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
|
904
818
|
"""
|
|
905
819
|
Projected Newton Raphson method.
|
|
906
820
|
|
|
907
|
-
|
|
908
|
-
a line search
|
|
909
|
-
|
|
910
|
-
.. note::
|
|
911
|
-
this is an experimental method.
|
|
821
|
+
Note:
|
|
822
|
+
a line search is recommended.
|
|
912
823
|
|
|
913
|
-
|
|
914
|
-
this uses
|
|
824
|
+
Warning:
|
|
825
|
+
this uses at least O(N^2) memory.
|
|
915
826
|
|
|
916
827
|
Reference:
|
|
917
828
|
Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
|
|
@@ -921,15 +832,14 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
|
921
832
|
def __init__(
|
|
922
833
|
self,
|
|
923
834
|
init_scale: float | Literal["auto"] = 'auto',
|
|
924
|
-
tol: float = 1e-
|
|
925
|
-
ptol: float | None = 1e-
|
|
926
|
-
|
|
927
|
-
gtol: float | None = 1e-
|
|
928
|
-
|
|
835
|
+
tol: float = 1e-32,
|
|
836
|
+
ptol: float | None = 1e-32,
|
|
837
|
+
ptol_restart: bool = False,
|
|
838
|
+
gtol: float | None = 1e-32,
|
|
839
|
+
restart_interval: int | None | Literal['auto'] = 'auto',
|
|
929
840
|
beta: float | None = None,
|
|
930
841
|
update_freq: int = 1,
|
|
931
|
-
scale_first: bool =
|
|
932
|
-
scale_second: bool = False,
|
|
842
|
+
scale_first: bool = False,
|
|
933
843
|
concat_params: bool = True,
|
|
934
844
|
inner: Chainable | None = None,
|
|
935
845
|
):
|
|
@@ -937,13 +847,12 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
|
937
847
|
init_scale=init_scale,
|
|
938
848
|
tol=tol,
|
|
939
849
|
ptol = ptol,
|
|
940
|
-
|
|
850
|
+
ptol_restart=ptol_restart,
|
|
941
851
|
gtol=gtol,
|
|
942
|
-
|
|
852
|
+
restart_interval=restart_interval,
|
|
943
853
|
beta=beta,
|
|
944
854
|
update_freq=update_freq,
|
|
945
855
|
scale_first=scale_first,
|
|
946
|
-
scale_second=scale_second,
|
|
947
856
|
concat_params=concat_params,
|
|
948
857
|
inverse=True,
|
|
949
858
|
inner=inner,
|
|
@@ -955,9 +864,10 @@ class ProjectedNewtonRaphson(HessianUpdateStrategy):
|
|
|
955
864
|
state["R"] = R
|
|
956
865
|
return H
|
|
957
866
|
|
|
958
|
-
def
|
|
867
|
+
def reset_P(self, P, s, y, inverse, init_scale, state):
|
|
959
868
|
assert inverse
|
|
960
|
-
|
|
869
|
+
if 'R' not in state: state['R'] = torch.eye(P.size(-1), device=P.device, dtype=P.dtype)
|
|
870
|
+
P.copy_(state["R"])
|
|
961
871
|
|
|
962
872
|
# Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
|
|
963
873
|
def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
|
|
@@ -969,8 +879,8 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
969
879
|
# however p.12 says eps = gs / gHy
|
|
970
880
|
|
|
971
881
|
Hy = H@y
|
|
972
|
-
gHy =
|
|
973
|
-
yHy =
|
|
882
|
+
gHy = safe_clip(g.dot(Hy))
|
|
883
|
+
yHy = safe_clip(y.dot(Hy))
|
|
974
884
|
sy = s.dot(y)
|
|
975
885
|
if sy < tol: return H # the proof is for sy>0. But not clear if it should be skipped
|
|
976
886
|
|
|
@@ -987,26 +897,26 @@ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, swi
|
|
|
987
897
|
e = gs / gHy
|
|
988
898
|
if switch in (1, 3):
|
|
989
899
|
if e/o <= 1:
|
|
990
|
-
phi = e/
|
|
900
|
+
phi = e/safe_clip(o)
|
|
991
901
|
theta = 0
|
|
992
902
|
elif o/t >= 1:
|
|
993
|
-
phi = o/
|
|
903
|
+
phi = o/safe_clip(t)
|
|
994
904
|
theta = 1
|
|
995
905
|
else:
|
|
996
906
|
phi = 1
|
|
997
|
-
denom =
|
|
907
|
+
denom = safe_clip(e*t - o**2)
|
|
998
908
|
if switch == 1: theta = o * (e - o) / denom
|
|
999
909
|
else: theta = o * (t - o) / denom
|
|
1000
910
|
|
|
1001
911
|
elif switch == 2:
|
|
1002
|
-
t =
|
|
1003
|
-
o =
|
|
1004
|
-
e =
|
|
912
|
+
t = safe_clip(t)
|
|
913
|
+
o = safe_clip(o)
|
|
914
|
+
e = safe_clip(e)
|
|
1005
915
|
phi = (e / t) ** 0.5
|
|
1006
916
|
theta = 1 / (1 + (t*e / o**2)**0.5)
|
|
1007
917
|
|
|
1008
918
|
elif switch == 4:
|
|
1009
|
-
phi = e/
|
|
919
|
+
phi = e/safe_clip(t)
|
|
1010
920
|
theta = 1/2
|
|
1011
921
|
|
|
1012
922
|
else: raise ValueError(switch)
|
|
@@ -1028,14 +938,11 @@ class SSVM(HessianUpdateStrategy):
|
|
|
1028
938
|
"""
|
|
1029
939
|
Self-scaling variable metric Quasi-Newton method.
|
|
1030
940
|
|
|
1031
|
-
|
|
1032
|
-
a line search
|
|
1033
|
-
|
|
1034
|
-
.. note::
|
|
1035
|
-
BFGS is the recommended QN method and will usually outperform this.
|
|
941
|
+
Note:
|
|
942
|
+
a line search is recommended.
|
|
1036
943
|
|
|
1037
|
-
|
|
1038
|
-
this uses
|
|
944
|
+
Warning:
|
|
945
|
+
this uses at least O(N^2) memory.
|
|
1039
946
|
|
|
1040
947
|
Reference:
|
|
1041
948
|
Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
|
|
@@ -1044,15 +951,14 @@ class SSVM(HessianUpdateStrategy):
|
|
|
1044
951
|
self,
|
|
1045
952
|
switch: tuple[float,float] | Literal[1,2,3,4] = 3,
|
|
1046
953
|
init_scale: float | Literal["auto"] = 'auto',
|
|
1047
|
-
tol: float = 1e-
|
|
1048
|
-
ptol: float | None = 1e-
|
|
1049
|
-
|
|
1050
|
-
gtol: float | None = 1e-
|
|
1051
|
-
|
|
954
|
+
tol: float = 1e-32,
|
|
955
|
+
ptol: float | None = 1e-32,
|
|
956
|
+
ptol_restart: bool = False,
|
|
957
|
+
gtol: float | None = 1e-32,
|
|
958
|
+
restart_interval: int | None = None,
|
|
1052
959
|
beta: float | None = None,
|
|
1053
960
|
update_freq: int = 1,
|
|
1054
|
-
scale_first: bool =
|
|
1055
|
-
scale_second: bool = False,
|
|
961
|
+
scale_first: bool = False,
|
|
1056
962
|
concat_params: bool = True,
|
|
1057
963
|
inner: Chainable | None = None,
|
|
1058
964
|
):
|
|
@@ -1062,13 +968,12 @@ class SSVM(HessianUpdateStrategy):
|
|
|
1062
968
|
init_scale=init_scale,
|
|
1063
969
|
tol=tol,
|
|
1064
970
|
ptol=ptol,
|
|
1065
|
-
|
|
971
|
+
ptol_restart=ptol_restart,
|
|
1066
972
|
gtol=gtol,
|
|
1067
|
-
|
|
973
|
+
restart_interval=restart_interval,
|
|
1068
974
|
beta=beta,
|
|
1069
975
|
update_freq=update_freq,
|
|
1070
976
|
scale_first=scale_first,
|
|
1071
|
-
scale_second=scale_second,
|
|
1072
977
|
concat_params=concat_params,
|
|
1073
978
|
inverse=True,
|
|
1074
979
|
inner=inner,
|
|
@@ -1083,7 +988,7 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
1083
988
|
ys = y.dot(s)
|
|
1084
989
|
if ys.abs() <= tol: return H # probably? because it is BFGS and DFP-like
|
|
1085
990
|
yHy = y.dot(Hy)
|
|
1086
|
-
denom =
|
|
991
|
+
denom = safe_clip(ys + yHy)
|
|
1087
992
|
|
|
1088
993
|
term1 = 1/denom
|
|
1089
994
|
term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
|
|
@@ -1096,7 +1001,7 @@ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
|
1096
1001
|
return H
|
|
1097
1002
|
|
|
1098
1003
|
def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
|
|
1099
|
-
sy =
|
|
1004
|
+
sy = safe_clip(s.dot(y))
|
|
1100
1005
|
return g - (y * (s.dot(g) / sy))
|
|
1101
1006
|
|
|
1102
1007
|
|
|
@@ -1106,16 +1011,16 @@ class GradientCorrection(Transform):
|
|
|
1106
1011
|
|
|
1107
1012
|
This can useful as inner module for second order methods with inexact line search.
|
|
1108
1013
|
|
|
1109
|
-
Example:
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
.. code-block :: python
|
|
1014
|
+
## Example:
|
|
1015
|
+
L-BFGS with gradient correction
|
|
1113
1016
|
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1017
|
+
```python
|
|
1018
|
+
opt = tz.Modular(
|
|
1019
|
+
model.parameters(),
|
|
1020
|
+
tz.m.LBFGS(inner=tz.m.GradientCorrection()),
|
|
1021
|
+
tz.m.Backtracking()
|
|
1022
|
+
)
|
|
1023
|
+
```
|
|
1119
1024
|
|
|
1120
1025
|
Reference:
|
|
1121
1026
|
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
@@ -1141,14 +1046,11 @@ class Horisho(_InverseHessianUpdateStrategyDefaults):
|
|
|
1141
1046
|
"""
|
|
1142
1047
|
Horisho's variable metric Quasi-Newton method.
|
|
1143
1048
|
|
|
1144
|
-
|
|
1145
|
-
a line search
|
|
1146
|
-
|
|
1147
|
-
.. note::
|
|
1148
|
-
BFGS is the recommended QN method and will usually outperform this.
|
|
1049
|
+
Note:
|
|
1050
|
+
a line search is recommended.
|
|
1149
1051
|
|
|
1150
|
-
|
|
1151
|
-
this uses
|
|
1052
|
+
Warning:
|
|
1053
|
+
this uses at least O(N^2) memory.
|
|
1152
1054
|
|
|
1153
1055
|
Reference:
|
|
1154
1056
|
HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
|
|
@@ -1175,14 +1077,11 @@ class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
|
|
|
1175
1077
|
"""
|
|
1176
1078
|
Fletcher's variable metric Quasi-Newton method.
|
|
1177
1079
|
|
|
1178
|
-
|
|
1179
|
-
a line search
|
|
1180
|
-
|
|
1181
|
-
.. note::
|
|
1182
|
-
BFGS is the recommended QN method and will usually outperform this.
|
|
1080
|
+
Note:
|
|
1081
|
+
a line search is recommended.
|
|
1183
1082
|
|
|
1184
|
-
|
|
1185
|
-
this uses
|
|
1083
|
+
Warning:
|
|
1084
|
+
this uses at least O(N^2) memory.
|
|
1186
1085
|
|
|
1187
1086
|
Reference:
|
|
1188
1087
|
Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
|
|
@@ -1218,10 +1117,10 @@ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol:
|
|
|
1218
1117
|
class NewSSM(HessianUpdateStrategy):
|
|
1219
1118
|
"""Self-scaling Quasi-Newton method.
|
|
1220
1119
|
|
|
1221
|
-
|
|
1222
|
-
a line search such as
|
|
1120
|
+
Note:
|
|
1121
|
+
a line search such as ``tz.m.StrongWolfe()`` is required.
|
|
1223
1122
|
|
|
1224
|
-
|
|
1123
|
+
Warning:
|
|
1225
1124
|
this uses roughly O(N^2) memory.
|
|
1226
1125
|
|
|
1227
1126
|
Reference:
|
|
@@ -1231,15 +1130,14 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
1231
1130
|
self,
|
|
1232
1131
|
type: Literal[1, 2] = 1,
|
|
1233
1132
|
init_scale: float | Literal["auto"] = "auto",
|
|
1234
|
-
tol: float = 1e-
|
|
1235
|
-
ptol: float | None = 1e-
|
|
1236
|
-
|
|
1237
|
-
gtol: float | None = 1e-
|
|
1238
|
-
|
|
1133
|
+
tol: float = 1e-32,
|
|
1134
|
+
ptol: float | None = 1e-32,
|
|
1135
|
+
ptol_restart: bool = False,
|
|
1136
|
+
gtol: float | None = 1e-32,
|
|
1137
|
+
restart_interval: int | None = None,
|
|
1239
1138
|
beta: float | None = None,
|
|
1240
1139
|
update_freq: int = 1,
|
|
1241
|
-
scale_first: bool =
|
|
1242
|
-
scale_second: bool = False,
|
|
1140
|
+
scale_first: bool = False,
|
|
1243
1141
|
concat_params: bool = True,
|
|
1244
1142
|
inner: Chainable | None = None,
|
|
1245
1143
|
):
|
|
@@ -1248,13 +1146,12 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
1248
1146
|
init_scale=init_scale,
|
|
1249
1147
|
tol=tol,
|
|
1250
1148
|
ptol=ptol,
|
|
1251
|
-
|
|
1149
|
+
ptol_restart=ptol_restart,
|
|
1252
1150
|
gtol=gtol,
|
|
1253
|
-
|
|
1151
|
+
restart_interval=restart_interval,
|
|
1254
1152
|
beta=beta,
|
|
1255
1153
|
update_freq=update_freq,
|
|
1256
1154
|
scale_first=scale_first,
|
|
1257
|
-
scale_second=scale_second,
|
|
1258
1155
|
concat_params=concat_params,
|
|
1259
1156
|
inverse=True,
|
|
1260
1157
|
inner=inner,
|
|
@@ -1267,44 +1164,48 @@ class NewSSM(HessianUpdateStrategy):
|
|
|
1267
1164
|
# ---------------------------- Shor’s r-algorithm ---------------------------- #
|
|
1268
1165
|
# def shor_r(B:torch.Tensor, y:torch.Tensor, gamma:float):
|
|
1269
1166
|
# r = B.T @ y
|
|
1270
|
-
# r /= torch.linalg.vector_norm(r).clip(min=1e-
|
|
1167
|
+
# r /= torch.linalg.vector_norm(r).clip(min=1e-32) # pylint:disable=not-callable
|
|
1271
1168
|
|
|
1272
1169
|
# I = torch.eye(B.size(1), device=B.device, dtype=B.dtype)
|
|
1273
1170
|
# return B @ (I - gamma*r.outer(r))
|
|
1274
1171
|
|
|
1275
|
-
# this is supposed to be equivalent
|
|
1172
|
+
# this is supposed to be equivalent (and it is)
|
|
1276
1173
|
def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
|
|
1277
1174
|
p = H@y
|
|
1278
1175
|
#(1-y)^2 (ppT)/(pTq)
|
|
1279
|
-
term = p.outer(p).div_(p.dot(y).clip(min=1e-
|
|
1176
|
+
#term = p.outer(p).div_(p.dot(y).clip(min=1e-32))
|
|
1177
|
+
term = p.outer(p).div_(safe_clip(p.dot(y)))
|
|
1280
1178
|
H.sub_(term, alpha=1-alpha**2)
|
|
1281
1179
|
return H
|
|
1282
1180
|
|
|
1283
1181
|
class ShorR(HessianUpdateStrategy):
|
|
1284
1182
|
"""Shor’s r-algorithm.
|
|
1285
1183
|
|
|
1286
|
-
|
|
1287
|
-
|
|
1184
|
+
Note:
|
|
1185
|
+
A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
|
|
1186
|
+
Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
|
|
1187
|
+
so setting ``a_init`` in the line search is recommended.
|
|
1288
1188
|
|
|
1289
|
-
|
|
1290
|
-
|
|
1189
|
+
References:
|
|
1190
|
+
S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.
|
|
1191
|
+
|
|
1192
|
+
Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.
|
|
1291
1193
|
|
|
1292
|
-
Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.
|
|
1194
|
+
Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
|
|
1293
1195
|
"""
|
|
1294
1196
|
|
|
1295
1197
|
def __init__(
|
|
1296
1198
|
self,
|
|
1297
1199
|
alpha=0.5,
|
|
1298
1200
|
init_scale: float | Literal["auto"] = 1,
|
|
1299
|
-
tol: float = 1e-
|
|
1300
|
-
ptol: float | None = 1e-
|
|
1301
|
-
|
|
1302
|
-
gtol: float | None = 1e-
|
|
1303
|
-
|
|
1201
|
+
tol: float = 1e-32,
|
|
1202
|
+
ptol: float | None = 1e-32,
|
|
1203
|
+
ptol_restart: bool = False,
|
|
1204
|
+
gtol: float | None = 1e-32,
|
|
1205
|
+
restart_interval: int | None | Literal['auto'] = None,
|
|
1304
1206
|
beta: float | None = None,
|
|
1305
1207
|
update_freq: int = 1,
|
|
1306
1208
|
scale_first: bool = False,
|
|
1307
|
-
scale_second: bool = False,
|
|
1308
1209
|
concat_params: bool = True,
|
|
1309
1210
|
# inverse: bool = True,
|
|
1310
1211
|
inner: Chainable | None = None,
|
|
@@ -1315,13 +1216,12 @@ class ShorR(HessianUpdateStrategy):
|
|
|
1315
1216
|
init_scale=init_scale,
|
|
1316
1217
|
tol=tol,
|
|
1317
1218
|
ptol=ptol,
|
|
1318
|
-
|
|
1219
|
+
ptol_restart=ptol_restart,
|
|
1319
1220
|
gtol=gtol,
|
|
1320
|
-
|
|
1221
|
+
restart_interval=restart_interval,
|
|
1321
1222
|
beta=beta,
|
|
1322
1223
|
update_freq=update_freq,
|
|
1323
1224
|
scale_first=scale_first,
|
|
1324
|
-
scale_second=scale_second,
|
|
1325
1225
|
concat_params=concat_params,
|
|
1326
1226
|
inverse=True,
|
|
1327
1227
|
inner=inner,
|