torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -4,7 +4,7 @@ from operator import itemgetter
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from .line_search import LineSearchBase
|
|
7
|
+
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def backtracking_line_search(
|
|
@@ -14,7 +14,7 @@ def backtracking_line_search(
|
|
|
14
14
|
beta: float = 0.5,
|
|
15
15
|
c: float = 1e-4,
|
|
16
16
|
maxiter: int = 10,
|
|
17
|
-
|
|
17
|
+
condition: TerminationCondition = 'armijo',
|
|
18
18
|
) -> float | None:
|
|
19
19
|
"""
|
|
20
20
|
|
|
@@ -31,16 +31,20 @@ def backtracking_line_search(
|
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
33
|
a = init
|
|
34
|
-
|
|
34
|
+
f_0 = f(0)
|
|
35
35
|
f_prev = None
|
|
36
36
|
|
|
37
37
|
for iteration in range(maxiter):
|
|
38
38
|
f_a = f(a)
|
|
39
|
+
if not math.isfinite(f_a):
|
|
40
|
+
a *= beta
|
|
41
|
+
continue
|
|
39
42
|
|
|
40
|
-
if (f_prev is not None) and (f_a > f_prev) and (f_prev <
|
|
43
|
+
if (f_prev is not None) and (f_a > f_prev) and (f_prev < f_0):
|
|
44
|
+
return a / beta # new value is larger than previous value
|
|
41
45
|
f_prev = f_a
|
|
42
46
|
|
|
43
|
-
if
|
|
47
|
+
if termination_condition(condition, f_0=f_0, g_0=g_0, f_a=f_a, g_a=None, a=a, c=c):
|
|
44
48
|
# found an acceptable alpha
|
|
45
49
|
return a
|
|
46
50
|
|
|
@@ -48,53 +52,45 @@ def backtracking_line_search(
|
|
|
48
52
|
a *= beta
|
|
49
53
|
|
|
50
54
|
# fail
|
|
51
|
-
if try_negative:
|
|
52
|
-
def inv_objective(alpha): return f(-alpha)
|
|
53
|
-
|
|
54
|
-
v = backtracking_line_search(
|
|
55
|
-
inv_objective,
|
|
56
|
-
g_0=-g_0,
|
|
57
|
-
beta=beta,
|
|
58
|
-
c=c,
|
|
59
|
-
maxiter=maxiter,
|
|
60
|
-
try_negative=False,
|
|
61
|
-
)
|
|
62
|
-
if v is not None: return -v
|
|
63
|
-
|
|
64
55
|
return None
|
|
65
56
|
|
|
66
57
|
class Backtracking(LineSearchBase):
|
|
67
|
-
"""Backtracking line search
|
|
58
|
+
"""Backtracking line search.
|
|
68
59
|
|
|
69
60
|
Args:
|
|
70
61
|
init (float, optional): initial step size. Defaults to 1.0.
|
|
71
62
|
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
72
|
-
c (float, optional):
|
|
73
|
-
|
|
63
|
+
c (float, optional): sufficient decrease condition. Defaults to 1e-4.
|
|
64
|
+
condition (TerminationCondition, optional):
|
|
65
|
+
termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
|
|
66
|
+
- "armijo" - sufficient decrease condition.
|
|
67
|
+
- "decrease" - any decrease in objective function value satisfies the condition.
|
|
68
|
+
|
|
69
|
+
"goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
|
|
70
|
+
Defaults to 'armijo'.
|
|
71
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
74
72
|
adaptive (bool, optional):
|
|
75
|
-
when enabled, if line search failed,
|
|
76
|
-
Otherwise it
|
|
77
|
-
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
73
|
+
when enabled, if line search failed, step size will continue decreasing on the next step.
|
|
74
|
+
Otherwise it will restart the line search from ``init`` step size. Defaults to True.
|
|
78
75
|
|
|
79
76
|
Examples:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
)
|
|
77
|
+
Gradient descent with backtracking line search:
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.Backtracking()
|
|
83
|
+
)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
L-BFGS with backtracking line search:
|
|
87
|
+
```python
|
|
88
|
+
opt = tz.Modular(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.LBFGS(),
|
|
91
|
+
tz.m.Backtracking()
|
|
92
|
+
)
|
|
93
|
+
```
|
|
98
94
|
|
|
99
95
|
"""
|
|
100
96
|
def __init__(
|
|
@@ -102,41 +98,47 @@ class Backtracking(LineSearchBase):
|
|
|
102
98
|
init: float = 1.0,
|
|
103
99
|
beta: float = 0.5,
|
|
104
100
|
c: float = 1e-4,
|
|
101
|
+
condition: TerminationCondition = 'armijo',
|
|
105
102
|
maxiter: int = 10,
|
|
106
103
|
adaptive=True,
|
|
107
|
-
try_negative: bool = False,
|
|
108
104
|
):
|
|
109
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,adaptive=adaptive
|
|
105
|
+
defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
|
|
110
106
|
super().__init__(defaults=defaults)
|
|
111
|
-
self.global_state['beta_scale'] = 1.0
|
|
112
107
|
|
|
113
108
|
def reset(self):
|
|
114
109
|
super().reset()
|
|
115
|
-
self.global_state['beta_scale'] = 1.0
|
|
116
110
|
|
|
117
111
|
@torch.no_grad
|
|
118
112
|
def search(self, update, var):
|
|
119
|
-
init, beta, c, maxiter, adaptive
|
|
120
|
-
'init', 'beta', 'c', '
|
|
113
|
+
init, beta, c, condition, maxiter, adaptive = itemgetter(
|
|
114
|
+
'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)
|
|
121
115
|
|
|
122
116
|
objective = self.make_objective(var=var)
|
|
123
117
|
|
|
124
118
|
# # directional derivative
|
|
125
|
-
|
|
119
|
+
if c == 0: d = 0
|
|
120
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
|
|
126
121
|
|
|
127
|
-
# scale
|
|
128
|
-
|
|
122
|
+
# scale init
|
|
123
|
+
init_scale = self.global_state.get('init_scale', 1)
|
|
124
|
+
if adaptive: init = init * init_scale
|
|
129
125
|
|
|
130
|
-
step_size = backtracking_line_search(objective, d, init=init,beta=beta,
|
|
131
|
-
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
126
|
+
step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)
|
|
132
127
|
|
|
133
128
|
# found an alpha that reduces loss
|
|
134
129
|
if step_size is not None:
|
|
135
|
-
self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
130
|
+
#self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
131
|
+
self.global_state['init_scale'] = 1
|
|
136
132
|
return step_size
|
|
137
133
|
|
|
138
|
-
# on fail
|
|
139
|
-
|
|
134
|
+
# on fail set init_scale to continue decreasing the step size
|
|
135
|
+
# or set to large step size when it becomes too small
|
|
136
|
+
if adaptive:
|
|
137
|
+
finfo = torch.finfo(var.params[0].dtype)
|
|
138
|
+
if init_scale <= finfo.tiny * 2:
|
|
139
|
+
self.global_state["init_scale"] = finfo.max / 2
|
|
140
|
+
else:
|
|
141
|
+
self.global_state['init_scale'] = init_scale * beta**maxiter
|
|
140
142
|
return 0
|
|
141
143
|
|
|
142
144
|
def _lerp(start,end,weight):
|
|
@@ -147,30 +149,37 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
147
149
|
such that optimal step size in the procedure would be found on the second line search iteration.
|
|
148
150
|
|
|
149
151
|
Args:
|
|
150
|
-
init (float, optional): step size
|
|
152
|
+
init (float, optional): initial step size. Defaults to 1.0.
|
|
151
153
|
beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
|
|
152
|
-
c (float, optional):
|
|
153
|
-
|
|
154
|
+
c (float, optional): sufficient decrease condition. Defaults to 1e-4.
|
|
155
|
+
condition (TerminationCondition, optional):
|
|
156
|
+
termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
|
|
157
|
+
- "armijo" - sufficient decrease condition.
|
|
158
|
+
- "decrease" - any decrease in objective function value satisfies the condition.
|
|
159
|
+
|
|
160
|
+
"goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
|
|
161
|
+
Defaults to 'armijo'.
|
|
162
|
+
maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
|
|
154
163
|
target_iters (int, optional):
|
|
155
|
-
|
|
164
|
+
sets next step size such that this number of iterations are expected
|
|
165
|
+
to be performed until optimal step size is found. Defaults to 1.
|
|
156
166
|
nplus (float, optional):
|
|
157
|
-
|
|
167
|
+
if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
|
|
158
168
|
scale_beta (float, optional):
|
|
159
|
-
|
|
160
|
-
try_negative (bool, optional): Whether to perform line search in opposite direction on fail. Defaults to False.
|
|
169
|
+
momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
|
|
161
170
|
"""
|
|
162
171
|
def __init__(
|
|
163
172
|
self,
|
|
164
173
|
init: float = 1.0,
|
|
165
174
|
beta: float = 0.5,
|
|
166
175
|
c: float = 1e-4,
|
|
176
|
+
condition: TerminationCondition = 'armijo',
|
|
167
177
|
maxiter: int = 20,
|
|
168
178
|
target_iters = 1,
|
|
169
179
|
nplus = 2.0,
|
|
170
180
|
scale_beta = 0.0,
|
|
171
|
-
try_negative: bool = False,
|
|
172
181
|
):
|
|
173
|
-
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta
|
|
182
|
+
defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
|
|
174
183
|
super().__init__(defaults=defaults)
|
|
175
184
|
|
|
176
185
|
self.global_state['beta_scale'] = 1.0
|
|
@@ -183,8 +192,8 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
183
192
|
|
|
184
193
|
@torch.no_grad
|
|
185
194
|
def search(self, update, var):
|
|
186
|
-
init, beta, c, maxiter, target_iters, nplus, scale_beta
|
|
187
|
-
'init','beta','c','maxiter','target_iters','nplus','scale_beta'
|
|
195
|
+
init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
|
|
196
|
+
'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)
|
|
188
197
|
|
|
189
198
|
objective = self.make_objective(var=var)
|
|
190
199
|
|
|
@@ -198,8 +207,7 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
198
207
|
# scale step size so that decrease is expected at target_iters
|
|
199
208
|
init = init * self.global_state['initial_scale']
|
|
200
209
|
|
|
201
|
-
step_size = backtracking_line_search(objective, d, init=init, beta=beta,
|
|
202
|
-
c=c,maxiter=maxiter, try_negative=try_negative)
|
|
210
|
+
step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)
|
|
203
211
|
|
|
204
212
|
# found an alpha that reduces loss
|
|
205
213
|
if step_size is not None:
|
|
@@ -208,7 +216,12 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
208
216
|
# initial step size satisfied conditions, increase initial_scale by nplus
|
|
209
217
|
if step_size == init and target_iters > 0:
|
|
210
218
|
self.global_state['initial_scale'] *= nplus ** target_iters
|
|
211
|
-
|
|
219
|
+
|
|
220
|
+
# clip by maximum possibel value to avoid overflow exception
|
|
221
|
+
self.global_state['initial_scale'] = min(
|
|
222
|
+
self.global_state['initial_scale'],
|
|
223
|
+
torch.finfo(var.params[0].dtype).max / 2,
|
|
224
|
+
)
|
|
212
225
|
|
|
213
226
|
else:
|
|
214
227
|
# otherwise make initial_scale such that target_iters iterations will satisfy armijo
|
|
@@ -3,13 +3,13 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from functools import partial
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from ...core import Module, Target, Var
|
|
12
|
-
from ...utils import tofloat
|
|
12
|
+
from ...utils import tofloat, set_storage_
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class MaxLineSearchItersReached(Exception): pass
|
|
@@ -29,60 +29,59 @@ class LineSearchBase(Module, ABC):
|
|
|
29
29
|
doesn't have a maxiter option. Defaults to None.
|
|
30
30
|
|
|
31
31
|
Other useful methods:
|
|
32
|
-
*
|
|
33
|
-
*
|
|
34
|
-
*
|
|
35
|
-
*
|
|
32
|
+
* ``evaluate_f`` - returns loss with a given scalar step size
|
|
33
|
+
* ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
|
|
34
|
+
* ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
|
|
35
|
+
* ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
|
|
36
36
|
|
|
37
37
|
Examples:
|
|
38
|
-
#### Basic line search
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
#### Basic line search
|
|
41
40
|
|
|
42
|
-
|
|
41
|
+
This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
|
|
42
|
+
```python
|
|
43
|
+
class GridLineSearch(LineSearch):
|
|
44
|
+
def __init__(self, start, end, num):
|
|
45
|
+
defaults = dict(start=start,end=end,num=num)
|
|
46
|
+
super().__init__(defaults)
|
|
43
47
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
defaults = dict(start=start,end=end,num=num)
|
|
47
|
-
super().__init__(defaults)
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def search(self, update, var):
|
|
48
50
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
start = settings["start"]
|
|
53
|
-
end = settings["end"]
|
|
54
|
-
num = settings["num"]
|
|
51
|
+
start = self.defaults["start"]
|
|
52
|
+
end = self.defaults["end"]
|
|
53
|
+
num = self.defaults["num"]
|
|
55
54
|
|
|
56
|
-
|
|
57
|
-
|
|
55
|
+
lowest_loss = float("inf")
|
|
56
|
+
best_step_size = best_step_size
|
|
58
57
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
58
|
+
for step_size in torch.linspace(start,end,num):
|
|
59
|
+
loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
|
|
60
|
+
if loss < lowest_loss:
|
|
61
|
+
lowest_loss = loss
|
|
62
|
+
best_step_size = step_size
|
|
64
63
|
|
|
65
|
-
|
|
64
|
+
return best_step_size
|
|
65
|
+
```
|
|
66
66
|
|
|
67
|
-
|
|
67
|
+
#### Using external solver via self.make_objective
|
|
68
68
|
|
|
69
|
-
|
|
69
|
+
Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
|
|
70
70
|
|
|
71
|
-
|
|
71
|
+
```python
|
|
72
|
+
class ScipyMinimizeScalar(LineSearch):
|
|
73
|
+
def __init__(self, method: str | None = None):
|
|
74
|
+
defaults = dict(method=method)
|
|
75
|
+
super().__init__(defaults)
|
|
72
76
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
@torch.no_grad
|
|
79
|
-
def search(self, update, var):
|
|
80
|
-
objective = self.make_objective(var=var)
|
|
81
|
-
method = self.settings[var.params[0]]["method"]
|
|
82
|
-
|
|
83
|
-
res = self.scopt.minimize_scalar(objective, method=method)
|
|
84
|
-
return res.x
|
|
77
|
+
@torch.no_grad
|
|
78
|
+
def search(self, update, var):
|
|
79
|
+
objective = self.make_objective(var=var)
|
|
80
|
+
method = self.defaults["method"]
|
|
85
81
|
|
|
82
|
+
res = self.scopt.minimize_scalar(objective, method=method)
|
|
83
|
+
return res.x
|
|
84
|
+
```
|
|
86
85
|
"""
|
|
87
86
|
def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
|
|
88
87
|
super().__init__(defaults)
|
|
@@ -94,6 +93,7 @@ class LineSearchBase(Module, ABC):
|
|
|
94
93
|
self._lowest_loss = float('inf')
|
|
95
94
|
self._best_step_size: float = 0
|
|
96
95
|
self._current_iter = 0
|
|
96
|
+
self._initial_params = None
|
|
97
97
|
|
|
98
98
|
def set_step_size_(
|
|
99
99
|
self,
|
|
@@ -102,10 +102,27 @@ class LineSearchBase(Module, ABC):
|
|
|
102
102
|
update: list[torch.Tensor],
|
|
103
103
|
):
|
|
104
104
|
if not math.isfinite(step_size): return
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
105
|
+
|
|
106
|
+
# fixes overflow when backtracking keeps increasing alpha after converging
|
|
107
|
+
step_size = max(min(tofloat(step_size), 1e36), -1e36)
|
|
108
|
+
|
|
109
|
+
# skip is parameters are already at suggested step size
|
|
110
|
+
if self._current_step_size == step_size: return
|
|
111
|
+
|
|
112
|
+
# this was basically causing floating point imprecision to build up
|
|
113
|
+
#if False:
|
|
114
|
+
# if abs(alpha) < abs(step_size) and step_size != 0:
|
|
115
|
+
# torch._foreach_add_(params, update, alpha=alpha)
|
|
116
|
+
|
|
117
|
+
# else:
|
|
118
|
+
assert self._initial_params is not None
|
|
119
|
+
if step_size == 0:
|
|
120
|
+
new_params = [p.clone() for p in self._initial_params]
|
|
121
|
+
else:
|
|
122
|
+
new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
|
|
123
|
+
for c, n in zip(params, new_params):
|
|
124
|
+
set_storage_(c, n)
|
|
125
|
+
|
|
109
126
|
self._current_step_size = step_size
|
|
110
127
|
|
|
111
128
|
def _set_per_parameter_step_size_(
|
|
@@ -114,10 +131,20 @@ class LineSearchBase(Module, ABC):
|
|
|
114
131
|
params: list[torch.Tensor],
|
|
115
132
|
update: list[torch.Tensor],
|
|
116
133
|
):
|
|
117
|
-
if not np.isfinite(step_size): step_size = [0 for _ in step_size]
|
|
118
|
-
alpha = [self._current_step_size - s for s in step_size]
|
|
119
|
-
if any(a!=0 for a in alpha):
|
|
120
|
-
|
|
134
|
+
# if not np.isfinite(step_size): step_size = [0 for _ in step_size]
|
|
135
|
+
# alpha = [self._current_step_size - s for s in step_size]
|
|
136
|
+
# if any(a!=0 for a in alpha):
|
|
137
|
+
# torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
138
|
+
assert self._initial_params is not None
|
|
139
|
+
if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
|
|
140
|
+
|
|
141
|
+
if any(s!=0 for s in step_size):
|
|
142
|
+
new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
|
|
143
|
+
else:
|
|
144
|
+
new_params = [p.clone() for p in self._initial_params]
|
|
145
|
+
|
|
146
|
+
for c, n in zip(params, new_params):
|
|
147
|
+
set_storage_(c, n)
|
|
121
148
|
|
|
122
149
|
def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
|
|
123
150
|
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
@@ -149,7 +176,7 @@ class LineSearchBase(Module, ABC):
|
|
|
149
176
|
|
|
150
177
|
return tofloat(loss)
|
|
151
178
|
|
|
152
|
-
def
|
|
179
|
+
def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
|
|
153
180
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
154
181
|
# if step_size is 0, we might already know the derivative
|
|
155
182
|
if (var.grad is not None) and (step_size == 0):
|
|
@@ -164,18 +191,31 @@ class LineSearchBase(Module, ABC):
|
|
|
164
191
|
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
165
192
|
else torch.zeros_like(p) for p in params], update))
|
|
166
193
|
|
|
167
|
-
|
|
194
|
+
assert var.grad is not None
|
|
195
|
+
return loss, tofloat(derivative), var.grad
|
|
168
196
|
|
|
169
|
-
def
|
|
197
|
+
def _loss_derivative(self, step_size: float, var: Var, closure,
|
|
198
|
+
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
199
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
|
|
200
|
+
|
|
201
|
+
def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
|
|
202
|
+
"""evaluate function value at alpha `step_size`."""
|
|
170
203
|
closure = var.closure
|
|
171
204
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
172
205
|
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
|
|
173
206
|
|
|
174
|
-
def
|
|
207
|
+
def evaluate_f_d(self, step_size: float, var: Var):
|
|
208
|
+
"""evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
|
|
175
209
|
closure = var.closure
|
|
176
210
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
177
211
|
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
178
212
|
|
|
213
|
+
def evaluate_f_d_g(self, step_size: float, var: Var):
|
|
214
|
+
"""evaluate function value, directional derivative, and gradient list at step size `step_size`."""
|
|
215
|
+
closure = var.closure
|
|
216
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
217
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
|
|
218
|
+
|
|
179
219
|
def make_objective(self, var: Var, backward:bool=False):
|
|
180
220
|
closure = var.closure
|
|
181
221
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
@@ -186,6 +226,11 @@ class LineSearchBase(Module, ABC):
|
|
|
186
226
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
187
227
|
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
188
228
|
|
|
229
|
+
def make_objective_with_derivative_and_gradient(self, var: Var):
|
|
230
|
+
closure = var.closure
|
|
231
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
232
|
+
return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
|
|
233
|
+
|
|
189
234
|
@abstractmethod
|
|
190
235
|
def search(self, update: list[torch.Tensor], var: Var) -> float:
|
|
191
236
|
"""Finds the step size to use"""
|
|
@@ -193,7 +238,9 @@ class LineSearchBase(Module, ABC):
|
|
|
193
238
|
@torch.no_grad
|
|
194
239
|
def step(self, var: Var) -> Var:
|
|
195
240
|
self._reset()
|
|
241
|
+
|
|
196
242
|
params = var.params
|
|
243
|
+
self._initial_params = [p.clone() for p in params]
|
|
197
244
|
update = var.get_update()
|
|
198
245
|
|
|
199
246
|
try:
|
|
@@ -206,7 +253,6 @@ class LineSearchBase(Module, ABC):
|
|
|
206
253
|
|
|
207
254
|
# this is last module - set step size to found step_size times lr
|
|
208
255
|
if var.is_last:
|
|
209
|
-
|
|
210
256
|
if var.last_module_lrs is None:
|
|
211
257
|
self.set_step_size_(step_size, params=params, update=update)
|
|
212
258
|
|
|
@@ -223,17 +269,62 @@ class LineSearchBase(Module, ABC):
|
|
|
223
269
|
|
|
224
270
|
|
|
225
271
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
# @torch.no_grad
|
|
233
|
-
# def search(self, update, var):
|
|
234
|
-
# start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
|
|
235
|
-
|
|
236
|
-
# for lr in torch.linspace(start,end,num):
|
|
237
|
-
# self.evaluate_step_size(lr.item(), var=var, backward=False)
|
|
272
|
+
class GridLineSearch(LineSearchBase):
|
|
273
|
+
""""""
|
|
274
|
+
def __init__(self, start, end, num):
|
|
275
|
+
defaults = dict(start=start,end=end,num=num)
|
|
276
|
+
super().__init__(defaults)
|
|
238
277
|
|
|
239
|
-
|
|
278
|
+
@torch.no_grad
|
|
279
|
+
def search(self, update, var):
|
|
280
|
+
start,end,num=itemgetter('start','end','num')(self.defaults)
|
|
281
|
+
|
|
282
|
+
for lr in torch.linspace(start,end,num):
|
|
283
|
+
self.evaluate_f(lr.item(), var=var, backward=False)
|
|
284
|
+
|
|
285
|
+
return self._best_step_size
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def sufficient_decrease(f_0, g_0, f_a, a, c):
|
|
289
|
+
return f_a < f_0 + c*a*min(g_0, 0)
|
|
290
|
+
|
|
291
|
+
def curvature(g_0, g_a, c):
|
|
292
|
+
if g_0 > 0: return True
|
|
293
|
+
return g_a >= c * g_0
|
|
294
|
+
|
|
295
|
+
def strong_curvature(g_0, g_a, c):
|
|
296
|
+
"""same as curvature condition except curvature can't be too positive (which indicates overstep)"""
|
|
297
|
+
if g_0 > 0: return True
|
|
298
|
+
return abs(g_a) <= c * abs(g_0)
|
|
299
|
+
|
|
300
|
+
def wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
|
|
301
|
+
return sufficient_decrease(f_0, g_0, f_a, a, c1) and curvature(g_0, g_a, c2)
|
|
302
|
+
|
|
303
|
+
def strong_wolfe(f_0, g_0, f_a, g_a, a, c1, c2):
|
|
304
|
+
return sufficient_decrease(f_0, g_0, f_a, a, c1) and strong_curvature(g_0, g_a, c2)
|
|
305
|
+
|
|
306
|
+
def goldstein(f_0, g_0, f_a, a, c):
|
|
307
|
+
"""same as armijo (sufficient_decrease) but additional lower bound"""
|
|
308
|
+
g_0 = min(g_0, 0)
|
|
309
|
+
return f_0 + (1-c)*a*g_0 < f_a < f_0 + c*a*g_0
|
|
310
|
+
|
|
311
|
+
TerminationCondition = Literal["armijo", "curvature", "strong_curvature", "wolfe", "strong_wolfe", "goldstein", "decrease"]
|
|
312
|
+
def termination_condition(
|
|
313
|
+
condition: TerminationCondition,
|
|
314
|
+
f_0,
|
|
315
|
+
g_0,
|
|
316
|
+
f_a,
|
|
317
|
+
g_a: Any | None,
|
|
318
|
+
a,
|
|
319
|
+
c,
|
|
320
|
+
c2=None,
|
|
321
|
+
):
|
|
322
|
+
if not math.isfinite(f_a): return False
|
|
323
|
+
if condition == 'armijo': return sufficient_decrease(f_0, g_0, f_a, a, c)
|
|
324
|
+
if condition == 'curvature': return curvature(g_0, g_a, c)
|
|
325
|
+
if condition == 'strong_curvature': return strong_curvature(g_0, g_a, c)
|
|
326
|
+
if condition == 'wolfe': return wolfe(f_0, g_0, f_a, g_a, a, c, c2)
|
|
327
|
+
if condition == 'strong_wolfe': return strong_wolfe(f_0, g_0, f_a, g_a, a, c, c2)
|
|
328
|
+
if condition == 'goldstein': return goldstein(f_0, g_0, f_a, a, c)
|
|
329
|
+
if condition == 'decrease': return f_a < f_0
|
|
330
|
+
raise ValueError(f"unknown condition {condition}")
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import math
|
|
1
2
|
from collections.abc import Mapping
|
|
2
3
|
from operator import itemgetter
|
|
3
4
|
|
|
@@ -17,6 +18,7 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
17
18
|
bounds (Sequence | None, optional):
|
|
18
19
|
For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
|
|
19
20
|
tol (float | None, optional): Tolerance for termination. Defaults to None.
|
|
21
|
+
prev_init (bool, optional): uses previous step size as initial guess for the line search.
|
|
20
22
|
options (dict | None, optional): A dictionary of solver options. Defaults to None.
|
|
21
23
|
|
|
22
24
|
For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
|
|
@@ -29,9 +31,10 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
29
31
|
bracket=None,
|
|
30
32
|
bounds=None,
|
|
31
33
|
tol: float | None = None,
|
|
34
|
+
prev_init: bool = False,
|
|
32
35
|
options=None,
|
|
33
36
|
):
|
|
34
|
-
defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
|
|
37
|
+
defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter, prev_init=prev_init)
|
|
35
38
|
super().__init__(defaults)
|
|
36
39
|
|
|
37
40
|
import scipy.optimize
|
|
@@ -42,11 +45,20 @@ class ScipyMinimizeScalar(LineSearchBase):
|
|
|
42
45
|
def search(self, update, var):
|
|
43
46
|
objective = self.make_objective(var=var)
|
|
44
47
|
method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
45
|
-
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.
|
|
48
|
+
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)
|
|
46
49
|
|
|
47
50
|
if maxiter is not None:
|
|
48
51
|
options = dict(options) if isinstance(options, Mapping) else {}
|
|
49
52
|
options['maxiter'] = maxiter
|
|
50
53
|
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
if self.defaults["prev_init"] and "x_prev" in self.global_state:
|
|
55
|
+
if bracket is None: bracket = (0, 1)
|
|
56
|
+
bracket = (*bracket[:-1], self.global_state["x_prev"])
|
|
57
|
+
|
|
58
|
+
x = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options).x # pyright:ignore[reportAttributeAccessIssue]
|
|
59
|
+
|
|
60
|
+
max = torch.finfo(var.params[0].dtype).max / 2
|
|
61
|
+
if (not math.isfinite(x)) or abs(x) >= max: x = 0
|
|
62
|
+
|
|
63
|
+
self.global_state['x_prev'] = x
|
|
64
|
+
return x
|