torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -6,10 +6,10 @@ from typing import Literal
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from ...core import Target, Transform
|
|
9
|
-
from ...utils import NumberList, TensorList,
|
|
9
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
10
|
from ..functional import ema_, ema_sq_, sqrt_ema_sq_
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
11
|
+
from ..momentum.momentum import nag_
|
|
12
|
+
from ..ops.higher_level import EMASquared, SqrtEMASquared
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def precentered_ema_sq_(
|
|
@@ -49,7 +49,7 @@ class PrecenteredEMASquared(Transform):
|
|
|
49
49
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
50
50
|
|
|
51
51
|
@torch.no_grad
|
|
52
|
-
def
|
|
52
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
53
53
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
54
54
|
|
|
55
55
|
beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
|
|
@@ -154,44 +154,7 @@ class CoordinateMomentum(Transform):
|
|
|
154
154
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
155
155
|
|
|
156
156
|
@torch.no_grad
|
|
157
|
-
def
|
|
157
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
158
158
|
p = NumberList(s['p'] for s in settings)
|
|
159
159
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
160
160
|
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
# def multiplicative_momentum_(
|
|
164
|
-
# tensors_: TensorList,
|
|
165
|
-
# velocity_: TensorList,
|
|
166
|
-
# momentum: float | NumberList,
|
|
167
|
-
# dampening: float | NumberList,
|
|
168
|
-
# normalize_velocity: bool = True,
|
|
169
|
-
# abs: bool = False,
|
|
170
|
-
# lerp: bool = False,
|
|
171
|
-
# ):
|
|
172
|
-
# """
|
|
173
|
-
# abs: if True, tracks momentum of absolute magnitudes.
|
|
174
|
-
|
|
175
|
-
# returns `tensors_`.
|
|
176
|
-
# """
|
|
177
|
-
# tensors_into_velocity = tensors_.abs() if abs else tensors_
|
|
178
|
-
# ema_(tensors_into_velocity, exp_avg_=velocity_, beta=momentum, dampening=0, lerp=lerp)
|
|
179
|
-
|
|
180
|
-
# if normalize_velocity: velocity_ = velocity_ / velocity_.std().add_(1e-8)
|
|
181
|
-
# return tensors_.mul_(velocity_.lazy_mul(1-dampening) if abs else velocity_.abs().lazy_mul_(1-dampening))
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
# class MultiplicativeMomentum(Transform):
|
|
185
|
-
# """sucks"""
|
|
186
|
-
# def __init__(self, momentum: float = 0.9, dampening: float = 0,normalize_velocity: bool = True, abs: bool = False, lerp: bool = False):
|
|
187
|
-
# defaults = dict(momentum=momentum, dampening=dampening, normalize_velocity=normalize_velocity,abs=abs, lerp=lerp)
|
|
188
|
-
# super().__init__(defaults, uses_grad=False)
|
|
189
|
-
|
|
190
|
-
# @torch.no_grad
|
|
191
|
-
# def apply(self, tensors, params, grads, loss, states, settings):
|
|
192
|
-
# momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
|
|
193
|
-
# abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
|
|
194
|
-
# velocity = self.get_state('velocity', params=params, cls=TensorList)
|
|
195
|
-
# return multiplicative_momentum_(TensorList(target), velocity_=velocity, momentum=momentum, dampening=dampening,
|
|
196
|
-
# normalize_velocity=normalize_velocity,abs=abs,lerp=lerp)
|
|
197
|
-
|
|
@@ -3,28 +3,36 @@ from typing import Any, Literal, overload
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module, apply_transform
|
|
6
|
+
from ...core import Chainable, Modular, Module, apply_transform
|
|
7
7
|
from ...utils import TensorList, as_tensorlist
|
|
8
|
-
from ...utils.derivatives import hvp
|
|
8
|
+
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
9
9
|
from ..quasi_newton import LBFGS
|
|
10
10
|
|
|
11
|
+
|
|
11
12
|
class NewtonSolver(Module):
|
|
12
|
-
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
|
|
13
|
+
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
|
|
13
14
|
def __init__(
|
|
14
15
|
self,
|
|
15
16
|
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
16
17
|
maxiter=None,
|
|
17
|
-
|
|
18
|
+
maxiter1=None,
|
|
19
|
+
tol:float | None=1e-3,
|
|
18
20
|
reg: float = 0,
|
|
19
21
|
warm_start=True,
|
|
22
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
23
|
+
reset_solver: bool = False,
|
|
24
|
+
h: float= 1e-3,
|
|
20
25
|
inner: Chainable | None = None,
|
|
21
26
|
):
|
|
22
|
-
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, warm_start=warm_start, solver=solver)
|
|
27
|
+
defaults = dict(tol=tol, h=h,reset_solver=reset_solver, maxiter=maxiter, maxiter1=maxiter1, reg=reg, warm_start=warm_start, solver=solver, hvp_method=hvp_method)
|
|
23
28
|
super().__init__(defaults,)
|
|
24
29
|
|
|
25
30
|
if inner is not None:
|
|
26
31
|
self.set_child('inner', inner)
|
|
27
32
|
|
|
33
|
+
self._num_hvps = 0
|
|
34
|
+
self._num_hvps_last_step = 0
|
|
35
|
+
|
|
28
36
|
@torch.no_grad
|
|
29
37
|
def step(self, var):
|
|
30
38
|
params = TensorList(var.params)
|
|
@@ -34,19 +42,49 @@ class NewtonSolver(Module):
|
|
|
34
42
|
settings = self.settings[params[0]]
|
|
35
43
|
solver_cls = settings['solver']
|
|
36
44
|
maxiter = settings['maxiter']
|
|
45
|
+
maxiter1 = settings['maxiter1']
|
|
37
46
|
tol = settings['tol']
|
|
38
47
|
reg = settings['reg']
|
|
48
|
+
hvp_method = settings['hvp_method']
|
|
39
49
|
warm_start = settings['warm_start']
|
|
50
|
+
h = settings['h']
|
|
51
|
+
reset_solver = settings['reset_solver']
|
|
40
52
|
|
|
53
|
+
self._num_hvps_last_step = 0
|
|
41
54
|
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
-
|
|
55
|
+
if hvp_method == 'autograd':
|
|
56
|
+
grad = var.get_grad(create_graph=True)
|
|
43
57
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
58
|
+
def H_mm(x):
|
|
59
|
+
self._num_hvps_last_step += 1
|
|
60
|
+
with torch.enable_grad():
|
|
61
|
+
Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
|
|
47
62
|
if reg != 0: Hvp = Hvp + (x*reg)
|
|
48
63
|
return Hvp
|
|
49
64
|
|
|
65
|
+
else:
|
|
66
|
+
|
|
67
|
+
with torch.enable_grad():
|
|
68
|
+
grad = var.get_grad()
|
|
69
|
+
|
|
70
|
+
if hvp_method == 'forward':
|
|
71
|
+
def H_mm(x):
|
|
72
|
+
self._num_hvps_last_step += 1
|
|
73
|
+
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
74
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
75
|
+
return Hvp
|
|
76
|
+
|
|
77
|
+
elif hvp_method == 'central':
|
|
78
|
+
def H_mm(x):
|
|
79
|
+
self._num_hvps_last_step += 1
|
|
80
|
+
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
81
|
+
if reg != 0: Hvp = Hvp + (x*reg)
|
|
82
|
+
return Hvp
|
|
83
|
+
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(hvp_method)
|
|
86
|
+
|
|
87
|
+
|
|
50
88
|
# -------------------------------- inner step -------------------------------- #
|
|
51
89
|
b = as_tensorlist(grad)
|
|
52
90
|
if 'inner' in self.children:
|
|
@@ -58,23 +96,46 @@ class NewtonSolver(Module):
|
|
|
58
96
|
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
97
|
else: x = x0.clone().requires_grad_(True)
|
|
60
98
|
|
|
61
|
-
|
|
99
|
+
|
|
100
|
+
if 'solver' not in self.global_state:
|
|
101
|
+
if maxiter1 is not None: maxiter = maxiter1
|
|
102
|
+
solver = self.global_state['solver'] = solver_cls(x)
|
|
103
|
+
self.global_state['x'] = x
|
|
104
|
+
|
|
105
|
+
else:
|
|
106
|
+
if reset_solver:
|
|
107
|
+
solver = self.global_state['solver'] = solver_cls(x)
|
|
108
|
+
else:
|
|
109
|
+
solver_params = self.global_state['x']
|
|
110
|
+
solver_params.set_(x)
|
|
111
|
+
x = solver_params
|
|
112
|
+
solver = self.global_state['solver']
|
|
113
|
+
|
|
62
114
|
def lstsq_closure(backward=True):
|
|
63
|
-
Hx = H_mm(x)
|
|
64
|
-
loss = (Hx-b).pow(2).global_mean()
|
|
115
|
+
Hx = H_mm(x).detach()
|
|
116
|
+
# loss = (Hx-b).pow(2).global_mean()
|
|
117
|
+
# if backward:
|
|
118
|
+
# solver.zero_grad()
|
|
119
|
+
# loss.backward(inputs=x)
|
|
120
|
+
|
|
121
|
+
residual = Hx - b
|
|
122
|
+
loss = residual.pow(2).global_mean()
|
|
65
123
|
if backward:
|
|
66
|
-
|
|
67
|
-
|
|
124
|
+
with torch.no_grad():
|
|
125
|
+
H_residual = H_mm(residual)
|
|
126
|
+
n = residual.global_numel()
|
|
127
|
+
x.set_grad_((2.0 / n) * H_residual)
|
|
128
|
+
|
|
68
129
|
return loss
|
|
69
130
|
|
|
70
131
|
if maxiter is None: maxiter = b.global_numel()
|
|
71
132
|
loss = None
|
|
72
|
-
initial_loss = lstsq_closure(False)
|
|
73
|
-
if initial_loss >
|
|
133
|
+
initial_loss = lstsq_closure(False) if tol is not None else None # skip unnecessary closure if tol is None
|
|
134
|
+
if initial_loss is None or initial_loss > torch.finfo(b[0].dtype).eps:
|
|
74
135
|
for i in range(maxiter):
|
|
75
136
|
loss = solver.step(lstsq_closure)
|
|
76
137
|
assert loss is not None
|
|
77
|
-
if
|
|
138
|
+
if initial_loss is not None and loss/initial_loss < tol: break
|
|
78
139
|
|
|
79
140
|
# print(f'{loss = }')
|
|
80
141
|
|
|
@@ -83,6 +144,7 @@ class NewtonSolver(Module):
|
|
|
83
144
|
x0.copy_(x)
|
|
84
145
|
|
|
85
146
|
var.update = x.detach()
|
|
147
|
+
self._num_hvps += self._num_hvps_last_step
|
|
86
148
|
return var
|
|
87
149
|
|
|
88
150
|
|
|
@@ -10,20 +10,21 @@ import torch
|
|
|
10
10
|
from ...core import Chainable, Module, apply_transform
|
|
11
11
|
from ...utils import TensorList, vec_to_tensors
|
|
12
12
|
from ...utils.derivatives import (
|
|
13
|
-
|
|
13
|
+
flatten_jacobian,
|
|
14
14
|
jacobian_wrt,
|
|
15
15
|
)
|
|
16
16
|
from ..second_order.newton import (
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
_cholesky_solve,
|
|
18
|
+
_eigh_solve,
|
|
19
|
+
_least_squares_solve,
|
|
20
|
+
_lu_solve,
|
|
21
21
|
)
|
|
22
|
-
|
|
22
|
+
from ...utils.linalg.linear_operator import Dense
|
|
23
23
|
|
|
24
24
|
class NewtonNewton(Module):
|
|
25
|
-
"""
|
|
26
|
-
|
|
25
|
+
"""Applies Newton-like preconditioning to Newton step.
|
|
26
|
+
|
|
27
|
+
This is a method that I thought of and then it worked. Here is how it works:
|
|
27
28
|
|
|
28
29
|
1. Calculate newton step by solving Hx=g
|
|
29
30
|
|
|
@@ -34,6 +35,9 @@ class NewtonNewton(Module):
|
|
|
34
35
|
4. Optionally, repeat (if order is higher than 3.)
|
|
35
36
|
|
|
36
37
|
Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
|
|
38
|
+
|
|
39
|
+
3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
|
|
40
|
+
this is if pytorch can vectorize hessian computation efficiently.
|
|
37
41
|
"""
|
|
38
42
|
def __init__(
|
|
39
43
|
self,
|
|
@@ -47,10 +51,10 @@ class NewtonNewton(Module):
|
|
|
47
51
|
super().__init__(defaults)
|
|
48
52
|
|
|
49
53
|
@torch.no_grad
|
|
50
|
-
def
|
|
54
|
+
def update(self, var):
|
|
51
55
|
params = TensorList(var.params)
|
|
52
56
|
closure = var.closure
|
|
53
|
-
if closure is None: raise RuntimeError('
|
|
57
|
+
if closure is None: raise RuntimeError('NewtonNewton requires closure')
|
|
54
58
|
|
|
55
59
|
settings = self.settings[params[0]]
|
|
56
60
|
reg = settings['reg']
|
|
@@ -60,6 +64,7 @@ class NewtonNewton(Module):
|
|
|
60
64
|
eigval_tfm = settings['eigval_tfm']
|
|
61
65
|
|
|
62
66
|
# ------------------------ calculate grad and hessian ------------------------ #
|
|
67
|
+
Hs = []
|
|
63
68
|
with torch.enable_grad():
|
|
64
69
|
loss = var.loss = var.loss_approx = closure(False)
|
|
65
70
|
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
@@ -72,17 +77,29 @@ class NewtonNewton(Module):
|
|
|
72
77
|
is_last = o == order
|
|
73
78
|
H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
|
|
74
79
|
with torch.no_grad() if is_last else nullcontext():
|
|
75
|
-
H =
|
|
80
|
+
H = flatten_jacobian(H_list)
|
|
76
81
|
if reg != 0: H = H + I * reg
|
|
82
|
+
Hs.append(H)
|
|
77
83
|
|
|
78
84
|
x = None
|
|
79
85
|
if search_negative or (is_last and eigval_tfm is not None):
|
|
80
|
-
x =
|
|
81
|
-
if x is None: x =
|
|
82
|
-
if x is None: x =
|
|
83
|
-
if x is None: x =
|
|
86
|
+
x = _eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
|
|
87
|
+
if x is None: x = _cholesky_solve(H, xp)
|
|
88
|
+
if x is None: x = _lu_solve(H, xp)
|
|
89
|
+
if x is None: x = _least_squares_solve(H, xp)
|
|
84
90
|
xp = x.squeeze()
|
|
85
91
|
|
|
92
|
+
self.global_state["Hs"] = Hs
|
|
93
|
+
self.global_state['xp'] = xp.nan_to_num_(0,0,0)
|
|
94
|
+
|
|
95
|
+
@torch.no_grad
|
|
96
|
+
def apply(self, var):
|
|
97
|
+
params = var.params
|
|
98
|
+
xp = self.global_state['xp']
|
|
86
99
|
var.update = vec_to_tensors(xp, params)
|
|
87
100
|
return var
|
|
88
101
|
|
|
102
|
+
def get_H(self, var):
|
|
103
|
+
Hs = self.global_state["Hs"]
|
|
104
|
+
if len(Hs) == 1: return Dense(Hs[0])
|
|
105
|
+
return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
|
|
@@ -4,19 +4,19 @@ from ...core import Target, Transform
|
|
|
4
4
|
from ...utils import TensorList, unpack_states, unpack_dicts
|
|
5
5
|
|
|
6
6
|
class ReduceOutwardLR(Transform):
|
|
7
|
-
"""
|
|
8
|
-
When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
7
|
+
"""When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
|
|
9
8
|
|
|
10
9
|
This means updates that move weights towards zero have higher learning rates.
|
|
11
10
|
|
|
12
|
-
|
|
11
|
+
.. warning::
|
|
12
|
+
This sounded good but after testing turns out it sucks.
|
|
13
13
|
"""
|
|
14
14
|
def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
|
|
15
15
|
defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
|
|
16
16
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
17
17
|
|
|
18
18
|
@torch.no_grad
|
|
19
|
-
def
|
|
19
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
20
20
|
params = TensorList(params)
|
|
21
21
|
tensors = TensorList(tensors)
|
|
22
22
|
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from typing import Literal, overload
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from scipy.sparse.linalg import LinearOperator, gcrotmk
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, apply_transform
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, generic_vector_norm, vec_to_tensors
|
|
8
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
9
|
+
from ...utils.linalg.solve import cg, minres
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ScipyNewtonCG(Module):
|
|
13
|
+
"""NewtonCG with scipy solvers (any from scipy.sparse.linalg)"""
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
solver = gcrotmk,
|
|
17
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
+
h: float = 1e-3,
|
|
19
|
+
warm_start=False,
|
|
20
|
+
inner: Chainable | None = None,
|
|
21
|
+
kwargs: dict | None = None,
|
|
22
|
+
):
|
|
23
|
+
defaults = dict(hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
|
|
24
|
+
super().__init__(defaults,)
|
|
25
|
+
|
|
26
|
+
if inner is not None:
|
|
27
|
+
self.set_child('inner', inner)
|
|
28
|
+
|
|
29
|
+
self._num_hvps = 0
|
|
30
|
+
self._num_hvps_last_step = 0
|
|
31
|
+
|
|
32
|
+
if kwargs is None: kwargs = {}
|
|
33
|
+
self._kwargs = kwargs
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def step(self, var):
|
|
37
|
+
params = TensorList(var.params)
|
|
38
|
+
closure = var.closure
|
|
39
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
40
|
+
|
|
41
|
+
settings = self.settings[params[0]]
|
|
42
|
+
hvp_method = settings['hvp_method']
|
|
43
|
+
solver = settings['solver']
|
|
44
|
+
h = settings['h']
|
|
45
|
+
warm_start = settings['warm_start']
|
|
46
|
+
|
|
47
|
+
self._num_hvps_last_step = 0
|
|
48
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
49
|
+
device = params[0].device; dtype=params[0].dtype
|
|
50
|
+
if hvp_method == 'autograd':
|
|
51
|
+
grad = var.get_grad(create_graph=True)
|
|
52
|
+
|
|
53
|
+
def H_mm(x_np):
|
|
54
|
+
self._num_hvps_last_step += 1
|
|
55
|
+
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
56
|
+
with torch.enable_grad():
|
|
57
|
+
Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
|
|
58
|
+
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
|
|
62
|
+
with torch.enable_grad():
|
|
63
|
+
grad = var.get_grad()
|
|
64
|
+
|
|
65
|
+
if hvp_method == 'forward':
|
|
66
|
+
def H_mm(x_np):
|
|
67
|
+
self._num_hvps_last_step += 1
|
|
68
|
+
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
69
|
+
Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
70
|
+
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
71
|
+
|
|
72
|
+
elif hvp_method == 'central':
|
|
73
|
+
def H_mm(x_np):
|
|
74
|
+
self._num_hvps_last_step += 1
|
|
75
|
+
x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
|
|
76
|
+
Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
77
|
+
return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(hvp_method)
|
|
81
|
+
|
|
82
|
+
ndim = sum(p.numel() for p in params)
|
|
83
|
+
H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore
|
|
84
|
+
|
|
85
|
+
# -------------------------------- inner step -------------------------------- #
|
|
86
|
+
b = var.get_update()
|
|
87
|
+
if 'inner' in self.children:
|
|
88
|
+
b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
|
|
89
|
+
b = as_tensorlist(b)
|
|
90
|
+
|
|
91
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
92
|
+
x0 = None
|
|
93
|
+
if warm_start: x0 = self.global_state.get('x_prev', None) # initialized to 0 which is default anyway
|
|
94
|
+
|
|
95
|
+
x_np = solver(H, b.to_vec().nan_to_num().numpy(force=True), x0=x0, **self._kwargs)
|
|
96
|
+
if isinstance(x_np, tuple): x_np = x_np[0]
|
|
97
|
+
|
|
98
|
+
if warm_start:
|
|
99
|
+
self.global_state['x_prev'] = x_np
|
|
100
|
+
|
|
101
|
+
var.update = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)
|
|
102
|
+
|
|
103
|
+
self._num_hvps += self._num_hvps_last_step
|
|
104
|
+
return var
|
|
105
|
+
|
|
@@ -5,36 +5,19 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
from ...core import Chainable
|
|
7
7
|
from ...utils import vec_to_tensors, TensorList
|
|
8
|
-
from ..
|
|
9
|
-
from
|
|
8
|
+
from ..adaptive.shampoo import _merge_small_dims
|
|
9
|
+
from ..projections import ProjectionBase
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class VectorProjection(Projection):
|
|
13
|
-
"""
|
|
14
|
-
flattens and concatenates all parameters into a vector
|
|
15
|
-
"""
|
|
16
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
17
|
-
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
18
12
|
|
|
19
|
-
|
|
20
|
-
def project(self, tensors, var, current):
|
|
21
|
-
return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
|
|
22
|
-
|
|
23
|
-
@torch.no_grad
|
|
24
|
-
def unproject(self, tensors, var, current):
|
|
25
|
-
return vec_to_tensors(vec=tensors[0], reference=var.params)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class TensorizeProjection(Projection):
|
|
13
|
+
class TensorizeProjection(ProjectionBase):
|
|
30
14
|
"""flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
|
|
31
15
|
def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
|
|
32
16
|
defaults = dict(max_side=max_side)
|
|
33
17
|
super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
34
18
|
|
|
35
19
|
@torch.no_grad
|
|
36
|
-
def project(self, tensors,
|
|
37
|
-
params = var.params
|
|
20
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
38
21
|
max_side = self.settings[params[0]]['max_side']
|
|
39
22
|
num_elems = sum(t.numel() for t in tensors)
|
|
40
23
|
|
|
@@ -60,23 +43,23 @@ class TensorizeProjection(Projection):
|
|
|
60
43
|
return [vec.view(dims)]
|
|
61
44
|
|
|
62
45
|
@torch.no_grad
|
|
63
|
-
def unproject(self,
|
|
46
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
64
47
|
remainder = self.global_state['remainder']
|
|
65
48
|
# warnings.warn(f'{tensors[0].shape = }')
|
|
66
|
-
vec =
|
|
49
|
+
vec = projected_tensors[0].view(-1)
|
|
67
50
|
if remainder > 0: vec = vec[:-remainder]
|
|
68
|
-
return vec_to_tensors(vec,
|
|
51
|
+
return vec_to_tensors(vec, params)
|
|
69
52
|
|
|
70
|
-
class BlockPartition(
|
|
53
|
+
class BlockPartition(ProjectionBase):
|
|
71
54
|
"""splits parameters into blocks (for now flatttens them and chunks)"""
|
|
72
55
|
def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
|
|
73
56
|
defaults = dict(max_size=max_size, batched=batched)
|
|
74
57
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
75
58
|
|
|
76
59
|
@torch.no_grad
|
|
77
|
-
def project(self, tensors,
|
|
60
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
78
61
|
partitioned = []
|
|
79
|
-
for p,t in zip(
|
|
62
|
+
for p,t in zip(params, tensors):
|
|
80
63
|
settings = self.settings[p]
|
|
81
64
|
max_size = settings['max_size']
|
|
82
65
|
n = t.numel()
|
|
@@ -101,10 +84,10 @@ class BlockPartition(Projection):
|
|
|
101
84
|
return partitioned
|
|
102
85
|
|
|
103
86
|
@torch.no_grad
|
|
104
|
-
def unproject(self,
|
|
105
|
-
ti = iter(
|
|
87
|
+
def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
|
|
88
|
+
ti = iter(projected_tensors)
|
|
106
89
|
unprojected = []
|
|
107
|
-
for p in
|
|
90
|
+
for p in params:
|
|
108
91
|
settings = self.settings[p]
|
|
109
92
|
n = p.numel()
|
|
110
93
|
|
|
@@ -124,28 +107,3 @@ class BlockPartition(Projection):
|
|
|
124
107
|
|
|
125
108
|
return unprojected
|
|
126
109
|
|
|
127
|
-
|
|
128
|
-
class TensorNormsProjection(Projection):
|
|
129
|
-
def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
|
|
130
|
-
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
|
|
131
|
-
|
|
132
|
-
@torch.no_grad
|
|
133
|
-
def project(self, tensors, var, current):
|
|
134
|
-
orig = self.get_state(var.params, f'{current}_orig')
|
|
135
|
-
torch._foreach_copy_(orig, tensors)
|
|
136
|
-
|
|
137
|
-
norms = torch._foreach_norm(tensors)
|
|
138
|
-
self.get_state(var.params, f'{current}_orig_norms', cls=TensorList).set_(norms)
|
|
139
|
-
|
|
140
|
-
return [torch.stack(norms)]
|
|
141
|
-
|
|
142
|
-
@torch.no_grad
|
|
143
|
-
def unproject(self, tensors, var, current):
|
|
144
|
-
orig = self.get_state(var.params, f'{current}_orig')
|
|
145
|
-
orig_norms = torch.stack(self.get_state(var.params, f'{current}_orig_norms'))
|
|
146
|
-
target_norms = tensors[0]
|
|
147
|
-
|
|
148
|
-
orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
|
|
149
|
-
|
|
150
|
-
torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
|
|
151
|
-
return orig
|
torchzero/modules/functional.py
CHANGED
|
@@ -7,10 +7,19 @@ storage is always indicated in the docstring.
|
|
|
7
7
|
|
|
8
8
|
Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
|
|
9
9
|
"""
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import overload
|
|
10
12
|
|
|
11
|
-
|
|
13
|
+
import torch
|
|
12
14
|
|
|
13
|
-
from ..utils import
|
|
15
|
+
from ..utils import (
|
|
16
|
+
NumberList,
|
|
17
|
+
TensorList,
|
|
18
|
+
generic_finfo_eps,
|
|
19
|
+
generic_max,
|
|
20
|
+
generic_sum,
|
|
21
|
+
tofloat,
|
|
22
|
+
)
|
|
14
23
|
|
|
15
24
|
inf = float('inf')
|
|
16
25
|
|
|
@@ -86,10 +95,10 @@ def root(tensors_:TensorList, p:float, inplace: bool):
|
|
|
86
95
|
if p == 1: return tensors_.abs_()
|
|
87
96
|
if p == 2: return tensors_.sqrt_()
|
|
88
97
|
return tensors_.pow_(1/p)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
98
|
+
|
|
99
|
+
if p == 1: return tensors_.abs()
|
|
100
|
+
if p == 2: return tensors_.sqrt()
|
|
101
|
+
return tensors_.pow(1/p)
|
|
93
102
|
|
|
94
103
|
|
|
95
104
|
def ema_(
|
|
@@ -206,4 +215,41 @@ def sqrt_centered_ema_sq_(
|
|
|
206
215
|
ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
|
|
207
216
|
)
|
|
208
217
|
|
|
218
|
+
def initial_step_size(tensors: torch.Tensor | TensorList, eps=None) -> float:
|
|
219
|
+
"""initial scaling taken from pytorch L-BFGS to avoid requiring a lot of line search iterations,
|
|
220
|
+
this version is safer and makes sure largest value isn't smaller than epsilon."""
|
|
221
|
+
tensors_abs = tensors.abs()
|
|
222
|
+
tensors_sum = generic_sum(tensors_abs)
|
|
223
|
+
tensors_max = generic_max(tensors_abs)
|
|
224
|
+
|
|
225
|
+
feps = generic_finfo_eps(tensors)
|
|
226
|
+
if eps is None: eps = feps
|
|
227
|
+
else: eps = max(eps, feps)
|
|
228
|
+
|
|
229
|
+
# scale should not make largest value smaller than epsilon
|
|
230
|
+
min = eps / tensors_max
|
|
231
|
+
if min >= 1: return 1.0
|
|
232
|
+
|
|
233
|
+
scale = 1 / tensors_sum
|
|
234
|
+
scale = scale.clip(min=min.item(), max=1)
|
|
235
|
+
return scale.item()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def epsilon_step_size(tensors: torch.Tensor | TensorList, alpha=1e-7) -> float:
|
|
239
|
+
"""makes sure largest value isn't smaller than epsilon."""
|
|
240
|
+
tensors_abs = tensors.abs()
|
|
241
|
+
tensors_max = generic_max(tensors_abs)
|
|
242
|
+
if tensors_max < alpha: return 1.0
|
|
243
|
+
|
|
244
|
+
if tensors_max < 1: alpha = alpha / tensors_max
|
|
245
|
+
return tofloat(alpha)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def safe_clip(x: torch.Tensor, min=None):
|
|
250
|
+
"""makes sure absolute value of scalar tensor x is not smaller than min"""
|
|
251
|
+
assert x.numel() == 1, x.shape
|
|
252
|
+
if min is None: min = torch.finfo(x.dtype).tiny * 2
|
|
209
253
|
|
|
254
|
+
if x.abs() < min: return x.new_full(x.size(), min).copysign(x)
|
|
255
|
+
return x
|