torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,948 @@
|
|
|
1
|
+
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable, Sequence, Iterable
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ..utils import Distributions, TensorList, vec_to_tensors, set_storage_
|
|
11
|
+
from ..utils.derivatives import (
|
|
12
|
+
flatten_jacobian,
|
|
13
|
+
hessian_mat,
|
|
14
|
+
hvp_fd_central,
|
|
15
|
+
hvp_fd_forward,
|
|
16
|
+
jacobian_and_hessian_wrt,
|
|
17
|
+
jacobian_wrt,
|
|
18
|
+
hessian_fd,
|
|
19
|
+
)
|
|
20
|
+
from ..utils.thoad_tools import thoad_derivatives, thoad_single_tensor, lazy_thoad
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .modular import Optimizer
|
|
24
|
+
from .module import Module
|
|
25
|
+
|
|
26
|
+
def _closure_backward(closure, params, backward, retain_graph, create_graph):
|
|
27
|
+
"""Calls closure with specified ``backward``, ``retain_graph`` and ``create_graph``.
|
|
28
|
+
|
|
29
|
+
Returns loss and sets ``param.grad`` attributes.
|
|
30
|
+
|
|
31
|
+
If ``backward=True``, this uses ``torch.enable_grad()`` context.
|
|
32
|
+
"""
|
|
33
|
+
if not backward:
|
|
34
|
+
return closure(False)
|
|
35
|
+
|
|
36
|
+
with torch.enable_grad():
|
|
37
|
+
if not (retain_graph or create_graph):
|
|
38
|
+
return closure()
|
|
39
|
+
|
|
40
|
+
# zero grad (because closure called with backward=False)
|
|
41
|
+
for p in params: p.grad = None
|
|
42
|
+
|
|
43
|
+
# loss
|
|
44
|
+
loss = closure(False).ravel()
|
|
45
|
+
|
|
46
|
+
# grad
|
|
47
|
+
grad = torch.autograd.grad(
|
|
48
|
+
loss,
|
|
49
|
+
params,
|
|
50
|
+
retain_graph=retain_graph,
|
|
51
|
+
create_graph=create_graph,
|
|
52
|
+
allow_unused=True,
|
|
53
|
+
materialize_grads=True,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# set p.grad
|
|
57
|
+
for p,g in zip(params,grad): p.grad = g
|
|
58
|
+
return loss
|
|
59
|
+
|
|
60
|
+
@torch.enable_grad
|
|
61
|
+
def _closure_loss_grad(closure, params, retain_graph, create_graph) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
|
62
|
+
"""Calls closure with specified ``backward``, ``retain_graph`` and ``create_graph``
|
|
63
|
+
within ``torch.enable_grad()``context.
|
|
64
|
+
|
|
65
|
+
Returns ``(loss, grad)``. Unlike ``_closure_backward``, this won't always set ``p.grad``.
|
|
66
|
+
"""
|
|
67
|
+
if closure is None: raise RuntimeError("closure is None")
|
|
68
|
+
|
|
69
|
+
# use torch.autograd.grad
|
|
70
|
+
if retain_graph or create_graph:
|
|
71
|
+
loss = closure(False).ravel()
|
|
72
|
+
return loss, list(
|
|
73
|
+
torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph, allow_unused=True, materialize_grads=True)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# use backward
|
|
77
|
+
loss = closure()
|
|
78
|
+
return loss, [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
79
|
+
|
|
80
|
+
HVPMethod = Literal["batched_autograd", "autograd", "fd_forward", "fd_central"]
|
|
81
|
+
"""
|
|
82
|
+
Determines how hessian-vector products are computed.
|
|
83
|
+
|
|
84
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
|
|
85
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
86
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
87
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
88
|
+
|
|
89
|
+
Defaults to ``"autograd"``.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
HessianMethod = Literal[
|
|
93
|
+
"batched_autograd",
|
|
94
|
+
"autograd",
|
|
95
|
+
"functional_revrev",
|
|
96
|
+
"functional_fwdrev",
|
|
97
|
+
"func",
|
|
98
|
+
"gfd_forward",
|
|
99
|
+
"gfd_central",
|
|
100
|
+
"fd",
|
|
101
|
+
"fd_full",
|
|
102
|
+
"thoad",
|
|
103
|
+
]
|
|
104
|
+
"""
|
|
105
|
+
Determines how hessian is computed.
|
|
106
|
+
|
|
107
|
+
- ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
|
|
108
|
+
- ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
109
|
+
- ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
|
|
110
|
+
- ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
|
|
111
|
+
- ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
|
|
112
|
+
- ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
113
|
+
- ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
114
|
+
- ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. Only computes upper triangle of the hessian, requires ``2n^2 + 1`` function evaluations. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
|
|
115
|
+
- ``"fd_full"`` - uses function values to estimate gradient and hessian via finite difference. Computes both upper and lower triangles and averages them, requires ``4n^2 - 2n + 1`` function evaluations This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
|
|
116
|
+
- ``"thoad"`` - uses [thoad](https://github.com/mntsx/thoad) library (experimental).
|
|
117
|
+
|
|
118
|
+
Defaults to ``"batched_autograd"``.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
DerivativesMethod = Literal["autograd", "batched_autograd", "thoad"]
|
|
122
|
+
"""
|
|
123
|
+
Determines how higher order derivatives are computed.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
class Objective:
|
|
127
|
+
"""
|
|
128
|
+
Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
|
|
129
|
+
Modules take in a ``Objective`` object, modify and it is passed to the next module.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
params (Iterable[torch.Tensor]): iterable of parameters that are being optimized.
|
|
133
|
+
closure (Callable | None, optional): callable that re-evaluates loss. Defaults to None.
|
|
134
|
+
loss (torch.Tensor | None, optional): loss at ``params``. Defaults to None.
|
|
135
|
+
model (torch.nn.Module | None, optional):
|
|
136
|
+
``torch.nn.Module`` object, needed for a few modules that require access to the model. Defaults to None.
|
|
137
|
+
current_step (int, optional):
|
|
138
|
+
number of times ``Optimizer.step()`` has been called, starting at 0. Defaults to 0.
|
|
139
|
+
parent (Objective | None, optional):
|
|
140
|
+
parent ``Objective`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
|
|
141
|
+
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
142
|
+
e.g. when projecting. Defaults to None.
|
|
143
|
+
modular (Optimizer | None, optional):
|
|
144
|
+
Top-level ``Optimizer`` optimizer. Defaults to None.
|
|
145
|
+
storage (dict | None, optional):
|
|
146
|
+
additional kwargs passed to ``step`` to control some module-specific behavior. Defaults to None.
|
|
147
|
+
|
|
148
|
+
"""
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
params: Iterable[torch.Tensor],
|
|
152
|
+
closure: Callable | None = None,
|
|
153
|
+
loss: torch.Tensor | None = None,
|
|
154
|
+
model: torch.nn.Module | None = None,
|
|
155
|
+
current_step: int = 0,
|
|
156
|
+
parent: "Objective | None" = None,
|
|
157
|
+
modular: "Optimizer | None" = None,
|
|
158
|
+
storage: dict | None = None,
|
|
159
|
+
):
|
|
160
|
+
self.params: list[torch.Tensor] = list(params)
|
|
161
|
+
"""List of all parameters with ``requires_grad = True``."""
|
|
162
|
+
|
|
163
|
+
self.closure = closure
|
|
164
|
+
"""A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
|
|
165
|
+
|
|
166
|
+
self.model = model
|
|
167
|
+
"""``torch.nn.Module`` object of the model, ``None`` if it wasn't specified."""
|
|
168
|
+
|
|
169
|
+
self.current_step: int = current_step
|
|
170
|
+
"""global current step, starts at 0. This may not correspond to module current step,
|
|
171
|
+
for example a module may step every 10 global steps."""
|
|
172
|
+
|
|
173
|
+
self.parent: "Objective | None" = parent
|
|
174
|
+
"""parent ``Objective`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
|
|
175
|
+
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
176
|
+
e.g. when projecting."""
|
|
177
|
+
|
|
178
|
+
self.modular: "Optimizer | None" = modular
|
|
179
|
+
"""Top-level ``Optimizer`` optimizer, ``None`` if it wasn't specified."""
|
|
180
|
+
|
|
181
|
+
self.updates: list[torch.Tensor] | None = None
|
|
182
|
+
"""
|
|
183
|
+
current updates list. Update is assumed to be a transformed gradient, therefore it is subtracted.
|
|
184
|
+
|
|
185
|
+
If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
|
|
186
|
+
|
|
187
|
+
At the end ``objective.get_update()`` is subtracted from parameters.
|
|
188
|
+
Therefore if ``objective.update`` is ``None``, gradient will be used and calculated if needed.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
self.grads: list[torch.Tensor] | None = None
|
|
192
|
+
"""gradient with current parameters. If closure is not ``None``,
|
|
193
|
+
this is set to ``None`` and can be calculated if needed."""
|
|
194
|
+
|
|
195
|
+
self.loss: torch.Tensor | Any | None = loss
|
|
196
|
+
"""loss with current parameters."""
|
|
197
|
+
|
|
198
|
+
self.loss_approx: torch.Tensor | Any | None = None
|
|
199
|
+
"""loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
|
|
200
|
+
whereas some other modules require loss strictly at current point."""
|
|
201
|
+
|
|
202
|
+
self.post_step_hooks: "list[Callable[[Objective, tuple[Module, ...]], None]]" = []
|
|
203
|
+
"""list of functions to be called after optimizer step.
|
|
204
|
+
|
|
205
|
+
This attribute should always be modified in-place (using ``append`` or ``extend``).
|
|
206
|
+
|
|
207
|
+
The signature is:
|
|
208
|
+
|
|
209
|
+
```python
|
|
210
|
+
def hook(objective: Objective, modules: tuple[Module]): ...
|
|
211
|
+
```
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
self.stop: bool = False
|
|
215
|
+
"""if True, all following modules will be skipped.
|
|
216
|
+
If this module is a child, it only affects modules at the same level (in the same Chain)."""
|
|
217
|
+
|
|
218
|
+
self.skip_update: bool = False
|
|
219
|
+
"""if True, the parameters will not be updated."""
|
|
220
|
+
|
|
221
|
+
# self.storage: dict = {}
|
|
222
|
+
# """Storage for any other data, such as hessian estimates, etc."""
|
|
223
|
+
|
|
224
|
+
self.attrs: dict = {}
|
|
225
|
+
"""attributes, ``Optimizer.attrs`` is updated with this after each step.
|
|
226
|
+
This attribute should always be modified in-place"""
|
|
227
|
+
|
|
228
|
+
if storage is None: storage = {}
|
|
229
|
+
self.storage: dict = storage
|
|
230
|
+
"""additional kwargs passed to ``step`` to control some module-specific behavior.
|
|
231
|
+
This attribute should always be modified in-place"""
|
|
232
|
+
|
|
233
|
+
self.should_terminate: bool | None = None
|
|
234
|
+
"""termination criteria, ``Optimizer.should_terminate`` is set to this after each step if not ``None``"""
|
|
235
|
+
|
|
236
|
+
self.temp: Any = cast(Any, None)
|
|
237
|
+
"""temporary storage, ``Module.update`` can set this and ``Module.apply`` access via ``objective.poptemp()``.
|
|
238
|
+
This doesn't get cloned."""
|
|
239
|
+
|
|
240
|
+
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False, at_x0:bool=True) -> torch.Tensor:
|
|
241
|
+
"""Returns the loss at current parameters, computing it if it hasn't been computed already
|
|
242
|
+
and assigning ``objective.loss``.Do not call this at perturbed parameters.
|
|
243
|
+
Backward always sets grads to None before recomputing.
|
|
244
|
+
|
|
245
|
+
If ``backward==True``, closure is called within ``torch.enable_grad()``
|
|
246
|
+
"""
|
|
247
|
+
|
|
248
|
+
# at non-x0 point just call closure and return
|
|
249
|
+
if not at_x0:
|
|
250
|
+
if self.closure is None: raise RuntimeError("closure is None")
|
|
251
|
+
return _closure_backward(
|
|
252
|
+
self.closure, self.params, backward=backward, retain_graph=retain_graph, create_graph=create_graph,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# at x0 set self.loss and self.grads
|
|
256
|
+
if self.loss is None:
|
|
257
|
+
|
|
258
|
+
if self.closure is None: raise RuntimeError("closure is None")
|
|
259
|
+
|
|
260
|
+
# backward
|
|
261
|
+
if backward:
|
|
262
|
+
self.loss = self.loss_approx = _closure_backward(
|
|
263
|
+
closure=self.closure, params=self.params, backward=True, retain_graph=retain_graph, create_graph=create_graph
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# next time closure() is called, it will set grad to None.
|
|
267
|
+
# zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
|
|
268
|
+
# because otherwise it will zero self.grads in-place
|
|
269
|
+
self.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
270
|
+
|
|
271
|
+
# no backward
|
|
272
|
+
else:
|
|
273
|
+
self.loss = self.loss_approx = _closure_backward(
|
|
274
|
+
closure=self.closure, params=self.params, backward=False, retain_graph=False, create_graph=False
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
|
|
278
|
+
# and now it is requested to be evaluated with backward.
|
|
279
|
+
if backward and self.grads is None:
|
|
280
|
+
warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
|
|
281
|
+
if self.closure is None: raise RuntimeError("closure is None")
|
|
282
|
+
|
|
283
|
+
self.loss = self.loss_approx = _closure_backward(
|
|
284
|
+
closure=self.closure, params=self.params, backward=True, retain_graph=retain_graph, create_graph=create_graph
|
|
285
|
+
)
|
|
286
|
+
self.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
287
|
+
|
|
288
|
+
# set parent grad
|
|
289
|
+
if self.parent is not None:
|
|
290
|
+
# the way projections/split work, they make a new closure which evaluates original
|
|
291
|
+
# closure and projects the gradient, and set it as their objective.closure.
|
|
292
|
+
# then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
|
|
293
|
+
# and we set it to parent objective here.
|
|
294
|
+
if self.parent.loss is None: self.parent.loss = self.loss
|
|
295
|
+
if self.parent.grads is None and backward:
|
|
296
|
+
if all(p.grad is None for p in self.parent.params):
|
|
297
|
+
warnings.warn("Parent grad is None after backward.")
|
|
298
|
+
self.parent.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
|
|
299
|
+
|
|
300
|
+
return self.loss # type:ignore
|
|
301
|
+
|
|
302
|
+
def get_grads(self, retain_graph: bool | None = None, create_graph: bool = False, at_x0: bool = True) -> list[torch.Tensor]:
|
|
303
|
+
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning ``objective.grad`` and potentially ``objective.loss``. Do not call this at perturbed parameters."""
|
|
304
|
+
# at non-x0 point just call closure and return grads
|
|
305
|
+
if not at_x0:
|
|
306
|
+
_, grads = _closure_loss_grad(self.closure, self.params, retain_graph=retain_graph, create_graph=create_graph)
|
|
307
|
+
return grads
|
|
308
|
+
|
|
309
|
+
# at x0 get_loss sets self.loss and self.grads
|
|
310
|
+
if self.grads is None:
|
|
311
|
+
if self.closure is None: raise RuntimeError("closure is None")
|
|
312
|
+
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
313
|
+
|
|
314
|
+
assert self.grads is not None
|
|
315
|
+
return self.grads
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def get_loss_grads(self, retain_graph: bool | None = None, create_graph: bool = False, at_x0: bool = True) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
|
319
|
+
"""returns ``(loss, grads)``. Useful when you need both not at x0."""
|
|
320
|
+
# at non-x0 point just call closure and return (loss, grads)
|
|
321
|
+
if not at_x0:
|
|
322
|
+
return _closure_loss_grad(self.closure, self.params, retain_graph=retain_graph, create_graph=create_graph)
|
|
323
|
+
|
|
324
|
+
# at x0 get_grads sets self.loss and self.grads, then get_loss returns self.loss.
|
|
325
|
+
grad = self.get_grads(retain_graph=retain_graph, create_graph=create_graph)
|
|
326
|
+
loss = self.get_loss(False)
|
|
327
|
+
return loss, grad
|
|
328
|
+
|
|
329
|
+
def get_updates(self) -> list[torch.Tensor]:
|
|
330
|
+
"""Returns the update. If update is None, it is initialized by cloning the gradients
|
|
331
|
+
and assigning to ``objective.update``. Computing the gradients may assign ``objective.grad``
|
|
332
|
+
and ``objective.loss`` if they haven't been computed. Do not call this at perturbed parameters."""
|
|
333
|
+
if self.updates is None: self.updates = [g.clone() for g in self.get_grads()]
|
|
334
|
+
return self.updates
|
|
335
|
+
|
|
336
|
+
def clone(self, clone_updates: bool, parent: "Objective | None" = None):
|
|
337
|
+
"""Creates a shallow copy of this ``Objective``, update can optionally be deep-copied (via ``torch.clone``).
|
|
338
|
+
|
|
339
|
+
This copies over all attributes except ``temp``.
|
|
340
|
+
|
|
341
|
+
Setting ``parent`` is only if clone's parameters are something different,
|
|
342
|
+
while clone's closure referes to the same objective but with a "view" on parameters.
|
|
343
|
+
"""
|
|
344
|
+
copy = Objective(
|
|
345
|
+
params=self.params, closure=self.closure, model=self.model, current_step=self.current_step,
|
|
346
|
+
parent=parent, modular=self.modular, loss=self.loss, storage=self.storage
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if clone_updates and self.updates is not None:
|
|
350
|
+
copy.updates = [u.clone() for u in self.updates]
|
|
351
|
+
else:
|
|
352
|
+
copy.updates = self.updates
|
|
353
|
+
|
|
354
|
+
copy.grads = self.grads
|
|
355
|
+
copy.loss_approx = self.loss_approx
|
|
356
|
+
copy.post_step_hooks = self.post_step_hooks
|
|
357
|
+
copy.stop = self.stop
|
|
358
|
+
copy.skip_update = self.skip_update
|
|
359
|
+
|
|
360
|
+
copy.attrs = self.attrs
|
|
361
|
+
copy.should_terminate = self.should_terminate
|
|
362
|
+
|
|
363
|
+
return copy
|
|
364
|
+
|
|
365
|
+
def update_attrs_from_clone_(self, objective: "Objective"):
|
|
366
|
+
"""Updates attributes of this ``Objective`` instance from a cloned instance.
|
|
367
|
+
Typically called after a child module has processed a cloned ``Objective``
|
|
368
|
+
object. This propagates any newly computed loss or gradient values
|
|
369
|
+
from the child's context back to the parent ``Objective`` if the parent
|
|
370
|
+
didn't have them computed already.
|
|
371
|
+
|
|
372
|
+
This copies over ``loss``, ``loss_approx``, ``grads``, ``should_terminate`` and ``skip_update``.
|
|
373
|
+
|
|
374
|
+
Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
|
|
375
|
+
if the child updates them, the update will affect the parent too.
|
|
376
|
+
"""
|
|
377
|
+
if self.loss is None: self.loss = objective.loss
|
|
378
|
+
if self.loss_approx is None: self.loss_approx = objective.loss_approx
|
|
379
|
+
if self.grads is None: self.grads = objective.grads
|
|
380
|
+
|
|
381
|
+
if objective.should_terminate is not None: self.should_terminate = objective.should_terminate
|
|
382
|
+
if objective.skip_update: self.skip_update = objective.skip_update
|
|
383
|
+
|
|
384
|
+
@torch.no_grad
|
|
385
|
+
def zero_grad(self, set_to_none=True):
|
|
386
|
+
"""In most cases not call with ``set_to_none=False``, as that will zero ``self.grads`` in-place."""
|
|
387
|
+
if set_to_none:
|
|
388
|
+
for p in self.params: p.grad = None
|
|
389
|
+
else:
|
|
390
|
+
grads = [p.grad for p in self.params if p.grad is not None]
|
|
391
|
+
if len(grads) != 0: torch._foreach_zero_(grads)
|
|
392
|
+
|
|
393
|
+
def poptemp(self):
|
|
394
|
+
"""to pass information from ``update`` to ``apply``."""
|
|
395
|
+
temp = self.temp
|
|
396
|
+
self.temp = None
|
|
397
|
+
return temp
|
|
398
|
+
|
|
399
|
+
@torch.no_grad
|
|
400
|
+
def update_parameters(self):
|
|
401
|
+
"""subtracts ``self.get_updates()`` from parameters, unless ``self.skip_update = True``, then does nothing."""
|
|
402
|
+
if self.skip_update: return
|
|
403
|
+
torch._foreach_sub_(self.params, self.get_updates())
|
|
404
|
+
|
|
405
|
+
def apply_post_step_hooks(self, modules: "Sequence[Module]"):
|
|
406
|
+
"""Runs hooks that a few modules use. This should be called **after** updating parameters."""
|
|
407
|
+
modules = tuple(modules)
|
|
408
|
+
for hook in self.post_step_hooks:
|
|
409
|
+
hook(self, modules)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
# ------------------------------ HELPER METHODS ------------------------------ #
|
|
413
|
+
@torch.no_grad
|
|
414
|
+
def hessian_vector_product(
|
|
415
|
+
self,
|
|
416
|
+
z: Sequence[torch.Tensor],
|
|
417
|
+
rgrad: Sequence[torch.Tensor] | None,
|
|
418
|
+
at_x0: bool,
|
|
419
|
+
hvp_method: HVPMethod,
|
|
420
|
+
h: float,
|
|
421
|
+
retain_graph: bool = False,
|
|
422
|
+
) -> tuple[list[torch.Tensor], Sequence[torch.Tensor] | None]:
|
|
423
|
+
"""
|
|
424
|
+
Returns ``(Hz, rgrad)``, where ``rgrad`` is gradient at current parameters but it may be None.
|
|
425
|
+
|
|
426
|
+
Gradient is set to ``objective`` automatically if ``at_x0`` and can be accessed with ``objective.get_grad()``.
|
|
427
|
+
|
|
428
|
+
Single hessian vector product example:
|
|
429
|
+
|
|
430
|
+
```python
|
|
431
|
+
Hz, _ = self.hessian_vector_product(z, rgrad=None, at_x0=True, ..., retain_graph=False)
|
|
432
|
+
```
|
|
433
|
+
|
|
434
|
+
Multiple hessian-vector products example:
|
|
435
|
+
|
|
436
|
+
```python
|
|
437
|
+
rgrad = None
|
|
438
|
+
for z in vecs:
|
|
439
|
+
retain_graph = i < len(vecs) - 1
|
|
440
|
+
Hz, rgrad = self.hessian_vector_product(z, rgrad=rgrad, ..., retain_graph=retain_graph)
|
|
441
|
+
|
|
442
|
+
```
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
z (Sequence[torch.Tensor]): vector in hessian-vector product
|
|
446
|
+
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
447
|
+
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
448
|
+
hvp_method (str): hvp method.
|
|
449
|
+
h (float): finite difference step size
|
|
450
|
+
retain_grad (bool): retain grad
|
|
451
|
+
"""
|
|
452
|
+
if hvp_method in ('batched_autograd', "autograd"):
|
|
453
|
+
with torch.enable_grad():
|
|
454
|
+
if rgrad is None: rgrad = self.get_grads(create_graph=True, at_x0=at_x0)
|
|
455
|
+
Hz = torch.autograd.grad(rgrad, self.params, z, retain_graph=retain_graph)
|
|
456
|
+
|
|
457
|
+
# loss returned by fd hvp is not guaranteed to be at x0 so we don't use/return it
|
|
458
|
+
elif hvp_method == 'fd_forward':
|
|
459
|
+
if rgrad is None: rgrad = self.get_grads(at_x0=at_x0)
|
|
460
|
+
_, Hz = hvp_fd_forward(self.closure, self.params, z, h=h, g_0=rgrad)
|
|
461
|
+
|
|
462
|
+
elif hvp_method == 'fd_central':
|
|
463
|
+
_, Hz = hvp_fd_central(self.closure, self.params, z, h=h)
|
|
464
|
+
|
|
465
|
+
else:
|
|
466
|
+
raise ValueError(hvp_method)
|
|
467
|
+
|
|
468
|
+
return list(Hz), rgrad
|
|
469
|
+
|
|
470
|
+
@torch.no_grad
|
|
471
|
+
def hessian_matrix_product(
|
|
472
|
+
self,
|
|
473
|
+
Z: torch.Tensor,
|
|
474
|
+
rgrad: Sequence[torch.Tensor] | None,
|
|
475
|
+
at_x0: bool,
|
|
476
|
+
hvp_method: HVPMethod,
|
|
477
|
+
h: float,
|
|
478
|
+
retain_graph: bool = False,
|
|
479
|
+
) -> tuple[torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
480
|
+
"""Z is ``(n_dim, n_hvps)``, computes ``H @ Z`` of shape ``(n_dim, n_hvps)``.
|
|
481
|
+
|
|
482
|
+
Returns ``(HZ, rgrad)`` where ``rgrad`` is gradient at current parameters but it may be None.
|
|
483
|
+
|
|
484
|
+
Gradient is set to ``objective`` automatically if ``at_x0`` and can be accessed with ``objective.get_grad()``.
|
|
485
|
+
|
|
486
|
+
Unlike ``hessian_vector_product`` this returns a single matrix, not a per-parameter list.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
Z (torch.Tensor): matrix in hessian-matrix product
|
|
490
|
+
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
491
|
+
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
492
|
+
hvp_method (str): hvp method.
|
|
493
|
+
h (float): finite difference step size
|
|
494
|
+
retain_grad (bool): retain grad
|
|
495
|
+
|
|
496
|
+
"""
|
|
497
|
+
# compute
|
|
498
|
+
if hvp_method == "batched_autograd":
|
|
499
|
+
with torch.enable_grad():
|
|
500
|
+
if rgrad is None: rgrad = self.get_grads(create_graph=True, at_x0=at_x0)
|
|
501
|
+
flat_inputs = torch.cat([g.ravel() for g in rgrad])
|
|
502
|
+
HZ_list = torch.autograd.grad(
|
|
503
|
+
flat_inputs,
|
|
504
|
+
self.params,
|
|
505
|
+
grad_outputs=Z.T,
|
|
506
|
+
is_grads_batched=True,
|
|
507
|
+
retain_graph=retain_graph,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
HZ = flatten_jacobian(HZ_list).T
|
|
511
|
+
|
|
512
|
+
elif hvp_method == 'autograd':
|
|
513
|
+
with torch.enable_grad():
|
|
514
|
+
if rgrad is None: rgrad = self.get_grads(create_graph=True, at_x0=at_x0)
|
|
515
|
+
flat_inputs = torch.cat([g.ravel() for g in rgrad])
|
|
516
|
+
HZ_tensors = [
|
|
517
|
+
torch.autograd.grad(
|
|
518
|
+
flat_inputs,
|
|
519
|
+
self.params,
|
|
520
|
+
grad_outputs=col,
|
|
521
|
+
retain_graph=retain_graph or (i < Z.size(1) - 1),
|
|
522
|
+
)
|
|
523
|
+
for i, col in enumerate(Z.unbind(1))
|
|
524
|
+
]
|
|
525
|
+
|
|
526
|
+
HZ_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HZ_tensors]
|
|
527
|
+
HZ = torch.stack(HZ_list, 1)
|
|
528
|
+
|
|
529
|
+
elif hvp_method == 'fd_forward':
|
|
530
|
+
if rgrad is None: rgrad = self.get_grads(at_x0=at_x0)
|
|
531
|
+
HZ_tensors = [
|
|
532
|
+
hvp_fd_forward(
|
|
533
|
+
self.closure,
|
|
534
|
+
self.params,
|
|
535
|
+
vec_to_tensors(col, self.params),
|
|
536
|
+
h=h,
|
|
537
|
+
g_0=rgrad,
|
|
538
|
+
)[1]
|
|
539
|
+
for col in Z.unbind(1)
|
|
540
|
+
]
|
|
541
|
+
HZ_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HZ_tensors]
|
|
542
|
+
HZ = flatten_jacobian(HZ_list)
|
|
543
|
+
|
|
544
|
+
elif hvp_method == 'fd_central':
|
|
545
|
+
HZ_tensors = [
|
|
546
|
+
hvp_fd_central(
|
|
547
|
+
self.closure, self.params, vec_to_tensors(col, self.params), h=h
|
|
548
|
+
)[1]
|
|
549
|
+
for col in Z.unbind(1)
|
|
550
|
+
]
|
|
551
|
+
HZ_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HZ_tensors]
|
|
552
|
+
HZ = flatten_jacobian(HZ_list)
|
|
553
|
+
|
|
554
|
+
else:
|
|
555
|
+
raise ValueError(hvp_method)
|
|
556
|
+
|
|
557
|
+
return HZ, rgrad
|
|
558
|
+
|
|
559
|
+
@torch.no_grad
|
|
560
|
+
def hutchinson_hessian(
|
|
561
|
+
self,
|
|
562
|
+
rgrad: Sequence[torch.Tensor] | None,
|
|
563
|
+
at_x0: bool,
|
|
564
|
+
n_samples: int | None,
|
|
565
|
+
distribution: Distributions | Sequence[Sequence[torch.Tensor]],
|
|
566
|
+
hvp_method: HVPMethod,
|
|
567
|
+
h: float,
|
|
568
|
+
generator,
|
|
569
|
+
variance: int | None = 1,
|
|
570
|
+
zHz: bool = True,
|
|
571
|
+
retain_graph: bool = False,
|
|
572
|
+
) -> tuple[list[torch.Tensor], Sequence[torch.Tensor] | None]:
|
|
573
|
+
"""
|
|
574
|
+
Returns ``(D, rgrad)``, where ``rgrad`` is gradient at current parameters but it may be None.
|
|
575
|
+
|
|
576
|
+
Gradient is set to ``objective`` automatically if ``at_x0`` and can be accessed with ``objective.get_grad()``.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
580
|
+
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
581
|
+
n_samples (int | None): number of random vectors.
|
|
582
|
+
distribution (Distributions | Sequence[Sequence[torch.Tensor]]):
|
|
583
|
+
distribution, this can also be a sequence of tensor sequences.
|
|
584
|
+
hvp_method (str): how to compute hessian-vector products.
|
|
585
|
+
h (float): finite difference step size.
|
|
586
|
+
generator (Any): generator
|
|
587
|
+
variance (int | None, optional): variance of random vectors. Defaults to 1.
|
|
588
|
+
zHz (bool, optional): whether to compute z ⊙ Hz. If False, computes Hz. Defaults to True.
|
|
589
|
+
retain_graph (bool, optional): whether to retain graph. Defaults to False.
|
|
590
|
+
"""
|
|
591
|
+
|
|
592
|
+
params = TensorList(self.params)
|
|
593
|
+
samples = None
|
|
594
|
+
|
|
595
|
+
# check when distribution is sequence of tensors
|
|
596
|
+
if not isinstance(distribution, str):
|
|
597
|
+
if n_samples is not None and n_samples != len(distribution):
|
|
598
|
+
raise RuntimeError("when passing sequence of z to `hutchinson_hessian`, set `n_samples` to None")
|
|
599
|
+
|
|
600
|
+
n_samples = len(distribution)
|
|
601
|
+
samples = distribution
|
|
602
|
+
|
|
603
|
+
# use non-batched with single sample
|
|
604
|
+
if n_samples == 1 and hvp_method == 'batched_autograd':
|
|
605
|
+
hvp_method = 'autograd'
|
|
606
|
+
|
|
607
|
+
# -------------------------- non-batched hutchinson -------------------------- #
|
|
608
|
+
if hvp_method in ('autograd', 'fd_forward', 'fd_central'):
|
|
609
|
+
|
|
610
|
+
D = None
|
|
611
|
+
assert n_samples is not None
|
|
612
|
+
|
|
613
|
+
for i in range(n_samples):
|
|
614
|
+
|
|
615
|
+
# sample
|
|
616
|
+
if samples is not None: z = samples[i]
|
|
617
|
+
else: z = params.sample_like(cast(Distributions, distribution), variance, generator=generator)
|
|
618
|
+
|
|
619
|
+
# compute
|
|
620
|
+
Hz, rgrad = self.hessian_vector_product(
|
|
621
|
+
z=z,
|
|
622
|
+
rgrad=rgrad,
|
|
623
|
+
at_x0=at_x0,
|
|
624
|
+
hvp_method=hvp_method,
|
|
625
|
+
h=h,
|
|
626
|
+
retain_graph=(i < n_samples - 1) or retain_graph,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# add
|
|
630
|
+
if zHz: torch._foreach_mul_(Hz, tuple(z))
|
|
631
|
+
|
|
632
|
+
if D is None: D = Hz
|
|
633
|
+
else: torch._foreach_add_(D, Hz)
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
assert D is not None
|
|
637
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
638
|
+
return D, rgrad
|
|
639
|
+
|
|
640
|
+
# ---------------------------- batched hutchinson ---------------------------- #
|
|
641
|
+
if hvp_method != 'batched_autograd':
|
|
642
|
+
raise RuntimeError(f"Unknown hvp_method: `{hvp_method}`")
|
|
643
|
+
|
|
644
|
+
# generate and vectorize samples
|
|
645
|
+
if samples is None:
|
|
646
|
+
samples = [params.sample_like(cast(Distributions, distribution), variance, generator=generator).to_vec()]
|
|
647
|
+
|
|
648
|
+
else:
|
|
649
|
+
samples = [torch.cat([t.ravel() for t in s]) for s in samples]
|
|
650
|
+
|
|
651
|
+
# compute Hz
|
|
652
|
+
Z = torch.stack(samples, -1)
|
|
653
|
+
HZ, rgrad = self.hessian_matrix_product(
|
|
654
|
+
Z,
|
|
655
|
+
rgrad=rgrad,
|
|
656
|
+
at_x0=at_x0,
|
|
657
|
+
hvp_method='batched_autograd',
|
|
658
|
+
h=h, # not used
|
|
659
|
+
retain_graph=retain_graph,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
if zHz: HZ *= Z
|
|
663
|
+
D_vec = HZ.mean(-1)
|
|
664
|
+
return vec_to_tensors(D_vec, params), rgrad
|
|
665
|
+
|
|
666
|
+
@torch.no_grad
|
|
667
|
+
def hessian(
|
|
668
|
+
self,
|
|
669
|
+
hessian_method: HessianMethod,
|
|
670
|
+
h: float,
|
|
671
|
+
at_x0: bool,
|
|
672
|
+
) -> tuple[torch.Tensor | None, Sequence[torch.Tensor] | None, torch.Tensor]:
|
|
673
|
+
"""returns ``(f, g_list, H)``. Also sets ``objective.loss`` and ``objective.grad`` if ``at_x0``.
|
|
674
|
+
|
|
675
|
+
``f`` and ``g_list`` may be None if they aren't computed with ``hessian_method``.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
hessian_method: how to compute hessian
|
|
679
|
+
h (float): finite difference step size
|
|
680
|
+
vectorize (bool): whether to vectorize hessian computation
|
|
681
|
+
at_x0 (bool): whether its at x0.
|
|
682
|
+
"""
|
|
683
|
+
closure = self.closure
|
|
684
|
+
if closure is None:
|
|
685
|
+
raise RuntimeError("Computing hessian requires a closure to be provided to the `step` method.")
|
|
686
|
+
|
|
687
|
+
params = self.params
|
|
688
|
+
numel = sum(p.numel() for p in params)
|
|
689
|
+
|
|
690
|
+
f = None
|
|
691
|
+
g_list = None
|
|
692
|
+
|
|
693
|
+
# autograd hessian
|
|
694
|
+
if hessian_method in ("batched_autograd", "autograd"):
|
|
695
|
+
with torch.enable_grad():
|
|
696
|
+
f = self.get_loss(False, at_x0=at_x0)
|
|
697
|
+
|
|
698
|
+
batched = hessian_method == "batched_autograd"
|
|
699
|
+
g_list, H_list = jacobian_and_hessian_wrt([f.ravel()], params, batched=batched)
|
|
700
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
701
|
+
|
|
702
|
+
H = flatten_jacobian(H_list)
|
|
703
|
+
|
|
704
|
+
# functional autograd hessian
|
|
705
|
+
elif hessian_method in ('func', 'functional_revrev', 'functional_fwdrev'):
|
|
706
|
+
if hessian_method == 'functional_fwdrev':
|
|
707
|
+
method = "autograd.functional"
|
|
708
|
+
outer_jacobian_strategy = "forward-mode"
|
|
709
|
+
vectorize=True
|
|
710
|
+
elif hessian_method == 'functional_revrev':
|
|
711
|
+
method = "autograd.functional"
|
|
712
|
+
outer_jacobian_strategy = "reverse-mode"
|
|
713
|
+
vectorize=False
|
|
714
|
+
else:
|
|
715
|
+
method = 'func'
|
|
716
|
+
outer_jacobian_strategy = "forward-mode" # unused
|
|
717
|
+
vectorize=True # unused
|
|
718
|
+
|
|
719
|
+
with torch.enable_grad():
|
|
720
|
+
H = hessian_mat(partial(closure, backward=False), params,
|
|
721
|
+
method=method, vectorize=vectorize,
|
|
722
|
+
outer_jacobian_strategy=outer_jacobian_strategy)
|
|
723
|
+
|
|
724
|
+
# thoad
|
|
725
|
+
elif hessian_method == "thoad":
|
|
726
|
+
with torch.enable_grad():
|
|
727
|
+
f = self.get_loss(False, at_x0=at_x0)
|
|
728
|
+
ctrl = lazy_thoad.backward(f, 2, crossings=True)
|
|
729
|
+
|
|
730
|
+
g_list = [p.hgrad[0].squeeze(0) for p in params] # pyright:ignore[reportAttributeAccessIssue]
|
|
731
|
+
H = thoad_single_tensor(ctrl, params, 2)
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
# gradient finite difference
|
|
735
|
+
elif hessian_method in ('gfd_forward', 'gfd_central'):
|
|
736
|
+
|
|
737
|
+
if hessian_method == 'gfd_central': hvp_method = 'fd_central'
|
|
738
|
+
else: hvp_method = 'fd_forward'
|
|
739
|
+
|
|
740
|
+
I = torch.eye(numel, device=params[0].device, dtype=params[0].dtype)
|
|
741
|
+
H, g_list = self.hessian_matrix_product(I, rgrad=None, at_x0=at_x0, hvp_method=hvp_method, h=h)
|
|
742
|
+
|
|
743
|
+
# function value finite difference
|
|
744
|
+
elif hessian_method in ('fd', "fd_full"):
|
|
745
|
+
full = hessian_method == "fd_full"
|
|
746
|
+
f, g_list, H = hessian_fd(partial(closure, False), params=params, eps=h, full=full)
|
|
747
|
+
|
|
748
|
+
else:
|
|
749
|
+
raise ValueError(hessian_method)
|
|
750
|
+
|
|
751
|
+
# set objective attributes if at x0
|
|
752
|
+
if at_x0:
|
|
753
|
+
if f is not None and self.loss is None:
|
|
754
|
+
self.loss = self.loss_approx = f
|
|
755
|
+
|
|
756
|
+
if g_list is not None and self.grads is None:
|
|
757
|
+
self.grads = list(g_list)
|
|
758
|
+
|
|
759
|
+
return f, g_list, H.detach()
|
|
760
|
+
|
|
761
|
+
@torch.no_grad
|
|
762
|
+
def derivatives(self, order: int, at_x0: bool, method:DerivativesMethod="batched_autograd"):
|
|
763
|
+
"""
|
|
764
|
+
returns a tuple of tensors of function value and derivatives up to ``order``
|
|
765
|
+
|
|
766
|
+
``order = 0`` returns ``(f,)``;
|
|
767
|
+
|
|
768
|
+
``order = 1`` returns ``(f, g)``;
|
|
769
|
+
|
|
770
|
+
``order = 2`` returns ``(f, g, H)``;
|
|
771
|
+
|
|
772
|
+
``order = 3`` returns ``(f, g, H, T3)``;
|
|
773
|
+
|
|
774
|
+
etc.
|
|
775
|
+
"""
|
|
776
|
+
closure = self.closure
|
|
777
|
+
if closure is None:
|
|
778
|
+
raise RuntimeError("Computing hessian requires a closure to be provided to the `step` method.")
|
|
779
|
+
|
|
780
|
+
# just loss
|
|
781
|
+
if order == 0:
|
|
782
|
+
f = self.get_loss(False, at_x0=at_x0)
|
|
783
|
+
return (f, )
|
|
784
|
+
|
|
785
|
+
# loss and grad
|
|
786
|
+
if order == 1:
|
|
787
|
+
f, g_list = self.get_loss_grads(at_x0=at_x0)
|
|
788
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
789
|
+
|
|
790
|
+
return f, g
|
|
791
|
+
|
|
792
|
+
if method in ("autograd", "batched_autograd"):
|
|
793
|
+
batched = method == "batched_autograd"
|
|
794
|
+
|
|
795
|
+
# recursively compute derivatives up to order
|
|
796
|
+
with torch.enable_grad():
|
|
797
|
+
f, g_list = self.get_loss_grads(at_x0=at_x0, create_graph=True)
|
|
798
|
+
g = torch.cat([t.ravel() for t in g_list])
|
|
799
|
+
|
|
800
|
+
n = g.numel()
|
|
801
|
+
ret = [f, g]
|
|
802
|
+
T = g # current derivatives tensor
|
|
803
|
+
|
|
804
|
+
# get all derivative up to order
|
|
805
|
+
for o in range(2, order + 1):
|
|
806
|
+
is_last = o == order
|
|
807
|
+
T_list = jacobian_wrt([T], self.params, create_graph=not is_last, batched=batched)
|
|
808
|
+
with torch.no_grad() if is_last else nullcontext():
|
|
809
|
+
|
|
810
|
+
# the shape is (ndim, ) * order
|
|
811
|
+
T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
|
|
812
|
+
ret.append(T)
|
|
813
|
+
|
|
814
|
+
return tuple(ret)
|
|
815
|
+
|
|
816
|
+
if method == "thoad":
|
|
817
|
+
with torch.enable_grad():
|
|
818
|
+
f = self.get_loss(False, at_x0=at_x0)
|
|
819
|
+
ctrl = lazy_thoad.backward(f, order, crossings=True)
|
|
820
|
+
|
|
821
|
+
return tuple([f, *thoad_derivatives(ctrl, self.params, order=order)])
|
|
822
|
+
|
|
823
|
+
raise ValueError(method)
|
|
824
|
+
|
|
825
|
+
@torch.no_grad
|
|
826
|
+
def derivatives_at(
|
|
827
|
+
self,
|
|
828
|
+
x: torch.Tensor | Sequence[torch.Tensor],
|
|
829
|
+
order: int,
|
|
830
|
+
method:DerivativesMethod="batched_autograd"
|
|
831
|
+
):
|
|
832
|
+
"""
|
|
833
|
+
returns a tuple of tensors of function value and derivatives up to ``order`` at ``x``,
|
|
834
|
+
then sets original parameters.
|
|
835
|
+
|
|
836
|
+
``x`` can be a vector or a list of tensors.
|
|
837
|
+
|
|
838
|
+
``order = 0`` returns ``(f,)``;
|
|
839
|
+
|
|
840
|
+
``order = 1`` returns ``(f, g)``;
|
|
841
|
+
|
|
842
|
+
``order = 2`` returns ``(f, g, H)``;
|
|
843
|
+
|
|
844
|
+
``order = 3`` returns ``(f, g, H, T3)``;
|
|
845
|
+
|
|
846
|
+
etc.
|
|
847
|
+
"""
|
|
848
|
+
if isinstance(x, torch.Tensor): x = vec_to_tensors(x, self.params)
|
|
849
|
+
|
|
850
|
+
x0 = [p.clone() for p in self.params]
|
|
851
|
+
|
|
852
|
+
# set params to x
|
|
853
|
+
for p, x_i in zip(self.params, x):
|
|
854
|
+
set_storage_(p, x_i)
|
|
855
|
+
|
|
856
|
+
ret = self.derivatives(order=order, at_x0=False, method=method)
|
|
857
|
+
|
|
858
|
+
# set params to x0
|
|
859
|
+
for p, x0_i in zip(self.params, x0):
|
|
860
|
+
set_storage_(p, x0_i)
|
|
861
|
+
|
|
862
|
+
return ret
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def list_Hvp_function(self, hvp_method: HVPMethod, h: float, at_x0:bool):
|
|
866
|
+
"""returns ``(grad, H_mv)`` where ``H_mv`` is a callable that accepts and returns lists of tensors.
|
|
867
|
+
|
|
868
|
+
``grad`` may be None, and this sets ``objective.grad`` if ``at_x0`` so at x0 just use ``objective.get_grad()``.
|
|
869
|
+
"""
|
|
870
|
+
params = TensorList(self.params)
|
|
871
|
+
closure = self.closure
|
|
872
|
+
|
|
873
|
+
if hvp_method in ('batched_autograd', 'autograd'):
|
|
874
|
+
grad = self.get_grads(create_graph=True, at_x0=at_x0)
|
|
875
|
+
|
|
876
|
+
def H_mv(x: torch.Tensor | Sequence[torch.Tensor]):
|
|
877
|
+
if isinstance(x, torch.Tensor): x = params.from_vec(x)
|
|
878
|
+
with torch.enable_grad():
|
|
879
|
+
return TensorList(torch.autograd.grad(grad, params, x, retain_graph=True))
|
|
880
|
+
|
|
881
|
+
else:
|
|
882
|
+
|
|
883
|
+
if hvp_method == 'fd_forward':
|
|
884
|
+
grad = self.get_grads(at_x0=at_x0)
|
|
885
|
+
def H_mv(x: torch.Tensor | Sequence[torch.Tensor]):
|
|
886
|
+
if isinstance(x, torch.Tensor): x = params.from_vec(x)
|
|
887
|
+
_, Hx = hvp_fd_forward(closure, params, x, h=h, g_0=grad)
|
|
888
|
+
return TensorList(Hx)
|
|
889
|
+
|
|
890
|
+
elif hvp_method == 'fd_central':
|
|
891
|
+
grad = None
|
|
892
|
+
def H_mv(x: torch.Tensor | Sequence[torch.Tensor]):
|
|
893
|
+
if isinstance(x, torch.Tensor): x = params.from_vec(x)
|
|
894
|
+
_, Hx = hvp_fd_central(closure, params, x, h=h)
|
|
895
|
+
return TensorList(Hx)
|
|
896
|
+
|
|
897
|
+
else:
|
|
898
|
+
raise ValueError(hvp_method)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
return grad, H_mv
|
|
902
|
+
|
|
903
|
+
def tensor_Hvp_function(self, hvp_method: HVPMethod, h: float, at_x0:bool):
|
|
904
|
+
"""returns ``(grad, H_mv, H_mm)``, where ``H_mv`` and ``H_mm`` accept and return single tensors.
|
|
905
|
+
|
|
906
|
+
``grad`` may be None, and this sets ``objective.grad`` if ``at_x0`` so at x0 just use ``objective.get_grad()``.
|
|
907
|
+
"""
|
|
908
|
+
if hvp_method in ('fd_forward', "fd_central", "autograd"):
|
|
909
|
+
grad, list_H_mv = self.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=at_x0)
|
|
910
|
+
|
|
911
|
+
def H_mv_loop(x: torch.Tensor):
|
|
912
|
+
Hx_list = list_H_mv(x)
|
|
913
|
+
return torch.cat([t.ravel() for t in Hx_list])
|
|
914
|
+
|
|
915
|
+
def H_mm_loop(X: torch.Tensor):
|
|
916
|
+
return torch.stack([H_mv_loop(col) for col in X.unbind(-1)], -1)
|
|
917
|
+
|
|
918
|
+
return grad, H_mv_loop, H_mm_loop
|
|
919
|
+
|
|
920
|
+
# for batched we need grad
|
|
921
|
+
if hvp_method != 'batched_autograd':
|
|
922
|
+
raise RuntimeError(f"Unknown hvp_method `{hvp_method}`")
|
|
923
|
+
|
|
924
|
+
params = TensorList(self.params)
|
|
925
|
+
grad = self.get_grads(create_graph=True, at_x0=at_x0)
|
|
926
|
+
|
|
927
|
+
def H_mv_batched(x: torch.Tensor):
|
|
928
|
+
with torch.enable_grad():
|
|
929
|
+
Hx_list = torch.autograd.grad(grad, params, params.from_vec(x), retain_graph=True)
|
|
930
|
+
|
|
931
|
+
return torch.cat([t.ravel() for t in Hx_list])
|
|
932
|
+
|
|
933
|
+
def H_mm_batched(X: torch.Tensor):
|
|
934
|
+
with torch.enable_grad():
|
|
935
|
+
flat_inputs = torch.cat([g.ravel() for g in grad])
|
|
936
|
+
HX_list = torch.autograd.grad(
|
|
937
|
+
flat_inputs,
|
|
938
|
+
self.params,
|
|
939
|
+
grad_outputs=X.T,
|
|
940
|
+
is_grads_batched=True,
|
|
941
|
+
retain_graph=True,
|
|
942
|
+
)
|
|
943
|
+
return flatten_jacobian(HX_list).T
|
|
944
|
+
|
|
945
|
+
return grad, H_mv_batched, H_mm_batched
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
# endregion
|