torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
import typing
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
try:
|
|
5
|
-
import scipy.optimize as scopt
|
|
6
|
-
except ModuleNotFoundError:
|
|
7
|
-
scopt = typing.cast(typing.Any, None)
|
|
8
|
-
|
|
9
|
-
from ...tensorlist import TensorList
|
|
10
|
-
from ...core import OptimizationVars
|
|
11
|
-
|
|
12
|
-
from .base_ls import LineSearchBase, MaxIterReached
|
|
13
|
-
|
|
14
|
-
if typing.TYPE_CHECKING:
|
|
15
|
-
import scipy.optimize as scopt
|
|
16
|
-
|
|
17
|
-
class ScipyMinimizeScalarLS(LineSearchBase):
|
|
18
|
-
"""Line search via `scipy.optimize.minimize_scalar`. All args except maxiter are the same as for it.
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
method (Optional[str], optional): 'brent', 'golden' or 'bounded'. Defaults to None.
|
|
22
|
-
maxiter (Optional[int], optional): hard limit on maximum number of function evaluations. Defaults to None.
|
|
23
|
-
bracket (optional): bracket. Defaults to None.
|
|
24
|
-
bounds (optional): bounds. Defaults to None.
|
|
25
|
-
tol (Optional[float], optional): some kind of tolerance. Defaults to None.
|
|
26
|
-
options (optional): options for method. Defaults to None.
|
|
27
|
-
log_lrs (bool, optional): logs lrs and values into `_lrs`. Defaults to False.
|
|
28
|
-
"""
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
method: str | None = None,
|
|
32
|
-
maxiter: int | None = None,
|
|
33
|
-
bracket = None,
|
|
34
|
-
bounds = None,
|
|
35
|
-
tol: float | None = None,
|
|
36
|
-
options = None,
|
|
37
|
-
log_lrs = False,
|
|
38
|
-
):
|
|
39
|
-
if scopt is None: raise ModuleNotFoundError("scipy is not installed")
|
|
40
|
-
super().__init__({}, maxiter=maxiter, log_lrs=log_lrs)
|
|
41
|
-
self.method = method
|
|
42
|
-
self.tol = tol
|
|
43
|
-
self.bracket = bracket
|
|
44
|
-
self.bounds = bounds
|
|
45
|
-
self.options = options
|
|
46
|
-
|
|
47
|
-
@torch.no_grad
|
|
48
|
-
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
49
|
-
try:
|
|
50
|
-
res = scopt.minimize_scalar(
|
|
51
|
-
self._evaluate_lr_ensure_float,
|
|
52
|
-
args = (vars.closure, vars.ascent, params),
|
|
53
|
-
method = self.method,
|
|
54
|
-
tol = self.tol,
|
|
55
|
-
bracket = self.bracket,
|
|
56
|
-
bounds = self.bounds,
|
|
57
|
-
options = self.options,
|
|
58
|
-
) # type:ignore
|
|
59
|
-
except MaxIterReached:
|
|
60
|
-
pass
|
|
61
|
-
|
|
62
|
-
return float(self._best_lr)
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
"""Modules that use other modules."""
|
|
2
|
-
# from .chain import Chain, ChainReturn
|
|
3
|
-
import sys
|
|
4
|
-
|
|
5
|
-
from .alternate import Alternate
|
|
6
|
-
from .grafting import Graft, IntermoduleCautious, SignGrafting
|
|
7
|
-
from .return_overrides import ReturnAscent, ReturnClosure, SetGrad
|
|
8
|
-
|
|
9
|
-
# if sys.version_info[1] < 12:
|
|
10
|
-
from .optimizer_wrapper import Wrap, WrapClosure
|
|
11
|
-
# else:
|
|
12
|
-
# from .optimizer_wrapper import Wrap, WrapClosure
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
from collections.abc import Iterable
|
|
3
|
-
from typing import Any, Literal
|
|
4
|
-
|
|
5
|
-
from ...core import OptimizerModule, _Chainable
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class Alternate(OptimizerModule):
|
|
9
|
-
"""Alternates stepping with multiple modules.
|
|
10
|
-
|
|
11
|
-
Args:
|
|
12
|
-
modules (Iterable[OptimizerModule | Iterable[OptimizerModule]]): modules to alternate between.
|
|
13
|
-
mode (int | list[int] | tuple[int] | "random"], optional):
|
|
14
|
-
can be integer - number of repeats for all modules;
|
|
15
|
-
list or tuple of integers per each module with number of repeats;
|
|
16
|
-
"random" to pick module randomly each time. Defaults to 1.
|
|
17
|
-
seed (int | None, optional): seed for "random" mode. Defaults to None.
|
|
18
|
-
"""
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
modules: Iterable[_Chainable],
|
|
22
|
-
mode: int | list[int] | tuple[int] | Literal["random"] = 1,
|
|
23
|
-
seed: int | None = None
|
|
24
|
-
):
|
|
25
|
-
super().__init__({})
|
|
26
|
-
modules = list(modules)
|
|
27
|
-
|
|
28
|
-
for i,m in enumerate(modules):
|
|
29
|
-
self._set_child_(i, m)
|
|
30
|
-
|
|
31
|
-
self.random = random.Random(seed)
|
|
32
|
-
|
|
33
|
-
if isinstance(mode, int): mode = [mode for _ in modules]
|
|
34
|
-
self.mode: list[int] | tuple[int] | Literal['random'] = mode
|
|
35
|
-
|
|
36
|
-
self.cur = 0
|
|
37
|
-
if self.mode == 'random': self.remaining = 0
|
|
38
|
-
else:
|
|
39
|
-
self.remaining = self.mode[0]
|
|
40
|
-
if len(self.mode) != len(self.children):
|
|
41
|
-
raise ValueError(f"got {len(self.children)} modules but {len(mode)} repeats, they should be the same")
|
|
42
|
-
|
|
43
|
-
def step(self, vars):
|
|
44
|
-
if self.mode == 'random':
|
|
45
|
-
module = self.random.choice(list(self.children.values()))
|
|
46
|
-
|
|
47
|
-
else:
|
|
48
|
-
if self.remaining == 0:
|
|
49
|
-
self.cur += 1
|
|
50
|
-
|
|
51
|
-
if self.cur >= len(self.mode):
|
|
52
|
-
self.cur = 0
|
|
53
|
-
|
|
54
|
-
if self.remaining == 0: self.remaining = self.mode[self.cur]
|
|
55
|
-
|
|
56
|
-
module = self.children[self.cur]
|
|
57
|
-
|
|
58
|
-
self.remaining -= 1
|
|
59
|
-
|
|
60
|
-
if self.next_module is None:
|
|
61
|
-
return module.step(vars)
|
|
62
|
-
|
|
63
|
-
vars.ascent = module.return_ascent(vars)
|
|
64
|
-
return self._update_params_or_step_with_next(vars)
|
|
65
|
-
|
|
@@ -1,195 +0,0 @@
|
|
|
1
|
-
from collections.abc import Iterable
|
|
2
|
-
from typing import Literal
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import OptimizerModule
|
|
6
|
-
from ...tensorlist import TensorList
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class Graft(OptimizerModule):
|
|
10
|
-
"""
|
|
11
|
-
Optimizer grafting (magnitude#direction).
|
|
12
|
-
Takes update of one optimizer and makes its norm same as update of another optimizer.
|
|
13
|
-
Can be applied to all weights or layerwise.
|
|
14
|
-
|
|
15
|
-
Args:
|
|
16
|
-
magnitude (OptimizerModule | Iterable[OptimizerModule]):
|
|
17
|
-
module to use magnitude from.
|
|
18
|
-
If sequence of modules is provided, they will be chained.
|
|
19
|
-
direction (OptimizerModule | Iterable[OptimizerModule]):
|
|
20
|
-
module/modules to use direction from.
|
|
21
|
-
If sequence of modules is provided, they will be chained.
|
|
22
|
-
ord (int, optional): norm type. Defaults to 2.
|
|
23
|
-
eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
|
|
24
|
-
layerwise (bool, optional): whether to apply grafting layerwise. Defaults to False.
|
|
25
|
-
|
|
26
|
-
reference
|
|
27
|
-
*Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C.
|
|
28
|
-
Learning Rate Grafting: Transferability of Optimizer Tuning.*
|
|
29
|
-
"""
|
|
30
|
-
def __init__(
|
|
31
|
-
self,
|
|
32
|
-
magnitude: OptimizerModule | Iterable[OptimizerModule],
|
|
33
|
-
direction: OptimizerModule | Iterable[OptimizerModule],
|
|
34
|
-
ord: float = 2,
|
|
35
|
-
eps: float = 1e-8,
|
|
36
|
-
layerwise: bool = False,
|
|
37
|
-
# TODO: channelwise
|
|
38
|
-
):
|
|
39
|
-
super().__init__({})
|
|
40
|
-
self._set_child_('magnitude', magnitude)
|
|
41
|
-
self._set_child_('direction', direction)
|
|
42
|
-
self.ord = ord
|
|
43
|
-
self.eps = eps
|
|
44
|
-
self.layerwise = layerwise
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@torch.no_grad
|
|
48
|
-
def step(self, vars):
|
|
49
|
-
state_copy = vars.copy(clone_ascent=True)
|
|
50
|
-
magnitude = self.children['magnitude'].return_ascent(state_copy)
|
|
51
|
-
|
|
52
|
-
if state_copy.grad is not None: vars.grad = state_copy.grad
|
|
53
|
-
if state_copy.fx0 is not None: vars.fx0 = state_copy.fx0
|
|
54
|
-
if state_copy.fx0_approx is not None: vars.fx0_approx = state_copy.fx0_approx
|
|
55
|
-
|
|
56
|
-
direction = self.children['direction'].return_ascent(vars)
|
|
57
|
-
|
|
58
|
-
if self.layerwise:
|
|
59
|
-
M = magnitude.norm(self.ord)
|
|
60
|
-
D = direction.norm(self.ord)
|
|
61
|
-
D.select_set_(D == 0, M)
|
|
62
|
-
|
|
63
|
-
else:
|
|
64
|
-
M = magnitude.total_vector_norm(self.ord)
|
|
65
|
-
D = direction.total_vector_norm(self.ord)
|
|
66
|
-
if D == 0: D = M
|
|
67
|
-
|
|
68
|
-
vars.ascent = direction.mul_(M / (D + self.eps))
|
|
69
|
-
return self._update_params_or_step_with_next(vars)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class SignGrafting(OptimizerModule):
|
|
74
|
-
"""Weight-wise grafting-like operation where sign of the ascent is taken from first module
|
|
75
|
-
and magnitude from second module.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
magnitude (OptimizerModule | Iterable[OptimizerModule]):
|
|
79
|
-
module to take magnitude from.
|
|
80
|
-
If sequence of modules is provided, they will be chained.
|
|
81
|
-
sign (OptimizerModule | Iterable[OptimizerModule]):
|
|
82
|
-
module to take sign from.
|
|
83
|
-
If sequence of modules is provided, they will be chained.
|
|
84
|
-
"""
|
|
85
|
-
def __init__(
|
|
86
|
-
self,
|
|
87
|
-
magnitude: OptimizerModule | Iterable[OptimizerModule],
|
|
88
|
-
sign: OptimizerModule | Iterable[OptimizerModule],
|
|
89
|
-
):
|
|
90
|
-
super().__init__({})
|
|
91
|
-
|
|
92
|
-
self._set_child_('magnitude', magnitude)
|
|
93
|
-
self._set_child_('sign', sign)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
@torch.no_grad
|
|
97
|
-
def step(self, vars):
|
|
98
|
-
state_copy = vars.copy(clone_ascent=True)
|
|
99
|
-
magnitude = self.children['magnitude'].return_ascent(state_copy)
|
|
100
|
-
|
|
101
|
-
# make sure to store grad and fx0 if it was calculated
|
|
102
|
-
vars.update_attrs_(state_copy)
|
|
103
|
-
|
|
104
|
-
sign = self.children['sign'].return_ascent(vars)
|
|
105
|
-
|
|
106
|
-
vars.ascent = magnitude.copysign_(sign)
|
|
107
|
-
return self._update_params_or_step_with_next(vars)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
class IntermoduleCautious(OptimizerModule):
|
|
111
|
-
"""Negates update for parameters where updates of two modules or module chains have inconsistent sign.
|
|
112
|
-
Optionally normalizes the update by the number of parameters that are not masked.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
main_module (OptimizerModule | Iterable[OptimizerModule]):
|
|
116
|
-
main module or sequence of modules to chain, which update will be used with a consistency mask applied.
|
|
117
|
-
compare_module (OptimizerModule | Iterable[OptimizerModule]):
|
|
118
|
-
module or sequence of modules to chain, which update will be used to compute a consistency mask.
|
|
119
|
-
Can also be set to `ascent` to compare to update that is passed `main_module`, or `grad` to compare
|
|
120
|
-
to gradients.
|
|
121
|
-
normalize (bool, optional):
|
|
122
|
-
renormalize update after masking.
|
|
123
|
-
only has effect when mode is 'zero'. Defaults to False.
|
|
124
|
-
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
125
|
-
mode (str, optional):
|
|
126
|
-
what to do with updates with inconsistent signs.
|
|
127
|
-
|
|
128
|
-
"zero" - set them to zero (as in paper)
|
|
129
|
-
|
|
130
|
-
"grad" - set them to the gradient
|
|
131
|
-
|
|
132
|
-
"compare_module" - set them to `compare_module`'s update
|
|
133
|
-
|
|
134
|
-
"negate" - negate them (same as using update magnitude and gradient sign)
|
|
135
|
-
"""
|
|
136
|
-
def __init__(
|
|
137
|
-
self,
|
|
138
|
-
main_module: OptimizerModule | Iterable[OptimizerModule],
|
|
139
|
-
compare_module: OptimizerModule | Iterable[OptimizerModule] | Literal['ascent', 'grad'],
|
|
140
|
-
normalize=False,
|
|
141
|
-
eps=1e-6,
|
|
142
|
-
mode: Literal["zero", "grad", "backtrack", "compare_module"] = "zero",
|
|
143
|
-
):
|
|
144
|
-
super().__init__({})
|
|
145
|
-
|
|
146
|
-
self._set_child_('main',main_module)
|
|
147
|
-
if isinstance(compare_module, str): self.compare_mode = compare_module
|
|
148
|
-
else:
|
|
149
|
-
self._set_child_('compare', compare_module)
|
|
150
|
-
self.compare_mode = 'module'
|
|
151
|
-
self.eps = eps
|
|
152
|
-
self.normalize = normalize
|
|
153
|
-
self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode
|
|
154
|
-
|
|
155
|
-
@torch.no_grad
|
|
156
|
-
def step(self, vars):
|
|
157
|
-
params = None
|
|
158
|
-
state_copy = vars.copy(clone_ascent=True)
|
|
159
|
-
ascent = self.children['main'].return_ascent(state_copy)
|
|
160
|
-
vars.update_attrs_(state_copy)
|
|
161
|
-
|
|
162
|
-
if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(vars)
|
|
163
|
-
else:
|
|
164
|
-
params = self.get_params()
|
|
165
|
-
if self.compare_mode == 'ascent': compare: TensorList = vars.maybe_use_grad_(params)
|
|
166
|
-
elif self.compare_mode == 'grad': compare: TensorList = vars.maybe_compute_grad_(params)
|
|
167
|
-
else: raise ValueError(f'Invalid compare_module: {self.compare_mode}')
|
|
168
|
-
|
|
169
|
-
# mask will be > 0 for parameters where both signs are the same
|
|
170
|
-
mask = (ascent * compare) > 0
|
|
171
|
-
|
|
172
|
-
if self.mode == 'backtrack':
|
|
173
|
-
ascent -= ascent.mul(2).mul_(mask.logical_not_())
|
|
174
|
-
|
|
175
|
-
else:
|
|
176
|
-
# normalize if mode is `zero`
|
|
177
|
-
if self.normalize and self.mode == 'zero':
|
|
178
|
-
fmask = mask.to(ascent[0].dtype)
|
|
179
|
-
fmask /= fmask.total_mean() + self.eps
|
|
180
|
-
else:
|
|
181
|
-
fmask = mask
|
|
182
|
-
|
|
183
|
-
# apply the mask
|
|
184
|
-
ascent *= fmask
|
|
185
|
-
|
|
186
|
-
if self.mode == 'grad':
|
|
187
|
-
params = self.get_params()
|
|
188
|
-
ascent += vars.maybe_compute_grad_(params) * mask.logical_not_()
|
|
189
|
-
|
|
190
|
-
elif self.mode == 'compare_module':
|
|
191
|
-
ascent += compare * mask.logical_not_()
|
|
192
|
-
|
|
193
|
-
vars.ascent = ascent
|
|
194
|
-
return self._update_params_or_step_with_next(vars, params)
|
|
195
|
-
|
|
@@ -1,173 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable, Sequence
|
|
2
|
-
from typing import Any, overload
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from typing_extensions import Concatenate, ParamSpec
|
|
6
|
-
|
|
7
|
-
from ...core import OptimizerModule
|
|
8
|
-
from .return_overrides import SetGrad
|
|
9
|
-
|
|
10
|
-
K = ParamSpec('K')
|
|
11
|
-
|
|
12
|
-
class Wrap(OptimizerModule):
|
|
13
|
-
"""
|
|
14
|
-
Wraps any torch.optim.Optimizer.
|
|
15
|
-
|
|
16
|
-
Sets .grad attribute to the current update and steps with the `optimizer`.
|
|
17
|
-
|
|
18
|
-
Additionally, if this is not the last module, this takes the update of `optimizer`,
|
|
19
|
-
undoes it and passes to the next module instead. That means you can chain multiple
|
|
20
|
-
optimizers together.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
optimizer (torch.optim.Optimizer): optimizer to wrap,
|
|
24
|
-
or a callable (class) that constructs the optimizer.
|
|
25
|
-
kwargs:
|
|
26
|
-
if class is passed, kwargs are passed to the constructor.
|
|
27
|
-
parameters are passed separately and automatically
|
|
28
|
-
which is the point of passing a constructor
|
|
29
|
-
instead of an optimizer directly.
|
|
30
|
-
|
|
31
|
-
This can be constructed in two ways.
|
|
32
|
-
.. code-block:: python
|
|
33
|
-
wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
|
|
34
|
-
# or
|
|
35
|
-
wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
@overload
|
|
39
|
-
def __init__(self, optimizer: torch.optim.Optimizer): ...
|
|
40
|
-
@overload
|
|
41
|
-
# def __init__[**K](
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
|
|
45
|
-
*args: K.args,
|
|
46
|
-
**kwargs: K.kwargs,
|
|
47
|
-
# optimizer: abc.Callable[..., torch.optim.Optimizer],
|
|
48
|
-
# *args,
|
|
49
|
-
# **kwargs,
|
|
50
|
-
): ...
|
|
51
|
-
def __init__(self, optimizer, *args, **kwargs):
|
|
52
|
-
|
|
53
|
-
super().__init__({})
|
|
54
|
-
self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
|
|
55
|
-
self._args = args
|
|
56
|
-
self._kwargs = kwargs
|
|
57
|
-
|
|
58
|
-
def _initialize_(self, params, set_passed_params):
|
|
59
|
-
"""Initializes this optimizer and all children with the given parameters."""
|
|
60
|
-
super()._initialize_(params, set_passed_params=set_passed_params)
|
|
61
|
-
if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
|
|
62
|
-
self.optimizer = self._optimizer_cls
|
|
63
|
-
else:
|
|
64
|
-
self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
|
|
65
|
-
|
|
66
|
-
@torch.no_grad
|
|
67
|
-
def step(self, vars):
|
|
68
|
-
# check attrs
|
|
69
|
-
# if self.pass_closure:
|
|
70
|
-
# if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
|
|
71
|
-
# if state.ascent is not None:
|
|
72
|
-
# raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
|
|
73
|
-
|
|
74
|
-
params = self.get_params()
|
|
75
|
-
|
|
76
|
-
if self.next_module is None:
|
|
77
|
-
# set grad to ascent and make a step with the optimizer
|
|
78
|
-
g = vars.maybe_use_grad_(params)
|
|
79
|
-
params.set_grad_(g)
|
|
80
|
-
vars.fx0 = self.optimizer.step()
|
|
81
|
-
return vars.get_loss()
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
params_before_step = params.clone()
|
|
85
|
-
|
|
86
|
-
g = vars.maybe_use_grad_(params)
|
|
87
|
-
params.set_grad_(g)
|
|
88
|
-
vars.fx0 = self.optimizer.step()
|
|
89
|
-
|
|
90
|
-
# calculate update as difference in params
|
|
91
|
-
vars.ascent = params_before_step - params
|
|
92
|
-
params.set_(params_before_step)
|
|
93
|
-
return self.next_module.step(vars)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
class WrapClosure(OptimizerModule):
|
|
97
|
-
"""
|
|
98
|
-
Wraps any torch.optim.Optimizer. This only works with modules with :code:`target = "Closure"` argument.
|
|
99
|
-
The modified closure will be passed to the optimizer.
|
|
100
|
-
|
|
101
|
-
Alternative any module can be turned into a closure module by using :any:`MakeClosure` module,
|
|
102
|
-
in that case this should be placed after MakeClosure.
|
|
103
|
-
|
|
104
|
-
Args:
|
|
105
|
-
optimizer (torch.optim.Optimizer): optimizer to wrap,
|
|
106
|
-
or a callable (class) that constructs the optimizer.
|
|
107
|
-
kwargs:
|
|
108
|
-
if class is passed, kwargs are passed to the constructor.
|
|
109
|
-
parameters are passed separately and automatically
|
|
110
|
-
which is the point of passing a constructor
|
|
111
|
-
instead of an optimizer directly.
|
|
112
|
-
|
|
113
|
-
This can be constructed in two ways.
|
|
114
|
-
|
|
115
|
-
.. code-block:: python
|
|
116
|
-
|
|
117
|
-
wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
|
|
118
|
-
# or
|
|
119
|
-
wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
|
|
120
|
-
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
@overload
|
|
124
|
-
def __init__(self, optimizer: torch.optim.Optimizer,): ...
|
|
125
|
-
@overload
|
|
126
|
-
def __init__(
|
|
127
|
-
self,
|
|
128
|
-
optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
|
|
129
|
-
*args: K.args,
|
|
130
|
-
**kwargs: K.kwargs,
|
|
131
|
-
# optimizer: abc.Callable[..., torch.optim.Optimizer],
|
|
132
|
-
# *args,
|
|
133
|
-
# **kwargs,
|
|
134
|
-
): ...
|
|
135
|
-
def __init__(self, optimizer, *args, **kwargs):
|
|
136
|
-
|
|
137
|
-
super().__init__({})
|
|
138
|
-
self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
|
|
139
|
-
self._args = args
|
|
140
|
-
self._kwargs = kwargs
|
|
141
|
-
|
|
142
|
-
def _initialize_(self, params, set_passed_params):
|
|
143
|
-
"""Initializes this optimizer and all children with the given parameters."""
|
|
144
|
-
super()._initialize_(params, set_passed_params=set_passed_params)
|
|
145
|
-
if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
|
|
146
|
-
self.optimizer = self._optimizer_cls
|
|
147
|
-
else:
|
|
148
|
-
self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
|
|
149
|
-
|
|
150
|
-
@torch.no_grad
|
|
151
|
-
def step(self, vars):
|
|
152
|
-
# check attrs
|
|
153
|
-
# if self.pass_closure:
|
|
154
|
-
# if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
|
|
155
|
-
# if state.ascent is not None:
|
|
156
|
-
# raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
|
|
157
|
-
|
|
158
|
-
params = self.get_params()
|
|
159
|
-
|
|
160
|
-
if self.next_module is None:
|
|
161
|
-
# set grad to ascent and make a step with the optimizer
|
|
162
|
-
vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
|
|
163
|
-
return vars.get_loss()
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
params_before_step = params.clone()
|
|
167
|
-
vars.fx0 = self.optimizer.step(vars.closure) # type:ignore
|
|
168
|
-
|
|
169
|
-
# calculate update as difference in params
|
|
170
|
-
vars.ascent = params_before_step - params
|
|
171
|
-
params.set_(params_before_step)
|
|
172
|
-
return self.next_module.step(vars)
|
|
173
|
-
|
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...tensorlist import TensorList
|
|
3
|
-
from ...core import OptimizerModule, _get_loss, _ClosureType
|
|
4
|
-
|
|
5
|
-
class SetGrad(OptimizerModule):
|
|
6
|
-
"""Doesn't update parameters, instead replaces all parameters `.grad` attribute with the current update.
|
|
7
|
-
You can now step with any pytorch optimizer that utilises the `.grad` attribute."""
|
|
8
|
-
def __init__(self):
|
|
9
|
-
super().__init__({})
|
|
10
|
-
|
|
11
|
-
@torch.no_grad
|
|
12
|
-
def step(self, vars):
|
|
13
|
-
if self.next_module is not None: raise ValueError("SetGrad can't have children")
|
|
14
|
-
params = self.get_params()
|
|
15
|
-
g = vars.maybe_use_grad_(params) # this may execute the closure which might be modified
|
|
16
|
-
params.set_grad_(g)
|
|
17
|
-
return vars.get_loss()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class ReturnAscent(OptimizerModule):
|
|
21
|
-
"""Doesn't update parameters, instead returns the update as a TensorList of tensors."""
|
|
22
|
-
def __init__(self):
|
|
23
|
-
super().__init__({})
|
|
24
|
-
|
|
25
|
-
@torch.no_grad
|
|
26
|
-
def step(self, vars) -> TensorList: # type:ignore
|
|
27
|
-
if self.next_module is not None: raise ValueError("ReturnAscent can't have children")
|
|
28
|
-
params = self.get_params()
|
|
29
|
-
update = vars.maybe_use_grad_(params) # this will execute the closure which might be modified
|
|
30
|
-
return update
|
|
31
|
-
|
|
32
|
-
class ReturnClosure(OptimizerModule):
|
|
33
|
-
"""Doesn't update parameters, instead returns the current modified closure.
|
|
34
|
-
For example, if you put this after :code:`torchzero.modules.FDM(target = "closure")`,
|
|
35
|
-
the closure will set `.grad` attribute to gradients approximated via finite difference.
|
|
36
|
-
You can then pass that closure to something that requires closure like `torch.optim.LBFGS`."""
|
|
37
|
-
def __init__(self):
|
|
38
|
-
super().__init__({})
|
|
39
|
-
|
|
40
|
-
@torch.no_grad
|
|
41
|
-
def step(self, vars) -> _ClosureType: # type:ignore
|
|
42
|
-
if self.next_module is not None: raise ValueError("ReturnClosure can't have children")
|
|
43
|
-
if vars.closure is None:
|
|
44
|
-
raise ValueError("MakeClosure requires closure")
|
|
45
|
-
return vars.closure
|
|
46
|
-
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
r"""
|
|
2
|
-
This module includes various basic operators, notable LR for setting the learning rate,
|
|
3
|
-
as well as gradient/update clipping and normalization.
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
from .basic import Clone, Fill, Grad, Identity, Lambda, Zeros, Alpha, GradToUpdate, MakeClosure
|
|
7
|
-
from .lr import LR
|
|
8
|
-
from .on_increase import NegateOnLossIncrease
|
|
9
|
-
from .multistep import Multistep
|
|
10
|
-
from .accumulate import Accumulate
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable, Iterable
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...tensorlist import TensorList
|
|
6
|
-
|
|
7
|
-
from ...core import OptimizerModule
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Accumulate(OptimizerModule):
|
|
11
|
-
"""Accumulates update over n steps, and steps once updates have been accumulated.
|
|
12
|
-
Put this as the first module to get gradient accumulation.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
n_steps (int): number of steps (batches) to accumulate the update over.
|
|
16
|
-
mean (bool, optional):
|
|
17
|
-
If True, divides accumulated gradients by number of step,
|
|
18
|
-
since most loss functions calculate the mean of all samples
|
|
19
|
-
over batch. Defaults to True.
|
|
20
|
-
"""
|
|
21
|
-
def __init__(self, n_steps: int, mean = True):
|
|
22
|
-
|
|
23
|
-
super().__init__({})
|
|
24
|
-
self.n_steps = n_steps
|
|
25
|
-
self.mean = mean
|
|
26
|
-
self.cur_step = 0
|
|
27
|
-
|
|
28
|
-
@torch.no_grad
|
|
29
|
-
def step(self, vars):
|
|
30
|
-
self.cur_step += 1
|
|
31
|
-
|
|
32
|
-
params = self.get_params()
|
|
33
|
-
accumulated_update = self.get_state_key('accumulated_grads')
|
|
34
|
-
accumulated_update += vars.maybe_use_grad_(params)
|
|
35
|
-
|
|
36
|
-
if self.cur_step % self.n_steps == 0:
|
|
37
|
-
vars.ascent = accumulated_update.clone()
|
|
38
|
-
if self.mean: vars.ascent /= self.n_steps
|
|
39
|
-
accumulated_update.zero_()
|
|
40
|
-
return self._update_params_or_step_with_next(vars)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return vars.get_loss()
|