torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,198 +1,257 @@
|
|
|
1
1
|
from collections import deque
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import overload
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
|
|
6
|
-
from ...core import Chainable,
|
|
7
|
-
from ...utils import
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
7
|
+
from ...core import Chainable, Transform
|
|
8
|
+
from ...utils import TensorList, as_tensorlist, unpack_states
|
|
9
|
+
from ...utils.linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..functional import initial_step_size
|
|
11
|
+
from .damping import DampingStrategyType, apply_damping
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@torch.no_grad
|
|
15
|
+
def _make_M(S:torch.Tensor, Y:torch.Tensor, B_0:torch.Tensor):
|
|
16
|
+
m,n = S.size()
|
|
17
|
+
|
|
18
|
+
M = torch.zeros((2 * m, 2 * m), device=S.device, dtype=S.dtype)
|
|
19
|
+
|
|
20
|
+
# top-left is B S^T S
|
|
21
|
+
M[:m, :m] = B_0 * S @ S.mT
|
|
22
|
+
|
|
23
|
+
# anti-diagonal is L^T and L
|
|
24
|
+
L = (S @ Y.mT).tril_(-1)
|
|
25
|
+
|
|
26
|
+
M[m:, :m] = L.mT
|
|
27
|
+
M[:m, m:] = L
|
|
28
|
+
|
|
29
|
+
# bottom-right
|
|
30
|
+
D_diag = (S * Y).sum(1).neg()
|
|
31
|
+
M[m:, m:] = D_diag.diag_embed()
|
|
32
|
+
|
|
33
|
+
return M
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def lbfgs_Bx(x: torch.Tensor, S: torch.Tensor, Y: torch.Tensor, sy_history, M=None):
|
|
38
|
+
"""L-BFGS hessian-vector product based on compact representation,
|
|
39
|
+
returns (Bx, M), where M is an internal matrix that depends on S and Y so it can be reused."""
|
|
40
|
+
m = len(S)
|
|
41
|
+
if m == 0: return x.clone()
|
|
42
|
+
|
|
43
|
+
# initial scaling
|
|
44
|
+
y = Y[-1]
|
|
45
|
+
sy = sy_history[-1]
|
|
46
|
+
yy = y.dot(y)
|
|
47
|
+
B_0 = yy / sy
|
|
48
|
+
Bx = x * B_0
|
|
49
|
+
|
|
50
|
+
Psi = torch.zeros(2 * m, device=x.device, dtype=x.dtype)
|
|
51
|
+
Psi[:m] = B_0 * S@x
|
|
52
|
+
Psi[m:] = Y@x
|
|
53
|
+
|
|
54
|
+
if M is None: M = _make_M(S, Y, B_0)
|
|
55
|
+
|
|
56
|
+
# solve Mu = p
|
|
57
|
+
u, info = torch.linalg.solve_ex(M, Psi) # pylint:disable=not-callable
|
|
58
|
+
if info != 0:
|
|
59
|
+
return Bx
|
|
60
|
+
|
|
61
|
+
# Bx
|
|
62
|
+
u_S = u[:m]
|
|
63
|
+
u_Y = u[m:]
|
|
64
|
+
SuS = (S * u_S.unsqueeze(-1)).sum(0)
|
|
65
|
+
YuY = (Y * u_Y.unsqueeze(-1)).sum(0)
|
|
66
|
+
return Bx - (B_0 * SuS + YuY), M
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@overload
|
|
70
|
+
def lbfgs_Hx(
|
|
71
|
+
x: torch.Tensor,
|
|
72
|
+
s_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
73
|
+
y_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
74
|
+
sy_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
75
|
+
) -> torch.Tensor: ...
|
|
76
|
+
@overload
|
|
77
|
+
def lbfgs_Hx(
|
|
78
|
+
x: TensorList,
|
|
79
|
+
s_history: Sequence[TensorList],
|
|
80
|
+
y_history: Sequence[TensorList],
|
|
81
|
+
sy_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
82
|
+
) -> TensorList: ...
|
|
83
|
+
def lbfgs_Hx(
|
|
84
|
+
x,
|
|
85
|
+
s_history: Sequence | torch.Tensor,
|
|
86
|
+
y_history: Sequence | torch.Tensor,
|
|
87
|
+
sy_history: Sequence[torch.Tensor] | torch.Tensor,
|
|
39
88
|
):
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return safe_scaling_(TensorList(tensors_))
|
|
89
|
+
"""L-BFGS inverse-hessian-vector product, works with tensors and TensorLists"""
|
|
90
|
+
x = x.clone()
|
|
91
|
+
if len(s_history) == 0: return x
|
|
44
92
|
|
|
45
93
|
# 1st loop
|
|
46
94
|
alpha_list = []
|
|
47
|
-
q = tensors_.clone()
|
|
48
95
|
for s_i, y_i, sy_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
49
|
-
p_i = 1 / sy_i
|
|
50
|
-
alpha = p_i * s_i.dot(
|
|
96
|
+
p_i = 1 / sy_i
|
|
97
|
+
alpha = p_i * s_i.dot(x)
|
|
51
98
|
alpha_list.append(alpha)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
# calculate z
|
|
55
|
-
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
56
|
-
# z is it times q
|
|
57
|
-
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
58
|
-
z = q * (sy / (y.dot(y)))
|
|
99
|
+
x.sub_(y_i, alpha=alpha)
|
|
59
100
|
|
|
60
|
-
#
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
z = z_ema
|
|
101
|
+
# scaled initial hessian inverse
|
|
102
|
+
# H_0 = (s.y/y.y) * I, and z = H_0 @ q
|
|
103
|
+
sy = sy_history[-1]
|
|
104
|
+
y = y_history[-1]
|
|
105
|
+
Hx = x * (sy / y.dot(y))
|
|
66
106
|
|
|
67
107
|
# 2nd loop
|
|
68
108
|
for s_i, y_i, sy_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
69
109
|
p_i = 1 / sy_i
|
|
70
|
-
beta_i = p_i * y_i.dot(
|
|
71
|
-
|
|
110
|
+
beta_i = p_i * y_i.dot(Hx)
|
|
111
|
+
Hx.add_(s_i, alpha = alpha_i - beta_i)
|
|
72
112
|
|
|
73
|
-
return
|
|
113
|
+
return Hx
|
|
74
114
|
|
|
75
|
-
def _lerp_params_update_(
|
|
76
|
-
self_: Module,
|
|
77
|
-
params: list[torch.Tensor],
|
|
78
|
-
update: list[torch.Tensor],
|
|
79
|
-
params_beta: list[float | None],
|
|
80
|
-
grads_beta: list[float | None],
|
|
81
|
-
):
|
|
82
|
-
for i, (p, u, p_beta, u_beta) in enumerate(zip(params.copy(), update.copy(), params_beta, grads_beta)):
|
|
83
|
-
if p_beta is not None or u_beta is not None:
|
|
84
|
-
state = self_.state[p]
|
|
85
115
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
116
|
+
class LBFGSLinearOperator(LinearOperator):
|
|
117
|
+
def __init__(self, s_history: Sequence[torch.Tensor] | torch.Tensor, y_history: Sequence[torch.Tensor] | torch.Tensor, sy_history: Sequence[torch.Tensor] | torch.Tensor):
|
|
118
|
+
super().__init__()
|
|
119
|
+
if len(s_history) == 0:
|
|
120
|
+
self.S = self.Y = self.yy = None
|
|
121
|
+
else:
|
|
122
|
+
self.S = s_history
|
|
123
|
+
self.Y = y_history
|
|
124
|
+
|
|
125
|
+
self.sy_history = sy_history
|
|
126
|
+
self.M = None
|
|
127
|
+
|
|
128
|
+
def _get_S(self):
|
|
129
|
+
if self.S is None: return None
|
|
130
|
+
if not isinstance(self.S, torch.Tensor):
|
|
131
|
+
self.S = torch.stack(tuple(self.S))
|
|
132
|
+
return self.S
|
|
133
|
+
|
|
134
|
+
def _get_Y(self):
|
|
135
|
+
if self.Y is None: return None
|
|
136
|
+
if not isinstance(self.Y, torch.Tensor):
|
|
137
|
+
self.Y = torch.stack(tuple(self.Y))
|
|
138
|
+
return self.Y
|
|
90
139
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
140
|
+
def solve(self, b):
|
|
141
|
+
S = self._get_S(); Y = self._get_Y()
|
|
142
|
+
if S is None or Y is None: return b.clone()
|
|
143
|
+
return lbfgs_Hx(b, S, Y, self.sy_history)
|
|
144
|
+
|
|
145
|
+
def matvec(self, x):
|
|
146
|
+
S = self._get_S(); Y = self._get_Y()
|
|
147
|
+
if S is None or Y is None: return x.clone()
|
|
148
|
+
Bx, self.M = lbfgs_Bx(x, S, Y, self.sy_history, M=self.M)
|
|
149
|
+
return Bx
|
|
150
|
+
|
|
151
|
+
def size(self):
|
|
152
|
+
if self.S is None: raise RuntimeError()
|
|
153
|
+
n = len(self.S[0])
|
|
154
|
+
return (n, n)
|
|
95
155
|
|
|
96
|
-
return TensorList(params), TensorList(update)
|
|
97
156
|
|
|
98
157
|
class LBFGS(Transform):
|
|
99
|
-
"""Limited-memory BFGS algorithm. A line search
|
|
158
|
+
"""Limited-memory BFGS algorithm. A line search or trust region is recommended.
|
|
100
159
|
|
|
101
160
|
Args:
|
|
102
161
|
history_size (int, optional):
|
|
103
162
|
number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
tol (float | None, optional):
|
|
111
|
-
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
|
|
112
|
-
tol_reset (bool, optional):
|
|
113
|
-
If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
|
|
163
|
+
ptol (float | None, optional):
|
|
164
|
+
skips updating the history if maximum absolute value of
|
|
165
|
+
parameter difference is less than this value. Defaults to 1e-10.
|
|
166
|
+
ptol_restart (bool, optional):
|
|
167
|
+
If true, whenever parameter difference is less then ``ptol``,
|
|
168
|
+
L-BFGS state will be reset. Defaults to None.
|
|
114
169
|
gtol (float | None, optional):
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
170
|
+
skips updating the history if if maximum absolute value of
|
|
171
|
+
gradient difference is less than this value. Defaults to 1e-10.
|
|
172
|
+
ptol_restart (bool, optional):
|
|
173
|
+
If true, whenever gradient difference is less then ``gtol``,
|
|
174
|
+
L-BFGS state will be reset. Defaults to None.
|
|
175
|
+
sy_tol (float | None, optional):
|
|
176
|
+
history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)
|
|
177
|
+
scale_first (bool, optional):
|
|
178
|
+
makes first step, when hessian approximation is not available,
|
|
179
|
+
small to reduce number of line search iterations. Defaults to True.
|
|
120
180
|
update_freq (int, optional):
|
|
121
|
-
how often to update L-BFGS history. Defaults to 1.
|
|
122
|
-
|
|
123
|
-
|
|
181
|
+
how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.
|
|
182
|
+
damping (DampingStrategyType, optional):
|
|
183
|
+
damping to use, can be "powell" or "double". Defaults to None.
|
|
124
184
|
inner (Chainable | None, optional):
|
|
125
185
|
optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
|
|
126
186
|
|
|
127
|
-
Examples:
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
tz.m.StrongWolfe()
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
L-BFGS preconditioning applied to momentum (may be unstable!)
|
|
149
|
-
|
|
150
|
-
.. code-block:: python
|
|
151
|
-
|
|
152
|
-
opt = tz.Modular(
|
|
153
|
-
model.parameters(),
|
|
154
|
-
tz.m.LBFGS(inner=tz.m.EMA(0.9)),
|
|
155
|
-
tz.m.LR(1e-2)
|
|
156
|
-
)
|
|
187
|
+
## Examples:
|
|
188
|
+
|
|
189
|
+
L-BFGS with line search
|
|
190
|
+
```python
|
|
191
|
+
opt = tz.Modular(
|
|
192
|
+
model.parameters(),
|
|
193
|
+
tz.m.LBFGS(100),
|
|
194
|
+
tz.m.Backtracking()
|
|
195
|
+
)
|
|
196
|
+
```
|
|
197
|
+
|
|
198
|
+
L-BFGS with trust region
|
|
199
|
+
```python
|
|
200
|
+
opt = tz.Modular(
|
|
201
|
+
model.parameters(),
|
|
202
|
+
tz.m.TrustCG(tz.m.LBFGS())
|
|
203
|
+
)
|
|
204
|
+
```
|
|
157
205
|
"""
|
|
158
206
|
def __init__(
|
|
159
207
|
self,
|
|
160
208
|
history_size=10,
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
params_beta: float | None = None,
|
|
168
|
-
grads_beta: float | None = None,
|
|
209
|
+
ptol: float | None = 1e-32,
|
|
210
|
+
ptol_restart: bool = False,
|
|
211
|
+
gtol: float | None = 1e-32,
|
|
212
|
+
gtol_restart: bool = False,
|
|
213
|
+
sy_tol: float = 1e-32,
|
|
214
|
+
scale_first:bool=True,
|
|
169
215
|
update_freq = 1,
|
|
170
|
-
|
|
216
|
+
damping: DampingStrategyType = None,
|
|
171
217
|
inner: Chainable | None = None,
|
|
172
218
|
):
|
|
173
|
-
defaults = dict(
|
|
174
|
-
|
|
219
|
+
defaults = dict(
|
|
220
|
+
history_size=history_size,
|
|
221
|
+
scale_first=scale_first,
|
|
222
|
+
ptol=ptol,
|
|
223
|
+
gtol=gtol,
|
|
224
|
+
ptol_restart=ptol_restart,
|
|
225
|
+
gtol_restart=gtol_restart,
|
|
226
|
+
sy_tol=sy_tol,
|
|
227
|
+
damping = damping,
|
|
228
|
+
)
|
|
229
|
+
super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
|
|
175
230
|
|
|
176
231
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
177
232
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
178
233
|
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
179
234
|
|
|
180
|
-
def
|
|
235
|
+
def _reset_self(self):
|
|
181
236
|
self.state.clear()
|
|
182
237
|
self.global_state['step'] = 0
|
|
183
238
|
self.global_state['s_history'].clear()
|
|
184
239
|
self.global_state['y_history'].clear()
|
|
185
240
|
self.global_state['sy_history'].clear()
|
|
186
241
|
|
|
242
|
+
def reset(self):
|
|
243
|
+
self._reset_self()
|
|
244
|
+
for c in self.children.values(): c.reset()
|
|
245
|
+
|
|
187
246
|
def reset_for_online(self):
|
|
188
247
|
super().reset_for_online()
|
|
189
|
-
self.clear_state_keys('
|
|
248
|
+
self.clear_state_keys('p_prev', 'g_prev')
|
|
190
249
|
self.global_state.pop('step', None)
|
|
191
250
|
|
|
192
251
|
@torch.no_grad
|
|
193
252
|
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
194
|
-
|
|
195
|
-
|
|
253
|
+
p = as_tensorlist(params)
|
|
254
|
+
g = as_tensorlist(tensors)
|
|
196
255
|
step = self.global_state.get('step', 0)
|
|
197
256
|
self.global_state['step'] = step + 1
|
|
198
257
|
|
|
@@ -201,86 +260,83 @@ class LBFGS(Transform):
|
|
|
201
260
|
y_history: deque[TensorList] = self.global_state['y_history']
|
|
202
261
|
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
203
262
|
|
|
204
|
-
|
|
205
|
-
|
|
263
|
+
ptol = self.defaults['ptol']
|
|
264
|
+
gtol = self.defaults['gtol']
|
|
265
|
+
ptol_restart = self.defaults['ptol_restart']
|
|
266
|
+
gtol_restart = self.defaults['gtol_restart']
|
|
267
|
+
sy_tol = self.defaults['sy_tol']
|
|
268
|
+
damping = self.defaults['damping']
|
|
206
269
|
|
|
207
|
-
|
|
208
|
-
prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
270
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
|
|
209
271
|
|
|
210
272
|
# 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
|
|
211
273
|
if step == 0:
|
|
212
274
|
s = None; y = None; sy = None
|
|
213
275
|
else:
|
|
214
|
-
s =
|
|
215
|
-
y =
|
|
276
|
+
s = p - p_prev
|
|
277
|
+
y = g - g_prev
|
|
278
|
+
|
|
279
|
+
if damping is not None:
|
|
280
|
+
s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
|
|
281
|
+
|
|
216
282
|
sy = s.dot(y)
|
|
283
|
+
# damping to be added here
|
|
217
284
|
|
|
218
|
-
|
|
219
|
-
|
|
285
|
+
below_tol = False
|
|
286
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
287
|
+
if ptol is not None:
|
|
288
|
+
if s is not None and s.abs().global_max() <= ptol:
|
|
289
|
+
if ptol_restart:
|
|
290
|
+
self._reset_self()
|
|
291
|
+
sy = None
|
|
292
|
+
below_tol = True
|
|
220
293
|
|
|
221
|
-
|
|
222
|
-
|
|
294
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
295
|
+
if gtol is not None:
|
|
296
|
+
if y is not None and y.abs().global_max() <= gtol:
|
|
297
|
+
if gtol_restart: self._reset_self()
|
|
298
|
+
sy = None
|
|
299
|
+
below_tol = True
|
|
223
300
|
|
|
224
|
-
#
|
|
225
|
-
if
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
s_history.append(s)
|
|
229
|
-
y_history.append(y)
|
|
230
|
-
sy_history.append(sy)
|
|
301
|
+
# store previous params and grads
|
|
302
|
+
if not below_tol:
|
|
303
|
+
p_prev.copy_(p)
|
|
304
|
+
g_prev.copy_(g)
|
|
231
305
|
|
|
232
|
-
#
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
self.global_state['sy'] = sy
|
|
306
|
+
# update effective preconditioning state
|
|
307
|
+
if sy is not None and sy > sy_tol:
|
|
308
|
+
assert s is not None and y is not None and sy is not None
|
|
236
309
|
|
|
237
|
-
|
|
238
|
-
|
|
310
|
+
s_history.append(s)
|
|
311
|
+
y_history.append(y)
|
|
312
|
+
sy_history.append(sy)
|
|
239
313
|
|
|
240
|
-
def
|
|
241
|
-
|
|
314
|
+
def get_H(self, var=...):
|
|
315
|
+
s_history = [tl.to_vec() for tl in self.global_state['s_history']]
|
|
316
|
+
y_history = [tl.to_vec() for tl in self.global_state['y_history']]
|
|
317
|
+
sy_history = self.global_state['sy_history']
|
|
318
|
+
return LBFGSLinearOperator(s_history, y_history, sy_history)
|
|
242
319
|
|
|
243
320
|
@torch.no_grad
|
|
244
321
|
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
s = self.global_state.pop('s')
|
|
248
|
-
y = self.global_state.pop('y')
|
|
249
|
-
sy = self.global_state.pop('sy')
|
|
250
|
-
|
|
251
|
-
setting = settings[0]
|
|
252
|
-
tol = setting['tol']
|
|
253
|
-
gtol = setting['gtol']
|
|
254
|
-
tol_reset = setting['tol_reset']
|
|
255
|
-
z_beta = setting['z_beta']
|
|
256
|
-
|
|
257
|
-
# tolerance on parameter difference to avoid exploding after converging
|
|
258
|
-
if tol is not None:
|
|
259
|
-
if s is not None and s.abs().global_max() <= tol:
|
|
260
|
-
if tol_reset: self.reset()
|
|
261
|
-
return safe_scaling_(TensorList(tensors))
|
|
322
|
+
scale_first = self.defaults['scale_first']
|
|
262
323
|
|
|
263
|
-
|
|
264
|
-
if tol is not None:
|
|
265
|
-
if y is not None and y.abs().global_max() <= gtol:
|
|
266
|
-
return safe_scaling_(TensorList(tensors))
|
|
324
|
+
tensors = as_tensorlist(tensors)
|
|
267
325
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
z_ema = unpack_states(states, tensors, 'z_ema', cls=TensorList)
|
|
326
|
+
s_history = self.global_state['s_history']
|
|
327
|
+
y_history = self.global_state['y_history']
|
|
328
|
+
sy_history = self.global_state['sy_history']
|
|
272
329
|
|
|
273
330
|
# precondition
|
|
274
|
-
dir =
|
|
275
|
-
|
|
276
|
-
s_history=
|
|
277
|
-
y_history=
|
|
278
|
-
sy_history=
|
|
279
|
-
y=y,
|
|
280
|
-
sy=sy,
|
|
281
|
-
z_beta = z_beta,
|
|
282
|
-
z_ema = z_ema,
|
|
283
|
-
step=self.global_state.get('step', 1)
|
|
331
|
+
dir = lbfgs_Hx(
|
|
332
|
+
x=tensors,
|
|
333
|
+
s_history=s_history,
|
|
334
|
+
y_history=y_history,
|
|
335
|
+
sy_history=sy_history,
|
|
284
336
|
)
|
|
285
337
|
|
|
338
|
+
# scale 1st step
|
|
339
|
+
if scale_first and self.global_state.get('step', 1) == 1:
|
|
340
|
+
dir *= initial_step_size(dir, eps=1e-7)
|
|
341
|
+
|
|
286
342
|
return dir
|