torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
torchzero/utils/derivatives.py
CHANGED
|
@@ -1,99 +1,513 @@
|
|
|
1
|
-
from collections.abc import
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
for i in
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
"""
|
|
98
|
-
|
|
99
|
-
|
|
1
|
+
from collections.abc import Iterable, Sequence
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.autograd.forward_ad as fwAD
|
|
5
|
+
|
|
6
|
+
from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
|
|
7
|
+
|
|
8
|
+
def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
9
|
+
flat_input = torch.cat([i.reshape(-1) for i in output])
|
|
10
|
+
grad_ouputs = torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype)
|
|
11
|
+
jac = []
|
|
12
|
+
for i in range(flat_input.numel()):
|
|
13
|
+
jac.append(torch.autograd.grad(
|
|
14
|
+
flat_input,
|
|
15
|
+
wrt,
|
|
16
|
+
grad_ouputs[i],
|
|
17
|
+
retain_graph=True,
|
|
18
|
+
create_graph=create_graph,
|
|
19
|
+
allow_unused=True,
|
|
20
|
+
is_grads_batched=False,
|
|
21
|
+
))
|
|
22
|
+
return [torch.stack(z) for z in zip(*jac)]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
26
|
+
flat_input = torch.cat([i.reshape(-1) for i in output])
|
|
27
|
+
return torch.autograd.grad(
|
|
28
|
+
flat_input,
|
|
29
|
+
wrt,
|
|
30
|
+
torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype),
|
|
31
|
+
retain_graph=True,
|
|
32
|
+
create_graph=create_graph,
|
|
33
|
+
allow_unused=True,
|
|
34
|
+
is_grads_batched=True,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
|
|
38
|
+
"""Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
|
|
39
|
+
Returns a sequence of tensors with the length as `wrt`.
|
|
40
|
+
Each tensor will have the shape `(*input.shape, *wrt[i].shape)`.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
input (Sequence[torch.Tensor]): input sequence of tensors.
|
|
44
|
+
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
45
|
+
create_graph (bool, optional):
|
|
46
|
+
pytorch option, if True, graph of the derivative will be constructed,
|
|
47
|
+
allowing to compute higher order derivative products. Default: False.
|
|
48
|
+
batched (bool, optional): use faster but experimental pytorch batched jacobian
|
|
49
|
+
This only has effect when `input` has more than 1 element. Defaults to True.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
sequence of tensors with the length as `wrt`.
|
|
53
|
+
"""
|
|
54
|
+
if batched: return _jacobian_batched(output, wrt, create_graph)
|
|
55
|
+
return _jacobian(output, wrt, create_graph)
|
|
56
|
+
|
|
57
|
+
def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
|
|
58
|
+
"""Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
|
|
59
|
+
Calculating hessian requires calculating the jacobian. So this function is more efficient than
|
|
60
|
+
calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
input (Sequence[torch.Tensor]): input sequence of tensors.
|
|
64
|
+
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
65
|
+
create_graph (bool, optional):
|
|
66
|
+
pytorch option, if True, graph of the derivative will be constructed,
|
|
67
|
+
allowing to compute higher order derivative products. Default: False.
|
|
68
|
+
batched (bool, optional): use faster but experimental pytorch batched grad. Defaults to True.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
tuple with jacobians sequence and hessians sequence.
|
|
72
|
+
"""
|
|
73
|
+
jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
|
|
74
|
+
return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
|
|
78
|
+
"""takes output of `hessian` and returns the 2D hessian matrix.
|
|
79
|
+
Note - I only tested this for cases where input is a scalar."""
|
|
80
|
+
return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
|
|
81
|
+
|
|
82
|
+
def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
|
|
83
|
+
"""Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
|
|
84
|
+
Calculating hessian requires calculating the jacobian. So this function is more efficient than
|
|
85
|
+
calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
input (Sequence[torch.Tensor]): input sequence of tensors.
|
|
89
|
+
wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
|
|
90
|
+
create_graph (bool, optional):
|
|
91
|
+
pytorch option, if True, graph of the derivative will be constructed,
|
|
92
|
+
allowing to compute higher order derivative products. Default: False.
|
|
93
|
+
batched (bool, optional): use faster but experimental pytorch batched grad. Defaults to True.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
tuple with jacobians sequence and hessians sequence.
|
|
97
|
+
"""
|
|
98
|
+
jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
|
|
99
|
+
H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
100
|
+
return torch.cat([j.view(-1) for j in jac]), hessian_list_to_mat(H_list)
|
|
101
|
+
|
|
102
|
+
def hessian(
|
|
103
|
+
fn,
|
|
104
|
+
params: Iterable[torch.Tensor],
|
|
105
|
+
create_graph=False,
|
|
106
|
+
method="func",
|
|
107
|
+
vectorize=False,
|
|
108
|
+
outer_jacobian_strategy="reverse-mode",
|
|
109
|
+
):
|
|
110
|
+
"""
|
|
111
|
+
returns list of lists of lists of values of hessian matrix of each param wrt each param.
|
|
112
|
+
To just get a single matrix use the :code:`hessian_mat` function.
|
|
113
|
+
|
|
114
|
+
`vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
|
|
115
|
+
|
|
116
|
+
Example:
|
|
117
|
+
.. code:: py
|
|
118
|
+
|
|
119
|
+
model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
|
|
120
|
+
X = torch.randn(10, 4)
|
|
121
|
+
y = torch.randn(10, 2)
|
|
122
|
+
|
|
123
|
+
def fn():
|
|
124
|
+
y_hat = model(X)
|
|
125
|
+
loss = F.mse_loss(y_hat, y)
|
|
126
|
+
return loss
|
|
127
|
+
|
|
128
|
+
hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
params = list(params)
|
|
133
|
+
|
|
134
|
+
def func(x: list[torch.Tensor]):
|
|
135
|
+
for p, x_i in zip(params, x): swap_tensors_no_use_count_check(p, x_i)
|
|
136
|
+
loss = fn()
|
|
137
|
+
for p, x_i in zip(params, x): swap_tensors_no_use_count_check(p, x_i)
|
|
138
|
+
return loss
|
|
139
|
+
|
|
140
|
+
if method == 'func':
|
|
141
|
+
return torch.func.hessian(func)([p.detach().requires_grad_(create_graph) for p in params])
|
|
142
|
+
|
|
143
|
+
if method == 'autograd.functional':
|
|
144
|
+
return torch.autograd.functional.hessian(
|
|
145
|
+
func,
|
|
146
|
+
[p.detach() for p in params],
|
|
147
|
+
create_graph=create_graph,
|
|
148
|
+
vectorize=vectorize,
|
|
149
|
+
outer_jacobian_strategy=outer_jacobian_strategy,
|
|
150
|
+
)
|
|
151
|
+
raise ValueError(method)
|
|
152
|
+
|
|
153
|
+
def hessian_mat(
|
|
154
|
+
fn,
|
|
155
|
+
params: Iterable[torch.Tensor],
|
|
156
|
+
create_graph=False,
|
|
157
|
+
method="func",
|
|
158
|
+
vectorize=False,
|
|
159
|
+
outer_jacobian_strategy="reverse-mode",
|
|
160
|
+
):
|
|
161
|
+
"""
|
|
162
|
+
returns hessian matrix for parameters (as if they were flattened and concatenated into a vector).
|
|
163
|
+
|
|
164
|
+
`vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
|
|
165
|
+
|
|
166
|
+
Example:
|
|
167
|
+
.. code:: py
|
|
168
|
+
|
|
169
|
+
model = nn.Linear(4, 2) # 10 parameters in total
|
|
170
|
+
X = torch.randn(10, 4)
|
|
171
|
+
y = torch.randn(10, 2)
|
|
172
|
+
|
|
173
|
+
def fn():
|
|
174
|
+
y_hat = model(X)
|
|
175
|
+
loss = F.mse_loss(y_hat, y)
|
|
176
|
+
return loss
|
|
177
|
+
|
|
178
|
+
hessian_mat(fn, model.parameters()) # 10x10 tensor
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
"""
|
|
182
|
+
params = list(params)
|
|
183
|
+
|
|
184
|
+
def func(x: torch.Tensor):
|
|
185
|
+
x_params = vec_to_tensors(x, params)
|
|
186
|
+
for p, x_i in zip(params, x_params): swap_tensors_no_use_count_check(p, x_i)
|
|
187
|
+
loss = fn()
|
|
188
|
+
for p, x_i in zip(params, x_params): swap_tensors_no_use_count_check(p, x_i)
|
|
189
|
+
return loss
|
|
190
|
+
|
|
191
|
+
if method == 'func':
|
|
192
|
+
return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph))
|
|
193
|
+
|
|
194
|
+
if method == 'autograd.functional':
|
|
195
|
+
return torch.autograd.functional.hessian(
|
|
196
|
+
func,
|
|
197
|
+
torch.cat([p.view(-1) for p in params]).detach(),
|
|
198
|
+
create_graph=create_graph,
|
|
199
|
+
vectorize=vectorize,
|
|
200
|
+
outer_jacobian_strategy=outer_jacobian_strategy,
|
|
201
|
+
)
|
|
202
|
+
raise ValueError(method)
|
|
203
|
+
|
|
204
|
+
def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
205
|
+
"""Jacobian vector product.
|
|
206
|
+
|
|
207
|
+
Example:
|
|
208
|
+
.. code:: py
|
|
209
|
+
|
|
210
|
+
model = nn.Linear(4, 2)
|
|
211
|
+
X = torch.randn(10, 4)
|
|
212
|
+
y = torch.randn(10, 2)
|
|
213
|
+
|
|
214
|
+
tangent = [torch.randn_like(p) for p in model.parameters()]
|
|
215
|
+
|
|
216
|
+
def fn():
|
|
217
|
+
y_hat = model(X)
|
|
218
|
+
loss = F.mse_loss(y_hat, y)
|
|
219
|
+
return loss
|
|
220
|
+
|
|
221
|
+
jvp(fn, model.parameters(), tangent) # scalar
|
|
222
|
+
|
|
223
|
+
"""
|
|
224
|
+
params = list(params)
|
|
225
|
+
tangent = list(tangent)
|
|
226
|
+
detached_params = [p.detach() for p in params]
|
|
227
|
+
|
|
228
|
+
duals = []
|
|
229
|
+
with fwAD.dual_level():
|
|
230
|
+
for p, d, t in zip(params, detached_params, tangent):
|
|
231
|
+
dual = fwAD.make_dual(d, t).requires_grad_(p.requires_grad)
|
|
232
|
+
duals.append(dual)
|
|
233
|
+
swap_tensors_no_use_count_check(p, dual)
|
|
234
|
+
|
|
235
|
+
loss = fn()
|
|
236
|
+
res = fwAD.unpack_dual(loss).tangent
|
|
237
|
+
|
|
238
|
+
for p, d in zip(params, duals):
|
|
239
|
+
swap_tensors_no_use_count_check(p, d)
|
|
240
|
+
return loss, res
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@torch.no_grad
|
|
245
|
+
def jvp_fd_central(
|
|
246
|
+
fn,
|
|
247
|
+
params: Iterable[torch.Tensor],
|
|
248
|
+
tangent: Iterable[torch.Tensor],
|
|
249
|
+
h=1e-3,
|
|
250
|
+
normalize=False,
|
|
251
|
+
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
|
252
|
+
"""Jacobian vector product using central finite difference formula.
|
|
253
|
+
|
|
254
|
+
Example:
|
|
255
|
+
.. code:: py
|
|
256
|
+
|
|
257
|
+
model = nn.Linear(4, 2)
|
|
258
|
+
X = torch.randn(10, 4)
|
|
259
|
+
y = torch.randn(10, 2)
|
|
260
|
+
|
|
261
|
+
tangent = [torch.randn_like(p) for p in model.parameters()]
|
|
262
|
+
|
|
263
|
+
def fn():
|
|
264
|
+
y_hat = model(X)
|
|
265
|
+
loss = F.mse_loss(y_hat, y)
|
|
266
|
+
return loss
|
|
267
|
+
|
|
268
|
+
jvp_fd_central(fn, model.parameters(), tangent) # scalar
|
|
269
|
+
|
|
270
|
+
"""
|
|
271
|
+
params = list(params)
|
|
272
|
+
tangent = list(tangent)
|
|
273
|
+
|
|
274
|
+
tangent_norm = None
|
|
275
|
+
if normalize:
|
|
276
|
+
tangent_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in tangent])) # pylint:disable=not-callable
|
|
277
|
+
if tangent_norm == 0: return None, torch.tensor(0., device=tangent[0].device, dtype=tangent[0].dtype)
|
|
278
|
+
tangent = torch._foreach_div(tangent, tangent_norm)
|
|
279
|
+
|
|
280
|
+
tangent_h= torch._foreach_mul(tangent, h)
|
|
281
|
+
|
|
282
|
+
torch._foreach_add_(params, tangent_h)
|
|
283
|
+
v_plus = fn()
|
|
284
|
+
torch._foreach_sub_(params, tangent_h)
|
|
285
|
+
torch._foreach_sub_(params, tangent_h)
|
|
286
|
+
v_minus = fn()
|
|
287
|
+
torch._foreach_add_(params, tangent_h)
|
|
288
|
+
|
|
289
|
+
res = (v_plus - v_minus) / (2 * h)
|
|
290
|
+
if normalize: res = res * tangent_norm
|
|
291
|
+
return v_plus, res
|
|
292
|
+
|
|
293
|
+
@torch.no_grad
|
|
294
|
+
def jvp_fd_forward(
|
|
295
|
+
fn,
|
|
296
|
+
params: Iterable[torch.Tensor],
|
|
297
|
+
tangent: Iterable[torch.Tensor],
|
|
298
|
+
h=1e-3,
|
|
299
|
+
v_0=None,
|
|
300
|
+
normalize=False,
|
|
301
|
+
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
|
302
|
+
"""Jacobian vector product using forward finite difference formula.
|
|
303
|
+
Loss at initial point can be specified in the `v_0` argument.
|
|
304
|
+
|
|
305
|
+
Example:
|
|
306
|
+
.. code:: py
|
|
307
|
+
|
|
308
|
+
model = nn.Linear(4, 2)
|
|
309
|
+
X = torch.randn(10, 4)
|
|
310
|
+
y = torch.randn(10, 2)
|
|
311
|
+
|
|
312
|
+
tangent1 = [torch.randn_like(p) for p in model.parameters()]
|
|
313
|
+
tangent2 = [torch.randn_like(p) for p in model.parameters()]
|
|
314
|
+
|
|
315
|
+
def fn():
|
|
316
|
+
y_hat = model(X)
|
|
317
|
+
loss = F.mse_loss(y_hat, y)
|
|
318
|
+
return loss
|
|
319
|
+
|
|
320
|
+
v_0 = fn() # pre-calculate loss at initial point
|
|
321
|
+
|
|
322
|
+
jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
|
|
323
|
+
jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
|
|
324
|
+
|
|
325
|
+
"""
|
|
326
|
+
params = list(params)
|
|
327
|
+
tangent = list(tangent)
|
|
328
|
+
|
|
329
|
+
tangent_norm = None
|
|
330
|
+
if normalize:
|
|
331
|
+
tangent_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in tangent])) # pylint:disable=not-callable
|
|
332
|
+
if tangent_norm == 0: return None, torch.tensor(0., device=tangent[0].device, dtype=tangent[0].dtype)
|
|
333
|
+
tangent = torch._foreach_div(tangent, tangent_norm)
|
|
334
|
+
|
|
335
|
+
tangent_h= torch._foreach_mul(tangent, h)
|
|
336
|
+
|
|
337
|
+
if v_0 is None: v_0 = fn()
|
|
338
|
+
|
|
339
|
+
torch._foreach_add_(params, tangent_h)
|
|
340
|
+
v_plus = fn()
|
|
341
|
+
torch._foreach_sub_(params, tangent_h)
|
|
342
|
+
|
|
343
|
+
res = (v_plus - v_0) / h
|
|
344
|
+
if normalize: res = res * tangent_norm
|
|
345
|
+
return v_0, res
|
|
346
|
+
|
|
347
|
+
def hvp(
|
|
348
|
+
params: Iterable[torch.Tensor],
|
|
349
|
+
grads: Iterable[torch.Tensor],
|
|
350
|
+
vec: Iterable[torch.Tensor],
|
|
351
|
+
retain_graph=None,
|
|
352
|
+
create_graph=False,
|
|
353
|
+
allow_unused=None,
|
|
354
|
+
):
|
|
355
|
+
"""Hessian-vector product
|
|
356
|
+
|
|
357
|
+
Example:
|
|
358
|
+
.. code:: py
|
|
359
|
+
|
|
360
|
+
model = nn.Linear(4, 2)
|
|
361
|
+
X = torch.randn(10, 4)
|
|
362
|
+
y = torch.randn(10, 2)
|
|
363
|
+
|
|
364
|
+
y_hat = model(X)
|
|
365
|
+
loss = F.mse_loss(y_hat, y)
|
|
366
|
+
loss.backward(create_graph=True)
|
|
367
|
+
|
|
368
|
+
grads = [p.grad for p in model.parameters()]
|
|
369
|
+
vec = [torch.randn_like(p) for p in model.parameters()]
|
|
370
|
+
|
|
371
|
+
# list of tensors, same layout as model.parameters()
|
|
372
|
+
hvp(model.parameters(), grads, vec=vec)
|
|
373
|
+
"""
|
|
374
|
+
params = list(params)
|
|
375
|
+
g = list(grads)
|
|
376
|
+
vec = list(vec)
|
|
377
|
+
|
|
378
|
+
with torch.enable_grad():
|
|
379
|
+
return torch.autograd.grad(g, params, vec, create_graph=create_graph, retain_graph=retain_graph, allow_unused=allow_unused)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@torch.no_grad
|
|
383
|
+
def hvp_fd_central(
|
|
384
|
+
closure,
|
|
385
|
+
params: Iterable[torch.Tensor],
|
|
386
|
+
vec: Iterable[torch.Tensor],
|
|
387
|
+
h=1e-3,
|
|
388
|
+
normalize=False,
|
|
389
|
+
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
|
|
390
|
+
"""Hessian-vector product using central finite difference formula.
|
|
391
|
+
|
|
392
|
+
Please note that this will clear :code:`grad` attributes in params.
|
|
393
|
+
|
|
394
|
+
Example:
|
|
395
|
+
.. code:: py
|
|
396
|
+
|
|
397
|
+
model = nn.Linear(4, 2)
|
|
398
|
+
X = torch.randn(10, 4)
|
|
399
|
+
y = torch.randn(10, 2)
|
|
400
|
+
|
|
401
|
+
def closure():
|
|
402
|
+
y_hat = model(X)
|
|
403
|
+
loss = F.mse_loss(y_hat, y)
|
|
404
|
+
model.zero_grad()
|
|
405
|
+
loss.backward()
|
|
406
|
+
return loss
|
|
407
|
+
|
|
408
|
+
vec = [torch.randn_like(p) for p in model.parameters()]
|
|
409
|
+
|
|
410
|
+
# list of tensors, same layout as model.parameters()
|
|
411
|
+
hvp_fd_central(closure, model.parameters(), vec=vec)
|
|
412
|
+
"""
|
|
413
|
+
params = list(params)
|
|
414
|
+
vec = list(vec)
|
|
415
|
+
|
|
416
|
+
vec_norm = None
|
|
417
|
+
if normalize:
|
|
418
|
+
vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
|
|
419
|
+
if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
|
|
420
|
+
vec = torch._foreach_div(vec, vec_norm)
|
|
421
|
+
|
|
422
|
+
vec_h = torch._foreach_mul(vec, h)
|
|
423
|
+
torch._foreach_add_(params, vec_h)
|
|
424
|
+
with torch.enable_grad(): loss = closure()
|
|
425
|
+
g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
426
|
+
|
|
427
|
+
torch._foreach_sub_(params, vec_h)
|
|
428
|
+
torch._foreach_sub_(params, vec_h)
|
|
429
|
+
with torch.enable_grad(): loss = closure()
|
|
430
|
+
g_minus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
431
|
+
|
|
432
|
+
torch._foreach_add_(params, vec_h)
|
|
433
|
+
for p in params: p.grad = None
|
|
434
|
+
|
|
435
|
+
hvp_ = g_plus
|
|
436
|
+
torch._foreach_sub_(hvp_, g_minus)
|
|
437
|
+
torch._foreach_div_(hvp_, 2*h)
|
|
438
|
+
|
|
439
|
+
if normalize: torch._foreach_mul_(hvp_, vec_norm)
|
|
440
|
+
return loss, hvp_
|
|
441
|
+
|
|
442
|
+
@torch.no_grad
|
|
443
|
+
def hvp_fd_forward(
|
|
444
|
+
closure,
|
|
445
|
+
params: Iterable[torch.Tensor],
|
|
446
|
+
vec: Iterable[torch.Tensor],
|
|
447
|
+
h=1e-3,
|
|
448
|
+
g_0=None,
|
|
449
|
+
normalize=False,
|
|
450
|
+
) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
|
|
451
|
+
"""Hessian-vector product using forward finite difference formula.
|
|
452
|
+
|
|
453
|
+
Gradient at initial point can be specified in the `g_0` argument.
|
|
454
|
+
|
|
455
|
+
Please note that this will clear :code:`grad` attributes in params.
|
|
456
|
+
|
|
457
|
+
Example:
|
|
458
|
+
.. code:: py
|
|
459
|
+
|
|
460
|
+
model = nn.Linear(4, 2)
|
|
461
|
+
X = torch.randn(10, 4)
|
|
462
|
+
y = torch.randn(10, 2)
|
|
463
|
+
|
|
464
|
+
def closure():
|
|
465
|
+
y_hat = model(X)
|
|
466
|
+
loss = F.mse_loss(y_hat, y)
|
|
467
|
+
model.zero_grad()
|
|
468
|
+
loss.backward()
|
|
469
|
+
return loss
|
|
470
|
+
|
|
471
|
+
vec = [torch.randn_like(p) for p in model.parameters()]
|
|
472
|
+
|
|
473
|
+
# pre-compute gradient at initial point
|
|
474
|
+
closure()
|
|
475
|
+
g_0 = [p.grad for p in model.parameters()]
|
|
476
|
+
|
|
477
|
+
# list of tensors, same layout as model.parameters()
|
|
478
|
+
hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
params = list(params)
|
|
482
|
+
vec = list(vec)
|
|
483
|
+
loss = None
|
|
484
|
+
|
|
485
|
+
vec_norm = None
|
|
486
|
+
if normalize:
|
|
487
|
+
vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
|
|
488
|
+
if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
|
|
489
|
+
vec = torch._foreach_div(vec, vec_norm)
|
|
490
|
+
|
|
491
|
+
vec_h = torch._foreach_mul(vec, h)
|
|
492
|
+
|
|
493
|
+
if g_0 is None:
|
|
494
|
+
with torch.enable_grad(): loss = closure()
|
|
495
|
+
g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
496
|
+
else:
|
|
497
|
+
g_0 = list(g_0)
|
|
498
|
+
|
|
499
|
+
torch._foreach_add_(params, vec_h)
|
|
500
|
+
with torch.enable_grad():
|
|
501
|
+
l = closure()
|
|
502
|
+
if loss is None: loss = l
|
|
503
|
+
g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
504
|
+
|
|
505
|
+
torch._foreach_sub_(params, vec_h)
|
|
506
|
+
for p in params: p.grad = None
|
|
507
|
+
|
|
508
|
+
hvp_ = g_plus
|
|
509
|
+
torch._foreach_sub_(hvp_, g_0)
|
|
510
|
+
torch._foreach_div_(hvp_, h)
|
|
511
|
+
|
|
512
|
+
if normalize: torch._foreach_mul_(hvp_, vec_norm)
|
|
513
|
+
return loss, hvp_
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix_power_eigh, x_inv
|
|
2
|
+
from .orthogonalize import gram_schmidt
|
|
3
|
+
from .qr import qr_householder
|
|
4
|
+
from .svd import randomized_svd
|
|
5
|
+
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
|