torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable, Iterable
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Module, Vars
|
|
9
|
+
|
|
10
|
+
GradTarget = Literal['update', 'grad', 'closure']
|
|
11
|
+
_Scalar = torch.Tensor | float
|
|
12
|
+
|
|
13
|
+
class GradApproximator(Module, ABC):
|
|
14
|
+
"""Base class for gradient approximations.
|
|
15
|
+
This is an abstract class, to use it, subclass it and override `approximate`.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
|
|
19
|
+
target (str, optional):
|
|
20
|
+
whether to set `vars.grad`, `vars.update` or 'vars.closure`. Defaults to 'closure'.
|
|
21
|
+
"""
|
|
22
|
+
def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
|
|
23
|
+
super().__init__(defaults)
|
|
24
|
+
self._target: GradTarget = target
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, vars: Vars) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
|
|
28
|
+
"""Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
|
|
29
|
+
|
|
30
|
+
def pre_step(self, vars: Vars) -> Vars | None:
|
|
31
|
+
"""This runs once before each step, whereas `approximate` may run multiple times per step if further modules
|
|
32
|
+
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
33
|
+
return vars
|
|
34
|
+
|
|
35
|
+
@torch.no_grad
|
|
36
|
+
def step(self, vars):
|
|
37
|
+
ret = self.pre_step(vars)
|
|
38
|
+
if isinstance(ret, Vars): vars = ret
|
|
39
|
+
|
|
40
|
+
if vars.closure is None: raise RuntimeError("Gradient approximation requires closure")
|
|
41
|
+
params, closure, loss = vars.params, vars.closure, vars.loss
|
|
42
|
+
|
|
43
|
+
if self._target == 'closure':
|
|
44
|
+
|
|
45
|
+
def approx_closure(backward=True):
|
|
46
|
+
if backward:
|
|
47
|
+
# set loss to None because closure might be evaluated at different points
|
|
48
|
+
grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, vars=vars)
|
|
49
|
+
for p, g in zip(params, grad): p.grad = g
|
|
50
|
+
return l if l is not None else l_approx
|
|
51
|
+
return closure(False)
|
|
52
|
+
|
|
53
|
+
vars.closure = approx_closure
|
|
54
|
+
return vars
|
|
55
|
+
|
|
56
|
+
# if vars.grad is not None:
|
|
57
|
+
# warnings.warn('Using grad approximator when `vars.grad` is already set.')
|
|
58
|
+
grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, vars=vars)
|
|
59
|
+
if loss_approx is not None: vars.loss_approx = loss_approx
|
|
60
|
+
if loss is not None: vars.loss = vars.loss_approx = loss
|
|
61
|
+
if self._target == 'grad': vars.grad = list(grad)
|
|
62
|
+
elif self._target == 'update': vars.update = list(grad)
|
|
63
|
+
else: raise ValueError(self._target)
|
|
64
|
+
return vars
|
|
65
|
+
|
|
66
|
+
_FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', 'central2', 'central4']
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
from functools import partial
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils import TensorList, Distributions, NumberList, generic_eq
|
|
7
|
+
from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
|
|
11
|
+
"""p_fn is a function that returns the perturbation.
|
|
12
|
+
It may return pre-generated one or generate one deterministically from a seed as in MeZO.
|
|
13
|
+
Returned perturbation must be multiplied by `h`."""
|
|
14
|
+
if v_0 is None: v_0 = closure(False)
|
|
15
|
+
params += p_fn()
|
|
16
|
+
v_plus = closure(False)
|
|
17
|
+
params -= p_fn()
|
|
18
|
+
h = h**2 # because perturbation already multiplied by h
|
|
19
|
+
return v_0, v_0, (v_plus - v_0) / h # (loss, loss_approx, grad)
|
|
20
|
+
|
|
21
|
+
def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
|
|
22
|
+
if v_0 is None: v_0 = closure(False)
|
|
23
|
+
params -= p_fn()
|
|
24
|
+
v_minus = closure(False)
|
|
25
|
+
params += p_fn()
|
|
26
|
+
h = h**2 # because perturbation already multiplied by h
|
|
27
|
+
return v_0, v_0, (v_0 - v_minus) / h
|
|
28
|
+
|
|
29
|
+
def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: Any):
|
|
30
|
+
params += p_fn()
|
|
31
|
+
v_plus = closure(False)
|
|
32
|
+
|
|
33
|
+
params -= p_fn() * 2
|
|
34
|
+
v_minus = closure(False)
|
|
35
|
+
|
|
36
|
+
params += p_fn()
|
|
37
|
+
h = h**2 # because perturbation already multiplied by h
|
|
38
|
+
return v_0, v_plus, (v_plus - v_minus) / (2 * h)
|
|
39
|
+
|
|
40
|
+
def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
|
|
41
|
+
if v_0 is None: v_0 = closure(False)
|
|
42
|
+
params += p_fn()
|
|
43
|
+
v_plus1 = closure(False)
|
|
44
|
+
|
|
45
|
+
params += p_fn()
|
|
46
|
+
v_plus2 = closure(False)
|
|
47
|
+
|
|
48
|
+
params -= p_fn() * 2
|
|
49
|
+
h = h**2 # because perturbation already multiplied by h
|
|
50
|
+
return v_0, v_0, (-3*v_0 + 4*v_plus1 - v_plus2) / (2 * h)
|
|
51
|
+
|
|
52
|
+
def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
|
|
53
|
+
if v_0 is None: v_0 = closure(False)
|
|
54
|
+
|
|
55
|
+
params -= p_fn()
|
|
56
|
+
v_minus1 = closure(False)
|
|
57
|
+
|
|
58
|
+
params -= p_fn()
|
|
59
|
+
v_minus2 = closure(False)
|
|
60
|
+
|
|
61
|
+
params += p_fn() * 2
|
|
62
|
+
h = h**2 # because perturbation already multiplied by h
|
|
63
|
+
return v_0, v_0, (v_minus2 - 4*v_minus1 + 3*v_0) / (2 * h)
|
|
64
|
+
|
|
65
|
+
def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
|
|
66
|
+
params += p_fn()
|
|
67
|
+
v_plus1 = closure(False)
|
|
68
|
+
|
|
69
|
+
params += p_fn()
|
|
70
|
+
v_plus2 = closure(False)
|
|
71
|
+
|
|
72
|
+
params -= p_fn() * 3
|
|
73
|
+
v_minus1 = closure(False)
|
|
74
|
+
|
|
75
|
+
params -= p_fn()
|
|
76
|
+
v_minus2 = closure(False)
|
|
77
|
+
|
|
78
|
+
params += p_fn() * 2
|
|
79
|
+
h = h**2 # because perturbation already multiplied by h
|
|
80
|
+
return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
|
|
81
|
+
|
|
82
|
+
_RFD_FUNCS = {
|
|
83
|
+
"forward2": _rforward2,
|
|
84
|
+
"backward2": _rbackward2,
|
|
85
|
+
"central2": _rcentral2,
|
|
86
|
+
"forward3": _rforward3,
|
|
87
|
+
"backward3": _rbackward3,
|
|
88
|
+
"central4": _rcentral4,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class RandomizedFDM(GradApproximator):
|
|
93
|
+
PRE_MULTIPLY_BY_H = True
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
h: float = 1e-3,
|
|
97
|
+
n_samples: int = 1,
|
|
98
|
+
formula: _FD_Formula = "central2",
|
|
99
|
+
distribution: Distributions = "rademacher",
|
|
100
|
+
beta: float = 0,
|
|
101
|
+
pre_generate = True,
|
|
102
|
+
target: GradTarget = "closure",
|
|
103
|
+
seed: int | None | torch.Generator = None,
|
|
104
|
+
):
|
|
105
|
+
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
|
|
106
|
+
super().__init__(defaults, target=target)
|
|
107
|
+
|
|
108
|
+
def reset(self):
|
|
109
|
+
self.state.clear()
|
|
110
|
+
generator = self.global_state.get('generator', None) # avoid resetting generator
|
|
111
|
+
self.global_state.clear()
|
|
112
|
+
if generator is not None: self.global_state['generator'] = generator
|
|
113
|
+
|
|
114
|
+
def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
|
|
115
|
+
if 'generator' not in self.global_state:
|
|
116
|
+
if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
|
|
117
|
+
elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
118
|
+
else: self.global_state['generator'] = None
|
|
119
|
+
return self.global_state['generator']
|
|
120
|
+
|
|
121
|
+
def pre_step(self, vars):
|
|
122
|
+
h, beta = self.get_settings('h', 'beta', params=vars.params)
|
|
123
|
+
settings = self.settings[vars.params[0]]
|
|
124
|
+
n_samples = settings['n_samples']
|
|
125
|
+
distribution = settings['distribution']
|
|
126
|
+
pre_generate = settings['pre_generate']
|
|
127
|
+
|
|
128
|
+
if pre_generate:
|
|
129
|
+
params = TensorList(vars.params)
|
|
130
|
+
generator = self._get_generator(settings['seed'], vars.params)
|
|
131
|
+
perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
|
|
132
|
+
|
|
133
|
+
if self.PRE_MULTIPLY_BY_H:
|
|
134
|
+
torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
|
|
135
|
+
|
|
136
|
+
if all(i==0 for i in beta):
|
|
137
|
+
# just use pre-generated perturbations
|
|
138
|
+
for param, prt in zip(params, zip(*perturbations)):
|
|
139
|
+
self.state[param]['perturbations'] = prt
|
|
140
|
+
|
|
141
|
+
else:
|
|
142
|
+
# lerp old and new perturbations. This makes the subspace change gradually
|
|
143
|
+
# which in theory might improve algorithms with history
|
|
144
|
+
for i,p in enumerate(params):
|
|
145
|
+
state = self.state[p]
|
|
146
|
+
if 'perturbations' not in state: state['perturbations'] = [p[i] for p in perturbations]
|
|
147
|
+
|
|
148
|
+
cur = [self.state[p]['perturbations'][:n_samples] for p in params]
|
|
149
|
+
cur_flat = [p for l in cur for p in l]
|
|
150
|
+
new_flat = [p for l in zip(*perturbations) for p in l]
|
|
151
|
+
betas = [1-v for b in beta for v in [b]*n_samples]
|
|
152
|
+
torch._foreach_lerp_(cur_flat, new_flat, betas)
|
|
153
|
+
|
|
154
|
+
@torch.no_grad
|
|
155
|
+
def approximate(self, closure, params, loss, vars):
|
|
156
|
+
params = TensorList(params)
|
|
157
|
+
loss_approx = None
|
|
158
|
+
|
|
159
|
+
h = self.get_settings('h', params=vars.params, cls=NumberList)
|
|
160
|
+
settings = self.settings[params[0]]
|
|
161
|
+
n_samples = settings['n_samples']
|
|
162
|
+
fd_fn = _RFD_FUNCS[settings['formula']]
|
|
163
|
+
default = [None]*n_samples
|
|
164
|
+
perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
|
|
165
|
+
distribution = settings['distribution']
|
|
166
|
+
generator = self._get_generator(settings['seed'], params)
|
|
167
|
+
|
|
168
|
+
grad = None
|
|
169
|
+
for i in range(n_samples):
|
|
170
|
+
prt = perturbations[i]
|
|
171
|
+
if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator).mul_(h)
|
|
172
|
+
else: prt = TensorList(prt)
|
|
173
|
+
|
|
174
|
+
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, v_0=loss)
|
|
175
|
+
if grad is None: grad = prt * d
|
|
176
|
+
else: grad += prt * d
|
|
177
|
+
|
|
178
|
+
assert grad is not None
|
|
179
|
+
if n_samples > 1: grad.div_(n_samples)
|
|
180
|
+
return grad, loss, loss_approx
|
|
181
|
+
|
|
182
|
+
SPSA = RandomizedFDM
|
|
183
|
+
|
|
184
|
+
class RDSA(RandomizedFDM):
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
h: float = 1e-3,
|
|
188
|
+
n_samples: int = 1,
|
|
189
|
+
formula: _FD_Formula = "central2",
|
|
190
|
+
distribution: Distributions = "gaussian",
|
|
191
|
+
beta: float = 0,
|
|
192
|
+
pre_generate = True,
|
|
193
|
+
target: GradTarget = "closure",
|
|
194
|
+
seed: int | None | torch.Generator = None,
|
|
195
|
+
):
|
|
196
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
|
|
197
|
+
|
|
198
|
+
class GaussianSmoothing(RandomizedFDM):
|
|
199
|
+
def __init__(
|
|
200
|
+
self,
|
|
201
|
+
h: float = 1e-2,
|
|
202
|
+
n_samples: int = 100,
|
|
203
|
+
formula: _FD_Formula = "central2",
|
|
204
|
+
distribution: Distributions = "gaussian",
|
|
205
|
+
beta: float = 0,
|
|
206
|
+
pre_generate = True,
|
|
207
|
+
target: GradTarget = "closure",
|
|
208
|
+
seed: int | None | torch.Generator = None,
|
|
209
|
+
):
|
|
210
|
+
super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
|
|
211
|
+
|
|
212
|
+
class MeZO(GradApproximator):
|
|
213
|
+
def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
|
|
214
|
+
distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
|
|
215
|
+
defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
|
|
216
|
+
super().__init__(defaults, target=target)
|
|
217
|
+
|
|
218
|
+
def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
|
|
219
|
+
return TensorList(params).sample_like(
|
|
220
|
+
distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
|
|
221
|
+
).mul_(h)
|
|
222
|
+
|
|
223
|
+
def pre_step(self, vars):
|
|
224
|
+
h = self.get_settings('h', params=vars.params)
|
|
225
|
+
settings = self.settings[vars.params[0]]
|
|
226
|
+
n_samples = settings['n_samples']
|
|
227
|
+
distribution = settings['distribution']
|
|
228
|
+
|
|
229
|
+
step = vars.current_step
|
|
230
|
+
|
|
231
|
+
# create functions that generate a deterministic perturbation from seed based on current step
|
|
232
|
+
prt_fns = []
|
|
233
|
+
for i in range(n_samples):
|
|
234
|
+
|
|
235
|
+
prt_fn = partial(self._seeded_perturbation, params=vars.params, distribution=distribution, seed=1_000_000*step + i, h=h)
|
|
236
|
+
prt_fns.append(prt_fn)
|
|
237
|
+
|
|
238
|
+
self.global_state['prt_fns'] = prt_fns
|
|
239
|
+
|
|
240
|
+
@torch.no_grad
|
|
241
|
+
def approximate(self, closure, params, loss, vars):
|
|
242
|
+
params = TensorList(params)
|
|
243
|
+
loss_approx = None
|
|
244
|
+
|
|
245
|
+
h = self.get_settings('h', params=vars.params, cls=NumberList)
|
|
246
|
+
settings = self.settings[params[0]]
|
|
247
|
+
n_samples = settings['n_samples']
|
|
248
|
+
fd_fn = _RFD_FUNCS[settings['formula']]
|
|
249
|
+
prt_fns = self.global_state['prt_fns']
|
|
250
|
+
|
|
251
|
+
grad = None
|
|
252
|
+
for i in range(n_samples):
|
|
253
|
+
loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, v_0=loss)
|
|
254
|
+
if grad is None: grad = prt_fns[i]().mul_(d)
|
|
255
|
+
else: grad += prt_fns[i]().mul_(d)
|
|
256
|
+
|
|
257
|
+
assert grad is not None
|
|
258
|
+
if n_samples > 1: grad.div_(n_samples)
|
|
259
|
+
return grad, loss, loss_approx
|
|
@@ -1,30 +1,5 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from ...core import OptimizerModule
|
|
8
|
-
from ..regularization import Normalize
|
|
9
|
-
from .grid_ls import (ArangeLS, BacktrackingLS, GridLS, LinspaceLS,
|
|
10
|
-
MultiplicativeLS)
|
|
11
|
-
# from .quad_interp import QuadraticInterpolation2Point
|
|
12
|
-
from .directional_newton import DirectionalNewton3Points, DirectionalNewton
|
|
13
|
-
from .scipy_minimize_scalar import ScipyMinimizeScalarLS
|
|
14
|
-
from .armijo import ArmijoLS
|
|
15
|
-
|
|
16
|
-
LineSearches = Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative', 'newton', 'newton3', 'armijo'] | OptimizerModule
|
|
17
|
-
|
|
18
|
-
def get_line_search(name:str | OptimizerModule) -> OptimizerModule | list[OptimizerModule]:
|
|
19
|
-
if isinstance(name, str):
|
|
20
|
-
name = name.strip().lower()
|
|
21
|
-
if name == 'backtracking': return BacktrackingLS()
|
|
22
|
-
if name == 'multiplicative': return MultiplicativeLS()
|
|
23
|
-
if name == 'brent': return ScipyMinimizeScalarLS(maxiter=8)
|
|
24
|
-
if name == 'brent-exact': return ScipyMinimizeScalarLS()
|
|
25
|
-
if name == 'brent-norm': return [Normalize(), ScipyMinimizeScalarLS(maxiter=16)]
|
|
26
|
-
if name == 'newton': return DirectionalNewton(1)
|
|
27
|
-
if name == 'newton3': return DirectionalNewton3Points(1)
|
|
28
|
-
if name == 'armijo': return ArmijoLS(1)
|
|
29
|
-
raise ValueError(f"Unknown line search method: {name}")
|
|
30
|
-
return name
|
|
1
|
+
from .line_search import LineSearch, GridLineSearch
|
|
2
|
+
from .backtracking import backtracking_line_search, Backtracking, AdaptiveBacktracking
|
|
3
|
+
from .strong_wolfe import StrongWolfe
|
|
4
|
+
from .scipy import ScipyMinimizeScalar
|
|
5
|
+
from .trust_region import TrustRegion
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .line_search import LineSearch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def backtracking_line_search(
|
|
11
|
+
f: Callable[[float], float],
|
|
12
|
+
g_0: float | torch.Tensor,
|
|
13
|
+
init: float = 1.0,
|
|
14
|
+
beta: float = 0.5,
|
|
15
|
+
c: float = 1e-4,
|
|
16
|
+
maxiter: int = 10,
|
|
17
|
+
a_min: float | None = None,
|
|
18
|
+
try_negative: bool = False,
|
|
19
|
+
) -> float | None:
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
objective_fn: evaluates step size along some descent direction.
|
|
24
|
+
dir_derivative: directional derivative along the descent direction.
|
|
25
|
+
alpha_init: initial step size.
|
|
26
|
+
beta: The factor by which to decrease alpha in each iteration
|
|
27
|
+
c: The constant for the Armijo sufficient decrease condition
|
|
28
|
+
max_iter: Maximum number of backtracking iterations (default: 10).
|
|
29
|
+
min_alpha: Minimum allowable step size to prevent near-zero values (default: 1e-16).
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
step size
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
a = init
|
|
36
|
+
f_x = f(0)
|
|
37
|
+
|
|
38
|
+
for iteration in range(maxiter):
|
|
39
|
+
f_a = f(a)
|
|
40
|
+
|
|
41
|
+
if f_a <= f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
|
|
42
|
+
# found an acceptable alpha
|
|
43
|
+
return a
|
|
44
|
+
|
|
45
|
+
# decrease alpha
|
|
46
|
+
a *= beta
|
|
47
|
+
|
|
48
|
+
# alpha too small
|
|
49
|
+
if a_min is not None and a < a_min:
|
|
50
|
+
return a_min
|
|
51
|
+
|
|
52
|
+
# fail
|
|
53
|
+
if try_negative:
|
|
54
|
+
def inv_objective(alpha): return f(-alpha)
|
|
55
|
+
|
|
56
|
+
v = backtracking_line_search(
|
|
57
|
+
inv_objective,
|
|
58
|
+
g_0=-g_0,
|
|
59
|
+
beta=beta,
|
|
60
|
+
c=c,
|
|
61
|
+
maxiter=maxiter,
|
|
62
|
+
a_min=a_min,
|
|
63
|
+
try_negative=False,
|
|
64
|
+
)
|
|
65
|
+
if v is not None: return -v
|
|
66
|
+
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
class Backtracking(LineSearch):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
init: float = 1.0,
|
|
73
|
+
beta: float = 0.5,
|
|
74
|
+
c: float = 1e-4,
|
|
75
|
+
maxiter: int = 10,
|
|
76
|
+
min_alpha: float | None = None,
|
|
77
|
+
adaptive=True,
|
|
78
|
+
try_negative: bool = False,
|
|
79
|
+
):
|
|
80
|
+
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,adaptive=adaptive, try_negative=try_negative)
|
|
81
|
+
super().__init__(defaults=defaults)
|
|
82
|
+
self.global_state['beta_scale'] = 1.0
|
|
83
|
+
|
|
84
|
+
def reset(self):
|
|
85
|
+
super().reset()
|
|
86
|
+
self.global_state['beta_scale'] = 1.0
|
|
87
|
+
|
|
88
|
+
@torch.no_grad
|
|
89
|
+
def search(self, update, vars):
|
|
90
|
+
init, beta, c, maxiter, min_alpha, adaptive, try_negative = itemgetter(
|
|
91
|
+
'init', 'beta', 'c', 'maxiter', 'min_alpha', 'adaptive', 'try_negative')(self.settings[vars.params[0]])
|
|
92
|
+
|
|
93
|
+
objective = self.make_objective(vars=vars)
|
|
94
|
+
|
|
95
|
+
# # directional derivative
|
|
96
|
+
d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
|
|
97
|
+
|
|
98
|
+
# scale beta (beta is multiplicative and i think may be better than scaling initial step size)
|
|
99
|
+
if adaptive: beta = beta * self.global_state['beta_scale']
|
|
100
|
+
|
|
101
|
+
step_size = backtracking_line_search(objective, d, init=init,beta=beta,
|
|
102
|
+
c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
|
|
103
|
+
|
|
104
|
+
# found an alpha that reduces loss
|
|
105
|
+
if step_size is not None:
|
|
106
|
+
self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
107
|
+
return step_size
|
|
108
|
+
|
|
109
|
+
# on fail reduce beta scale value
|
|
110
|
+
self.global_state['beta_scale'] /= 1.5
|
|
111
|
+
return 0
|
|
112
|
+
|
|
113
|
+
def _lerp(start,end,weight):
|
|
114
|
+
return start + weight * (end - start)
|
|
115
|
+
|
|
116
|
+
class AdaptiveBacktracking(LineSearch):
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
init: float = 1.0,
|
|
120
|
+
beta: float = 0.5,
|
|
121
|
+
c: float = 1e-4,
|
|
122
|
+
maxiter: int = 20,
|
|
123
|
+
min_alpha: float | None = None,
|
|
124
|
+
target_iters = 1,
|
|
125
|
+
nplus = 2.0,
|
|
126
|
+
scale_beta = 0.0,
|
|
127
|
+
try_negative: bool = False,
|
|
128
|
+
):
|
|
129
|
+
defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
|
|
130
|
+
super().__init__(defaults=defaults)
|
|
131
|
+
|
|
132
|
+
self.global_state['beta_scale'] = 1.0
|
|
133
|
+
self.global_state['initial_scale'] = 1.0
|
|
134
|
+
|
|
135
|
+
def reset(self):
|
|
136
|
+
super().reset()
|
|
137
|
+
self.global_state['beta_scale'] = 1.0
|
|
138
|
+
self.global_state['initial_scale'] = 1.0
|
|
139
|
+
|
|
140
|
+
@torch.no_grad
|
|
141
|
+
def search(self, update, vars):
|
|
142
|
+
init, beta, c, maxiter, min_alpha, target_iters, nplus, scale_beta, try_negative=itemgetter(
|
|
143
|
+
'init','beta','c','maxiter','min_alpha','target_iters','nplus','scale_beta', 'try_negative')(self.settings[vars.params[0]])
|
|
144
|
+
|
|
145
|
+
objective = self.make_objective(vars=vars)
|
|
146
|
+
|
|
147
|
+
# directional derivative (0 if c = 0 because it is not needed)
|
|
148
|
+
if c == 0: d = 0
|
|
149
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
|
|
150
|
+
|
|
151
|
+
# scale beta
|
|
152
|
+
beta = beta * self.global_state['beta_scale']
|
|
153
|
+
|
|
154
|
+
# scale step size so that decrease is expected at target_iters
|
|
155
|
+
init = init * self.global_state['initial_scale']
|
|
156
|
+
|
|
157
|
+
step_size = backtracking_line_search(objective, d, init=init, beta=beta,
|
|
158
|
+
c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
|
|
159
|
+
|
|
160
|
+
# found an alpha that reduces loss
|
|
161
|
+
if step_size is not None:
|
|
162
|
+
|
|
163
|
+
# update initial_scale
|
|
164
|
+
# initial step size satisfied conditions, increase initial_scale by nplus
|
|
165
|
+
if step_size == init and target_iters > 0:
|
|
166
|
+
self.global_state['initial_scale'] *= nplus ** target_iters
|
|
167
|
+
self.global_state['initial_scale'] = min(self.global_state['initial_scale'], 1e32) # avoid overflow error
|
|
168
|
+
|
|
169
|
+
else:
|
|
170
|
+
# otherwise make initial_scale such that target_iters iterations will satisfy armijo
|
|
171
|
+
init_target = step_size
|
|
172
|
+
for _ in range(target_iters):
|
|
173
|
+
init_target = step_size / beta
|
|
174
|
+
|
|
175
|
+
self.global_state['initial_scale'] = _lerp(
|
|
176
|
+
self.global_state['initial_scale'], init_target / init, 1-scale_beta
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# revert beta_scale
|
|
180
|
+
self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
|
|
181
|
+
|
|
182
|
+
return step_size
|
|
183
|
+
|
|
184
|
+
# on fail reduce beta scale value
|
|
185
|
+
self.global_state['beta_scale'] /= 1.5
|
|
186
|
+
return 0
|