torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Literal, Protocol, overload
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList
|
|
7
|
+
from ...utils.linalg.linear_operator import DenseInverse, LinearOperator
|
|
8
|
+
from ..functional import safe_clip
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DampingStrategy(Protocol):
|
|
12
|
+
def __call__(
|
|
13
|
+
self,
|
|
14
|
+
s: torch.Tensor,
|
|
15
|
+
y: torch.Tensor,
|
|
16
|
+
g: torch.Tensor,
|
|
17
|
+
H: LinearOperator,
|
|
18
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
19
|
+
return s, y
|
|
20
|
+
|
|
21
|
+
def _sy_Hs_sHs(s:torch.Tensor, y:torch.Tensor, H:LinearOperator):
|
|
22
|
+
if isinstance(H, DenseInverse):
|
|
23
|
+
Hs = H.solve(y)
|
|
24
|
+
sHs = y.dot(Hs)
|
|
25
|
+
else:
|
|
26
|
+
Hs = H.matvec(s)
|
|
27
|
+
sHs = s.dot(Hs)
|
|
28
|
+
|
|
29
|
+
return s.dot(y), Hs, sHs
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def powell_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u=0.2):
|
|
34
|
+
# here H is hessian! not the inverse
|
|
35
|
+
|
|
36
|
+
sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
|
|
37
|
+
if sy < u*sHs:
|
|
38
|
+
phi = ((1-u) * sHs) / safe_clip((sHs - sy))
|
|
39
|
+
s = phi * s + (1 - phi) * Hs
|
|
40
|
+
|
|
41
|
+
return s, y
|
|
42
|
+
|
|
43
|
+
def double_damping(s:torch.Tensor, y:torch.Tensor, g:torch.Tensor, H:LinearOperator, u1=0.2, u2=1/3):
|
|
44
|
+
# Goldfarb, Donald, Yi Ren, and Achraf Bahamou. "Practical quasi-newton methods for training deep neural networks." Advances in Neural Information Processing Systems 33 (2020): 2386-2396.
|
|
45
|
+
|
|
46
|
+
# Powell’s damping on H
|
|
47
|
+
sy, Hs, sHs = _sy_Hs_sHs(s, y, H)
|
|
48
|
+
if sy < u1*sHs:
|
|
49
|
+
phi = ((1-u1) * sHs) / safe_clip(sHs - sy)
|
|
50
|
+
s = phi * s + (1 - phi) * Hs
|
|
51
|
+
|
|
52
|
+
# Powell’s damping with B = I
|
|
53
|
+
sy = s.dot(y)
|
|
54
|
+
ss = s.dot(s)
|
|
55
|
+
|
|
56
|
+
if sy < u2*ss:
|
|
57
|
+
phi = ((1-u2) * ss) / safe_clip(ss - sy)
|
|
58
|
+
y = phi * y + (1 - phi) * s
|
|
59
|
+
|
|
60
|
+
return s, y
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
_DAMPING_KEYS = Literal["powell", "double"]
|
|
65
|
+
_DAMPING_STRATEGIES: dict[_DAMPING_KEYS, DampingStrategy] = {
|
|
66
|
+
"powell": powell_damping,
|
|
67
|
+
"double": double_damping,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
DampingStrategyType = _DAMPING_KEYS | DampingStrategy | None
|
|
72
|
+
|
|
73
|
+
@overload
|
|
74
|
+
def apply_damping(
|
|
75
|
+
strategy: DampingStrategyType,
|
|
76
|
+
s: torch.Tensor,
|
|
77
|
+
y: torch.Tensor,
|
|
78
|
+
g: torch.Tensor,
|
|
79
|
+
H: LinearOperator,
|
|
80
|
+
) -> tuple[torch.Tensor, torch.Tensor]: ...
|
|
81
|
+
@overload
|
|
82
|
+
def apply_damping(
|
|
83
|
+
strategy: DampingStrategyType,
|
|
84
|
+
s: TensorList,
|
|
85
|
+
y: TensorList,
|
|
86
|
+
g: TensorList,
|
|
87
|
+
H: LinearOperator,
|
|
88
|
+
) -> tuple[TensorList, TensorList]: ...
|
|
89
|
+
def apply_damping(
|
|
90
|
+
strategy: DampingStrategyType,
|
|
91
|
+
s,
|
|
92
|
+
y,
|
|
93
|
+
g,
|
|
94
|
+
H: LinearOperator,
|
|
95
|
+
):
|
|
96
|
+
if strategy is None: return s, y
|
|
97
|
+
if isinstance(strategy, str): strategy = _DAMPING_STRATEGIES[strategy]
|
|
98
|
+
|
|
99
|
+
if isinstance(s, TensorList):
|
|
100
|
+
assert isinstance(y, TensorList) and isinstance(g, TensorList)
|
|
101
|
+
s_vec, y_vec = strategy(s.to_vec(), y.to_vec(), g.to_vec(), H)
|
|
102
|
+
return s.from_vec(s_vec), y.from_vec(y_vec)
|
|
103
|
+
|
|
104
|
+
assert isinstance(y, torch.Tensor) and isinstance(g, torch.Tensor)
|
|
105
|
+
return strategy(s, y, g, H)
|
|
@@ -1,163 +1,167 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
H
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
return
|
|
34
|
-
|
|
35
|
-
def
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class
|
|
49
|
-
"""Diagonal
|
|
50
|
-
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
51
|
-
return
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
"""
|
|
72
|
-
def
|
|
73
|
-
return
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from ...core import Chainable
|
|
6
|
+
from .quasi_newton import (
|
|
7
|
+
HessianUpdateStrategy,
|
|
8
|
+
_HessianUpdateStrategyDefaults,
|
|
9
|
+
_InverseHessianUpdateStrategyDefaults,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from ..functional import safe_clip
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
|
|
16
|
+
sy = s.dot(y)
|
|
17
|
+
if sy < tol: return H
|
|
18
|
+
|
|
19
|
+
sy_sq = safe_clip(sy**2)
|
|
20
|
+
|
|
21
|
+
num1 = (sy + (y * H * y)) * s*s
|
|
22
|
+
term1 = num1.div_(sy_sq)
|
|
23
|
+
num2 = (H * y * s).add_(s * y * H)
|
|
24
|
+
term2 = num2.div_(sy)
|
|
25
|
+
H += term1.sub_(term2)
|
|
26
|
+
return H
|
|
27
|
+
|
|
28
|
+
class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
|
|
29
|
+
"""Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
30
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
31
|
+
return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
|
|
32
|
+
|
|
33
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
34
|
+
|
|
35
|
+
def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
|
|
36
|
+
z = s - H*y
|
|
37
|
+
denom = z.dot(y)
|
|
38
|
+
|
|
39
|
+
z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
|
|
40
|
+
y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
|
|
41
|
+
|
|
42
|
+
# if y_norm*z_norm < tol: return H
|
|
43
|
+
|
|
44
|
+
# check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
|
|
45
|
+
if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
|
|
46
|
+
H += (z*z).div_(safe_clip(denom))
|
|
47
|
+
return H
|
|
48
|
+
class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
|
|
49
|
+
"""Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
|
|
50
|
+
def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
|
|
51
|
+
return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
|
|
52
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
53
|
+
return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
|
|
54
|
+
|
|
55
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
60
|
+
def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
61
|
+
denom = safe_clip((s**4).sum())
|
|
62
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
63
|
+
B += s**2 * (num/denom)
|
|
64
|
+
return B
|
|
65
|
+
|
|
66
|
+
class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
67
|
+
"""Diagonal quasi-cauchi method.
|
|
68
|
+
|
|
69
|
+
Reference:
|
|
70
|
+
Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
|
|
71
|
+
"""
|
|
72
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
73
|
+
return diagonal_qc_B_(B=B, s=s, y=y)
|
|
74
|
+
|
|
75
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
76
|
+
|
|
77
|
+
# Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
78
|
+
def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
79
|
+
E_sq = s**2 * B**2
|
|
80
|
+
denom = safe_clip((s*E_sq).dot(s))
|
|
81
|
+
num = s.dot(y) - (s*B).dot(s)
|
|
82
|
+
B += E_sq * (num/denom)
|
|
83
|
+
return B
|
|
84
|
+
|
|
85
|
+
class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
|
|
86
|
+
"""Diagonal quasi-cauchi method.
|
|
87
|
+
|
|
88
|
+
Reference:
|
|
89
|
+
Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
|
|
90
|
+
"""
|
|
91
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
92
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
93
|
+
|
|
94
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
95
|
+
|
|
96
|
+
def _truncate(B: torch.Tensor, lb, ub):
|
|
97
|
+
return torch.where((B>lb).logical_and(B<ub), B, 1)
|
|
98
|
+
|
|
99
|
+
# Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
100
|
+
def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
101
|
+
denom = safe_clip((s**4).sum())
|
|
102
|
+
num = s.dot(y) + s.dot(s) - (s*B).dot(s)
|
|
103
|
+
B += s**2 * (num/denom) - 1
|
|
104
|
+
return B
|
|
105
|
+
|
|
106
|
+
class DNRTR(HessianUpdateStrategy):
|
|
107
|
+
"""Diagonal quasi-newton method.
|
|
108
|
+
|
|
109
|
+
Reference:
|
|
110
|
+
Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
|
|
111
|
+
"""
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
lb: float = 1e-2,
|
|
115
|
+
ub: float = 1e5,
|
|
116
|
+
init_scale: float | Literal["auto"] = "auto",
|
|
117
|
+
tol: float = 1e-32,
|
|
118
|
+
ptol: float | None = 1e-32,
|
|
119
|
+
ptol_restart: bool = False,
|
|
120
|
+
gtol: float | None = 1e-32,
|
|
121
|
+
restart_interval: int | None | Literal['auto'] = None,
|
|
122
|
+
beta: float | None = None,
|
|
123
|
+
update_freq: int = 1,
|
|
124
|
+
scale_first: bool = False,
|
|
125
|
+
concat_params: bool = True,
|
|
126
|
+
inner: Chainable | None = None,
|
|
127
|
+
):
|
|
128
|
+
defaults = dict(lb=lb, ub=ub)
|
|
129
|
+
super().__init__(
|
|
130
|
+
defaults=defaults,
|
|
131
|
+
init_scale=init_scale,
|
|
132
|
+
tol=tol,
|
|
133
|
+
ptol=ptol,
|
|
134
|
+
ptol_restart=ptol_restart,
|
|
135
|
+
gtol=gtol,
|
|
136
|
+
restart_interval=restart_interval,
|
|
137
|
+
beta=beta,
|
|
138
|
+
update_freq=update_freq,
|
|
139
|
+
scale_first=scale_first,
|
|
140
|
+
concat_params=concat_params,
|
|
141
|
+
inverse=False,
|
|
142
|
+
inner=inner,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
146
|
+
return diagonal_wqc_B_(B=B, s=s, y=y)
|
|
147
|
+
|
|
148
|
+
def modify_B(self, B, state, setting):
|
|
149
|
+
return _truncate(B, setting['lb'], setting['ub'])
|
|
150
|
+
|
|
151
|
+
def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
|
|
152
|
+
|
|
153
|
+
# Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
154
|
+
def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
|
|
155
|
+
denom = safe_clip((s**4).sum())
|
|
156
|
+
num = s.dot(y)
|
|
157
|
+
B += s**2 * (num/denom)
|
|
158
|
+
return B
|
|
159
|
+
|
|
160
|
+
class NewDQN(DNRTR):
|
|
161
|
+
"""Diagonal quasi-newton method.
|
|
162
|
+
|
|
163
|
+
Reference:
|
|
164
|
+
Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
|
|
165
|
+
"""
|
|
166
|
+
def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
|
|
167
|
+
return new_dqn_B_(B=B, s=s, y=y)
|