torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
# import visualbench as vb
|
|
5
|
+
|
|
6
|
+
# import torchzero as tz
|
|
7
|
+
|
|
8
|
+
from ...core import Transform, Chainable, apply
|
|
9
|
+
from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
|
|
10
|
+
from ...utils import TensorList, vec_to_tensors_
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def inverse_sqrt(M):
|
|
14
|
+
if M.shape[-1] == 2: return inv_sqrt_2x2(M, force_pd=True) # general formula for 2x2 matrices
|
|
15
|
+
return matrix_power_eigh(M, -1/2)
|
|
16
|
+
|
|
17
|
+
def update_subspace_preconditioner_(
|
|
18
|
+
grad: torch.Tensor, # store grads and basis as vectors for matmul
|
|
19
|
+
basis: torch.Tensor, # ndim, k
|
|
20
|
+
accumulator_: torch.Tensor, # k, k
|
|
21
|
+
beta: float | None,
|
|
22
|
+
):
|
|
23
|
+
projected = basis.T @ grad # k
|
|
24
|
+
outer = torch.outer(projected, projected)
|
|
25
|
+
|
|
26
|
+
if beta is None: accumulator_.add_(outer)
|
|
27
|
+
else: accumulator_.lerp_(outer, 1-beta)
|
|
28
|
+
|
|
29
|
+
def apply_subspace_preconditioner(
|
|
30
|
+
tensor: torch.Tensor,
|
|
31
|
+
basis: torch.Tensor, # ndim, k
|
|
32
|
+
accumulator: torch.Tensor,
|
|
33
|
+
):
|
|
34
|
+
preconditioner = inverse_sqrt(accumulator) # k,k
|
|
35
|
+
|
|
36
|
+
tensor_projected = basis.T @ tensor # k
|
|
37
|
+
update_projected = preconditioner @ tensor_projected # k
|
|
38
|
+
return basis @ update_projected # d
|
|
39
|
+
|
|
40
|
+
class RandomSubspacePreconditioning(Transform):
|
|
41
|
+
"""full matrix rmsprop in random subspace"""
|
|
42
|
+
def __init__(self, k: int, beta: float | None = 0.99):
|
|
43
|
+
defaults = dict(k=k, beta=beta)
|
|
44
|
+
super().__init__(defaults, uses_grad=False)
|
|
45
|
+
|
|
46
|
+
def transform(self, tensors, params, grads, vars):
|
|
47
|
+
settings = self.settings[params[0]]
|
|
48
|
+
g = torch.cat([t.view(-1) for t in tensors])
|
|
49
|
+
k = settings['k']
|
|
50
|
+
beta = settings['beta']
|
|
51
|
+
|
|
52
|
+
if 'basis' not in self.global_state:
|
|
53
|
+
self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
54
|
+
self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
55
|
+
|
|
56
|
+
basis = self.global_state['basis']
|
|
57
|
+
accumulator = self.global_state['accumulator']
|
|
58
|
+
|
|
59
|
+
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
60
|
+
try:
|
|
61
|
+
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
62
|
+
except torch.linalg.LinAlgError:
|
|
63
|
+
denom = g.abs().sum()
|
|
64
|
+
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
65
|
+
preconditioned = g / g.abs().sum()
|
|
66
|
+
vec_to_tensors_(preconditioned, tensors)
|
|
67
|
+
|
|
68
|
+
return tensors
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class HistorySubspacePreconditioning(Transform):
|
|
72
|
+
"""full matrix rmsprop in subspace spanned by history of gradient differences
|
|
73
|
+
|
|
74
|
+
basis_beta is how much basis is allowed to change, and beta is for preconditioner itself in the basis.
|
|
75
|
+
"""
|
|
76
|
+
def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
|
|
77
|
+
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
78
|
+
super().__init__(defaults, uses_grad=False)
|
|
79
|
+
|
|
80
|
+
if inner is not None: self.set_child('inner', inner)
|
|
81
|
+
|
|
82
|
+
def transform(self, tensors, params, grads, vars):
|
|
83
|
+
settings = self.settings[params[0]]
|
|
84
|
+
|
|
85
|
+
g = torch.cat([t.view(-1) for t in tensors])
|
|
86
|
+
k = settings['k']
|
|
87
|
+
beta = settings['beta']
|
|
88
|
+
basis_beta = settings['basis_beta']
|
|
89
|
+
|
|
90
|
+
if 'history' not in self.global_state:
|
|
91
|
+
self.global_state['history'] = deque(maxlen=k)
|
|
92
|
+
self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
93
|
+
self.global_state['basis'] = torch.ones(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
history: deque = self.global_state['history']
|
|
97
|
+
accumulator = self.global_state['accumulator']
|
|
98
|
+
basis = self.global_state['basis']
|
|
99
|
+
|
|
100
|
+
history.append(g)
|
|
101
|
+
if len(history) < k:
|
|
102
|
+
basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
103
|
+
history_basis = torch.stack(tuple(history), -1)
|
|
104
|
+
basis_t[:, -len(history):] = history_basis
|
|
105
|
+
|
|
106
|
+
else:
|
|
107
|
+
basis_t = torch.stack(tuple(history), -1)
|
|
108
|
+
|
|
109
|
+
basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
|
|
110
|
+
basis_t = (basis_t - basis_t.mean()) / basis_t.std()
|
|
111
|
+
|
|
112
|
+
basis.lerp_(basis_t, 1-basis_beta)
|
|
113
|
+
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
114
|
+
|
|
115
|
+
if 'inner' in self.children:
|
|
116
|
+
tensors = apply(self.children['inner'], tensors, params, grads, vars)
|
|
117
|
+
g = torch.cat([t.view(-1) for t in tensors])
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
121
|
+
except torch.linalg.LinAlgError:
|
|
122
|
+
denom = g.abs().sum()
|
|
123
|
+
if denom <= 1e-10: denom = torch.ones_like(denom)
|
|
124
|
+
preconditioned = g / g.abs().sum()
|
|
125
|
+
vec_to_tensors_(preconditioned, tensors)
|
|
126
|
+
|
|
127
|
+
return tensors
|
|
128
|
+
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, apply, Module
|
|
8
|
+
from ...utils import vec_to_tensors, TensorList
|
|
9
|
+
from ...utils.derivatives import (
|
|
10
|
+
hessian_list_to_mat,
|
|
11
|
+
hessian_mat,
|
|
12
|
+
jacobian_and_hessian_wrt,
|
|
13
|
+
)
|
|
14
|
+
from ..second_order.newton import lu_solve, cholesky_solve, least_squares_solve
|
|
15
|
+
|
|
16
|
+
def tropical_sum(x, dim): return torch.amax(x, dim=dim)
|
|
17
|
+
def tropical_mul(x, y): return x+y
|
|
18
|
+
|
|
19
|
+
def tropical_matmul(x: torch.Tensor, y: torch.Tensor):
|
|
20
|
+
# this imlements matmul by calling mul and sum
|
|
21
|
+
|
|
22
|
+
x_squeeze = False
|
|
23
|
+
y_squeeze = False
|
|
24
|
+
|
|
25
|
+
if x.ndim == 1:
|
|
26
|
+
x_squeeze = True
|
|
27
|
+
x = x.unsqueeze(0)
|
|
28
|
+
|
|
29
|
+
if y.ndim == 1:
|
|
30
|
+
y_squeeze = True
|
|
31
|
+
y = y.unsqueeze(1)
|
|
32
|
+
|
|
33
|
+
res = tropical_sum(tropical_mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim = -2)
|
|
34
|
+
|
|
35
|
+
if x_squeeze: res = res.squeeze(-2)
|
|
36
|
+
if y_squeeze: res = res.squeeze(-1)
|
|
37
|
+
|
|
38
|
+
return res
|
|
39
|
+
|
|
40
|
+
def tropical_dot(x:torch.Tensor, y:torch.Tensor):
|
|
41
|
+
assert x.ndim == 1 and y.ndim == 1
|
|
42
|
+
return tropical_matmul(x.unsqueeze(0), y.unsqueeze(1))
|
|
43
|
+
|
|
44
|
+
def tropical_outer(x:torch.Tensor, y:torch.Tensor):
|
|
45
|
+
assert x.ndim == 1 and y.ndim == 1
|
|
46
|
+
return tropical_matmul(x.unsqueeze(1), y.unsqueeze(0))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def tropical_solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
r = b.unsqueeze(1) - A
|
|
51
|
+
return r.amin(dim=-2)
|
|
52
|
+
|
|
53
|
+
def tropical_solve_and_reconstruct(A: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
54
|
+
r = b.unsqueeze(1) - A
|
|
55
|
+
x = r.amin(dim=-2)
|
|
56
|
+
b_hat = tropical_matmul(A, x)
|
|
57
|
+
return x, b_hat
|
|
58
|
+
|
|
59
|
+
def tikhonov(H: torch.Tensor, reg: float):
|
|
60
|
+
if reg!=0: H += torch.eye(H.size(-1), dtype=H.dtype, device=H.device) * reg
|
|
61
|
+
return H
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class TropicalNewton(Module):
|
|
65
|
+
"""suston"""
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
reg: float | None = None,
|
|
69
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
70
|
+
vectorize: bool = True,
|
|
71
|
+
interpolate:bool=False,
|
|
72
|
+
inner: Chainable | None = None,
|
|
73
|
+
):
|
|
74
|
+
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, interpolate=interpolate)
|
|
75
|
+
super().__init__(defaults)
|
|
76
|
+
|
|
77
|
+
if inner is not None:
|
|
78
|
+
self.set_child('inner', inner)
|
|
79
|
+
|
|
80
|
+
@torch.no_grad
|
|
81
|
+
def step(self, vars):
|
|
82
|
+
params = TensorList(vars.params)
|
|
83
|
+
closure = vars.closure
|
|
84
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
85
|
+
|
|
86
|
+
settings = self.settings[params[0]]
|
|
87
|
+
reg = settings['reg']
|
|
88
|
+
hessian_method = settings['hessian_method']
|
|
89
|
+
vectorize = settings['vectorize']
|
|
90
|
+
interpolate = settings['interpolate']
|
|
91
|
+
|
|
92
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
93
|
+
if hessian_method == 'autograd':
|
|
94
|
+
with torch.enable_grad():
|
|
95
|
+
loss = vars.loss = vars.loss_approx = closure(False)
|
|
96
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
97
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
98
|
+
vars.grad = g_list
|
|
99
|
+
H = hessian_list_to_mat(H_list)
|
|
100
|
+
|
|
101
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
102
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
103
|
+
with torch.enable_grad():
|
|
104
|
+
g_list = vars.get_grad(retain_graph=True)
|
|
105
|
+
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
106
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
107
|
+
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(hessian_method)
|
|
110
|
+
|
|
111
|
+
# -------------------------------- inner step -------------------------------- #
|
|
112
|
+
if 'inner' in self.children:
|
|
113
|
+
g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
|
|
114
|
+
g = torch.cat([t.view(-1) for t in g_list])
|
|
115
|
+
|
|
116
|
+
# ------------------------------- regulazition ------------------------------- #
|
|
117
|
+
if reg is not None: H = tikhonov(H, reg)
|
|
118
|
+
|
|
119
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
120
|
+
tropical_update, g_hat = tropical_solve_and_reconstruct(H, g)
|
|
121
|
+
|
|
122
|
+
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
123
|
+
abs_error = torch.linalg.vector_norm(g-g_hat) # pylint:disable=not-callable
|
|
124
|
+
rel_error = abs_error/g_norm.clip(min=1e-8)
|
|
125
|
+
|
|
126
|
+
if interpolate:
|
|
127
|
+
if rel_error > 1e-8:
|
|
128
|
+
|
|
129
|
+
update = cholesky_solve(H, g)
|
|
130
|
+
if update is None: update = lu_solve(H, g)
|
|
131
|
+
if update is None: update = least_squares_solve(H, g)
|
|
132
|
+
|
|
133
|
+
tropical_update.lerp_(update.ravel(), rel_error.clip(max=1))
|
|
134
|
+
|
|
135
|
+
vars.update = vec_to_tensors(tropical_update, params)
|
|
136
|
+
return vars
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Arguments that are modified in-place are denoted with "_" at the end.
|
|
3
|
+
|
|
4
|
+
Some functions return one of the arguments which was modified in-place, some return new tensors.
|
|
5
|
+
Make sure to keep track of that to avoid unexpected in-place modifications of buffers. The returned
|
|
6
|
+
storage is always indicated in the docstring.
|
|
7
|
+
|
|
8
|
+
Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from collections.abc import Callable, Sequence
|
|
12
|
+
|
|
13
|
+
from ..utils import NumberList, TensorList
|
|
14
|
+
|
|
15
|
+
inf = float('inf')
|
|
16
|
+
|
|
17
|
+
def debiased_step_size(
|
|
18
|
+
step,
|
|
19
|
+
beta1: float | NumberList | None = None,
|
|
20
|
+
beta2: float | NumberList | None = None,
|
|
21
|
+
pow: float = 2,
|
|
22
|
+
alpha: float | NumberList = 1,
|
|
23
|
+
):
|
|
24
|
+
"""returns multiplier to step size"""
|
|
25
|
+
if isinstance(beta1, NumberList): beta1 = beta1.fill_none(0)
|
|
26
|
+
if isinstance(beta2, NumberList): beta2 = beta2.fill_none(0)
|
|
27
|
+
|
|
28
|
+
step_size = alpha
|
|
29
|
+
if beta1 is not None:
|
|
30
|
+
bias_correction1 = 1.0 - (beta1 ** step)
|
|
31
|
+
step_size /= bias_correction1
|
|
32
|
+
if beta2 is not None:
|
|
33
|
+
bias_correction2 = 1.0 - (beta2 ** step)
|
|
34
|
+
step_size *= bias_correction2 ** (1/pow)
|
|
35
|
+
return step_size
|
|
36
|
+
|
|
37
|
+
def debias(
|
|
38
|
+
tensors_: TensorList,
|
|
39
|
+
step: int,
|
|
40
|
+
inplace: bool,
|
|
41
|
+
beta1: float | NumberList | None = None,
|
|
42
|
+
beta2: float | NumberList | None = None,
|
|
43
|
+
alpha: float | NumberList = 1,
|
|
44
|
+
pow: float = 2,
|
|
45
|
+
):
|
|
46
|
+
step_size = debiased_step_size(step=step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
47
|
+
if inplace: return tensors_.mul_(step_size)
|
|
48
|
+
return tensors_ * step_size
|
|
49
|
+
|
|
50
|
+
def debias_second_momentum(tensors_:TensorList, step: int, beta: float | NumberList, pow: float, inplace:bool):
|
|
51
|
+
"""debias 2nd momentum, optionally in-place"""
|
|
52
|
+
bias_correction2 = (1.0 - (beta ** step)) ** (1/pow)
|
|
53
|
+
if inplace: return tensors_.div_(bias_correction2)
|
|
54
|
+
return tensors_ / bias_correction2
|
|
55
|
+
|
|
56
|
+
def lerp_power_(tensors:TensorList, exp_avg_pow_:TensorList, beta:float|NumberList, pow:float) -> TensorList:
|
|
57
|
+
"""
|
|
58
|
+
Lerp `exp_avg_pow_` with `tensors ^ pow`
|
|
59
|
+
|
|
60
|
+
Returns `exp_avg_pow_`.
|
|
61
|
+
"""
|
|
62
|
+
if pow == 1: return exp_avg_pow_.lerp_(tensors.abs(), 1-beta)
|
|
63
|
+
if pow == 2: return exp_avg_pow_.mul_(beta).addcmul_(tensors, tensors, value = 1-beta)
|
|
64
|
+
if pow % 2 == 0: return exp_avg_pow_.lerp_(tensors.pow(pow), 1-beta)
|
|
65
|
+
return exp_avg_pow_.lerp_(tensors.pow(pow).abs_(), 1-beta)
|
|
66
|
+
|
|
67
|
+
def add_power_(tensors:TensorList, sum_:TensorList, pow:float) -> TensorList:
|
|
68
|
+
"""
|
|
69
|
+
Add `tensors ^ pow` to `sum_`
|
|
70
|
+
|
|
71
|
+
Returns `sum_`.
|
|
72
|
+
"""
|
|
73
|
+
if pow == 1: return sum_.add_(tensors.abs())
|
|
74
|
+
if pow == 2: return sum_.addcmul_(tensors, tensors)
|
|
75
|
+
if pow % 2 == 0: return sum_.add_(tensors.pow(pow))
|
|
76
|
+
return sum_.add_(tensors.pow(pow).abs_())
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def root(tensors_:TensorList, p:float, inplace: bool):
|
|
80
|
+
"""
|
|
81
|
+
Root of tensors, optionally in-place.
|
|
82
|
+
|
|
83
|
+
Returns `tensors_` if `inplace` else new tensors.
|
|
84
|
+
"""
|
|
85
|
+
if inplace:
|
|
86
|
+
if p == 1: return tensors_.abs_()
|
|
87
|
+
if p == 2: return tensors_.sqrt_()
|
|
88
|
+
return tensors_.pow_(1/p)
|
|
89
|
+
else:
|
|
90
|
+
if p == 1: return tensors_.abs()
|
|
91
|
+
if p == 2: return tensors_.sqrt()
|
|
92
|
+
return tensors_.pow(1/p)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def ema_(
|
|
96
|
+
tensors: TensorList,
|
|
97
|
+
exp_avg_: TensorList,
|
|
98
|
+
beta: float | NumberList,
|
|
99
|
+
dampening: float | NumberList = 0,
|
|
100
|
+
lerp: bool = True,
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Updates `exp_avg_` with EMA of `tensors`.
|
|
104
|
+
|
|
105
|
+
Returns `exp_avg_`.
|
|
106
|
+
"""
|
|
107
|
+
tensors.lazy_mul_(1 - dampening)
|
|
108
|
+
if lerp: return exp_avg_.lerp_(tensors, (1 - beta))
|
|
109
|
+
return exp_avg_.mul_(beta).add_(tensors)
|
|
110
|
+
|
|
111
|
+
def ema_sq_(
|
|
112
|
+
tensors: TensorList,
|
|
113
|
+
exp_avg_sq_: TensorList,
|
|
114
|
+
beta: float | NumberList,
|
|
115
|
+
max_exp_avg_sq_: TensorList | None,
|
|
116
|
+
pow: float = 2,
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
Updates `exp_avg_sq_` with EMA of squared `tensors`, if `max_exp_avg_sq_` is not None, updates it with maximum of EMA.
|
|
120
|
+
|
|
121
|
+
Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
|
|
122
|
+
"""
|
|
123
|
+
lerp_power_(tensors=tensors, exp_avg_pow_=exp_avg_sq_,beta=beta,pow=pow)
|
|
124
|
+
|
|
125
|
+
# AMSGrad
|
|
126
|
+
if max_exp_avg_sq_ is not None:
|
|
127
|
+
max_exp_avg_sq_.maximum_(exp_avg_sq_)
|
|
128
|
+
exp_avg_sq_ = max_exp_avg_sq_
|
|
129
|
+
|
|
130
|
+
return exp_avg_sq_
|
|
131
|
+
|
|
132
|
+
def sqrt_ema_sq_(
|
|
133
|
+
tensors: TensorList,
|
|
134
|
+
exp_avg_sq_: TensorList,
|
|
135
|
+
beta: float | NumberList,
|
|
136
|
+
max_exp_avg_sq_: TensorList | None,
|
|
137
|
+
debiased: bool,
|
|
138
|
+
step: int,
|
|
139
|
+
pow: float = 2,
|
|
140
|
+
ema_sq_fn: Callable = ema_sq_,
|
|
141
|
+
):
|
|
142
|
+
"""
|
|
143
|
+
Updates `exp_avg_sq_` with EMA of squared `tensors` and calculates it's square root,
|
|
144
|
+
with optional AMSGrad and debiasing.
|
|
145
|
+
|
|
146
|
+
Returns new tensors.
|
|
147
|
+
"""
|
|
148
|
+
exp_avg_sq_=ema_sq_fn(
|
|
149
|
+
tensors=tensors,
|
|
150
|
+
exp_avg_sq_=exp_avg_sq_,
|
|
151
|
+
beta=beta,
|
|
152
|
+
max_exp_avg_sq_=max_exp_avg_sq_,
|
|
153
|
+
pow=pow,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
sqrt_exp_avg_sq = root(exp_avg_sq_, pow, inplace=False)
|
|
157
|
+
|
|
158
|
+
if debiased: sqrt_exp_avg_sq = debias_second_momentum(sqrt_exp_avg_sq, step=step, beta=beta, pow=pow, inplace=True)
|
|
159
|
+
return sqrt_exp_avg_sq
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def centered_ema_sq_(tensors: TensorList, exp_avg_: TensorList, exp_avg_sq_: TensorList,
|
|
163
|
+
beta: float | NumberList, max_exp_avg_sq_: TensorList | None = None, pow:float=2):
|
|
164
|
+
"""
|
|
165
|
+
Updates `exp_avg_` and `exp_avg_sq_` with EMA of `tensors` and squared `tensors`,
|
|
166
|
+
centers `exp_avg_sq_` by subtracting `exp_avg_` squared.
|
|
167
|
+
|
|
168
|
+
Returns `max_exp_avg_sq_` or new tensors.
|
|
169
|
+
"""
|
|
170
|
+
exp_avg_sq_ = ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta, max_exp_avg_sq_=max_exp_avg_sq_, pow=pow)
|
|
171
|
+
exp_avg_.lerp_(tensors, 1-beta)
|
|
172
|
+
exp_avg_sq_ = exp_avg_sq_.addcmul(exp_avg_, exp_avg_, value=-1)
|
|
173
|
+
|
|
174
|
+
# AMSGrad
|
|
175
|
+
if max_exp_avg_sq_ is not None:
|
|
176
|
+
max_exp_avg_sq_.maximum_(exp_avg_sq_)
|
|
177
|
+
exp_avg_sq_ = max_exp_avg_sq_
|
|
178
|
+
|
|
179
|
+
return exp_avg_sq_
|
|
180
|
+
|
|
181
|
+
def sqrt_centered_ema_sq_(
|
|
182
|
+
tensors: TensorList,
|
|
183
|
+
exp_avg_: TensorList,
|
|
184
|
+
exp_avg_sq_: TensorList,
|
|
185
|
+
max_exp_avg_sq_: TensorList | None,
|
|
186
|
+
beta: float | NumberList,
|
|
187
|
+
debiased: bool,
|
|
188
|
+
step: int,
|
|
189
|
+
pow: float = 2,
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Updates `exp_avg_` and `exp_avg_sq_` with EMA of `tensors` and squared `tensors`,
|
|
193
|
+
centers `exp_avg_sq_` by subtracting `exp_avg_` squared. Calculates it's square root,
|
|
194
|
+
with optional AMSGrad and debiasing.
|
|
195
|
+
|
|
196
|
+
Returns new tensors.
|
|
197
|
+
"""
|
|
198
|
+
return sqrt_ema_sq_(
|
|
199
|
+
tensors=tensors,
|
|
200
|
+
exp_avg_sq_=exp_avg_sq_,
|
|
201
|
+
beta=beta,
|
|
202
|
+
max_exp_avg_sq_=max_exp_avg_sq_,
|
|
203
|
+
debiased=debiased,
|
|
204
|
+
step=step,
|
|
205
|
+
pow=pow,
|
|
206
|
+
ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList
|
|
7
|
+
from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _forward2(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
|
|
11
|
+
if v_0 is None: v_0 = closure(False)
|
|
12
|
+
assert param.ndim == 1
|
|
13
|
+
param[idx] += h
|
|
14
|
+
v_plus = closure(False)
|
|
15
|
+
param[idx] -= h
|
|
16
|
+
return v_0, v_0, (v_plus - v_0) / h # (loss, loss_approx, grad)
|
|
17
|
+
|
|
18
|
+
def _backward2(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
|
|
19
|
+
if v_0 is None: v_0 = closure(False)
|
|
20
|
+
assert param.ndim == 1
|
|
21
|
+
param[idx] -= h
|
|
22
|
+
v_minus = closure(False)
|
|
23
|
+
param[idx] += h
|
|
24
|
+
return v_0, v_0, (v_0 - v_minus) / h
|
|
25
|
+
|
|
26
|
+
def _central2(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: Any):
|
|
27
|
+
assert param.ndim == 1
|
|
28
|
+
param[idx] += h
|
|
29
|
+
v_plus = closure(False)
|
|
30
|
+
|
|
31
|
+
param[idx] -= h * 2
|
|
32
|
+
v_minus = closure(False)
|
|
33
|
+
|
|
34
|
+
param[idx] += h
|
|
35
|
+
return v_0, v_plus, (v_plus - v_minus) / (2 * h)
|
|
36
|
+
|
|
37
|
+
def _forward3(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
|
|
38
|
+
if v_0 is None: v_0 = closure(False)
|
|
39
|
+
assert param.ndim == 1
|
|
40
|
+
param[idx] += h
|
|
41
|
+
v_plus1 = closure(False)
|
|
42
|
+
|
|
43
|
+
param[idx] += h
|
|
44
|
+
v_plus2 = closure(False)
|
|
45
|
+
|
|
46
|
+
param[idx] -= 2 * h
|
|
47
|
+
return v_0, v_0, (-3*v_0 + 4*v_plus1 - v_plus2) / (2 * h)
|
|
48
|
+
|
|
49
|
+
def _backward3(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: float | None):
|
|
50
|
+
if v_0 is None: v_0 = closure(False)
|
|
51
|
+
assert param.ndim == 1
|
|
52
|
+
param[idx] -= h
|
|
53
|
+
v_minus1 = closure(False)
|
|
54
|
+
|
|
55
|
+
param[idx] -= h
|
|
56
|
+
v_minus2 = closure(False)
|
|
57
|
+
|
|
58
|
+
param[idx] += 2 * h
|
|
59
|
+
return v_0, v_0, (v_minus2 - 4*v_minus1 + 3*v_0) / (2 * h)
|
|
60
|
+
|
|
61
|
+
def _central4(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_0: Any):
|
|
62
|
+
assert param.ndim == 1
|
|
63
|
+
|
|
64
|
+
param[idx] += h
|
|
65
|
+
v_plus1 = closure(False)
|
|
66
|
+
|
|
67
|
+
param[idx] += h
|
|
68
|
+
v_plus2 = closure(False)
|
|
69
|
+
|
|
70
|
+
param[idx] -= 3 * h
|
|
71
|
+
v_minus1 = closure(False)
|
|
72
|
+
|
|
73
|
+
param[idx] -= h
|
|
74
|
+
v_minus2 = closure(False)
|
|
75
|
+
|
|
76
|
+
param[idx] += 2 * h
|
|
77
|
+
return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
|
|
78
|
+
|
|
79
|
+
_FD_FUNCS = {
|
|
80
|
+
"forward2": _forward2,
|
|
81
|
+
"backward2": _backward2,
|
|
82
|
+
"central2": _central2,
|
|
83
|
+
"central3": _central2, # they are the same
|
|
84
|
+
"forward3": _forward3,
|
|
85
|
+
"backward3": _backward3,
|
|
86
|
+
"central4": _central4,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class FDM(GradApproximator):
|
|
91
|
+
"""Approximate gradients via finite difference method
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
|
|
95
|
+
formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
|
|
96
|
+
target (GradTarget, optional): what to set on vars. Defaults to 'closure'.
|
|
97
|
+
"""
|
|
98
|
+
def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central2', target: GradTarget = 'closure'):
|
|
99
|
+
defaults = dict(h=h, formula=formula)
|
|
100
|
+
super().__init__(defaults, target=target)
|
|
101
|
+
|
|
102
|
+
@torch.no_grad
|
|
103
|
+
def approximate(self, closure, params, loss, vars):
|
|
104
|
+
grads = []
|
|
105
|
+
loss_approx = None
|
|
106
|
+
|
|
107
|
+
for p in params:
|
|
108
|
+
g = torch.zeros_like(p)
|
|
109
|
+
grads.append(g)
|
|
110
|
+
|
|
111
|
+
settings = self.settings[p]
|
|
112
|
+
h = settings['h']
|
|
113
|
+
fd_fn = _FD_FUNCS[settings['formula']]
|
|
114
|
+
|
|
115
|
+
p_flat = p.view(-1); g_flat = g.view(-1)
|
|
116
|
+
for i in range(len(p_flat)):
|
|
117
|
+
loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
|
|
118
|
+
g_flat[i] = d
|
|
119
|
+
|
|
120
|
+
return grads, loss, loss_approx
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...utils import Distributions, NumberList, TensorList, generic_eq
|
|
8
|
+
from ...utils.derivatives import jvp, jvp_fd_central, jvp_fd_forward
|
|
9
|
+
from .grad_approximator import GradApproximator, GradTarget
|
|
10
|
+
from .rfdm import RandomizedFDM
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ForwardGradient(RandomizedFDM):
|
|
14
|
+
"""Forward gradient method, same as randomized finite difference but directional derivative is estimated via autograd (as jacobian vector product)
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
n_samples (int, optional): number of random gradient samples. Defaults to 1.
|
|
18
|
+
distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
|
|
19
|
+
beta (float, optional):
|
|
20
|
+
if not 0, acts as momentum on gradient samples, making the subspace spanned by them change slowly. Defaults to 0.
|
|
21
|
+
pre_generate (bool, optional):
|
|
22
|
+
whether to pre-generate gradient samples before each step. Defaults to True.
|
|
23
|
+
jvp_method (str, optional):
|
|
24
|
+
how to calculate jacobian vector product, note that with `forward` and 'central' this is identical to randomized finite difference. Defaults to 'autograd'.
|
|
25
|
+
h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
|
|
26
|
+
target (GradTarget, optional): what to set on vars. Defaults to "closure".
|
|
27
|
+
"""
|
|
28
|
+
PRE_MULTIPLY_BY_H = False
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
n_samples: int = 1,
|
|
32
|
+
distribution: Distributions = "gaussian",
|
|
33
|
+
beta: float = 0,
|
|
34
|
+
pre_generate = True,
|
|
35
|
+
jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
36
|
+
h: float = 1e-3,
|
|
37
|
+
target: GradTarget = "closure",
|
|
38
|
+
seed: int | None | torch.Generator = None,
|
|
39
|
+
):
|
|
40
|
+
super().__init__(h=h, n_samples=n_samples, distribution=distribution, beta=beta, target=target, pre_generate=pre_generate, seed=seed)
|
|
41
|
+
self.defaults['jvp_method'] = jvp_method
|
|
42
|
+
|
|
43
|
+
@torch.no_grad
|
|
44
|
+
def approximate(self, closure, params, loss, vars):
|
|
45
|
+
params = TensorList(params)
|
|
46
|
+
loss_approx = None
|
|
47
|
+
|
|
48
|
+
settings = self.settings[params[0]]
|
|
49
|
+
n_samples = settings['n_samples']
|
|
50
|
+
jvp_method = settings['jvp_method']
|
|
51
|
+
h = settings['h']
|
|
52
|
+
distribution = settings['distribution']
|
|
53
|
+
default = [None]*n_samples
|
|
54
|
+
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
55
|
+
generator = self._get_generator(settings['seed'], params)
|
|
56
|
+
|
|
57
|
+
grad = None
|
|
58
|
+
for i in range(n_samples):
|
|
59
|
+
prt = perturbations[i]
|
|
60
|
+
if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator)
|
|
61
|
+
else: prt = TensorList(prt)
|
|
62
|
+
|
|
63
|
+
if jvp_method == 'autograd':
|
|
64
|
+
with torch.enable_grad():
|
|
65
|
+
loss, d = jvp(partial(closure, False), params=params, tangent=prt)
|
|
66
|
+
|
|
67
|
+
elif jvp_method == 'forward':
|
|
68
|
+
loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, normalize=True, h=h)
|
|
69
|
+
|
|
70
|
+
elif jvp_method == 'central':
|
|
71
|
+
loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, normalize=True, h=h)
|
|
72
|
+
|
|
73
|
+
else: raise ValueError(jvp_method)
|
|
74
|
+
|
|
75
|
+
if grad is None: grad = prt * d
|
|
76
|
+
else: grad += prt * d
|
|
77
|
+
|
|
78
|
+
assert grad is not None
|
|
79
|
+
if n_samples > 1: grad.div_(n_samples)
|
|
80
|
+
return grad, loss, loss_approx
|
|
81
|
+
|