torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
from typing import Any
|
|
3
|
-
|
|
4
|
-
from ...core import OptimizerModule
|
|
5
|
-
from ...tensorlist import TensorList
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class PolyakStepSize(OptimizerModule):
|
|
9
|
-
"""Polyak step-size. Meant to be used at the beginning when ascent is the gradient but other placements may work.
|
|
10
|
-
This can also work with SGD as SPS (Stochastic Polyak Step-Size) seems to use the same formula.
|
|
11
|
-
|
|
12
|
-
Args:
|
|
13
|
-
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
14
|
-
min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
15
|
-
use_grad (bool, optional):
|
|
16
|
-
if True, uses dot product of update and gradient to compute the step size.
|
|
17
|
-
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
18
|
-
Defaults to True.
|
|
19
|
-
parameterwise (bool, optional):
|
|
20
|
-
if True, calculate Polyak step-size for each parameter separately,
|
|
21
|
-
if False calculate one global step size for all parameters. Defaults to False.
|
|
22
|
-
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
23
|
-
"""
|
|
24
|
-
def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
|
|
25
|
-
|
|
26
|
-
defaults = dict(alpha = alpha)
|
|
27
|
-
super().__init__(defaults)
|
|
28
|
-
self.max = max
|
|
29
|
-
self.min_obj_value = min_obj_value
|
|
30
|
-
self.use_grad = use_grad
|
|
31
|
-
self.parameterwise = parameterwise
|
|
32
|
-
|
|
33
|
-
def _update(self, vars, ascent):
|
|
34
|
-
if vars.closure is None: raise ValueError("PolyakStepSize requires closure")
|
|
35
|
-
if vars.fx0 is None: vars.fx0 = vars.closure(False) # can only happen when placed after SPSA
|
|
36
|
-
|
|
37
|
-
alpha = self.get_group_key('alpha')
|
|
38
|
-
|
|
39
|
-
if self.parameterwise:
|
|
40
|
-
if self.use_grad: denom = (ascent*vars.maybe_compute_grad_(self.get_params())).mean()
|
|
41
|
-
else: denom = ascent.pow(2).mean()
|
|
42
|
-
polyak_step_size: TensorList | Any = (vars.fx0 - self.min_obj_value) / denom.where(denom!=0, 1) # type:ignore
|
|
43
|
-
polyak_step_size = polyak_step_size.where(denom != 0, 0)
|
|
44
|
-
if self.max is not None: polyak_step_size = polyak_step_size.clamp_max(self.max)
|
|
45
|
-
|
|
46
|
-
else:
|
|
47
|
-
if self.use_grad: denom = (ascent*vars.maybe_compute_grad_(self.get_params())).total_mean()
|
|
48
|
-
else: denom = ascent.pow(2).total_mean()
|
|
49
|
-
if denom == 0: polyak_step_size = 0 # we converged
|
|
50
|
-
else: polyak_step_size = (vars.fx0 - self.min_obj_value) / denom
|
|
51
|
-
|
|
52
|
-
if self.max is not None:
|
|
53
|
-
if polyak_step_size > self.max: polyak_step_size = self.max
|
|
54
|
-
|
|
55
|
-
ascent.mul_(alpha * polyak_step_size)
|
|
56
|
-
return ascent
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class RandomStepSize(OptimizerModule):
|
|
61
|
-
"""Uses random global step size from `low` to `high`.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
low (float, optional): minimum learning rate. Defaults to 0.
|
|
65
|
-
high (float, optional): maximum learning rate. Defaults to 1.
|
|
66
|
-
parameterwise (bool, optional):
|
|
67
|
-
if True, generate random step size for each parameter separately,
|
|
68
|
-
if False generate one global random step size. Defaults to False.
|
|
69
|
-
"""
|
|
70
|
-
def __init__(self, low: float = 0, high: float = 1, parameterwise=False):
|
|
71
|
-
super().__init__({})
|
|
72
|
-
self.low = low; self.high = high
|
|
73
|
-
self.parameterwise = parameterwise
|
|
74
|
-
|
|
75
|
-
def _update(self, vars, ascent):
|
|
76
|
-
if self.parameterwise:
|
|
77
|
-
lr = [random.uniform(self.low, self.high) for _ in range(len(ascent))]
|
|
78
|
-
else:
|
|
79
|
-
lr = random.uniform(self.low, self.high)
|
|
80
|
-
return ascent.mul_(lr) # type:ignore
|
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
from contextlib import nullcontext
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from ...tensorlist import TensorList, Distributions, mean as tlmean
|
|
7
|
-
from ...utils.python_tools import _ScalarLoss
|
|
8
|
-
from ...core import _ClosureType, OptimizationVars, OptimizerModule, _maybe_pass_backward
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _numpy_or_torch_mean(losses: list):
|
|
12
|
-
"""Returns the mean of a list of losses, which can be either numpy arrays or torch tensors."""
|
|
13
|
-
if isinstance(losses[0], torch.Tensor):
|
|
14
|
-
return torch.mean(torch.stack(losses))
|
|
15
|
-
return np.mean(losses).item()
|
|
16
|
-
|
|
17
|
-
class GaussianHomotopy(OptimizerModule):
|
|
18
|
-
"""Samples and averages value and gradients in multiple random points around current position.
|
|
19
|
-
This effectively applies smoothing to the function.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
n_samples (int, optional): number of gradient samples from around current position. Defaults to 4.
|
|
23
|
-
sigma (float, optional): how far from current position to sample from. Defaults to 0.1.
|
|
24
|
-
distribution (tl.Distributions, optional): distribution for random positions. Defaults to "normal".
|
|
25
|
-
sample_x0 (bool, optional): 1st sample will be x0. Defaults to False.
|
|
26
|
-
randomize_every (int | None, optional): randomizes the points every n steps. Defaults to 1.
|
|
27
|
-
"""
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
n_samples: int = 4,
|
|
31
|
-
sigma: float = 0.1,
|
|
32
|
-
distribution: Distributions = "normal",
|
|
33
|
-
sample_x0 = False,
|
|
34
|
-
randomize_every: int | None = 1,
|
|
35
|
-
):
|
|
36
|
-
defaults = dict(sigma = sigma)
|
|
37
|
-
super().__init__(defaults)
|
|
38
|
-
self.n_samples = n_samples
|
|
39
|
-
self.distribution: Distributions = distribution
|
|
40
|
-
self.randomize_every = randomize_every
|
|
41
|
-
self.current_step = 0
|
|
42
|
-
self.perturbations = None
|
|
43
|
-
self.sample_x0 = sample_x0
|
|
44
|
-
|
|
45
|
-
@torch.no_grad()
|
|
46
|
-
def step(self, vars: OptimizationVars):
|
|
47
|
-
if vars.closure is None: raise ValueError('GaussianSmoothing requires closure.')
|
|
48
|
-
closure = vars.closure
|
|
49
|
-
params = self.get_params()
|
|
50
|
-
sigmas = self.get_group_key('sigma')
|
|
51
|
-
|
|
52
|
-
# generate random perturbations
|
|
53
|
-
if self.perturbations is None or (self.randomize_every is not None and self.current_step % self.randomize_every == 0):
|
|
54
|
-
if self.sample_x0:
|
|
55
|
-
self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples-1)]
|
|
56
|
-
else:
|
|
57
|
-
self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples)]
|
|
58
|
-
|
|
59
|
-
@torch.no_grad
|
|
60
|
-
def smooth_closure(backward = True):
|
|
61
|
-
losses = []
|
|
62
|
-
grads = []
|
|
63
|
-
|
|
64
|
-
# sample gradient and loss at x0
|
|
65
|
-
if self.sample_x0:
|
|
66
|
-
with torch.enable_grad() if backward else nullcontext():
|
|
67
|
-
losses.append(closure())
|
|
68
|
-
if backward: grads.append(params.grad.clone())
|
|
69
|
-
|
|
70
|
-
# sample gradients from points around current params
|
|
71
|
-
# and average them
|
|
72
|
-
if self.perturbations is None: raise ValueError('who set perturbations to None???')
|
|
73
|
-
for p in self.perturbations:
|
|
74
|
-
params.add_(p)
|
|
75
|
-
with torch.enable_grad() if backward else nullcontext():
|
|
76
|
-
losses.append(_maybe_pass_backward(closure, backward))
|
|
77
|
-
if backward: grads.append(params.grad.clone())
|
|
78
|
-
params.sub_(p)
|
|
79
|
-
|
|
80
|
-
# set the new averaged grads and return average loss
|
|
81
|
-
if backward: params.set_grad_(tlmean(grads))
|
|
82
|
-
return _numpy_or_torch_mean(losses)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
self.current_step += 1
|
|
86
|
-
vars.closure = smooth_closure
|
|
87
|
-
return self._update_params_or_step_with_next(vars)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
# todo single loop gaussian homotopy?
|
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from ...core import OptimizerModule
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
def _reset_stats_hook(optimizer, state):
|
|
6
|
-
for module in optimizer.unrolled_modules:
|
|
7
|
-
module: OptimizerModule
|
|
8
|
-
module.reset_stats()
|
|
9
|
-
|
|
10
|
-
# the reason why this needs to be at the end is ??? I NEED TO REMEMBER
|
|
11
|
-
class SwitchEMA(OptimizerModule):
|
|
12
|
-
"""Switch-EMA. Every n steps switches params to an exponential moving average of past weights.
|
|
13
|
-
|
|
14
|
-
In the paper the switch happens after each epoch.
|
|
15
|
-
|
|
16
|
-
Please put this module at the end, after all other modules.
|
|
17
|
-
|
|
18
|
-
This can also function as EMA, set `update_every` to None and instead call `set_ema` and `unset_ema` on this module.
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
update_every (int): number of steps (batches) between setting model parameters to EMA.
|
|
23
|
-
momentum (int): EMA momentum factor.
|
|
24
|
-
reset_stats (bool, optional):
|
|
25
|
-
if True, when setting model parameters to EMA, resets other modules stats such as momentum velocities.
|
|
26
|
-
It might be better to set this to False if `update_every` is very small. Defaults to True.
|
|
27
|
-
|
|
28
|
-
reference
|
|
29
|
-
https://arxiv.org/abs/2402.09240
|
|
30
|
-
"""
|
|
31
|
-
def __init__(self, update_every: int | None, momentum: float = 0.99, reset_stats: bool = True):
|
|
32
|
-
defaults = dict(momentum=momentum)
|
|
33
|
-
super().__init__(defaults)
|
|
34
|
-
self.update_every = update_every
|
|
35
|
-
self.cur_step = 0
|
|
36
|
-
self.update_every = update_every
|
|
37
|
-
self._reset_stats = reset_stats
|
|
38
|
-
self.orig_params = None
|
|
39
|
-
|
|
40
|
-
def set_ema(self):
|
|
41
|
-
"""sets module parameters to EMA, stores original parameters that can be restored by calling `unset_ema`"""
|
|
42
|
-
params = self.get_params()
|
|
43
|
-
self.orig_params = params.clone()
|
|
44
|
-
params.set_(self.get_state_key('ema', init = 'params', params=params))
|
|
45
|
-
|
|
46
|
-
def unset_ema(self):
|
|
47
|
-
"""Undoes `set_ema`."""
|
|
48
|
-
if self.orig_params is None: raise ValueError('call `set_ema` first, and then `unset_ema`.')
|
|
49
|
-
params = self.get_params()
|
|
50
|
-
params.set_(self.orig_params)
|
|
51
|
-
|
|
52
|
-
@torch.no_grad
|
|
53
|
-
def step(self, vars):
|
|
54
|
-
# if self.next_module is not None:
|
|
55
|
-
# warn(f'EMA should usually be the last module, but {self.next_module.__class__.__name__} is after it.')
|
|
56
|
-
self.cur_step += 1
|
|
57
|
-
|
|
58
|
-
params = self.get_params()
|
|
59
|
-
# state.maybe_use_grad_(params)
|
|
60
|
-
# update params with the child. Averaging is always applied at the end.
|
|
61
|
-
ret = self._update_params_or_step_with_next(vars, params)
|
|
62
|
-
|
|
63
|
-
ema = self.get_state_key('ema', init = 'params', params=params)
|
|
64
|
-
momentum = self.get_group_key('momentum')
|
|
65
|
-
|
|
66
|
-
ema.lerp_compat_(params, 1 - momentum)
|
|
67
|
-
|
|
68
|
-
if (self.update_every is not None) and (self.cur_step % self.update_every == 0):
|
|
69
|
-
params.set_(ema.clone())
|
|
70
|
-
if self._reset_stats: vars.add_post_step_hook(_reset_stats_hook)
|
|
71
|
-
|
|
72
|
-
return ret
|
|
@@ -1,171 +0,0 @@
|
|
|
1
|
-
from ...core import OptimizerModule
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
def _reset_stats_hook(optimizer, state):
|
|
5
|
-
for module in optimizer.unrolled_modules:
|
|
6
|
-
module: OptimizerModule
|
|
7
|
-
module.reset_stats()
|
|
8
|
-
|
|
9
|
-
class PeriodicSWA(OptimizerModule):
|
|
10
|
-
"""Periodic Stochastic Weight Averaging.
|
|
11
|
-
|
|
12
|
-
Please put this module at the end, after all other modules.
|
|
13
|
-
|
|
14
|
-
The algorithm is as follows:
|
|
15
|
-
|
|
16
|
-
1. perform `pswa_start` normal steps before starting PSWA.
|
|
17
|
-
|
|
18
|
-
2. Perform multiple SWA iterations. On each iteration,
|
|
19
|
-
run SWA algorithm for `num_cycles` cycles,
|
|
20
|
-
and set weights to the weighted average before starting the next SWA iteration.
|
|
21
|
-
|
|
22
|
-
SWA iteration is as follows:
|
|
23
|
-
|
|
24
|
-
1. perform `cycle_start` initial steps (can be 0)
|
|
25
|
-
|
|
26
|
-
2. for `num_cycles`, after every `cycle_length` steps passed, update the weight average with current model weights.
|
|
27
|
-
|
|
28
|
-
3. After `num_cycles` cycles passed, set model parameters to the weight average.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
first_swa (int):
|
|
32
|
-
number of steps before starting PSWA, authors run PSWA starting from 40th epoch out ot 150 epochs in total.
|
|
33
|
-
cycle_length (int):
|
|
34
|
-
number of steps betwen updating the weight average. Authors update it once per epoch.
|
|
35
|
-
num_cycles (int):
|
|
36
|
-
Number of weight average updates before setting model weights to the average and proceding to the next cycle.
|
|
37
|
-
Authors use 20 (meaning 20 epochs since each cycle is 1 epoch).
|
|
38
|
-
cycle_start (int, optional):
|
|
39
|
-
number of steps at the beginning of each SWA period before updating the weight average (default: 0).
|
|
40
|
-
reset_stats (bool, optional):
|
|
41
|
-
if True, when setting model parameters to SWA, resets other modules stats such as momentum velocities (default: True).
|
|
42
|
-
"""
|
|
43
|
-
def __init__(self, pswa_start: int, cycle_length: int, num_cycles: int, cycle_start: int = 0, reset_stats:bool = True):
|
|
44
|
-
|
|
45
|
-
super().__init__({})
|
|
46
|
-
self.pswa_start = pswa_start
|
|
47
|
-
self.cycle_start = cycle_start
|
|
48
|
-
self.cycle_length = cycle_length
|
|
49
|
-
self.num_cycles = num_cycles
|
|
50
|
-
self._reset_stats = reset_stats
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
self.cur = 0
|
|
54
|
-
self.period_cur = 0
|
|
55
|
-
self.swa_cur = 0
|
|
56
|
-
self.n_models = 0
|
|
57
|
-
|
|
58
|
-
def step(self, vars):
|
|
59
|
-
swa = None
|
|
60
|
-
params = self.get_params()
|
|
61
|
-
ret = self._update_params_or_step_with_next(vars, params)
|
|
62
|
-
|
|
63
|
-
# start first period after `pswa_start` steps
|
|
64
|
-
if self.cur >= self.pswa_start:
|
|
65
|
-
|
|
66
|
-
# start swa after `cycle_start` steps in the current period
|
|
67
|
-
if self.period_cur >= self.cycle_start:
|
|
68
|
-
|
|
69
|
-
# swa updates on every `cycle_length`th step
|
|
70
|
-
if self.swa_cur % self.cycle_length == 0:
|
|
71
|
-
swa = self.get_state_key('swa') # initialized to zeros for simplicity
|
|
72
|
-
swa.mul_(self.n_models).add_(params).div_(self.n_models + 1)
|
|
73
|
-
self.n_models += 1
|
|
74
|
-
|
|
75
|
-
self.swa_cur += 1
|
|
76
|
-
|
|
77
|
-
self.period_cur += 1
|
|
78
|
-
|
|
79
|
-
self.cur += 1
|
|
80
|
-
|
|
81
|
-
# passed num_cycles in period, set model parameters to SWA
|
|
82
|
-
if self.n_models == self.num_cycles:
|
|
83
|
-
self.period_cur = 0
|
|
84
|
-
self.swa_cur = 0
|
|
85
|
-
self.n_models = 0
|
|
86
|
-
|
|
87
|
-
assert swa is not None # it's created above self.n_models += 1
|
|
88
|
-
|
|
89
|
-
params.set_(swa)
|
|
90
|
-
# add a hook that resets momentum, which also deletes `swa` in this module
|
|
91
|
-
if self._reset_stats: vars.add_post_step_hook(_reset_stats_hook)
|
|
92
|
-
|
|
93
|
-
return ret
|
|
94
|
-
|
|
95
|
-
class CyclicSWA(OptimizerModule):
|
|
96
|
-
"""Periodic SWA with cyclic learning rate. So it samples the weights, increases lr to `peak_lr`, samples the weights again,
|
|
97
|
-
decreases lr back to `init_lr`, and samples the weights last time. Then model weights are replaced with the average of the three sampled weights,
|
|
98
|
-
and next cycle starts. I made this due to a horrible misreading of the original SWA paper but it seems to work well.
|
|
99
|
-
|
|
100
|
-
Please put this module at the end, after all other modules.
|
|
101
|
-
|
|
102
|
-
Args:
|
|
103
|
-
cswa_start (int): number of steps before starting the first CSWA cycle.
|
|
104
|
-
cycle_length (int): length of each cycle in steps.
|
|
105
|
-
steps_between (int): number of steps between cycles.
|
|
106
|
-
init_lr (float, optional): initial and final learning rate in each cycle. Defaults to 0.
|
|
107
|
-
peak_lr (float, optional): peak learning rate of each cycle. Defaults to 1.
|
|
108
|
-
sample_all (float, optional): if True, instead of sampling 3 weights, it samples all weights in the cycle. Defaults to False.
|
|
109
|
-
reset_stats (bool, optional):
|
|
110
|
-
if True, when setting model parameters to SWA, resets other modules stats such as momentum velocities (default: True).
|
|
111
|
-
|
|
112
|
-
"""
|
|
113
|
-
def __init__(self, cswa_start: int, cycle_length: int, steps_between: int, init_lr: float = 0, peak_lr: float = 1, sample_all = False, reset_stats: bool=True,):
|
|
114
|
-
defaults = dict(init_lr = init_lr, peak_lr = peak_lr)
|
|
115
|
-
super().__init__(defaults)
|
|
116
|
-
self.cswa_start = cswa_start
|
|
117
|
-
self.cycle_length = cycle_length
|
|
118
|
-
self.init_lr = init_lr
|
|
119
|
-
self.peak_lr = peak_lr
|
|
120
|
-
self.steps_between = steps_between
|
|
121
|
-
self.sample_all = sample_all
|
|
122
|
-
self._reset_stats = reset_stats
|
|
123
|
-
|
|
124
|
-
self.cur = 0
|
|
125
|
-
self.cycle_cur = 0
|
|
126
|
-
self.n_models = 0
|
|
127
|
-
|
|
128
|
-
self.cur_lr = self.init_lr
|
|
129
|
-
|
|
130
|
-
def step(self, vars):
|
|
131
|
-
params = self.get_params()
|
|
132
|
-
|
|
133
|
-
# start first period after `cswa_start` steps
|
|
134
|
-
if self.cur >= self.cswa_start:
|
|
135
|
-
|
|
136
|
-
ascent = vars.maybe_use_grad_(params)
|
|
137
|
-
|
|
138
|
-
# determine the lr
|
|
139
|
-
point = self.cycle_cur / self.cycle_length
|
|
140
|
-
init_lr, peak_lr = self.get_group_keys('init_lr', 'peak_lr')
|
|
141
|
-
if point < 0.5:
|
|
142
|
-
p2 = point*2
|
|
143
|
-
lr = init_lr * (1-p2) + peak_lr * p2
|
|
144
|
-
else:
|
|
145
|
-
p2 = (1 - point)*2
|
|
146
|
-
lr = init_lr * (1-p2) + peak_lr * p2
|
|
147
|
-
|
|
148
|
-
ascent *= lr
|
|
149
|
-
ret = self._update_params_or_step_with_next(vars, params)
|
|
150
|
-
|
|
151
|
-
if self.sample_all or self.cycle_cur in (0, self.cycle_length, self.cycle_length // 2):
|
|
152
|
-
swa = self.get_state_key('swa')
|
|
153
|
-
swa.mul_(self.n_models).add_(params).div_(self.n_models + 1)
|
|
154
|
-
self.n_models += 1
|
|
155
|
-
|
|
156
|
-
if self.cycle_cur == self.cycle_length:
|
|
157
|
-
if not self.sample_all: assert self.n_models == 3, self.n_models
|
|
158
|
-
self.n_models = 0
|
|
159
|
-
self.cycle_cur = -1
|
|
160
|
-
|
|
161
|
-
params.set_(swa)
|
|
162
|
-
if self._reset_stats: vars.add_post_step_hook(_reset_stats_hook)
|
|
163
|
-
|
|
164
|
-
self.cycle_cur += 1
|
|
165
|
-
|
|
166
|
-
else:
|
|
167
|
-
ret = self._update_params_or_step_with_next(vars, params)
|
|
168
|
-
|
|
169
|
-
self.cur += 1
|
|
170
|
-
|
|
171
|
-
return ret
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
"""Optimizers that I haven't tested and various (mostly stupid) ideas go there.
|
|
2
|
-
If something works well I will move it outside of experimental folder.
|
|
3
|
-
Otherwise all optimizers in this category should be considered unlikely to good for most tasks."""
|
|
4
|
-
from .experimental import (
|
|
5
|
-
HVPDiagNewton,
|
|
6
|
-
ExaggeratedNesterov,
|
|
7
|
-
ExtraCautiousAdam,
|
|
8
|
-
GradMin,
|
|
9
|
-
InwardSGD,
|
|
10
|
-
MinibatchRprop,
|
|
11
|
-
MomentumDenominator,
|
|
12
|
-
MomentumNumerator,
|
|
13
|
-
MultistepSGD,
|
|
14
|
-
RandomCoordinateMomentum,
|
|
15
|
-
ReciprocalSGD,
|
|
16
|
-
NoiseSign,
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from .ray_search import NewtonFDMRaySearch, LBFGSRaySearch
|