torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
+
|
|
6
|
+
def adan_(
|
|
7
|
+
g: TensorList,
|
|
8
|
+
g_prev_: TensorList,
|
|
9
|
+
m_: TensorList, # exponential moving average
|
|
10
|
+
v_: TensorList, # exponential moving average of gradient differences
|
|
11
|
+
n_: TensorList, # kinda like squared momentum
|
|
12
|
+
beta1: float | NumberList,
|
|
13
|
+
beta2: float | NumberList,
|
|
14
|
+
beta3: float | NumberList,
|
|
15
|
+
eps: float | NumberList,
|
|
16
|
+
step: int,
|
|
17
|
+
):
|
|
18
|
+
"""Returns new tensors"""
|
|
19
|
+
m_.lerp_(g, 1 - beta1)
|
|
20
|
+
|
|
21
|
+
if step == 1:
|
|
22
|
+
term = g
|
|
23
|
+
else:
|
|
24
|
+
diff = g - g_prev_
|
|
25
|
+
v_.lerp_(diff, 1 - beta2)
|
|
26
|
+
term = g + beta2 * diff
|
|
27
|
+
|
|
28
|
+
n_.mul_(beta3).addcmul_(term, term, value=(1 - beta3))
|
|
29
|
+
|
|
30
|
+
m = m_ / (1.0 - beta1**step)
|
|
31
|
+
v = v_ / (1.0 - beta2**step)
|
|
32
|
+
n = n_ / (1.0 - beta3**step)
|
|
33
|
+
|
|
34
|
+
denom = n.sqrt_().add_(eps)
|
|
35
|
+
num = m + beta2 * v
|
|
36
|
+
|
|
37
|
+
update = num.div_(denom)
|
|
38
|
+
g_prev_.copy_(g)
|
|
39
|
+
|
|
40
|
+
return update
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Adan(Transform):
|
|
45
|
+
"""Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
beta1 (float, optional): momentum. Defaults to 0.98.
|
|
49
|
+
beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
|
|
50
|
+
beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
|
|
51
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
52
|
+
use_n_prev (bool, optional):
|
|
53
|
+
whether to use previous gradient differences momentum.
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
```python
|
|
57
|
+
opt = tz.Modular(
|
|
58
|
+
model.parameters(),
|
|
59
|
+
tz.m.Adan(),
|
|
60
|
+
tz.m.LR(1e-3),
|
|
61
|
+
)
|
|
62
|
+
Reference:
|
|
63
|
+
Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
|
|
64
|
+
"""
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
beta1: float = 0.98,
|
|
68
|
+
beta2: float = 0.92,
|
|
69
|
+
beta3: float = 0.99,
|
|
70
|
+
eps: float = 1e-8,
|
|
71
|
+
):
|
|
72
|
+
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps)
|
|
73
|
+
super().__init__(defaults, uses_grad=False)
|
|
74
|
+
|
|
75
|
+
@torch.no_grad
|
|
76
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
77
|
+
tensors = TensorList(tensors)
|
|
78
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
79
|
+
|
|
80
|
+
beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
|
|
81
|
+
g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
|
|
82
|
+
|
|
83
|
+
update = adan_(
|
|
84
|
+
g=tensors,
|
|
85
|
+
g_prev_=g_prev,
|
|
86
|
+
m_=m,
|
|
87
|
+
v_=v,
|
|
88
|
+
n_=n,
|
|
89
|
+
beta1=beta1,
|
|
90
|
+
beta2=beta2,
|
|
91
|
+
beta3=beta3,
|
|
92
|
+
eps=eps,
|
|
93
|
+
step=step,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return update
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Transform
|
|
3
|
+
from ...utils import TensorList, unpack_dicts, unpack_states
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
|
|
7
|
+
if f - f_star <= torch.finfo(p[0].dtype).tiny * 2: return g
|
|
8
|
+
|
|
9
|
+
g_g = g.dot(g)
|
|
10
|
+
g_gp = g.dot(g_prev)
|
|
11
|
+
num = -(f - f_star) * g.dot(g_prev)
|
|
12
|
+
denom = (f_prev - f_star) * g_g + (f - f_star) * g_gp
|
|
13
|
+
m = num/denom
|
|
14
|
+
|
|
15
|
+
h = 2*(f - f_star) / g_g
|
|
16
|
+
return (1 + m) * h * g - m*(p-p_prev)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AdaptiveHeavyBall(Transform):
|
|
20
|
+
"""Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
|
|
21
|
+
|
|
22
|
+
This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
|
|
23
|
+
|
|
24
|
+
note:
|
|
25
|
+
The step size is determined by the algorithm, so learning rate modules shouldn't be used.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
f_star (int, optional):
|
|
29
|
+
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
30
|
+
"""
|
|
31
|
+
def __init__(self, f_star: float = 0):
|
|
32
|
+
defaults = dict(f_star=f_star)
|
|
33
|
+
super().__init__(defaults, uses_grad=False, uses_loss=True)
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
37
|
+
assert loss is not None
|
|
38
|
+
tensors = TensorList(tensors)
|
|
39
|
+
f_star = self.defaults['f_star']
|
|
40
|
+
|
|
41
|
+
f_prev = self.global_state.get('f_prev', None)
|
|
42
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
|
|
43
|
+
|
|
44
|
+
if f_prev is None:
|
|
45
|
+
self.global_state['f_prev'] = loss
|
|
46
|
+
h = 2*(loss - f_star) / tensors.dot(tensors)
|
|
47
|
+
return h * tensors
|
|
48
|
+
|
|
49
|
+
update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
|
|
50
|
+
|
|
51
|
+
self.global_state['f_prev'] = loss
|
|
52
|
+
p_prev.copy_(params)
|
|
53
|
+
g_prev.copy_(tensors)
|
|
54
|
+
return update
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Transform
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
|
+
|
|
8
|
+
# i've verified, it is identical to official
|
|
9
|
+
# https://github.com/txping/AEGD/blob/master/aegd.py
|
|
10
|
+
def aegd_(f: torch.Tensor | float, g: TensorList, r_: TensorList, c:float|NumberList=1, eta:float|NumberList=0.1) -> TensorList:
|
|
11
|
+
v = g / (2 * (f + c)**0.5)
|
|
12
|
+
r_ /= 1 + (v ** 2).mul_(2*eta) # update energy
|
|
13
|
+
return 2*eta * r_*v # pyright:ignore[reportReturnType]
|
|
14
|
+
|
|
15
|
+
class AEGD(Transform):
|
|
16
|
+
"""AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.
|
|
17
|
+
|
|
18
|
+
Note:
|
|
19
|
+
AEGD has a learning rate hyperparameter that can't really be removed from the update rule.
|
|
20
|
+
To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
eta (float, optional): step size. Defaults to 0.1.
|
|
24
|
+
c (float, optional): c. Defaults to 1.
|
|
25
|
+
beta3 (float, optional): thrid (squared) momentum. Defaults to 0.1.
|
|
26
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
27
|
+
use_n_prev (bool, optional):
|
|
28
|
+
whether to use previous gradient differences momentum.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
lr: float = 0.1,
|
|
33
|
+
c: float = 1,
|
|
34
|
+
):
|
|
35
|
+
defaults=dict(c=c,lr=lr)
|
|
36
|
+
super().__init__(defaults, uses_loss=True)
|
|
37
|
+
|
|
38
|
+
@torch.no_grad
|
|
39
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
|
+
assert loss is not None
|
|
41
|
+
tensors = TensorList(tensors)
|
|
42
|
+
|
|
43
|
+
c,lr=unpack_dicts(settings, 'c','lr', cls=NumberList)
|
|
44
|
+
r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss+c[0])**0.5), cls=TensorList)
|
|
45
|
+
|
|
46
|
+
update = aegd_(
|
|
47
|
+
f=loss,
|
|
48
|
+
g=tensors,
|
|
49
|
+
r_=r,
|
|
50
|
+
c=c,
|
|
51
|
+
eta=lr,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return update
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def esgd_(
|
|
12
|
+
tensors_: TensorList,
|
|
13
|
+
D: TensorList | None,
|
|
14
|
+
D_sq_acc_: TensorList,
|
|
15
|
+
damping: float | NumberList,
|
|
16
|
+
update_freq: int,
|
|
17
|
+
step: int,
|
|
18
|
+
i: int,
|
|
19
|
+
):
|
|
20
|
+
# update preconditioner
|
|
21
|
+
if step % update_freq == 0:
|
|
22
|
+
assert D is not None
|
|
23
|
+
D_sq_acc_.addcmul_(D, D)
|
|
24
|
+
i += 1
|
|
25
|
+
else:
|
|
26
|
+
assert D is None
|
|
27
|
+
|
|
28
|
+
denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
|
|
29
|
+
return tensors_.div_(denom), i
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ESGD(Module):
|
|
33
|
+
"""Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
|
|
34
|
+
|
|
35
|
+
This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
|
|
36
|
+
|
|
37
|
+
.. note::
|
|
38
|
+
In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
|
|
39
|
+
|
|
40
|
+
.. note::
|
|
41
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
42
|
+
|
|
43
|
+
.. note::
|
|
44
|
+
This module requires a closure passed to the optimizer step,
|
|
45
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
46
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
damping (float, optional): added to denominator for stability. Defaults to 1e-4.
|
|
50
|
+
update_freq (int, optional):
|
|
51
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
52
|
+
This value can be increased to reduce computational cost. Defaults to 20.
|
|
53
|
+
hvp_method (str, optional):
|
|
54
|
+
Determines how Hessian-vector products are evaluated.
|
|
55
|
+
|
|
56
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
57
|
+
This requires creating a graph for the gradient.
|
|
58
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
59
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
60
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
61
|
+
more accurate HVP approximation. This requires two extra
|
|
62
|
+
gradient evaluations.
|
|
63
|
+
Defaults to "autograd".
|
|
64
|
+
fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
65
|
+
n_samples (int, optional):
|
|
66
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
67
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
68
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
69
|
+
inner (Chainable | None, optional):
|
|
70
|
+
Inner module. If this is specified, operations are performed in the following order.
|
|
71
|
+
1. compute hessian diagonal estimate.
|
|
72
|
+
2. pass inputs to :code:`inner`.
|
|
73
|
+
3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
|
|
74
|
+
|
|
75
|
+
Examples:
|
|
76
|
+
Using ESGD:
|
|
77
|
+
|
|
78
|
+
.. code-block:: python
|
|
79
|
+
|
|
80
|
+
opt = tz.Modular(
|
|
81
|
+
model.parameters(),
|
|
82
|
+
tz.m.ESGD(),
|
|
83
|
+
tz.m.LR(0.1)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
|
|
87
|
+
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
88
|
+
|
|
89
|
+
.. code-block:: python
|
|
90
|
+
|
|
91
|
+
opt = tz.Modular(
|
|
92
|
+
model.parameters(),
|
|
93
|
+
tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
|
|
94
|
+
tz.m.LR(0.1)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
damping: float = 1e-4,
|
|
101
|
+
update_freq: int = 20,
|
|
102
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
103
|
+
fd_h: float = 1e-3,
|
|
104
|
+
n_samples = 1,
|
|
105
|
+
seed: int | None = None,
|
|
106
|
+
inner: Chainable | None = None
|
|
107
|
+
):
|
|
108
|
+
defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
109
|
+
super().__init__(defaults)
|
|
110
|
+
|
|
111
|
+
if inner is not None:
|
|
112
|
+
self.set_child('inner', inner)
|
|
113
|
+
|
|
114
|
+
@torch.no_grad
|
|
115
|
+
def step(self, var):
|
|
116
|
+
params = var.params
|
|
117
|
+
settings = self.settings[params[0]]
|
|
118
|
+
hvp_method = settings['hvp_method']
|
|
119
|
+
fd_h = settings['fd_h']
|
|
120
|
+
update_freq = settings['update_freq']
|
|
121
|
+
n_samples = settings['n_samples']
|
|
122
|
+
|
|
123
|
+
seed = settings['seed']
|
|
124
|
+
generator = None
|
|
125
|
+
if seed is not None:
|
|
126
|
+
if 'generator' not in self.global_state:
|
|
127
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
128
|
+
generator = self.global_state['generator']
|
|
129
|
+
|
|
130
|
+
damping = self.get_settings(params, 'damping', cls=NumberList)
|
|
131
|
+
D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
|
|
132
|
+
i = self.global_state.get('i', 0)
|
|
133
|
+
|
|
134
|
+
step = self.global_state.get('step', 0)
|
|
135
|
+
self.global_state['step'] = step + 1
|
|
136
|
+
|
|
137
|
+
closure = var.closure
|
|
138
|
+
assert closure is not None
|
|
139
|
+
|
|
140
|
+
D = None
|
|
141
|
+
if step % update_freq == 0:
|
|
142
|
+
|
|
143
|
+
rgrad=None
|
|
144
|
+
for j in range(n_samples):
|
|
145
|
+
u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
|
|
146
|
+
|
|
147
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
148
|
+
h=fd_h, normalize=True, retain_grad=j < n_samples-1)
|
|
149
|
+
|
|
150
|
+
if D is None: D = Hvp
|
|
151
|
+
else: torch._foreach_add_(D, Hvp)
|
|
152
|
+
|
|
153
|
+
assert D is not None
|
|
154
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
155
|
+
|
|
156
|
+
D = TensorList(D)
|
|
157
|
+
|
|
158
|
+
update = var.get_update()
|
|
159
|
+
if 'inner' in self.children:
|
|
160
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
161
|
+
|
|
162
|
+
var.update, self.global_state['i'] = esgd_(
|
|
163
|
+
tensors_=TensorList(update),
|
|
164
|
+
D=TensorList(D) if D is not None else None,
|
|
165
|
+
D_sq_acc_=D_sq_acc,
|
|
166
|
+
damping=damping,
|
|
167
|
+
update_freq=update_freq,
|
|
168
|
+
step=step,
|
|
169
|
+
i=i,
|
|
170
|
+
)
|
|
171
|
+
return var
|
|
@@ -28,7 +28,7 @@ class Lion(Transform):
|
|
|
28
28
|
super().__init__(defaults, uses_grad=False)
|
|
29
29
|
|
|
30
30
|
@torch.no_grad
|
|
31
|
-
def
|
|
31
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
32
32
|
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
33
33
|
exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
34
34
|
return lion_(TensorList(tensors),exp_avg,beta1,beta2)
|
|
@@ -1,55 +1,57 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
import math
|
|
3
1
|
from collections import deque
|
|
4
2
|
from typing import Literal, Any
|
|
5
|
-
import
|
|
3
|
+
import warnings
|
|
6
4
|
|
|
7
5
|
import torch
|
|
8
6
|
from ...core import Chainable, TensorwiseTransform
|
|
9
|
-
from ...utils.linalg.matrix_funcs import matrix_power_eigh
|
|
10
|
-
from ...utils.linalg.svd import randomized_svd
|
|
11
|
-
from ...utils.linalg.qr import qr_householder
|
|
12
7
|
|
|
13
|
-
def
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
8
|
+
def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping):
|
|
9
|
+
if isinstance(history, torch.Tensor):
|
|
10
|
+
M = history
|
|
11
|
+
else:
|
|
12
|
+
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
13
|
+
|
|
14
|
+
MTM = M.T @ M
|
|
15
|
+
if damping != 0:
|
|
16
|
+
MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
|
|
17
17
|
|
|
18
18
|
try:
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
|
|
20
|
+
|
|
21
|
+
tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
|
|
22
|
+
indices = L > tol
|
|
23
|
+
L = L[indices]
|
|
24
|
+
Q = Q[:, indices]
|
|
21
25
|
|
|
22
|
-
|
|
23
|
-
if rdamping != 0: rdamping *= torch.linalg.vector_norm(S) # pylint:disable=not-callable
|
|
24
|
-
Iu = damping + rdamping
|
|
25
|
-
if true_damping:
|
|
26
|
-
S.pow_(2)
|
|
27
|
-
Iu **= 2
|
|
28
|
-
S.add_(Iu)
|
|
29
|
-
if true_damping: S.sqrt_()
|
|
26
|
+
U = (M @ Q) * L.rsqrt()
|
|
30
27
|
|
|
31
|
-
|
|
28
|
+
if rdamping != 0:
|
|
29
|
+
rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
|
|
30
|
+
L.add_(rdamping)
|
|
31
|
+
|
|
32
|
+
return U, L
|
|
32
33
|
|
|
33
34
|
except torch.linalg.LinAlgError:
|
|
34
35
|
return None, None
|
|
35
36
|
|
|
36
|
-
def
|
|
37
|
-
|
|
38
|
-
return U @
|
|
39
|
-
|
|
37
|
+
def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
|
|
38
|
+
Z = U.T @ g
|
|
39
|
+
return (U * L.rsqrt()) @ Z
|
|
40
40
|
|
|
41
41
|
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
42
42
|
if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
|
|
43
43
|
else:
|
|
44
|
-
if state_[key].shape != value.shape: state_[key] = value
|
|
44
|
+
if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
|
|
45
45
|
else: state_[key].lerp_(value, 1-beta)
|
|
46
46
|
|
|
47
|
-
class
|
|
47
|
+
class LMAdagrad(TensorwiseTransform):
|
|
48
48
|
"""
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
49
|
+
Limited-memory full matrix Adagrad.
|
|
50
|
+
|
|
51
|
+
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
|
|
52
|
+
But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
|
|
53
|
+
|
|
54
|
+
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
53
55
|
|
|
54
56
|
Args:
|
|
55
57
|
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
@@ -61,54 +63,81 @@ class SpectralPreconditioner(TensorwiseTransform):
|
|
|
61
63
|
true_damping (bool, optional):
|
|
62
64
|
If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
|
|
63
65
|
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
64
|
-
|
|
66
|
+
L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
|
|
65
67
|
interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
|
|
66
|
-
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to
|
|
67
|
-
normalize (bool, optional): whether to normalize gradients, this doesn't work well so don't use it. Defaults to False.
|
|
68
|
-
centralize (bool, optional): whether to centralize gradients, this doesn't work well so don't use it. Defaults to False.
|
|
68
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
|
|
69
69
|
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
70
|
+
|
|
71
|
+
## Examples:
|
|
72
|
+
|
|
73
|
+
Limited-memory Adagrad
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
optimizer = tz.Modular(
|
|
77
|
+
model.parameters(),
|
|
78
|
+
tz.m.LMAdagrad(),
|
|
79
|
+
tz.m.LR(0.1)
|
|
80
|
+
)
|
|
81
|
+
```
|
|
82
|
+
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
optimizer = tz.Modular(
|
|
86
|
+
model.parameters(),
|
|
87
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
88
|
+
tz.m.Debias(0.9, 0.999),
|
|
89
|
+
tz.m.LR(0.01)
|
|
90
|
+
)
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
optimizer = tz.Modular(
|
|
97
|
+
model.parameters(),
|
|
98
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
99
|
+
tz.m.Debias(0.9, 0.999),
|
|
100
|
+
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
101
|
+
tz.m.LR(0.01)
|
|
102
|
+
)
|
|
103
|
+
```
|
|
104
|
+
Reference:
|
|
105
|
+
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
70
106
|
"""
|
|
71
107
|
|
|
72
108
|
def __init__(
|
|
73
109
|
self,
|
|
74
|
-
history_size: int =
|
|
110
|
+
history_size: int = 100,
|
|
75
111
|
update_freq: int = 1,
|
|
76
112
|
damping: float = 1e-4,
|
|
77
113
|
rdamping: float = 0,
|
|
78
114
|
order: int = 1,
|
|
79
115
|
true_damping: bool = True,
|
|
80
116
|
U_beta: float | None = None,
|
|
81
|
-
|
|
117
|
+
L_beta: float | None = None,
|
|
82
118
|
interval: int = 1,
|
|
83
|
-
concat_params: bool =
|
|
84
|
-
normalize: bool=False,
|
|
85
|
-
centralize:bool = False,
|
|
119
|
+
concat_params: bool = True,
|
|
86
120
|
inner: Chainable | None = None,
|
|
87
121
|
):
|
|
88
122
|
# history is still updated each step so Precondition's update_freq has different meaning
|
|
89
|
-
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta,
|
|
123
|
+
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
|
|
90
124
|
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
|
|
91
125
|
|
|
92
126
|
@torch.no_grad
|
|
93
|
-
def update_tensor(self, tensor, param, grad, loss, state,
|
|
94
|
-
order =
|
|
95
|
-
history_size =
|
|
96
|
-
update_freq =
|
|
97
|
-
damping =
|
|
98
|
-
rdamping =
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
S_beta = settings['S_beta']
|
|
102
|
-
normalize = settings['normalize']
|
|
103
|
-
centralize = settings['centralize']
|
|
127
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
128
|
+
order = setting['order']
|
|
129
|
+
history_size = setting['history_size']
|
|
130
|
+
update_freq = setting['update_freq']
|
|
131
|
+
damping = setting['damping']
|
|
132
|
+
rdamping = setting['rdamping']
|
|
133
|
+
U_beta = setting['U_beta']
|
|
134
|
+
L_beta = setting['L_beta']
|
|
104
135
|
|
|
105
136
|
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
106
137
|
history = state['history']
|
|
107
138
|
|
|
108
139
|
if order == 1:
|
|
109
140
|
t = tensor.clone().view(-1)
|
|
110
|
-
if centralize: t -= t.mean()
|
|
111
|
-
if normalize: t /= torch.linalg.vector_norm(t).clip(min=1e-8) # pylint:disable=not-callable
|
|
112
141
|
history.append(t)
|
|
113
142
|
else:
|
|
114
143
|
|
|
@@ -116,48 +145,42 @@ class SpectralPreconditioner(TensorwiseTransform):
|
|
|
116
145
|
# scaled by parameter differences
|
|
117
146
|
cur_p = param.clone()
|
|
118
147
|
cur_g = tensor.clone()
|
|
148
|
+
eps = torch.finfo(cur_p.dtype).tiny * 2
|
|
119
149
|
for i in range(1, order):
|
|
120
150
|
if f'prev_g_{i}' not in state:
|
|
121
151
|
state[f'prev_p_{i}'] = cur_p
|
|
122
152
|
state[f'prev_g_{i}'] = cur_g
|
|
123
153
|
break
|
|
124
154
|
|
|
125
|
-
|
|
126
|
-
|
|
155
|
+
s = cur_p - state[f'prev_p_{i}']
|
|
156
|
+
y = cur_g - state[f'prev_g_{i}']
|
|
127
157
|
state[f'prev_p_{i}'] = cur_p
|
|
128
158
|
state[f'prev_g_{i}'] = cur_g
|
|
129
|
-
cur_p =
|
|
130
|
-
cur_g =
|
|
159
|
+
cur_p = s
|
|
160
|
+
cur_g = y
|
|
131
161
|
|
|
132
162
|
if i == order - 1:
|
|
133
|
-
|
|
134
|
-
if normalize: cur_g = cur_g / torch.linalg.vector_norm(cur_g).clip(min=1e-8) # pylint:disable=not-callable
|
|
135
|
-
else: cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
|
|
163
|
+
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
|
|
136
164
|
history.append(cur_g.view(-1))
|
|
137
165
|
|
|
138
166
|
step = state.get('step', 0)
|
|
139
167
|
if step % update_freq == 0 and len(history) != 0:
|
|
140
|
-
U,
|
|
168
|
+
U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
|
|
141
169
|
maybe_lerp_(state, U_beta, 'U', U)
|
|
142
|
-
maybe_lerp_(state,
|
|
170
|
+
maybe_lerp_(state, L_beta, 'L', L)
|
|
143
171
|
|
|
144
172
|
if len(history) != 0:
|
|
145
173
|
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
146
174
|
|
|
147
175
|
@torch.no_grad
|
|
148
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
149
|
-
history_size = settings['history_size']
|
|
150
|
-
|
|
176
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
151
177
|
U = state.get('U', None)
|
|
152
178
|
if U is None:
|
|
153
179
|
# make a conservative step to avoid issues due to different GD scaling
|
|
154
180
|
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
155
181
|
|
|
156
|
-
|
|
157
|
-
update =
|
|
182
|
+
L = state['L']
|
|
183
|
+
update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
|
|
158
184
|
|
|
159
|
-
n = len(state['history'])
|
|
160
|
-
mh = min(history_size, 10)
|
|
161
|
-
if n <= mh: update.mul_(n/mh)
|
|
162
185
|
return update
|
|
163
186
|
|