torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -1,165 +1,142 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from ...
|
|
8
|
-
from ...
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
if
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
else:
|
|
144
|
-
newton_step, success = self.solver(hessian, gvec)
|
|
145
|
-
if not success:
|
|
146
|
-
newton_step, success = self.fallback(hessian, gvec)
|
|
147
|
-
if not success:
|
|
148
|
-
newton_step, success = _fallback_gd(hessian, gvec)
|
|
149
|
-
|
|
150
|
-
# apply the `_update` method
|
|
151
|
-
vars.ascent = grads.from_vec(newton_step.squeeze_().nan_to_num_(0,0,0))
|
|
152
|
-
|
|
153
|
-
# validate if newton step decreased loss
|
|
154
|
-
if self.validate:
|
|
155
|
-
|
|
156
|
-
params.sub_(vars.ascent)
|
|
157
|
-
fx1 = vars.closure(False)
|
|
158
|
-
params.add_(vars.ascent)
|
|
159
|
-
|
|
160
|
-
# if loss increases, set ascent direction to grad times lr
|
|
161
|
-
if (not fx1.isfinite()) or fx1 - vars.fx0 > vars.fx0 * self.tol: # type:ignore
|
|
162
|
-
vars.ascent = grads.div_(grads.total_vector_norm(2) / self.gd_lr)
|
|
163
|
-
|
|
164
|
-
# peform an update with the ascent direction, or pass it to the child.
|
|
165
|
-
return self._update_params_or_step_with_next(vars, params=params)
|
|
1
|
+
import warnings
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, apply, Module
|
|
8
|
+
from ...utils import vec_to_tensors, TensorList
|
|
9
|
+
from ...utils.derivatives import (
|
|
10
|
+
hessian_list_to_mat,
|
|
11
|
+
hessian_mat,
|
|
12
|
+
jacobian_and_hessian_wrt,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def lu_solve(H: torch.Tensor, g: torch.Tensor):
|
|
17
|
+
x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
|
|
18
|
+
if info == 0: return x
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
|
|
22
|
+
x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
|
|
23
|
+
if info == 0:
|
|
24
|
+
g.unsqueeze_(1)
|
|
25
|
+
return torch.cholesky_solve(g, x)
|
|
26
|
+
return None
|
|
27
|
+
|
|
28
|
+
def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
|
|
29
|
+
return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
|
|
30
|
+
|
|
31
|
+
def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
|
|
32
|
+
try:
|
|
33
|
+
L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
|
|
34
|
+
if tfm is not None: L = tfm(L)
|
|
35
|
+
L.reciprocal_()
|
|
36
|
+
return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
|
|
37
|
+
except torch.linalg.LinAlgError:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
def tikhonov_(H: torch.Tensor, reg: float):
|
|
41
|
+
if reg!=0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(reg))
|
|
42
|
+
return H
|
|
43
|
+
|
|
44
|
+
def eig_tikhonov_(H: torch.Tensor, reg: float):
|
|
45
|
+
v = torch.linalg.eigvalsh(H).min().clamp_(max=0).neg_() + reg # pylint:disable=not-callable
|
|
46
|
+
return tikhonov_(H, v)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Newton(Module):
|
|
50
|
+
"""Exact newton via autograd.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
54
|
+
eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
|
|
55
|
+
hessian_method (str):
|
|
56
|
+
how to calculate hessian. Defaults to "autograd".
|
|
57
|
+
vectorize (bool, optional):
|
|
58
|
+
whether to enable vectorized hessian. Defaults to True.
|
|
59
|
+
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
60
|
+
H_tfm (Callable | None, optional):
|
|
61
|
+
optional hessian transforms, takes in two arguments - `(hessian, gradient)`.
|
|
62
|
+
|
|
63
|
+
must return a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
|
|
64
|
+
which must be True if transform inverted the hessian and False otherwise. Defaults to None.
|
|
65
|
+
eigval_tfm (Callable | None, optional):
|
|
66
|
+
optional eigenvalues transform, for example :code:`torch.abs` or :code:`lambda L: torch.clip(L, min=1e-8)`.
|
|
67
|
+
If this is specified, eigendecomposition will be used to solve Hx = g.
|
|
68
|
+
|
|
69
|
+
"""
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
reg: float = 1e-6,
|
|
73
|
+
eig_reg: bool = False,
|
|
74
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
75
|
+
vectorize: bool = True,
|
|
76
|
+
inner: Chainable | None = None,
|
|
77
|
+
H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
|
|
78
|
+
eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
79
|
+
):
|
|
80
|
+
defaults = dict(reg=reg, eig_reg=eig_reg, abs=abs,hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm)
|
|
81
|
+
super().__init__(defaults)
|
|
82
|
+
|
|
83
|
+
if inner is not None:
|
|
84
|
+
self.set_child('inner', inner)
|
|
85
|
+
|
|
86
|
+
@torch.no_grad
|
|
87
|
+
def step(self, vars):
|
|
88
|
+
params = TensorList(vars.params)
|
|
89
|
+
closure = vars.closure
|
|
90
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
91
|
+
|
|
92
|
+
settings = self.settings[params[0]]
|
|
93
|
+
reg = settings['reg']
|
|
94
|
+
eig_reg = settings['eig_reg']
|
|
95
|
+
hessian_method = settings['hessian_method']
|
|
96
|
+
vectorize = settings['vectorize']
|
|
97
|
+
H_tfm = settings['H_tfm']
|
|
98
|
+
eigval_tfm = settings['eigval_tfm']
|
|
99
|
+
|
|
100
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
101
|
+
if hessian_method == 'autograd':
|
|
102
|
+
with torch.enable_grad():
|
|
103
|
+
loss = vars.loss = vars.loss_approx = closure(False)
|
|
104
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
105
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
106
|
+
vars.grad = g_list
|
|
107
|
+
H = hessian_list_to_mat(H_list)
|
|
108
|
+
|
|
109
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
110
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
111
|
+
with torch.enable_grad():
|
|
112
|
+
g_list = vars.get_grad(retain_graph=True)
|
|
113
|
+
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
114
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(hessian_method)
|
|
118
|
+
|
|
119
|
+
# -------------------------------- inner step -------------------------------- #
|
|
120
|
+
if 'inner' in self.children:
|
|
121
|
+
g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
|
|
122
|
+
g = torch.cat([t.view(-1) for t in g_list])
|
|
123
|
+
|
|
124
|
+
# ------------------------------- regulazition ------------------------------- #
|
|
125
|
+
if eig_reg: H = eig_tikhonov_(H, reg)
|
|
126
|
+
else: H = tikhonov_(H, reg)
|
|
127
|
+
|
|
128
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
129
|
+
update = None
|
|
130
|
+
if H_tfm is not None:
|
|
131
|
+
H, is_inv = H_tfm(H, g)
|
|
132
|
+
if is_inv: update = H
|
|
133
|
+
|
|
134
|
+
if eigval_tfm is not None:
|
|
135
|
+
update = eigh_solve(H, g, eigval_tfm)
|
|
136
|
+
|
|
137
|
+
if update is None: update = cholesky_solve(H, g)
|
|
138
|
+
if update is None: update = lu_solve(H, g)
|
|
139
|
+
if update is None: update = least_squares_solve(H, g)
|
|
140
|
+
|
|
141
|
+
vars.update = vec_to_tensors(update, params)
|
|
142
|
+
return vars
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal, overload
|
|
3
|
+
import warnings
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
|
|
7
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
+
|
|
9
|
+
from ...core import Chainable, apply, Module
|
|
10
|
+
from ...utils.linalg.solve import cg
|
|
11
|
+
|
|
12
|
+
class NewtonCG(Module):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
maxiter=None,
|
|
16
|
+
tol=1e-3,
|
|
17
|
+
reg: float = 1e-8,
|
|
18
|
+
hvp_method: Literal["forward", "central", "autograd"] = "forward",
|
|
19
|
+
h=1e-3,
|
|
20
|
+
warm_start=False,
|
|
21
|
+
inner: Chainable | None = None,
|
|
22
|
+
):
|
|
23
|
+
defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, warm_start=warm_start)
|
|
24
|
+
super().__init__(defaults,)
|
|
25
|
+
|
|
26
|
+
if inner is not None:
|
|
27
|
+
self.set_child('inner', inner)
|
|
28
|
+
|
|
29
|
+
@torch.no_grad
|
|
30
|
+
def step(self, vars):
|
|
31
|
+
params = TensorList(vars.params)
|
|
32
|
+
closure = vars.closure
|
|
33
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
34
|
+
|
|
35
|
+
settings = self.settings[params[0]]
|
|
36
|
+
tol = settings['tol']
|
|
37
|
+
reg = settings['reg']
|
|
38
|
+
maxiter = settings['maxiter']
|
|
39
|
+
hvp_method = settings['hvp_method']
|
|
40
|
+
h = settings['h']
|
|
41
|
+
warm_start = settings['warm_start']
|
|
42
|
+
|
|
43
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
44
|
+
if hvp_method == 'autograd':
|
|
45
|
+
grad = vars.get_grad(create_graph=True)
|
|
46
|
+
|
|
47
|
+
def H_mm(x):
|
|
48
|
+
with torch.enable_grad():
|
|
49
|
+
return TensorList(hvp(params, grad, x, retain_graph=True))
|
|
50
|
+
|
|
51
|
+
else:
|
|
52
|
+
|
|
53
|
+
with torch.enable_grad():
|
|
54
|
+
grad = vars.get_grad()
|
|
55
|
+
|
|
56
|
+
if hvp_method == 'forward':
|
|
57
|
+
def H_mm(x):
|
|
58
|
+
return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
|
|
59
|
+
|
|
60
|
+
elif hvp_method == 'central':
|
|
61
|
+
def H_mm(x):
|
|
62
|
+
return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(hvp_method)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# -------------------------------- inner step -------------------------------- #
|
|
69
|
+
b = grad
|
|
70
|
+
if 'inner' in self.children:
|
|
71
|
+
b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
|
|
72
|
+
|
|
73
|
+
# ---------------------------------- run cg ---------------------------------- #
|
|
74
|
+
x0 = None
|
|
75
|
+
if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
|
|
76
|
+
x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
|
|
77
|
+
if warm_start:
|
|
78
|
+
assert x0 is not None
|
|
79
|
+
x0.set_(x)
|
|
80
|
+
|
|
81
|
+
vars.update = x
|
|
82
|
+
return vars
|
|
83
|
+
|
|
84
|
+
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal, overload
|
|
3
|
+
import warnings
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel, vec_to_tensors
|
|
7
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
|
+
|
|
9
|
+
from ...core import Chainable, apply, Module
|
|
10
|
+
from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
|
|
11
|
+
|
|
12
|
+
class NystromSketchAndSolve(Module):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
rank: int,
|
|
16
|
+
reg: float = 1e-3,
|
|
17
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
18
|
+
h=1e-3,
|
|
19
|
+
inner: Chainable | None = None,
|
|
20
|
+
seed: int | None = None,
|
|
21
|
+
):
|
|
22
|
+
defaults = dict(rank=rank, reg=reg, hvp_method=hvp_method, h=h, seed=seed)
|
|
23
|
+
super().__init__(defaults,)
|
|
24
|
+
|
|
25
|
+
if inner is not None:
|
|
26
|
+
self.set_child('inner', inner)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def step(self, vars):
|
|
30
|
+
params = TensorList(vars.params)
|
|
31
|
+
|
|
32
|
+
closure = vars.closure
|
|
33
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
34
|
+
|
|
35
|
+
settings = self.settings[params[0]]
|
|
36
|
+
rank = settings['rank']
|
|
37
|
+
reg = settings['reg']
|
|
38
|
+
hvp_method = settings['hvp_method']
|
|
39
|
+
h = settings['h']
|
|
40
|
+
|
|
41
|
+
seed = settings['seed']
|
|
42
|
+
generator = None
|
|
43
|
+
if seed is not None:
|
|
44
|
+
if 'generator' not in self.global_state:
|
|
45
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
46
|
+
generator = self.global_state['generator']
|
|
47
|
+
|
|
48
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
49
|
+
if hvp_method == 'autograd':
|
|
50
|
+
grad = vars.get_grad(create_graph=True)
|
|
51
|
+
|
|
52
|
+
def H_mm(x):
|
|
53
|
+
with torch.enable_grad():
|
|
54
|
+
Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
|
|
55
|
+
return torch.cat([t.ravel() for t in Hvp])
|
|
56
|
+
|
|
57
|
+
else:
|
|
58
|
+
|
|
59
|
+
with torch.enable_grad():
|
|
60
|
+
grad = vars.get_grad()
|
|
61
|
+
|
|
62
|
+
if hvp_method == 'forward':
|
|
63
|
+
def H_mm(x):
|
|
64
|
+
Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
|
|
65
|
+
return torch.cat([t.ravel() for t in Hvp])
|
|
66
|
+
|
|
67
|
+
elif hvp_method == 'central':
|
|
68
|
+
def H_mm(x):
|
|
69
|
+
Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
|
|
70
|
+
return torch.cat([t.ravel() for t in Hvp])
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(hvp_method)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# -------------------------------- inner step -------------------------------- #
|
|
77
|
+
b = grad
|
|
78
|
+
if 'inner' in self.children:
|
|
79
|
+
b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
|
|
80
|
+
|
|
81
|
+
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
82
|
+
x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
|
|
83
|
+
vars.update = vec_to_tensors(x, reference=params)
|
|
84
|
+
return vars
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class NystromPCG(Module):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
sketch_size: int,
|
|
92
|
+
maxiter=None,
|
|
93
|
+
tol=1e-3,
|
|
94
|
+
reg: float = 1e-6,
|
|
95
|
+
hvp_method: Literal["forward", "central", "autograd"] = "autograd",
|
|
96
|
+
h=1e-3,
|
|
97
|
+
inner: Chainable | None = None,
|
|
98
|
+
seed: int | None = None,
|
|
99
|
+
):
|
|
100
|
+
defaults = dict(sketch_size=sketch_size, reg=reg, maxiter=maxiter, tol=tol, hvp_method=hvp_method, h=h, seed=seed)
|
|
101
|
+
super().__init__(defaults,)
|
|
102
|
+
|
|
103
|
+
if inner is not None:
|
|
104
|
+
self.set_child('inner', inner)
|
|
105
|
+
|
|
106
|
+
@torch.no_grad
|
|
107
|
+
def step(self, vars):
|
|
108
|
+
params = TensorList(vars.params)
|
|
109
|
+
|
|
110
|
+
closure = vars.closure
|
|
111
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
112
|
+
|
|
113
|
+
settings = self.settings[params[0]]
|
|
114
|
+
sketch_size = settings['sketch_size']
|
|
115
|
+
maxiter = settings['maxiter']
|
|
116
|
+
tol = settings['tol']
|
|
117
|
+
reg = settings['reg']
|
|
118
|
+
hvp_method = settings['hvp_method']
|
|
119
|
+
h = settings['h']
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
seed = settings['seed']
|
|
123
|
+
generator = None
|
|
124
|
+
if seed is not None:
|
|
125
|
+
if 'generator' not in self.global_state:
|
|
126
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
127
|
+
generator = self.global_state['generator']
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# ---------------------- Hessian vector product function --------------------- #
|
|
131
|
+
if hvp_method == 'autograd':
|
|
132
|
+
grad = vars.get_grad(create_graph=True)
|
|
133
|
+
|
|
134
|
+
def H_mm(x):
|
|
135
|
+
with torch.enable_grad():
|
|
136
|
+
Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
|
|
137
|
+
return torch.cat([t.ravel() for t in Hvp])
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
|
|
141
|
+
with torch.enable_grad():
|
|
142
|
+
grad = vars.get_grad()
|
|
143
|
+
|
|
144
|
+
if hvp_method == 'forward':
|
|
145
|
+
def H_mm(x):
|
|
146
|
+
Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
|
|
147
|
+
return torch.cat([t.ravel() for t in Hvp])
|
|
148
|
+
|
|
149
|
+
elif hvp_method == 'central':
|
|
150
|
+
def H_mm(x):
|
|
151
|
+
Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
|
|
152
|
+
return torch.cat([t.ravel() for t in Hvp])
|
|
153
|
+
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(hvp_method)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# -------------------------------- inner step -------------------------------- #
|
|
159
|
+
b = grad
|
|
160
|
+
if 'inner' in self.children:
|
|
161
|
+
b = apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars)
|
|
162
|
+
|
|
163
|
+
# ------------------------------ sketch&n&solve ------------------------------ #
|
|
164
|
+
x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
|
|
165
|
+
vars.update = vec_to_tensors(x, reference=params)
|
|
166
|
+
return vars
|
|
167
|
+
|
|
168
|
+
|
|
@@ -1,5 +1,2 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
"""
|
|
4
|
-
from .laplacian_smoothing import LaplacianSmoothing, gradient_laplacian_smoothing_
|
|
5
|
-
from .gaussian_smoothing import GaussianHomotopy
|
|
1
|
+
from .laplacian import LaplacianSmoothing
|
|
2
|
+
from .gaussian import GaussianHomotopy
|