torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
torchzero/utils/__init__.py
CHANGED
|
@@ -1,5 +1,11 @@
|
|
|
1
1
|
from . import tensorlist as tl
|
|
2
|
-
from .compile import
|
|
2
|
+
from .compile import (
|
|
3
|
+
_optional_compiler,
|
|
4
|
+
benchmark_compile_cpu,
|
|
5
|
+
benchmark_compile_cuda,
|
|
6
|
+
enable_compilation,
|
|
7
|
+
set_compilation,
|
|
8
|
+
)
|
|
3
9
|
from .numberlist import NumberList
|
|
4
10
|
from .optimizer import (
|
|
5
11
|
Init,
|
|
@@ -18,6 +24,36 @@ from .params import (
|
|
|
18
24
|
_copy_param_groups,
|
|
19
25
|
_make_param_groups,
|
|
20
26
|
)
|
|
21
|
-
from .python_tools import
|
|
22
|
-
|
|
23
|
-
|
|
27
|
+
from .python_tools import (
|
|
28
|
+
flatten,
|
|
29
|
+
generic_eq,
|
|
30
|
+
generic_ne,
|
|
31
|
+
reduce_dim,
|
|
32
|
+
safe_dict_update_,
|
|
33
|
+
unpack_dicts,
|
|
34
|
+
)
|
|
35
|
+
from .tensorlist import (
|
|
36
|
+
Distributions,
|
|
37
|
+
Metrics,
|
|
38
|
+
TensorList,
|
|
39
|
+
as_tensorlist,
|
|
40
|
+
generic_clamp,
|
|
41
|
+
generic_finfo,
|
|
42
|
+
generic_finfo_eps,
|
|
43
|
+
generic_finfo_tiny,
|
|
44
|
+
generic_max,
|
|
45
|
+
generic_numel,
|
|
46
|
+
generic_randn_like,
|
|
47
|
+
generic_sum,
|
|
48
|
+
generic_vector_norm,
|
|
49
|
+
generic_zeros_like,
|
|
50
|
+
)
|
|
51
|
+
from .torch_tools import (
|
|
52
|
+
set_storage_,
|
|
53
|
+
tofloat,
|
|
54
|
+
tolist,
|
|
55
|
+
tonumpy,
|
|
56
|
+
totensor,
|
|
57
|
+
vec_to_tensors,
|
|
58
|
+
vec_to_tensors_,
|
|
59
|
+
)
|
torchzero/utils/compile.py
CHANGED
|
@@ -38,7 +38,7 @@ class _MaybeCompiledFunc:
|
|
|
38
38
|
_optional_compiler = _OptionalCompiler()
|
|
39
39
|
"""this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
|
|
40
40
|
|
|
41
|
-
def set_compilation(enable: bool):
|
|
41
|
+
def set_compilation(enable: bool=True):
|
|
42
42
|
"""`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
|
|
43
43
|
_optional_compiler.enable = enable
|
|
44
44
|
|
torchzero/utils/derivatives.py
CHANGED
|
@@ -2,7 +2,6 @@ from collections.abc import Iterable, Sequence
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
import torch.autograd.forward_ad as fwAD
|
|
5
|
-
from typing import Literal
|
|
6
5
|
|
|
7
6
|
from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
|
|
8
7
|
|
|
@@ -35,10 +34,27 @@ def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor
|
|
|
35
34
|
is_grads_batched=True,
|
|
36
35
|
)
|
|
37
36
|
|
|
37
|
+
def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
38
|
+
"""Converts the output of jacobian_wrt (a list of tensors) into a single 2D matrix.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
jacs (Sequence[torch.Tensor]):
|
|
42
|
+
output from jacobian_wrt where ach tensor has the shape `(*output.shape, *wrt[i].shape)`.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
torch.Tensor: has the shape `(output.ndim, wrt.ndim)`.
|
|
46
|
+
"""
|
|
47
|
+
if not jacs:
|
|
48
|
+
return torch.empty(0, 0)
|
|
49
|
+
|
|
50
|
+
n_out = jacs[0].shape[0]
|
|
51
|
+
return torch.cat([j.reshape(n_out, -1) for j in jacs], dim=1)
|
|
52
|
+
|
|
53
|
+
|
|
38
54
|
def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
|
|
39
55
|
"""Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
|
|
40
56
|
Returns a sequence of tensors with the length as `wrt`.
|
|
41
|
-
Each tensor will have the shape `(*
|
|
57
|
+
Each tensor will have the shape `(*output.shape, *wrt[i].shape)`.
|
|
42
58
|
|
|
43
59
|
Args:
|
|
44
60
|
input (Sequence[torch.Tensor]): input sequence of tensors.
|
|
@@ -75,10 +91,10 @@ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch
|
|
|
75
91
|
return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
76
92
|
|
|
77
93
|
|
|
78
|
-
def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
94
|
+
# def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
|
|
95
|
+
# """takes output of `hessian` and returns the 2D hessian matrix.
|
|
96
|
+
# Note - I only tested this for cases where input is a scalar."""
|
|
97
|
+
# return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
|
|
82
98
|
|
|
83
99
|
def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
|
|
84
100
|
"""Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
|
|
@@ -98,7 +114,7 @@ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[t
|
|
|
98
114
|
"""
|
|
99
115
|
jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
|
|
100
116
|
H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
|
|
101
|
-
return
|
|
117
|
+
return flatten_jacobian(jac), flatten_jacobian(H_list)
|
|
102
118
|
|
|
103
119
|
def hessian(
|
|
104
120
|
fn,
|
|
@@ -115,19 +131,18 @@ def hessian(
|
|
|
115
131
|
`vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
|
|
116
132
|
|
|
117
133
|
Example:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
y = torch.randn(10, 2)
|
|
134
|
+
```python
|
|
135
|
+
model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
|
|
136
|
+
X = torch.randn(10, 4)
|
|
137
|
+
y = torch.randn(10, 2)
|
|
123
138
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
|
|
139
|
+
def fn():
|
|
140
|
+
y_hat = model(X)
|
|
141
|
+
loss = F.mse_loss(y_hat, y)
|
|
142
|
+
return loss
|
|
130
143
|
|
|
144
|
+
hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
|
|
145
|
+
```
|
|
131
146
|
|
|
132
147
|
"""
|
|
133
148
|
params = list(params)
|
|
@@ -165,19 +180,18 @@ def hessian_mat(
|
|
|
165
180
|
`vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
|
|
166
181
|
|
|
167
182
|
Example:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
y = torch.randn(10, 2)
|
|
183
|
+
```python
|
|
184
|
+
model = nn.Linear(4, 2) # 10 parameters in total
|
|
185
|
+
X = torch.randn(10, 4)
|
|
186
|
+
y = torch.randn(10, 2)
|
|
173
187
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
hessian_mat(fn, model.parameters()) # 10x10 tensor
|
|
188
|
+
def fn():
|
|
189
|
+
y_hat = model(X)
|
|
190
|
+
loss = F.mse_loss(y_hat, y)
|
|
191
|
+
return loss
|
|
180
192
|
|
|
193
|
+
hessian_mat(fn, model.parameters()) # 10x10 tensor
|
|
194
|
+
```
|
|
181
195
|
|
|
182
196
|
"""
|
|
183
197
|
params = list(params)
|
|
@@ -206,21 +220,20 @@ def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) ->
|
|
|
206
220
|
"""Jacobian vector product.
|
|
207
221
|
|
|
208
222
|
Example:
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
y = torch.randn(10, 2)
|
|
214
|
-
|
|
215
|
-
tangent = [torch.randn_like(p) for p in model.parameters()]
|
|
223
|
+
```python
|
|
224
|
+
model = nn.Linear(4, 2)
|
|
225
|
+
X = torch.randn(10, 4)
|
|
226
|
+
y = torch.randn(10, 2)
|
|
216
227
|
|
|
217
|
-
|
|
218
|
-
y_hat = model(X)
|
|
219
|
-
loss = F.mse_loss(y_hat, y)
|
|
220
|
-
return loss
|
|
228
|
+
tangent = [torch.randn_like(p) for p in model.parameters()]
|
|
221
229
|
|
|
222
|
-
|
|
230
|
+
def fn():
|
|
231
|
+
y_hat = model(X)
|
|
232
|
+
loss = F.mse_loss(y_hat, y)
|
|
233
|
+
return loss
|
|
223
234
|
|
|
235
|
+
jvp(fn, model.parameters(), tangent) # scalar
|
|
236
|
+
```
|
|
224
237
|
"""
|
|
225
238
|
params = list(params)
|
|
226
239
|
tangent = list(tangent)
|
|
@@ -253,21 +266,20 @@ def jvp_fd_central(
|
|
|
253
266
|
"""Jacobian vector product using central finite difference formula.
|
|
254
267
|
|
|
255
268
|
Example:
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
y = torch.randn(10, 2)
|
|
261
|
-
|
|
262
|
-
tangent = [torch.randn_like(p) for p in model.parameters()]
|
|
269
|
+
```python
|
|
270
|
+
model = nn.Linear(4, 2)
|
|
271
|
+
X = torch.randn(10, 4)
|
|
272
|
+
y = torch.randn(10, 2)
|
|
263
273
|
|
|
264
|
-
|
|
265
|
-
y_hat = model(X)
|
|
266
|
-
loss = F.mse_loss(y_hat, y)
|
|
267
|
-
return loss
|
|
274
|
+
tangent = [torch.randn_like(p) for p in model.parameters()]
|
|
268
275
|
|
|
269
|
-
|
|
276
|
+
def fn():
|
|
277
|
+
y_hat = model(X)
|
|
278
|
+
loss = F.mse_loss(y_hat, y)
|
|
279
|
+
return loss
|
|
270
280
|
|
|
281
|
+
jvp_fd_central(fn, model.parameters(), tangent) # scalar
|
|
282
|
+
```
|
|
271
283
|
"""
|
|
272
284
|
params = list(params)
|
|
273
285
|
tangent = list(tangent)
|
|
@@ -304,24 +316,24 @@ def jvp_fd_forward(
|
|
|
304
316
|
Loss at initial point can be specified in the `v_0` argument.
|
|
305
317
|
|
|
306
318
|
Example:
|
|
307
|
-
|
|
319
|
+
```python
|
|
320
|
+
model = nn.Linear(4, 2)
|
|
321
|
+
X = torch.randn(10, 4)
|
|
322
|
+
y = torch.randn(10, 2)
|
|
308
323
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
y = torch.randn(10, 2)
|
|
324
|
+
tangent1 = [torch.randn_like(p) for p in model.parameters()]
|
|
325
|
+
tangent2 = [torch.randn_like(p) for p in model.parameters()]
|
|
312
326
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
y_hat = model(X)
|
|
318
|
-
loss = F.mse_loss(y_hat, y)
|
|
319
|
-
return loss
|
|
327
|
+
def fn():
|
|
328
|
+
y_hat = model(X)
|
|
329
|
+
loss = F.mse_loss(y_hat, y)
|
|
330
|
+
return loss
|
|
320
331
|
|
|
321
|
-
|
|
332
|
+
v_0 = fn() # pre-calculate loss at initial point
|
|
322
333
|
|
|
323
|
-
|
|
324
|
-
|
|
334
|
+
jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
|
|
335
|
+
jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
|
|
336
|
+
```
|
|
325
337
|
|
|
326
338
|
"""
|
|
327
339
|
params = list(params)
|
|
@@ -356,21 +368,21 @@ def hvp(
|
|
|
356
368
|
"""Hessian-vector product
|
|
357
369
|
|
|
358
370
|
Example:
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
y = torch.randn(10, 2)
|
|
371
|
+
```python
|
|
372
|
+
model = nn.Linear(4, 2)
|
|
373
|
+
X = torch.randn(10, 4)
|
|
374
|
+
y = torch.randn(10, 2)
|
|
364
375
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
376
|
+
y_hat = model(X)
|
|
377
|
+
loss = F.mse_loss(y_hat, y)
|
|
378
|
+
loss.backward(create_graph=True)
|
|
368
379
|
|
|
369
|
-
|
|
370
|
-
|
|
380
|
+
grads = [p.grad for p in model.parameters()]
|
|
381
|
+
vec = [torch.randn_like(p) for p in model.parameters()]
|
|
371
382
|
|
|
372
|
-
|
|
373
|
-
|
|
383
|
+
# list of tensors, same layout as model.parameters()
|
|
384
|
+
hvp(model.parameters(), grads, vec=vec)
|
|
385
|
+
```
|
|
374
386
|
"""
|
|
375
387
|
params = list(params)
|
|
376
388
|
g = list(grads)
|
|
@@ -393,23 +405,23 @@ def hvp_fd_central(
|
|
|
393
405
|
Please note that this will clear :code:`grad` attributes in params.
|
|
394
406
|
|
|
395
407
|
Example:
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
y = torch.randn(10, 2)
|
|
408
|
+
```python
|
|
409
|
+
model = nn.Linear(4, 2)
|
|
410
|
+
X = torch.randn(10, 4)
|
|
411
|
+
y = torch.randn(10, 2)
|
|
401
412
|
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
413
|
+
def closure():
|
|
414
|
+
y_hat = model(X)
|
|
415
|
+
loss = F.mse_loss(y_hat, y)
|
|
416
|
+
model.zero_grad()
|
|
417
|
+
loss.backward()
|
|
418
|
+
return loss
|
|
408
419
|
|
|
409
|
-
|
|
420
|
+
vec = [torch.randn_like(p) for p in model.parameters()]
|
|
410
421
|
|
|
411
|
-
|
|
412
|
-
|
|
422
|
+
# list of tensors, same layout as model.parameters()
|
|
423
|
+
hvp_fd_central(closure, model.parameters(), vec=vec)
|
|
424
|
+
```
|
|
413
425
|
"""
|
|
414
426
|
params = list(params)
|
|
415
427
|
vec = list(vec)
|
|
@@ -456,27 +468,27 @@ def hvp_fd_forward(
|
|
|
456
468
|
Please note that this will clear :code:`grad` attributes in params.
|
|
457
469
|
|
|
458
470
|
Example:
|
|
459
|
-
|
|
471
|
+
```python
|
|
472
|
+
model = nn.Linear(4, 2)
|
|
473
|
+
X = torch.randn(10, 4)
|
|
474
|
+
y = torch.randn(10, 2)
|
|
460
475
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
loss = F.mse_loss(y_hat, y)
|
|
468
|
-
model.zero_grad()
|
|
469
|
-
loss.backward()
|
|
470
|
-
return loss
|
|
476
|
+
def closure():
|
|
477
|
+
y_hat = model(X)
|
|
478
|
+
loss = F.mse_loss(y_hat, y)
|
|
479
|
+
model.zero_grad()
|
|
480
|
+
loss.backward()
|
|
481
|
+
return loss
|
|
471
482
|
|
|
472
|
-
|
|
483
|
+
vec = [torch.randn_like(p) for p in model.parameters()]
|
|
473
484
|
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
485
|
+
# pre-compute gradient at initial point
|
|
486
|
+
closure()
|
|
487
|
+
g_0 = [p.grad for p in model.parameters()]
|
|
477
488
|
|
|
478
|
-
|
|
479
|
-
|
|
489
|
+
# list of tensors, same layout as model.parameters()
|
|
490
|
+
hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
|
|
491
|
+
```
|
|
480
492
|
"""
|
|
481
493
|
|
|
482
494
|
params = list(params)
|
|
@@ -485,7 +497,7 @@ def hvp_fd_forward(
|
|
|
485
497
|
|
|
486
498
|
vec_norm = None
|
|
487
499
|
if normalize:
|
|
488
|
-
vec_norm = torch.linalg.vector_norm(torch.cat([t.
|
|
500
|
+
vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in vec])) # pylint:disable=not-callable
|
|
489
501
|
if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
|
|
490
502
|
vec = torch._foreach_div(vec, vec_norm)
|
|
491
503
|
|
|
@@ -1,5 +1,12 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from . import linear_operator
|
|
2
|
+
from .matrix_funcs import (
|
|
3
|
+
eigvals_func,
|
|
4
|
+
inv_sqrt_2x2,
|
|
5
|
+
matrix_power_eigh,
|
|
6
|
+
singular_vals_func,
|
|
7
|
+
x_inv,
|
|
8
|
+
)
|
|
2
9
|
from .orthogonalize import gram_schmidt
|
|
3
10
|
from .qr import qr_householder
|
|
11
|
+
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
|
|
4
12
|
from .svd import randomized_svd
|
|
5
|
-
from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve, steihaug_toint_cg
|