torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -1,125 +0,0 @@
|
|
|
1
|
-
from typing import Literal, Any, cast
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...utils.python_tools import _ScalarLoss
|
|
6
|
-
from ...tensorlist import Distributions, TensorList
|
|
7
|
-
from ...core import _ClosureType, OptimizerModule, OptimizationVars
|
|
8
|
-
from ._fd_formulas import _FD_Formulas
|
|
9
|
-
from .base_approximator import GradientApproximatorBase
|
|
10
|
-
|
|
11
|
-
def _two_point_rcd_(closure: _ClosureType, params: TensorList, perturbation: TensorList, eps: TensorList, fx0: _ScalarLoss | None, ):
|
|
12
|
-
"""Two point randomized finite difference (same signature for all other finite differences functions).
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
closure (Callable): A closure that reevaluates the model and returns the loss.
|
|
16
|
-
params (TensorList): TensorList with parameters.
|
|
17
|
-
perturbation (TensorList): TensorList with perturbation ALREADY MULTIPLIED BY EPSILON.
|
|
18
|
-
eps (TensorList): Finite difference epsilon.
|
|
19
|
-
fx0 (ScalarType): Loss at fx0, to avoid reevaluating it each time. On some functions can be None when it isn't needed.
|
|
20
|
-
|
|
21
|
-
Returns:
|
|
22
|
-
TensorList with gradient estimation and approximate loss.
|
|
23
|
-
"""
|
|
24
|
-
# positive loss
|
|
25
|
-
params += perturbation
|
|
26
|
-
loss_pos = closure(False)
|
|
27
|
-
|
|
28
|
-
# negative loss
|
|
29
|
-
params.sub_(perturbation, alpha = 2)
|
|
30
|
-
loss_neg = closure(False)
|
|
31
|
-
|
|
32
|
-
# restore params
|
|
33
|
-
params += perturbation
|
|
34
|
-
|
|
35
|
-
# calculate gradient estimation using central finite differences formula
|
|
36
|
-
# (we square eps in denominator because perturbation is already multiplied by eps)
|
|
37
|
-
# grad_est = (perturbation * (loss_pos - loss_neg)) / (2 * eps**2)
|
|
38
|
-
# is equivalent to the following:
|
|
39
|
-
return perturbation * eps.map(lambda x: (loss_pos - loss_neg) / (2 * x**2)), loss_pos
|
|
40
|
-
# also we can't reuse the perturbatuion tensor and multiply it in place,
|
|
41
|
-
# since if randomize_every is more than 1, that would break it.
|
|
42
|
-
|
|
43
|
-
def _two_point_rfd_(closure: _ClosureType, params: TensorList, perturbation: TensorList, eps: TensorList, fx0: _ScalarLoss | None):
|
|
44
|
-
if fx0 is None: raise ValueError()
|
|
45
|
-
|
|
46
|
-
params += perturbation
|
|
47
|
-
fx1 = closure(False)
|
|
48
|
-
|
|
49
|
-
params -= perturbation
|
|
50
|
-
|
|
51
|
-
return perturbation * eps.map(lambda x: (fx1 - fx0) / x**2), fx0
|
|
52
|
-
|
|
53
|
-
def _two_point_rbd_(closure: _ClosureType, params: TensorList, perturbation: TensorList, eps: TensorList, fx0: _ScalarLoss | None):
|
|
54
|
-
if fx0 is None: raise ValueError()
|
|
55
|
-
|
|
56
|
-
params -= perturbation
|
|
57
|
-
fx1 = closure(False)
|
|
58
|
-
|
|
59
|
-
params += perturbation
|
|
60
|
-
|
|
61
|
-
return perturbation * eps.map(lambda x: (fx0 - fx1) / x**2), fx0
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class RandomizedFDM(GradientApproximatorBase):
|
|
65
|
-
"""Gradient approximation via randomized finite difference.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
eps (float, optional): finite difference epsilon. Defaults to 1e-5.
|
|
69
|
-
formula (_FD_Formulas, optional): Finite difference formula. Defaults to 'forward'.
|
|
70
|
-
n_samples (int, optional): number of times gradient is approximated and then averaged. Defaults to 1.
|
|
71
|
-
distribution (Distributions, optional): distribution for random perturbations. Defaults to "normal".
|
|
72
|
-
target (str, optional):
|
|
73
|
-
determines what this module sets.
|
|
74
|
-
|
|
75
|
-
"ascent" - it creates a new ascent direction but doesn't treat is as gradient.
|
|
76
|
-
|
|
77
|
-
"grad" - it creates the gradient and sets it to `.grad` attributes (default).
|
|
78
|
-
|
|
79
|
-
"closure" - it makes a new closure that sets the estimated gradient to the `.grad` attributes.
|
|
80
|
-
"""
|
|
81
|
-
def __init__(
|
|
82
|
-
self,
|
|
83
|
-
eps: float = 1e-5,
|
|
84
|
-
formula: _FD_Formulas = "forward",
|
|
85
|
-
n_samples: int = 1,
|
|
86
|
-
distribution: Distributions = "normal",
|
|
87
|
-
target: Literal['ascent', 'grad', 'closure'] = 'grad',
|
|
88
|
-
):
|
|
89
|
-
defaults = dict(eps = eps)
|
|
90
|
-
|
|
91
|
-
if formula == 'forward':
|
|
92
|
-
self._finite_difference = _two_point_rfd_
|
|
93
|
-
requires_fx0 = True
|
|
94
|
-
|
|
95
|
-
elif formula == 'backward':
|
|
96
|
-
self._finite_difference = _two_point_rbd_
|
|
97
|
-
requires_fx0 = True
|
|
98
|
-
|
|
99
|
-
elif formula == 'central':
|
|
100
|
-
self._finite_difference = _two_point_rcd_
|
|
101
|
-
requires_fx0 = False
|
|
102
|
-
|
|
103
|
-
else: raise ValueError(f"Unknown formula: {formula}")
|
|
104
|
-
|
|
105
|
-
self.n_samples = n_samples
|
|
106
|
-
self.distribution: Distributions = distribution
|
|
107
|
-
|
|
108
|
-
super().__init__(defaults, requires_fx0=requires_fx0, target = target)
|
|
109
|
-
|
|
110
|
-
@torch.no_grad
|
|
111
|
-
def _make_ascent(self, closure, params, fx0):
|
|
112
|
-
eps = self.get_group_key('eps')
|
|
113
|
-
fx0_approx = None
|
|
114
|
-
|
|
115
|
-
if self.n_samples == 1:
|
|
116
|
-
grads, fx0_approx = self._finite_difference(closure, params, params.sample_like(eps, self.distribution), eps, fx0)
|
|
117
|
-
|
|
118
|
-
else:
|
|
119
|
-
grads = params.zeros_like()
|
|
120
|
-
for i in range(self.n_samples):
|
|
121
|
-
g, fx0_approx = self._finite_difference(closure, params, params.sample_like(eps, self.distribution), eps, fx0)
|
|
122
|
-
grads += g
|
|
123
|
-
grads /= self.n_samples
|
|
124
|
-
|
|
125
|
-
return grads, fx0, fx0_approx
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
|
|
3
|
-
from ...tensorlist import TensorList
|
|
4
|
-
from ...core import OptimizationVars
|
|
5
|
-
from .base_ls import LineSearchBase
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class ArmijoLS(LineSearchBase):
|
|
9
|
-
"""Armijo backtracking line search
|
|
10
|
-
|
|
11
|
-
Args:
|
|
12
|
-
alpha (float): initial step size.
|
|
13
|
-
mul (float, optional): lr multiplier on each iteration. Defaults to 0.5.
|
|
14
|
-
beta (float, optional):
|
|
15
|
-
armijo condition parameter, fraction of expected linear loss decrease to accept.
|
|
16
|
-
Larger values mean loss needs to decrease more for a step sizer to be accepted. Defaults to 1e-4.
|
|
17
|
-
max_iter (int, optional): maximum iterations. Defaults to 10.
|
|
18
|
-
log_lrs (bool, optional): logs learning rates. Defaults to False.
|
|
19
|
-
"""
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
alpha: float = 1,
|
|
23
|
-
mul: float = 0.5,
|
|
24
|
-
beta: float = 1e-2,
|
|
25
|
-
max_iter: int = 10,
|
|
26
|
-
log_lrs = False,
|
|
27
|
-
):
|
|
28
|
-
defaults = dict(alpha=alpha)
|
|
29
|
-
super().__init__(defaults, maxiter=None, log_lrs=log_lrs)
|
|
30
|
-
self.mul = mul
|
|
31
|
-
self.beta = beta
|
|
32
|
-
self.max_iter = max_iter
|
|
33
|
-
|
|
34
|
-
@torch.no_grad
|
|
35
|
-
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
36
|
-
if vars.closure is None: raise RuntimeError(f"Line searches ({self.__class__.__name__}) require a closure")
|
|
37
|
-
ascent = vars.maybe_use_grad_(params)
|
|
38
|
-
grad = vars.maybe_compute_grad_(params)
|
|
39
|
-
alpha = self.get_first_group_key('alpha')
|
|
40
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
41
|
-
|
|
42
|
-
# loss decrease per lr=1 if function was linear
|
|
43
|
-
decrease_per_lr = (grad*ascent).total_sum()
|
|
44
|
-
|
|
45
|
-
for _ in range(self.max_iter):
|
|
46
|
-
loss = self._evaluate_lr_(alpha, vars.closure, ascent, params)
|
|
47
|
-
|
|
48
|
-
# expected decrease
|
|
49
|
-
expected_decrease = decrease_per_lr * alpha
|
|
50
|
-
|
|
51
|
-
if (vars.fx0 - loss) / expected_decrease >= self.beta:
|
|
52
|
-
return alpha
|
|
53
|
-
|
|
54
|
-
alpha *= self.mul
|
|
55
|
-
|
|
56
|
-
return 0
|
|
@@ -1,139 +0,0 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...tensorlist import TensorList
|
|
8
|
-
from ...core import _ClosureType, OptimizationVars, OptimizerModule, _maybe_pass_backward
|
|
9
|
-
from ...utils.python_tools import _ScalarLoss
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class MaxIterReached(Exception): pass
|
|
13
|
-
|
|
14
|
-
class LineSearchBase(OptimizerModule, ABC):
|
|
15
|
-
"""Base linesearch class. This is an abstract class, please don't use it as the optimizer.
|
|
16
|
-
|
|
17
|
-
When inheriting from this class the easiest way is only override `_find_best_lr`, which should
|
|
18
|
-
return the final lr to use.
|
|
19
|
-
|
|
20
|
-
Args:
|
|
21
|
-
defaults (dict): dictionary with default parameters for the module.
|
|
22
|
-
target (str, optional):
|
|
23
|
-
determines how _update method is used in the default step method.
|
|
24
|
-
|
|
25
|
-
"ascent" - it updates the ascent
|
|
26
|
-
|
|
27
|
-
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
28
|
-
|
|
29
|
-
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
30
|
-
maxiter (_type_, optional): maximum line search iterations
|
|
31
|
-
(useful for things like scipy.optimize.minimize_scalar) as it doesn't have
|
|
32
|
-
an exact iteration limit. Defaults to None.
|
|
33
|
-
log_lrs (bool, optional):
|
|
34
|
-
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
35
|
-
Defaults to False.
|
|
36
|
-
"""
|
|
37
|
-
def __init__(
|
|
38
|
-
self,
|
|
39
|
-
defaults: dict,
|
|
40
|
-
target: Literal['grad', 'ascent', 'closure'] = 'ascent',
|
|
41
|
-
maxiter=None,
|
|
42
|
-
log_lrs=False,
|
|
43
|
-
):
|
|
44
|
-
super().__init__(defaults, target=target)
|
|
45
|
-
self._reset()
|
|
46
|
-
|
|
47
|
-
self.maxiter = maxiter
|
|
48
|
-
self.log_lrs = log_lrs
|
|
49
|
-
self._lrs: list[dict[float, _ScalarLoss]] = []
|
|
50
|
-
"""this only gets filled if `log_lrs` is True. On each step, a dictionary is added to this list,
|
|
51
|
-
with all lrs tested at that step as keys and corresponding losses as values."""
|
|
52
|
-
|
|
53
|
-
def _reset(self):
|
|
54
|
-
"""Resets `_last_lr`, `_lowest_loss`, `_best_lr`, `_fx0_approx` and `_current_iter`."""
|
|
55
|
-
self._last_lr = 0
|
|
56
|
-
self._lowest_loss = float('inf')
|
|
57
|
-
self._best_lr = 0
|
|
58
|
-
self._fx0_approx = None
|
|
59
|
-
self._current_iter = 0
|
|
60
|
-
|
|
61
|
-
def _set_lr_(self, lr: float, ascent_direction: TensorList, params: TensorList, ):
|
|
62
|
-
alpha = self._last_lr - lr
|
|
63
|
-
if alpha != 0: params.add_(ascent_direction, alpha = alpha)
|
|
64
|
-
self._last_lr = lr
|
|
65
|
-
|
|
66
|
-
# lr is first here so that we can use a partial
|
|
67
|
-
def _evaluate_lr_(self, lr: float, closure: _ClosureType, ascent: TensorList, params: TensorList, backward=False):
|
|
68
|
-
"""Evaluate `lr`, if loss is better than current lowest loss,
|
|
69
|
-
overrides `self._lowest_loss` and `self._best_lr`.
|
|
70
|
-
|
|
71
|
-
Args:
|
|
72
|
-
closure (ClosureType): closure.
|
|
73
|
-
params (tl.TensorList): params.
|
|
74
|
-
ascent_direction (tl.TensorList): ascent.
|
|
75
|
-
lr (float): lr to evaluate.
|
|
76
|
-
|
|
77
|
-
Returns:
|
|
78
|
-
Loss with evaluated lr.
|
|
79
|
-
"""
|
|
80
|
-
# check max iter
|
|
81
|
-
if self._current_iter == self.maxiter: raise MaxIterReached
|
|
82
|
-
self._current_iter += 1
|
|
83
|
-
|
|
84
|
-
# set new lr and evaluate loss with it
|
|
85
|
-
self._set_lr_(lr, ascent, params = params)
|
|
86
|
-
with torch.enable_grad() if backward else torch.no_grad(): self._fx0_approx = _maybe_pass_backward(closure, backward)
|
|
87
|
-
|
|
88
|
-
# if it is the best so far, record it
|
|
89
|
-
if self._fx0_approx < self._lowest_loss:
|
|
90
|
-
self._lowest_loss = self._fx0_approx
|
|
91
|
-
self._best_lr = lr
|
|
92
|
-
|
|
93
|
-
# log lr and loss
|
|
94
|
-
if self.log_lrs:
|
|
95
|
-
self._lrs[-1][lr] = self._fx0_approx
|
|
96
|
-
|
|
97
|
-
return self._fx0_approx
|
|
98
|
-
|
|
99
|
-
def _evaluate_lr_ensure_float(
|
|
100
|
-
self,
|
|
101
|
-
lr: float,
|
|
102
|
-
closure: _ClosureType,
|
|
103
|
-
ascent: TensorList,
|
|
104
|
-
params: TensorList,
|
|
105
|
-
) -> float:
|
|
106
|
-
"""Same as _evaluate_lr_ but ensures that the loss value is float."""
|
|
107
|
-
v = self._evaluate_lr_(lr, closure, ascent, params)
|
|
108
|
-
if isinstance(v, torch.Tensor): return v.detach().cpu().item()
|
|
109
|
-
return float(v)
|
|
110
|
-
|
|
111
|
-
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
112
|
-
"""This should return the best lr."""
|
|
113
|
-
... # pylint:disable=unnecessary-ellipsis
|
|
114
|
-
|
|
115
|
-
@torch.no_grad
|
|
116
|
-
def step(self, vars: OptimizationVars):
|
|
117
|
-
self._reset()
|
|
118
|
-
if self.log_lrs: self._lrs.append({})
|
|
119
|
-
|
|
120
|
-
params = self.get_params()
|
|
121
|
-
ascent_direction = vars.maybe_use_grad_(params)
|
|
122
|
-
|
|
123
|
-
try:
|
|
124
|
-
lr = self._find_best_lr(vars, params) # pylint:disable=assignment-from-no-return
|
|
125
|
-
except MaxIterReached:
|
|
126
|
-
lr = self._best_lr
|
|
127
|
-
|
|
128
|
-
# if child is None, set best lr which update params and return loss
|
|
129
|
-
if self.next_module is None:
|
|
130
|
-
self._set_lr_(lr, ascent_direction, params)
|
|
131
|
-
return self._lowest_loss
|
|
132
|
-
|
|
133
|
-
# otherwise undo the update by setting lr to 0 and instead multiply ascent direction by lr.
|
|
134
|
-
self._set_lr_(0, ascent_direction, params)
|
|
135
|
-
ascent_direction.mul_(self._best_lr)
|
|
136
|
-
vars.ascent = ascent_direction
|
|
137
|
-
if vars.fx0_approx is None: vars.fx0_approx = self._lowest_loss
|
|
138
|
-
return self.next_module.step(vars)
|
|
139
|
-
|
|
@@ -1,217 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...tensorlist import TensorList
|
|
5
|
-
from ...core import OptimizationVars
|
|
6
|
-
from .base_ls import LineSearchBase
|
|
7
|
-
|
|
8
|
-
_FloatOrTensor = float | torch.Tensor
|
|
9
|
-
def _fit_and_minimize_quadratic_2points_grad(x1:_FloatOrTensor,y1:_FloatOrTensor,y1_prime:_FloatOrTensor,x2:_FloatOrTensor,y2:_FloatOrTensor):
|
|
10
|
-
"""Fits a quadratic to value and gradient and x1 and value at x2 and returns minima and a parameter."""
|
|
11
|
-
a = (y1_prime * x2 - y2 - y1_prime*x1 + y1) / (x1**2 - x2**2 - 2*x1**2 + 2*x1*x2)
|
|
12
|
-
b = y1_prime - 2*a*x1
|
|
13
|
-
# c = -(a*x1**2 + b*x1 - y1)
|
|
14
|
-
return -b / (2 * a), a
|
|
15
|
-
|
|
16
|
-
def _ensure_float(x):
|
|
17
|
-
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
|
|
18
|
-
elif isinstance(x, np.ndarray): return x.item()
|
|
19
|
-
return float(x)
|
|
20
|
-
|
|
21
|
-
class DirectionalNewton(LineSearchBase):
|
|
22
|
-
"""Minimizes a parabola in the direction of the update via one additional forward pass,
|
|
23
|
-
and uses another forward pass to make sure it didn't overstep (optionally).
|
|
24
|
-
So in total this performs three forward passes and one backward.
|
|
25
|
-
|
|
26
|
-
It is recommented to put LR before DirectionalNewton.
|
|
27
|
-
|
|
28
|
-
First forward and backward pass is used to calculate the value and gradient at initial parameters.
|
|
29
|
-
Then a gradient descent step is performed with `lr` learning rate, and loss is recalculated
|
|
30
|
-
with new parameters. A quadratic is fitted to two points and gradient,
|
|
31
|
-
if it has positive curvature, this makes a step towards the minimum, and checks if lr decreased
|
|
32
|
-
with an additional forward pass.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
eps (float, optional):
|
|
36
|
-
learning rate, also functions as epsilon for directional second derivative estimation. Defaults to 1.
|
|
37
|
-
max_dist (float | None, optional):
|
|
38
|
-
maximum distance to step when minimizing quadratic.
|
|
39
|
-
If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
|
|
40
|
-
validate_step (bool, optional):
|
|
41
|
-
uses an additional forward pass to check
|
|
42
|
-
if step towards the minimum actually decreased the loss. Defaults to True.
|
|
43
|
-
alpha (float, optional):
|
|
44
|
-
epsilon for estimating directional second derivative, also works as learning rate
|
|
45
|
-
for when curvature is negative or loss increases.
|
|
46
|
-
log_lrs (bool, optional):
|
|
47
|
-
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
48
|
-
Defaults to False.
|
|
49
|
-
|
|
50
|
-
Note:
|
|
51
|
-
While lr scheduling is supported, this uses lr of the first parameter for all parameters.
|
|
52
|
-
"""
|
|
53
|
-
def __init__(self, max_dist: float | None = 1e5, validate_step = True, alpha:float=1, log_lrs = False,):
|
|
54
|
-
super().__init__({"alpha": alpha}, maxiter=None, log_lrs=log_lrs)
|
|
55
|
-
|
|
56
|
-
self.max_dist = max_dist
|
|
57
|
-
self.validate_step = validate_step
|
|
58
|
-
|
|
59
|
-
@torch.no_grad
|
|
60
|
-
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
61
|
-
if vars.closure is None: raise ValueError('QuardaticLS requires closure')
|
|
62
|
-
closure = vars.closure
|
|
63
|
-
|
|
64
|
-
params = self.get_params()
|
|
65
|
-
grad = vars.maybe_compute_grad_(params)
|
|
66
|
-
ascent = vars.maybe_use_grad_(params)
|
|
67
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False) # at this stage maybe_compute_grad could've evaluated fx0
|
|
68
|
-
|
|
69
|
-
alpha: float = self.get_first_group_key('alpha') # this doesn't support variable lrs but we still want to support schedulers
|
|
70
|
-
|
|
71
|
-
# directional f'(x1)
|
|
72
|
-
y1_prime = (grad * ascent).total_sum()
|
|
73
|
-
|
|
74
|
-
# f(x2)
|
|
75
|
-
y2 = self._evaluate_lr_(alpha, closure, ascent, params)
|
|
76
|
-
|
|
77
|
-
# if gradients weren't 0
|
|
78
|
-
if y1_prime != 0:
|
|
79
|
-
xmin, a = _fit_and_minimize_quadratic_2points_grad(
|
|
80
|
-
x1=0,
|
|
81
|
-
y1=vars.fx0,
|
|
82
|
-
y1_prime=-y1_prime,
|
|
83
|
-
x2=alpha,
|
|
84
|
-
# we stepped in the direction of minus gradient times lr.
|
|
85
|
-
# which is why y1_prime is negative and we multiply x2 by lr.
|
|
86
|
-
y2=y2
|
|
87
|
-
)
|
|
88
|
-
# so we obtained xmin in lr*grad units. We need in lr units.
|
|
89
|
-
xmin = _ensure_float(xmin)
|
|
90
|
-
|
|
91
|
-
# make sure curvature is positive
|
|
92
|
-
if a > 0:
|
|
93
|
-
|
|
94
|
-
# discard very large steps
|
|
95
|
-
if self.max_dist is None or xmin <= self.max_dist:
|
|
96
|
-
|
|
97
|
-
# if validate_step is enabled, make sure loss didn't increase
|
|
98
|
-
if self.validate_step:
|
|
99
|
-
y_val = self._evaluate_lr_(xmin, closure, ascent, params)
|
|
100
|
-
# if it increased, move back to y2.
|
|
101
|
-
if y_val > y2:
|
|
102
|
-
return float(alpha)
|
|
103
|
-
|
|
104
|
-
return float(xmin)
|
|
105
|
-
|
|
106
|
-
return float(alpha)
|
|
107
|
-
|
|
108
|
-
def _fit_and_minimize_quadratic_3points(
|
|
109
|
-
x1: _FloatOrTensor,
|
|
110
|
-
y1: _FloatOrTensor,
|
|
111
|
-
x2: _FloatOrTensor,
|
|
112
|
-
y2: _FloatOrTensor,
|
|
113
|
-
x3: _FloatOrTensor,
|
|
114
|
-
y3: _FloatOrTensor,
|
|
115
|
-
):
|
|
116
|
-
"""Fits a quadratic to three points."""
|
|
117
|
-
a = (x1*(y3-y2) + x2*(y1-y3) + x3*(y2-y1)) / ((x1-x2) * (x1 - x3) * (x2 - x3))
|
|
118
|
-
b = (y2-y1) / (x2-x1) - a*(x1+x2)
|
|
119
|
-
# c = (y1 - a*x1**2 - b*x1)
|
|
120
|
-
return (-b / (2 * a), a)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def _newton_step_3points(
|
|
124
|
-
xneg: _FloatOrTensor,
|
|
125
|
-
yneg: _FloatOrTensor,
|
|
126
|
-
x0: _FloatOrTensor,
|
|
127
|
-
y0: _FloatOrTensor,
|
|
128
|
-
xpos: _FloatOrTensor, # since points are evenly spaced, xpos is x0 + eps, its turns out unused
|
|
129
|
-
ypos: _FloatOrTensor,
|
|
130
|
-
):
|
|
131
|
-
eps = x0 - xneg
|
|
132
|
-
dx = (-yneg + ypos) / (2 * eps)
|
|
133
|
-
ddx = (ypos - 2*y0 + yneg) / (eps**2)
|
|
134
|
-
|
|
135
|
-
# xneg is actually x0
|
|
136
|
-
return xneg - dx / ddx, ddx
|
|
137
|
-
|
|
138
|
-
class DirectionalNewton3Points(LineSearchBase):
|
|
139
|
-
"""Minimizes a parabola in the direction of the update via two additional forward pass,
|
|
140
|
-
and uses another forward pass to make sure it didn't overstep (optionally).
|
|
141
|
-
So in total this performs four forward passes.
|
|
142
|
-
|
|
143
|
-
It is recommented to put LR before DirectionalNewton3Points
|
|
144
|
-
|
|
145
|
-
Two steps are performed in the direction of the update with `lr` learning rate.
|
|
146
|
-
A quadratic is fitted to three points, if it has positive curvature,
|
|
147
|
-
this makes a step towards the minimum, and checks if lr decreased
|
|
148
|
-
with an additional forward pass.
|
|
149
|
-
|
|
150
|
-
Args:
|
|
151
|
-
for when curvature is negative or loss increases.
|
|
152
|
-
max_dist (float | None, optional):
|
|
153
|
-
maximum distance to step when minimizing quadratic.
|
|
154
|
-
If minimum is further than this distance, minimization is not performed. Defaults to 1e4.
|
|
155
|
-
validate_step (bool, optional):
|
|
156
|
-
uses an additional forward pass to check
|
|
157
|
-
if step towards the minimum actually decreased the loss. Defaults to True.
|
|
158
|
-
alpha (float, optional):
|
|
159
|
-
epsilon for estimating directional second derivative, also works as learning rate
|
|
160
|
-
log_lrs (bool, optional):
|
|
161
|
-
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
162
|
-
Defaults to False.
|
|
163
|
-
|
|
164
|
-
Note:
|
|
165
|
-
While lr scheduling is supported, this uses lr of the first parameter for all parameters.
|
|
166
|
-
"""
|
|
167
|
-
def __init__(self, max_dist: float | None = 1e4, validate_step = True, alpha: float = 1, log_lrs = False,):
|
|
168
|
-
super().__init__(dict(alpha = alpha), maxiter=None, log_lrs=log_lrs)
|
|
169
|
-
|
|
170
|
-
self.alpha = alpha
|
|
171
|
-
self.max_dist = max_dist
|
|
172
|
-
self.validate_step = validate_step
|
|
173
|
-
|
|
174
|
-
@torch.no_grad
|
|
175
|
-
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
176
|
-
if vars.closure is None: raise ValueError('QuardaticLS requires closure')
|
|
177
|
-
closure = vars.closure
|
|
178
|
-
ascent_direction = vars.ascent
|
|
179
|
-
if ascent_direction is None: raise ValueError('Ascent direction is None')
|
|
180
|
-
alpha: float = self.get_first_group_key('alpha')
|
|
181
|
-
|
|
182
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
183
|
-
params = self.get_params()
|
|
184
|
-
|
|
185
|
-
# make a step in the direction and evaluate f(x2)
|
|
186
|
-
y2 = self._evaluate_lr_(alpha, closure, ascent_direction, params)
|
|
187
|
-
|
|
188
|
-
# make a step in the direction and evaluate f(x3)
|
|
189
|
-
y3 = self._evaluate_lr_(alpha*2, closure, ascent_direction, params)
|
|
190
|
-
|
|
191
|
-
# if gradients weren't 0
|
|
192
|
-
xmin, a = _newton_step_3points(
|
|
193
|
-
0, vars.fx0,
|
|
194
|
-
# we stepped in the direction of minus ascent_direction.
|
|
195
|
-
alpha, y2,
|
|
196
|
-
alpha * 2, y3
|
|
197
|
-
)
|
|
198
|
-
xmin = _ensure_float(xmin)
|
|
199
|
-
|
|
200
|
-
# make sure curvature is positive
|
|
201
|
-
if a > 0:
|
|
202
|
-
|
|
203
|
-
# discard very large steps
|
|
204
|
-
if self.max_dist is None or xmin <= self.max_dist:
|
|
205
|
-
|
|
206
|
-
# if validate_step is enabled, make sure loss didn't increase
|
|
207
|
-
if self.validate_step:
|
|
208
|
-
y_val = self._evaluate_lr_(xmin, closure, ascent_direction, params)
|
|
209
|
-
# if it increased, move back to y2.
|
|
210
|
-
if y_val > y2 or y_val > y3:
|
|
211
|
-
if y3 > y2: return alpha
|
|
212
|
-
else: return alpha * 2
|
|
213
|
-
|
|
214
|
-
return xmin
|
|
215
|
-
|
|
216
|
-
if y3 > y2: return alpha
|
|
217
|
-
else: return alpha * 2
|
|
@@ -1,158 +0,0 @@
|
|
|
1
|
-
from typing import Any, Literal
|
|
2
|
-
from collections.abc import Sequence
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...tensorlist import TensorList
|
|
8
|
-
from ...core import _ClosureType, OptimizationVars
|
|
9
|
-
from .base_ls import LineSearchBase
|
|
10
|
-
|
|
11
|
-
class GridLS(LineSearchBase):
|
|
12
|
-
"""Test all `lrs` and pick best.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
lrs (Sequence[float] | np.ndarray | torch.Tensor): sequence of lrs to test.
|
|
16
|
-
stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
|
|
17
|
-
stop_on_worsened (bool, optional):
|
|
18
|
-
stops if next lr loss is worse than previous one.
|
|
19
|
-
this assumes that lrs are in ascending order. Defaults to False.
|
|
20
|
-
log_lrs (bool, optional):
|
|
21
|
-
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
22
|
-
Defaults to False.
|
|
23
|
-
"""
|
|
24
|
-
def __init__(
|
|
25
|
-
self,
|
|
26
|
-
lrs: Sequence[float] | np.ndarray | torch.Tensor,
|
|
27
|
-
stop_on_improvement=False,
|
|
28
|
-
stop_on_worsened=False,
|
|
29
|
-
log_lrs = False,
|
|
30
|
-
):
|
|
31
|
-
super().__init__({}, maxiter=None, log_lrs=log_lrs)
|
|
32
|
-
self.lrs = lrs
|
|
33
|
-
self.stop_on_improvement = stop_on_improvement
|
|
34
|
-
self.stop_on_worsened = stop_on_worsened
|
|
35
|
-
|
|
36
|
-
@torch.no_grad
|
|
37
|
-
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
38
|
-
if vars.closure is None: raise ValueError("closure is not set")
|
|
39
|
-
if vars.ascent is None: raise ValueError("ascent_direction is not set")
|
|
40
|
-
|
|
41
|
-
if self.stop_on_improvement:
|
|
42
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
43
|
-
self._lowest_loss = vars.fx0
|
|
44
|
-
|
|
45
|
-
for lr in self.lrs:
|
|
46
|
-
loss = self._evaluate_lr_(float(lr), vars.closure, vars.ascent, params)
|
|
47
|
-
|
|
48
|
-
# if worsened
|
|
49
|
-
if self.stop_on_worsened and loss != self._lowest_loss:
|
|
50
|
-
break
|
|
51
|
-
|
|
52
|
-
# if improved
|
|
53
|
-
if self.stop_on_improvement and loss == self._lowest_loss:
|
|
54
|
-
break
|
|
55
|
-
|
|
56
|
-
return float(self._best_lr)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class MultiplicativeLS(GridLS):
|
|
61
|
-
"""Starts with `init` lr, then keeps multiplying it by `mul` until loss stops decreasing.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
init (float, optional): initial lr. Defaults to 0.001.
|
|
65
|
-
mul (float, optional): lr multiplier. Defaults to 2.
|
|
66
|
-
num (int, optional): maximum number of multiplication steps. Defaults to 10.
|
|
67
|
-
stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
|
|
68
|
-
stop_on_worsened (bool, optional):
|
|
69
|
-
stops if next lr loss is worse than previous one.
|
|
70
|
-
this assumes that lrs are in ascending order. Defaults to False.
|
|
71
|
-
log_lrs (bool, optional):
|
|
72
|
-
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
73
|
-
Defaults to False.
|
|
74
|
-
"""
|
|
75
|
-
def __init__(
|
|
76
|
-
self,
|
|
77
|
-
init: float = 0.001,
|
|
78
|
-
mul: float = 2,
|
|
79
|
-
num=10,
|
|
80
|
-
stop_on_improvement=False,
|
|
81
|
-
stop_on_worsened=True,
|
|
82
|
-
):
|
|
83
|
-
super().__init__(
|
|
84
|
-
[init * mul**i for i in range(num)],
|
|
85
|
-
stop_on_improvement=stop_on_improvement,
|
|
86
|
-
stop_on_worsened=stop_on_worsened,
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
class BacktrackingLS(GridLS):
|
|
90
|
-
"""tests `init` lr, and keeps multiplying it by `mul` until loss becomes better than initial loss.
|
|
91
|
-
|
|
92
|
-
note: this doesn't include Armijo–Goldstein condition.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
init (float, optional): initial lr. Defaults to 1.
|
|
96
|
-
mul (float, optional): lr multiplier. Defaults to 0.5.
|
|
97
|
-
num (int, optional): maximum number of multiplication steps. Defaults to 10.
|
|
98
|
-
stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
|
|
99
|
-
stop_on_worsened (bool, optional):
|
|
100
|
-
stops if next lr loss is worse than previous one.
|
|
101
|
-
this assumes that lrs are in ascending order. Defaults to False.
|
|
102
|
-
log_lrs (bool, optional):
|
|
103
|
-
saves lrs and losses with them into optimizer._lrs (for debugging).
|
|
104
|
-
Defaults to False.
|
|
105
|
-
|
|
106
|
-
"""
|
|
107
|
-
def __init__(
|
|
108
|
-
self,
|
|
109
|
-
init: float = 1,
|
|
110
|
-
mul: float = 0.5,
|
|
111
|
-
num=10,
|
|
112
|
-
stop_on_improvement=True,
|
|
113
|
-
stop_on_worsened=False,
|
|
114
|
-
log_lrs = False,
|
|
115
|
-
):
|
|
116
|
-
super().__init__(
|
|
117
|
-
[init * mul**i for i in range(num)],
|
|
118
|
-
stop_on_improvement=stop_on_improvement,
|
|
119
|
-
stop_on_worsened=stop_on_worsened,
|
|
120
|
-
log_lrs = log_lrs,
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
class LinspaceLS(GridLS):
|
|
124
|
-
"""Test all learning rates from a linspace and pick best."""
|
|
125
|
-
def __init__(
|
|
126
|
-
self,
|
|
127
|
-
start: float = 0.001,
|
|
128
|
-
end: float = 2,
|
|
129
|
-
steps=10,
|
|
130
|
-
stop_on_improvement=False,
|
|
131
|
-
stop_on_worsened=False,
|
|
132
|
-
log_lrs = False,
|
|
133
|
-
):
|
|
134
|
-
super().__init__(
|
|
135
|
-
torch.linspace(start, end, steps),
|
|
136
|
-
stop_on_improvement=stop_on_improvement,
|
|
137
|
-
stop_on_worsened=stop_on_worsened,
|
|
138
|
-
log_lrs = log_lrs,
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
class ArangeLS(GridLS):
|
|
142
|
-
"""Test all learning rates from a linspace and pick best."""
|
|
143
|
-
def __init__(
|
|
144
|
-
self,
|
|
145
|
-
start: float = 0.001,
|
|
146
|
-
end: float = 2,
|
|
147
|
-
step=0.1,
|
|
148
|
-
stop_on_improvement=False,
|
|
149
|
-
stop_on_worsened=False,
|
|
150
|
-
log_lrs = False,
|
|
151
|
-
|
|
152
|
-
):
|
|
153
|
-
super().__init__(
|
|
154
|
-
torch.arange(start, end, step),
|
|
155
|
-
stop_on_improvement=stop_on_improvement,
|
|
156
|
-
stop_on_worsened=stop_on_worsened,
|
|
157
|
-
log_lrs = log_lrs,
|
|
158
|
-
)
|