torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,276 +1,375 @@
|
|
|
1
|
-
"""this needs to be reworked maybe but it also works"""
|
|
2
1
|
import math
|
|
3
2
|
import warnings
|
|
4
3
|
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
5
|
|
|
6
|
+
import numpy as np
|
|
6
7
|
import torch
|
|
7
8
|
from torch.optim.lbfgs import _cubic_interpolate
|
|
8
9
|
|
|
9
|
-
from
|
|
10
|
-
from
|
|
10
|
+
from ...utils import as_tensorlist, totensor
|
|
11
|
+
from ._polyinterp import polyinterp, polyinterp2
|
|
12
|
+
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
13
|
+
from ..step_size.adaptive import _bb_geom
|
|
14
|
+
|
|
15
|
+
def _totensor(x):
|
|
16
|
+
if not isinstance(x, torch.Tensor): return torch.tensor(x, dtype=torch.float32)
|
|
17
|
+
return x
|
|
18
|
+
|
|
19
|
+
def _within_bounds(x, bounds):
|
|
20
|
+
if bounds is None: return True
|
|
21
|
+
lb,ub = bounds
|
|
22
|
+
if lb is not None and x < lb: return False
|
|
23
|
+
if ub is not None and x > ub: return False
|
|
24
|
+
return True
|
|
25
|
+
|
|
26
|
+
def _apply_bounds(x, bounds):
|
|
27
|
+
if bounds is None: return True
|
|
28
|
+
lb,ub = bounds
|
|
29
|
+
if lb is not None and x < lb: return lb
|
|
30
|
+
if ub is not None and x > ub: return ub
|
|
31
|
+
return x
|
|
32
|
+
|
|
33
|
+
class _StrongWolfe:
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
f,
|
|
37
|
+
f_0,
|
|
38
|
+
g_0,
|
|
39
|
+
d_norm,
|
|
40
|
+
a_init,
|
|
41
|
+
a_max,
|
|
42
|
+
c1,
|
|
43
|
+
c2,
|
|
44
|
+
maxiter,
|
|
45
|
+
maxeval,
|
|
46
|
+
maxzoom,
|
|
47
|
+
tol_change,
|
|
48
|
+
interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", "polynomial2"],
|
|
49
|
+
):
|
|
50
|
+
self._f = f
|
|
51
|
+
self.f_0 = f_0
|
|
52
|
+
self.g_0 = g_0
|
|
53
|
+
self.d_norm = d_norm
|
|
54
|
+
self.a_init = a_init
|
|
55
|
+
self.a_max = a_max
|
|
56
|
+
self.c1 = c1
|
|
57
|
+
self.c2 = c2
|
|
58
|
+
self.maxiter = maxiter
|
|
59
|
+
if maxeval is None: maxeval = float('inf')
|
|
60
|
+
self.maxeval = maxeval
|
|
61
|
+
self.tol_change = tol_change
|
|
62
|
+
self.num_evals = 0
|
|
63
|
+
self.maxzoom = maxzoom
|
|
64
|
+
self.interpolation = interpolation
|
|
65
|
+
|
|
66
|
+
self.history = {}
|
|
67
|
+
|
|
68
|
+
def f(self, a):
|
|
69
|
+
if a in self.history: return self.history[a]
|
|
70
|
+
self.num_evals += 1
|
|
71
|
+
f_a, g_a = self._f(a)
|
|
72
|
+
self.history[a] = (f_a, g_a)
|
|
73
|
+
return f_a, g_a
|
|
74
|
+
|
|
75
|
+
def interpolate(self, a_lo, f_lo, g_lo, a_hi, f_hi, g_hi, bounds=None):
|
|
76
|
+
if self.interpolation == 'cubic':
|
|
77
|
+
# pytorch cubic interpolate needs tensors
|
|
78
|
+
a_lo = _totensor(a_lo); f_lo = _totensor(f_lo); g_lo = _totensor(g_lo)
|
|
79
|
+
a_hi = _totensor(a_hi); f_hi = _totensor(f_hi); g_hi = _totensor(g_hi)
|
|
80
|
+
return float(_cubic_interpolate(x1=a_lo, f1=f_lo, g1=g_lo, x2=a_hi, f2=f_hi, g2=g_hi, bounds=bounds))
|
|
81
|
+
|
|
82
|
+
if self.interpolation == 'bisection':
|
|
83
|
+
return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
|
|
84
|
+
|
|
85
|
+
if self.interpolation == 'quadratic':
|
|
86
|
+
a = a_hi - a_lo
|
|
87
|
+
denom = 2 * (f_hi - f_lo - g_lo*a)
|
|
88
|
+
if denom > 1e-32:
|
|
89
|
+
num = g_lo * a**2
|
|
90
|
+
a_min = num / -denom
|
|
91
|
+
return _apply_bounds(a_min, bounds)
|
|
92
|
+
return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
|
|
93
|
+
|
|
94
|
+
if self.interpolation in ('polynomial', 'polynomial2'):
|
|
95
|
+
finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
|
|
96
|
+
if bounds is None: bounds = (None, None)
|
|
97
|
+
polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
|
|
98
|
+
try:
|
|
99
|
+
return _apply_bounds(polyinterp_fn(np.array(finite_history), *bounds), bounds) # pyright:ignore[reportArgumentType]
|
|
100
|
+
except torch.linalg.LinAlgError:
|
|
101
|
+
return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(self.interpolation)
|
|
104
|
+
|
|
105
|
+
def zoom(self, a_lo, f_lo, g_lo, a_hi, f_hi, g_hi):
|
|
106
|
+
if a_lo >= a_hi:
|
|
107
|
+
a_hi, f_hi, g_hi, a_lo, f_lo, g_lo = a_lo, f_lo, g_lo, a_hi, f_hi, g_hi
|
|
108
|
+
|
|
109
|
+
insuf_progress = False
|
|
110
|
+
for _ in range(self.maxzoom):
|
|
111
|
+
if self.num_evals >= self.maxeval: break
|
|
112
|
+
if (a_hi - a_lo) * self.d_norm < self.tol_change: break # small bracket
|
|
113
|
+
|
|
114
|
+
if not (math.isfinite(f_hi) and math.isfinite(g_hi)):
|
|
115
|
+
a_hi = a_hi / 2
|
|
116
|
+
f_hi, g_hi = self.f(a_hi)
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
a_j = self.interpolate(a_lo, f_lo, g_lo, a_hi, f_hi, g_hi, bounds=(a_lo, min(a_hi, self.a_max)))
|
|
120
|
+
|
|
121
|
+
# this part is from https://github.com/pytorch/pytorch/blob/main/torch/optim/lbfgs.py:
|
|
122
|
+
eps = 0.1 * (a_hi - a_lo)
|
|
123
|
+
if min(a_hi - a_j, a_j - a_lo) < eps:
|
|
124
|
+
# interpolation close to boundary
|
|
125
|
+
if insuf_progress or a_j >= a_hi or a_j <= a_lo:
|
|
126
|
+
# evaluate at 0.1 away from boundary
|
|
127
|
+
if abs(a_j - a_hi) < abs(a_j - a_lo):
|
|
128
|
+
a_j = a_hi - eps
|
|
129
|
+
else:
|
|
130
|
+
a_j = a_lo + eps
|
|
131
|
+
insuf_progress = False
|
|
132
|
+
else:
|
|
133
|
+
insuf_progress = True
|
|
134
|
+
else:
|
|
135
|
+
insuf_progress = False
|
|
11
136
|
|
|
137
|
+
f_j, g_j = self.f(a_j)
|
|
12
138
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
f_l, g_l,
|
|
16
|
-
f_h, g_h,
|
|
17
|
-
f_0, g_0,
|
|
18
|
-
c1, c2,
|
|
19
|
-
maxzoom):
|
|
139
|
+
if f_j > self.f_0 + self.c1*a_j*self.g_0 or f_j > f_lo:
|
|
140
|
+
a_hi, f_hi, g_hi = a_j, f_j, g_j
|
|
20
141
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
142
|
+
else:
|
|
143
|
+
if abs(g_j) <= -self.c2 * self.g_0:
|
|
144
|
+
return a_j, f_j, g_j
|
|
24
145
|
|
|
25
|
-
|
|
146
|
+
if g_j * (a_hi - a_lo) >= 0:
|
|
147
|
+
a_hi, f_hi, g_hi = a_lo, f_lo, g_lo
|
|
26
148
|
|
|
27
|
-
|
|
28
|
-
delta = abs(a_h - a_l)
|
|
29
|
-
if a_j is None or a_j == a_l or a_j == a_h:
|
|
30
|
-
a_j = a_l + 0.5 * delta
|
|
149
|
+
a_lo, f_lo, g_lo = a_j, f_j, g_j
|
|
31
150
|
|
|
151
|
+
# fail
|
|
152
|
+
return None, None, None
|
|
32
153
|
|
|
33
|
-
|
|
154
|
+
def search(self):
|
|
155
|
+
a_i = min(self.a_init, self.a_max)
|
|
156
|
+
f_i = g_i = None
|
|
157
|
+
a_prev = 0
|
|
158
|
+
f_prev = self.f_0
|
|
159
|
+
g_prev = self.g_0
|
|
160
|
+
for i in range(self.maxiter):
|
|
161
|
+
if self.num_evals >= self.maxeval: break
|
|
162
|
+
f_i, g_i = self.f(a_i)
|
|
34
163
|
|
|
35
|
-
|
|
36
|
-
|
|
164
|
+
if f_i > self.f_0 + self.c1*a_i*self.g_0 or (i > 0 and f_i > f_prev):
|
|
165
|
+
return self.zoom(a_prev, f_prev, g_prev, a_i, f_i, g_i)
|
|
37
166
|
|
|
38
|
-
|
|
39
|
-
|
|
167
|
+
if abs(g_i) <= -self.c2 * self.g_0:
|
|
168
|
+
return a_i, f_i, g_i
|
|
40
169
|
|
|
170
|
+
if g_i >= 0:
|
|
171
|
+
return self.zoom(a_i, f_i, g_i, a_prev, f_prev, g_prev)
|
|
172
|
+
|
|
173
|
+
# from pytorch
|
|
174
|
+
min_step = a_i + 0.01 * (a_i - a_prev)
|
|
175
|
+
max_step = a_i * 10
|
|
176
|
+
a_i_next = self.interpolate(a_prev, f_prev, g_prev, a_i, f_i, g_i, bounds=(min_step, min(max_step, self.a_max)))
|
|
177
|
+
# a_i_next = self.interpolate(a_prev, f_prev, g_prev, a_i, f_i, g_i, bounds=(0, self.a_max))
|
|
178
|
+
|
|
179
|
+
a_prev, f_prev, g_prev = a_i, f_i, g_i
|
|
180
|
+
a_i = a_i_next
|
|
181
|
+
|
|
182
|
+
if self.num_evals < self.maxeval:
|
|
183
|
+
assert f_i is not None and g_i is not None
|
|
184
|
+
return self.zoom(0, self.f_0, self.g_0, a_i, f_i, g_i)
|
|
185
|
+
|
|
186
|
+
return None, None, None
|
|
41
187
|
|
|
42
|
-
# minimum between alpha_low and alpha_j
|
|
43
|
-
if not armijo or f_j >= f_l:
|
|
44
|
-
a_h = a_j
|
|
45
|
-
f_h = f_j
|
|
46
|
-
g_h = g_j
|
|
47
|
-
else:
|
|
48
|
-
# alpha_j satisfies armijo
|
|
49
|
-
if wolfe:
|
|
50
|
-
return a_j, f_j
|
|
51
|
-
|
|
52
|
-
# minimum between alpha_j and alpha_high
|
|
53
|
-
if g_j * (a_h - a_l) >= 0:
|
|
54
|
-
# between alpha_low and alpha_j
|
|
55
|
-
# a_h = a_l
|
|
56
|
-
# f_h = f_l
|
|
57
|
-
# g_h = g_l
|
|
58
|
-
a_h = a_j
|
|
59
|
-
f_h = f_j
|
|
60
|
-
g_h = g_j
|
|
61
|
-
|
|
62
|
-
# is this messing it up?
|
|
63
|
-
else:
|
|
64
|
-
a_l = a_j
|
|
65
|
-
f_l = f_j
|
|
66
|
-
g_l = g_j
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
# check if interval too small
|
|
72
|
-
delta = abs(a_h - a_l)
|
|
73
|
-
if delta <= 1e-9 or delta <= 1e-6 * max(abs(a_l), abs(a_h)):
|
|
74
|
-
l_satisfies_wolfe = (f_l <= f_0 + c1 * a_l * g_0) and (abs(g_l) <= c2 * abs(g_0))
|
|
75
|
-
h_satisfies_wolfe = (f_h <= f_0 + c1 * a_h * g_0) and (abs(g_h) <= c2 * abs(g_0))
|
|
76
|
-
|
|
77
|
-
if l_satisfies_wolfe and h_satisfies_wolfe: return a_l if f_l <= f_h else a_h, f_h
|
|
78
|
-
if l_satisfies_wolfe: return a_l, f_l
|
|
79
|
-
if h_satisfies_wolfe: return a_h, f_h
|
|
80
|
-
if f_l <= f_0 + c1 * a_l * g_0: return a_l, f_l
|
|
81
|
-
return None,None
|
|
82
|
-
|
|
83
|
-
if a_j is None or a_j == a_l or a_j == a_h:
|
|
84
|
-
a_j = a_l + 0.5 * delta
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
return None,None
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
def strong_wolfe(
|
|
91
|
-
f,
|
|
92
|
-
f_0,
|
|
93
|
-
g_0,
|
|
94
|
-
init: float = 1.0,
|
|
95
|
-
c1: float = 1e-4,
|
|
96
|
-
c2: float = 0.9,
|
|
97
|
-
maxiter: int = 25,
|
|
98
|
-
maxzoom: int = 15,
|
|
99
|
-
# a_max: float = 1e30,
|
|
100
|
-
expand: float = 2.0, # Factor to increase alpha in bracketing
|
|
101
|
-
plus_minus: bool = False,
|
|
102
|
-
) -> tuple[float,float] | tuple[None,None]:
|
|
103
|
-
a_prev = 0.0
|
|
104
|
-
|
|
105
|
-
if g_0 == 0: return None,None
|
|
106
|
-
if g_0 > 0:
|
|
107
|
-
# if direction is not a descent direction, perform line search in opposite direction
|
|
108
|
-
if plus_minus:
|
|
109
|
-
def inverted_objective(alpha):
|
|
110
|
-
l, g = f(-alpha)
|
|
111
|
-
return l, -g
|
|
112
|
-
a, v = strong_wolfe(
|
|
113
|
-
inverted_objective,
|
|
114
|
-
init=init,
|
|
115
|
-
f_0=f_0,
|
|
116
|
-
g_0=-g_0,
|
|
117
|
-
c1=c1,
|
|
118
|
-
c2=c2,
|
|
119
|
-
maxiter=maxiter,
|
|
120
|
-
# a_max=a_max,
|
|
121
|
-
expand=expand,
|
|
122
|
-
plus_minus=False,
|
|
123
|
-
)
|
|
124
|
-
if a is not None and v is not None: return -a, v
|
|
125
|
-
return None, None
|
|
126
|
-
|
|
127
|
-
f_prev = f_0
|
|
128
|
-
g_prev = g_0
|
|
129
|
-
a_cur = init
|
|
130
|
-
|
|
131
|
-
# bracket
|
|
132
|
-
for i in range(maxiter):
|
|
133
|
-
|
|
134
|
-
f_cur, g_cur = f(a_cur)
|
|
135
|
-
|
|
136
|
-
# check armijo
|
|
137
|
-
armijo_violated = f_cur > f_0 + c1 * a_cur * g_0
|
|
138
|
-
func_increased = f_cur >= f_prev and i > 0
|
|
139
|
-
|
|
140
|
-
if armijo_violated or func_increased:
|
|
141
|
-
return _zoom(f,
|
|
142
|
-
a_prev, a_cur,
|
|
143
|
-
f_prev, g_prev,
|
|
144
|
-
f_cur, g_cur,
|
|
145
|
-
f_0, g_0,
|
|
146
|
-
c1, c2,
|
|
147
|
-
maxzoom=maxzoom,
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
# check strong wolfe
|
|
153
|
-
if abs(g_cur) <= c2 * abs(g_0):
|
|
154
|
-
return a_cur, f_cur
|
|
155
|
-
|
|
156
|
-
# minimum is bracketed
|
|
157
|
-
if g_cur >= 0:
|
|
158
|
-
return _zoom(f,
|
|
159
|
-
#alpha_curr, alpha_prev,
|
|
160
|
-
a_prev, a_cur,
|
|
161
|
-
#phi_curr, phi_prime_curr,
|
|
162
|
-
f_prev, g_prev,
|
|
163
|
-
f_cur, g_cur,
|
|
164
|
-
f_0, g_0,
|
|
165
|
-
c1, c2,
|
|
166
|
-
maxzoom=maxzoom,)
|
|
167
|
-
|
|
168
|
-
# otherwise continue bracketing
|
|
169
|
-
a_next = a_cur * expand
|
|
170
|
-
|
|
171
|
-
# update previous point and continue loop with increased step size
|
|
172
|
-
a_prev = a_cur
|
|
173
|
-
f_prev = f_cur
|
|
174
|
-
g_prev = g_cur
|
|
175
|
-
a_cur = a_next
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
# max iters reached
|
|
179
|
-
return None, None
|
|
180
|
-
|
|
181
|
-
def _notfinite(x):
|
|
182
|
-
if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
|
|
183
|
-
return not math.isfinite(x)
|
|
184
188
|
|
|
185
189
|
class StrongWolfe(LineSearchBase):
|
|
186
|
-
"""
|
|
190
|
+
"""Interpolation line search satisfying Strong Wolfe condition.
|
|
187
191
|
|
|
188
192
|
Args:
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
193
|
+
c1 (float, optional): sufficient descent condition. Defaults to 1e-4.
|
|
194
|
+
c2 (float, optional): strong curvature condition. For CG set to 0.1. Defaults to 0.9.
|
|
195
|
+
a_init (str, optional):
|
|
196
|
+
strategy for initializing the initial step size guess.
|
|
197
|
+
- "fixed" - uses a fixed value specified in `init_value` argument.
|
|
198
|
+
- "first-order" - assumes first-order change in the function at iterate will be the same as that obtained at the previous step.
|
|
199
|
+
- "quadratic" - interpolates quadratic to f(x_{-1}) and f_x.
|
|
200
|
+
- "quadratic-clip" - same as quad, but uses min(1, 1.01*alpha) as described in Numerical Optimization.
|
|
201
|
+
- "previous" - uses final step size found on previous iteration.
|
|
202
|
+
|
|
203
|
+
For 2nd order methods it is usually best to leave at "fixed".
|
|
204
|
+
For methods that do not produce well scaled search directions, e.g. conjugate gradient,
|
|
205
|
+
"first-order" or "quadratic-clip" are recommended. Defaults to 'init'.
|
|
206
|
+
a_max (float, optional): upper bound for the proposed step sizes. Defaults to 1e12.
|
|
207
|
+
init_value (float, optional):
|
|
208
|
+
initial step size. Used when ``a_init``="fixed", and with other strategies as fallback value. Defaults to 1.
|
|
209
|
+
maxiter (int, optional): maximum number of line search iterations. Defaults to 25.
|
|
210
|
+
maxzoom (int, optional): maximum number of zoom iterations. Defaults to 10.
|
|
211
|
+
maxeval (int | None, optional): maximum number of function evaluations. Defaults to None.
|
|
212
|
+
tol_change (float, optional): tolerance, terminates on small brackets. Defaults to 1e-9.
|
|
213
|
+
interpolation (str, optional):
|
|
214
|
+
What type of interpolation to use.
|
|
215
|
+
- "bisection" - uses the middle point. This is robust, especially if the objective function is non-smooth, however it may need more function evaluations.
|
|
216
|
+
- "quadratic" - minimizes a quadratic model, generally outperformed by "cubic".
|
|
217
|
+
- "cubic" - minimizes a cubic model - this is the most widely used interpolation strategy.
|
|
218
|
+
- "polynomial" - fits a a polynomial to all points obtained during line search.
|
|
219
|
+
- "polynomial2" - alternative polynomial fit, where if a point is outside of bounds, a lower degree polynomial is tried.
|
|
220
|
+
This may have faster convergence than "cubic" and "polynomial".
|
|
221
|
+
|
|
222
|
+
Defaults to 'cubic'.
|
|
197
223
|
adaptive (bool, optional):
|
|
198
|
-
|
|
199
|
-
|
|
224
|
+
if True, the initial step size will be halved when line search failed to find a good direction.
|
|
225
|
+
When a good direction is found, initial step size is reset to the original value. Defaults to True.
|
|
226
|
+
fallback (bool, optional):
|
|
227
|
+
if True, when no point satisfied strong wolfe criteria,
|
|
228
|
+
returns a point with value lower than initial value that doesn't satisfy the criteria. Defaults to False.
|
|
200
229
|
plus_minus (bool, optional):
|
|
201
|
-
|
|
202
|
-
|
|
230
|
+
if True, enables the plus-minus variant, where if curvature is negative, line search is performed
|
|
231
|
+
in the opposite direction. Defaults to False.
|
|
203
232
|
|
|
204
|
-
Examples:
|
|
205
|
-
Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
|
|
206
233
|
|
|
207
|
-
|
|
234
|
+
## Examples:
|
|
208
235
|
|
|
209
|
-
|
|
210
|
-
model.parameters(),
|
|
211
|
-
tz.m.PolakRibiere(),
|
|
212
|
-
tz.m.StrongWolfe(c2=0.1)
|
|
213
|
-
)
|
|
236
|
+
Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by ``a_init="first-order"``.
|
|
214
237
|
|
|
215
|
-
|
|
238
|
+
```python
|
|
239
|
+
opt = tz.Modular(
|
|
240
|
+
model.parameters(),
|
|
241
|
+
tz.m.PolakRibiere(),
|
|
242
|
+
tz.m.StrongWolfe(c2=0.1, a_init="first-order")
|
|
243
|
+
)
|
|
244
|
+
```
|
|
216
245
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
246
|
+
LBFGS strong wolfe line search:
|
|
247
|
+
```python
|
|
248
|
+
opt = tz.Modular(
|
|
249
|
+
model.parameters(),
|
|
250
|
+
tz.m.LBFGS(),
|
|
251
|
+
tz.m.StrongWolfe()
|
|
252
|
+
)
|
|
253
|
+
```
|
|
224
254
|
|
|
225
255
|
"""
|
|
226
256
|
def __init__(
|
|
227
257
|
self,
|
|
228
|
-
init: float = 1.0,
|
|
229
258
|
c1: float = 1e-4,
|
|
230
259
|
c2: float = 0.9,
|
|
260
|
+
a_init: Literal['first-order', 'quadratic', 'quadratic-clip', 'previous', 'fixed'] = 'fixed',
|
|
261
|
+
a_max: float = 1e12,
|
|
262
|
+
init_value: float = 1,
|
|
231
263
|
maxiter: int = 25,
|
|
232
264
|
maxzoom: int = 10,
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
265
|
+
maxeval: int | None = None,
|
|
266
|
+
tol_change: float = 1e-9,
|
|
267
|
+
interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", 'polynomial2'] = 'cubic',
|
|
236
268
|
adaptive = True,
|
|
269
|
+
fallback:bool = False,
|
|
237
270
|
plus_minus = False,
|
|
238
271
|
):
|
|
239
|
-
defaults=dict(init=
|
|
240
|
-
|
|
272
|
+
defaults=dict(init_value=init_value,init=a_init,a_max=a_max,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom, fallback=fallback,
|
|
273
|
+
maxeval=maxeval, adaptive=adaptive, interpolation=interpolation, plus_minus=plus_minus, tol_change=tol_change)
|
|
241
274
|
super().__init__(defaults=defaults)
|
|
242
275
|
|
|
243
276
|
self.global_state['initial_scale'] = 1.0
|
|
244
|
-
self.global_state['beta_scale'] = 1.0
|
|
245
277
|
|
|
246
278
|
@torch.no_grad
|
|
247
279
|
def search(self, update, var):
|
|
280
|
+
self._g_prev = self._f_prev = None
|
|
248
281
|
objective = self.make_objective_with_derivative(var=var)
|
|
249
282
|
|
|
250
|
-
init, c1, c2, maxiter, maxzoom,
|
|
251
|
-
'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
252
|
-
'
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
283
|
+
init_value, init, c1, c2, a_max, maxiter, maxzoom, maxeval, interpolation, adaptive, plus_minus, fallback, tol_change = itemgetter(
|
|
284
|
+
'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
|
|
285
|
+
'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)
|
|
286
|
+
|
|
287
|
+
dir = as_tensorlist(var.get_update())
|
|
288
|
+
grad_list = var.get_grad()
|
|
289
|
+
|
|
290
|
+
g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
|
|
291
|
+
f_0 = var.get_loss(False)
|
|
292
|
+
dir_norm = dir.global_vector_norm()
|
|
293
|
+
|
|
294
|
+
inverted = False
|
|
295
|
+
if plus_minus and g_0 > 0:
|
|
296
|
+
original_objective = objective
|
|
297
|
+
def inverted_objective(a):
|
|
298
|
+
l, g_a = original_objective(-a)
|
|
299
|
+
return l, -g_a
|
|
300
|
+
objective = inverted_objective
|
|
301
|
+
inverted = True
|
|
302
|
+
|
|
303
|
+
# --------------------- determine initial step size guess -------------------- #
|
|
304
|
+
init = init.lower().strip()
|
|
305
|
+
|
|
306
|
+
a_init = init_value
|
|
307
|
+
if init == 'fixed':
|
|
308
|
+
pass # use init_value
|
|
309
|
+
|
|
310
|
+
elif init == 'previous':
|
|
311
|
+
if 'a_prev' in self.global_state:
|
|
312
|
+
a_init = self.global_state['a_prev']
|
|
313
|
+
|
|
314
|
+
elif init == 'first-order':
|
|
315
|
+
if 'g_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
|
|
316
|
+
a_prev = self.global_state['a_prev']
|
|
317
|
+
g_prev = self.global_state['g_prev']
|
|
318
|
+
if g_prev < 0:
|
|
319
|
+
a_init = a_prev * g_prev / g_0
|
|
320
|
+
|
|
321
|
+
elif init in ('quadratic', 'quadratic-clip'):
|
|
322
|
+
if 'f_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
|
|
323
|
+
f_prev = self.global_state['f_prev']
|
|
324
|
+
if f_0 < f_prev:
|
|
325
|
+
a_init = 2 * (f_0 - f_prev) / g_0
|
|
326
|
+
if init == 'quadratic-clip': a_init = min(1, 1.01*a_init)
|
|
327
|
+
else:
|
|
328
|
+
raise ValueError(init)
|
|
329
|
+
|
|
330
|
+
if adaptive:
|
|
331
|
+
a_init *= self.global_state.get('initial_scale', 1)
|
|
332
|
+
|
|
333
|
+
strong_wolfe = _StrongWolfe(
|
|
334
|
+
f=objective,
|
|
335
|
+
f_0=f_0,
|
|
336
|
+
g_0=g_0,
|
|
337
|
+
d_norm=dir_norm,
|
|
338
|
+
a_init=a_init,
|
|
339
|
+
a_max=a_max,
|
|
261
340
|
c1=c1,
|
|
262
341
|
c2=c2,
|
|
263
342
|
maxiter=maxiter,
|
|
264
343
|
maxzoom=maxzoom,
|
|
265
|
-
|
|
266
|
-
|
|
344
|
+
maxeval=maxeval,
|
|
345
|
+
tol_change=tol_change,
|
|
346
|
+
interpolation=interpolation,
|
|
267
347
|
)
|
|
268
348
|
|
|
269
|
-
|
|
270
|
-
if
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
349
|
+
a, f_a, g_a = strong_wolfe.search()
|
|
350
|
+
if inverted and a is not None: a = -a
|
|
351
|
+
if f_a is not None and (f_a > f_0 or not math.isfinite(f_a)): a = None
|
|
352
|
+
|
|
353
|
+
if fallback:
|
|
354
|
+
if a is None or a==0 or not math.isfinite(a):
|
|
355
|
+
lowest = min(strong_wolfe.history.items(), key=lambda x: x[1][0])
|
|
356
|
+
if lowest[1][0] < f_0:
|
|
357
|
+
a = lowest[0]
|
|
358
|
+
f_a, g_a = lowest[1]
|
|
359
|
+
if inverted: a = -a
|
|
360
|
+
|
|
361
|
+
if a is not None and a != 0 and math.isfinite(a):
|
|
362
|
+
self.global_state['initial_scale'] = 1
|
|
363
|
+
self.global_state['a_prev'] = a
|
|
364
|
+
self.global_state['f_prev'] = f_0
|
|
365
|
+
self.global_state['g_prev'] = g_0
|
|
366
|
+
return a
|
|
367
|
+
|
|
368
|
+
# fail
|
|
369
|
+
if adaptive:
|
|
370
|
+
self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
|
|
371
|
+
finfo = torch.finfo(dir[0].dtype)
|
|
372
|
+
if self.global_state['initial_scale'] < finfo.tiny * 2:
|
|
373
|
+
self.global_state['initial_scale'] = finfo.max / 2
|
|
274
374
|
|
|
275
|
-
if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
276
375
|
return 0
|
|
@@ -1,6 +1,13 @@
|
|
|
1
1
|
from .debug import PrintLoss, PrintParams, PrintShape, PrintUpdate
|
|
2
2
|
from .escape import EscapeAnnealing
|
|
3
3
|
from .gradient_accumulation import GradientAccumulation
|
|
4
|
+
from .homotopy import (
|
|
5
|
+
ExpHomotopy,
|
|
6
|
+
LambdaHomotopy,
|
|
7
|
+
LogHomotopy,
|
|
8
|
+
SqrtHomotopy,
|
|
9
|
+
SquareHomotopy,
|
|
10
|
+
)
|
|
4
11
|
from .misc import (
|
|
5
12
|
DivByLoss,
|
|
6
13
|
FillLoss,
|
|
@@ -20,6 +27,7 @@ from .misc import (
|
|
|
20
27
|
RandomHvp,
|
|
21
28
|
Relative,
|
|
22
29
|
UpdateSign,
|
|
30
|
+
SaveBest,
|
|
23
31
|
)
|
|
24
32
|
from .multistep import Multistep, NegateOnLossIncrease, Online, Sequential
|
|
25
33
|
from .regularization import Dropout, PerturbWeights, WeightDropout
|
torchzero/modules/misc/debug.py
CHANGED
|
@@ -12,7 +12,7 @@ class PrintUpdate(Module):
|
|
|
12
12
|
super().__init__(defaults)
|
|
13
13
|
|
|
14
14
|
def step(self, var):
|
|
15
|
-
self.
|
|
15
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{var.update}')
|
|
16
16
|
return var
|
|
17
17
|
|
|
18
18
|
class PrintShape(Module):
|
|
@@ -23,7 +23,7 @@ class PrintShape(Module):
|
|
|
23
23
|
|
|
24
24
|
def step(self, var):
|
|
25
25
|
shapes = [u.shape for u in var.update] if var.update is not None else None
|
|
26
|
-
self.
|
|
26
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
|
|
27
27
|
return var
|
|
28
28
|
|
|
29
29
|
class PrintParams(Module):
|
|
@@ -33,7 +33,7 @@ class PrintParams(Module):
|
|
|
33
33
|
super().__init__(defaults)
|
|
34
34
|
|
|
35
35
|
def step(self, var):
|
|
36
|
-
self.
|
|
36
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{var.params}')
|
|
37
37
|
return var
|
|
38
38
|
|
|
39
39
|
|
|
@@ -44,5 +44,5 @@ class PrintLoss(Module):
|
|
|
44
44
|
super().__init__(defaults)
|
|
45
45
|
|
|
46
46
|
def step(self, var):
|
|
47
|
-
self.
|
|
47
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{var.get_loss(False)}')
|
|
48
48
|
return var
|