torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
tests/test_identical.py CHANGED
@@ -219,8 +219,8 @@ def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
219
219
 
220
220
  @pytest.mark.parametrize('tensorwise', [True, False])
221
221
  def test_graft(tensorwise):
222
- graft1 = lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
- graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.Graft([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
222
+ graft1 = lambda p: tz.Modular(p, tz.m.Graft(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
223
+ graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.GraftInputToOutput([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
224
224
  _assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
225
225
  for fn in [graft1, graft2]:
226
226
  if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)
@@ -0,0 +1,586 @@
1
+ from importlib.util import find_spec
2
+ # pylint:disable=deprecated-method
3
+ from typing import Any
4
+ from collections.abc import Sequence
5
+
6
+ import pytest
7
+ import torch
8
+
9
+ import torchzero as tz
10
+ from torchzero.utils import TensorList, vec_to_tensors
11
+
12
+ # ----------------------------------- utils ---------------------------------- #
13
+ DEVICES = ["cpu"]
14
+ if torch.cuda.is_available(): DEVICES.append("cuda")
15
+ DEVICES = tuple(DEVICES)
16
+
17
+ def _gen(device):
18
+ return torch.Generator(device).manual_seed(0)
19
+
20
+ def cat(ts: Sequence[torch.Tensor]):
21
+ return torch.cat([t.flatten() for t in ts])
22
+
23
+ def numel(ts: Sequence[torch.Tensor]):
24
+ return sum(t.numel() for t in ts)
25
+
26
+ def assert_tl_equal_(tl1: Sequence[torch.Tensor | Any], tl2: Sequence[torch.Tensor | Any]):
27
+ assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
28
+ for t1, t2 in zip(tl1, tl2):
29
+ if t1 is None and t2 is None:
30
+ continue
31
+ assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
32
+ assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
33
+ assert torch.equal(t1, t2), f"Tensors are not equal:\n{t1}\nvs\n{t2}"
34
+
35
+ def assert_tl_allclose_(tl1: Sequence[torch.Tensor | Any], tl2: Sequence[torch.Tensor | Any], **kwargs):
36
+ assert len(tl1) == len(tl2), f"TensorLists have different lengths:\n{[t.shape for t in tl1]}\n{[t.shape for t in tl2]};"
37
+ for t1, t2 in zip(tl1, tl2):
38
+ if t1 is None and t2 is None:
39
+ continue
40
+ assert t1 is not None and t2 is not None, "One tensor is None, the other is not"
41
+ assert t1.shape == t2.shape, f"Tensors have different shapes:\n{t1}\nvs\n{t2}"
42
+ assert torch.allclose(t1, t2, equal_nan=True, **kwargs), f"Tensors are not close:\n{t1}\nvs\n{t2}"
43
+
44
+ def assert_tl_same_(seq1: Sequence[torch.Tensor], seq2: Sequence[torch.Tensor]):
45
+ seq1=tuple(seq1)
46
+ seq2=tuple(seq2)
47
+ assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
48
+ for t1, t2 in zip(seq1, seq2):
49
+ assert t1 is t2
50
+
51
+
52
+ def assert_tl_same_storage_(seq1: Sequence[torch.Tensor], seq2: Sequence[torch.Tensor]):
53
+ seq1=tuple(seq1)
54
+ seq2=tuple(seq2)
55
+ assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
56
+ for t1, t2 in zip(seq1, seq2):
57
+ assert t1.data_ptr() == t2.data_ptr()
58
+
59
+ class _EvalCounter:
60
+ def __init__(self, closure):
61
+ self.closure = closure
62
+ self.false = 0
63
+ self.true = 0
64
+
65
+ def __call__(self, backward=True):
66
+ if backward: self.true += 1
67
+ else: self.false += 1
68
+ return self.closure(backward)
69
+
70
+ def assert_(self, true:int, false:int):
71
+ assert true == self.true
72
+ assert false == self.false
73
+
74
+ def __repr__(self):
75
+ return f"EvalCounter(true={self.true}, false={self.false})"
76
+
77
+ # --------------------------------- objective --------------------------------
78
+
79
+ def objective_value(x:torch.Tensor, A:torch.Tensor, b:torch.Tensor):
80
+ return 0.5 * x @ A @ x + (b @ x).exp()
81
+
82
+ def analytical_gradient(x:torch.Tensor, A:torch.Tensor, b:torch.Tensor):
83
+ return A @ x + (b @ x).exp() * b
84
+
85
+ def analytical_hessian(x:torch.Tensor, A:torch.Tensor, b:torch.Tensor):
86
+ return A + (b @ x).exp() * b.outer(b)
87
+
88
+ def analytical_derivative(x: torch.Tensor, b:torch.Tensor, order: int) -> torch.Tensor:
89
+ assert order >= 3
90
+ # n-th order outer product
91
+ # n=4 -> 'i,j,k,l->ijkl'
92
+ indices = 'ijklmnopqrstuvwxyz'[:order]
93
+ b_outer = torch.einsum(f"{','.join(indices)}->{indices}", *[b] * order)
94
+ return (b @ x).exp() * b_outer
95
+
96
+
97
+ def get_var(device, dtype=torch.float32):
98
+
99
+ # we cat a few tensors to make sure those methods handle multiple params correctly
100
+ p1 = torch.tensor(1., requires_grad=True, device=device, dtype=dtype)
101
+ p2 = torch.randn(1, 3, 2, requires_grad=True, device=device, generator=_gen(device), dtype=dtype)
102
+ p3 = torch.randn(4, requires_grad=True, device=device, generator=_gen(device), dtype=dtype)
103
+
104
+ params = [p1, p2, p3]
105
+ n = numel(params)
106
+
107
+ A = torch.randn(n, n, device=device, generator=_gen(device), dtype=dtype)
108
+ A = A.T @ A + torch.eye(n, device=device, dtype=dtype) * 1e-3
109
+ b = torch.randn(n, device=device, generator=_gen(device), dtype=dtype)
110
+
111
+ def closure(backward=True):
112
+ x = cat(params)
113
+ loss = objective_value(x, A, b)
114
+
115
+ if backward:
116
+ for p in params:
117
+ p.grad = None
118
+ loss.backward()
119
+
120
+ return loss
121
+
122
+ objective = _EvalCounter(closure)
123
+ var = tz.core.Objective(params=params, closure=objective, model=None, current_step=0)
124
+
125
+ return var, A, b, objective
126
+
127
+ # ------------------------------------ hvp ----------------------------------- #
128
+ @pytest.mark.parametrize("device", DEVICES)
129
+ def test_gradient(device):
130
+ """makes sure gradient is correct"""
131
+ var, A, b, objective = get_var(device)
132
+ grad = var.get_grads()
133
+ assert torch.allclose(cat(grad), analytical_gradient(cat(var.params), A, b))
134
+ objective.assert_(true=1, false=0)
135
+
136
+ @pytest.mark.parametrize("device", DEVICES)
137
+ @pytest.mark.parametrize("at_x0", [True, False])
138
+ @pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
139
+ @pytest.mark.parametrize("get_grad", [True, False])
140
+ def test_hvp_autograd(device, at_x0, hvp_method, get_grad):
141
+ """compares hessian-vector product with analytical"""
142
+
143
+ var, A, b, objective = get_var(device)
144
+
145
+ grad = None
146
+ if get_grad:
147
+ grad = var.get_grads(create_graph=True, at_x0=at_x0) # one false (one closure call with backward=False)
148
+
149
+ # generate random z
150
+ n = numel(var.params)
151
+ z = vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params)
152
+
153
+ # Hz
154
+ # this is for all following autograd tests
155
+ # if at_x0:
156
+ # one false call happens either in get_grad or here, so 1 false
157
+ # else:
158
+ # if get_grad, both get_grad and this call with false, so 2 false
159
+ # else only this calls with false, so 1 false
160
+ Hz, rgrad = var.hessian_vector_product(z, None, at_x0=at_x0, hvp_method=hvp_method, h=1e-3)
161
+
162
+ # check storage
163
+ assert rgrad is not None
164
+ if at_x0:
165
+ assert var.grads is not None
166
+ assert_tl_same_(var.grads, rgrad)
167
+ if grad is not None: assert_tl_same_(grad, rgrad)
168
+ else:
169
+ assert var.grads is None
170
+ if grad is not None: assert_tl_allclose_(grad, rgrad)
171
+
172
+ # check against known Hvp
173
+ x = cat(var.params)
174
+ assert torch.allclose(cat(rgrad), analytical_gradient(x, A, b))
175
+ assert torch.allclose(cat(Hz), analytical_hessian(x, A, b) @ cat(z))
176
+
177
+ # check evals
178
+ if at_x0: false = 1
179
+ else:
180
+ if get_grad: false = 2
181
+ else: false = 1
182
+ objective.assert_(true=0, false=false)
183
+
184
+ # -------------------------- hessian-matrix product -------------------------- #\
185
+ @pytest.mark.parametrize("device", DEVICES)
186
+ @pytest.mark.parametrize("at_x0", [True, False])
187
+ @pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
188
+ @pytest.mark.parametrize("get_grad", [True, False])
189
+ def test_hessian_matrix_product(device, at_x0, hvp_method, get_grad):
190
+ """compares hessian-matrix product with analytical"""
191
+
192
+ var, A, b, objective = get_var(device)
193
+ if get_grad:
194
+ var.get_grads(create_graph=True, at_x0=at_x0) # one false
195
+
196
+ # generate random matrix
197
+ n = numel(var.params)
198
+ Z = torch.randn(n, n*2, device=device, generator=_gen(device))
199
+
200
+ # HZ same as above
201
+ HZ, rgrad = var.hessian_matrix_product(Z, rgrad=None, at_x0=at_x0, hvp_method=hvp_method, h=1e-3)
202
+
203
+ # check storage
204
+ assert rgrad is not None
205
+ if at_x0:
206
+ assert var.grads is not None
207
+ assert_tl_same_(rgrad, var.grads)
208
+ else:
209
+ assert var.grads is None
210
+
211
+ # check against known HZ
212
+ x = cat(var.params)
213
+ assert torch.allclose(HZ, analytical_hessian(x, A, b) @ Z, rtol=1e-4, atol=1e-6), f"{HZ = }, {A@Z = }"
214
+
215
+ # check evals
216
+ if at_x0: false = 1
217
+ else:
218
+ if get_grad: false = 2
219
+ else: false = 1
220
+ objective.assert_(true=0, false=false)
221
+
222
+ @pytest.mark.parametrize("device", DEVICES)
223
+ @pytest.mark.parametrize("at_x0", [True, False])
224
+ @pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd", "fd_forward", "fd_central"])
225
+ @pytest.mark.parametrize("h", [1e-1, 1e-2, 1e-3])
226
+ def test_hessian_vector_vs_matrix_product(device, at_x0, hvp_method, h):
227
+ """compares hessian_vector_product and hessian_matrix_product, including fd"""
228
+
229
+ var, A, b, objective = get_var(device, dtype=torch.float64)
230
+
231
+ # generate random matrix
232
+ n = numel(var.params)
233
+ Z = torch.randn(n, n*2, device=device, generator=_gen(device))
234
+ z_vecs = [vec_to_tensors(col, var.params) for col in Z.unbind(1)]
235
+
236
+ # hessian-vector
237
+ rgrad = None
238
+ Hzs = []
239
+ for z in z_vecs:
240
+ Hz, rgrad = var.hessian_vector_product(z, rgrad=rgrad, at_x0=at_x0, hvp_method=hvp_method, h=h, retain_graph=True)
241
+ Hzs.append(cat(Hz))
242
+
243
+ # check evals (did n*2 hvps)
244
+ if hvp_method in ('autograd', 'batched_autograd'): objective.assert_(true=0, false=1)
245
+ elif hvp_method == 'fd_central': objective.assert_(true=n*4, false=0)
246
+ elif hvp_method == 'fd_forward': objective.assert_(true=n*2+1, false=0)
247
+ else: assert False, hvp_method
248
+
249
+ # clear evals
250
+ objective.true = objective.false = 0
251
+
252
+ # hessian-matrix
253
+ HZ, rgrad = var.hessian_matrix_product(Z, rgrad=rgrad, at_x0=at_x0, hvp_method=hvp_method, h=h)
254
+
255
+ # check evals (did n*2 hvps, initial grad is rgrad)
256
+ if hvp_method in ('autograd', 'batched_autograd'): objective.assert_(true=0, false=0)
257
+ elif hvp_method == 'fd_central': objective.assert_(true=n*4, false=0)
258
+ elif hvp_method == 'fd_forward': objective.assert_(true=n*2, false=0)
259
+ else: assert False, hvp_method
260
+
261
+ # check storage
262
+ if hvp_method == 'fd_central': assert rgrad is None
263
+ else: assert rgrad is not None
264
+
265
+ if at_x0:
266
+ if hvp_method == 'fd_central': assert var.grads is None
267
+ else:
268
+ assert var.grads is not None
269
+ assert rgrad is not None
270
+ assert_tl_same_(rgrad, var.grads)
271
+ else:
272
+ assert var.grads is None
273
+
274
+ # check that they match
275
+ assert torch.allclose(HZ, torch.stack(Hzs, dim=-1)), f"{HZ = }, {torch.stack(Hzs, dim=-1) = }"
276
+
277
+ # -------------------------------- hutchinson -------------------------------- #
278
+ @pytest.mark.parametrize("device", DEVICES)
279
+ @pytest.mark.parametrize("at_x0", [True, False])
280
+ @pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
281
+ @pytest.mark.parametrize("zHz", [True, False])
282
+ @pytest.mark.parametrize("get_grad", [True, False])
283
+ def test_hutchinson(device, at_x0, hvp_method, zHz, get_grad):
284
+ """compares autograd hutchinson with one computed with analytical hessian-vector products"""
285
+
286
+ var, A, b, objective = get_var(device)
287
+ if get_grad:
288
+ var.get_grads(create_graph=True, at_x0=at_x0) # one false
289
+
290
+ # 10 random vecs
291
+ n = numel(var.params)
292
+ zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
293
+
294
+ # compute hutchinson estimate, same as above
295
+ D, rgrad = var.hutchinson_hessian(rgrad=None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_method, h=1e-3, zHz=zHz, generator=None)
296
+
297
+ # check storage
298
+ assert rgrad is not None
299
+ if at_x0:
300
+ assert var.grads is not None
301
+ if at_x0: assert_tl_same_(var.grads, rgrad)
302
+ else:
303
+ assert var.grads is None
304
+
305
+ # compute D via known hvp
306
+ x = cat(var.params)
307
+ z_vecs = [cat(z) for z in zs]
308
+ Hzs = [analytical_hessian(x, A, b) @ z for z in z_vecs]
309
+ D2 = torch.stack(Hzs)
310
+ if zHz: D2 *= torch.stack(z_vecs)
311
+ D2 = D2.mean(0)
312
+
313
+ # compare Ds
314
+ assert_tl_allclose_(D, vec_to_tensors(D2, var.params))
315
+
316
+ # check evals
317
+ if at_x0: false = 1
318
+ else:
319
+ if get_grad: false = 2
320
+ else: false = 1
321
+ objective.assert_(true=0, false=false)
322
+
323
+ @pytest.mark.parametrize("device", DEVICES)
324
+ @pytest.mark.parametrize("at_x0", [True, False])
325
+ @pytest.mark.parametrize("zHz", [True, False])
326
+ @pytest.mark.parametrize("get_grad", [True, False])
327
+ @pytest.mark.parametrize("pass_rgrad", [True, False])
328
+ def test_hutchinson_batching(device, at_x0, zHz, get_grad, pass_rgrad):
329
+ """compares batched and unbatched hutchinson"""
330
+
331
+ var, A, b, objective = get_var(device)
332
+ if get_grad:
333
+ var.get_grads(create_graph=True, at_x0=at_x0) # one false
334
+
335
+ # 10 random vecs
336
+ n = numel(var.params)
337
+ zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
338
+
339
+ # compute hutchinson estimate, same as above
340
+ D, rgrad = var.hutchinson_hessian(rgrad=None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method='autograd', h=1e-3, zHz=zHz, generator=None, retain_graph=True)
341
+
342
+ # check evals
343
+ if at_x0: false = 1
344
+ else:
345
+ if get_grad: false = 2
346
+ else: false = 1
347
+ objective.assert_(true=0, false=false)
348
+
349
+ # reset evals
350
+ objective.true = objective.false = 0
351
+
352
+ # compute batched hutchinson estimate, if not at x0, one false if not pass_rgrad
353
+ D2, rgrad2 = var.hutchinson_hessian(rgrad=rgrad if pass_rgrad else None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method='batched_autograd', h=1e-3, zHz=zHz, generator=None)
354
+
355
+ # check storage
356
+ assert rgrad is not None
357
+ assert rgrad2 is not None
358
+ if at_x0:
359
+ assert var.grads is not None
360
+ assert_tl_same_(var.grads, rgrad2)
361
+ else:
362
+ assert var.grads is None
363
+ if at_x0 or pass_rgrad: assert_tl_same_(rgrad, rgrad2)
364
+
365
+ # make sure Ds match
366
+ assert_tl_allclose_(D, D2)
367
+
368
+ # check evals
369
+ if at_x0 or pass_rgrad: false = 0
370
+ else: false = 1
371
+ objective.assert_(true=0, false=false)
372
+
373
+ @pytest.mark.parametrize("device", DEVICES)
374
+ @pytest.mark.parametrize("at_x0", [True, False])
375
+ @pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd"])
376
+ @pytest.mark.parametrize("hvp_fd_method", ["fd_forward", "fd_central"])
377
+ @pytest.mark.parametrize("zHz", [True, False])
378
+ def test_hutchinson_fd(device, at_x0, hvp_method, hvp_fd_method, zHz):
379
+ """compares exact and FD hutchinson"""
380
+
381
+ var, A, b, objective = get_var(device)
382
+
383
+ # 10 random vecs
384
+ n = numel(var.params)
385
+ zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
386
+
387
+ # compute hutchinson D, always one false
388
+ D, rgrad = var.hutchinson_hessian(rgrad=None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_method, h=1e-3, zHz=zHz, generator=None)
389
+
390
+ # compute finite difference hutchinson D
391
+ # rgrad is already computed
392
+ # fd_forward 10 true, fd_central 20 true
393
+ D_fd, rgrad = var.hutchinson_hessian(rgrad=rgrad, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_fd_method, h=1e-3, zHz=zHz, generator=None)
394
+
395
+ # make sure they are close
396
+ assert_tl_allclose_(D, D_fd, rtol=1e-2, atol=1e-2)
397
+
398
+ # check evals
399
+ assert objective.false == 1
400
+ if hvp_fd_method == 'fd_forward':
401
+ assert objective.true == 10
402
+ else:
403
+ assert objective.true == 20
404
+
405
+
406
+
407
+ @pytest.mark.parametrize("device", DEVICES)
408
+ @pytest.mark.parametrize("at_x0", [True, False])
409
+ @pytest.mark.parametrize("hvp_method", ["autograd", "batched_autograd", "fd_forward", "fd_central"])
410
+ @pytest.mark.parametrize("h", [1e-1, 1e-2, 1e-3])
411
+ @pytest.mark.parametrize("zHz", [True, False])
412
+ @pytest.mark.parametrize("get_grad", [True, False])
413
+ @pytest.mark.parametrize("pass_rgrad", [True, False])
414
+ def test_hvp_vs_hutchinson(device, at_x0, hvp_method, h, zHz, get_grad, pass_rgrad):
415
+ """compares hutchinson via hessian_vector_product and via hutchinson methods, including fd"""
416
+
417
+ var, A, b, objective = get_var(device)
418
+ if get_grad:
419
+ var.get_grads(create_graph=hvp_method in ("autograd", "batched_autograd"), at_x0=at_x0) # one false or true
420
+
421
+ # generate 10 vecs
422
+ n = numel(var.params)
423
+ zs = [vec_to_tensors(torch.randn(n, device=device, generator=_gen(device)), var.params) for _ in range(10)]
424
+
425
+ # mean of 10 z * Hz
426
+ # autograd and batched autograd - same as above
427
+ # fd forward
428
+ # if at_x0, first true either here or in get_grad, then 10 true, so total always 11 true
429
+ # else extra true in get_grad so 12 true
430
+ # fd central - 20 true plus one if get_grad
431
+ D = [torch.zeros_like(t) for t in var.params]
432
+ rgrad = None
433
+ for z in zs:
434
+ Hz, rgrad = var.hessian_vector_product(z, rgrad, at_x0=at_x0, hvp_method=hvp_method, h=h, retain_graph=True)
435
+
436
+ if zHz: torch._foreach_mul_(Hz, z)
437
+ torch._foreach_add_(D, Hz, alpha = 1/10)
438
+
439
+ # check storage
440
+ if not at_x0: assert var.grads is None
441
+ else:
442
+ if hvp_method == 'fd_central':
443
+ assert rgrad is None
444
+ if get_grad: assert var.grads is not None
445
+
446
+ else:
447
+ assert var.grads is not None
448
+ assert rgrad is not None
449
+ assert_tl_same_(var.grads, rgrad)
450
+
451
+ # check number of evals
452
+ if hvp_method in ('autograd', 'batched_autograd'):
453
+ if at_x0: false = 1
454
+ else:
455
+ if get_grad: false = 2
456
+ else: false = 1
457
+ objective.assert_(true=0, false=false)
458
+
459
+ elif hvp_method == "fd_forward":
460
+ if get_grad and not at_x0: true = 12
461
+ else: true = 11
462
+ objective.assert_(true=true, false=0)
463
+
464
+ elif hvp_method == 'fd_central':
465
+ if get_grad: objective.assert_(true=21, false=0)
466
+ else: objective.assert_(true=20, false=0)
467
+
468
+ else:
469
+ assert False, hvp_method
470
+
471
+ # reset evals
472
+ objective.true = objective.false = 0
473
+
474
+ # compute hutchinson hessian
475
+ # number of evals
476
+ # autograd/batched autograd - one false only if both pass_rgrad and at_x0 are False, else 0
477
+ # fd_forward - 11 true if both pass_rgrad and at_x0 are False, else 10 true
478
+ # fd_central - always 20 true
479
+ D2, rgrad2 = var.hutchinson_hessian(rgrad=rgrad if pass_rgrad else None, at_x0=at_x0, n_samples=None, distribution=zs, hvp_method=hvp_method, h=h, zHz=zHz, generator=None)
480
+
481
+ # check storage
482
+ if hvp_method != "fd_central":
483
+ assert rgrad is not None
484
+ assert rgrad2 is not None
485
+ if at_x0 or pass_rgrad: assert_tl_same_(rgrad, rgrad2)
486
+ else: assert_tl_allclose_(rgrad, rgrad2)
487
+
488
+ # check that Ds match
489
+ assert_tl_allclose_(D, D2)
490
+
491
+ # check evals
492
+ # check number of evals
493
+ if hvp_method in ('autograd', 'batched_autograd'):
494
+ if at_x0 or pass_rgrad: false = 0
495
+ else: false = 1
496
+ objective.assert_(true=0, false=false)
497
+
498
+ elif hvp_method == "fd_forward":
499
+ if at_x0 or pass_rgrad: objective.assert_(true=10, false=0)
500
+ else: objective.assert_(true=11, false=0)
501
+ elif hvp_method == 'fd_central':
502
+ objective.assert_(true=20, false=0)
503
+ else:
504
+ assert False, hvp_method
505
+
506
+ # update should be none after all of this
507
+ assert var.updates is None
508
+
509
+ _HESSIAN_METHODS = [
510
+ "batched_autograd",
511
+ "autograd",
512
+ "functional_revrev",
513
+ # "functional_fwdrev", # has shape issue
514
+ "func",
515
+ "gfd_forward",
516
+ "gfd_central",
517
+ "fd",
518
+ "fd_full",
519
+ ]
520
+
521
+ # if find_spec("thoad") is not None: _HESSIAN_METHODS.append("thoad")
522
+ # SqueezeBackward4 is not supported.
523
+
524
+ @pytest.mark.parametrize("device", DEVICES)
525
+ @pytest.mark.parametrize("at_x0", [True, False])
526
+ @pytest.mark.parametrize("hessian_method", _HESSIAN_METHODS)
527
+ def test_hessian(device, at_x0, hessian_method):
528
+ """compares hessian with analytical, including gfd and fd"""
529
+
530
+ var, A, b, objective = get_var(device, dtype=torch.float64)
531
+ n = numel(var.params)
532
+
533
+ # compute hessian
534
+ if hessian_method in ("fd", "fd_full"): h = 1e-2
535
+ else: h = 1e-5
536
+ f, g_list, H = var.hessian(hessian_method=hessian_method, h=h, at_x0=at_x0)
537
+
538
+ # check storages
539
+ if hessian_method in ("batched_autograd", "autograd", "gfd_forward", "fd", "fd_full"):
540
+ if hessian_method == "gfd_forward": assert f is None
541
+ else: assert f == objective.closure(False)
542
+ assert g_list is not None
543
+ if at_x0:
544
+ assert var.grads is not None
545
+ assert_tl_same_(g_list, var.grads)
546
+ else:
547
+ assert var.grads is None
548
+ else:
549
+ assert f is None
550
+ assert g_list is None
551
+ assert var.grads is None
552
+
553
+ # compare with analytical
554
+ x = cat(var.params)
555
+ H_real = analytical_hessian(x, A, b)
556
+ if hessian_method in ("gfd_forward", "gfd_central"):
557
+ assert torch.allclose(H, H_real, rtol=1e-1, atol=1e-1), f"{H = }, {H_real = }"
558
+
559
+ elif hessian_method in ("fd", "fd_full"):
560
+ # assert torch.allclose(H, H_real, rtol=1e-1, atol=1e-1), f"{H = }, {H_real = }"
561
+ # TODO find a good test
562
+
563
+ # compare gradient with analytical
564
+ g_real = analytical_gradient(x, A, b)
565
+ assert g_list is not None
566
+ assert torch.allclose(cat(g_list), g_real, rtol=1e-2, atol=1e-2), f"{cat(g_list) = }, {g_real = }"
567
+
568
+ else:
569
+ assert torch.allclose(H, H_real), f"{H = }, {H_real = }"
570
+
571
+
572
+ # check evals
573
+ if hessian_method == "gfd_forward":
574
+ objective.assert_(true=n+1, false=0)
575
+
576
+ elif hessian_method == "gfd_central":
577
+ objective.assert_(true=n*2, false=0)
578
+
579
+ elif hessian_method == "fd":
580
+ objective.assert_(true=0, false=2*n**2 + 1)
581
+
582
+ elif hessian_method == "fd_full":
583
+ objective.assert_(true=0, false=4*n**2 - 2*n + 1)
584
+
585
+ else:
586
+ objective.assert_(true=0, false=1)