torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,122 +1,387 @@
|
|
|
1
1
|
"""Various step size strategies"""
|
|
2
|
-
|
|
2
|
+
import math
|
|
3
3
|
from operator import itemgetter
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
4
6
|
import torch
|
|
5
7
|
|
|
6
|
-
from ...core import
|
|
7
|
-
from ...utils import TensorList, unpack_dicts, unpack_states
|
|
8
|
+
from ...core import Chainable, Transform
|
|
9
|
+
from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
|
|
10
|
+
from ...utils.linalg.linear_operator import ScaledIdentity
|
|
11
|
+
from ..functional import epsilon_step_size
|
|
12
|
+
|
|
13
|
+
def _acceptable_alpha(alpha, param:torch.Tensor):
|
|
14
|
+
finfo = torch.finfo(param.dtype)
|
|
15
|
+
if (alpha is None) or (alpha < finfo.tiny*2) or (not math.isfinite(alpha)) or (alpha > finfo.max/2):
|
|
16
|
+
return False
|
|
17
|
+
return True
|
|
18
|
+
|
|
19
|
+
def _get_H(self: Transform, var):
|
|
20
|
+
n = sum(p.numel() for p in var.params)
|
|
21
|
+
p = var.params[0]
|
|
22
|
+
alpha = self.global_state.get('alpha', 1)
|
|
23
|
+
if not _acceptable_alpha(alpha, p): alpha = 1
|
|
24
|
+
|
|
25
|
+
return ScaledIdentity(1 / alpha, shape=(n,n), device=p.device, dtype=p.dtype)
|
|
8
26
|
|
|
9
27
|
|
|
10
28
|
class PolyakStepSize(Transform):
|
|
11
|
-
"""Polyak's subgradient method
|
|
29
|
+
"""Polyak's subgradient method with known or unknown f*.
|
|
12
30
|
|
|
13
31
|
Args:
|
|
14
|
-
f_star (
|
|
15
|
-
|
|
32
|
+
f_star (float | Mone, optional):
|
|
33
|
+
minimal possible value of the objective function. If not known, set to ``None``. Defaults to 0.
|
|
34
|
+
y (float, optional):
|
|
35
|
+
when ``f_star`` is set to None, it is calculated as ``f_best - y``.
|
|
36
|
+
y_decay (float, optional):
|
|
37
|
+
``y`` is multiplied by ``(1 - y_decay)`` after each step. Defaults to 1e-3.
|
|
16
38
|
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
17
39
|
use_grad (bool, optional):
|
|
18
40
|
if True, uses dot product of update and gradient to compute the step size.
|
|
19
|
-
Otherwise, dot product of update with itself is used
|
|
20
|
-
Defaults to False.
|
|
41
|
+
Otherwise, dot product of update with itself is used.
|
|
21
42
|
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
22
43
|
"""
|
|
23
|
-
def __init__(self, f_star: float = 0, max: float | None = None, use_grad=
|
|
44
|
+
def __init__(self, f_star: float | None = 0, y: float = 1, y_decay: float = 1e-3, max: float | None = None, use_grad=True, alpha: float = 1, inner: Chainable | None = None):
|
|
24
45
|
|
|
25
|
-
defaults = dict(alpha=alpha, max=max, f_star=f_star,
|
|
46
|
+
defaults = dict(alpha=alpha, max=max, f_star=f_star, y=y, y_decay=y_decay)
|
|
26
47
|
super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
|
|
27
48
|
|
|
49
|
+
@torch.no_grad
|
|
28
50
|
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
29
51
|
assert grads is not None and loss is not None
|
|
30
52
|
tensors = TensorList(tensors)
|
|
31
53
|
grads = TensorList(grads)
|
|
32
54
|
|
|
33
|
-
|
|
55
|
+
# load variables
|
|
56
|
+
max, f_star, y, y_decay = itemgetter('max', 'f_star', 'y', 'y_decay')(settings[0])
|
|
57
|
+
y_val = self.global_state.get('y_val', y)
|
|
58
|
+
f_best = self.global_state.get('f_best', None)
|
|
34
59
|
|
|
35
|
-
|
|
60
|
+
# gg
|
|
61
|
+
if self._uses_grad: gg = tensors.dot(grads)
|
|
36
62
|
else: gg = tensors.dot(tensors)
|
|
37
63
|
|
|
38
|
-
|
|
39
|
-
|
|
64
|
+
# store loss
|
|
65
|
+
if f_best is None or loss < f_best: f_best = tofloat(loss)
|
|
66
|
+
if f_star is None: f_star = f_best - y_val
|
|
67
|
+
|
|
68
|
+
# calculate the step size
|
|
69
|
+
if gg <= torch.finfo(gg.dtype).tiny * 2: alpha = 0 # converged
|
|
70
|
+
else: alpha = (loss - f_star) / gg
|
|
40
71
|
|
|
72
|
+
# clip
|
|
41
73
|
if max is not None:
|
|
42
|
-
if
|
|
74
|
+
if alpha > max: alpha = max
|
|
43
75
|
|
|
44
|
-
|
|
76
|
+
# store state
|
|
77
|
+
self.global_state['f_best'] = f_best
|
|
78
|
+
self.global_state['y_val'] = y_val * (1 - y_decay)
|
|
79
|
+
self.global_state['alpha'] = alpha
|
|
45
80
|
|
|
46
81
|
@torch.no_grad
|
|
47
82
|
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
48
|
-
|
|
49
|
-
|
|
83
|
+
alpha = self.global_state.get('alpha', 1)
|
|
84
|
+
if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))
|
|
85
|
+
|
|
86
|
+
torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
|
|
50
87
|
return tensors
|
|
51
88
|
|
|
89
|
+
def get_H(self, var):
|
|
90
|
+
return _get_H(self, var)
|
|
52
91
|
|
|
53
92
|
|
|
54
|
-
def _bb_short(s: TensorList, y: TensorList, sy, eps
|
|
93
|
+
def _bb_short(s: TensorList, y: TensorList, sy, eps):
|
|
55
94
|
yy = y.dot(y)
|
|
56
95
|
if yy < eps:
|
|
57
|
-
if sy < eps: return
|
|
96
|
+
if sy < eps: return None # try to fallback on long
|
|
58
97
|
ss = s.dot(s)
|
|
59
98
|
return ss/sy
|
|
60
99
|
return sy/yy
|
|
61
100
|
|
|
62
|
-
def _bb_long(s: TensorList, y: TensorList, sy, eps
|
|
101
|
+
def _bb_long(s: TensorList, y: TensorList, sy, eps):
|
|
63
102
|
ss = s.dot(s)
|
|
64
103
|
if sy < eps:
|
|
65
104
|
yy = y.dot(y) # try to fallback on short
|
|
66
|
-
if yy < eps: return
|
|
105
|
+
if yy < eps: return None
|
|
67
106
|
return sy/yy
|
|
68
107
|
return ss/sy
|
|
69
108
|
|
|
70
|
-
def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback):
|
|
71
|
-
short = _bb_short(s, y, sy, eps
|
|
72
|
-
long = _bb_long(s, y, sy, eps
|
|
109
|
+
def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback:bool):
|
|
110
|
+
short = _bb_short(s, y, sy, eps)
|
|
111
|
+
long = _bb_long(s, y, sy, eps)
|
|
112
|
+
if long is None or short is None:
|
|
113
|
+
if fallback:
|
|
114
|
+
if short is not None: return short
|
|
115
|
+
if long is not None: return long
|
|
116
|
+
return None
|
|
73
117
|
return (short * long) ** 0.5
|
|
74
118
|
|
|
75
119
|
class BarzilaiBorwein(Transform):
|
|
76
|
-
"""Barzilai-Borwein method.
|
|
120
|
+
"""Barzilai-Borwein step size method.
|
|
77
121
|
|
|
78
122
|
Args:
|
|
79
123
|
type (str, optional):
|
|
80
124
|
one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
|
|
81
|
-
Defaults to
|
|
82
|
-
scale_first (bool, optional):
|
|
83
|
-
whether to make first step very small when previous gradient is not available. Defaults to True.
|
|
125
|
+
Defaults to "geom".
|
|
84
126
|
fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
|
|
85
127
|
inner (Chainable | None, optional):
|
|
86
128
|
step size will be applied to outputs of this module. Defaults to None.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
|
|
134
|
+
alpha_0: float = 1e-7,
|
|
135
|
+
use_grad=True,
|
|
136
|
+
inner: Chainable | None = None,
|
|
137
|
+
):
|
|
138
|
+
defaults = dict(type=type, alpha_0=alpha_0)
|
|
139
|
+
super().__init__(defaults, uses_grad=use_grad, inner=inner)
|
|
140
|
+
|
|
141
|
+
def reset_for_online(self):
|
|
142
|
+
super().reset_for_online()
|
|
143
|
+
self.clear_state_keys('prev_g')
|
|
144
|
+
self.global_state['reset'] = True
|
|
145
|
+
|
|
146
|
+
@torch.no_grad
|
|
147
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
148
|
+
step = self.global_state.get('step', 0)
|
|
149
|
+
self.global_state['step'] = step + 1
|
|
150
|
+
|
|
151
|
+
prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
|
|
152
|
+
type = self.defaults['type']
|
|
153
|
+
|
|
154
|
+
g = grads if self._uses_grad else tensors
|
|
155
|
+
assert g is not None
|
|
156
|
+
|
|
157
|
+
reset = self.global_state.get('reset', False)
|
|
158
|
+
self.global_state.pop('reset', None)
|
|
159
|
+
|
|
160
|
+
if step != 0 and not reset:
|
|
161
|
+
s = params-prev_p
|
|
162
|
+
y = g-prev_g
|
|
163
|
+
sy = s.dot(y)
|
|
164
|
+
eps = torch.finfo(sy.dtype).tiny * 2
|
|
165
|
+
|
|
166
|
+
if type == 'short': alpha = _bb_short(s, y, sy, eps)
|
|
167
|
+
elif type == 'long': alpha = _bb_long(s, y, sy, eps)
|
|
168
|
+
elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
|
|
169
|
+
elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
|
|
170
|
+
else: raise ValueError(type)
|
|
171
|
+
|
|
172
|
+
# if alpha is not None:
|
|
173
|
+
self.global_state['alpha'] = alpha
|
|
174
|
+
|
|
175
|
+
prev_p.copy_(params)
|
|
176
|
+
prev_g.copy_(g)
|
|
177
|
+
|
|
178
|
+
def get_H(self, var):
|
|
179
|
+
return _get_H(self, var)
|
|
180
|
+
|
|
181
|
+
@torch.no_grad
|
|
182
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
183
|
+
alpha = self.global_state.get('alpha', None)
|
|
184
|
+
|
|
185
|
+
if not _acceptable_alpha(alpha, tensors[0]):
|
|
186
|
+
alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
|
|
187
|
+
|
|
188
|
+
torch._foreach_mul_(tensors, alpha)
|
|
189
|
+
return tensors
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class BBStab(Transform):
|
|
193
|
+
"""Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).
|
|
194
|
+
|
|
195
|
+
This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
c (float, optional):
|
|
199
|
+
adaptive delta parameter. If ``delta`` is set to None, first ``inf_iters`` updates are performed
|
|
200
|
+
with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of
|
|
201
|
+
the update that had the smallest norm, and multiplied by ``c``. Defaults to 0.2.
|
|
202
|
+
delta (float | None, optional):
|
|
203
|
+
Barzilai-Borwein update is clipped to this value. Set to ``None`` to use an adaptive choice. Defaults to None.
|
|
204
|
+
type (str, optional):
|
|
205
|
+
one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
|
|
206
|
+
Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab,
|
|
207
|
+
however I found that "geom" works really well.
|
|
208
|
+
inner (Chainable | None, optional):
|
|
209
|
+
step size will be applied to outputs of this module. Defaults to None.
|
|
87
210
|
|
|
88
211
|
"""
|
|
89
|
-
def __init__(
|
|
90
|
-
|
|
91
|
-
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
c=0.2,
|
|
215
|
+
delta:float | None = None,
|
|
216
|
+
type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
|
|
217
|
+
alpha_0: float = 1e-7,
|
|
218
|
+
use_grad=True,
|
|
219
|
+
inf_iters: int = 3,
|
|
220
|
+
inner: Chainable | None = None,
|
|
221
|
+
):
|
|
222
|
+
defaults = dict(type=type,alpha_0=alpha_0, c=c, delta=delta, inf_iters=inf_iters)
|
|
223
|
+
super().__init__(defaults, uses_grad=use_grad, inner=inner)
|
|
92
224
|
|
|
93
225
|
def reset_for_online(self):
|
|
94
226
|
super().reset_for_online()
|
|
95
|
-
self.clear_state_keys('
|
|
227
|
+
self.clear_state_keys('prev_g')
|
|
228
|
+
self.global_state['reset'] = True
|
|
96
229
|
|
|
97
230
|
@torch.no_grad
|
|
98
231
|
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
232
|
+
step = self.global_state.get('step', 0)
|
|
233
|
+
self.global_state['step'] = step + 1
|
|
234
|
+
|
|
99
235
|
prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
|
|
100
|
-
|
|
101
|
-
|
|
236
|
+
type = self.defaults['type']
|
|
237
|
+
c = self.defaults['c']
|
|
238
|
+
delta = self.defaults['delta']
|
|
239
|
+
inf_iters = self.defaults['inf_iters']
|
|
240
|
+
|
|
241
|
+
g = grads if self._uses_grad else tensors
|
|
242
|
+
assert g is not None
|
|
243
|
+
g = TensorList(g)
|
|
244
|
+
|
|
245
|
+
reset = self.global_state.get('reset', False)
|
|
246
|
+
self.global_state.pop('reset', None)
|
|
247
|
+
|
|
248
|
+
if step != 0 and not reset:
|
|
249
|
+
s = params-prev_p
|
|
250
|
+
y = g-prev_g
|
|
251
|
+
sy = s.dot(y)
|
|
252
|
+
eps = torch.finfo(sy.dtype).tiny
|
|
253
|
+
|
|
254
|
+
if type == 'short': alpha = _bb_short(s, y, sy, eps)
|
|
255
|
+
elif type == 'long': alpha = _bb_long(s, y, sy, eps)
|
|
256
|
+
elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
|
|
257
|
+
elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
|
|
258
|
+
else: raise ValueError(type)
|
|
259
|
+
|
|
260
|
+
if alpha is not None:
|
|
261
|
+
|
|
262
|
+
# adaptive delta
|
|
263
|
+
if delta is None:
|
|
264
|
+
niters = self.global_state.get('niters', 0) # this accounts for skipped negative curvature steps
|
|
265
|
+
self.global_state['niters'] = niters + 1
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
if niters == 0: pass # 1st iteration is scaled GD step, shouldn't be used to find s_norm_min
|
|
269
|
+
elif niters <= inf_iters:
|
|
270
|
+
s_norm_min = self.global_state.get('s_norm_min', None)
|
|
271
|
+
if s_norm_min is None: s_norm_min = s.global_vector_norm()
|
|
272
|
+
else: s_norm_min = min(s_norm_min, s.global_vector_norm())
|
|
273
|
+
self.global_state['s_norm_min'] = s_norm_min
|
|
274
|
+
# first few steps use delta=inf, so delta remains None
|
|
102
275
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
sy = s.dot(y)
|
|
106
|
-
eps = torch.finfo(sy.dtype).eps
|
|
276
|
+
else:
|
|
277
|
+
delta = c * self.global_state['s_norm_min']
|
|
107
278
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
elif type == 'geom': step_size = _bb_geom(s, y, sy, eps, fallback)
|
|
111
|
-
else: raise ValueError(type)
|
|
279
|
+
if delta is None: # delta is inf for first few steps
|
|
280
|
+
self.global_state['alpha'] = alpha
|
|
112
281
|
|
|
113
|
-
|
|
282
|
+
# BBStab step size
|
|
283
|
+
else:
|
|
284
|
+
a_stab = delta / g.global_vector_norm()
|
|
285
|
+
self.global_state['alpha'] = min(alpha, a_stab)
|
|
114
286
|
|
|
115
287
|
prev_p.copy_(params)
|
|
116
|
-
prev_g.copy_(
|
|
288
|
+
prev_g.copy_(g)
|
|
289
|
+
|
|
290
|
+
def get_H(self, var):
|
|
291
|
+
return _get_H(self, var)
|
|
292
|
+
|
|
293
|
+
@torch.no_grad
|
|
294
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
295
|
+
alpha = self.global_state.get('alpha', None)
|
|
296
|
+
|
|
297
|
+
if not _acceptable_alpha(alpha, tensors[0]):
|
|
298
|
+
alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
|
|
299
|
+
|
|
300
|
+
torch._foreach_mul_(tensors, alpha)
|
|
301
|
+
return tensors
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class AdGD(Transform):
|
|
305
|
+
"""AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
|
|
306
|
+
def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
|
|
307
|
+
defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
|
|
308
|
+
super().__init__(defaults, uses_grad=use_grad, inner=inner,)
|
|
309
|
+
|
|
310
|
+
def reset_for_online(self):
|
|
311
|
+
super().reset_for_online()
|
|
312
|
+
self.clear_state_keys('prev_g')
|
|
313
|
+
self.global_state['reset'] = True
|
|
314
|
+
|
|
315
|
+
@torch.no_grad
|
|
316
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
317
|
+
variant = settings[0]['variant']
|
|
318
|
+
theta_0 = 0 if variant == 1 else 1/3
|
|
319
|
+
theta = self.global_state.get('theta', theta_0)
|
|
320
|
+
|
|
321
|
+
step = self.global_state.get('step', 0)
|
|
322
|
+
self.global_state['step'] = step + 1
|
|
323
|
+
|
|
324
|
+
p = TensorList(params)
|
|
325
|
+
g = grads if self._uses_grad else tensors
|
|
326
|
+
assert g is not None
|
|
327
|
+
g = TensorList(g)
|
|
328
|
+
|
|
329
|
+
prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
|
|
330
|
+
|
|
331
|
+
# online
|
|
332
|
+
if self.global_state.get('reset', False):
|
|
333
|
+
del self.global_state['reset']
|
|
334
|
+
prev_p.copy_(p)
|
|
335
|
+
prev_g.copy_(g)
|
|
336
|
+
return
|
|
117
337
|
|
|
338
|
+
if step == 0:
|
|
339
|
+
alpha_0 = settings[0]['alpha_0']
|
|
340
|
+
if alpha_0 is None: alpha_0 = epsilon_step_size(g)
|
|
341
|
+
self.global_state['alpha'] = alpha_0
|
|
342
|
+
prev_p.copy_(p)
|
|
343
|
+
prev_g.copy_(g)
|
|
344
|
+
return
|
|
345
|
+
|
|
346
|
+
sqrt = settings[0]['sqrt']
|
|
347
|
+
alpha = self.global_state.get('alpha', math.inf)
|
|
348
|
+
L = (g - prev_g).global_vector_norm() / (p - prev_p).global_vector_norm()
|
|
349
|
+
eps = torch.finfo(L.dtype).tiny * 2
|
|
350
|
+
|
|
351
|
+
if variant == 1:
|
|
352
|
+
a1 = math.sqrt(1 + theta)*alpha
|
|
353
|
+
val = math.sqrt(2) if sqrt else 2
|
|
354
|
+
if L > eps: a2 = 1 / (val*L)
|
|
355
|
+
else: a2 = math.inf
|
|
356
|
+
|
|
357
|
+
elif variant == 2:
|
|
358
|
+
a1 = math.sqrt(2/3 + theta)*alpha
|
|
359
|
+
a2 = alpha / math.sqrt(max(eps, 2 * alpha**2 * L**2 - 1))
|
|
360
|
+
|
|
361
|
+
else:
|
|
362
|
+
raise ValueError(variant)
|
|
363
|
+
|
|
364
|
+
alpha_new = min(a1, a2)
|
|
365
|
+
if alpha_new < 0: alpha_new = max(a1, a2)
|
|
366
|
+
if alpha_new > eps:
|
|
367
|
+
self.global_state['theta'] = alpha_new/alpha
|
|
368
|
+
self.global_state['alpha'] = alpha_new
|
|
369
|
+
|
|
370
|
+
prev_p.copy_(p)
|
|
371
|
+
prev_g.copy_(g)
|
|
372
|
+
|
|
373
|
+
@torch.no_grad
|
|
118
374
|
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
119
|
-
|
|
120
|
-
|
|
375
|
+
alpha = self.global_state.get('alpha', None)
|
|
376
|
+
|
|
377
|
+
if not _acceptable_alpha(alpha, tensors[0]):
|
|
378
|
+
# alpha isn't None on 1st step
|
|
379
|
+
self.state.clear()
|
|
380
|
+
self.global_state.clear()
|
|
381
|
+
alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])
|
|
382
|
+
|
|
383
|
+
torch._foreach_mul_(tensors, alpha)
|
|
121
384
|
return tensors
|
|
122
385
|
|
|
386
|
+
def get_H(self, var):
|
|
387
|
+
return _get_H(self, var)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .termination import (
|
|
2
|
+
TerminateAfterNEvaluations,
|
|
3
|
+
TerminateAfterNSeconds,
|
|
4
|
+
TerminateAfterNSteps,
|
|
5
|
+
TerminateAll,
|
|
6
|
+
TerminateAny,
|
|
7
|
+
TerminateByGradientNorm,
|
|
8
|
+
TerminateByUpdateNorm,
|
|
9
|
+
TerminateOnLossReached,
|
|
10
|
+
TerminateOnNoImprovement,
|
|
11
|
+
TerminationCriteriaBase,
|
|
12
|
+
TerminateNever,
|
|
13
|
+
make_termination_criteria
|
|
14
|
+
)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Module, Var
|
|
9
|
+
from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TerminationCriteriaBase(Module):
|
|
13
|
+
def __init__(self, defaults:dict | None = None, n: int = 1):
|
|
14
|
+
if defaults is None: defaults = {}
|
|
15
|
+
safe_dict_update_(defaults, {"_n": n})
|
|
16
|
+
super().__init__(defaults)
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def termination_criteria(self, var: Var) -> bool:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def should_terminate(self, var: Var) -> bool:
|
|
23
|
+
n_bad = self.global_state.get('_n_bad', 0)
|
|
24
|
+
n = self.defaults['_n']
|
|
25
|
+
|
|
26
|
+
if self.termination_criteria(var):
|
|
27
|
+
n_bad += 1
|
|
28
|
+
if n_bad >= n:
|
|
29
|
+
self.global_state['_n_bad'] = 0
|
|
30
|
+
return True
|
|
31
|
+
|
|
32
|
+
else:
|
|
33
|
+
n_bad = 0
|
|
34
|
+
|
|
35
|
+
self.global_state['_n_bad'] = n_bad
|
|
36
|
+
return False
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def update(self, var):
|
|
40
|
+
var.should_terminate = self.should_terminate(var)
|
|
41
|
+
if var.should_terminate: self.global_state['_n_bad'] = 0
|
|
42
|
+
|
|
43
|
+
def apply(self, var):
|
|
44
|
+
return var
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TerminateAfterNSteps(TerminationCriteriaBase):
|
|
48
|
+
def __init__(self, steps:int):
|
|
49
|
+
defaults = dict(steps=steps)
|
|
50
|
+
super().__init__(defaults)
|
|
51
|
+
|
|
52
|
+
def termination_criteria(self, var):
|
|
53
|
+
step = self.global_state.get('step', 0)
|
|
54
|
+
self.global_state['step'] = step + 1
|
|
55
|
+
|
|
56
|
+
max_steps = self.defaults['steps']
|
|
57
|
+
return step >= max_steps
|
|
58
|
+
|
|
59
|
+
class TerminateAfterNEvaluations(TerminationCriteriaBase):
|
|
60
|
+
def __init__(self, maxevals:int):
|
|
61
|
+
defaults = dict(maxevals=maxevals)
|
|
62
|
+
super().__init__(defaults)
|
|
63
|
+
|
|
64
|
+
def termination_criteria(self, var):
|
|
65
|
+
maxevals = self.defaults['maxevals']
|
|
66
|
+
return var.modular.num_evaluations >= maxevals
|
|
67
|
+
|
|
68
|
+
class TerminateAfterNSeconds(TerminationCriteriaBase):
|
|
69
|
+
def __init__(self, seconds:float, sec_fn = time.time):
|
|
70
|
+
defaults = dict(seconds=seconds, sec_fn=sec_fn)
|
|
71
|
+
super().__init__(defaults)
|
|
72
|
+
|
|
73
|
+
def termination_criteria(self, var):
|
|
74
|
+
max_seconds = self.defaults['seconds']
|
|
75
|
+
sec_fn = self.defaults['sec_fn']
|
|
76
|
+
|
|
77
|
+
if 'start' not in self.global_state:
|
|
78
|
+
self.global_state['start'] = sec_fn()
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
seconds_passed = sec_fn() - self.global_state['start']
|
|
82
|
+
return seconds_passed >= max_seconds
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class TerminateByGradientNorm(TerminationCriteriaBase):
|
|
87
|
+
def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
|
|
88
|
+
defaults = dict(tol=tol, ord=ord)
|
|
89
|
+
super().__init__(defaults, n=n)
|
|
90
|
+
|
|
91
|
+
def termination_criteria(self, var):
|
|
92
|
+
tol = self.defaults['tol']
|
|
93
|
+
ord = self.defaults['ord']
|
|
94
|
+
return TensorList(var.get_grad()).global_metric(ord) <= tol
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class TerminateByUpdateNorm(TerminationCriteriaBase):
|
|
98
|
+
"""update is calculated as parameter difference"""
|
|
99
|
+
def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
|
|
100
|
+
defaults = dict(tol=tol, ord=ord)
|
|
101
|
+
super().__init__(defaults, n=n)
|
|
102
|
+
|
|
103
|
+
def termination_criteria(self, var):
|
|
104
|
+
step = self.global_state.get('step', 0)
|
|
105
|
+
self.global_state['step'] = step + 1
|
|
106
|
+
|
|
107
|
+
tol = self.defaults['tol']
|
|
108
|
+
ord = self.defaults['ord']
|
|
109
|
+
|
|
110
|
+
p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
|
|
111
|
+
if step == 0:
|
|
112
|
+
p_prev.copy_(var.params)
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
should_terminate = (p_prev - var.params).global_metric(ord) <= tol
|
|
116
|
+
p_prev.copy_(var.params)
|
|
117
|
+
return should_terminate
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class TerminateOnNoImprovement(TerminationCriteriaBase):
|
|
121
|
+
def __init__(self, tol:float = 1e-8, n: int = 10):
|
|
122
|
+
defaults = dict(tol=tol)
|
|
123
|
+
super().__init__(defaults, n=n)
|
|
124
|
+
|
|
125
|
+
def termination_criteria(self, var):
|
|
126
|
+
tol = self.defaults['tol']
|
|
127
|
+
|
|
128
|
+
f = tofloat(var.get_loss(False))
|
|
129
|
+
if 'f_min' not in self.global_state:
|
|
130
|
+
self.global_state['f_min'] = f
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
f_min = self.global_state['f_min']
|
|
134
|
+
d = f_min - f
|
|
135
|
+
should_terminate = d <= tol
|
|
136
|
+
self.global_state['f_min'] = min(f, f_min)
|
|
137
|
+
return should_terminate
|
|
138
|
+
|
|
139
|
+
class TerminateOnLossReached(TerminationCriteriaBase):
|
|
140
|
+
def __init__(self, value: float):
|
|
141
|
+
defaults = dict(value=value)
|
|
142
|
+
super().__init__(defaults)
|
|
143
|
+
|
|
144
|
+
def termination_criteria(self, var):
|
|
145
|
+
value = self.defaults['value']
|
|
146
|
+
return var.get_loss(False) <= value
|
|
147
|
+
|
|
148
|
+
class TerminateAny(TerminationCriteriaBase):
|
|
149
|
+
def __init__(self, *criteria: TerminationCriteriaBase):
|
|
150
|
+
super().__init__()
|
|
151
|
+
|
|
152
|
+
self.set_children_sequence(criteria)
|
|
153
|
+
|
|
154
|
+
def termination_criteria(self, var: Var) -> bool:
|
|
155
|
+
for c in self.get_children_sequence():
|
|
156
|
+
if cast(TerminationCriteriaBase, c).termination_criteria(var): return True
|
|
157
|
+
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
class TerminateAll(TerminationCriteriaBase):
|
|
161
|
+
def __init__(self, *criteria: TerminationCriteriaBase):
|
|
162
|
+
super().__init__()
|
|
163
|
+
|
|
164
|
+
self.set_children_sequence(criteria)
|
|
165
|
+
|
|
166
|
+
def termination_criteria(self, var: Var) -> bool:
|
|
167
|
+
for c in self.get_children_sequence():
|
|
168
|
+
if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False
|
|
169
|
+
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
class TerminateNever(TerminationCriteriaBase):
|
|
173
|
+
def __init__(self):
|
|
174
|
+
super().__init__()
|
|
175
|
+
|
|
176
|
+
def termination_criteria(self, var): return False
|
|
177
|
+
|
|
178
|
+
def make_termination_criteria(
|
|
179
|
+
ftol: float | None = None,
|
|
180
|
+
gtol: float | None = None,
|
|
181
|
+
stol: float | None = None,
|
|
182
|
+
maxiter: int | None = None,
|
|
183
|
+
maxeval: int | None = None,
|
|
184
|
+
maxsec: float | None = None,
|
|
185
|
+
target_loss: float | None = None,
|
|
186
|
+
extra: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
|
|
187
|
+
n: int = 3,
|
|
188
|
+
):
|
|
189
|
+
criteria: list[TerminationCriteriaBase] = []
|
|
190
|
+
|
|
191
|
+
if ftol is not None: criteria.append(TerminateOnNoImprovement(ftol, n=n))
|
|
192
|
+
if gtol is not None: criteria.append(TerminateByGradientNorm(gtol, n=n))
|
|
193
|
+
if stol is not None: criteria.append(TerminateByUpdateNorm(stol, n=n))
|
|
194
|
+
|
|
195
|
+
if maxiter is not None: criteria.append(TerminateAfterNSteps(maxiter))
|
|
196
|
+
if maxeval is not None: criteria.append(TerminateAfterNEvaluations(maxeval))
|
|
197
|
+
if maxsec is not None: criteria.append(TerminateAfterNSeconds(maxsec))
|
|
198
|
+
|
|
199
|
+
if target_loss is not None: criteria.append(TerminateOnLossReached(target_loss))
|
|
200
|
+
|
|
201
|
+
if extra is not None:
|
|
202
|
+
if isinstance(extra, TerminationCriteriaBase): criteria.append(extra)
|
|
203
|
+
else: criteria.extend(extra)
|
|
204
|
+
|
|
205
|
+
if len(criteria) == 0: return TerminateNever()
|
|
206
|
+
if len(criteria) == 1: return criteria[0]
|
|
207
|
+
return TerminateAny(*criteria)
|