torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
tests/test_identical.py
CHANGED
|
@@ -219,8 +219,8 @@ def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
|
|
|
219
219
|
|
|
220
220
|
@pytest.mark.parametrize('tensorwise', [True, False])
|
|
221
221
|
def test_graft(tensorwise):
|
|
222
|
-
graft1 = lambda p: tz.Modular(p, tz.m.
|
|
223
|
-
graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.
|
|
222
|
+
graft1 = lambda p: tz.Modular(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
|
|
223
|
+
graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
|
|
224
224
|
_assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
|
|
225
225
|
for fn in [graft1, graft2]:
|
|
226
226
|
if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)
|
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
from importlib.util import find_spec
|
|
2
|
+
# pylint:disable=deprecated-method
|
|
3
|
+
from typing import Any
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
import torchzero as tz
|
|
10
|
+
from torchzero.utils import TensorList, vec_to_tensors
|
|
11
|
+
|
|
12
|
+
# ----------------------------------- utils ---------------------------------- #
|
|
13
|
+
DEVICES = ["cpu"]
|
|
14
|
+
if torch.cuda.is_available(): DEVICES.append("cuda")
|
|
15
|
+
DEVICES = tuple(DEVICES)
|
|
16
|
+
|
|
17
|
+
def _gen(device):
|
|
18
|
+
return torch.Generator(device).manual_seed(0)
|
|
19
|
+
|
|
20
|
+
def cat(ts: Sequence[torch.Tensor]):
|
|
21
|
+
return torch.cat([t.flatten() for t in ts])
|
|
22
|
+
|
|
23
|
+
def numel(ts: Sequence[torch.Tensor]):
|
|
24
|
+
return sum(t.numel() for t in ts)
|
|
25
|
+
|
|
26
|
+
def assert_tl_equal_(tl1: Sequence[torch.Tensor | Any], tl2: Sequence[torch.Tensor | Any]):
|
|
27
|
+
assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
|
|
28
|
+
for t1, t2 in zip(tl1, tl2):
|
|
29
|
+
if t1 is None and t2 is None:
|
|
30
|
+
continue
|
|
31
|
+
assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
|
|
32
|
+
assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
|
|
33
|
+
assert torch.equal(t1, t2), f"Tensors are not equal:\n{t1}\nvs\n{t2}"
|
|
34
|
+
|
|
35
|
+
def assert_tl_allclose_(tl1: Sequence[torch.Tensor | Any], tl2: Sequence[torch.Tensor | Any], **kwargs):
|
|
36
|
+
assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
|
|
37
|
+
for t1, t2 in zip(tl1, tl2):
|
|
38
|
+
if t1 is None and t2 is None:
|
|
39
|
+
continue
|
|
40
|
+
assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
|
|
41
|
+
assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
|
|
42
|
+
assert torch.allclose(t1, t2, equal_nan=True, **kwargs), f"Tensors are not close:\n{t1}\nvs\n{t2}"
|
|
43
|
+
|
|
44
|
+
def assert_tl_same_(seq1: Sequence[torch.Tensor], seq2: Sequence[torch.Tensor]):
|
|
45
|
+
seq1=tuple(seq1)
|
|
46
|
+
seq2=tuple(seq2)
|
|
47
|
+
assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
|
|
48
|
+
for t1, t2 in zip(seq1, seq2):
|
|
49
|
+
assert t1 is t2
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def assert_tl_same_storage_(seq1: Sequence[torch.Tensor], seq2: Sequence[torch.Tensor]):
|
|
53
|
+
seq1=tuple(seq1)
|
|
54
|
+
seq2=tuple(seq2)
|
|
55
|
+
assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
|
|
56
|
+
for t1, t2 in zip(seq1, seq2):
|
|
57
|
+
assert t1.data_ptr() == t2.data_ptr()
|
|
58
|
+
|
|
59
|
+
class _EvalCounter:
|
|
60
|
+
def __init__(self, closure):
|
|
61
|
+
self.closure = closure
|
|
62
|
+
self.false = 0
|
|
63
|
+
self.true = 0
|
|
64
|
+
|
|
65
|
+
def __call__(self, backward=True):
|
|
66
|
+
if backward: self.true += 1
|
|
67
|
+
else: self.false += 1
|
|
68
|
+
return self.closure(backward)
|
|
69
|
+
|
|
70
|
+
def assert_(self, true:int, false:int):
|
|
71
|
+
assert true == self.true
|
|
72
|
+
assert false == self.false
|
|
73
|
+
|
|
74
|
+
def __repr__(self):
|
|
75
|
+
return f"EvalCounter(true={self.true}, false={self.false})"
|
|
76
|
+
|
|
77
|
+
# --------------------------------- objective --------------------------------
|
|
78
|
+
|
|
79
|
+
def objective_value(x:torch.Tensor, A:torch.Tensor, b:torch.Tensor):
|
|
80
|
+
return 0.5 * x @ A @ x + (b @ x).exp()
|
|
81
|
+
|
|
82
|
+
def analytical_gradient(x:torch.Tensor, A:torch.Tensor, b:torch.Tensor):
|
|
83
|
+
return A @ x + (b @ x).exp() * b
|
|
84
|
+
|
|
85
|
+
def analytical_hessian(x:torch.Tensor, A:torch.Tensor, b:torch.Tensor):
|
|
86
|
+
return A + (b @ x).exp() * b.outer(b)
|
|
87
|
+
|
|
88
|
+
def analytical_derivative(x: torch.Tensor, b:torch.Tensor, order: int) -> torch.Tensor:
|
|
89
|
+
assert order >= 3
|
|
90
|
+
# n-th order outer product
|
|
91
|
+
# n=4 -> 'i,j,k,l->ijkl'
|
|
92
|
+
indices = 'ijklmnopqrstuvwxyz'[:order]
|
|
93
|
+
b_outer = torch.einsum(f"{','.join(indices)}->{indices}", *[b] * order)
|
|
94
|
+
return (b @ x).exp() * b_outer
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_var(device, dtype=torch.float32):
|
|
98
|
+
|
|
99
|
+
# we cat a few tensors to make sure those methods handle multiple params correctly
|
|
100
|
+
p1 = torch.tensor(1., requires_grad=True, device=device, dtype=dtype)
|
|
101
|
+
p2 = torch.randn(1, 3, 2, requires_grad=True, device=device, generator=_gen(device), dtype=dtype)
|
|
102
|
+
p3 = torch.randn(4, requires_grad=True, device=device, generator=_gen(device), dtype=dtype)
|
|
103
|
+
|
|
104
|
+
params = [p1, p2, p3]
|
|
105
|
+
n = numel(params)
|
|
106
|
+
|
|
107
|
+
A = torch.randn(n, n, device=device, generator=_gen(device), dtype=dtype)
|
|
108
|
+
A = A.T @ A + torch.eye(n, device=device, dtype=dtype) * 1e-3
|
|
109
|
+
b = torch.randn(n, device=device, generator=_gen(device), dtype=dtype)
|
|
110
|
+
|
|
111
|
+
def closure(backward=True):
|
|
112
|
+
x = cat(params)
|
|
113
|
+
loss = objective_value(x, A, b)
|
|
114
|
+
|
|
115
|
+
if backward:
|
|
116
|
+
for p in params:
|
|
117
|
+
p.grad = None
|
|
118
|
+
loss.backward()
|
|
119
|
+
|
|
120
|
+
return loss
|
|
121
|
+
|
|
122
|
+
objective = _EvalCounter(closure)
|
|
123
|
+
var = tz.core.Objective(params=params, closure=objective, model=None, current_step=0)
|
|
124
|
+
|
|
125
|
+
return var, A, b, objective
|
|
126
|
+
|
|
127
|
+
# ------------------------------------ hvp ----------------------------------- #
|
|
128
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
129
|
+
def test_gradient(device):
|
|
130
|
+
"""makes sure gradient is correct"""
|
|
131
|
+
var, A, b, objective = get_var(device)
|
|
132
|
+
grad = var.get_grads()
|
|
133
|
+
assert torch.allclose(cat(grad), analytical_gradient(cat(var.params), A, b))
|
|
134
|
+
objective.assert_(true=1, false=0)
|
|
135
|
+
|
|
136
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
137
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
138
|
+
@pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
|
|
139
|
+
@pytest.mark.parametrize("get_grad", [True, False])
|
|
140
|
+
def test_hvp_autograd(device, at_x0, hvp_method, get_grad):
|
|
141
|
+
"""compares hessian-vector product with analytical"""
|
|
142
|
+
|
|
143
|
+
var, A, b, objective = get_var(device)
|
|
144
|
+
|
|
145
|
+
grad = None
|
|
146
|
+
if get_grad:
|
|
147
|
+
grad = var.get_grads(create_graph=True, at_x0=at_x0) # one false (one closure call with backward=False)
|
|
148
|
+
|
|
149
|
+
# generate random z
|
|
150
|
+
n = numel(var.params)
|
|
151
|
+
z = vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params)
|
|
152
|
+
|
|
153
|
+
# Hz
|
|
154
|
+
# this is for all following autograd tests
|
|
155
|
+
# if at_x0:
|
|
156
|
+
# one false call happens either in get_grad or here, so 1 false
|
|
157
|
+
# else:
|
|
158
|
+
# if get_grad, both get_grad and this call with false, so 2 false
|
|
159
|
+
# else only this calls with false, so 1 false
|
|
160
|
+
Hz, rgrad = var.hessian_vector_product(z, None, at_x0=at_x0, hvp_method=hvp_method, h=1e-3)
|
|
161
|
+
|
|
162
|
+
# check storage
|
|
163
|
+
assert rgrad is not None
|
|
164
|
+
if at_x0:
|
|
165
|
+
assert var.grads is not None
|
|
166
|
+
assert_tl_same_(var.grads, rgrad)
|
|
167
|
+
if grad is not None: assert_tl_same_(grad, rgrad)
|
|
168
|
+
else:
|
|
169
|
+
assert var.grads is None
|
|
170
|
+
if grad is not None: assert_tl_allclose_(grad, rgrad)
|
|
171
|
+
|
|
172
|
+
# check against known Hvp
|
|
173
|
+
x = cat(var.params)
|
|
174
|
+
assert torch.allclose(cat(rgrad), analytical_gradient(x, A, b))
|
|
175
|
+
assert torch.allclose(cat(Hz), analytical_hessian(x, A, b) @ cat(z))
|
|
176
|
+
|
|
177
|
+
# check evals
|
|
178
|
+
if at_x0: false = 1
|
|
179
|
+
else:
|
|
180
|
+
if get_grad: false = 2
|
|
181
|
+
else: false = 1
|
|
182
|
+
objective.assert_(true=0, false=false)
|
|
183
|
+
|
|
184
|
+
# -------------------------- hessian-matrix product -------------------------- #\
|
|
185
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
186
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
187
|
+
@pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
|
|
188
|
+
@pytest.mark.parametrize("get_grad", [True, False])
|
|
189
|
+
def test_hessian_matrix_product(device, at_x0, hvp_method, get_grad):
|
|
190
|
+
"""compares hessian-matrix product with analytical"""
|
|
191
|
+
|
|
192
|
+
var, A, b, objective = get_var(device)
|
|
193
|
+
if get_grad:
|
|
194
|
+
var.get_grads(create_graph=True, at_x0=at_x0) # one false
|
|
195
|
+
|
|
196
|
+
# generate random matrix
|
|
197
|
+
n = numel(var.params)
|
|
198
|
+
Z = torch.randn(n, n*2, device=device, generator=_gen(device))
|
|
199
|
+
|
|
200
|
+
# HZ same as above
|
|
201
|
+
HZ, rgrad = var.hessian_matrix_product(Z, rgrad=None, at_x0=at_x0, hvp_method=hvp_method, h=1e-3)
|
|
202
|
+
|
|
203
|
+
# check storage
|
|
204
|
+
assert rgrad is not None
|
|
205
|
+
if at_x0:
|
|
206
|
+
assert var.grads is not None
|
|
207
|
+
assert_tl_same_(rgrad, var.grads)
|
|
208
|
+
else:
|
|
209
|
+
assert var.grads is None
|
|
210
|
+
|
|
211
|
+
# check against known HZ
|
|
212
|
+
x = cat(var.params)
|
|
213
|
+
assert torch.allclose(HZ, analytical_hessian(x, A, b) @ Z, rtol=1e-4, atol=1e-6), f"{HZ = }, {A@Z = }"
|
|
214
|
+
|
|
215
|
+
# check evals
|
|
216
|
+
if at_x0: false = 1
|
|
217
|
+
else:
|
|
218
|
+
if get_grad: false = 2
|
|
219
|
+
else: false = 1
|
|
220
|
+
objective.assert_(true=0, false=false)
|
|
221
|
+
|
|
222
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
223
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
224
|
+
@pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd", "fd_forward", "fd_central"])
|
|
225
|
+
@pytest.mark.parametrize("h", [1e-1, 1e-2, 1e-3])
|
|
226
|
+
def test_hessian_vector_vs_matrix_product(device, at_x0, hvp_method, h):
|
|
227
|
+
"""compares hessian_vector_product and hessian_matrix_product, including fd"""
|
|
228
|
+
|
|
229
|
+
var, A, b, objective = get_var(device, dtype=torch.float64)
|
|
230
|
+
|
|
231
|
+
# generate random matrix
|
|
232
|
+
n = numel(var.params)
|
|
233
|
+
Z = torch.randn(n, n*2, device=device, generator=_gen(device))
|
|
234
|
+
z_vecs = [vec_to_tensors(col, var.params) for col in Z.unbind(1)]
|
|
235
|
+
|
|
236
|
+
# hessian-vector
|
|
237
|
+
rgrad = None
|
|
238
|
+
Hzs = []
|
|
239
|
+
for z in z_vecs:
|
|
240
|
+
Hz, rgrad = var.hessian_vector_product(z, rgrad=rgrad, at_x0=at_x0, hvp_method=hvp_method, h=h, retain_graph=True)
|
|
241
|
+
Hzs.append(cat(Hz))
|
|
242
|
+
|
|
243
|
+
# check evals (did n*2 hvps)
|
|
244
|
+
if hvp_method in ('autograd', 'batched_autograd'): objective.assert_(true=0, false=1)
|
|
245
|
+
elif hvp_method == 'fd_central': objective.assert_(true=n*4, false=0)
|
|
246
|
+
elif hvp_method == 'fd_forward': objective.assert_(true=n*2+1, false=0)
|
|
247
|
+
else: assert False, hvp_method
|
|
248
|
+
|
|
249
|
+
# clear evals
|
|
250
|
+
objective.true = objective.false = 0
|
|
251
|
+
|
|
252
|
+
# hessian-matrix
|
|
253
|
+
HZ, rgrad = var.hessian_matrix_product(Z, rgrad=rgrad, at_x0=at_x0, hvp_method=hvp_method, h=h)
|
|
254
|
+
|
|
255
|
+
# check evals (did n*2 hvps, initial grad is rgrad)
|
|
256
|
+
if hvp_method in ('autograd', 'batched_autograd'): objective.assert_(true=0, false=0)
|
|
257
|
+
elif hvp_method == 'fd_central': objective.assert_(true=n*4, false=0)
|
|
258
|
+
elif hvp_method == 'fd_forward': objective.assert_(true=n*2, false=0)
|
|
259
|
+
else: assert False, hvp_method
|
|
260
|
+
|
|
261
|
+
# check storage
|
|
262
|
+
if hvp_method == 'fd_central': assert rgrad is None
|
|
263
|
+
else: assert rgrad is not None
|
|
264
|
+
|
|
265
|
+
if at_x0:
|
|
266
|
+
if hvp_method == 'fd_central': assert var.grads is None
|
|
267
|
+
else:
|
|
268
|
+
assert var.grads is not None
|
|
269
|
+
assert rgrad is not None
|
|
270
|
+
assert_tl_same_(rgrad, var.grads)
|
|
271
|
+
else:
|
|
272
|
+
assert var.grads is None
|
|
273
|
+
|
|
274
|
+
# check that they match
|
|
275
|
+
assert torch.allclose(HZ, torch.stack(Hzs, dim=-1)), f"{HZ = }, {torch.stack(Hzs, dim=-1) = }"
|
|
276
|
+
|
|
277
|
+
# -------------------------------- hutchinson -------------------------------- #
|
|
278
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
279
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
280
|
+
@pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
|
|
281
|
+
@pytest.mark.parametrize("zHz", [True, False])
|
|
282
|
+
@pytest.mark.parametrize("get_grad", [True, False])
|
|
283
|
+
def test_hutchinson(device, at_x0, hvp_method, zHz, get_grad):
|
|
284
|
+
"""compares autograd hutchinson with one computed with analytical hessian-vector products"""
|
|
285
|
+
|
|
286
|
+
var, A, b, objective = get_var(device)
|
|
287
|
+
if get_grad:
|
|
288
|
+
var.get_grads(create_graph=True, at_x0=at_x0) # one false
|
|
289
|
+
|
|
290
|
+
# 10 random vecs
|
|
291
|
+
n = numel(var.params)
|
|
292
|
+
zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
|
|
293
|
+
|
|
294
|
+
# compute hutchinson estimate, same as above
|
|
295
|
+
D, rgrad = var.hutchinson_hessian(rgrad=None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_method, h=1e-3, zHz=zHz, generator=None)
|
|
296
|
+
|
|
297
|
+
# check storage
|
|
298
|
+
assert rgrad is not None
|
|
299
|
+
if at_x0:
|
|
300
|
+
assert var.grads is not None
|
|
301
|
+
if at_x0: assert_tl_same_(var.grads, rgrad)
|
|
302
|
+
else:
|
|
303
|
+
assert var.grads is None
|
|
304
|
+
|
|
305
|
+
# compute D via known hvp
|
|
306
|
+
x = cat(var.params)
|
|
307
|
+
z_vecs = [cat(z) for z in zs]
|
|
308
|
+
Hzs = [analytical_hessian(x, A, b) @ z for z in z_vecs]
|
|
309
|
+
D2 = torch.stack(Hzs)
|
|
310
|
+
if zHz: D2 *= torch.stack(z_vecs)
|
|
311
|
+
D2 = D2.mean(0)
|
|
312
|
+
|
|
313
|
+
# compare Ds
|
|
314
|
+
assert_tl_allclose_(D, vec_to_tensors(D2, var.params))
|
|
315
|
+
|
|
316
|
+
# check evals
|
|
317
|
+
if at_x0: false = 1
|
|
318
|
+
else:
|
|
319
|
+
if get_grad: false = 2
|
|
320
|
+
else: false = 1
|
|
321
|
+
objective.assert_(true=0, false=false)
|
|
322
|
+
|
|
323
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
324
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
325
|
+
@pytest.mark.parametrize("zHz", [True, False])
|
|
326
|
+
@pytest.mark.parametrize("get_grad", [True, False])
|
|
327
|
+
@pytest.mark.parametrize("pass_rgrad", [True, False])
|
|
328
|
+
def test_hutchinson_batching(device, at_x0, zHz, get_grad, pass_rgrad):
|
|
329
|
+
"""compares batched and unbatched hutchinson"""
|
|
330
|
+
|
|
331
|
+
var, A, b, objective = get_var(device)
|
|
332
|
+
if get_grad:
|
|
333
|
+
var.get_grads(create_graph=True, at_x0=at_x0) # one false
|
|
334
|
+
|
|
335
|
+
# 10 random vecs
|
|
336
|
+
n = numel(var.params)
|
|
337
|
+
zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
|
|
338
|
+
|
|
339
|
+
# compute hutchinson estimate, same as above
|
|
340
|
+
D, rgrad = var.hutchinson_hessian(rgrad=None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method='autograd', h=1e-3, zHz=zHz, generator=None, retain_graph=True)
|
|
341
|
+
|
|
342
|
+
# check evals
|
|
343
|
+
if at_x0: false = 1
|
|
344
|
+
else:
|
|
345
|
+
if get_grad: false = 2
|
|
346
|
+
else: false = 1
|
|
347
|
+
objective.assert_(true=0, false=false)
|
|
348
|
+
|
|
349
|
+
# reset evals
|
|
350
|
+
objective.true = objective.false = 0
|
|
351
|
+
|
|
352
|
+
# compute batched hutchinson estimate, if not at x0, one false if not pass_rgrad
|
|
353
|
+
D2, rgrad2 = var.hutchinson_hessian(rgrad=rgrad if pass_rgrad else None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method='batched_autograd', h=1e-3, zHz=zHz, generator=None)
|
|
354
|
+
|
|
355
|
+
# check storage
|
|
356
|
+
assert rgrad is not None
|
|
357
|
+
assert rgrad2 is not None
|
|
358
|
+
if at_x0:
|
|
359
|
+
assert var.grads is not None
|
|
360
|
+
assert_tl_same_(var.grads, rgrad2)
|
|
361
|
+
else:
|
|
362
|
+
assert var.grads is None
|
|
363
|
+
if at_x0 or pass_rgrad: assert_tl_same_(rgrad, rgrad2)
|
|
364
|
+
|
|
365
|
+
# make sure Ds match
|
|
366
|
+
assert_tl_allclose_(D, D2)
|
|
367
|
+
|
|
368
|
+
# check evals
|
|
369
|
+
if at_x0 or pass_rgrad: false = 0
|
|
370
|
+
else: false = 1
|
|
371
|
+
objective.assert_(true=0, false=false)
|
|
372
|
+
|
|
373
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
374
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
375
|
+
@pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
|
|
376
|
+
@pytest.mark.parametrize("hvp_fd_method", ["fd_forward", "fd_central"])
|
|
377
|
+
@pytest.mark.parametrize("zHz", [True, False])
|
|
378
|
+
def test_hutchinson_fd(device, at_x0, hvp_method, hvp_fd_method, zHz):
|
|
379
|
+
"""compares exact and FD hutchinson"""
|
|
380
|
+
|
|
381
|
+
var, A, b, objective = get_var(device)
|
|
382
|
+
|
|
383
|
+
# 10 random vecs
|
|
384
|
+
n = numel(var.params)
|
|
385
|
+
zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
|
|
386
|
+
|
|
387
|
+
# compute hutchinson D, always one false
|
|
388
|
+
D, rgrad = var.hutchinson_hessian(rgrad=None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_method, h=1e-3, zHz=zHz, generator=None)
|
|
389
|
+
|
|
390
|
+
# compute finite difference hutchinson D
|
|
391
|
+
# rgrad is already computed
|
|
392
|
+
# fd_forward 10 true, fd_central 20 true
|
|
393
|
+
D_fd, rgrad = var.hutchinson_hessian(rgrad=rgrad, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_fd_method, h=1e-3, zHz=zHz, generator=None)
|
|
394
|
+
|
|
395
|
+
# make sure they are close
|
|
396
|
+
assert_tl_allclose_(D, D_fd, rtol=1e-2, atol=1e-2)
|
|
397
|
+
|
|
398
|
+
# check evals
|
|
399
|
+
assert objective.false == 1
|
|
400
|
+
if hvp_fd_method == 'fd_forward':
|
|
401
|
+
assert objective.true == 10
|
|
402
|
+
else:
|
|
403
|
+
assert objective.true == 20
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
408
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
409
|
+
@pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd", "fd_forward", "fd_central"])
|
|
410
|
+
@pytest.mark.parametrize("h", [1e-1, 1e-2, 1e-3])
|
|
411
|
+
@pytest.mark.parametrize("zHz", [True, False])
|
|
412
|
+
@pytest.mark.parametrize("get_grad", [True, False])
|
|
413
|
+
@pytest.mark.parametrize("pass_rgrad", [True, False])
|
|
414
|
+
def test_hvp_vs_hutchinson(device, at_x0, hvp_method, h, zHz, get_grad, pass_rgrad):
|
|
415
|
+
"""compares hutchinson via hessian_vector_product and via hutchinson methods, including fd"""
|
|
416
|
+
|
|
417
|
+
var, A, b, objective = get_var(device)
|
|
418
|
+
if get_grad:
|
|
419
|
+
var.get_grads(create_graph=hvp_method in ("autograd", "batched_autograd"), at_x0=at_x0) # one false or true
|
|
420
|
+
|
|
421
|
+
# generate 10 vecs
|
|
422
|
+
n = numel(var.params)
|
|
423
|
+
zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
|
|
424
|
+
|
|
425
|
+
# mean of 10 z * Hz
|
|
426
|
+
# autograd and batched autograd - same as above
|
|
427
|
+
# fd forward
|
|
428
|
+
# if at_x0, first true either here or in get_grad, then 10 true, so total always 11 true
|
|
429
|
+
# else extra true in get_grad so 12 true
|
|
430
|
+
# fd central - 20 true plus one if get_grad
|
|
431
|
+
D = [torch.zeros_like(t) for t in var.params]
|
|
432
|
+
rgrad = None
|
|
433
|
+
for z in zs:
|
|
434
|
+
Hz, rgrad = var.hessian_vector_product(z, rgrad, at_x0=at_x0, hvp_method=hvp_method, h=h, retain_graph=True)
|
|
435
|
+
|
|
436
|
+
if zHz: torch._foreach_mul_(Hz, z)
|
|
437
|
+
torch._foreach_add_(D, Hz, alpha = 1/10)
|
|
438
|
+
|
|
439
|
+
# check storage
|
|
440
|
+
if not at_x0: assert var.grads is None
|
|
441
|
+
else:
|
|
442
|
+
if hvp_method == 'fd_central':
|
|
443
|
+
assert rgrad is None
|
|
444
|
+
if get_grad: assert var.grads is not None
|
|
445
|
+
|
|
446
|
+
else:
|
|
447
|
+
assert var.grads is not None
|
|
448
|
+
assert rgrad is not None
|
|
449
|
+
assert_tl_same_(var.grads, rgrad)
|
|
450
|
+
|
|
451
|
+
# check number of evals
|
|
452
|
+
if hvp_method in ('autograd', 'batched_autograd'):
|
|
453
|
+
if at_x0: false = 1
|
|
454
|
+
else:
|
|
455
|
+
if get_grad: false = 2
|
|
456
|
+
else: false = 1
|
|
457
|
+
objective.assert_(true=0, false=false)
|
|
458
|
+
|
|
459
|
+
elif hvp_method == "fd_forward":
|
|
460
|
+
if get_grad and not at_x0: true = 12
|
|
461
|
+
else: true = 11
|
|
462
|
+
objective.assert_(true=true, false=0)
|
|
463
|
+
|
|
464
|
+
elif hvp_method == 'fd_central':
|
|
465
|
+
if get_grad: objective.assert_(true=21, false=0)
|
|
466
|
+
else: objective.assert_(true=20, false=0)
|
|
467
|
+
|
|
468
|
+
else:
|
|
469
|
+
assert False, hvp_method
|
|
470
|
+
|
|
471
|
+
# reset evals
|
|
472
|
+
objective.true = objective.false = 0
|
|
473
|
+
|
|
474
|
+
# compute hutchinson hessian
|
|
475
|
+
# number of evals
|
|
476
|
+
# autograd/batched autograd - one false only if both pass_rgrad and at_x0 are False, else 0
|
|
477
|
+
# fd_forward - 11 true if both pass_rgrad and at_x0 are False, else 10 true
|
|
478
|
+
# fd_central - always 20 true
|
|
479
|
+
D2, rgrad2 = var.hutchinson_hessian(rgrad=rgrad if pass_rgrad else None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_method, h=h, zHz=zHz, generator=None)
|
|
480
|
+
|
|
481
|
+
# check storage
|
|
482
|
+
if hvp_method != "fd_central":
|
|
483
|
+
assert rgrad is not None
|
|
484
|
+
assert rgrad2 is not None
|
|
485
|
+
if at_x0 or pass_rgrad: assert_tl_same_(rgrad, rgrad2)
|
|
486
|
+
else: assert_tl_allclose_(rgrad, rgrad2)
|
|
487
|
+
|
|
488
|
+
# check that Ds match
|
|
489
|
+
assert_tl_allclose_(D, D2)
|
|
490
|
+
|
|
491
|
+
# check evals
|
|
492
|
+
# check number of evals
|
|
493
|
+
if hvp_method in ('autograd', 'batched_autograd'):
|
|
494
|
+
if at_x0 or pass_rgrad: false = 0
|
|
495
|
+
else: false = 1
|
|
496
|
+
objective.assert_(true=0, false=false)
|
|
497
|
+
|
|
498
|
+
elif hvp_method == "fd_forward":
|
|
499
|
+
if at_x0 or pass_rgrad: objective.assert_(true=10, false=0)
|
|
500
|
+
else: objective.assert_(true=11, false=0)
|
|
501
|
+
elif hvp_method == 'fd_central':
|
|
502
|
+
objective.assert_(true=20, false=0)
|
|
503
|
+
else:
|
|
504
|
+
assert False, hvp_method
|
|
505
|
+
|
|
506
|
+
# update should be none after all of this
|
|
507
|
+
assert var.updates is None
|
|
508
|
+
|
|
509
|
+
_HESSIAN_METHODS = [
|
|
510
|
+
"batched_autograd",
|
|
511
|
+
"autograd",
|
|
512
|
+
"functional_revrev",
|
|
513
|
+
# "functional_fwdrev", # has shape issue
|
|
514
|
+
"func",
|
|
515
|
+
"gfd_forward",
|
|
516
|
+
"gfd_central",
|
|
517
|
+
"fd",
|
|
518
|
+
"fd_full",
|
|
519
|
+
]
|
|
520
|
+
|
|
521
|
+
# if find_spec("thoad") is not None: _HESSIAN_METHODS.append("thoad")
|
|
522
|
+
# SqueezeBackward4 is not supported.
|
|
523
|
+
|
|
524
|
+
@pytest.mark.parametrize("device", DEVICES)
|
|
525
|
+
@pytest.mark.parametrize("at_x0", [True, False])
|
|
526
|
+
@pytest.mark.parametrize("hessian_method", _HESSIAN_METHODS)
|
|
527
|
+
def test_hessian(device, at_x0, hessian_method):
|
|
528
|
+
"""compares hessian with analytical, including gfd and fd"""
|
|
529
|
+
|
|
530
|
+
var, A, b, objective = get_var(device, dtype=torch.float64)
|
|
531
|
+
n = numel(var.params)
|
|
532
|
+
|
|
533
|
+
# compute hessian
|
|
534
|
+
if hessian_method in ("fd", "fd_full"): h = 1e-2
|
|
535
|
+
else: h = 1e-5
|
|
536
|
+
f, g_list, H = var.hessian(hessian_method=hessian_method, h=h, at_x0=at_x0)
|
|
537
|
+
|
|
538
|
+
# check storages
|
|
539
|
+
if hessian_method in ("batched_autograd", "autograd", "gfd_forward", "fd", "fd_full"):
|
|
540
|
+
if hessian_method == "gfd_forward": assert f is None
|
|
541
|
+
else: assert f == objective.closure(False)
|
|
542
|
+
assert g_list is not None
|
|
543
|
+
if at_x0:
|
|
544
|
+
assert var.grads is not None
|
|
545
|
+
assert_tl_same_(g_list, var.grads)
|
|
546
|
+
else:
|
|
547
|
+
assert var.grads is None
|
|
548
|
+
else:
|
|
549
|
+
assert f is None
|
|
550
|
+
assert g_list is None
|
|
551
|
+
assert var.grads is None
|
|
552
|
+
|
|
553
|
+
# compare with analytical
|
|
554
|
+
x = cat(var.params)
|
|
555
|
+
H_real = analytical_hessian(x, A, b)
|
|
556
|
+
if hessian_method in ("gfd_forward", "gfd_central"):
|
|
557
|
+
assert torch.allclose(H, H_real, rtol=1e-1, atol=1e-1), f"{H = }, {H_real = }"
|
|
558
|
+
|
|
559
|
+
elif hessian_method in ("fd", "fd_full"):
|
|
560
|
+
# assert torch.allclose(H, H_real, rtol=1e-1, atol=1e-1), f"{H = }, {H_real = }"
|
|
561
|
+
# TODO find a good test
|
|
562
|
+
|
|
563
|
+
# compare gradient with analytical
|
|
564
|
+
g_real = analytical_gradient(x, A, b)
|
|
565
|
+
assert g_list is not None
|
|
566
|
+
assert torch.allclose(cat(g_list), g_real, rtol=1e-2, atol=1e-2), f"{cat(g_list) = }, {g_real = }"
|
|
567
|
+
|
|
568
|
+
else:
|
|
569
|
+
assert torch.allclose(H, H_real), f"{H = }, {H_real = }"
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
# check evals
|
|
573
|
+
if hessian_method == "gfd_forward":
|
|
574
|
+
objective.assert_(true=n+1, false=0)
|
|
575
|
+
|
|
576
|
+
elif hessian_method == "gfd_central":
|
|
577
|
+
objective.assert_(true=n*2, false=0)
|
|
578
|
+
|
|
579
|
+
elif hessian_method == "fd":
|
|
580
|
+
objective.assert_(true=0, false=2*n**2 + 1)
|
|
581
|
+
|
|
582
|
+
elif hessian_method == "fd_full":
|
|
583
|
+
objective.assert_(true=0, false=4*n**2 - 2*n + 1)
|
|
584
|
+
|
|
585
|
+
else:
|
|
586
|
+
objective.assert_(true=0, false=1)
|