torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import TensorwiseTransform
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def eigh_solve(H: torch.Tensor, g: torch.Tensor):
|
|
9
|
+
try:
|
|
10
|
+
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
11
|
+
return Q @ ((Q.mH @ g) / L)
|
|
12
|
+
except torch.linalg.LinAlgError:
|
|
13
|
+
return None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HNewton(TensorwiseTransform):
|
|
17
|
+
"""This treats gradient differences as Hvps with vectors being parameter differences, using past gradients that are close to each other. Basically this is another limited memory quasi newton method to test.
|
|
18
|
+
|
|
19
|
+
.. warning::
|
|
20
|
+
Experimental.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self, history_size: int, window_size: int, reg: float=0, tol: float = 1e-8, concat_params:bool=True, inner=None):
|
|
24
|
+
defaults = dict(history_size=history_size, window_size=window_size, reg=reg, tol=tol)
|
|
25
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner)
|
|
26
|
+
|
|
27
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
28
|
+
|
|
29
|
+
history_size = setting['history_size']
|
|
30
|
+
|
|
31
|
+
if 'param_history' not in state:
|
|
32
|
+
state['param_history'] = deque(maxlen=history_size)
|
|
33
|
+
state['grad_history'] = deque(maxlen=history_size)
|
|
34
|
+
|
|
35
|
+
param_history: deque = state['param_history']
|
|
36
|
+
grad_history: deque = state['grad_history']
|
|
37
|
+
param_history.append(param.ravel())
|
|
38
|
+
grad_history.append(tensor.ravel())
|
|
39
|
+
|
|
40
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
41
|
+
window_size = setting['window_size']
|
|
42
|
+
reg = setting['reg']
|
|
43
|
+
tol = setting['tol']
|
|
44
|
+
|
|
45
|
+
param_history: deque = state['param_history']
|
|
46
|
+
grad_history: deque = state['grad_history']
|
|
47
|
+
g = tensor.ravel()
|
|
48
|
+
|
|
49
|
+
n = len(param_history)
|
|
50
|
+
s_list = []
|
|
51
|
+
y_list = []
|
|
52
|
+
|
|
53
|
+
for i in range(n):
|
|
54
|
+
for j in range(i):
|
|
55
|
+
if i - j <= window_size:
|
|
56
|
+
p_i, g_i = param_history[i], grad_history[i]
|
|
57
|
+
p_j, g_j = param_history[j], grad_history[j]
|
|
58
|
+
s = p_i - p_j # vec in hvp
|
|
59
|
+
y = g_i - g_j # hvp
|
|
60
|
+
if s.dot(y) > tol:
|
|
61
|
+
s_list.append(s)
|
|
62
|
+
y_list.append(y)
|
|
63
|
+
|
|
64
|
+
if len(s_list) < 1:
|
|
65
|
+
scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
|
|
66
|
+
tensor.mul_(scale)
|
|
67
|
+
return tensor
|
|
68
|
+
|
|
69
|
+
S = torch.stack(s_list, 1)
|
|
70
|
+
Y = torch.stack(y_list, 1)
|
|
71
|
+
|
|
72
|
+
B = S.T @ Y
|
|
73
|
+
if reg != 0: B.add_(torch.eye(B.size(0), device=B.device, dtype=B.dtype).mul_(reg))
|
|
74
|
+
g_proj = g @ S
|
|
75
|
+
|
|
76
|
+
newton_proj, info = torch.linalg.solve_ex(B, g_proj) # pylint:disable=not-callable
|
|
77
|
+
if info != 0:
|
|
78
|
+
newton_proj = -torch.linalg.lstsq(B, g_proj).solution # pylint:disable=not-callable
|
|
79
|
+
newton = S @ newton_proj
|
|
80
|
+
return newton.view_as(tensor)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
|
|
84
|
+
# tensor.mul_(scale)
|
|
85
|
+
# return tensor
|
|
@@ -4,8 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from
|
|
7
|
+
from ...core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def _adaptive_damping(
|
|
@@ -28,7 +28,7 @@ def _adaptive_damping(
|
|
|
28
28
|
|
|
29
29
|
def lbfgs(
|
|
30
30
|
tensors_: TensorList,
|
|
31
|
-
|
|
31
|
+
var: Var,
|
|
32
32
|
s_history: deque[TensorList],
|
|
33
33
|
y_history: deque[TensorList],
|
|
34
34
|
sy_history: deque[torch.Tensor],
|
|
@@ -43,58 +43,57 @@ def lbfgs(
|
|
|
43
43
|
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
44
44
|
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
return z
|
|
46
|
+
# 1st loop
|
|
47
|
+
alpha_list = []
|
|
48
|
+
q = tensors_.clone()
|
|
49
|
+
for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
50
|
+
p_i = 1 / ys_i # this is also denoted as ρ (rho)
|
|
51
|
+
alpha = p_i * s_i.dot(q)
|
|
52
|
+
alpha_list.append(alpha)
|
|
53
|
+
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
54
|
+
|
|
55
|
+
# calculate z
|
|
56
|
+
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
57
|
+
# z is it times q
|
|
58
|
+
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
59
|
+
z = q * (ys_k / (y_k.dot(y_k)))
|
|
60
|
+
|
|
61
|
+
if z_tfm is not None:
|
|
62
|
+
z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
|
|
63
|
+
|
|
64
|
+
# 2nd loop
|
|
65
|
+
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
66
|
+
p_i = 1 / ys_i
|
|
67
|
+
beta_i = p_i * y_i.dot(z)
|
|
68
|
+
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
69
|
+
|
|
70
|
+
return z
|
|
72
71
|
|
|
73
72
|
def _apply_tfms_into_history(
|
|
74
73
|
self: Module,
|
|
75
74
|
params: list[torch.Tensor],
|
|
76
|
-
|
|
75
|
+
var: Var,
|
|
77
76
|
update: list[torch.Tensor],
|
|
78
77
|
):
|
|
79
78
|
if 'params_history_tfm' in self.children:
|
|
80
|
-
params =
|
|
79
|
+
params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
81
80
|
|
|
82
81
|
if 'grad_history_tfm' in self.children:
|
|
83
|
-
update =
|
|
82
|
+
update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
|
|
84
83
|
|
|
85
84
|
return params, update
|
|
86
85
|
|
|
87
86
|
def _apply_tfms_into_precond(
|
|
88
87
|
self: Module,
|
|
89
88
|
params: list[torch.Tensor],
|
|
90
|
-
|
|
89
|
+
var: Var,
|
|
91
90
|
update: list[torch.Tensor],
|
|
92
91
|
):
|
|
93
92
|
if 'params_precond_tfm' in self.children:
|
|
94
|
-
params =
|
|
93
|
+
params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
95
94
|
|
|
96
95
|
if 'grad_precond_tfm' in self.children:
|
|
97
|
-
update =
|
|
96
|
+
update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
|
|
98
97
|
|
|
99
98
|
return params, update
|
|
100
99
|
|
|
@@ -165,9 +164,9 @@ class ModularLBFGS(Module):
|
|
|
165
164
|
self.global_state['sy_history'].clear()
|
|
166
165
|
|
|
167
166
|
@torch.no_grad
|
|
168
|
-
def step(self,
|
|
169
|
-
params = as_tensorlist(
|
|
170
|
-
update = as_tensorlist(
|
|
167
|
+
def step(self, var):
|
|
168
|
+
params = as_tensorlist(var.params)
|
|
169
|
+
update = as_tensorlist(var.get_update())
|
|
171
170
|
step = self.global_state.get('step', 0)
|
|
172
171
|
self.global_state['step'] = step + 1
|
|
173
172
|
|
|
@@ -186,11 +185,11 @@ class ModularLBFGS(Module):
|
|
|
186
185
|
params_h, update_h = _apply_tfms_into_history(
|
|
187
186
|
self,
|
|
188
187
|
params=params,
|
|
189
|
-
|
|
188
|
+
var=var,
|
|
190
189
|
update=update,
|
|
191
190
|
)
|
|
192
191
|
|
|
193
|
-
prev_params_h, prev_grad_h = self.get_state('prev_params_h', 'prev_grad_h',
|
|
192
|
+
prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
|
|
194
193
|
|
|
195
194
|
# 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
|
|
196
195
|
if step == 0:
|
|
@@ -217,16 +216,16 @@ class ModularLBFGS(Module):
|
|
|
217
216
|
# step with inner module before applying preconditioner
|
|
218
217
|
if 'update_precond_tfm' in self.children:
|
|
219
218
|
update_precond_tfm = self.children['update_precond_tfm']
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
tensors =
|
|
219
|
+
inner_var = update_precond_tfm.step(var.clone(clone_update=True))
|
|
220
|
+
var.update_attrs_from_clone_(inner_var)
|
|
221
|
+
tensors = inner_var.update
|
|
223
222
|
assert tensors is not None
|
|
224
223
|
else:
|
|
225
224
|
tensors = update.clone()
|
|
226
225
|
|
|
227
226
|
# transforms into preconditioner
|
|
228
|
-
params_p, update_p = _apply_tfms_into_precond(self, params=params,
|
|
229
|
-
prev_params_p, prev_grad_p = self.get_state('prev_params_p', 'prev_grad_p',
|
|
227
|
+
params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
|
|
228
|
+
prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
|
|
230
229
|
|
|
231
230
|
if step == 0:
|
|
232
231
|
s_k_p = None; y_k_p = None; ys_k_p = None
|
|
@@ -245,13 +244,13 @@ class ModularLBFGS(Module):
|
|
|
245
244
|
# tolerance on gradient difference to avoid exploding after converging
|
|
246
245
|
if tol is not None:
|
|
247
246
|
if y_k_p is not None and y_k_p.abs().global_max() <= tol:
|
|
248
|
-
|
|
249
|
-
return
|
|
247
|
+
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
248
|
+
return var
|
|
250
249
|
|
|
251
250
|
# precondition
|
|
252
251
|
dir = lbfgs(
|
|
253
252
|
tensors_=as_tensorlist(tensors),
|
|
254
|
-
|
|
253
|
+
var=var,
|
|
255
254
|
s_history=s_history,
|
|
256
255
|
y_history=y_history,
|
|
257
256
|
sy_history=sy_history,
|
|
@@ -260,7 +259,7 @@ class ModularLBFGS(Module):
|
|
|
260
259
|
z_tfm=self.children.get('z_tfm', None),
|
|
261
260
|
)
|
|
262
261
|
|
|
263
|
-
|
|
262
|
+
var.update = dir
|
|
264
263
|
|
|
265
|
-
return
|
|
264
|
+
return var
|
|
266
265
|
|
|
@@ -3,13 +3,13 @@ from typing import Any, Literal, overload
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Chainable, Module,
|
|
6
|
+
from ...core import Chainable, Module, apply_transform, Modular
|
|
7
7
|
from ...utils import TensorList, as_tensorlist
|
|
8
8
|
from ...utils.derivatives import hvp
|
|
9
9
|
from ..quasi_newton import LBFGS
|
|
10
10
|
|
|
11
11
|
class NewtonSolver(Module):
|
|
12
|
-
"""Matrix free newton via with any custom solver (
|
|
12
|
+
"""Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
|
|
13
13
|
def __init__(
|
|
14
14
|
self,
|
|
15
15
|
solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
|
|
@@ -26,9 +26,9 @@ class NewtonSolver(Module):
|
|
|
26
26
|
self.set_child('inner', inner)
|
|
27
27
|
|
|
28
28
|
@torch.no_grad
|
|
29
|
-
def step(self,
|
|
30
|
-
params = TensorList(
|
|
31
|
-
closure =
|
|
29
|
+
def step(self, var):
|
|
30
|
+
params = TensorList(var.params)
|
|
31
|
+
closure = var.closure
|
|
32
32
|
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
33
33
|
|
|
34
34
|
settings = self.settings[params[0]]
|
|
@@ -39,7 +39,7 @@ class NewtonSolver(Module):
|
|
|
39
39
|
warm_start = settings['warm_start']
|
|
40
40
|
|
|
41
41
|
# ---------------------- Hessian vector product function --------------------- #
|
|
42
|
-
grad =
|
|
42
|
+
grad = var.get_grad(create_graph=True)
|
|
43
43
|
|
|
44
44
|
def H_mm(x):
|
|
45
45
|
with torch.enable_grad():
|
|
@@ -50,11 +50,11 @@ class NewtonSolver(Module):
|
|
|
50
50
|
# -------------------------------- inner step -------------------------------- #
|
|
51
51
|
b = as_tensorlist(grad)
|
|
52
52
|
if 'inner' in self.children:
|
|
53
|
-
b = as_tensorlist(
|
|
53
|
+
b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
|
|
54
54
|
|
|
55
55
|
# ---------------------------------- run cg ---------------------------------- #
|
|
56
56
|
x0 = None
|
|
57
|
-
if warm_start: x0 = self.get_state('prev_x',
|
|
57
|
+
if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
|
|
58
58
|
if x0 is None: x = b.zeros_like().requires_grad_(True)
|
|
59
59
|
else: x = x0.clone().requires_grad_(True)
|
|
60
60
|
|
|
@@ -76,13 +76,13 @@ class NewtonSolver(Module):
|
|
|
76
76
|
assert loss is not None
|
|
77
77
|
if min(loss, loss/initial_loss) < tol: break
|
|
78
78
|
|
|
79
|
-
print(f'{loss = }')
|
|
79
|
+
# print(f'{loss = }')
|
|
80
80
|
|
|
81
81
|
if warm_start:
|
|
82
82
|
assert x0 is not None
|
|
83
83
|
x0.copy_(x)
|
|
84
84
|
|
|
85
|
-
|
|
86
|
-
return
|
|
85
|
+
var.update = x.detach()
|
|
86
|
+
return var
|
|
87
87
|
|
|
88
88
|
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ...core import Chainable, Module, apply_transform
|
|
11
|
+
from ...utils import TensorList, vec_to_tensors
|
|
12
|
+
from ...utils.derivatives import (
|
|
13
|
+
hessian_list_to_mat,
|
|
14
|
+
jacobian_wrt,
|
|
15
|
+
)
|
|
16
|
+
from ..second_order.newton import (
|
|
17
|
+
cholesky_solve,
|
|
18
|
+
eigh_solve,
|
|
19
|
+
least_squares_solve,
|
|
20
|
+
lu_solve,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NewtonNewton(Module):
|
|
25
|
+
"""Applies Newton-like preconditioning to Newton step.
|
|
26
|
+
|
|
27
|
+
This is a method that I thought of and then it worked. Here is how it works:
|
|
28
|
+
|
|
29
|
+
1. Calculate newton step by solving Hx=g
|
|
30
|
+
|
|
31
|
+
2. Calculate jacobian of x wrt parameters and call it H2
|
|
32
|
+
|
|
33
|
+
3. Solve H2 x2 = x for x2.
|
|
34
|
+
|
|
35
|
+
4. Optionally, repeat (if order is higher than 3.)
|
|
36
|
+
|
|
37
|
+
Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
|
|
38
|
+
|
|
39
|
+
3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
|
|
40
|
+
this is if pytorch can vectorize hessian computation efficiently.
|
|
41
|
+
"""
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
reg: float = 1e-6,
|
|
45
|
+
order: int = 3,
|
|
46
|
+
search_negative: bool = False,
|
|
47
|
+
vectorize: bool = True,
|
|
48
|
+
eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
49
|
+
):
|
|
50
|
+
defaults = dict(order=order, reg=reg, vectorize=vectorize, eigval_tfm=eigval_tfm, search_negative=search_negative)
|
|
51
|
+
super().__init__(defaults)
|
|
52
|
+
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def step(self, var):
|
|
55
|
+
params = TensorList(var.params)
|
|
56
|
+
closure = var.closure
|
|
57
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
58
|
+
|
|
59
|
+
settings = self.settings[params[0]]
|
|
60
|
+
reg = settings['reg']
|
|
61
|
+
vectorize = settings['vectorize']
|
|
62
|
+
order = settings['order']
|
|
63
|
+
search_negative = settings['search_negative']
|
|
64
|
+
eigval_tfm = settings['eigval_tfm']
|
|
65
|
+
|
|
66
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
67
|
+
with torch.enable_grad():
|
|
68
|
+
loss = var.loss = var.loss_approx = closure(False)
|
|
69
|
+
g_list = torch.autograd.grad(loss, params, create_graph=True)
|
|
70
|
+
var.grad = list(g_list)
|
|
71
|
+
|
|
72
|
+
xp = torch.cat([t.ravel() for t in g_list])
|
|
73
|
+
I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
|
|
74
|
+
|
|
75
|
+
for o in range(2, order + 1):
|
|
76
|
+
is_last = o == order
|
|
77
|
+
H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
|
|
78
|
+
with torch.no_grad() if is_last else nullcontext():
|
|
79
|
+
H = hessian_list_to_mat(H_list)
|
|
80
|
+
if reg != 0: H = H + I * reg
|
|
81
|
+
|
|
82
|
+
x = None
|
|
83
|
+
if search_negative or (is_last and eigval_tfm is not None):
|
|
84
|
+
x = eigh_solve(H, xp, eigval_tfm, search_negative=search_negative)
|
|
85
|
+
if x is None: x = cholesky_solve(H, xp)
|
|
86
|
+
if x is None: x = lu_solve(H, xp)
|
|
87
|
+
if x is None: x = least_squares_solve(H, xp)
|
|
88
|
+
xp = x.squeeze()
|
|
89
|
+
|
|
90
|
+
var.update = vec_to_tensors(xp.nan_to_num_(0,0,0), params)
|
|
91
|
+
return var
|
|
92
|
+
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Module
|
|
8
|
+
from ...utils import TensorList
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def adaptive_tracking(
|
|
13
|
+
f,
|
|
14
|
+
f_0,
|
|
15
|
+
f_1,
|
|
16
|
+
t_0,
|
|
17
|
+
maxiter: int
|
|
18
|
+
):
|
|
19
|
+
|
|
20
|
+
t = t_0
|
|
21
|
+
f_t = f(t)
|
|
22
|
+
|
|
23
|
+
# backtrack
|
|
24
|
+
if f_t > f_0:
|
|
25
|
+
if f_1 > f_0: t = min(0.5, t_0/2)
|
|
26
|
+
while f_t > f_0:
|
|
27
|
+
maxiter -= 1
|
|
28
|
+
if maxiter < 0: return 0, f_0
|
|
29
|
+
t = t/2
|
|
30
|
+
f_t = f(t) if t!=1 else f_1
|
|
31
|
+
return t, f_t
|
|
32
|
+
|
|
33
|
+
# forwardtrack
|
|
34
|
+
f_prev = f_t
|
|
35
|
+
t *= 2
|
|
36
|
+
f_t = f(t)
|
|
37
|
+
if f_prev < f_t: return t/2, f_prev
|
|
38
|
+
while f_prev >= f_t:
|
|
39
|
+
maxiter -= 1
|
|
40
|
+
if maxiter < 0: return t, f_t
|
|
41
|
+
f_prev = f_t
|
|
42
|
+
t *= 2
|
|
43
|
+
f_t = f(t)
|
|
44
|
+
return t/2, f_prev
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ParabolaSearch(Module):
|
|
49
|
+
""""""
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
step_size: float = 1e-2,
|
|
53
|
+
adaptive: bool=True,
|
|
54
|
+
normalize: bool=False,
|
|
55
|
+
# method: str | None = None,
|
|
56
|
+
maxiter: int | None = 10,
|
|
57
|
+
# bracket=None,
|
|
58
|
+
# bounds=None,
|
|
59
|
+
# tol: float | None = None,
|
|
60
|
+
# options=None,
|
|
61
|
+
):
|
|
62
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
63
|
+
defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
|
|
64
|
+
super().__init__(defaults)
|
|
65
|
+
|
|
66
|
+
import scipy.optimize
|
|
67
|
+
self.scopt = scipy.optimize
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@torch.no_grad
|
|
71
|
+
def step(self, var):
|
|
72
|
+
x_0 = TensorList(var.params)
|
|
73
|
+
closure = var.closure
|
|
74
|
+
assert closure is not None
|
|
75
|
+
settings = self.settings[x_0[0]]
|
|
76
|
+
step_size = settings['step_size']
|
|
77
|
+
adaptive = settings['adaptive']
|
|
78
|
+
normalize = settings['normalize']
|
|
79
|
+
maxiter = settings['maxiter']
|
|
80
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
81
|
+
|
|
82
|
+
grad = TensorList(var.get_grad())
|
|
83
|
+
f_0 = var.get_loss(False)
|
|
84
|
+
|
|
85
|
+
scale = 1
|
|
86
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
87
|
+
if adaptive: scale = grad.abs().mean().clip(min=1e-8)
|
|
88
|
+
|
|
89
|
+
# make step
|
|
90
|
+
v_0 = grad * (step_size/scale)
|
|
91
|
+
x_0 -= v_0
|
|
92
|
+
with torch.enable_grad():
|
|
93
|
+
f_1 = closure()
|
|
94
|
+
grad = x_0.grad
|
|
95
|
+
|
|
96
|
+
x_0 += v_0
|
|
97
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
98
|
+
v_1 = grad * (step_size/scale)
|
|
99
|
+
a = v_1 - v_0
|
|
100
|
+
|
|
101
|
+
def parabolic_objective(t: float):
|
|
102
|
+
nonlocal x_0
|
|
103
|
+
|
|
104
|
+
step = v_0*t + 0.5*a*t**2
|
|
105
|
+
x_0 -= step
|
|
106
|
+
value = closure(False)
|
|
107
|
+
x_0 += step
|
|
108
|
+
return value.detach().cpu()
|
|
109
|
+
|
|
110
|
+
prev_t = self.global_state.get('prev_t', 2)
|
|
111
|
+
t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
|
|
112
|
+
self.global_state['prev_t'] = t
|
|
113
|
+
|
|
114
|
+
# method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
115
|
+
# 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
116
|
+
|
|
117
|
+
# if maxiter is not None:
|
|
118
|
+
# options = dict(options) if isinstance(options, Mapping) else {}
|
|
119
|
+
# options['maxiter'] = maxiter
|
|
120
|
+
|
|
121
|
+
# res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
|
|
122
|
+
# t = res.x
|
|
123
|
+
|
|
124
|
+
var.update = v_0*t + 0.5*a*t**2
|
|
125
|
+
return var
|
|
126
|
+
|
|
127
|
+
class CubicParabolaSearch(Module):
|
|
128
|
+
""""""
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
step_size: float = 1e-2,
|
|
132
|
+
adaptive: bool=True,
|
|
133
|
+
normalize: bool=False,
|
|
134
|
+
# method: str | None = None,
|
|
135
|
+
maxiter: int | None = 10,
|
|
136
|
+
# bracket=None,
|
|
137
|
+
# bounds=None,
|
|
138
|
+
# tol: float | None = None,
|
|
139
|
+
# options=None,
|
|
140
|
+
):
|
|
141
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
142
|
+
defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
|
|
143
|
+
super().__init__(defaults)
|
|
144
|
+
|
|
145
|
+
import scipy.optimize
|
|
146
|
+
self.scopt = scipy.optimize
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@torch.no_grad
|
|
150
|
+
def step(self, var):
|
|
151
|
+
x_0 = TensorList(var.params)
|
|
152
|
+
closure = var.closure
|
|
153
|
+
assert closure is not None
|
|
154
|
+
settings = self.settings[x_0[0]]
|
|
155
|
+
step_size = settings['step_size']
|
|
156
|
+
adaptive = settings['adaptive']
|
|
157
|
+
maxiter = settings['maxiter']
|
|
158
|
+
normalize = settings['normalize']
|
|
159
|
+
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
160
|
+
|
|
161
|
+
grad = TensorList(var.get_grad())
|
|
162
|
+
f_0 = var.get_loss(False)
|
|
163
|
+
|
|
164
|
+
scale = 1
|
|
165
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
166
|
+
if adaptive: scale = grad.abs().mean().clip(min=1e-8)
|
|
167
|
+
|
|
168
|
+
# make step
|
|
169
|
+
v_0 = grad * (step_size/scale)
|
|
170
|
+
x_0 -= v_0
|
|
171
|
+
with torch.enable_grad():
|
|
172
|
+
f_1 = closure()
|
|
173
|
+
grad = x_0.grad
|
|
174
|
+
|
|
175
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
176
|
+
v_1 = grad * (step_size/scale)
|
|
177
|
+
a_0 = v_1 - v_0
|
|
178
|
+
|
|
179
|
+
# make another step
|
|
180
|
+
x_0 -= v_1
|
|
181
|
+
with torch.enable_grad():
|
|
182
|
+
f_2 = closure()
|
|
183
|
+
grad = x_0.grad
|
|
184
|
+
|
|
185
|
+
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
186
|
+
v_2 = grad * (step_size/scale)
|
|
187
|
+
a_1 = v_2 - v_1
|
|
188
|
+
|
|
189
|
+
j = a_1 - a_0
|
|
190
|
+
|
|
191
|
+
x_0 += v_0
|
|
192
|
+
x_0 += v_1
|
|
193
|
+
|
|
194
|
+
def parabolic_objective(t: float):
|
|
195
|
+
nonlocal x_0
|
|
196
|
+
|
|
197
|
+
step = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
|
|
198
|
+
x_0 -= step
|
|
199
|
+
value = closure(False)
|
|
200
|
+
x_0 += step
|
|
201
|
+
return value
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
prev_t = self.global_state.get('prev_t', 2)
|
|
205
|
+
t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
|
|
206
|
+
self.global_state['prev_t'] = t
|
|
207
|
+
|
|
208
|
+
# method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
209
|
+
# 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
210
|
+
|
|
211
|
+
# if maxiter is not None:
|
|
212
|
+
# options = dict(options) if isinstance(options, Mapping) else {}
|
|
213
|
+
# options['maxiter'] = maxiter
|
|
214
|
+
|
|
215
|
+
# res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
|
|
216
|
+
# t = res.x
|
|
217
|
+
|
|
218
|
+
var.update = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
|
|
219
|
+
return var
|
|
220
|
+
|