torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,58 +1,73 @@
|
|
|
1
1
|
import math
|
|
2
|
+
from bisect import insort
|
|
3
|
+
from collections import deque
|
|
2
4
|
from collections.abc import Callable
|
|
3
5
|
from operator import itemgetter
|
|
4
6
|
|
|
7
|
+
import numpy as np
|
|
5
8
|
import torch
|
|
6
9
|
|
|
7
|
-
from .line_search import LineSearchBase
|
|
8
|
-
|
|
10
|
+
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def adaptive_tracking(
|
|
12
14
|
f,
|
|
13
|
-
|
|
15
|
+
a_init,
|
|
14
16
|
maxiter: int,
|
|
15
17
|
nplus: float = 2,
|
|
16
18
|
nminus: float = 0.5,
|
|
19
|
+
f_0 = None,
|
|
17
20
|
):
|
|
18
|
-
|
|
21
|
+
niter = 0
|
|
22
|
+
if f_0 is None: f_0 = f(0)
|
|
19
23
|
|
|
20
|
-
|
|
21
|
-
|
|
24
|
+
a = a_init
|
|
25
|
+
f_a = f(a)
|
|
22
26
|
|
|
23
27
|
# backtrack
|
|
24
|
-
|
|
25
|
-
|
|
28
|
+
a_prev = a
|
|
29
|
+
f_prev = math.inf
|
|
30
|
+
if (f_a > f_0) or (not math.isfinite(f_a)):
|
|
31
|
+
while (f_a < f_prev) or not math.isfinite(f_a):
|
|
32
|
+
a_prev, f_prev = a, f_a
|
|
26
33
|
maxiter -= 1
|
|
27
|
-
if maxiter < 0:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
if maxiter < 0: break
|
|
35
|
+
|
|
36
|
+
a = a*nminus
|
|
37
|
+
f_a = f(a)
|
|
38
|
+
niter += 1
|
|
39
|
+
|
|
40
|
+
if f_prev < f_0: return a_prev, f_prev, niter
|
|
41
|
+
return 0, f_0, niter
|
|
31
42
|
|
|
32
43
|
# forwardtrack
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
while f_prev >= f_t:
|
|
44
|
+
a_prev = a
|
|
45
|
+
f_prev = math.inf
|
|
46
|
+
while (f_a <= f_prev) and math.isfinite(f_a):
|
|
47
|
+
a_prev, f_prev = a, f_a
|
|
38
48
|
maxiter -= 1
|
|
39
|
-
if maxiter < 0:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
49
|
+
if maxiter < 0: break
|
|
50
|
+
|
|
51
|
+
a *= nplus
|
|
52
|
+
f_a = f(a)
|
|
53
|
+
niter+= 1
|
|
54
|
+
|
|
55
|
+
if f_prev < f_0: return a_prev, f_prev, niter
|
|
56
|
+
return 0, f_0, niter
|
|
57
|
+
|
|
44
58
|
|
|
45
|
-
class
|
|
46
|
-
"""
|
|
47
|
-
|
|
59
|
+
class AdaptiveTracking(LineSearchBase):
|
|
60
|
+
"""A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
|
|
61
|
+
otherwise forward-tracks until value stops decreasing.
|
|
48
62
|
|
|
49
63
|
Args:
|
|
50
64
|
init (float, optional): initial step size. Defaults to 1.0.
|
|
51
|
-
|
|
52
|
-
|
|
65
|
+
nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
|
|
66
|
+
nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
|
|
67
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
53
68
|
adaptive (bool, optional):
|
|
54
|
-
when enabled, if line search failed,
|
|
55
|
-
Otherwise it
|
|
69
|
+
when enabled, if line search failed, step size will continue decreasing on the next step.
|
|
70
|
+
Otherwise it will restart the line search from ``init`` step size. Defaults to True.
|
|
56
71
|
"""
|
|
57
72
|
def __init__(
|
|
58
73
|
self,
|
|
@@ -62,38 +77,48 @@ class AdaptiveLineSearch(LineSearchBase):
|
|
|
62
77
|
maxiter: int = 10,
|
|
63
78
|
adaptive=True,
|
|
64
79
|
):
|
|
65
|
-
defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive
|
|
80
|
+
defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
|
|
66
81
|
super().__init__(defaults=defaults)
|
|
67
|
-
self.global_state['beta_scale'] = 1.0
|
|
68
82
|
|
|
69
83
|
def reset(self):
|
|
70
84
|
super().reset()
|
|
71
|
-
self.global_state['beta_scale'] = 1.0
|
|
72
85
|
|
|
73
86
|
@torch.no_grad
|
|
74
87
|
def search(self, update, var):
|
|
75
88
|
init, nplus, nminus, maxiter, adaptive = itemgetter(
|
|
76
|
-
'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.
|
|
89
|
+
'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)
|
|
77
90
|
|
|
78
91
|
objective = self.make_objective(var=var)
|
|
79
92
|
|
|
80
|
-
#
|
|
81
|
-
|
|
93
|
+
# scale a_prev
|
|
94
|
+
a_prev = self.global_state.get('a_prev', init)
|
|
95
|
+
if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)
|
|
82
96
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
97
|
+
a_init = a_prev
|
|
98
|
+
if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
|
|
99
|
+
a_init = torch.finfo(var.params[0].dtype).max / 2
|
|
86
100
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
101
|
+
step_size, f, niter = adaptive_tracking(
|
|
102
|
+
objective,
|
|
103
|
+
a_init=a_init,
|
|
104
|
+
maxiter=maxiter,
|
|
105
|
+
nplus=nplus,
|
|
106
|
+
nminus=nminus,
|
|
107
|
+
)
|
|
91
108
|
|
|
92
109
|
# found an alpha that reduces loss
|
|
93
110
|
if step_size != 0:
|
|
94
|
-
|
|
111
|
+
assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
|
|
112
|
+
self.global_state['init_scale'] = 1
|
|
113
|
+
|
|
114
|
+
# if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
|
|
115
|
+
if niter == 1 and step_size >= a_init: step_size *= nminus
|
|
116
|
+
|
|
117
|
+
self.global_state['a_prev'] = step_size
|
|
95
118
|
return step_size
|
|
96
119
|
|
|
97
120
|
# on fail reduce beta scale value
|
|
98
|
-
self.global_state['
|
|
121
|
+
self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
|
|
122
|
+
self.global_state['a_prev'] = init
|
|
99
123
|
return 0
|
|
124
|
+
|
|
@@ -4,7 +4,7 @@ from operator import itemgetter
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from .line_search import LineSearchBase
|
|
7
|
+
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def backtracking_line_search(
|
|
@@ -14,7 +14,7 @@ def backtracking_line_search(
|
|
|
14
14
|
beta: float = 0.5,
|
|
15
15
|
c: float = 1e-4,
|
|
16
16
|
maxiter: int = 10,
|
|
17
|
-
|
|
17
|
+
condition: TerminationCondition = 'armijo',
|
|
18
18
|
) -> float | None:
|
|
19
19
|
"""
|
|
20
20
|
|
|
@@ -31,16 +31,20 @@ def backtracking_line_search(
|
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
a = init
|
|
34
|
-
|
|
34
|
+
f_0 = f(0)
|
|
35
35
|
f_prev = None
|
|
36
36
|
|
|
37
37
|
for iteration in range(maxiter):
|
|
38
38
|
f_a = f(a)
|
|
39
|
+
if not math.isfinite(f_a):
|
|
40
|
+
a *= beta
|
|
41
|
+
continue
|
|
39
42
|
|
|
40
|
-
if (f_prev is not None) and (f_a > f_prev) and (f_prev <
|
|
43
|
+
if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_0):
|
|
44
|
+
return a / beta # new value is larger than previous value
|
|
41
45
|
f_prev = f_a
|
|
42
46
|
|
|
43
|
-
if
|
|
47
|
+
if termination_condition(condition, f_0=f_0, g_0=g_0, f_a=f_a, g_a=None, a=a, c=c):
|
|
44
48
|
# found an acceptable alpha
|
|
45
49
|
return a
|
|
46
50
|
|
|
@@ -48,53 +52,45 @@ def backtracking_line_search(
|
|
|
48
52
|
a *= beta
|
|
49
53
|
|
|
50
54
|
# fail
|
|
51
|
-
if try_negative:
|
|
52
|
-
def inv_objective(alpha): return f(-alpha)
|
|
53
|
-
|
|
54
|
-
v = backtracking_line_search(
|
|
55
|
-
inv_objective,
|
|
56
|
-
g_0=-g_0,
|
|
57
|
-
beta=beta,
|
|
58
|
-
c=c,
|
|
59
|
-
maxiter=maxiter,
|
|
60
|
-
try_negative=False,
|
|
61
|
-
)
|
|
62
|
-
if v is not None: return -v
|
|
63
|
-
|
|
64
55
|
return None
|
|
65
56
|
|
|
66
57
|
class Backtracking(LineSearchBase):
|
|
67
|
-
"""Backtracking line search
|
|
58
|
+
"""Backtracking line search.
|
|
68
59
|
|
|
69
60
|
Args:
|
|
70
61
|
init (float, optional): initial step size. Defaults to 1.0.
|
|
71
62
|
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
72
|
-
c (float, optional):
|
|
73
|
-
|
|
63
|
+
c (float, optional): sufficient decrease condition. Defaults to 1e-4.
|
|
64
|
+
condition (TerminationCondition, optional):
|
|
65
|
+
termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
|
|
66
|
+
- "armijo" - sufficient decrease condition.
|
|
67
|
+
- "decrease" - any decrease in objective function value satisfies the condition.
|
|
68
|
+
|
|
69
|
+
"goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
|
|
70
|
+
Defaults to 'armijo'.
|
|
71
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
74
72
|
adaptive (bool, optional):
|
|
75
|
-
when enabled, if line search failed,
|
|
76
|
-
Otherwise it
|
|
77
|
-
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
73
|
+
when enabled, if line search failed, step size will continue decreasing on the next step.
|
|
74
|
+
Otherwise it will restart the line search from ``init`` step size. Defaults to True.
|
|
78
75
|
|
|
79
76
|
Examples:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
)
|
|
77
|
+
Gradient descent with backtracking line search:
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.Backtracking()
|
|
83
|
+
)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
L-BFGS with backtracking line search:
|
|
87
|
+
```python
|
|
88
|
+
opt = tz.Modular(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.LBFGS(),
|
|
91
|
+
tz.m.Backtracking()
|
|
92
|
+
)
|
|
93
|
+
```
|
|
98
94
|
|
|
99
95
|
"""
|
|
100
96
|
def __init__(
|
|
@@ -102,41 +98,47 @@ class Backtracking(LineSearchBase):
|
|
|
102
98
|
init: float = 1.0,
|
|
103
99
|
beta: float = 0.5,
|
|
104
100
|
c: float = 1e-4,
|
|
101
|
+
condition: TerminationCondition = 'armijo',
|
|
105
102
|
maxiter: int = 10,
|
|
106
103
|
adaptive=True,
|
|
107
|
-
try_negative: bool = False,
|
|
108
104
|
):
|
|
109
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive
|
|
105
|
+
defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
|
|
110
106
|
super().__init__(defaults=defaults)
|
|
111
|
-
self.global_state['beta_scale'] = 1.0
|
|
112
107
|
|
|
113
108
|
def reset(self):
|
|
114
109
|
super().reset()
|
|
115
|
-
self.global_state['beta_scale'] = 1.0
|
|
116
110
|
|
|
117
111
|
@torch.no_grad
|
|
118
112
|
def search(self, update, var):
|
|
119
|
-
init, beta, c, maxiter, adaptive
|
|
120
|
-
'init', 'beta', 'c', '
|
|
113
|
+
init, beta, c, condition, maxiter, adaptive = itemgetter(
|
|
114
|
+
'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)
|
|
121
115
|
|
|
122
116
|
objective = self.make_objective(var=var)
|
|
123
117
|
|
|
124
118
|
# # directional derivative
|
|
125
|
-
|
|
119
|
+
if c == 0: d = 0
|
|
120
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
|
|
126
121
|
|
|
127
|
-
# scale
|
|
128
|
-
|
|
122
|
+
# scale init
|
|
123
|
+
init_scale = self.global_state.get('init_scale', 1)
|
|
124
|
+
if adaptive: init = init * init_scale
|
|
129
125
|
|
|
130
|
-
step_size = backtracking_line_search(objective, d, init=init,beta=beta,
|
|
131
|
-
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
126
|
+
step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)
|
|
132
127
|
|
|
133
128
|
# found an alpha that reduces loss
|
|
134
129
|
if step_size is not None:
|
|
135
|
-
self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
130
|
+
#self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
131
|
+
self.global_state['init_scale'] = 1
|
|
136
132
|
return step_size
|
|
137
133
|
|
|
138
|
-
# on fail
|
|
139
|
-
|
|
134
|
+
# on fail set init_scale to continue decreasing the step size
|
|
135
|
+
# or set to large step size when it becomes too small
|
|
136
|
+
if adaptive:
|
|
137
|
+
finfo = torch.finfo(var.params[0].dtype)
|
|
138
|
+
if init_scale <= finfo.tiny * 2:
|
|
139
|
+
self.global_state["init_scale"] = finfo.max / 2
|
|
140
|
+
else:
|
|
141
|
+
self.global_state['init_scale'] = init_scale * beta**maxiter
|
|
140
142
|
return 0
|
|
141
143
|
|
|
142
144
|
def _lerp(start,end,weight):
|
|
@@ -147,30 +149,37 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
147
149
|
such that optimal step size in the procedure would be found on the second line search iteration.
|
|
148
150
|
|
|
149
151
|
Args:
|
|
150
|
-
init (float, optional): step size
|
|
152
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
151
153
|
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
152
|
-
c (float, optional):
|
|
153
|
-
|
|
154
|
+
c (float, optional): sufficient decrease condition. Defaults to 1e-4.
|
|
155
|
+
condition (TerminationCondition, optional):
|
|
156
|
+
termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
|
|
157
|
+
- "armijo" - sufficient decrease condition.
|
|
158
|
+
- "decrease" - any decrease in objective function value satisfies the condition.
|
|
159
|
+
|
|
160
|
+
"goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
|
|
161
|
+
Defaults to 'armijo'.
|
|
162
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
154
163
|
target_iters (int, optional):
|
|
155
|
-
|
|
164
|
+
sets next step size such that this number of iterations are expected
|
|
165
|
+
to be performed until optimal step size is found. Defaults to 1.
|
|
156
166
|
nplus (float, optional):
|
|
157
|
-
|
|
167
|
+
if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
|
|
158
168
|
scale_beta (float, optional):
|
|
159
|
-
|
|
160
|
-
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
169
|
+
momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
|
|
161
170
|
"""
|
|
162
171
|
def __init__(
|
|
163
172
|
self,
|
|
164
173
|
init: float = 1.0,
|
|
165
174
|
beta: float = 0.5,
|
|
166
175
|
c: float = 1e-4,
|
|
176
|
+
condition: TerminationCondition = 'armijo',
|
|
167
177
|
maxiter: int = 20,
|
|
168
178
|
target_iters = 1,
|
|
169
179
|
nplus = 2.0,
|
|
170
180
|
scale_beta = 0.0,
|
|
171
|
-
try_negative: bool = False,
|
|
172
181
|
):
|
|
173
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta
|
|
182
|
+
defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
|
|
174
183
|
super().__init__(defaults=defaults)
|
|
175
184
|
|
|
176
185
|
self.global_state['beta_scale'] = 1.0
|
|
@@ -183,8 +192,8 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
183
192
|
|
|
184
193
|
@torch.no_grad
|
|
185
194
|
def search(self, update, var):
|
|
186
|
-
init, beta, c, maxiter, target_iters, nplus, scale_beta
|
|
187
|
-
'init','beta','c','maxiter','target_iters','nplus','scale_beta'
|
|
195
|
+
init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
|
|
196
|
+
'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)
|
|
188
197
|
|
|
189
198
|
objective = self.make_objective(var=var)
|
|
190
199
|
|
|
@@ -198,8 +207,7 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
198
207
|
# scale step size so that decrease is expected at target_iters
|
|
199
208
|
init = init * self.global_state['initial_scale']
|
|
200
209
|
|
|
201
|
-
step_size = backtracking_line_search(objective, d, init=init, beta=beta,
|
|
202
|
-
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
210
|
+
step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)
|
|
203
211
|
|
|
204
212
|
# found an alpha that reduces loss
|
|
205
213
|
if step_size is not None:
|
|
@@ -208,7 +216,12 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
208
216
|
# initial step size satisfied conditions, increase initial_scale by nplus
|
|
209
217
|
if step_size == init and target_iters > 0:
|
|
210
218
|
self.global_state['initial_scale'] *= nplus ** target_iters
|
|
211
|
-
|
|
219
|
+
|
|
220
|
+
# clip by maximum possibel value to avoid overflow exception
|
|
221
|
+
self.global_state['initial_scale'] = min(
|
|
222
|
+
self.global_state['initial_scale'],
|
|
223
|
+
torch.finfo(var.params[0].dtype).max / 2,
|
|
224
|
+
)
|
|
212
225
|
|
|
213
226
|
else:
|
|
214
227
|
# otherwise make initial_scale such that target_iters iterations will satisfy armijo
|