torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,948 @@
1
+
2
+ import warnings
3
+ from collections.abc import Callable, Sequence, Iterable
4
+ from contextlib import nullcontext
5
+ from functools import partial
6
+ from typing import TYPE_CHECKING, Any, Literal, cast
7
+
8
+ import torch
9
+
10
+ from ..utils import Distributions, TensorList, vec_to_tensors, set_storage_
11
+ from ..utils.derivatives import (
12
+ flatten_jacobian,
13
+ hessian_mat,
14
+ hvp_fd_central,
15
+ hvp_fd_forward,
16
+ jacobian_and_hessian_wrt,
17
+ jacobian_wrt,
18
+ hessian_fd,
19
+ )
20
+ from ..utils.thoad_tools import thoad_derivatives, thoad_single_tensor, lazy_thoad
21
+
22
+ if TYPE_CHECKING:
23
+ from .modular import Modular
24
+ from .module import Module
25
+
26
+ def _closure_backward(closure, params, backward, retain_graph, create_graph):
27
+ """Calls closure with specified ``backward``, ``retain_graph`` and ``create_graph``.
28
+
29
+ Returns loss and sets ``param.grad`` attributes.
30
+
31
+ If ``backward=True``, this uses ``torch.enable_grad()`` context.
32
+ """
33
+ if not backward:
34
+ return closure(False)
35
+
36
+ with torch.enable_grad():
37
+ if not (retain_graph or create_graph):
38
+ return closure()
39
+
40
+ # zero grad (because closure called with backward=False)
41
+ for p in params: p.grad = None
42
+
43
+ # loss
44
+ loss = closure(False).ravel()
45
+
46
+ # grad
47
+ grad = torch.autograd.grad(
48
+ loss,
49
+ params,
50
+ retain_graph=retain_graph,
51
+ create_graph=create_graph,
52
+ allow_unused=True,
53
+ materialize_grads=True,
54
+ )
55
+
56
+ # set p.grad
57
+ for p,g in zip(params,grad): p.grad = g
58
+ return loss
59
+
60
+ @torch.enable_grad
61
+ def _closure_loss_grad(closure, params, retain_graph, create_graph) -> tuple[torch.Tensor, list[torch.Tensor]]:
62
+ """Calls closure with specified ``backward``, ``retain_graph`` and ``create_graph``
63
+ within ``torch.enable_grad()``context.
64
+
65
+ Returns ``(loss, grad)``. Unlike ``_closure_backward``, this won't always set ``p.grad``.
66
+ """
67
+ if closure is None: raise RuntimeError("closure is None")
68
+
69
+ # use torch.autograd.grad
70
+ if retain_graph or create_graph:
71
+ loss = closure(False).ravel()
72
+ return loss, list(
73
+ torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph, allow_unused=True, materialize_grads=True)
74
+ )
75
+
76
+ # use backward
77
+ loss = closure()
78
+ return loss, [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
79
+
80
+ HVPMethod = Literal["batched_autograd", "autograd", "fd_forward", "fd_central"]
81
+ """
82
+ Determines how hessian-vector products are computed.
83
+
84
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
85
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
86
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
87
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
88
+
89
+ Defaults to ``"autograd"``.
90
+ """
91
+
92
+ HessianMethod = Literal[
93
+ "batched_autograd",
94
+ "autograd",
95
+ "functional_revrev",
96
+ "functional_fwdrev",
97
+ "func",
98
+ "gfd_forward",
99
+ "gfd_central",
100
+ "fd",
101
+ "fd_full",
102
+ "thoad",
103
+ ]
104
+ """
105
+ Determines how hessian is computed.
106
+
107
+ - ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
108
+ - ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
109
+ - ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
110
+ - ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
111
+ - ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
112
+ - ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
113
+ - ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
114
+ - ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. Only computes upper triangle of the hessian, requires ``2n^2 + 1`` function evaluations. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
115
+ - ``"fd_full"`` - uses function values to estimate gradient and hessian via finite difference. Computes both upper and lower triangles and averages them, requires ``4n^2 - 2n + 1`` function evaluations This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
116
+ - ``"thoad"`` - uses [thoad](https://github.com/mntsx/thoad) library (experimental).
117
+
118
+ Defaults to ``"batched_autograd"``.
119
+ """
120
+
121
+ DerivativesMethod = Literal["autograd", "batched_autograd", "thoad"]
122
+ """
123
+ Determines how higher order derivatives are computed.
124
+ """
125
+
126
+ class Objective:
127
+ """
128
+ Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
129
+ Modules take in a ``Objective`` object, modify and it is passed to the next module.
130
+
131
+ Args:
132
+ params (Iterable[torch.Tensor]): iterable of parameters that are being optimized.
133
+ closure (Callable | None, optional): callable that re-evaluates loss. Defaults to None.
134
+ loss (torch.Tensor | None, optional): loss at ``params``. Defaults to None.
135
+ model (torch.nn.Module | None, optional):
136
+ ``torch.nn.Module`` object, needed for a few modules that require access to the model. Defaults to None.
137
+ current_step (int, optional):
138
+ number of times ``Modular.step()`` has been called, starting at 0. Defaults to 0.
139
+ parent (Objective | None, optional):
140
+ parent ``Objective`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
141
+ Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
142
+ e.g. when projecting. Defaults to None.
143
+ modular (Modular | None, optional):
144
+ Top-level ``Modular`` optimizer. Defaults to None.
145
+ storage (dict | None, optional):
146
+ additional kwargs passed to ``step`` to control some module-specific behavior. Defaults to None.
147
+
148
+ """
149
+ def __init__(
150
+ self,
151
+ params: Iterable[torch.Tensor],
152
+ closure: Callable | None = None,
153
+ loss: torch.Tensor | None = None,
154
+ model: torch.nn.Module | None = None,
155
+ current_step: int = 0,
156
+ parent: "Objective | None" = None,
157
+ modular: "Modular | None" = None,
158
+ storage: dict | None = None,
159
+ ):
160
+ self.params: list[torch.Tensor] = list(params)
161
+ """List of all parameters with ``requires_grad = True``."""
162
+
163
+ self.closure = closure
164
+ """A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
165
+
166
+ self.model = model
167
+ """``torch.nn.Module`` object of the model, ``None`` if it wasn't specified."""
168
+
169
+ self.current_step: int = current_step
170
+ """global current step, starts at 0. This may not correspond to module current step,
171
+ for example a module may step every 10 global steps."""
172
+
173
+ self.parent: "Objective | None" = parent
174
+ """parent ``Objective`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
175
+ Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
176
+ e.g. when projecting."""
177
+
178
+ self.modular: "Modular | None" = modular
179
+ """Top-level ``Modular`` optimizer, ``None`` if it wasn't specified."""
180
+
181
+ self.updates: list[torch.Tensor] | None = None
182
+ """
183
+ current updates list. Update is assumed to be a transformed gradient, therefore it is subtracted.
184
+
185
+ If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
186
+
187
+ At the end ``objective.get_update()`` is subtracted from parameters.
188
+ Therefore if ``objective.update`` is ``None``, gradient will be used and calculated if needed.
189
+ """
190
+
191
+ self.grads: list[torch.Tensor] | None = None
192
+ """gradient with current parameters. If closure is not ``None``,
193
+ this is set to ``None`` and can be calculated if needed."""
194
+
195
+ self.loss: torch.Tensor | Any | None = loss
196
+ """loss with current parameters."""
197
+
198
+ self.loss_approx: torch.Tensor | Any | None = None
199
+ """loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
200
+ whereas some other modules require loss strictly at current point."""
201
+
202
+ self.post_step_hooks: "list[Callable[[Objective, tuple[Module, ...]], None]]" = []
203
+ """list of functions to be called after optimizer step.
204
+
205
+ This attribute should always be modified in-place (using ``append`` or ``extend``).
206
+
207
+ The signature is:
208
+
209
+ ```python
210
+ def hook(objective: Objective, modules: tuple[Module]): ...
211
+ ```
212
+ """
213
+
214
+ self.stop: bool = False
215
+ """if True, all following modules will be skipped.
216
+ If this module is a child, it only affects modules at the same level (in the same Chain)."""
217
+
218
+ self.skip_update: bool = False
219
+ """if True, the parameters will not be updated."""
220
+
221
+ # self.storage: dict = {}
222
+ # """Storage for any other data, such as hessian estimates, etc."""
223
+
224
+ self.attrs: dict = {}
225
+ """attributes, ``Modular.attrs`` is updated with this after each step.
226
+ This attribute should always be modified in-place"""
227
+
228
+ if storage is None: storage = {}
229
+ self.storage: dict = storage
230
+ """additional kwargs passed to ``step`` to control some module-specific behavior.
231
+ This attribute should always be modified in-place"""
232
+
233
+ self.should_terminate: bool | None = None
234
+ """termination criteria, ``Modular.should_terminate`` is set to this after each step if not ``None``"""
235
+
236
+ self.temp: Any = cast(Any, None)
237
+ """temporary storage, ``Module.update`` can set this and ``Module.apply`` access via ``objective.poptemp()``.
238
+ This doesn't get cloned."""
239
+
240
+ def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False, at_x0:bool=True) -> torch.Tensor:
241
+ """Returns the loss at current parameters, computing it if it hasn't been computed already
242
+ and assigning ``objective.loss``.Do not call this at perturbed parameters.
243
+ Backward always sets grads to None before recomputing.
244
+
245
+ If ``backward==True``, closure is called within ``torch.enable_grad()``
246
+ """
247
+
248
+ # at non-x0 point just call closure and return
249
+ if not at_x0:
250
+ if self.closure is None: raise RuntimeError("closure is None")
251
+ return _closure_backward(
252
+ self.closure, self.params, backward=backward, retain_graph=retain_graph, create_graph=create_graph,
253
+ )
254
+
255
+ # at x0 set self.loss and self.grads
256
+ if self.loss is None:
257
+
258
+ if self.closure is None: raise RuntimeError("closure is None")
259
+
260
+ # backward
261
+ if backward:
262
+ self.loss = self.loss_approx = _closure_backward(
263
+ closure=self.closure, params=self.params, backward=True, retain_graph=retain_graph, create_graph=create_graph
264
+ )
265
+
266
+ # next time closure() is called, it will set grad to None.
267
+ # zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
268
+ # because otherwise it will zero self.grads in-place
269
+ self.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
270
+
271
+ # no backward
272
+ else:
273
+ self.loss = self.loss_approx = _closure_backward(
274
+ closure=self.closure, params=self.params, backward=False, retain_graph=False, create_graph=False
275
+ )
276
+
277
+ # if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
278
+ # and now it is requested to be evaluated with backward.
279
+ if backward and self.grads is None:
280
+ warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
281
+ if self.closure is None: raise RuntimeError("closure is None")
282
+
283
+ self.loss = self.loss_approx = _closure_backward(
284
+ closure=self.closure, params=self.params, backward=True, retain_graph=retain_graph, create_graph=create_graph
285
+ )
286
+ self.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
287
+
288
+ # set parent grad
289
+ if self.parent is not None:
290
+ # the way projections/split work, they make a new closure which evaluates original
291
+ # closure and projects the gradient, and set it as their objective.closure.
292
+ # then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
293
+ # and we set it to parent objective here.
294
+ if self.parent.loss is None: self.parent.loss = self.loss
295
+ if self.parent.grads is None and backward:
296
+ if all(p.grad is None for p in self.parent.params):
297
+ warnings.warn("Parent grad is None after backward.")
298
+ self.parent.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
299
+
300
+ return self.loss # type:ignore
301
+
302
+ def get_grads(self, retain_graph: bool | None = None, create_graph: bool = False, at_x0: bool = True) -> list[torch.Tensor]:
303
+ """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning ``objective.grad`` and potentially ``objective.loss``. Do not call this at perturbed parameters."""
304
+ # at non-x0 point just call closure and return grads
305
+ if not at_x0:
306
+ _, grads = _closure_loss_grad(self.closure, self.params, retain_graph=retain_graph, create_graph=create_graph)
307
+ return grads
308
+
309
+ # at x0 get_loss sets self.loss and self.grads
310
+ if self.grads is None:
311
+ if self.closure is None: raise RuntimeError("closure is None")
312
+ self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
313
+
314
+ assert self.grads is not None
315
+ return self.grads
316
+
317
+
318
+ def get_loss_grads(self, retain_graph: bool | None = None, create_graph: bool = False, at_x0: bool = True) -> tuple[torch.Tensor, list[torch.Tensor]]:
319
+ """returns ``(loss, grads)``. Useful when you need both not at x0."""
320
+ # at non-x0 point just call closure and return (loss, grads)
321
+ if not at_x0:
322
+ return _closure_loss_grad(self.closure, self.params, retain_graph=retain_graph, create_graph=create_graph)
323
+
324
+ # at x0 get_grads sets self.loss and self.grads, then get_loss returns self.loss.
325
+ grad = self.get_grads(retain_graph=retain_graph, create_graph=create_graph)
326
+ loss = self.get_loss(False)
327
+ return loss, grad
328
+
329
+ def get_updates(self) -> list[torch.Tensor]:
330
+ """Returns the update. If update is None, it is initialized by cloning the gradients
331
+ and assigning to ``objective.update``. Computing the gradients may assign ``objective.grad``
332
+ and ``objective.loss`` if they haven't been computed. Do not call this at perturbed parameters."""
333
+ if self.updates is None: self.updates = [g.clone() for g in self.get_grads()]
334
+ return self.updates
335
+
336
+ def clone(self, clone_updates: bool, parent: "Objective | None" = None):
337
+ """Creates a shallow copy of this ``Objective``, update can optionally be deep-copied (via ``torch.clone``).
338
+
339
+ This copies over all attributes except ``temp``.
340
+
341
+ Setting ``parent`` is only if clone's parameters are something different,
342
+ while clone's closure referes to the same objective but with a "view" on parameters.
343
+ """
344
+ copy = Objective(
345
+ params=self.params, closure=self.closure, model=self.model, current_step=self.current_step,
346
+ parent=parent, modular=self.modular, loss=self.loss, storage=self.storage
347
+ )
348
+
349
+ if clone_updates and self.updates is not None:
350
+ copy.updates = [u.clone() for u in self.updates]
351
+ else:
352
+ copy.updates = self.updates
353
+
354
+ copy.grads = self.grads
355
+ copy.loss_approx = self.loss_approx
356
+ copy.post_step_hooks = self.post_step_hooks
357
+ copy.stop = self.stop
358
+ copy.skip_update = self.skip_update
359
+
360
+ copy.attrs = self.attrs
361
+ copy.should_terminate = self.should_terminate
362
+
363
+ return copy
364
+
365
+ def update_attrs_from_clone_(self, objective: "Objective"):
366
+ """Updates attributes of this ``Objective`` instance from a cloned instance.
367
+ Typically called after a child module has processed a cloned ``Objective``
368
+ object. This propagates any newly computed loss or gradient values
369
+ from the child's context back to the parent ``Objective`` if the parent
370
+ didn't have them computed already.
371
+
372
+ This copies over ``loss``, ``loss_approx``, ``grads``, ``should_terminate`` and ``skip_update``.
373
+
374
+ Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
375
+ if the child updates them, the update will affect the parent too.
376
+ """
377
+ if self.loss is None: self.loss = objective.loss
378
+ if self.loss_approx is None: self.loss_approx = objective.loss_approx
379
+ if self.grads is None: self.grads = objective.grads
380
+
381
+ if objective.should_terminate is not None: self.should_terminate = objective.should_terminate
382
+ if objective.skip_update: self.skip_update = objective.skip_update
383
+
384
+ @torch.no_grad
385
+ def zero_grad(self, set_to_none=True):
386
+ """In most cases not call with ``set_to_none=False``, as that will zero ``self.grads`` in-place."""
387
+ if set_to_none:
388
+ for p in self.params: p.grad = None
389
+ else:
390
+ grads = [p.grad for p in self.params if p.grad is not None]
391
+ if len(grads) != 0: torch._foreach_zero_(grads)
392
+
393
+ def poptemp(self):
394
+ """to pass information from ``update`` to ``apply``."""
395
+ temp = self.temp
396
+ self.temp = None
397
+ return temp
398
+
399
+ @torch.no_grad
400
+ def update_parameters(self):
401
+ """subtracts ``self.get_updates()`` from parameters, unless ``self.skip_update = True``, then does nothing."""
402
+ if self.skip_update: return
403
+ torch._foreach_sub_(self.params, self.get_updates())
404
+
405
+ def apply_post_step_hooks(self, modules: "Sequence[Module]"):
406
+ """Runs hooks that a few modules use. This should be called **after** updating parameters."""
407
+ modules = tuple(modules)
408
+ for hook in self.post_step_hooks:
409
+ hook(self, modules)
410
+
411
+
412
+ # ------------------------------ HELPER METHODS ------------------------------ #
413
+ @torch.no_grad
414
+ def hessian_vector_product(
415
+ self,
416
+ z: Sequence[torch.Tensor],
417
+ rgrad: Sequence[torch.Tensor] | None,
418
+ at_x0: bool,
419
+ hvp_method: HVPMethod,
420
+ h: float,
421
+ retain_graph: bool = False,
422
+ ) -> tuple[list[torch.Tensor], Sequence[torch.Tensor] | None]:
423
+ """
424
+ Returns ``(Hz, rgrad)``, where ``rgrad`` is gradient at current parameters but it may be None.
425
+
426
+ Gradient is set to ``objective`` automatically if ``at_x0`` and can be accessed with ``objective.get_grad()``.
427
+
428
+ Single hessian vector product example:
429
+
430
+ ```python
431
+ Hz, _ = self.hessian_vector_product(z, rgrad=None, at_x0=True, ..., retain_graph=False)
432
+ ```
433
+
434
+ Multiple hessian-vector products example:
435
+
436
+ ```python
437
+ rgrad = None
438
+ for z in vecs:
439
+ retain_graph = i < len(vecs) - 1
440
+ Hz, rgrad = self.hessian_vector_product(z, rgrad=rgrad, ..., retain_graph=retain_graph)
441
+
442
+ ```
443
+
444
+ Args:
445
+ z (Sequence[torch.Tensor]): vector in hessian-vector product
446
+ rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
447
+ at_x0 (bool): whether this is being called at original or perturbed parameters.
448
+ hvp_method (str): hvp method.
449
+ h (float): finite difference step size
450
+ retain_grad (bool): retain grad
451
+ """
452
+ if hvp_method in ('batched_autograd', "autograd"):
453
+ with torch.enable_grad():
454
+ if rgrad is None: rgrad = self.get_grads(create_graph=True, at_x0=at_x0)
455
+ Hz = torch.autograd.grad(rgrad, self.params, z, retain_graph=retain_graph)
456
+
457
+ # loss returned by fd hvp is not guaranteed to be at x0 so we don't use/return it
458
+ elif hvp_method == 'fd_forward':
459
+ if rgrad is None: rgrad = self.get_grads(at_x0=at_x0)
460
+ _, Hz = hvp_fd_forward(self.closure, self.params, z, h=h, g_0=rgrad)
461
+
462
+ elif hvp_method == 'fd_central':
463
+ _, Hz = hvp_fd_central(self.closure, self.params, z, h=h)
464
+
465
+ else:
466
+ raise ValueError(hvp_method)
467
+
468
+ return list(Hz), rgrad
469
+
470
+ @torch.no_grad
471
+ def hessian_matrix_product(
472
+ self,
473
+ Z: torch.Tensor,
474
+ rgrad: Sequence[torch.Tensor] | None,
475
+ at_x0: bool,
476
+ hvp_method: HVPMethod,
477
+ h: float,
478
+ retain_graph: bool = False,
479
+ ) -> tuple[torch.Tensor, Sequence[torch.Tensor] | None]:
480
+ """Z is ``(n_dim, n_hvps)``, computes ``H @ Z`` of shape ``(n_dim, n_hvps)``.
481
+
482
+ Returns ``(HZ, rgrad)`` where ``rgrad`` is gradient at current parameters but it may be None.
483
+
484
+ Gradient is set to ``objective`` automatically if ``at_x0`` and can be accessed with ``objective.get_grad()``.
485
+
486
+ Unlike ``hessian_vector_product`` this returns a single matrix, not a per-parameter list.
487
+
488
+ Args:
489
+ Z (torch.Tensor): matrix in hessian-matrix product
490
+ rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
491
+ at_x0 (bool): whether this is being called at original or perturbed parameters.
492
+ hvp_method (str): hvp method.
493
+ h (float): finite difference step size
494
+ retain_grad (bool): retain grad
495
+
496
+ """
497
+ # compute
498
+ if hvp_method == "batched_autograd":
499
+ with torch.enable_grad():
500
+ if rgrad is None: rgrad = self.get_grads(create_graph=True, at_x0=at_x0)
501
+ flat_inputs = torch.cat([g.ravel() for g in rgrad])
502
+ HZ_list = torch.autograd.grad(
503
+ flat_inputs,
504
+ self.params,
505
+ grad_outputs=Z.T,
506
+ is_grads_batched=True,
507
+ retain_graph=retain_graph,
508
+ )
509
+
510
+ HZ = flatten_jacobian(HZ_list).T
511
+
512
+ elif hvp_method == 'autograd':
513
+ with torch.enable_grad():
514
+ if rgrad is None: rgrad = self.get_grads(create_graph=True, at_x0=at_x0)
515
+ flat_inputs = torch.cat([g.ravel() for g in rgrad])
516
+ HZ_tensors = [
517
+ torch.autograd.grad(
518
+ flat_inputs,
519
+ self.params,
520
+ grad_outputs=col,
521
+ retain_graph=retain_graph or (i < Z.size(1) - 1),
522
+ )
523
+ for i, col in enumerate(Z.unbind(1))
524
+ ]
525
+
526
+ HZ_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HZ_tensors]
527
+ HZ = torch.stack(HZ_list, 1)
528
+
529
+ elif hvp_method == 'fd_forward':
530
+ if rgrad is None: rgrad = self.get_grads(at_x0=at_x0)
531
+ HZ_tensors = [
532
+ hvp_fd_forward(
533
+ self.closure,
534
+ self.params,
535
+ vec_to_tensors(col, self.params),
536
+ h=h,
537
+ g_0=rgrad,
538
+ )[1]
539
+ for col in Z.unbind(1)
540
+ ]
541
+ HZ_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HZ_tensors]
542
+ HZ = flatten_jacobian(HZ_list)
543
+
544
+ elif hvp_method == 'fd_central':
545
+ HZ_tensors = [
546
+ hvp_fd_central(
547
+ self.closure, self.params, vec_to_tensors(col, self.params), h=h
548
+ )[1]
549
+ for col in Z.unbind(1)
550
+ ]
551
+ HZ_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HZ_tensors]
552
+ HZ = flatten_jacobian(HZ_list)
553
+
554
+ else:
555
+ raise ValueError(hvp_method)
556
+
557
+ return HZ, rgrad
558
+
559
+ @torch.no_grad
560
+ def hutchinson_hessian(
561
+ self,
562
+ rgrad: Sequence[torch.Tensor] | None,
563
+ at_x0: bool,
564
+ n_samples: int | None,
565
+ distribution: Distributions | Sequence[Sequence[torch.Tensor]],
566
+ hvp_method: HVPMethod,
567
+ h: float,
568
+ generator,
569
+ variance: int | None = 1,
570
+ zHz: bool = True,
571
+ retain_graph: bool = False,
572
+ ) -> tuple[list[torch.Tensor], Sequence[torch.Tensor] | None]:
573
+ """
574
+ Returns ``(D, rgrad)``, where ``rgrad`` is gradient at current parameters but it may be None.
575
+
576
+ Gradient is set to ``objective`` automatically if ``at_x0`` and can be accessed with ``objective.get_grad()``.
577
+
578
+ Args:
579
+ rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
580
+ at_x0 (bool): whether this is being called at original or perturbed parameters.
581
+ n_samples (int | None): number of random vectors.
582
+ distribution (Distributions | Sequence[Sequence[torch.Tensor]]):
583
+ distribution, this can also be a sequence of tensor sequences.
584
+ hvp_method (str): how to compute hessian-vector products.
585
+ h (float): finite difference step size.
586
+ generator (Any): generator
587
+ variance (int | None, optional): variance of random vectors. Defaults to 1.
588
+ zHz (bool, optional): whether to compute z ⊙ Hz. If False, computes Hz. Defaults to True.
589
+ retain_graph (bool, optional): whether to retain graph. Defaults to False.
590
+ """
591
+
592
+ params = TensorList(self.params)
593
+ samples = None
594
+
595
+ # check when distribution is sequence of tensors
596
+ if not isinstance(distribution, str):
597
+ if n_samples is not None and n_samples != len(distribution):
598
+ raise RuntimeError("when passing sequence of z to `hutchinson_hessian`, set `n_samples` to None")
599
+
600
+ n_samples = len(distribution)
601
+ samples = distribution
602
+
603
+ # use non-batched with single sample
604
+ if n_samples == 1 and hvp_method == 'batched_autograd':
605
+ hvp_method = 'autograd'
606
+
607
+ # -------------------------- non-batched hutchinson -------------------------- #
608
+ if hvp_method in ('autograd', 'fd_forward', 'fd_central'):
609
+
610
+ D = None
611
+ assert n_samples is not None
612
+
613
+ for i in range(n_samples):
614
+
615
+ # sample
616
+ if samples is not None: z = samples[i]
617
+ else: z = params.sample_like(cast(Distributions, distribution), variance, generator=generator)
618
+
619
+ # compute
620
+ Hz, rgrad = self.hessian_vector_product(
621
+ z=z,
622
+ rgrad=rgrad,
623
+ at_x0=at_x0,
624
+ hvp_method=hvp_method,
625
+ h=h,
626
+ retain_graph=(i < n_samples - 1) or retain_graph,
627
+ )
628
+
629
+ # add
630
+ if zHz: torch._foreach_mul_(Hz, tuple(z))
631
+
632
+ if D is None: D = Hz
633
+ else: torch._foreach_add_(D, Hz)
634
+
635
+
636
+ assert D is not None
637
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
638
+ return D, rgrad
639
+
640
+ # ---------------------------- batched hutchinson ---------------------------- #
641
+ if hvp_method != 'batched_autograd':
642
+ raise RuntimeError(f"Unknown hvp_method: `{hvp_method}`")
643
+
644
+ # generate and vectorize samples
645
+ if samples is None:
646
+ samples = [params.sample_like(cast(Distributions, distribution), variance, generator=generator).to_vec()]
647
+
648
+ else:
649
+ samples = [torch.cat([t.ravel() for t in s]) for s in samples]
650
+
651
+ # compute Hz
652
+ Z = torch.stack(samples, -1)
653
+ HZ, rgrad = self.hessian_matrix_product(
654
+ Z,
655
+ rgrad=rgrad,
656
+ at_x0=at_x0,
657
+ hvp_method='batched_autograd',
658
+ h=h, # not used
659
+ retain_graph=retain_graph,
660
+ )
661
+
662
+ if zHz: HZ *= Z
663
+ D_vec = HZ.mean(-1)
664
+ return vec_to_tensors(D_vec, params), rgrad
665
+
666
+ @torch.no_grad
667
+ def hessian(
668
+ self,
669
+ hessian_method: HessianMethod,
670
+ h: float,
671
+ at_x0: bool,
672
+ ) -> tuple[torch.Tensor | None, Sequence[torch.Tensor] | None, torch.Tensor]:
673
+ """returns ``(f, g_list, H)``. Also sets ``objective.loss`` and ``objective.grad`` if ``at_x0``.
674
+
675
+ ``f`` and ``g_list`` may be None if they aren't computed with ``hessian_method``.
676
+
677
+ Args:
678
+ hessian_method: how to compute hessian
679
+ h (float): finite difference step size
680
+ vectorize (bool): whether to vectorize hessian computation
681
+ at_x0 (bool): whether its at x0.
682
+ """
683
+ closure = self.closure
684
+ if closure is None:
685
+ raise RuntimeError("Computing hessian requires a closure to be provided to the `step` method.")
686
+
687
+ params = self.params
688
+ numel = sum(p.numel() for p in params)
689
+
690
+ f = None
691
+ g_list = None
692
+
693
+ # autograd hessian
694
+ if hessian_method in ("batched_autograd", "autograd"):
695
+ with torch.enable_grad():
696
+ f = self.get_loss(False, at_x0=at_x0)
697
+
698
+ batched = hessian_method == "batched_autograd"
699
+ g_list, H_list = jacobian_and_hessian_wrt([f.ravel()], params, batched=batched)
700
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
701
+
702
+ H = flatten_jacobian(H_list)
703
+
704
+ # functional autograd hessian
705
+ elif hessian_method in ('func', 'functional_revrev', 'functional_fwdrev'):
706
+ if hessian_method == 'functional_fwdrev':
707
+ method = "autograd.functional"
708
+ outer_jacobian_strategy = "forward-mode"
709
+ vectorize=True
710
+ elif hessian_method == 'functional_revrev':
711
+ method = "autograd.functional"
712
+ outer_jacobian_strategy = "reverse-mode"
713
+ vectorize=False
714
+ else:
715
+ method = 'func'
716
+ outer_jacobian_strategy = "forward-mode" # unused
717
+ vectorize=True # unused
718
+
719
+ with torch.enable_grad():
720
+ H = hessian_mat(partial(closure, backward=False), params,
721
+ method=method, vectorize=vectorize,
722
+ outer_jacobian_strategy=outer_jacobian_strategy)
723
+
724
+ # thoad
725
+ elif hessian_method == "thoad":
726
+ with torch.enable_grad():
727
+ f = self.get_loss(False, at_x0=at_x0)
728
+ ctrl = lazy_thoad.backward(f, 2, crossings=True)
729
+
730
+ g_list = [p.hgrad[0].squeeze(0) for p in params] # pyright:ignore[reportAttributeAccessIssue]
731
+ H = thoad_single_tensor(ctrl, params, 2)
732
+
733
+
734
+ # gradient finite difference
735
+ elif hessian_method in ('gfd_forward', 'gfd_central'):
736
+
737
+ if hessian_method == 'gfd_central': hvp_method = 'fd_central'
738
+ else: hvp_method = 'fd_forward'
739
+
740
+ I = torch.eye(numel, device=params[0].device, dtype=params[0].dtype)
741
+ H, g_list = self.hessian_matrix_product(I, rgrad=None, at_x0=at_x0, hvp_method=hvp_method, h=h)
742
+
743
+ # function value finite difference
744
+ elif hessian_method in ('fd', "fd_full"):
745
+ full = hessian_method == "fd_full"
746
+ f, g_list, H = hessian_fd(partial(closure, False), params=params, eps=h, full=full)
747
+
748
+ else:
749
+ raise ValueError(hessian_method)
750
+
751
+ # set objective attributes if at x0
752
+ if at_x0:
753
+ if f is not None and self.loss is None:
754
+ self.loss = self.loss_approx = f
755
+
756
+ if g_list is not None and self.grads is None:
757
+ self.grads = list(g_list)
758
+
759
+ return f, g_list, H
760
+
761
+ @torch.no_grad
762
+ def derivatives(self, order: int, at_x0: bool, method:DerivativesMethod="batched_autograd"):
763
+ """
764
+ returns a tuple of tensors of function value and derivatives up to ``order``
765
+
766
+ ``order = 0`` returns ``(f,)``;
767
+
768
+ ``order = 1`` returns ``(f, g)``;
769
+
770
+ ``order = 2`` returns ``(f, g, H)``;
771
+
772
+ ``order = 3`` returns ``(f, g, H, T3)``;
773
+
774
+ etc.
775
+ """
776
+ closure = self.closure
777
+ if closure is None:
778
+ raise RuntimeError("Computing hessian requires a closure to be provided to the `step` method.")
779
+
780
+ # just loss
781
+ if order == 0:
782
+ f = self.get_loss(False, at_x0=at_x0)
783
+ return (f, )
784
+
785
+ # loss and grad
786
+ if order == 1:
787
+ f, g_list = self.get_loss_grads(at_x0=at_x0)
788
+ g = torch.cat([t.ravel() for t in g_list])
789
+
790
+ return f, g
791
+
792
+ if method in ("autograd", "batched_autograd"):
793
+ batched = method == "batched_autograd"
794
+
795
+ # recursively compute derivatives up to order
796
+ with torch.enable_grad():
797
+ f, g_list = self.get_loss_grads(at_x0=at_x0, create_graph=True)
798
+ g = torch.cat([t.ravel() for t in g_list])
799
+
800
+ n = g.numel()
801
+ ret = [f, g]
802
+ T = g # current derivatives tensor
803
+
804
+ # get all derivative up to order
805
+ for o in range(2, order + 1):
806
+ is_last = o == order
807
+ T_list = jacobian_wrt([T], self.params, create_graph=not is_last, batched=batched)
808
+ with torch.no_grad() if is_last else nullcontext():
809
+
810
+ # the shape is (ndim, ) * order
811
+ T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
812
+ ret.append(T)
813
+
814
+ return tuple(ret)
815
+
816
+ if method == "thoad":
817
+ with torch.enable_grad():
818
+ f = self.get_loss(False, at_x0=at_x0)
819
+ ctrl = lazy_thoad.backward(f, order, crossings=True)
820
+
821
+ return tuple([f, *thoad_derivatives(ctrl, self.params, order=order)])
822
+
823
+ raise ValueError(method)
824
+
825
+ @torch.no_grad
826
+ def derivatives_at(
827
+ self,
828
+ x: torch.Tensor | Sequence[torch.Tensor],
829
+ order: int,
830
+ method:DerivativesMethod="batched_autograd"
831
+ ):
832
+ """
833
+ returns a tuple of tensors of function value and derivatives up to ``order`` at ``x``,
834
+ then sets original parameters.
835
+
836
+ ``x`` can be a vector or a list of tensors.
837
+
838
+ ``order = 0`` returns ``(f,)``;
839
+
840
+ ``order = 1`` returns ``(f, g)``;
841
+
842
+ ``order = 2`` returns ``(f, g, H)``;
843
+
844
+ ``order = 3`` returns ``(f, g, H, T3)``;
845
+
846
+ etc.
847
+ """
848
+ if isinstance(x, torch.Tensor): x = vec_to_tensors(x, self.params)
849
+
850
+ x0 = [p.clone() for p in self.params]
851
+
852
+ # set params to x
853
+ for p, x_i in zip(self.params, x):
854
+ set_storage_(p, x_i)
855
+
856
+ ret = self.derivatives(order=order, at_x0=False, method=method)
857
+
858
+ # set params to x0
859
+ for p, x0_i in zip(self.params, x0):
860
+ set_storage_(p, x0_i)
861
+
862
+ return ret
863
+
864
+
865
+ def list_Hvp_function(self, hvp_method: HVPMethod, h: float, at_x0:bool):
866
+ """returns ``(grad, H_mv)`` where ``H_mv`` is a callable that accepts and returns lists of tensors.
867
+
868
+ ``grad`` may be None, and this sets ``objective.grad`` if ``at_x0`` so at x0 just use ``objective.get_grad()``.
869
+ """
870
+ params = TensorList(self.params)
871
+ closure = self.closure
872
+
873
+ if hvp_method in ('batched_autograd', 'autograd'):
874
+ grad = self.get_grads(create_graph=True, at_x0=at_x0)
875
+
876
+ def H_mv(x: torch.Tensor | Sequence[torch.Tensor]):
877
+ if isinstance(x, torch.Tensor): x = params.from_vec(x)
878
+ with torch.enable_grad():
879
+ return TensorList(torch.autograd.grad(grad, params, x, retain_graph=True))
880
+
881
+ else:
882
+
883
+ if hvp_method == 'fd_forward':
884
+ grad = self.get_grads(at_x0=at_x0)
885
+ def H_mv(x: torch.Tensor | Sequence[torch.Tensor]):
886
+ if isinstance(x, torch.Tensor): x = params.from_vec(x)
887
+ _, Hx = hvp_fd_forward(closure, params, x, h=h, g_0=grad)
888
+ return TensorList(Hx)
889
+
890
+ elif hvp_method == 'fd_central':
891
+ grad = None
892
+ def H_mv(x: torch.Tensor | Sequence[torch.Tensor]):
893
+ if isinstance(x, torch.Tensor): x = params.from_vec(x)
894
+ _, Hx = hvp_fd_central(closure, params, x, h=h)
895
+ return TensorList(Hx)
896
+
897
+ else:
898
+ raise ValueError(hvp_method)
899
+
900
+
901
+ return grad, H_mv
902
+
903
+ def tensor_Hvp_function(self, hvp_method: HVPMethod, h: float, at_x0:bool):
904
+ """returns ``(grad, H_mv, H_mm)``, where ``H_mv`` and ``H_mm`` accept and return single tensors.
905
+
906
+ ``grad`` may be None, and this sets ``objective.grad`` if ``at_x0`` so at x0 just use ``objective.get_grad()``.
907
+ """
908
+ if hvp_method in ('fd_forward', "fd_central", "autograd"):
909
+ grad, list_H_mv = self.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=at_x0)
910
+
911
+ def H_mv_loop(x: torch.Tensor):
912
+ Hx_list = list_H_mv(x)
913
+ return torch.cat([t.ravel() for t in Hx_list])
914
+
915
+ def H_mm_loop(X: torch.Tensor):
916
+ return torch.stack([H_mv_loop(col) for col in X.unbind(-1)], -1)
917
+
918
+ return grad, H_mv_loop, H_mm_loop
919
+
920
+ # for batched we need grad
921
+ if hvp_method != 'batched_autograd':
922
+ raise RuntimeError(f"Unknown hvp_method `{hvp_method}`")
923
+
924
+ params = TensorList(self.params)
925
+ grad = self.get_grads(create_graph=True, at_x0=at_x0)
926
+
927
+ def H_mv_batched(x: torch.Tensor):
928
+ with torch.enable_grad():
929
+ Hx_list = torch.autograd.grad(grad, params, params.from_vec(x), retain_graph=True)
930
+
931
+ return torch.cat([t.ravel() for t in Hx_list])
932
+
933
+ def H_mm_batched(X: torch.Tensor):
934
+ with torch.enable_grad():
935
+ flat_inputs = torch.cat([g.ravel() for g in grad])
936
+ HX_list = torch.autograd.grad(
937
+ flat_inputs,
938
+ self.params,
939
+ grad_outputs=X.T,
940
+ is_grads_batched=True,
941
+ retain_graph=True,
942
+ )
943
+ return flatten_jacobian(HX_list).T
944
+
945
+ return grad, H_mv_batched, H_mm_batched
946
+
947
+
948
+ # endregion