torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,99 +1,513 @@
1
- from collections.abc import Sequence, Iterable
2
-
3
- import torch
4
-
5
- def _jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
6
- flat_input = torch.cat([i.reshape(-1) for i in input])
7
- grad_ouputs = torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype)
8
- jac = []
9
- for i in range(flat_input.numel()):
10
- jac.append(torch.autograd.grad(
11
- flat_input,
12
- wrt,
13
- grad_ouputs[i],
14
- retain_graph=True,
15
- create_graph=create_graph,
16
- allow_unused=True,
17
- is_grads_batched=False,
18
- ))
19
- return [torch.stack(z) for z in zip(*jac)]
20
-
21
-
22
-
23
- def _jacobian_batched(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
24
- flat_input = torch.cat([i.reshape(-1) for i in input])
25
- return torch.autograd.grad(
26
- flat_input,
27
- wrt,
28
- torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype),
29
- retain_graph=True,
30
- create_graph=create_graph,
31
- allow_unused=True,
32
- is_grads_batched=True,
33
- )
34
-
35
- def jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
36
- """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
37
- Returns a sequence of tensors with the length as `wrt`.
38
- Each tensor will have the shape `(*input.shape, *wrt[i].shape)`.
39
-
40
- Args:
41
- input (Sequence[torch.Tensor]): input sequence of tensors.
42
- wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
43
- create_graph (bool, optional):
44
- pytorch option, if True, graph of the derivative will be constructed,
45
- allowing to compute higher order derivative products. Default: False.
46
- batched (bool, optional): use faster but experimental pytorch batched jacobian
47
- This only has effect when `input` has more than 1 element. Defaults to True.
48
-
49
- Returns:
50
- sequence of tensors with the length as `wrt`.
51
- """
52
- if batched: return _jacobian_batched(input, wrt, create_graph)
53
- return _jacobian(input, wrt, create_graph)
54
-
55
- def hessian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
56
- """Calculate hessian of a sequence of tensors w.r.t another sequence of tensors.
57
- Returns a sequence of tensors with the length as `wrt`.
58
- If you need a hessian matrix out of that sequence, pass it to `hessian_list_to_mat`.
59
-
60
- Args:
61
- input (Sequence[torch.Tensor]): input sequence of tensors.
62
- wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
63
- create_graph (bool, optional):
64
- pytorch option, if True, graph of the derivative will be constructed,
65
- allowing to compute higher order derivative products. Default: False.
66
- batched (bool, optional): use faster but experimental pytorch batched grad. Defaults to True.
67
-
68
- Returns:
69
- sequence of tensors with the length as `wrt`.
70
- """
71
- return jacobian(jacobian(input, wrt, create_graph=True, batched=batched), wrt, create_graph=create_graph, batched=batched)
72
-
73
- def jacobian_and_hessian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
74
- """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
75
- Calculating hessian requires calculating the jacobian. So this function is more efficient than
76
- calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
77
-
78
- Args:
79
- input (Sequence[torch.Tensor]): input sequence of tensors.
80
- wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
81
- create_graph (bool, optional):
82
- pytorch option, if True, graph of the derivative will be constructed,
83
- allowing to compute higher order derivative products. Default: False.
84
- batched (bool, optional): use faster but experimental pytorch batched grad. Defaults to True.
85
-
86
- Returns:
87
- tuple with jacobians sequence and hessians sequence.
88
- """
89
- jac = jacobian(input, wrt, create_graph=True, batched = batched)
90
- return jac, jacobian(jac, wrt, batched = batched, create_graph=create_graph)
91
-
92
- def jacobian_list_to_vec(jacobians: Iterable[torch.Tensor]):
93
- """flattens and concatenates a sequence of tensors."""
94
- return torch.cat([i.ravel() for i in jacobians], 0)
95
-
96
- def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
97
- """takes output of `hessian` and returns the 2D hessian matrix.
98
- Note - I only tested this for cases where input is a scalar."""
99
- return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
1
+ from collections.abc import Iterable, Sequence
2
+
3
+ import torch
4
+ import torch.autograd.forward_ad as fwAD
5
+
6
+ from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
7
+
8
+ def _jacobian(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
9
+ flat_input = torch.cat([i.reshape(-1) for i in output])
10
+ grad_ouputs = torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype)
11
+ jac = []
12
+ for i in range(flat_input.numel()):
13
+ jac.append(torch.autograd.grad(
14
+ flat_input,
15
+ wrt,
16
+ grad_ouputs[i],
17
+ retain_graph=True,
18
+ create_graph=create_graph,
19
+ allow_unused=True,
20
+ is_grads_batched=False,
21
+ ))
22
+ return [torch.stack(z) for z in zip(*jac)]
23
+
24
+
25
+ def _jacobian_batched(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
26
+ flat_input = torch.cat([i.reshape(-1) for i in output])
27
+ return torch.autograd.grad(
28
+ flat_input,
29
+ wrt,
30
+ torch.eye(len(flat_input), device=output[0].device, dtype=output[0].dtype),
31
+ retain_graph=True,
32
+ create_graph=create_graph,
33
+ allow_unused=True,
34
+ is_grads_batched=True,
35
+ )
36
+
37
+ def jacobian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True) -> Sequence[torch.Tensor]:
38
+ """Calculate jacobian of a sequence of tensors w.r.t another sequence of tensors.
39
+ Returns a sequence of tensors with the length as `wrt`.
40
+ Each tensor will have the shape `(*input.shape, *wrt[i].shape)`.
41
+
42
+ Args:
43
+ input (Sequence[torch.Tensor]): input sequence of tensors.
44
+ wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
45
+ create_graph (bool, optional):
46
+ pytorch option, if True, graph of the derivative will be constructed,
47
+ allowing to compute higher order derivative products. Default: False.
48
+ batched (bool, optional): use faster but experimental pytorch batched jacobian
49
+ This only has effect when `input` has more than 1 element. Defaults to True.
50
+
51
+ Returns:
52
+ sequence of tensors with the length as `wrt`.
53
+ """
54
+ if batched: return _jacobian_batched(output, wrt, create_graph)
55
+ return _jacobian(output, wrt, create_graph)
56
+
57
+ def jacobian_and_hessian_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
58
+ """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
59
+ Calculating hessian requires calculating the jacobian. So this function is more efficient than
60
+ calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
61
+
62
+ Args:
63
+ input (Sequence[torch.Tensor]): input sequence of tensors.
64
+ wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
65
+ create_graph (bool, optional):
66
+ pytorch option, if True, graph of the derivative will be constructed,
67
+ allowing to compute higher order derivative products. Default: False.
68
+ batched (bool, optional): use faster but experimental pytorch batched grad. Defaults to True.
69
+
70
+ Returns:
71
+ tuple with jacobians sequence and hessians sequence.
72
+ """
73
+ jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
74
+ return jac, jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
75
+
76
+
77
+ def hessian_list_to_mat(hessians: Sequence[torch.Tensor]):
78
+ """takes output of `hessian` and returns the 2D hessian matrix.
79
+ Note - I only tested this for cases where input is a scalar."""
80
+ return torch.cat([h.reshape(h.size(0), h[1].numel()) for h in hessians], 1)
81
+
82
+ def jacobian_and_hessian_mat_wrt(output: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False, batched=True):
83
+ """Calculate jacobian and hessian of a sequence of tensors w.r.t another sequence of tensors.
84
+ Calculating hessian requires calculating the jacobian. So this function is more efficient than
85
+ calling `jacobian` and `hessian` separately, which would calculate jacobian twice.
86
+
87
+ Args:
88
+ input (Sequence[torch.Tensor]): input sequence of tensors.
89
+ wrt (Sequence[torch.Tensor]): sequence of tensors to differentiate w.r.t.
90
+ create_graph (bool, optional):
91
+ pytorch option, if True, graph of the derivative will be constructed,
92
+ allowing to compute higher order derivative products. Default: False.
93
+ batched (bool, optional): use faster but experimental pytorch batched grad. Defaults to True.
94
+
95
+ Returns:
96
+ tuple with jacobians sequence and hessians sequence.
97
+ """
98
+ jac = jacobian_wrt(output, wrt, create_graph=True, batched = batched)
99
+ H_list = jacobian_wrt(jac, wrt, batched = batched, create_graph=create_graph)
100
+ return torch.cat([j.view(-1) for j in jac]), hessian_list_to_mat(H_list)
101
+
102
+ def hessian(
103
+ fn,
104
+ params: Iterable[torch.Tensor],
105
+ create_graph=False,
106
+ method="func",
107
+ vectorize=False,
108
+ outer_jacobian_strategy="reverse-mode",
109
+ ):
110
+ """
111
+ returns list of lists of lists of values of hessian matrix of each param wrt each param.
112
+ To just get a single matrix use the :code:`hessian_mat` function.
113
+
114
+ `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
115
+
116
+ Example:
117
+ .. code:: py
118
+
119
+ model = nn.Linear(4, 2) # (2, 4) weight and (2, ) bias
120
+ X = torch.randn(10, 4)
121
+ y = torch.randn(10, 2)
122
+
123
+ def fn():
124
+ y_hat = model(X)
125
+ loss = F.mse_loss(y_hat, y)
126
+ return loss
127
+
128
+ hessian_mat(fn, model.parameters()) # list of two lists of two lists of 3D and 4D tensors
129
+
130
+
131
+ """
132
+ params = list(params)
133
+
134
+ def func(x: list[torch.Tensor]):
135
+ for p, x_i in zip(params, x): swap_tensors_no_use_count_check(p, x_i)
136
+ loss = fn()
137
+ for p, x_i in zip(params, x): swap_tensors_no_use_count_check(p, x_i)
138
+ return loss
139
+
140
+ if method == 'func':
141
+ return torch.func.hessian(func)([p.detach().requires_grad_(create_graph) for p in params])
142
+
143
+ if method == 'autograd.functional':
144
+ return torch.autograd.functional.hessian(
145
+ func,
146
+ [p.detach() for p in params],
147
+ create_graph=create_graph,
148
+ vectorize=vectorize,
149
+ outer_jacobian_strategy=outer_jacobian_strategy,
150
+ )
151
+ raise ValueError(method)
152
+
153
+ def hessian_mat(
154
+ fn,
155
+ params: Iterable[torch.Tensor],
156
+ create_graph=False,
157
+ method="func",
158
+ vectorize=False,
159
+ outer_jacobian_strategy="reverse-mode",
160
+ ):
161
+ """
162
+ returns hessian matrix for parameters (as if they were flattened and concatenated into a vector).
163
+
164
+ `vectorize` and `outer_jacobian_strategy` are only for `method = "torch.autograd"`, refer to its documentation.
165
+
166
+ Example:
167
+ .. code:: py
168
+
169
+ model = nn.Linear(4, 2) # 10 parameters in total
170
+ X = torch.randn(10, 4)
171
+ y = torch.randn(10, 2)
172
+
173
+ def fn():
174
+ y_hat = model(X)
175
+ loss = F.mse_loss(y_hat, y)
176
+ return loss
177
+
178
+ hessian_mat(fn, model.parameters()) # 10x10 tensor
179
+
180
+
181
+ """
182
+ params = list(params)
183
+
184
+ def func(x: torch.Tensor):
185
+ x_params = vec_to_tensors(x, params)
186
+ for p, x_i in zip(params, x_params): swap_tensors_no_use_count_check(p, x_i)
187
+ loss = fn()
188
+ for p, x_i in zip(params, x_params): swap_tensors_no_use_count_check(p, x_i)
189
+ return loss
190
+
191
+ if method == 'func':
192
+ return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph))
193
+
194
+ if method == 'autograd.functional':
195
+ return torch.autograd.functional.hessian(
196
+ func,
197
+ torch.cat([p.view(-1) for p in params]).detach(),
198
+ create_graph=create_graph,
199
+ vectorize=vectorize,
200
+ outer_jacobian_strategy=outer_jacobian_strategy,
201
+ )
202
+ raise ValueError(method)
203
+
204
+ def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
205
+ """Jacobian vector product.
206
+
207
+ Example:
208
+ .. code:: py
209
+
210
+ model = nn.Linear(4, 2)
211
+ X = torch.randn(10, 4)
212
+ y = torch.randn(10, 2)
213
+
214
+ tangent = [torch.randn_like(p) for p in model.parameters()]
215
+
216
+ def fn():
217
+ y_hat = model(X)
218
+ loss = F.mse_loss(y_hat, y)
219
+ return loss
220
+
221
+ jvp(fn, model.parameters(), tangent) # scalar
222
+
223
+ """
224
+ params = list(params)
225
+ tangent = list(tangent)
226
+ detached_params = [p.detach() for p in params]
227
+
228
+ duals = []
229
+ with fwAD.dual_level():
230
+ for p, d, t in zip(params, detached_params, tangent):
231
+ dual = fwAD.make_dual(d, t).requires_grad_(p.requires_grad)
232
+ duals.append(dual)
233
+ swap_tensors_no_use_count_check(p, dual)
234
+
235
+ loss = fn()
236
+ res = fwAD.unpack_dual(loss).tangent
237
+
238
+ for p, d in zip(params, duals):
239
+ swap_tensors_no_use_count_check(p, d)
240
+ return loss, res
241
+
242
+
243
+
244
+ @torch.no_grad
245
+ def jvp_fd_central(
246
+ fn,
247
+ params: Iterable[torch.Tensor],
248
+ tangent: Iterable[torch.Tensor],
249
+ h=1e-3,
250
+ normalize=False,
251
+ ) -> tuple[torch.Tensor | None, torch.Tensor]:
252
+ """Jacobian vector product using central finite difference formula.
253
+
254
+ Example:
255
+ .. code:: py
256
+
257
+ model = nn.Linear(4, 2)
258
+ X = torch.randn(10, 4)
259
+ y = torch.randn(10, 2)
260
+
261
+ tangent = [torch.randn_like(p) for p in model.parameters()]
262
+
263
+ def fn():
264
+ y_hat = model(X)
265
+ loss = F.mse_loss(y_hat, y)
266
+ return loss
267
+
268
+ jvp_fd_central(fn, model.parameters(), tangent) # scalar
269
+
270
+ """
271
+ params = list(params)
272
+ tangent = list(tangent)
273
+
274
+ tangent_norm = None
275
+ if normalize:
276
+ tangent_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in tangent])) # pylint:disable=not-callable
277
+ if tangent_norm == 0: return None, torch.tensor(0., device=tangent[0].device, dtype=tangent[0].dtype)
278
+ tangent = torch._foreach_div(tangent, tangent_norm)
279
+
280
+ tangent_h= torch._foreach_mul(tangent, h)
281
+
282
+ torch._foreach_add_(params, tangent_h)
283
+ v_plus = fn()
284
+ torch._foreach_sub_(params, tangent_h)
285
+ torch._foreach_sub_(params, tangent_h)
286
+ v_minus = fn()
287
+ torch._foreach_add_(params, tangent_h)
288
+
289
+ res = (v_plus - v_minus) / (2 * h)
290
+ if normalize: res = res * tangent_norm
291
+ return v_plus, res
292
+
293
+ @torch.no_grad
294
+ def jvp_fd_forward(
295
+ fn,
296
+ params: Iterable[torch.Tensor],
297
+ tangent: Iterable[torch.Tensor],
298
+ h=1e-3,
299
+ v_0=None,
300
+ normalize=False,
301
+ ) -> tuple[torch.Tensor | None, torch.Tensor]:
302
+ """Jacobian vector product using forward finite difference formula.
303
+ Loss at initial point can be specified in the `v_0` argument.
304
+
305
+ Example:
306
+ .. code:: py
307
+
308
+ model = nn.Linear(4, 2)
309
+ X = torch.randn(10, 4)
310
+ y = torch.randn(10, 2)
311
+
312
+ tangent1 = [torch.randn_like(p) for p in model.parameters()]
313
+ tangent2 = [torch.randn_like(p) for p in model.parameters()]
314
+
315
+ def fn():
316
+ y_hat = model(X)
317
+ loss = F.mse_loss(y_hat, y)
318
+ return loss
319
+
320
+ v_0 = fn() # pre-calculate loss at initial point
321
+
322
+ jvp1 = jvp_fd_forward(fn, model.parameters(), tangent1, v_0=v_0) # scalar
323
+ jvp2 = jvp_fd_forward(fn, model.parameters(), tangent2, v_0=v_0) # scalar
324
+
325
+ """
326
+ params = list(params)
327
+ tangent = list(tangent)
328
+
329
+ tangent_norm = None
330
+ if normalize:
331
+ tangent_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in tangent])) # pylint:disable=not-callable
332
+ if tangent_norm == 0: return None, torch.tensor(0., device=tangent[0].device, dtype=tangent[0].dtype)
333
+ tangent = torch._foreach_div(tangent, tangent_norm)
334
+
335
+ tangent_h= torch._foreach_mul(tangent, h)
336
+
337
+ if v_0 is None: v_0 = fn()
338
+
339
+ torch._foreach_add_(params, tangent_h)
340
+ v_plus = fn()
341
+ torch._foreach_sub_(params, tangent_h)
342
+
343
+ res = (v_plus - v_0) / h
344
+ if normalize: res = res * tangent_norm
345
+ return v_0, res
346
+
347
+ def hvp(
348
+ params: Iterable[torch.Tensor],
349
+ grads: Iterable[torch.Tensor],
350
+ vec: Iterable[torch.Tensor],
351
+ retain_graph=None,
352
+ create_graph=False,
353
+ allow_unused=None,
354
+ ):
355
+ """Hessian-vector product
356
+
357
+ Example:
358
+ .. code:: py
359
+
360
+ model = nn.Linear(4, 2)
361
+ X = torch.randn(10, 4)
362
+ y = torch.randn(10, 2)
363
+
364
+ y_hat = model(X)
365
+ loss = F.mse_loss(y_hat, y)
366
+ loss.backward(create_graph=True)
367
+
368
+ grads = [p.grad for p in model.parameters()]
369
+ vec = [torch.randn_like(p) for p in model.parameters()]
370
+
371
+ # list of tensors, same layout as model.parameters()
372
+ hvp(model.parameters(), grads, vec=vec)
373
+ """
374
+ params = list(params)
375
+ g = list(grads)
376
+ vec = list(vec)
377
+
378
+ with torch.enable_grad():
379
+ return torch.autograd.grad(g, params, vec, create_graph=create_graph, retain_graph=retain_graph, allow_unused=allow_unused)
380
+
381
+
382
+ @torch.no_grad
383
+ def hvp_fd_central(
384
+ closure,
385
+ params: Iterable[torch.Tensor],
386
+ vec: Iterable[torch.Tensor],
387
+ h=1e-3,
388
+ normalize=False,
389
+ ) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
390
+ """Hessian-vector product using central finite difference formula.
391
+
392
+ Please note that this will clear :code:`grad` attributes in params.
393
+
394
+ Example:
395
+ .. code:: py
396
+
397
+ model = nn.Linear(4, 2)
398
+ X = torch.randn(10, 4)
399
+ y = torch.randn(10, 2)
400
+
401
+ def closure():
402
+ y_hat = model(X)
403
+ loss = F.mse_loss(y_hat, y)
404
+ model.zero_grad()
405
+ loss.backward()
406
+ return loss
407
+
408
+ vec = [torch.randn_like(p) for p in model.parameters()]
409
+
410
+ # list of tensors, same layout as model.parameters()
411
+ hvp_fd_central(closure, model.parameters(), vec=vec)
412
+ """
413
+ params = list(params)
414
+ vec = list(vec)
415
+
416
+ vec_norm = None
417
+ if normalize:
418
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
419
+ if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
420
+ vec = torch._foreach_div(vec, vec_norm)
421
+
422
+ vec_h = torch._foreach_mul(vec, h)
423
+ torch._foreach_add_(params, vec_h)
424
+ with torch.enable_grad(): loss = closure()
425
+ g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
426
+
427
+ torch._foreach_sub_(params, vec_h)
428
+ torch._foreach_sub_(params, vec_h)
429
+ with torch.enable_grad(): loss = closure()
430
+ g_minus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
431
+
432
+ torch._foreach_add_(params, vec_h)
433
+ for p in params: p.grad = None
434
+
435
+ hvp_ = g_plus
436
+ torch._foreach_sub_(hvp_, g_minus)
437
+ torch._foreach_div_(hvp_, 2*h)
438
+
439
+ if normalize: torch._foreach_mul_(hvp_, vec_norm)
440
+ return loss, hvp_
441
+
442
+ @torch.no_grad
443
+ def hvp_fd_forward(
444
+ closure,
445
+ params: Iterable[torch.Tensor],
446
+ vec: Iterable[torch.Tensor],
447
+ h=1e-3,
448
+ g_0=None,
449
+ normalize=False,
450
+ ) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
451
+ """Hessian-vector product using forward finite difference formula.
452
+
453
+ Gradient at initial point can be specified in the `g_0` argument.
454
+
455
+ Please note that this will clear :code:`grad` attributes in params.
456
+
457
+ Example:
458
+ .. code:: py
459
+
460
+ model = nn.Linear(4, 2)
461
+ X = torch.randn(10, 4)
462
+ y = torch.randn(10, 2)
463
+
464
+ def closure():
465
+ y_hat = model(X)
466
+ loss = F.mse_loss(y_hat, y)
467
+ model.zero_grad()
468
+ loss.backward()
469
+ return loss
470
+
471
+ vec = [torch.randn_like(p) for p in model.parameters()]
472
+
473
+ # pre-compute gradient at initial point
474
+ closure()
475
+ g_0 = [p.grad for p in model.parameters()]
476
+
477
+ # list of tensors, same layout as model.parameters()
478
+ hvp_fd_forward(closure, model.parameters(), vec=vec, g_0=g_0)
479
+ """
480
+
481
+ params = list(params)
482
+ vec = list(vec)
483
+ loss = None
484
+
485
+ vec_norm = None
486
+ if normalize:
487
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
488
+ if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
489
+ vec = torch._foreach_div(vec, vec_norm)
490
+
491
+ vec_h = torch._foreach_mul(vec, h)
492
+
493
+ if g_0 is None:
494
+ with torch.enable_grad(): loss = closure()
495
+ g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
496
+ else:
497
+ g_0 = list(g_0)
498
+
499
+ torch._foreach_add_(params, vec_h)
500
+ with torch.enable_grad():
501
+ l = closure()
502
+ if loss is None: loss = l
503
+ g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
504
+
505
+ torch._foreach_sub_(params, vec_h)
506
+ for p in params: p.grad = None
507
+
508
+ hvp_ = g_plus
509
+ torch._foreach_sub_(hvp_, g_0)
510
+ torch._foreach_div_(hvp_, h)
511
+
512
+ if normalize: torch._foreach_mul_(hvp_, vec_norm)
513
+ return loss, hvp_
@@ -0,0 +1,5 @@
1
+ from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix_power_eigh, x_inv
2
+ from .orthogonalize import gram_schmidt
3
+ from .qr import qr_householder
4
+ from .svd import randomized_svd
5
+ from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve