torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from functools import partial
|
|
3
|
-
from typing import Any, Literal
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...core import (
|
|
8
|
-
OptimizationVars,
|
|
9
|
-
OptimizerModule,
|
|
10
|
-
_ClosureType,
|
|
11
|
-
_maybe_pass_backward,
|
|
12
|
-
_ScalarLoss,
|
|
13
|
-
)
|
|
14
|
-
from ...tensorlist import TensorList
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class GradientApproximatorBase(OptimizerModule, ABC):
|
|
18
|
-
"""Base gradient approximator class. This is an abstract class, please don't use it as the optimizer.
|
|
19
|
-
|
|
20
|
-
When inheriting from this class the easiest way is to override `_make_ascent`, which should
|
|
21
|
-
return the ascent direction (like approximated gradient).
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
defaults (dict[str, Any]): defaults
|
|
25
|
-
requires_fx0 (bool):
|
|
26
|
-
if True, makes sure to calculate fx0 beforehand.
|
|
27
|
-
This means `_make_ascent` will always receive a pre-calculated `fx0` that won't be None.
|
|
28
|
-
|
|
29
|
-
target (str, optional):
|
|
30
|
-
determines what this module sets.
|
|
31
|
-
|
|
32
|
-
"ascent" - it creates a new ascent direction but doesn't treat is as gradient.
|
|
33
|
-
|
|
34
|
-
"grad" - it creates the gradient and sets it to `.grad` attributes (default).
|
|
35
|
-
|
|
36
|
-
"closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
|
|
37
|
-
"""
|
|
38
|
-
def __init__(self, defaults: dict[str, Any], requires_fx0: bool, target: Literal['ascent', 'grad', 'closure']):
|
|
39
|
-
super().__init__(defaults, target)
|
|
40
|
-
self.requires_fx0 = requires_fx0
|
|
41
|
-
|
|
42
|
-
def _step_make_closure_(self, vars: OptimizationVars, params: TensorList):
|
|
43
|
-
if vars.closure is None: raise ValueError("gradient approximation requires closure")
|
|
44
|
-
closure = vars.closure
|
|
45
|
-
|
|
46
|
-
if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
|
|
47
|
-
else: fx0 = vars.fx0
|
|
48
|
-
|
|
49
|
-
def new_closure(backward=True) -> _ScalarLoss:
|
|
50
|
-
if backward:
|
|
51
|
-
g, ret_fx0, ret_fx0_approx = self._make_ascent(closure, params, fx0)
|
|
52
|
-
params.set_grad_(g)
|
|
53
|
-
|
|
54
|
-
if ret_fx0 is None: return ret_fx0_approx # type:ignore
|
|
55
|
-
return ret_fx0
|
|
56
|
-
|
|
57
|
-
return closure(False)
|
|
58
|
-
|
|
59
|
-
vars.closure = new_closure
|
|
60
|
-
|
|
61
|
-
def _step_make_target_(self, vars: OptimizationVars, params: TensorList):
|
|
62
|
-
if vars.closure is None: raise ValueError("gradient approximation requires closure")
|
|
63
|
-
|
|
64
|
-
if self.requires_fx0: fx0 = vars.evaluate_fx0_(False)
|
|
65
|
-
else: fx0 = vars.fx0
|
|
66
|
-
|
|
67
|
-
g, vars.fx0, vars.fx0_approx = self._make_ascent(vars.closure, params, fx0)
|
|
68
|
-
if self._default_step_target == 'ascent': vars.ascent = g
|
|
69
|
-
elif self._default_step_target == 'grad': vars.set_grad_(g, params)
|
|
70
|
-
else: raise ValueError(f"Unknown target {self._default_step_target}")
|
|
71
|
-
|
|
72
|
-
@torch.no_grad
|
|
73
|
-
def step(self, vars: OptimizationVars):
|
|
74
|
-
params = self.get_params()
|
|
75
|
-
if self._default_step_target == 'closure':
|
|
76
|
-
self._step_make_closure_(vars, params)
|
|
77
|
-
|
|
78
|
-
else:
|
|
79
|
-
self._step_make_target_(vars, params)
|
|
80
|
-
|
|
81
|
-
return self._update_params_or_step_with_next(vars, params)
|
|
82
|
-
|
|
83
|
-
@abstractmethod
|
|
84
|
-
@torch.no_grad
|
|
85
|
-
def _make_ascent(
|
|
86
|
-
self,
|
|
87
|
-
# vars: OptimizationVars,
|
|
88
|
-
closure: _ClosureType,
|
|
89
|
-
params: TensorList,
|
|
90
|
-
fx0: Any,
|
|
91
|
-
) -> tuple[TensorList, _ScalarLoss | None, _ScalarLoss | None]:
|
|
92
|
-
"""This should return a tuple of 3 elements:
|
|
93
|
-
|
|
94
|
-
.. code:: py
|
|
95
|
-
|
|
96
|
-
(ascent, fx0, fx0_approx)
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
closure (_ClosureType): closure
|
|
100
|
-
params (TensorList): parameters
|
|
101
|
-
fx0 (Any): fx0, can be None unless :target:`requires_fx0` is True on this module.
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
(ascent, fx0, fx0_approx)
|
|
105
|
-
"""
|
|
@@ -1,125 +0,0 @@
|
|
|
1
|
-
from typing import Literal, Any
|
|
2
|
-
from warnings import warn
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...utils.python_tools import _ScalarLoss
|
|
6
|
-
from ...tensorlist import TensorList
|
|
7
|
-
from ...core import _ClosureType, OptimizerModule, OptimizationVars
|
|
8
|
-
from ._fd_formulas import _FD_Formulas
|
|
9
|
-
from .base_approximator import GradientApproximatorBase
|
|
10
|
-
|
|
11
|
-
def _two_point_fd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
|
|
12
|
-
"""Two point finite difference (same signature for all other finite differences functions).
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
closure (Callable): A closure that reevaluates the model and returns the loss.
|
|
16
|
-
idx (int): Flat index of the current parameter.
|
|
17
|
-
pvec (Tensor): Flattened view of the current parameter tensor.
|
|
18
|
-
gvec (Tensor): Flattened view of the current parameter tensor gradient.
|
|
19
|
-
eps (float): Finite difference epsilon.
|
|
20
|
-
fx0 (ScalarType): Loss at fx0, to avoid reevaluating it each time. On some functions can be None when it isn't needed.
|
|
21
|
-
|
|
22
|
-
Returns:
|
|
23
|
-
This modifies `gvec` in place.
|
|
24
|
-
This returns loss, not necessarily at fx0 (for example central difference never evaluate at fx0).
|
|
25
|
-
So this should be assigned to fx0_approx.
|
|
26
|
-
"""
|
|
27
|
-
pvec[idx] += eps
|
|
28
|
-
fx1 = closure(False)
|
|
29
|
-
gvec[idx] = (fx1 - fx0) / eps
|
|
30
|
-
pvec[idx] -= eps
|
|
31
|
-
return fx0
|
|
32
|
-
|
|
33
|
-
def _two_point_bd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
|
|
34
|
-
pvec[idx] += eps
|
|
35
|
-
fx1 = closure(False)
|
|
36
|
-
gvec[idx] = (fx0 - fx1) / eps
|
|
37
|
-
pvec[idx] -= eps
|
|
38
|
-
return fx0
|
|
39
|
-
|
|
40
|
-
def _two_point_cd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0 = None, ):
|
|
41
|
-
pvec[idx] += eps
|
|
42
|
-
fxplus = closure(False)
|
|
43
|
-
pvec[idx] -= eps * 2
|
|
44
|
-
fxminus = closure(False)
|
|
45
|
-
gvec[idx] = (fxplus - fxminus) / (2 * eps)
|
|
46
|
-
pvec[idx] += eps
|
|
47
|
-
return fxplus
|
|
48
|
-
|
|
49
|
-
def _three_point_fd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
|
|
50
|
-
pvec[idx] += eps
|
|
51
|
-
fx1 = closure(False)
|
|
52
|
-
pvec[idx] += eps
|
|
53
|
-
fx2 = closure(False)
|
|
54
|
-
gvec[idx] = (-3*fx0 + 4*fx1 - fx2) / (2 * eps)
|
|
55
|
-
pvec[idx] -= 2 * eps
|
|
56
|
-
return fx0
|
|
57
|
-
|
|
58
|
-
def _three_point_bd_(closure: _ClosureType, idx: int, pvec: torch.Tensor, gvec: torch.Tensor, eps: _ScalarLoss, fx0: _ScalarLoss, ):
|
|
59
|
-
pvec[idx] -= eps
|
|
60
|
-
fx1 = closure(False)
|
|
61
|
-
pvec[idx] -= eps
|
|
62
|
-
fx2 = closure(False)
|
|
63
|
-
gvec[idx] = (fx2 - 4*fx1 + 3*fx0) / (2 * eps)
|
|
64
|
-
pvec[idx] += 2 * eps
|
|
65
|
-
return fx0
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class FDM(GradientApproximatorBase):
|
|
69
|
-
"""Gradient approximation via finite difference.
|
|
70
|
-
|
|
71
|
-
This performs :math:`num_parameters + 1` or :math:`num_parameters * 2` evaluations per step, depending on formula.
|
|
72
|
-
|
|
73
|
-
Args:
|
|
74
|
-
eps (float, optional): finite difference epsilon. Defaults to 1e-5.
|
|
75
|
-
formula (_FD_Formulas, optional): finite difference formula. Defaults to 'forward'.
|
|
76
|
-
n_points (T.Literal[2, 3], optional): number of points, 2 or 3. Defaults to 2.
|
|
77
|
-
target (str, optional):
|
|
78
|
-
determines what this module sets.
|
|
79
|
-
|
|
80
|
-
"ascent" - it creates a new ascent direction but doesn't treat is as gradient.
|
|
81
|
-
|
|
82
|
-
"grad" - it creates the gradient and sets it to `.grad` attributes (default).
|
|
83
|
-
|
|
84
|
-
"closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
|
|
85
|
-
"""
|
|
86
|
-
def __init__(
|
|
87
|
-
self,
|
|
88
|
-
eps: float = 1e-5,
|
|
89
|
-
formula: _FD_Formulas = "forward",
|
|
90
|
-
n_points: Literal[2, 3] = 2,
|
|
91
|
-
target: Literal["ascent", "grad", "closure"] = "grad",
|
|
92
|
-
):
|
|
93
|
-
defaults = dict(eps = eps)
|
|
94
|
-
|
|
95
|
-
if formula == 'central':
|
|
96
|
-
self._finite_difference_ = _two_point_cd_ # this is both 2 and 3 point formula
|
|
97
|
-
requires_fx0 = False
|
|
98
|
-
|
|
99
|
-
elif formula == 'forward':
|
|
100
|
-
if n_points == 2: self._finite_difference_ = _two_point_fd_
|
|
101
|
-
else: self._finite_difference_ = _three_point_fd_
|
|
102
|
-
requires_fx0 = True
|
|
103
|
-
|
|
104
|
-
elif formula == 'backward':
|
|
105
|
-
if n_points == 2: self._finite_difference_ = _two_point_bd_
|
|
106
|
-
else: self._finite_difference_ = _three_point_bd_
|
|
107
|
-
requires_fx0 = True
|
|
108
|
-
|
|
109
|
-
else: raise ValueError(f'{formula} is not valid.')
|
|
110
|
-
|
|
111
|
-
super().__init__(defaults, requires_fx0=requires_fx0, target = target)
|
|
112
|
-
|
|
113
|
-
@torch.no_grad
|
|
114
|
-
def _make_ascent(self, closure, params, fx0):
|
|
115
|
-
grads = params.zeros_like()
|
|
116
|
-
epsilons = self.get_group_key('eps')
|
|
117
|
-
|
|
118
|
-
fx0_approx = None
|
|
119
|
-
for p, g, eps in zip(params, grads, epsilons):
|
|
120
|
-
flat_param = p.view(-1)
|
|
121
|
-
flat_grad = g.view(-1)
|
|
122
|
-
for idx in range(flat_param.numel()):
|
|
123
|
-
fx0_approx = self._finite_difference_(closure, idx, flat_param, flat_grad, eps, fx0)
|
|
124
|
-
|
|
125
|
-
return grads, fx0, fx0_approx
|
|
@@ -1,163 +0,0 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
2
|
-
from typing import Literal
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
import torch.autograd.forward_ad as fwAD
|
|
6
|
-
|
|
7
|
-
from ...core import OptimizerModule, _ClosureType
|
|
8
|
-
from ...tensorlist import TensorList
|
|
9
|
-
from ...random import Distributions
|
|
10
|
-
from ...utils.torch_tools import swap_tensors_no_use_count_check
|
|
11
|
-
from .base_approximator import GradientApproximatorBase
|
|
12
|
-
|
|
13
|
-
def get_forward_gradient(
|
|
14
|
-
params: Iterable[torch.Tensor],
|
|
15
|
-
closure: _ClosureType,
|
|
16
|
-
n_samples: int,
|
|
17
|
-
distribution: Distributions,
|
|
18
|
-
mode: Literal["jvp", "grad", "fd"],
|
|
19
|
-
fd_eps: float = 1e-4,
|
|
20
|
-
):
|
|
21
|
-
"""Evaluates forward gradient of a closure w.r.t iterable of parameters with a random tangent vector.
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
params (Iterable[torch.Tensor]): iterable of parameters of the model.
|
|
25
|
-
closure (_ClosureType):
|
|
26
|
-
A closure that reevaluates the model and returns the loss.
|
|
27
|
-
Closure must accept `backward = True` boolean argument. Forward gradient will always call it as
|
|
28
|
-
`closure(False)`, unless `mode = "grad"` which requires a backward pass.
|
|
29
|
-
n_samples (int): number of forward gradients to evaluate and average.
|
|
30
|
-
distribution (Distributions): distribution for random tangent vector.
|
|
31
|
-
mode (str):
|
|
32
|
-
"jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory.
|
|
33
|
-
|
|
34
|
-
"grad" - evaluates gradient with `loss.backward()` which may be faster but uses all the memory, mainly useful for
|
|
35
|
-
benchmarking as there is probably no point in forward gradient if full gradient is available.
|
|
36
|
-
|
|
37
|
-
"fd" - uses finite difference to estimate JVP in two forward passes,
|
|
38
|
-
doesn't require the objective to be autodiffable. Equivalent to randomized FDM.
|
|
39
|
-
|
|
40
|
-
fd_eps (float, optional): epsilon for finite difference, only has effect if mode is "fd". Defaults to 1e-4.
|
|
41
|
-
|
|
42
|
-
Returns:
|
|
43
|
-
TensorList: list of estimated gradients of the same structure and shape as `params`.
|
|
44
|
-
|
|
45
|
-
Reference:
|
|
46
|
-
Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022).
|
|
47
|
-
Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
|
|
48
|
-
https://arxiv.org/abs/2202.08587
|
|
49
|
-
"""
|
|
50
|
-
if not isinstance(params, TensorList): params = TensorList(params)
|
|
51
|
-
params = params.with_requires_grad()
|
|
52
|
-
|
|
53
|
-
orig_params = None
|
|
54
|
-
grad = None
|
|
55
|
-
loss = None
|
|
56
|
-
for _ in range(n_samples):
|
|
57
|
-
|
|
58
|
-
# generate random vector
|
|
59
|
-
tangents = params.sample_like(fd_eps if mode == 'fd' else 1, distribution)
|
|
60
|
-
|
|
61
|
-
if mode == 'jvp':
|
|
62
|
-
if orig_params is None:
|
|
63
|
-
orig_params = params.clone().requires_grad_()
|
|
64
|
-
|
|
65
|
-
# evaluate jvp with it
|
|
66
|
-
with fwAD.dual_level():
|
|
67
|
-
|
|
68
|
-
# swap to duals
|
|
69
|
-
for param, clone, tangent in zip(params, orig_params, tangents):
|
|
70
|
-
dual = fwAD.make_dual(clone, tangent)
|
|
71
|
-
torch.utils.swap_tensors(param, dual)
|
|
72
|
-
|
|
73
|
-
loss = closure(False)
|
|
74
|
-
jvp = fwAD.unpack_dual(loss).tangent
|
|
75
|
-
|
|
76
|
-
elif mode == 'grad':
|
|
77
|
-
with torch.enable_grad(): loss = closure()
|
|
78
|
-
jvp = tangents.mul(params.ensure_grad_().grad).sum()
|
|
79
|
-
|
|
80
|
-
elif mode == 'fd':
|
|
81
|
-
loss = closure(False)
|
|
82
|
-
params += tangents
|
|
83
|
-
loss2 = closure(False)
|
|
84
|
-
params -= tangents
|
|
85
|
-
jvp = (loss2 - loss) / fd_eps**2
|
|
86
|
-
|
|
87
|
-
else:
|
|
88
|
-
raise ValueError(mode)
|
|
89
|
-
|
|
90
|
-
# update grad estimate
|
|
91
|
-
if grad is None: grad = tangents * jvp
|
|
92
|
-
else: grad += tangents * jvp
|
|
93
|
-
|
|
94
|
-
# swap back to original params
|
|
95
|
-
if orig_params is not None:
|
|
96
|
-
for param, orig in zip(params, orig_params):
|
|
97
|
-
swap_tensors_no_use_count_check(param, orig)
|
|
98
|
-
|
|
99
|
-
assert grad is not None
|
|
100
|
-
assert loss is not None
|
|
101
|
-
if n_samples > 1:
|
|
102
|
-
grad /= n_samples
|
|
103
|
-
|
|
104
|
-
return grad, loss
|
|
105
|
-
|
|
106
|
-
class ForwardGradient(GradientApproximatorBase):
|
|
107
|
-
"""Evaluates jacobian-vector product with a random vector using forward mode autodiff (torch.autograd.forward_ad), which is
|
|
108
|
-
the true directional derivative in the direction of that vector.
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
n_samples (int): number of forward gradients to evaluate and average.
|
|
112
|
-
distribution (Distributions): distribution for random tangent vector.
|
|
113
|
-
mode (str):
|
|
114
|
-
"jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory,
|
|
115
|
-
because it doesn't have to store intermediate activations.
|
|
116
|
-
|
|
117
|
-
"grad" - evaluates gradient with `loss.backward()` which may be faster but uses all the memory, mainly useful for
|
|
118
|
-
benchmarking as there is probably no point in forward gradient if full gradient is available.
|
|
119
|
-
|
|
120
|
-
"fd" - uses finite difference to estimate JVP in two forward passes,
|
|
121
|
-
doesn't require the objective to be autodiffable. Equivalent to randomized FDM.
|
|
122
|
-
|
|
123
|
-
fd_eps (float, optional): epsilon for finite difference, only has effect if mode is "fd". Defaults to 1e-4.
|
|
124
|
-
target (str, optional):
|
|
125
|
-
determines what this module sets.
|
|
126
|
-
|
|
127
|
-
"ascent" - it creates a new ascent direction but doesn't treat is as gradient.
|
|
128
|
-
|
|
129
|
-
"grad" - it creates the gradient and sets it to `.grad` attributes (default).
|
|
130
|
-
|
|
131
|
-
"closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
|
|
132
|
-
|
|
133
|
-
Reference:
|
|
134
|
-
Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022).
|
|
135
|
-
Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
|
|
136
|
-
https://arxiv.org/abs/2202.08587
|
|
137
|
-
"""
|
|
138
|
-
def __init__(
|
|
139
|
-
self,
|
|
140
|
-
n_samples: int = 1,
|
|
141
|
-
distribution: Distributions = "normal",
|
|
142
|
-
mode: Literal["jvp", "grad", "fd"] = "jvp",
|
|
143
|
-
fd_eps: float = 1e-4,
|
|
144
|
-
target: Literal['ascent', 'grad', 'closure'] = 'grad',
|
|
145
|
-
):
|
|
146
|
-
super().__init__({}, requires_fx0=False, target = target)
|
|
147
|
-
self.distribution: Distributions = distribution
|
|
148
|
-
self.n_samples = n_samples
|
|
149
|
-
self.mode: Literal["jvp", "grad", "fd"] = mode
|
|
150
|
-
self.fd_eps = fd_eps
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
def _make_ascent(self, closure, params, fx0):
|
|
154
|
-
g, fx0 = get_forward_gradient(
|
|
155
|
-
params=params,
|
|
156
|
-
closure=closure,
|
|
157
|
-
n_samples=self.n_samples,
|
|
158
|
-
distribution=self.distribution,
|
|
159
|
-
mode=self.mode,
|
|
160
|
-
fd_eps=self.fd_eps,
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
return g, fx0, None
|
|
@@ -1,198 +0,0 @@
|
|
|
1
|
-
import typing as T
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...utils.python_tools import _ScalarLoss
|
|
6
|
-
from ...tensorlist import TensorList
|
|
7
|
-
from ...core import _ClosureType, OptimizerModule
|
|
8
|
-
from ..second_order.newton import (LINEAR_SYSTEM_SOLVERS,
|
|
9
|
-
FallbackLinearSystemSolvers,
|
|
10
|
-
LinearSystemSolvers, _fallback_gd)
|
|
11
|
-
from ._fd_formulas import _FD_Formulas
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _three_point_2cd_(
|
|
15
|
-
closure: _ClosureType,
|
|
16
|
-
idx1: int,
|
|
17
|
-
idx2: int,
|
|
18
|
-
p1: torch.Tensor,
|
|
19
|
-
p2: torch.Tensor,
|
|
20
|
-
g1: torch.Tensor,
|
|
21
|
-
hessian: torch.Tensor,
|
|
22
|
-
eps1: _ScalarLoss,
|
|
23
|
-
eps2: _ScalarLoss,
|
|
24
|
-
i1: int,
|
|
25
|
-
i2: int,
|
|
26
|
-
fx0: _ScalarLoss,
|
|
27
|
-
):
|
|
28
|
-
"""Second order three point finite differences (same signature for all other 2nd order finite differences functions).
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
closure (ClosureType): _description_
|
|
32
|
-
idx1 (int): _description_
|
|
33
|
-
idx2 (int): _description_
|
|
34
|
-
p1 (torch.Tensor): _description_
|
|
35
|
-
p2 (torch.Tensor): _description_
|
|
36
|
-
g1 (torch.Tensor): _description_
|
|
37
|
-
g2 (torch.Tensor): _description_
|
|
38
|
-
hessian (torch.Tensor): _description_
|
|
39
|
-
eps1 (ScalarType): _description_
|
|
40
|
-
eps2 (ScalarType): _description_
|
|
41
|
-
i1 (int): _description_
|
|
42
|
-
i23 (int): _description_
|
|
43
|
-
fx0 (ScalarType): _description_
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
# same param
|
|
47
|
-
if i1 == i2 and idx1 == idx2:
|
|
48
|
-
p1[idx1] += eps1
|
|
49
|
-
fxplus = closure(False)
|
|
50
|
-
|
|
51
|
-
p1[idx1] -= 2*eps1
|
|
52
|
-
fxminus = closure(False)
|
|
53
|
-
|
|
54
|
-
p1[idx1] += eps1
|
|
55
|
-
|
|
56
|
-
g1[idx1] = (fxplus - fxminus) / (2 * eps1)
|
|
57
|
-
hessian[i1, i2] = (fxplus - 2*fx0 + fxminus) / eps1**2
|
|
58
|
-
|
|
59
|
-
else:
|
|
60
|
-
p1[idx1] += eps1
|
|
61
|
-
p2[idx2] += eps2
|
|
62
|
-
fxpp = closure(False)
|
|
63
|
-
p1[idx1] -= eps1*2
|
|
64
|
-
fxnp = closure(False)
|
|
65
|
-
p2[idx2] -= eps2*2
|
|
66
|
-
fxnn = closure(False)
|
|
67
|
-
p1[idx1] += eps1*2
|
|
68
|
-
fxpn = closure(False)
|
|
69
|
-
|
|
70
|
-
p1[idx1] -= eps1
|
|
71
|
-
p2[idx2] += eps2
|
|
72
|
-
|
|
73
|
-
hessian[i1, i2] = (fxpp - fxpn - fxnp + fxnn) / (4 * eps1 * eps2)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class NewtonFDM(OptimizerModule):
|
|
77
|
-
"""Newton method with gradient and hessian approximated via finite difference.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
eps (float, optional):
|
|
81
|
-
epsilon for finite difference.
|
|
82
|
-
Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
|
|
83
|
-
diag (bool, optional):
|
|
84
|
-
whether to only approximate diagonal elements of the hessian.
|
|
85
|
-
If true, ignores `solver` and `fallback`. Defaults to False.
|
|
86
|
-
solver (LinearSystemSolvers, optional):
|
|
87
|
-
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
88
|
-
fallback (FallbackLinearSystemSolvers, optional):
|
|
89
|
-
what to do if solver fails. Defaults to "safe_diag"
|
|
90
|
-
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
91
|
-
validate (bool, optional):
|
|
92
|
-
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
93
|
-
If not, undo the step and perform a gradient descent step.
|
|
94
|
-
tol (float, optional):
|
|
95
|
-
only has effect if `validate` is enabled.
|
|
96
|
-
If loss increased by `loss * tol`, perform gradient descent step.
|
|
97
|
-
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
98
|
-
gd_lr (float, optional):
|
|
99
|
-
only has effect if `validate` is enabled.
|
|
100
|
-
Gradient descent step learning rate. Defaults to 1e-2.
|
|
101
|
-
|
|
102
|
-
"""
|
|
103
|
-
def __init__(
|
|
104
|
-
self,
|
|
105
|
-
eps: float = 1e-2,
|
|
106
|
-
diag=False,
|
|
107
|
-
solver: LinearSystemSolvers = "cholesky_lu",
|
|
108
|
-
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
109
|
-
validate=False,
|
|
110
|
-
tol: float = 1,
|
|
111
|
-
gd_lr = 1e-2,
|
|
112
|
-
):
|
|
113
|
-
defaults = dict(eps = eps)
|
|
114
|
-
super().__init__(defaults)
|
|
115
|
-
self.diag = diag
|
|
116
|
-
self.solver = LINEAR_SYSTEM_SOLVERS[solver]
|
|
117
|
-
self.fallback = LINEAR_SYSTEM_SOLVERS[fallback]
|
|
118
|
-
|
|
119
|
-
self.validate = validate
|
|
120
|
-
self.gd_lr = gd_lr
|
|
121
|
-
self.tol = tol
|
|
122
|
-
|
|
123
|
-
@torch.no_grad
|
|
124
|
-
def step(self, vars):
|
|
125
|
-
"""Returns a new ascent direction."""
|
|
126
|
-
if vars.closure is None: raise ValueError('NewtonFDM requires a closure.')
|
|
127
|
-
if vars.ascent is not None: raise ValueError('NewtonFDM got ascent direction')
|
|
128
|
-
|
|
129
|
-
params = self.get_params()
|
|
130
|
-
epsilons = self.get_group_key('eps')
|
|
131
|
-
|
|
132
|
-
# evaluate fx0.
|
|
133
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
134
|
-
|
|
135
|
-
# evaluate gradients and hessian via finite differences.
|
|
136
|
-
grads = params.zeros_like()
|
|
137
|
-
numel = params.total_numel()
|
|
138
|
-
hessian = torch.zeros((numel, numel), dtype = params[0].dtype, device = params[0].device)
|
|
139
|
-
|
|
140
|
-
cur1 = 0
|
|
141
|
-
for p1, g1, eps1 in zip(params, grads, epsilons):
|
|
142
|
-
flat_param1 = p1.view(-1)
|
|
143
|
-
flat_grad1 = g1.view(-1)
|
|
144
|
-
for idx1 in range(flat_param1.numel()):
|
|
145
|
-
|
|
146
|
-
cur2 = 0
|
|
147
|
-
for p2, eps2 in zip(params, epsilons):
|
|
148
|
-
|
|
149
|
-
flat_param2 = p2.view(-1)
|
|
150
|
-
for idx2 in range(flat_param2.numel()):
|
|
151
|
-
if self.diag and (idx1 != idx2 or cur1 != cur2):
|
|
152
|
-
cur2 += 1
|
|
153
|
-
continue
|
|
154
|
-
_three_point_2cd_(
|
|
155
|
-
closure = vars.closure,
|
|
156
|
-
idx1 = idx1,
|
|
157
|
-
idx2 = idx2,
|
|
158
|
-
p1 = flat_param1,
|
|
159
|
-
p2 = flat_param2,
|
|
160
|
-
g1 = flat_grad1,
|
|
161
|
-
hessian = hessian,
|
|
162
|
-
eps1 = eps1,
|
|
163
|
-
eps2 = eps2,
|
|
164
|
-
fx0 = vars.fx0,
|
|
165
|
-
i1 = cur1,
|
|
166
|
-
i2 = cur2,
|
|
167
|
-
)
|
|
168
|
-
cur2 += 1
|
|
169
|
-
cur1 += 1
|
|
170
|
-
|
|
171
|
-
gvec = grads.to_vec()
|
|
172
|
-
if self.diag:
|
|
173
|
-
hdiag = hessian.diag()
|
|
174
|
-
hdiag[hdiag == 0] = 1
|
|
175
|
-
newton_step = gvec / hdiag
|
|
176
|
-
else:
|
|
177
|
-
newton_step, success = self.solver(hessian, gvec)
|
|
178
|
-
if not success:
|
|
179
|
-
newton_step, success = self.fallback(hessian, gvec)
|
|
180
|
-
if not success:
|
|
181
|
-
newton_step, success = _fallback_gd(hessian, gvec)
|
|
182
|
-
|
|
183
|
-
# update params or pass the gradients to the child.
|
|
184
|
-
vars.ascent = grads.from_vec(newton_step)
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
# validate if newton step decreased loss
|
|
188
|
-
if self.validate:
|
|
189
|
-
|
|
190
|
-
params.sub_(vars.ascent)
|
|
191
|
-
fx1 = vars.closure(False)
|
|
192
|
-
params.add_(vars.ascent)
|
|
193
|
-
|
|
194
|
-
# if loss increases, set ascent direction to gvec times lr
|
|
195
|
-
if fx1 - vars.fx0 > vars.fx0 * self.tol:
|
|
196
|
-
vars.ascent = grads.from_vec(gvec) * self.gd_lr
|
|
197
|
-
|
|
198
|
-
return self._update_params_or_step_with_next(vars, params)
|