torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# pylint:disable=not-callable
|
|
2
|
+
"""all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ....core import Chainable, TensorTransform
|
|
9
|
+
from ._psgd_utils import _initialize_lra_state_
|
|
10
|
+
from .psgd import lift2single, precond_grad_lra, update_precond_lra_whiten
|
|
11
|
+
|
|
12
|
+
# matches
|
|
13
|
+
class PSGDLRAWhiten(TensorTransform):
|
|
14
|
+
"""Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
rank (int, optional):
|
|
18
|
+
Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
|
|
19
|
+
init_scale (float | None, optional):
|
|
20
|
+
initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
|
|
21
|
+
lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
|
|
22
|
+
betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
|
|
23
|
+
damping (float, optional):
|
|
24
|
+
adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
|
|
25
|
+
grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
|
|
26
|
+
update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
|
|
27
|
+
concat_params (bool, optional):
|
|
28
|
+
if True, treats all parameters as concatenated to a single vector.
|
|
29
|
+
If False, each parameter is preconditioned separately. Defaults to True.
|
|
30
|
+
inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
|
|
31
|
+
|
|
32
|
+
###Examples:
|
|
33
|
+
|
|
34
|
+
Pure PSGD LRA:
|
|
35
|
+
```py
|
|
36
|
+
optimizer = tz.Optimizer(
|
|
37
|
+
model.parameters(),
|
|
38
|
+
tz.m.LRAWhiten(),
|
|
39
|
+
tz.m.LR(1e-3),
|
|
40
|
+
)
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Momentum into preconditioner (whitens momentum):
|
|
44
|
+
```py
|
|
45
|
+
optimizer = tz.Optimizer(
|
|
46
|
+
model.parameters(),
|
|
47
|
+
tz.m.EMA(0.9),
|
|
48
|
+
tz.m.LRAWhiten(),
|
|
49
|
+
tz.m.LR(1e-3),
|
|
50
|
+
)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
Updating the preconditioner from gradients and applying it to momentum:
|
|
54
|
+
```py
|
|
55
|
+
optimizer = tz.Optimizer(
|
|
56
|
+
model.parameters(),
|
|
57
|
+
tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
|
|
58
|
+
tz.m.LR(1e-3),
|
|
59
|
+
)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
rank: int = 10,
|
|
66
|
+
init_scale: float | None = None,
|
|
67
|
+
lr_preconditioner=0.1,
|
|
68
|
+
betaL=0.9,
|
|
69
|
+
damping=1e-9,
|
|
70
|
+
grad_clip_max_amp=float("inf"),
|
|
71
|
+
update_probability=1.0,
|
|
72
|
+
|
|
73
|
+
concat_params: bool = True,
|
|
74
|
+
inner: Chainable | None = None,
|
|
75
|
+
):
|
|
76
|
+
defaults = locals().copy()
|
|
77
|
+
del defaults["inner"], defaults["self"]
|
|
78
|
+
super().__init__(defaults, concat_params=concat_params, inner=inner)
|
|
79
|
+
|
|
80
|
+
@torch.no_grad
|
|
81
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
82
|
+
_initialize_lra_state_(tensor, state, setting)
|
|
83
|
+
|
|
84
|
+
@torch.no_grad
|
|
85
|
+
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
|
|
86
|
+
|
|
87
|
+
g = tensor.ravel().unsqueeze(1) # column vector
|
|
88
|
+
|
|
89
|
+
UVd = state["UVd"]
|
|
90
|
+
if UVd[2] is None: # initialize d on the fly
|
|
91
|
+
UVd[2] = (torch.mean(g**4) + setting["damping"]**4)**(-1/8) * torch.ones_like(g)
|
|
92
|
+
|
|
93
|
+
if torch.rand([]) < setting["update_probability"]: # update preconditioner
|
|
94
|
+
update_precond_lra_whiten(
|
|
95
|
+
UVd=UVd,
|
|
96
|
+
Luvd=state["Luvd"],
|
|
97
|
+
g=g,
|
|
98
|
+
lr=setting["lr_preconditioner"],
|
|
99
|
+
betaL=setting["betaL"],
|
|
100
|
+
damping=setting["damping"],
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@torch.no_grad
|
|
104
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
105
|
+
|
|
106
|
+
g = tensor.ravel().unsqueeze(1)
|
|
107
|
+
pre_grad = precond_grad_lra(UVd=state["UVd"], g=g)
|
|
108
|
+
|
|
109
|
+
# norm clipping
|
|
110
|
+
grad_clip_max_amp = setting["grad_clip_max_amp"]
|
|
111
|
+
if grad_clip_max_amp < float("inf"): # clip preconditioned gradient
|
|
112
|
+
amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
|
|
113
|
+
if amp > grad_clip_max_amp:
|
|
114
|
+
pre_grad *= grad_clip_max_amp/amp
|
|
115
|
+
|
|
116
|
+
return pre_grad.view_as(tensor)
|
|
@@ -1,45 +1,11 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
1
|
from typing import Literal
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
|
|
6
|
-
from ...core import
|
|
5
|
+
from ...core import TensorTransform, Chainable
|
|
7
6
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def rmsprop_(
|
|
12
|
-
tensors_: TensorList,
|
|
13
|
-
exp_avg_sq_: TensorList,
|
|
14
|
-
smoothing: float | NumberList,
|
|
15
|
-
eps: float | NumberList,
|
|
16
|
-
debiased: bool,
|
|
17
|
-
step: int,
|
|
18
|
-
exp_avg_: TensorList | None = None,
|
|
19
|
-
max_exp_avg_sq_: TensorList | None = None,
|
|
20
|
-
pow: float = 2,
|
|
21
|
-
|
|
22
|
-
# inner args
|
|
23
|
-
inner: Module | None = None,
|
|
24
|
-
params: list[torch.Tensor] | None = None,
|
|
25
|
-
grads: list[torch.Tensor] | None = None,
|
|
26
|
-
):
|
|
27
|
-
"""returns `tensors_`"""
|
|
28
|
-
if exp_avg_ is not None:
|
|
29
|
-
sqrt_exp_avg_sq = sqrt_centered_ema_sq_(tensors=tensors_, exp_avg_=exp_avg_,
|
|
30
|
-
exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
31
|
-
beta=smoothing,debiased=debiased,step=step,pow=pow)
|
|
32
|
-
else:
|
|
33
|
-
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors=tensors_,exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
34
|
-
beta=smoothing,debiased=debiased,step=step,pow=pow)
|
|
35
|
-
|
|
36
|
-
if inner is not None:
|
|
37
|
-
assert params is not None
|
|
38
|
-
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
39
|
-
|
|
40
|
-
return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
|
|
41
|
-
|
|
42
|
-
class RMSprop(Transform):
|
|
7
|
+
|
|
8
|
+
class RMSprop(TensorTransform):
|
|
43
9
|
"""Divides graient by EMA of gradient squares.
|
|
44
10
|
|
|
45
11
|
This implementation is identical to :code:`torch.optim.RMSprop`.
|
|
@@ -48,7 +14,7 @@ class RMSprop(Transform):
|
|
|
48
14
|
smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
|
|
49
15
|
eps (float, optional): epsilon for division. Defaults to 1e-8.
|
|
50
16
|
centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
|
|
51
|
-
|
|
17
|
+
debias (bool, optional): applies Adam debiasing. Defaults to False.
|
|
52
18
|
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
53
19
|
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
54
20
|
init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
|
|
@@ -60,44 +26,86 @@ class RMSprop(Transform):
|
|
|
60
26
|
smoothing: float = 0.99,
|
|
61
27
|
eps: float = 1e-8,
|
|
62
28
|
centered: bool = False,
|
|
63
|
-
|
|
29
|
+
debias: bool = False,
|
|
64
30
|
amsgrad: bool = False,
|
|
65
|
-
pow: float = 2,
|
|
66
31
|
init: Literal["zeros", "update"] = "zeros",
|
|
32
|
+
|
|
67
33
|
inner: Chainable | None = None,
|
|
34
|
+
exp_avg_sq_tfm: Chainable | None = None,
|
|
68
35
|
):
|
|
69
|
-
defaults =
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
36
|
+
defaults = locals().copy()
|
|
37
|
+
del defaults['self'], defaults["inner"], defaults["exp_avg_sq_tfm"]
|
|
38
|
+
super().__init__(defaults, inner=inner)
|
|
39
|
+
|
|
40
|
+
self.set_child('exp_avg_sq', exp_avg_sq_tfm)
|
|
41
|
+
|
|
42
|
+
@torch.no_grad
|
|
43
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
44
|
+
if setting["init"] == "zeros":
|
|
45
|
+
state["exp_avg_sq"] = torch.zeros_like(tensor)
|
|
46
|
+
if setting["centered"]: state["exp_avg"] = torch.zeros_like(tensor)
|
|
47
|
+
if setting["amsgrad"]: state["amsgrad"] = torch.zeros_like(tensor)
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
state["exp_avg_sq"] = tensor ** 2
|
|
51
|
+
if setting["centered"]: state["exp_avg"] = tensor.clone()
|
|
52
|
+
if setting["amsgrad"]: state["amsgrad"] = tensor ** 2
|
|
53
|
+
|
|
54
|
+
@torch.no_grad
|
|
55
|
+
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
|
|
56
|
+
self.increment_counter("step", start = 0)
|
|
57
|
+
fs = settings[0]
|
|
58
|
+
|
|
59
|
+
exp_avg_sq = unpack_states(states, tensors, "exp_avg_sq", cls=TensorList)
|
|
60
|
+
|
|
61
|
+
# update exponential average
|
|
62
|
+
smoothing = NumberList(s["smoothing"] for s in settings)
|
|
63
|
+
exp_avg_sq.mul_(smoothing).addcmul_(tensors, tensors, value=1-smoothing)
|
|
64
|
+
|
|
65
|
+
# update mean estimate if centered
|
|
66
|
+
if fs["centered"]:
|
|
67
|
+
exp_avg = unpack_states(states, tensors, "exp_avg", cls=TensorList)
|
|
68
|
+
exp_avg.lerp_(tensors, 1-smoothing)
|
|
69
|
+
|
|
70
|
+
# amsgrad
|
|
71
|
+
if fs["amsgrad"]:
|
|
72
|
+
exp_avg_sq_max = unpack_states(states, tensors, "exp_avg_sq_max", cls=TensorList)
|
|
73
|
+
exp_avg_sq_max.maximum_(exp_avg_sq)
|
|
74
|
+
|
|
75
|
+
@torch.no_grad
|
|
76
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
77
|
+
tensors = TensorList(tensors)
|
|
78
|
+
step = self.global_state["step"] # 0 on 1st step
|
|
79
|
+
eps = NumberList(s["eps"] for s in settings)
|
|
80
|
+
fs = settings[0]
|
|
81
|
+
|
|
82
|
+
if fs["amsgrad"]: key = "max_exp_avg_sq"
|
|
83
|
+
else: key = "exp_avg_sq"
|
|
84
|
+
exp_avg_sq = TensorList(s[key] for s in states)
|
|
85
|
+
|
|
86
|
+
# load mean estimate if centered
|
|
87
|
+
exp_avg = None
|
|
88
|
+
if fs['centered']:
|
|
89
|
+
exp_avg = TensorList(s["exp_avg"] for s in states)
|
|
90
|
+
|
|
91
|
+
# debias exp_avg_sq and exp_avg
|
|
92
|
+
if fs["debias"]:
|
|
93
|
+
smoothing = NumberList(s["smoothing"] for s in settings)
|
|
94
|
+
bias_correction = 1 - (smoothing ** (step + 1))
|
|
95
|
+
exp_avg_sq = exp_avg_sq / bias_correction
|
|
96
|
+
|
|
97
|
+
if fs['centered']:
|
|
98
|
+
assert exp_avg is not None
|
|
99
|
+
exp_avg = exp_avg / bias_correction
|
|
100
|
+
|
|
101
|
+
# apply transform to potentially debiased exp_avg_sq
|
|
102
|
+
exp_avg_sq = TensorList(self.inner_step_tensors(
|
|
103
|
+
"exp_avg_sq", exp_avg_sq, params=params, grads=grads, loss=loss, clone=True, must_exist=False
|
|
104
|
+
))
|
|
105
|
+
|
|
106
|
+
# center
|
|
107
|
+
if fs["centered"]:
|
|
108
|
+
assert exp_avg is not None
|
|
109
|
+
exp_avg_sq = exp_avg_sq.addcmul(exp_avg, exp_avg, value=-1)
|
|
110
|
+
|
|
111
|
+
return tensors.div_(exp_avg_sq.sqrt().add_(eps))
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
|
-
from ...core import
|
|
5
|
-
from ...utils import NumberList, TensorList,
|
|
4
|
+
from ...core import TensorTransform
|
|
5
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
def _bool_ones_like(x):
|
|
@@ -126,7 +126,7 @@ def rprop_(
|
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
|
|
129
|
-
class Rprop(
|
|
129
|
+
class Rprop(TensorTransform):
|
|
130
130
|
"""
|
|
131
131
|
Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
|
|
132
132
|
or `nminus` if it did. Then the update is applied with the sign of the current gradient.
|
|
@@ -165,7 +165,7 @@ class Rprop(Transform):
|
|
|
165
165
|
super().__init__(defaults, uses_grad=False)
|
|
166
166
|
|
|
167
167
|
@torch.no_grad
|
|
168
|
-
def
|
|
168
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
169
169
|
step = self.global_state.get('step', 0)
|
|
170
170
|
self.global_state['step'] = step + 1
|
|
171
171
|
|
|
@@ -178,7 +178,7 @@ class Rprop(Transform):
|
|
|
178
178
|
)
|
|
179
179
|
|
|
180
180
|
tensors = rprop_(
|
|
181
|
-
tensors_ =
|
|
181
|
+
tensors_ = TensorList(tensors),
|
|
182
182
|
prev_ = prev,
|
|
183
183
|
allowed_ = allowed,
|
|
184
184
|
magnitudes_ = magnitudes,
|
|
@@ -194,7 +194,7 @@ class Rprop(Transform):
|
|
|
194
194
|
return tensors
|
|
195
195
|
|
|
196
196
|
|
|
197
|
-
class ScaleLRBySignChange(
|
|
197
|
+
class ScaleLRBySignChange(TensorTransform):
|
|
198
198
|
"""
|
|
199
199
|
learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
|
|
200
200
|
or `nminus` if it did.
|
|
@@ -218,19 +218,19 @@ class ScaleLRBySignChange(Transform):
|
|
|
218
218
|
ub=50.0,
|
|
219
219
|
alpha=1.0,
|
|
220
220
|
use_grad=False,
|
|
221
|
-
target: Target = "update",
|
|
222
221
|
):
|
|
223
222
|
defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
|
|
224
|
-
super().__init__(defaults, uses_grad=use_grad
|
|
223
|
+
super().__init__(defaults, uses_grad=use_grad)
|
|
225
224
|
|
|
226
225
|
@torch.no_grad
|
|
227
|
-
def
|
|
226
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
228
227
|
step = self.global_state.get('step', 0)
|
|
229
228
|
self.global_state['step'] = step + 1
|
|
230
229
|
|
|
231
|
-
tensors =
|
|
232
|
-
|
|
233
|
-
|
|
230
|
+
tensors = TensorList(tensors)
|
|
231
|
+
if self._uses_grad:
|
|
232
|
+
assert grads is not None
|
|
233
|
+
cur = TensorList(grads)
|
|
234
234
|
else: cur = tensors
|
|
235
235
|
|
|
236
236
|
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
@@ -252,7 +252,7 @@ class ScaleLRBySignChange(Transform):
|
|
|
252
252
|
)
|
|
253
253
|
return tensors
|
|
254
254
|
|
|
255
|
-
class BacktrackOnSignChange(
|
|
255
|
+
class BacktrackOnSignChange(TensorTransform):
|
|
256
256
|
"""Negates or undoes update for parameters where where gradient or update sign changes.
|
|
257
257
|
|
|
258
258
|
This is part of RProp update rule.
|
|
@@ -266,20 +266,21 @@ class BacktrackOnSignChange(Transform):
|
|
|
266
266
|
Defaults to True.
|
|
267
267
|
|
|
268
268
|
"""
|
|
269
|
-
def __init__(self, use_grad = False, backtrack = True
|
|
270
|
-
defaults = dict(use_grad=use_grad, backtrack=backtrack
|
|
269
|
+
def __init__(self, use_grad = False, backtrack = True):
|
|
270
|
+
defaults = dict(use_grad=use_grad, backtrack=backtrack)
|
|
271
271
|
super().__init__(defaults, uses_grad=use_grad)
|
|
272
272
|
|
|
273
273
|
@torch.no_grad
|
|
274
|
-
def
|
|
274
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
275
275
|
step = self.global_state.get('step', 0)
|
|
276
276
|
self.global_state['step'] = step + 1
|
|
277
277
|
|
|
278
|
-
tensors =
|
|
279
|
-
use_grad = settings[0]['use_grad']
|
|
278
|
+
tensors = TensorList(tensors)
|
|
280
279
|
backtrack = settings[0]['backtrack']
|
|
281
280
|
|
|
282
|
-
if
|
|
281
|
+
if self._uses_grad:
|
|
282
|
+
assert grads is not None
|
|
283
|
+
cur = TensorList(grads)
|
|
283
284
|
else: cur = tensors
|
|
284
285
|
|
|
285
286
|
tensors = backtrack_on_sign_change_(
|
|
@@ -292,54 +293,55 @@ class BacktrackOnSignChange(Transform):
|
|
|
292
293
|
|
|
293
294
|
return tensors
|
|
294
295
|
|
|
295
|
-
class SignConsistencyMask(
|
|
296
|
+
class SignConsistencyMask(TensorTransform):
|
|
296
297
|
"""
|
|
297
298
|
Outputs a mask of sign consistency of current and previous inputs.
|
|
298
299
|
|
|
299
300
|
The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
|
|
300
301
|
|
|
301
|
-
Examples:
|
|
302
|
-
|
|
303
|
-
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
302
|
+
### Examples:
|
|
304
303
|
|
|
305
|
-
|
|
304
|
+
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
306
305
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
306
|
+
```python
|
|
307
|
+
opt = tz.Optimizer(
|
|
308
|
+
model.parameters(),
|
|
309
|
+
tz.m.Mul(tz.m.SignConsistencyMask()),
|
|
310
|
+
tz.m.LR(1e-2)
|
|
311
|
+
)
|
|
312
|
+
```
|
|
312
313
|
|
|
313
314
|
"""
|
|
314
|
-
def __init__(self
|
|
315
|
-
super().__init__(
|
|
315
|
+
def __init__(self):
|
|
316
|
+
super().__init__()
|
|
316
317
|
|
|
317
318
|
@torch.no_grad
|
|
318
|
-
def
|
|
319
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
319
320
|
prev = unpack_states(states, tensors, 'prev', cls=TensorList)
|
|
320
321
|
mask = prev.mul_(tensors).gt_(0)
|
|
321
322
|
prev.copy_(tensors)
|
|
322
323
|
return mask
|
|
323
324
|
|
|
324
325
|
|
|
325
|
-
class SignConsistencyLRs(
|
|
326
|
+
class SignConsistencyLRs(TensorTransform):
|
|
326
327
|
"""Outputs per-weight learning rates based on consecutive sign consistency.
|
|
327
328
|
|
|
328
|
-
The learning rate for a weight is multiplied by
|
|
329
|
+
The learning rate for a weight is multiplied by ``nplus`` when two consecutive update signs are the same, otherwise it is multiplied by ``nplus``. The learning rates are bounded to be in ``(lb, ub)`` range.
|
|
329
330
|
|
|
330
|
-
Examples:
|
|
331
|
+
### Examples:
|
|
331
332
|
|
|
332
|
-
|
|
333
|
+
GD scaled by consecutive gradient sign consistency
|
|
333
334
|
|
|
334
|
-
|
|
335
|
+
```python
|
|
335
336
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
337
|
+
opt = tz.Optimizer(
|
|
338
|
+
model.parameters(),
|
|
339
|
+
tz.m.Mul(tz.m.SignConsistencyLRs()),
|
|
340
|
+
tz.m.LR(1e-2)
|
|
341
|
+
)
|
|
342
|
+
```
|
|
341
343
|
|
|
342
|
-
|
|
344
|
+
"""
|
|
343
345
|
def __init__(
|
|
344
346
|
self,
|
|
345
347
|
nplus: float = 1.2,
|
|
@@ -347,17 +349,16 @@ class SignConsistencyLRs(Transform):
|
|
|
347
349
|
lb: float | None = 1e-6,
|
|
348
350
|
ub: float | None = 50,
|
|
349
351
|
alpha: float = 1,
|
|
350
|
-
target: Target = 'update'
|
|
351
352
|
):
|
|
352
353
|
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
|
|
353
|
-
super().__init__(defaults, uses_grad=False
|
|
354
|
+
super().__init__(defaults, uses_grad=False)
|
|
354
355
|
|
|
355
356
|
@torch.no_grad
|
|
356
|
-
def
|
|
357
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
357
358
|
step = self.global_state.get('step', 0)
|
|
358
359
|
self.global_state['step'] = step + 1
|
|
359
360
|
|
|
360
|
-
target =
|
|
361
|
+
target = TensorList(tensors)
|
|
361
362
|
nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
|
|
362
363
|
prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
|
|
363
364
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from contextlib import nullcontext
|
|
2
2
|
import torch
|
|
3
|
-
from ...utils import TensorList, NumberList
|
|
4
|
-
from ...core import
|
|
3
|
+
from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
|
|
4
|
+
from ...core import Transform
|
|
5
5
|
|
|
6
6
|
|
|
7
|
-
class SAM(
|
|
7
|
+
class SAM(Transform):
|
|
8
8
|
"""Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
|
|
9
9
|
|
|
10
10
|
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
@@ -22,50 +22,51 @@ class SAM(Module):
|
|
|
22
22
|
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
23
23
|
asam (bool, optional):
|
|
24
24
|
enables ASAM variant which makes perturbation relative to weight magnitudes.
|
|
25
|
-
ASAM requires a much larger
|
|
26
|
-
The
|
|
27
|
-
it has larger
|
|
25
|
+
ASAM requires a much larger ``rho``, like 0.5 or 1.
|
|
26
|
+
The ``tz.m.ASAM`` class is idential to setting this argument to True, but
|
|
27
|
+
it has larger ``rho`` by default.
|
|
28
28
|
|
|
29
|
-
Examples:
|
|
30
|
-
SAM-SGD:
|
|
29
|
+
### Examples:
|
|
31
30
|
|
|
32
|
-
|
|
31
|
+
SAM-SGD:
|
|
33
32
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
33
|
+
```py
|
|
34
|
+
opt = tz.Optimizer(
|
|
35
|
+
model.parameters(),
|
|
36
|
+
tz.m.SAM(),
|
|
37
|
+
tz.m.LR(1e-2)
|
|
38
|
+
)
|
|
39
|
+
```
|
|
39
40
|
|
|
40
|
-
|
|
41
|
+
SAM-Adam:
|
|
41
42
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
43
|
+
```
|
|
44
|
+
opt = tz.Optimizer(
|
|
45
|
+
model.parameters(),
|
|
46
|
+
tz.m.SAM(),
|
|
47
|
+
tz.m.Adam(),
|
|
48
|
+
tz.m.LR(1e-2)
|
|
49
|
+
)
|
|
50
|
+
```
|
|
50
51
|
|
|
51
52
|
References:
|
|
52
|
-
Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.
|
|
53
|
+
[Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.](https://arxiv.org/abs/2010.01412#page=3.16)
|
|
53
54
|
"""
|
|
54
55
|
def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
|
|
55
56
|
defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
|
|
56
57
|
super().__init__(defaults)
|
|
57
58
|
|
|
58
59
|
@torch.no_grad
|
|
59
|
-
def
|
|
60
|
+
def update_states(self, objective, states, settings):
|
|
60
61
|
|
|
61
|
-
params =
|
|
62
|
-
closure =
|
|
63
|
-
zero_grad =
|
|
62
|
+
params = objective.params
|
|
63
|
+
closure = objective.closure
|
|
64
|
+
zero_grad = objective.zero_grad
|
|
64
65
|
if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
|
|
65
|
-
p, rho =
|
|
66
|
-
|
|
67
|
-
eps =
|
|
68
|
-
asam =
|
|
66
|
+
p, rho = unpack_dicts(settings, 'p', 'rho', cls=NumberList)
|
|
67
|
+
fs = settings[0]
|
|
68
|
+
eps = fs['eps']
|
|
69
|
+
asam = fs['asam']
|
|
69
70
|
|
|
70
71
|
# 1/p + 1/q = 1
|
|
71
72
|
# okay, authors of SAM paper, I will manually solve your equation
|
|
@@ -123,8 +124,7 @@ class SAM(Module):
|
|
|
123
124
|
|
|
124
125
|
return sam_loss
|
|
125
126
|
|
|
126
|
-
|
|
127
|
-
return var
|
|
127
|
+
objective.closure = sam_closure
|
|
128
128
|
|
|
129
129
|
# different class because defaults for SAM are bad for ASAM
|
|
130
130
|
class ASAM(SAM):
|
|
@@ -136,7 +136,7 @@ class ASAM(SAM):
|
|
|
136
136
|
This implementation modifies the closure to return loss and calculate gradients
|
|
137
137
|
of the SAM objective. All modules after this will use the modified objective.
|
|
138
138
|
|
|
139
|
-
|
|
139
|
+
Note:
|
|
140
140
|
This module requires a closure passed to the optimizer step,
|
|
141
141
|
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
142
142
|
|
|
@@ -144,20 +144,30 @@ class ASAM(SAM):
|
|
|
144
144
|
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
145
145
|
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
146
146
|
|
|
147
|
-
Examples:
|
|
148
|
-
|
|
147
|
+
### Examples:
|
|
148
|
+
|
|
149
|
+
ASAM-SGD:
|
|
149
150
|
|
|
150
|
-
|
|
151
|
+
```py
|
|
152
|
+
opt = tz.Optimizer(
|
|
153
|
+
model.parameters(),
|
|
154
|
+
tz.m.ASAM(),
|
|
155
|
+
tz.m.LR(1e-2)
|
|
156
|
+
)
|
|
157
|
+
```
|
|
151
158
|
|
|
152
|
-
|
|
153
|
-
model.parameters(),
|
|
154
|
-
tz.m.ASAM(),
|
|
155
|
-
tz.m.Adam(),
|
|
156
|
-
tz.m.LR(1e-2)
|
|
157
|
-
)
|
|
159
|
+
ASAM-Adam:
|
|
158
160
|
|
|
161
|
+
```
|
|
162
|
+
opt = tz.Optimizer(
|
|
163
|
+
model.parameters(),
|
|
164
|
+
tz.m.ASAM(),
|
|
165
|
+
tz.m.Adam(),
|
|
166
|
+
tz.m.LR(1e-2)
|
|
167
|
+
)
|
|
168
|
+
```
|
|
159
169
|
References:
|
|
160
|
-
Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July).
|
|
170
|
+
[Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.](https://arxiv.org/abs/2102.11600)
|
|
161
171
|
"""
|
|
162
172
|
def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
|
|
163
173
|
super().__init__(rho=rho, p=p, eps=eps, asam=True)
|