torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
|
@@ -1,120 +0,0 @@
|
|
|
1
|
-
from contextlib import nullcontext
|
|
2
|
-
import warnings
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from functools import partial
|
|
5
|
-
import itertools
|
|
6
|
-
from typing import Literal
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from ...core import Chainable, Module, apply_transform
|
|
11
|
-
from ...utils import TensorList, vec_to_tensors
|
|
12
|
-
from ...utils.derivatives import (
|
|
13
|
-
hessian_list_to_mat,
|
|
14
|
-
jacobian_wrt, jacobian_and_hessian_wrt, hessian_mat,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
def _batched_dot(x, y):
|
|
18
|
-
return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)
|
|
19
|
-
|
|
20
|
-
def _cosine_similarity(x, y):
|
|
21
|
-
denom = torch.linalg.vector_norm(x, dim=-1) * torch.linalg.vector_norm(y, dim=-1).clip(min=torch.finfo(x.dtype).eps) # pylint:disable=not-callable
|
|
22
|
-
return _batched_dot(x, y) / denom
|
|
23
|
-
|
|
24
|
-
class EigenDescent(Module):
|
|
25
|
-
"""
|
|
26
|
-
Uses eigenvectors corresponding to certain eigenvalues. For now they are just extracted from hessian.
|
|
27
|
-
|
|
28
|
-
.. warning::
|
|
29
|
-
Experimental.
|
|
30
|
-
|
|
31
|
-
Args:
|
|
32
|
-
mode (str, optional):
|
|
33
|
-
- largest - use largest eigenvalue unless all eigenvalues are negative, then smallest is used.
|
|
34
|
-
- smallest - use smallest eigenvalue unless all eigenvalues are positive, then largest is used.
|
|
35
|
-
- mean-sign - use mean of eigenvectors multiplied by 1 or -1 if they point in opposite direction from gradient.
|
|
36
|
-
- mean-dot - use mean of eigenvectors multiplied by dot product with gradient.
|
|
37
|
-
- mean-cosine - use mean of eigenvectors multiplied by cosine similarity with gradient.
|
|
38
|
-
- mm - for testing.
|
|
39
|
-
|
|
40
|
-
Defaults to 'mean-sign'.
|
|
41
|
-
hessian_method (str, optional): how to calculate hessian. Defaults to "autograd".
|
|
42
|
-
vectorize (bool, optional): how to calculate hessian. Defaults to True.
|
|
43
|
-
|
|
44
|
-
"""
|
|
45
|
-
def __init__(
|
|
46
|
-
self,
|
|
47
|
-
mode: Literal['largest', 'smallest','magnitude', 'mean-sign', 'mean-dot', 'mean-cosine', 'mm'] = 'mean-sign',
|
|
48
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
49
|
-
vectorize: bool = True,
|
|
50
|
-
):
|
|
51
|
-
defaults = dict(hessian_method=hessian_method, vectorize=vectorize, mode=mode)
|
|
52
|
-
super().__init__(defaults)
|
|
53
|
-
|
|
54
|
-
@torch.no_grad
|
|
55
|
-
def step(self, var):
|
|
56
|
-
params = TensorList(var.params)
|
|
57
|
-
closure = var.closure
|
|
58
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
59
|
-
|
|
60
|
-
settings = self.settings[params[0]]
|
|
61
|
-
mode = settings['mode']
|
|
62
|
-
hessian_method = settings['hessian_method']
|
|
63
|
-
vectorize = settings['vectorize']
|
|
64
|
-
|
|
65
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
66
|
-
if hessian_method == 'autograd':
|
|
67
|
-
with torch.enable_grad():
|
|
68
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
69
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
70
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
71
|
-
var.grad = g_list
|
|
72
|
-
H = hessian_list_to_mat(H_list)
|
|
73
|
-
|
|
74
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
75
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
76
|
-
with torch.enable_grad():
|
|
77
|
-
g_list = var.get_grad(retain_graph=True)
|
|
78
|
-
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
79
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
80
|
-
|
|
81
|
-
else:
|
|
82
|
-
raise ValueError(hessian_method)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
86
|
-
g = torch.cat([t.ravel() for t in g_list])
|
|
87
|
-
L, Q = torch.linalg.eigh(H) # L is sorted # pylint:disable=not-callable
|
|
88
|
-
if mode == 'largest':
|
|
89
|
-
# smallest eigenvalue if all eigenvalues are negative else largest
|
|
90
|
-
if L[-1] <= 0: d = Q[0]
|
|
91
|
-
else: d = Q[-1]
|
|
92
|
-
|
|
93
|
-
elif mode == 'smallest':
|
|
94
|
-
# smallest eigenvalue if negative eigenvalues exist else largest
|
|
95
|
-
if L[0] <= 0: d = Q[0]
|
|
96
|
-
else: d = Q[-1]
|
|
97
|
-
|
|
98
|
-
elif mode == 'magnitude':
|
|
99
|
-
# largest by magnitude
|
|
100
|
-
if L[0].abs() > L[-1].abs(): d = Q[0]
|
|
101
|
-
else: d = Q[-1]
|
|
102
|
-
|
|
103
|
-
elif mode == 'mean-dot':
|
|
104
|
-
d = ((g.unsqueeze(0) @ Q).squeeze(0) * Q).mean(1)
|
|
105
|
-
|
|
106
|
-
elif mode == 'mean-sign':
|
|
107
|
-
d = ((g.unsqueeze(0) @ Q).squeeze(0).sign() * Q).mean(1)
|
|
108
|
-
|
|
109
|
-
elif mode == 'mean-cosine':
|
|
110
|
-
d = (Q * _cosine_similarity(Q, g)).mean(1)
|
|
111
|
-
|
|
112
|
-
elif mode == 'mm':
|
|
113
|
-
d = (g.unsqueeze(0) @ Q).squeeze(0) / g.numel()
|
|
114
|
-
|
|
115
|
-
else:
|
|
116
|
-
raise ValueError(mode)
|
|
117
|
-
|
|
118
|
-
var.update = vec_to_tensors(g.dot(d).sign() * d, params)
|
|
119
|
-
return var
|
|
120
|
-
|
|
@@ -1,195 +0,0 @@
|
|
|
1
|
-
from typing import cast
|
|
2
|
-
import warnings
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module
|
|
7
|
-
from ...utils import vec_to_tensors, vec_to_tensors_, as_tensorlist
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class ExponentialTrajectoryFit(Module):
|
|
11
|
-
"""A method.
|
|
12
|
-
|
|
13
|
-
.. warning::
|
|
14
|
-
Experimental.
|
|
15
|
-
"""
|
|
16
|
-
def __init__(self, step_size=1e-2, adaptive:bool=True):
|
|
17
|
-
defaults = dict(step_size = step_size,adaptive=adaptive)
|
|
18
|
-
super().__init__(defaults)
|
|
19
|
-
|
|
20
|
-
@torch.no_grad
|
|
21
|
-
def step(self, var):
|
|
22
|
-
closure = var.closure
|
|
23
|
-
assert closure is not None
|
|
24
|
-
step_size = self.settings[var.params[0]]['step_size']
|
|
25
|
-
adaptive = self.settings[var.params[0]]['adaptive']
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
# 1. perform 3 GD steps to obtain 4 points
|
|
29
|
-
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
30
|
-
for i in range(3):
|
|
31
|
-
if i == 0:
|
|
32
|
-
grad = var.get_grad()
|
|
33
|
-
if adaptive:
|
|
34
|
-
step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
|
|
35
|
-
|
|
36
|
-
else:
|
|
37
|
-
with torch.enable_grad(): closure()
|
|
38
|
-
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
39
|
-
|
|
40
|
-
# GD step
|
|
41
|
-
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
42
|
-
|
|
43
|
-
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
44
|
-
|
|
45
|
-
assert len(points) == 4, len(points)
|
|
46
|
-
x0, x1, x2, x3 = points
|
|
47
|
-
dim = x0.numel()
|
|
48
|
-
|
|
49
|
-
# 2. fit a generalized exponential curve
|
|
50
|
-
d0 = (x1 - x0).unsqueeze(1) # column vectors
|
|
51
|
-
d1 = (x2 - x1).unsqueeze(1)
|
|
52
|
-
d2 = (x3 - x2).unsqueeze(1)
|
|
53
|
-
|
|
54
|
-
# cat
|
|
55
|
-
D1 = torch.cat([d0, d1], dim=1)
|
|
56
|
-
D2 = torch.cat([d1, d2], dim=1)
|
|
57
|
-
|
|
58
|
-
# if points are collinear this will happen on sphere and a quadratic "line search" will minimize it
|
|
59
|
-
if x0.numel() >= 2:
|
|
60
|
-
if torch.linalg.matrix_rank(D1) < 2: # pylint:disable=not-callable
|
|
61
|
-
pass # need to put a quadratic fit there
|
|
62
|
-
|
|
63
|
-
M = D2 @ torch.linalg.pinv(D1) # pylint:disable=not-callable # this defines the curve
|
|
64
|
-
|
|
65
|
-
# now we can predict x*
|
|
66
|
-
I = torch.eye(dim, device=x0.device, dtype=x0.dtype)
|
|
67
|
-
B = I - M
|
|
68
|
-
z = x1 - M @ x0
|
|
69
|
-
|
|
70
|
-
x_star = torch.linalg.lstsq(B, z).solution # pylint:disable=not-callable
|
|
71
|
-
|
|
72
|
-
vec_to_tensors_(x0, var.params)
|
|
73
|
-
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
74
|
-
var.update = list(difference)
|
|
75
|
-
return var
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class ExponentialTrajectoryFitV2(Module):
|
|
80
|
-
"""Should be better than one above, except it isn't.
|
|
81
|
-
|
|
82
|
-
.. warning::
|
|
83
|
-
Experimental.
|
|
84
|
-
|
|
85
|
-
"""
|
|
86
|
-
def __init__(self, step_size=1e-3, num_steps: int= 4, adaptive:bool=True):
|
|
87
|
-
defaults = dict(step_size = step_size, num_steps=num_steps, adaptive=adaptive)
|
|
88
|
-
super().__init__(defaults)
|
|
89
|
-
|
|
90
|
-
@torch.no_grad
|
|
91
|
-
def step(self, var):
|
|
92
|
-
closure = var.closure
|
|
93
|
-
assert closure is not None
|
|
94
|
-
step_size = self.settings[var.params[0]]['step_size']
|
|
95
|
-
num_steps = self.settings[var.params[0]]['num_steps']
|
|
96
|
-
adaptive = self.settings[var.params[0]]['adaptive']
|
|
97
|
-
|
|
98
|
-
# 1. perform 3 GD steps to obtain 4 points (or more)
|
|
99
|
-
grad = var.get_grad()
|
|
100
|
-
if adaptive:
|
|
101
|
-
step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
|
|
102
|
-
|
|
103
|
-
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
104
|
-
point_grads = [torch.cat([g.view(-1) for g in grad])]
|
|
105
|
-
|
|
106
|
-
for i in range(num_steps):
|
|
107
|
-
# GD step
|
|
108
|
-
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
109
|
-
|
|
110
|
-
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
111
|
-
|
|
112
|
-
closure(backward=True)
|
|
113
|
-
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
114
|
-
point_grads.append(torch.cat([g.view(-1) for g in grad]))
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
X = torch.stack(points, 1) # dim, num_steps+1
|
|
118
|
-
G = torch.stack(point_grads, 1)
|
|
119
|
-
dim = points[0].numel()
|
|
120
|
-
|
|
121
|
-
X = torch.cat([X, torch.ones(1, num_steps+1, dtype=G.dtype, device=G.device)])
|
|
122
|
-
|
|
123
|
-
P = G @ torch.linalg.pinv(X) # pylint:disable=not-callable
|
|
124
|
-
A = P[:, :dim]
|
|
125
|
-
b = -P[:, dim]
|
|
126
|
-
|
|
127
|
-
# symmetrize
|
|
128
|
-
A = 0.5 * (A + A.T)
|
|
129
|
-
|
|
130
|
-
# predict x*
|
|
131
|
-
x_star = torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
|
|
132
|
-
|
|
133
|
-
vec_to_tensors_(points[0], var.params)
|
|
134
|
-
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
135
|
-
var.update = list(difference)
|
|
136
|
-
return var
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def _fit_exponential(y0, y1, y2):
|
|
142
|
-
"""x0, x1 and x2 are assumed to be 0, 1, 2"""
|
|
143
|
-
r = (y2 - y1) / (y1 - y0)
|
|
144
|
-
ones = r==1
|
|
145
|
-
r[ones] = 0
|
|
146
|
-
B = (y1 - y0) / (r - 1)
|
|
147
|
-
A = y0 - B
|
|
148
|
-
|
|
149
|
-
A[ones] = 0
|
|
150
|
-
B[ones] = 0
|
|
151
|
-
return A, B, r
|
|
152
|
-
|
|
153
|
-
class PointwiseExponential(Module):
|
|
154
|
-
"""A stupid method (for my youtube channel).
|
|
155
|
-
|
|
156
|
-
.. warning::
|
|
157
|
-
Experimental.
|
|
158
|
-
"""
|
|
159
|
-
def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
|
|
160
|
-
defaults = dict(reg=reg, steps=steps, step_size=step_size)
|
|
161
|
-
super().__init__(defaults)
|
|
162
|
-
|
|
163
|
-
@torch.no_grad
|
|
164
|
-
def step(self, var):
|
|
165
|
-
closure = var.closure
|
|
166
|
-
assert closure is not None
|
|
167
|
-
settings = self.settings[var.params[0]]
|
|
168
|
-
step_size = settings['step_size']
|
|
169
|
-
reg = settings['reg']
|
|
170
|
-
steps = settings['steps']
|
|
171
|
-
|
|
172
|
-
# 1. perform 2 GD steps to obtain 3 points
|
|
173
|
-
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
174
|
-
for i in range(2):
|
|
175
|
-
if i == 0: grad = var.get_grad()
|
|
176
|
-
else:
|
|
177
|
-
with torch.enable_grad(): closure()
|
|
178
|
-
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
179
|
-
|
|
180
|
-
# GD step
|
|
181
|
-
torch._foreach_sub_(var.params, grad, alpha=step_size)
|
|
182
|
-
|
|
183
|
-
points.append(torch.cat([p.view(-1) for p in var.params]))
|
|
184
|
-
|
|
185
|
-
assert len(points) == 3, len(points)
|
|
186
|
-
y0, y1, y2 = points
|
|
187
|
-
|
|
188
|
-
A, B, r = _fit_exponential(y0, y1, y2)
|
|
189
|
-
r = r.clip(max = 1-reg)
|
|
190
|
-
x_star = A + B * r**steps
|
|
191
|
-
|
|
192
|
-
vec_to_tensors_(y0, var.params)
|
|
193
|
-
difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
|
|
194
|
-
var.update = list(difference)
|
|
195
|
-
return var
|
|
@@ -1,113 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
from functools import partial
|
|
3
|
-
import math
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
-
from ..functional import (
|
|
9
|
-
debias, debiased_step_size,
|
|
10
|
-
ema_,
|
|
11
|
-
sqrt_ema_sq_,
|
|
12
|
-
)
|
|
13
|
-
from ..step_size.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def exp_adam_(
|
|
19
|
-
tensors: TensorList,
|
|
20
|
-
exp_avg_: TensorList,
|
|
21
|
-
exp_avg_exp_: TensorList,
|
|
22
|
-
alpha: float | NumberList,
|
|
23
|
-
beta1: float | NumberList,
|
|
24
|
-
beta2: float | NumberList,
|
|
25
|
-
eps: float | NumberList,
|
|
26
|
-
step: int,
|
|
27
|
-
pow: float = 2,
|
|
28
|
-
debiased: bool = True,
|
|
29
|
-
max_exp_avg_exp_: TensorList | None = None,
|
|
30
|
-
|
|
31
|
-
# inner args
|
|
32
|
-
inner: Module | None = None,
|
|
33
|
-
params: list[torch.Tensor] | None = None,
|
|
34
|
-
grads: list[torch.Tensor] | None = None,
|
|
35
|
-
):
|
|
36
|
-
"""Returns new tensors."""
|
|
37
|
-
tensors_exp = tensors.abs().clip_(max=math.log(torch.finfo(tensors[0].dtype).max) / 2).exp_()
|
|
38
|
-
exp_avg_exp_.lerp_(tensors_exp, 1-beta2)
|
|
39
|
-
|
|
40
|
-
if max_exp_avg_exp_ is not None:
|
|
41
|
-
max_exp_avg_exp_.maximum_(exp_avg_exp_)
|
|
42
|
-
exp_avg_exp_ = max_exp_avg_exp_
|
|
43
|
-
|
|
44
|
-
if inner is not None:
|
|
45
|
-
assert params is not None
|
|
46
|
-
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
47
|
-
|
|
48
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
49
|
-
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
50
|
-
return (exp_avg_.lazy_mul(alpha) / exp_avg_exp_.log().add_(eps))
|
|
51
|
-
|
|
52
|
-
class ExpAdam(Transform):
|
|
53
|
-
"""Adam but uses abs exp and log instead of square and sqrt.
|
|
54
|
-
The gradient will be clipped to half the maximum value representable by its dtype (around 50 for float32)
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
58
|
-
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
59
|
-
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
60
|
-
alpha (float, optional): learning rate. Defaults to 1.
|
|
61
|
-
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
62
|
-
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
63
|
-
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
64
|
-
"""
|
|
65
|
-
def __init__(
|
|
66
|
-
self,
|
|
67
|
-
beta1: float = 0.9,
|
|
68
|
-
beta2: float = 0.999,
|
|
69
|
-
eps: float = 1e-8,
|
|
70
|
-
amsgrad: bool = False,
|
|
71
|
-
alpha: float = 1.,
|
|
72
|
-
pow: float = 2,
|
|
73
|
-
debiased: bool = True,
|
|
74
|
-
inner: Chainable | None = None
|
|
75
|
-
):
|
|
76
|
-
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
77
|
-
super().__init__(defaults, uses_grad=False)
|
|
78
|
-
|
|
79
|
-
if inner is not None: self.set_child('inner', inner)
|
|
80
|
-
|
|
81
|
-
@torch.no_grad
|
|
82
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
83
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
84
|
-
|
|
85
|
-
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
86
|
-
amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
|
|
87
|
-
|
|
88
|
-
if amsgrad:
|
|
89
|
-
exp_avg, exp_avg_exp, max_exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', 'max_exp_avg_exp', cls=TensorList)
|
|
90
|
-
else:
|
|
91
|
-
exp_avg, exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', cls=TensorList)
|
|
92
|
-
max_exp_avg_exp = None
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
return exp_adam_(
|
|
96
|
-
tensors=TensorList(tensors),
|
|
97
|
-
exp_avg_=exp_avg,
|
|
98
|
-
exp_avg_exp_=exp_avg_exp,
|
|
99
|
-
alpha=alpha,
|
|
100
|
-
beta1=beta1,
|
|
101
|
-
beta2=beta2,
|
|
102
|
-
eps=eps,
|
|
103
|
-
step=step,
|
|
104
|
-
pow=pow,
|
|
105
|
-
debiased=debiased,
|
|
106
|
-
max_exp_avg_exp_=max_exp_avg_exp,
|
|
107
|
-
|
|
108
|
-
# inner args
|
|
109
|
-
inner=self.children.get("inner", None),
|
|
110
|
-
params=params,
|
|
111
|
-
grads=grads,
|
|
112
|
-
|
|
113
|
-
)
|
|
@@ -1,141 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from operator import itemgetter
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Transform, Chainable, Module, Var, apply_transform
|
|
6
|
-
from ...utils import TensorList, as_tensorlist, NumberList
|
|
7
|
-
from ...modules.quasi_newton.lbfgs import _adaptive_damping, lbfgs, _lerp_params_update_
|
|
8
|
-
|
|
9
|
-
class ExpandedLBFGS(Module):
|
|
10
|
-
"""L-BFGS but uses differences between more pairs than just consequtive. Window size controls how far away the pairs are allowed to be.
|
|
11
|
-
"""
|
|
12
|
-
def __init__(
|
|
13
|
-
self,
|
|
14
|
-
history_size=10,
|
|
15
|
-
window_size:int=3,
|
|
16
|
-
tol: float | None = 1e-10,
|
|
17
|
-
damping: bool = False,
|
|
18
|
-
init_damping=0.9,
|
|
19
|
-
eigval_bounds=(0.5, 50),
|
|
20
|
-
params_beta: float | None = None,
|
|
21
|
-
grads_beta: float | None = None,
|
|
22
|
-
update_freq = 1,
|
|
23
|
-
z_beta: float | None = None,
|
|
24
|
-
tol_reset: bool = False,
|
|
25
|
-
inner: Chainable | None = None,
|
|
26
|
-
):
|
|
27
|
-
defaults = dict(history_size=history_size, window_size=window_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
|
|
28
|
-
super().__init__(defaults)
|
|
29
|
-
|
|
30
|
-
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
31
|
-
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
32
|
-
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
33
|
-
self.global_state['p_history'] = deque(maxlen=window_size)
|
|
34
|
-
self.global_state['g_history'] = deque(maxlen=window_size)
|
|
35
|
-
|
|
36
|
-
if inner is not None:
|
|
37
|
-
self.set_child('inner', inner)
|
|
38
|
-
|
|
39
|
-
def reset(self):
|
|
40
|
-
self.state.clear()
|
|
41
|
-
self.global_state['step'] = 0
|
|
42
|
-
self.global_state['s_history'].clear()
|
|
43
|
-
self.global_state['y_history'].clear()
|
|
44
|
-
self.global_state['sy_history'].clear()
|
|
45
|
-
self.global_state['p_history'].clear()
|
|
46
|
-
self.global_state['g_history'].clear()
|
|
47
|
-
|
|
48
|
-
@torch.no_grad
|
|
49
|
-
def step(self, var):
|
|
50
|
-
params = as_tensorlist(var.params)
|
|
51
|
-
update = as_tensorlist(var.get_update())
|
|
52
|
-
step = self.global_state.get('step', 0)
|
|
53
|
-
self.global_state['step'] = step + 1
|
|
54
|
-
|
|
55
|
-
# history of s and k
|
|
56
|
-
s_history: deque[TensorList] = self.global_state['s_history']
|
|
57
|
-
y_history: deque[TensorList] = self.global_state['y_history']
|
|
58
|
-
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
59
|
-
p_history: deque[TensorList] = self.global_state['p_history']
|
|
60
|
-
g_history: deque[TensorList] = self.global_state['g_history']
|
|
61
|
-
|
|
62
|
-
tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
|
|
63
|
-
'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
|
|
64
|
-
params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
|
|
65
|
-
|
|
66
|
-
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
67
|
-
prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
68
|
-
|
|
69
|
-
# 1st step - there are no previous params and grads, lbfgs will do normalized GD step
|
|
70
|
-
if step == 0:
|
|
71
|
-
s = None; y = None; ys = None
|
|
72
|
-
else:
|
|
73
|
-
s = l_params - prev_l_params
|
|
74
|
-
y = l_update - prev_l_grad
|
|
75
|
-
ys = s.dot(y)
|
|
76
|
-
|
|
77
|
-
if damping:
|
|
78
|
-
s, y, ys = _adaptive_damping(s, y, ys, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
79
|
-
|
|
80
|
-
prev_l_params.copy_(l_params)
|
|
81
|
-
prev_l_grad.copy_(l_update)
|
|
82
|
-
|
|
83
|
-
# update effective preconditioning state
|
|
84
|
-
if step % update_freq == 0:
|
|
85
|
-
if ys is not None and ys > 1e-10:
|
|
86
|
-
assert s is not None and y is not None
|
|
87
|
-
s_history.append(s)
|
|
88
|
-
y_history.append(y)
|
|
89
|
-
sy_history.append(ys)
|
|
90
|
-
|
|
91
|
-
if len(p_history) > 1:
|
|
92
|
-
for p_i, g_i in zip(list(p_history)[:-1], list(g_history)[:-1]):
|
|
93
|
-
s_i = l_params - p_i
|
|
94
|
-
y_i = l_update - g_i
|
|
95
|
-
ys_i = s_i.dot(y_i)
|
|
96
|
-
|
|
97
|
-
if ys_i > 1e-10:
|
|
98
|
-
if damping:
|
|
99
|
-
s_i, y_i, ys_i = _adaptive_damping(s_i, y_i, ys_i, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
100
|
-
|
|
101
|
-
s_history.append(s_i)
|
|
102
|
-
y_history.append(y_i)
|
|
103
|
-
sy_history.append(ys_i)
|
|
104
|
-
|
|
105
|
-
p_history.append(l_params.clone())
|
|
106
|
-
g_history.append(l_update.clone())
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
# step with inner module before applying preconditioner
|
|
110
|
-
if self.children:
|
|
111
|
-
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
112
|
-
|
|
113
|
-
# tolerance on gradient difference to avoid exploding after converging
|
|
114
|
-
if tol is not None:
|
|
115
|
-
if y is not None and y.abs().global_max() <= tol:
|
|
116
|
-
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
117
|
-
if tol_reset: self.reset()
|
|
118
|
-
return var
|
|
119
|
-
|
|
120
|
-
# lerp initial H^-1 @ q guess
|
|
121
|
-
z_ema = None
|
|
122
|
-
if z_beta is not None:
|
|
123
|
-
z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
|
|
124
|
-
|
|
125
|
-
# precondition
|
|
126
|
-
dir = lbfgs(
|
|
127
|
-
tensors_=as_tensorlist(update),
|
|
128
|
-
s_history=s_history,
|
|
129
|
-
y_history=y_history,
|
|
130
|
-
sy_history=sy_history,
|
|
131
|
-
y=y,
|
|
132
|
-
sy=ys,
|
|
133
|
-
z_beta = z_beta,
|
|
134
|
-
z_ema = z_ema,
|
|
135
|
-
step=step
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
var.update = dir
|
|
139
|
-
|
|
140
|
-
return var
|
|
141
|
-
|
|
@@ -1,85 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import TensorwiseTransform
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def eigh_solve(H: torch.Tensor, g: torch.Tensor):
|
|
9
|
-
try:
|
|
10
|
-
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
11
|
-
return Q @ ((Q.mH @ g) / L)
|
|
12
|
-
except torch.linalg.LinAlgError:
|
|
13
|
-
return None
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class HNewton(TensorwiseTransform):
|
|
17
|
-
"""This treats gradient differences as Hvps with vectors being parameter differences, using past gradients that are close to each other. Basically this is another limited memory quasi newton method to test.
|
|
18
|
-
|
|
19
|
-
.. warning::
|
|
20
|
-
Experimental.
|
|
21
|
-
|
|
22
|
-
"""
|
|
23
|
-
def __init__(self, history_size: int, window_size: int, reg: float=0, tol: float = 1e-8, concat_params:bool=True, inner=None):
|
|
24
|
-
defaults = dict(history_size=history_size, window_size=window_size, reg=reg, tol=tol)
|
|
25
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner)
|
|
26
|
-
|
|
27
|
-
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
28
|
-
|
|
29
|
-
history_size = setting['history_size']
|
|
30
|
-
|
|
31
|
-
if 'param_history' not in state:
|
|
32
|
-
state['param_history'] = deque(maxlen=history_size)
|
|
33
|
-
state['grad_history'] = deque(maxlen=history_size)
|
|
34
|
-
|
|
35
|
-
param_history: deque = state['param_history']
|
|
36
|
-
grad_history: deque = state['grad_history']
|
|
37
|
-
param_history.append(param.ravel())
|
|
38
|
-
grad_history.append(tensor.ravel())
|
|
39
|
-
|
|
40
|
-
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
41
|
-
window_size = setting['window_size']
|
|
42
|
-
reg = setting['reg']
|
|
43
|
-
tol = setting['tol']
|
|
44
|
-
|
|
45
|
-
param_history: deque = state['param_history']
|
|
46
|
-
grad_history: deque = state['grad_history']
|
|
47
|
-
g = tensor.ravel()
|
|
48
|
-
|
|
49
|
-
n = len(param_history)
|
|
50
|
-
s_list = []
|
|
51
|
-
y_list = []
|
|
52
|
-
|
|
53
|
-
for i in range(n):
|
|
54
|
-
for j in range(i):
|
|
55
|
-
if i - j <= window_size:
|
|
56
|
-
p_i, g_i = param_history[i], grad_history[i]
|
|
57
|
-
p_j, g_j = param_history[j], grad_history[j]
|
|
58
|
-
s = p_i - p_j # vec in hvp
|
|
59
|
-
y = g_i - g_j # hvp
|
|
60
|
-
if s.dot(y) > tol:
|
|
61
|
-
s_list.append(s)
|
|
62
|
-
y_list.append(y)
|
|
63
|
-
|
|
64
|
-
if len(s_list) < 1:
|
|
65
|
-
scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
|
|
66
|
-
tensor.mul_(scale)
|
|
67
|
-
return tensor
|
|
68
|
-
|
|
69
|
-
S = torch.stack(s_list, 1)
|
|
70
|
-
Y = torch.stack(y_list, 1)
|
|
71
|
-
|
|
72
|
-
B = S.T @ Y
|
|
73
|
-
if reg != 0: B.add_(torch.eye(B.size(0), device=B.device, dtype=B.dtype).mul_(reg))
|
|
74
|
-
g_proj = g @ S
|
|
75
|
-
|
|
76
|
-
newton_proj, info = torch.linalg.solve_ex(B, g_proj) # pylint:disable=not-callable
|
|
77
|
-
if info != 0:
|
|
78
|
-
newton_proj = -torch.linalg.lstsq(B, g_proj).solution # pylint:disable=not-callable
|
|
79
|
-
newton = S @ newton_proj
|
|
80
|
-
return newton.view_as(tensor)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
# scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
|
|
84
|
-
# tensor.mul_(scale)
|
|
85
|
-
# return tensor
|